Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storemysql.py

Revision 303 (checked in by fumanchu, 7 years ago)

Initial fix for #4 (transaction support). Tests pass, but this is not to be used in production yet!

  • Property svn:eol-style set to native
Line 
1 """
2 References the MySQLdb package at:
3 http://sourceforge.net/projects/mysql-python
4
5 From the MySQL manual:
6
7 "If the server SQL mode has ANSI_QUOTES enabled, string literals can be
8 quoted only with single quotes. A string quoted with double quotes will be
9 interpreted as an identifier."
10
11 So use single quotes throughout.
12 """
13
14 # Use _mysql directly to avoid all of the DB-API overhead.
15 import _mysql
16 import warnings
17 import datetime
18
19 import dejavu
20 from dejavu import storage, logic
21 from dejavu.storage import db
22
23
24 class AdapterToMySQL(db.AdapterToSQL):
25    
26     escapes = [("'", "''"), ("\\", r"\\")]
27     like_escapes = [("%", r"\%"), ("_", r"\_")]
28    
29     # TRUE and FALSE only work with 4.1 or better.
30     bool_true = "1"
31     bool_false = "0"
32    
33     def coerce_str_to_any(self, value, skip_encoding=False):
34         if not skip_encoding and not isinstance(value, str):
35             value = value.encode(self.encoding)
36         return "'" + _mysql.escape_string(value) + "'"
37    
38     def coerce_bool_to_any(self, value):
39         # TRUE and FALSE only work with 4.1 or better.
40         if value:
41             return '1'
42         return '0'
43
44
45 class AdapterFromMySQL(db.AdapterFromDB):
46    
47     def coerce_any_to_bool(self, value):
48         if isinstance(value, basestring):
49             # either '0' or '1'
50             value = (value == '1')
51         return bool(value)
52
53
54 class MySQLDecompiler(db.SQLDecompiler):
55    
56     def dejavu_today(self):
57         return "CURDATE()"
58
59
60 class MySQLDecompiler411(MySQLDecompiler):
61     # Before MySQL 4.1.1, BINARY comparisons could use UPPER()
62     # or LOWER() to perform case-insensitive comparisons. Newer
63     # versions must use CONVERT() to obtain a case-sensitive
64     # encoding, like utf8.
65    
66     def dejavu_icontainedby(self, op1, op2):
67         if isinstance(op1, db.ConstWrapper):
68             # Looking for text in a field. Use Like (reverse terms).
69             return ("CONVERT("+ op2 + " USING utf8) LIKE '%" +
70                     self.adapter.escape_like(op1) + "%'")
71         else:
72             # Looking for field in (a, b, c).
73             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
74             return "CONVERT(%s USING utf8) IN (%s)" % (op1, ", ".join(atoms))
75    
76     def dejavu_istartswith(self, x, y):
77         return ("CONVERT(" + x + " USING utf8) LIKE '" +
78                 self.adapter.escape_like(y) + "%'")
79    
80     def dejavu_iendswith(self, x, y):
81         return ("CONVERT(" + x + " USING utf8) LIKE '%" +
82                 self.adapter.escape_like(y) + "'")
83    
84     def dejavu_ieq(self, x, y):
85         return "CONVERT(" + x + " USING utf8) = " + y
86
87
88 class TypeAdapterMySQL(db.TypeAdapter):
89    
90     numeric_max_precision = 16
91     numeric_max_bytes = 8
92    
93     def float_type(self, precision):
94         """Return a datatype which can handle the given precision."""
95         # "p represents the precision in bits, but MySQL uses this value
96         # only to determine whether to use FLOAT or DOUBLE for the
97         # resulting data type. If p is from 0 to 24, the data type
98         # becomes FLOAT with no M or D values. If p is from 25 to 53,
99         # the data type becomes DOUBLE with no M or D values."
100         return "FLOAT(%s)" % precision
101    
102     def coerce_str(self, col):
103         bytes = int(col.hints.get('bytes', 255))
104        
105         if bytes:
106             # MySQL VARBINARY/BLOBs will do case-sensitive comparisons.
107             # They also won't truncate trailing spaces like VARCHAR does.
108             if bytes <= 255:
109                 return "VARBINARY(%s)" % bytes
110             elif bytes < 2 ** 16:
111                 return "BLOB"
112             elif bytes < 2 ** 24:
113                 return "MEDIUMBLOB"
114         return "LONGBLOB"
115    
116     def coerce_bool(self, col):
117         # We could use BOOLEAN, but it wasn't introduced until 4.1.0.
118         return "BOOL"
119    
120     def coerce_datetime_datetime(self, col):
121         return "DATETIME"
122    
123     def coerce_int(self, col):
124         bytes = int(col.hints.get('bytes', '4'))
125         if bytes <= 2:
126             return "SMALLINT"
127         elif bytes == 3:
128             return "MEDIUMINT"
129         return "INTEGER"
130
131
132 class TypeAdapterMySQL41(TypeAdapterMySQL):
133    
134     def coerce_str(self, col):
135         dbtype = TypeAdapterMySQL.coerce_str(self, col)
136         if dbtype == "BLOB":
137             dbtype = "BLOB(%s)" % col.hints['bytes']
138         return dbtype
139
140
141 class MySQLIndexSet(db.IndexSet):
142    
143     def __delitem__(self, key):
144         t = self.table
145         t.db.lock("Dropping index. Transactions not allowed.")
146         try:
147             # MySQL might rename multiple-column indices to "PRIMARY"
148             for i in t.db._get_indices(t.name):
149                 if i.colname == self[key].colname:
150                     t.db.execute('DROP INDEX %s ON %s;' % (i.qname, t.qname))
151         finally:
152             t.db.unlock()
153
154
155 class MySQLColumnSet(db.ColumnSet):
156    
157     indexsetclass = MySQLIndexSet
158    
159     def __setitem__(self, key, column):
160         t = self.table
161        
162         dbtype = column.dbtype
163         if column.autoincrement:
164             dbtype += " AUTO_INCREMENT"
165         else:
166             default = column.default or ""
167             if default:
168                 default = self.adaptertosql.coerce(default, dbtype)
169                 dbtype = "%s DEFAULT %s" % (dbtype, default)
170        
171         t.db.lock("Adding property. Transactions not allowed.")
172         try:
173             t.db.execute("ALTER TABLE %s ADD COLUMN %s %s;" %
174                          (t.qname, column.qname, dbtype))
175             dict.__setitem__(self, key, column)
176         finally:
177             t.db.unlock()
178    
179     def _rename(self, oldcol, newcol):
180         # Override this to do the actual rename at the DB level.
181         t = self.table
182         t.db.execute("ALTER TABLE %s CHANGE %s %s %s;" %
183                      (t.qname, oldcol.qname, newcol.qname, oldcol.dbtype))
184
185
186 class MySQLDatabase(db.Database):
187    
188     sql_name_max_length = 64
189     # MySQL uses case-sensitive database and table names on Unix, but
190     # not on Windows. Use all-lowercase identifiers to work around the
191     # problem. "Column names, index names, and column aliases are not
192     # case sensitive on any platform."
193     # If deployers set lower_case_table_names to 1, it would help.
194     sql_name_caseless = True
195     encoding = "utf8"
196    
197     adaptertosql = AdapterToMySQL()
198     adapterfromdb = AdapterFromMySQL()
199     typeadapter = TypeAdapterMySQL()
200    
201     columnsetclass = MySQLColumnSet
202     indexsetclass = MySQLIndexSet
203    
204     def __init__(self, name, **kwargs):
205         db.Database.__init__(self, name, **kwargs)
206        
207         self.decompiler = MySQLDecompiler
208        
209         # Get the version string from MySQL, to see if we need
210         # a different decompiler.
211         conn = self._template_conn()
212         rowdata, cols = self.fetch("SELECT version();", conn)
213         conn.close()
214         v = rowdata[0][0]
215         self._version = storage.Version(v)
216        
217         # decompiler
218         if self._version > storage.Version("4.1.1"):
219             self.decompiler = MySQLDecompiler411
220        
221         # type adapter
222         if self._version >= storage.Version("4.1"):
223             self.typeadapter = TypeAdapterMySQL41()
224    
225     def version(self):
226         return "MySQL Version: %s" % self._version
227    
228     def __setitem__(self, key, table):
229         q = self.quote
230        
231         fields = []
232         incr_fields = []
233         pk = []
234         for colkey, col in table.columns.iteritems():
235             qname = col.qname
236             dbtype = col.dbtype
237            
238             if col.autoincrement:
239                 dbtype += " AUTO_INCREMENT"
240                 # INSERT INTO t (c) VALUES(0) doesn't work for some reason
241                 if col.default > 1:
242                     incr_fields.append(col)
243             else:
244                 default = col.default or ""
245                 if default:
246                     default = self.adaptertosql.coerce(default, dbtype)
247                     dbtype = "%s DEFAULT %s" % (dbtype, default)
248            
249             fields.append('%s %s' % (qname, dbtype))
250            
251             if col.key:
252                 if dbtype.endswith('BLOB') or dbtype == 'TEXT':
253                     # MySQL won't allow indexes on a BLOB field without a
254                     # specific index prefix length. We choose 255 just for fun.
255                     qname = "%s(255)" % qname
256                 pk.append(qname)
257        
258         if pk:
259             pk = ", PRIMARY KEY (%s)" % ", ".join(pk)
260         else:
261             pk = ""
262        
263         self.lock("Creating storage. Transactions not allowed.")
264         try:
265             self.execute('CREATE TABLE %s (%s%s);' %
266                          (table.qname, ", ".join(fields), pk))
267            
268             if incr_fields:
269                 # Wow, what a hack. We have to INSERT a dummy row to set the
270                 # autoincrement initial value(s), and we can't delete it until
271                 # after the CREATE INDEX statements (or the counter will revert).
272                 fields = ", ".join([col.qname for col in incr_fields])
273                 values = ", ".join([str(col.default - 1) for col in incr_fields])
274                 self.execute("INSERT INTO %s (%s) VALUES (%s);"
275                              % (table.qname, fields, values))
276            
277             for k, index in table.columns.indices.iteritems():
278                 dbtype = table.columns[k].dbtype
279                 if dbtype.endswith('BLOB') or dbtype == 'TEXT':
280                     # MySQL won't allow indexes on a BLOB field without a
281                     # specific index prefix length. We choose 255 just for fun.
282                     self.execute('CREATE INDEX %s ON %s (%s(255));' %
283                                  (index.qname, table.qname, q(index.colname)))
284                 else:
285                     self.execute('CREATE INDEX %s ON %s (%s);' %
286                                  (index.qname, table.qname, q(index.colname)))
287            
288             if incr_fields:
289                 self.execute("DELETE FROM %s" % table.qname)
290         finally:
291             self.unlock()
292        
293         dict.__setitem__(self, key, table)
294    
295     def _get_tables(self, conn=None):
296         data, _ = self.fetch("SHOW TABLES FROM %s" % self.qname, conn=conn)
297         return [db.Table(self, row[0], self.quote(row[0])) for row in data]
298    
299     def _get_columns(self, tablename, conn=None):
300         # cols are: Field, Type, Null, Key, Default, Extra.
301         # See http://dev.mysql.com/doc/refman/4.1/en/describe.html
302         data, _ = self.fetch("SHOW COLUMNS FROM %s.%s" %
303                              (self.qname, self.quote(tablename)), conn=conn)
304         cols = []
305         for row in data:
306             c = db.Column(row[0], self.quote(row[0]), None, None)
307            
308             dbtype = row[1].upper()
309             parenpos = dbtype.find("(")
310             if parenpos > -1:
311                 args = dbtype[parenpos+1:-1]
312                 baretype = dbtype[:parenpos]
313                 if baretype in ("DECIMAL", "NUMERIC"):
314                     args = [x.strip() for x in args.split(",")]
315                     c.hints['precision'], c.hints['scale'] = args
316                 else:
317                     c.hints['bytes'] = args
318             elif dbtype == "FLOAT":
319                 c.hints['precision'] = 24
320             elif dbtype.startswith("DOUBLE"):
321                 c.hints['precision'] = 53
322             elif dbtype in ("TINYBLOB", "TINYTEXT"):
323                 c.hints['bytes'] = (2 ** 8) - 1
324             elif dbtype in ("BLOB", "TEXT"):
325                 c.hints['bytes'] = (2 ** 16) - 1
326             elif dbtype in ("MEDIUMBLOB", "MEDIUMTEXT"):
327                 c.hints['bytes'] = (2 ** 24) - 1
328             elif dbtype in ("LONGBLOB", "LONGTEXT"):
329                 c.hints['bytes'] = (2 ** 32) - 1
330            
331             c.key = bool(row[3])
332            
333             if row[4]:
334                 c.default = self.python_type(dbtype)(row[4])
335            
336             if "auto_increment" in row[5].lower():
337                 c.autoincrement = True
338            
339             c.dbtype = dbtype
340            
341             cols.append(c)
342         return cols
343    
344     def _get_indices(self, tablename, conn=None):
345         indices = []
346         try:
347             # cols are: Table, Non_unique, Key_name, Seq_in_index, Column_name,
348             # Collation, Cardinality, Sub_part, Packed, Null, Index_type, Comment
349             data, _ = self.fetch("SHOW INDEX FROM %s.%s"
350                                  % (self.qname, self.quote(tablename)),
351                                  conn=conn)
352         except _mysql.ProgrammingError, x:
353             if x.args[0] != 1146:
354                 raise
355         else:
356             for row in data:
357                 i = db.Index(row[2], self.quote(row[2]),
358                              row[0], row[4], not row[1])
359                 indices.append(i)
360         return indices
361    
362     def python_type(self, dbtype):
363         """Return a Python type which can store values of the given dbtype."""
364         dbtype = dbtype.upper()
365         parenpos = dbtype.find("(")
366         if parenpos > -1:
367             dbtype = dbtype[:parenpos]
368        
369         if dbtype in ('TINYINT', 'SMALLINT', 'MEDIUMINT', 'INT', 'INTEGER'):
370             return int
371         elif dbtype == 'BIGINT':
372             return long
373         elif dbtype in ('BOOL', 'BOOLEAN'):
374             return bool
375         elif dbtype in ('FLOAT', 'DOUBLE', 'DOUBLE PRECISION', 'REAL'):
376             return float
377         elif dbtype in ('DECIMAL', 'NUMERIC'):
378             if db.decimal:
379                 return db.decimal.Decimal
380             elif db.fixedpoint:
381                 return db.fixedpoint.Fixedpoint
382         elif dbtype == 'DATE':
383             return datetime.date
384         elif dbtype in ('DATETIME', 'TIMESTAMP'):
385             return datetime.datetime
386         elif dbtype == 'TIME':
387             return datetime.time
388         elif dbtype in ('CHAR', 'VARCHAR', 'BINARY', 'VARBINARY',
389                         'TINYBLOB', 'TINYTEXT', 'BLOB', 'TEXT',
390                         'MEDIUMBLOB', 'MEDIUMTEXT', 'LONGBLOB', 'LONGTEXT'):
391             return str
392        
393         raise TypeError("Database type %r could not be converted "
394                         "to a Python type." % dbtype)
395    
396     def quote(self, name):
397         """Return name, quoted for use in an SQL statement."""
398         return '`' + name.replace('`', '``') + '`'
399    
400     def _get_conn(self):
401         try:
402             conn = _mysql.connect(**self.connargs)
403         except _mysql.OperationalError, x:
404             if x.args[0] == 1040:   # Too many connections
405                 raise db.OutOfConnectionsError
406             raise
407         return conn
408    
409     def _template_conn(self):
410         tmplconn = self.connargs.copy()
411         tmplconn['db'] = 'mysql'
412         return _mysql.connect(**tmplconn)
413    
414     def execute(self, query, conn=None):
415         """execute(query, conn=None) -> result set."""
416         if conn is None:
417             conn = self.connection()
418         if isinstance(query, unicode):
419             query = query.encode(self.adaptertosql.encoding)
420         self.log(query)
421         try:
422             return conn.query(query)
423         except _mysql.OperationalError, x:
424             if x.args[0] == 1030 and x.args[1] == 'Got error 139 from storage engine':
425                 raise ValueError("row length exceeds 8000 byte limit")
426             raise
427    
428     def fetch(self, query, conn=None):
429         """fetch(query, conn=None) -> rowdata, columns.
430         
431         rowdata: a nested list (or tuples), column values within rows.
432         columns: a series of 2-tuples (or more). The first tuple value
433             will be the column name, the second value will be the column
434             type.
435         """
436         if conn is None:
437             conn = self.connection()
438         self.execute(query, conn)
439         # store_result uses a client-side cursor
440         res = conn.store_result()
441         return res.fetch_row(0, 0), res.describe()
442    
443     def create_database(self):
444         self.lock("Creating database. Transactions not allowed.")
445         try:
446             # _mysql has create_db and drop_db commands, but they're deprecated.
447             encoding = self.encoding
448             if encoding:
449                 encoding = " CHARACTER SET %s" % encoding
450             sql = 'CREATE DATABASE %s%s;' % (self.qname, encoding)
451             conn = self._template_conn()
452             self.execute(sql, conn)
453             conn.close()
454             self.clear()
455         finally:
456             self.unlock()
457    
458     def drop_database(self):
459         self.lock("Dropping database. Transactions not allowed.")
460         try:
461             sql = 'DROP DATABASE %s;' % self.qname
462             conn = self._template_conn()
463             self.execute(sql, conn)
464             conn.close()
465             self.clear()
466         finally:
467             self.unlock()
468
469
470
471 class StorageManagerMySQL(db.StorageManagerDB):
472     """StoreManager to save and retrieve Units via _mysql."""
473    
474     databaseclass = MySQLDatabase
475    
476     def __init__(self, arena, allOptions={}):
477         connargs = ["host", "user", "passwd", "db", "port", "unix_socket",
478                     "conv", "connect_time", "compress", "named_pipe",
479                     "init_command", "read_default_file", "read_default_group",
480                     "cursorclass", "client_flag",
481                     ]
482         connargs = dict([(k, v) for k, v in allOptions.iteritems()
483                          if k in connargs])
484         allOptions['connargs'] = connargs
485         allOptions['name'] = connargs['db']
486         db.StorageManagerDB.__init__(self, arena, allOptions)
487    
488     def _seq_UnitSequencerInteger(self, unit):
489         """Reserve a unit using the table's AUTO_INCREMENT field."""
490         cls = unit.__class__
491         t = self.db[cls.__name__]
492        
493         fields = []
494         values = []
495         for key in cls.properties:
496             col = t.columns[key]
497             if col.autoincrement:
498                 # Skip this field, since we're using a sequencer
499                 continue
500             val = self.db.adaptertosql.coerce(getattr(unit, key), col.dbtype)
501             fields.append(col.qname)
502             values.append(val)
503        
504         fields = ", ".join(fields)
505         values = ", ".join(values)
506        
507         conn = self.db.get_transaction()
508         self.db.execute('INSERT INTO %s (%s) VALUES (%s);' %
509                         (t.qname, fields, values), conn)
510        
511         # Grab the new ID. This is threadsafe because db.reserve has a mutex.
512         setattr(unit, cls.identifiers[0], conn.insert_id())
513
Note: See TracBrowser for help on using the browser.