Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storemysql.py

Revision 232 (checked in by fumanchu, 7 years ago)

Work on MySQL, SQLite stores in preparation for moving the typeadapters into the dbmodel layer.

  • Property svn:eol-style set to native
Line 
1 """
2 References the MySQLdb package at:
3 http://sourceforge.net/projects/mysql-python
4
5 From the MySQL manual:
6
7 "If the server SQL mode has ANSI_QUOTES enabled, string literals can be
8 quoted only with single quotes. A string quoted with double quotes will be
9 interpreted as an identifier."
10
11 So use single quotes throughout.
12 """
13
14 # Use _mysql directly to avoid all of the DB-API overhead.
15 import _mysql
16 import warnings
17 import datetime
18
19 import dejavu
20 from dejavu import storage, logic
21 from dejavu.storage import db
22
23
24 class AdapterToMySQL(db.AdapterToSQL):
25    
26     escapes = [("'", "''"), ("\\", r"\\")]
27     like_escapes = [("%", r"\%"), ("_", r"\_")]
28    
29     # TRUE and FALSE only work with 4.1 or better.
30     bool_true = "1"
31     bool_false = "0"
32    
33     def coerce_str(self, value, skip_encoding=False):
34         if not skip_encoding and not isinstance(value, str):
35             value = value.encode(self.encoding)
36         return "'" + _mysql.escape_string(value) + "'"
37    
38     def coerce_bool(self, value):
39         # TRUE and FALSE only work with 4.1 or better.
40         if value:
41             return '1'
42         return '0'
43
44
45 class AdapterFromMySQL(db.AdapterFromDB):
46    
47     def coerce_bool(self, value, coltype):
48         if isinstance(value, basestring):
49             # either '0' or '1'
50             value = (value == '1')
51         return bool(value)
52
53
54 class MySQLDecompiler(db.SQLDecompiler):
55    
56     def dejavu_today(self):
57         return "CURDATE()"
58
59
60 class MySQLDecompiler411(MySQLDecompiler):
61     # Before MySQL 4.1.1, BINARY comparisons could use UPPER()
62     # or LOWER() to perform case-insensitive comparisons. Newer
63     # versions must use CONVERT() to obtain a case-sensitive
64     # encoding, like utf8.
65    
66     def dejavu_icontainedby(self, op1, op2):
67         if isinstance(op1, db.ConstWrapper):
68             # Looking for text in a field. Use Like (reverse terms).
69             return ("CONVERT("+ op2 + " USING utf8) LIKE '%" +
70                     self.adapter.escape_like(op1) + "%'")
71         else:
72             # Looking for field in (a, b, c).
73             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
74             return "CONVERT(%s USING utf8) IN (%s)" % (op1, ", ".join(atoms))
75    
76     def dejavu_istartswith(self, x, y):
77         return ("CONVERT(" + x + " USING utf8) LIKE '" +
78                 self.adapter.escape_like(y) + "%'")
79    
80     def dejavu_iendswith(self, x, y):
81         return ("CONVERT(" + x + " USING utf8) LIKE '%" +
82                 self.adapter.escape_like(y) + "'")
83    
84     def dejavu_ieq(self, x, y):
85         return "CONVERT(" + x + " USING utf8) = " + y
86
87
88 class FieldTypeAdapterMySQL(db.FieldTypeAdapter):
89     """Return the SQL typename of a DB column."""
90    
91     float_max_precision = 53
92     numeric_max_precision = 16
93     numeric_max_bytes = 8
94    
95     def float_type(self, precision):
96         """Return a datatype which can handle the given precision."""
97         # "p represents the precision in bits, but MySQL uses this value
98         # only to determine whether to use FLOAT or DOUBLE for the
99         # resulting data type. If p is from 0 to 24, the data type
100         # becomes FLOAT with no M or D values. If p is from 25 to 53,
101         # the data type becomes DOUBLE with no M or D values."
102         return "FLOAT(%s)" % precision
103    
104     def coerce_str(self, cls, key):
105         prop = getattr(cls, key)
106         bytes = int(prop.hints.get('bytes', '0'))
107         if bytes:
108             # MySQL VARBINARY/BLOBs will do case-sensitive comparisons.
109             # They also won't truncate trailing spaces like VARCHAR does.
110             if bytes <= 255:
111                 return "VARBINARY(%s)" % bytes
112             elif bytes < 2 ** 16:
113                 if self.db._version >= storage.Version("4.1"):
114                     return "BLOB(%s)" % bytes
115                 else:
116                     return "BLOB"
117             elif bytes < 2 ** 24:
118                 return "MEDIUMBLOB"
119         return "LONGBLOB"
120    
121     def coerce_bool(self, cls, key):
122         # We could use BOOLEAN, but it wasn't introduced until 4.1.0.
123         return "BOOL"
124    
125     def coerce_datetime_datetime(self, cls, key):
126         return "DATETIME"
127    
128     def coerce_int(self, cls, key):
129         prop = getattr(cls, key)
130         bytes = int(prop.hints.get('bytes', '4'))
131         if bytes == 1:
132             return "BOOLEAN"
133         elif bytes == 2:
134             typename = "SMALLINT"
135         elif bytes == 3:
136             typename = "MEDIUMINT"
137         else:
138             typename = "INTEGER"
139         return typename
140
141
142
143 class MySQLIndexSet(db.IndexSet):
144    
145     def __delitem__(self, key):
146         t = self.table
147         # MySQL might rename multiple-column indices to "PRIMARY"
148         for i in t.db._get_indices(t.name):
149             if i.colname == self[key].colname:
150                 t.db.execute('DROP INDEX %s ON %s;' % (i.qname, t.qname))
151
152
153 class MySQLColumnSet(db.ColumnSet):
154    
155     indexsetclass = MySQLIndexSet
156    
157     def __setitem__(self, key, column):
158         t = self.table
159        
160         dbtype = column.dbtype
161         if column.autoincrement:
162             dbtype += " AUTO_INCREMENT"
163         else:
164             default = column.default or ""
165             if default:
166                 default = self.adaptertosql.coerce(default)
167                 dbtype = "%s DEFAULT %s" % (dbtype, default)
168        
169         t.db.execute("ALTER TABLE %s ADD COLUMN %s %s;" %
170                      (t.qname, column.qname, dbtype))
171         dict.__setitem__(self, key, column)
172    
173     def _rename(self, oldcol, newcol):
174         # Override this to do the actual rename at the DB level.
175         t = self.table
176         t.db.execute("ALTER TABLE %s CHANGE %s %s %s;" %
177                      (t.qname, oldcol.qname, newcol.qname, oldcol.dbtype))
178
179
180 class MySQLDatabase(db.Database):
181    
182     sql_name_max_length = 64
183     # MySQL uses case-sensitive database and table names on Unix, but
184     # not on Windows. Use all-lowercase identifiers to work around the
185     # problem. "Column names, index names, and column aliases are not
186     # case sensitive on any platform."
187     # If deployers set lower_case_table_names to 1, it would help.
188     sql_name_caseless = True
189    
190     adaptertosql = AdapterToMySQL()
191     adapterfromdb = AdapterFromMySQL()
192    
193     columnsetclass = MySQLColumnSet
194     indexsetclass = MySQLIndexSet
195    
196     def __init__(self, name, **kwargs):
197         db.Database.__init__(self, name, **kwargs)
198        
199         self.decompiler = MySQLDecompiler
200        
201         # Get the version string from MySQL, to see if we need
202         # a different decompiler.
203         conn = self._template_conn()
204         rowdata, cols = self.fetch("SELECT version();", conn)
205         conn.close()
206         v = rowdata[0][0]
207         self._version = storage.Version(v)
208         if self._version > storage.Version("4.1.1"):
209             self.decompiler = MySQLDecompiler411
210    
211     def version(self):
212         return "MySQL Version: %s" % self._version
213    
214     def __setitem__(self, key, table):
215         q = self.quote
216        
217         fields = []
218         incr_fields = []
219         pk = []
220         for colname, col in table.columns.iteritems():
221             qname = col.qname
222             dbtype = col.dbtype
223            
224             if col.autoincrement:
225                 dbtype += " AUTO_INCREMENT"
226                 # INSERT INTO t (c) VALUES(0) doesn't work for some reason
227                 if col.default > 1:
228                     incr_fields.append(col)
229             else:
230                 default = col.default or ""
231                 if default:
232                     default = self.adaptertosql.coerce(default)
233                     dbtype = "%s DEFAULT %s" % (dbtype, default)
234            
235             fields.append('%s %s' % (qname, dbtype))
236            
237             # See create_storage for the other half of this hack.
238             if col.name in table.mysql_identifiers:
239                 if dbtype.endswith('BLOB') or dbtype == 'TEXT':
240                     # MySQL won't allow indexes on a BLOB field without a
241                     # specific index prefix length. We choose 255 just for fun.
242                     qname = "%s(255)" % qname
243                 pk.append(qname)
244        
245         if pk:
246             pk = ", PRIMARY KEY (%s)" % ", ".join(pk)
247        
248         self.execute('CREATE TABLE %s (%s%s);' %
249                      (table.qname, ", ".join(fields), pk))
250        
251         if incr_fields:
252             # Wow, what a hack. We have to INSERT a dummy row to set the
253             # autoincrement initial value(s), and we can't delete it until
254             # after the CREATE INDEX statements (or the counter will revert).
255             fields = ", ".join([col.qname for col in incr_fields])
256             values = ", ".join([str(col.default - 1) for col in incr_fields])
257             self.execute("INSERT INTO %s (%s) VALUES (%s);"
258                          % (table.qname, fields, values))
259        
260         for k, index in table.columns.indices.iteritems():
261             dbtype = table.columns[k].dbtype
262             if dbtype.endswith('BLOB') or dbtype == 'TEXT':
263                 # MySQL won't allow indexes on a BLOB field without a
264                 # specific index prefix length. We choose 255 just for fun.
265                 self.execute('CREATE INDEX %s ON %s (%s(255));' %
266                              (index.qname, table.qname, q(index.colname)))
267             else:
268                 self.execute('CREATE INDEX %s ON %s (%s);' %
269                              (index.qname, table.qname, q(index.colname)))
270        
271         if incr_fields:
272             self.execute("DELETE FROM %s" % table.qname)
273        
274         dict.__setitem__(self, key, table)
275    
276     def _get_tables(self, conn=None):
277         data, _ = self.fetch("SHOW TABLES FROM %s" % self.qname, conn=conn)
278         return [db.Table(self, row[0], self.quote(row[0])) for row in data]
279    
280     def _get_columns(self, tablename, conn=None):
281         # cols are: Field, Type, Null, Key, Default, Extra.
282         # See http://dev.mysql.com/doc/refman/4.1/en/describe.html
283         data, _ = self.fetch("SHOW COLUMNS FROM %s.%s" %
284                              (self.qname, self.quote(tablename)), conn=conn)
285         cols = []
286         for row in data:
287             c = db.Column(row[0], self.quote(row[0]), None, None)
288            
289             dbtype = row[1].upper()
290             parenpos = dbtype.find("(")
291             if parenpos > -1:
292                 args = dbtype[parenpos+1:-1]
293                 baretype = dbtype[:parenpos]
294                 if baretype in ("DECIMAL", "NUMERIC"):
295                     args = [x.strip() for x in args.split(",")]
296                     c.hints['precision'], c.hints['scale'] = args
297                 else:
298                     c.hints['bytes'] = args
299             elif dbtype == "FLOAT":
300                 c.hints['precision'] = 24
301             elif dbtype.startswith("DOUBLE"):
302                 c.hints['precision'] = 53
303             elif dbtype in ("TINYBLOB", "TINYTEXT"):
304                 c.hints['bytes'] = (2 ** 8) - 1
305             elif dbtype in ("BLOB", "TEXT"):
306                 c.hints['bytes'] = (2 ** 16) - 1
307             elif dbtype in ("MEDIUMBLOB", "MEDIUMTEXT"):
308                 c.hints['bytes'] = (2 ** 24) - 1
309             elif dbtype in ("LONGBLOB", "LONGTEXT"):
310                 c.hints['bytes'] = (2 ** 32) - 1
311            
312             if row[4]:
313                 c.default = self.python_type(dbtype)(row[4])
314            
315             if "auto_increment" in row[5].lower():
316                 c.autoincrement = True
317            
318             c.dbtype = dbtype
319            
320             cols.append(c)
321         return cols
322    
323     def _get_indices(self, tablename, conn=None):
324         indices = []
325         try:
326             # cols are: Table, Non_unique, Key_name, Seq_in_index, Column_name,
327             # Collation, Cardinality, Sub_part, Packed, Null, Index_type, Comment
328             data, _ = self.fetch("SHOW INDEX FROM %s.%s"
329                                  % (self.qname, self.quote(tablename)),
330                                  conn=conn)
331         except _mysql.ProgrammingError, x:
332             if x.args[0] != 1146:
333                 raise
334         else:
335             for row in data:
336                 i = db.Index(row[2], self.quote(row[2]),
337                              row[0], row[4], None, not row[1])
338                 indices.append(i)
339         return indices
340    
341     def python_type(self, dbtype):
342         """Return a Python type which can store values of the given dbtype."""
343         dbtype = dbtype.upper()
344         parenpos = dbtype.find("(")
345         if parenpos > -1:
346             dbtype = dbtype[:parenpos]
347        
348         if dbtype in ('TINYINT', 'SMALLINT', 'MEDIUMINT', 'INT', 'INTEGER'):
349             return int
350         elif dbtype == 'BIGINT':
351             return long
352         elif dbtype in ('FLOAT', 'DOUBLE', 'DOUBLE PRECISION', 'REAL'):
353             return float
354         elif dbtype in ('DECIMAL', 'NUMERIC'):
355             if db.decimal:
356                 return db.decimal.Decimal
357             elif db.fixedpoint:
358                 return db.fixedpoint.Fixedpoint
359         elif dbtype == 'DATE':
360             return datetime.date
361         elif dbtype in ('DATETIME', 'TIMESTAMP'):
362             return datetime.datetime
363         elif dbtype == 'TIME':
364             return datetime.time
365         elif dbtype in ('CHAR', 'VARCHAR', 'BINARY', 'VARBINARY',
366                         'TINYBLOB', 'TINYTEXT', 'BLOB', 'TEXT',
367                         'MEDIUMBLOB', 'MEDIUMTEXT', 'LONGBLOB', 'LONGTEXT'):
368             return str
369        
370         raise TypeError("Database type %s could not be converted "
371                         "to a Python type." % repr(dbtype))
372    
373     def quote(self, name):
374         """Return name, quoted for use in an SQL statement."""
375         return '`' + name.replace('`', '``') + '`'
376    
377     def _get_conn(self):
378         try:
379             conn = _mysql.connect(**self.connargs)
380         except _mysql.OperationalError, x:
381             if x.args[0] == 1040:   # Too many connections
382                 raise db.OutOfConnectionsError
383             raise
384         return conn
385    
386     def _template_conn(self):
387         tmplconn = self.connargs.copy()
388         tmplconn['db'] = 'mysql'
389         return _mysql.connect(**tmplconn)
390    
391     def fetch(self, query, conn=None):
392         """fetch(query, conn=None) -> rowdata, columns.
393         
394         rowdata: a nested list (or tuples), column values within rows.
395         columns: a series of 2-tuples (or more). The first tuple value
396             will be the column name, the second value will be the column
397             type.
398         """
399         if conn is None:
400             conn = self.connection()
401         self.execute(query, conn)
402         # store_result uses a client-side cursor
403         res = conn.store_result()
404         return res.fetch_row(0, 0), res.describe()
405    
406     def create_database(self):
407         # _mysql has create_db and drop_db commands, but they're deprecated.
408         sql = 'CREATE DATABASE %s;' % self.qname
409         conn = self._template_conn()
410         self.execute(sql, conn)
411         conn.close()
412    
413     def drop_database(self):
414         sql = 'DROP DATABASE %s;' % self.qname
415         conn = self._template_conn()
416         self.execute(sql, conn)
417         conn.close()
418
419
420
421 class StorageManagerMySQL(db.StorageManagerDB):
422     """StoreManager to save and retrieve Units via _mysql."""
423    
424     typeAdapter = FieldTypeAdapterMySQL()
425     databaseclass = MySQLDatabase
426    
427     def __init__(self, name, arena, allOptions={}):
428         connargs = ["host", "user", "passwd", "db", "port", "unix_socket",
429                     "conv", "connect_time", "compress", "named_pipe",
430                     "init_command", "read_default_file", "read_default_group",
431                     "cursorclass", "client_flag",
432                     ]
433         connargs = dict([(k, v) for k, v in allOptions.iteritems()
434                          if k in connargs])
435         allOptions['connargs'] = connargs
436         allOptions['name'] = connargs['db']
437         db.StorageManagerDB.__init__(self, name, arena, allOptions)
438         self.typeAdapter.db = self.db
439    
440     def destroy(self, unit):
441         """destroy(unit). Delete the unit."""
442         t = self.db[unit.__class__.__name__].qname
443         self.db.execute('DELETE FROM %s WHERE %s;' % (t, self.id_clause(unit)))
444    
445     def _seq_UnitSequencerInteger(self, unit):
446         """Reserve a unit using the table's AUTO_INCREMENT field."""
447         cls = unit.__class__
448         t = self.db[cls.__name__]
449        
450         fields = []
451         values = []
452         for key in cls.properties:
453             col = t.columns[key]
454             if col.autoincrement:
455                 # Skip this field, since we're using a sequencer
456                 continue
457             val = self.db.adaptertosql.coerce(getattr(unit, key))
458             fields.append(col.qname)
459             values.append(val)
460        
461         fields = ", ".join(fields)
462         values = ", ".join(values)
463        
464         conn = self.db.connection()
465         self.db.execute('INSERT INTO %s (%s) VALUES (%s);' %
466                         (t.qname, fields, values), conn)
467        
468         # Grab the new ID. This is threadsafe because db.reserve has a mutex.
469         setattr(unit, cls.identifiers[0], conn.insert_id())
470    
471     #                               Schemas                               #
472    
473     def create_storage(self, cls):
474         """Create storage for the given class."""
475         # Make a Table object.
476         t = self.db.make_table(cls.__name__)
477        
478         indices = cls.indices()
479         fields = []
480         for key in cls.properties:
481             # Use the superclass call to avoid ALTER TABLE.
482             dict.__setitem__(t.columns, key, self._make_column(cls, key))
483            
484             if key in indices:
485                 i = self.db.make_index(cls.__name__, key)
486                 # Use the superclass call to avoid CREATE INDEX.
487                 dict.__setitem__(t.columns.indices, key, i)
488        
489         # Hack to get PRIMARY KEY right. See MySQLTableSet.__setitem__
490         t.mysql_identifiers = [t.columns[i].name for i in cls.identifiers]
491        
492         # Attach to self.db, which should call CREATE TABLE.
493         self.db[cls.__name__] = t
494    
495     def _make_column(self, cls, key):
496         dbtype = self.typeAdapter.coerce(cls, key)
497         col = self.db.make_column(cls.__name__, key, dbtype)
498         prop = getattr(cls, key)
499         col.default = prop.default
500         col.hints = prop.hints.copy()
501        
502         if key in cls.identifiers:
503             if isinstance(cls.sequencer, dejavu.UnitSequencerInteger):
504                 col.autoincrement = True
505                 col.default = cls.sequencer.initial
506         return col
Note: See TracBrowser for help on using the browser.