Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storemysql.py

Revision 265 (checked in by fumanchu, 7 years ago)

Fixed up primary key support by moving Index.pk to Column.key, and by adding PRIMARY KEY clauses to CREATE TABLE. Also added a SM.autosource method (which cleaned up a lot of minor bugs).

  • Property svn:eol-style set to native
Line 
1 """
2 References the MySQLdb package at:
3 http://sourceforge.net/projects/mysql-python
4
5 From the MySQL manual:
6
7 "If the server SQL mode has ANSI_QUOTES enabled, string literals can be
8 quoted only with single quotes. A string quoted with double quotes will be
9 interpreted as an identifier."
10
11 So use single quotes throughout.
12 """
13
14 # Use _mysql directly to avoid all of the DB-API overhead.
15 import _mysql
16 import warnings
17 import datetime
18
19 import dejavu
20 from dejavu import storage, logic, logflags
21 from dejavu.storage import db
22
23
24 class AdapterToMySQL(db.AdapterToSQL):
25    
26     escapes = [("'", "''"), ("\\", r"\\")]
27     like_escapes = [("%", r"\%"), ("_", r"\_")]
28    
29     # TRUE and FALSE only work with 4.1 or better.
30     bool_true = "1"
31     bool_false = "0"
32    
33     def coerce_str_to_any(self, value, skip_encoding=False):
34         if not skip_encoding and not isinstance(value, str):
35             value = value.encode(self.encoding)
36         return "'" + _mysql.escape_string(value) + "'"
37    
38     def coerce_bool_to_any(self, value):
39         # TRUE and FALSE only work with 4.1 or better.
40         if value:
41             return '1'
42         return '0'
43
44
45 class AdapterFromMySQL(db.AdapterFromDB):
46    
47     def coerce_any_to_bool(self, value):
48         if isinstance(value, basestring):
49             # either '0' or '1'
50             value = (value == '1')
51         return bool(value)
52
53
54 class MySQLDecompiler(db.SQLDecompiler):
55    
56     def dejavu_today(self):
57         return "CURDATE()"
58
59
60 class MySQLDecompiler411(MySQLDecompiler):
61     # Before MySQL 4.1.1, BINARY comparisons could use UPPER()
62     # or LOWER() to perform case-insensitive comparisons. Newer
63     # versions must use CONVERT() to obtain a case-sensitive
64     # encoding, like utf8.
65    
66     def dejavu_icontainedby(self, op1, op2):
67         if isinstance(op1, db.ConstWrapper):
68             # Looking for text in a field. Use Like (reverse terms).
69             return ("CONVERT("+ op2 + " USING utf8) LIKE '%" +
70                     self.adapter.escape_like(op1) + "%'")
71         else:
72             # Looking for field in (a, b, c).
73             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
74             return "CONVERT(%s USING utf8) IN (%s)" % (op1, ", ".join(atoms))
75    
76     def dejavu_istartswith(self, x, y):
77         return ("CONVERT(" + x + " USING utf8) LIKE '" +
78                 self.adapter.escape_like(y) + "%'")
79    
80     def dejavu_iendswith(self, x, y):
81         return ("CONVERT(" + x + " USING utf8) LIKE '%" +
82                 self.adapter.escape_like(y) + "'")
83    
84     def dejavu_ieq(self, x, y):
85         return "CONVERT(" + x + " USING utf8) = " + y
86
87
88 class TypeAdapterMySQL(db.TypeAdapter):
89    
90     numeric_max_precision = 16
91     numeric_max_bytes = 8
92    
93     def float_type(self, precision):
94         """Return a datatype which can handle the given precision."""
95         # "p represents the precision in bits, but MySQL uses this value
96         # only to determine whether to use FLOAT or DOUBLE for the
97         # resulting data type. If p is from 0 to 24, the data type
98         # becomes FLOAT with no M or D values. If p is from 25 to 53,
99         # the data type becomes DOUBLE with no M or D values."
100         return "FLOAT(%s)" % precision
101    
102     def coerce_str(self, col):
103         bytes = int(col.hints.get('bytes', 255))
104        
105         if bytes:
106             # MySQL VARBINARY/BLOBs will do case-sensitive comparisons.
107             # They also won't truncate trailing spaces like VARCHAR does.
108             if bytes <= 255:
109                 return "VARBINARY(%s)" % bytes
110             elif bytes < 2 ** 16:
111                 return "BLOB"
112             elif bytes < 2 ** 24:
113                 return "MEDIUMBLOB"
114         return "LONGBLOB"
115    
116     def coerce_bool(self, col):
117         # We could use BOOLEAN, but it wasn't introduced until 4.1.0.
118         return "BOOL"
119    
120     def coerce_datetime_datetime(self, col):
121         return "DATETIME"
122    
123     def coerce_int(self, col):
124         bytes = int(col.hints.get('bytes', '4'))
125         if bytes <= 2:
126             return "SMALLINT"
127         elif bytes == 3:
128             return "MEDIUMINT"
129         return "INTEGER"
130
131
132 class TypeAdapterMySQL41(TypeAdapterMySQL):
133    
134     def coerce_str(self, col):
135         dbtype = TypeAdapterMySQL.coerce_str(self, col)
136         if dbtype == "BLOB":
137             dbtype = "BLOB(%s)" % col.hints['bytes']
138         return dbtype
139
140
141 class MySQLIndexSet(db.IndexSet):
142    
143     def __delitem__(self, key):
144         t = self.table
145         # MySQL might rename multiple-column indices to "PRIMARY"
146         for i in t.db._get_indices(t.name):
147             if i.colname == self[key].colname:
148                 t.db.execute('DROP INDEX %s ON %s;' % (i.qname, t.qname))
149
150
151 class MySQLColumnSet(db.ColumnSet):
152    
153     indexsetclass = MySQLIndexSet
154    
155     def __setitem__(self, key, column):
156         t = self.table
157        
158         dbtype = column.dbtype
159         if column.autoincrement:
160             dbtype += " AUTO_INCREMENT"
161         else:
162             default = column.default or ""
163             if default:
164                 default = self.adaptertosql.coerce(default, dbtype)
165                 dbtype = "%s DEFAULT %s" % (dbtype, default)
166        
167         t.db.execute("ALTER TABLE %s ADD COLUMN %s %s;" %
168                      (t.qname, column.qname, dbtype))
169         dict.__setitem__(self, key, column)
170    
171     def _rename(self, oldcol, newcol):
172         # Override this to do the actual rename at the DB level.
173         t = self.table
174         t.db.execute("ALTER TABLE %s CHANGE %s %s %s;" %
175                      (t.qname, oldcol.qname, newcol.qname, oldcol.dbtype))
176
177
178 class MySQLDatabase(db.Database):
179    
180     sql_name_max_length = 64
181     # MySQL uses case-sensitive database and table names on Unix, but
182     # not on Windows. Use all-lowercase identifiers to work around the
183     # problem. "Column names, index names, and column aliases are not
184     # case sensitive on any platform."
185     # If deployers set lower_case_table_names to 1, it would help.
186     sql_name_caseless = True
187    
188     adaptertosql = AdapterToMySQL()
189     adapterfromdb = AdapterFromMySQL()
190     typeadapter = TypeAdapterMySQL()
191    
192     columnsetclass = MySQLColumnSet
193     indexsetclass = MySQLIndexSet
194    
195     def __init__(self, name, **kwargs):
196         db.Database.__init__(self, name, **kwargs)
197        
198         self.decompiler = MySQLDecompiler
199        
200         # Get the version string from MySQL, to see if we need
201         # a different decompiler.
202         conn = self._template_conn()
203         rowdata, cols = self.fetch("SELECT version();", conn)
204         conn.close()
205         v = rowdata[0][0]
206         self._version = storage.Version(v)
207        
208         # decompiler
209         if self._version > storage.Version("4.1.1"):
210             self.decompiler = MySQLDecompiler411
211        
212         # type adapter
213         if self._version >= storage.Version("4.1"):
214             self.typeadapter = TypeAdapterMySQL41()
215    
216     def version(self):
217         return "MySQL Version: %s" % self._version
218    
219     def __setitem__(self, key, table):
220         q = self.quote
221        
222         fields = []
223         incr_fields = []
224         pk = []
225         for colkey, col in table.columns.iteritems():
226             qname = col.qname
227             dbtype = col.dbtype
228            
229             if col.autoincrement:
230                 dbtype += " AUTO_INCREMENT"
231                 # INSERT INTO t (c) VALUES(0) doesn't work for some reason
232                 if col.default > 1:
233                     incr_fields.append(col)
234             else:
235                 default = col.default or ""
236                 if default:
237                     default = self.adaptertosql.coerce(default, dbtype)
238                     dbtype = "%s DEFAULT %s" % (dbtype, default)
239            
240             fields.append('%s %s' % (qname, dbtype))
241            
242             if col.key:
243                 if dbtype.endswith('BLOB') or dbtype == 'TEXT':
244                     # MySQL won't allow indexes on a BLOB field without a
245                     # specific index prefix length. We choose 255 just for fun.
246                     qname = "%s(255)" % qname
247                 pk.append(qname)
248        
249         if pk:
250             pk = ", PRIMARY KEY (%s)" % ", ".join(pk)
251         else:
252             pk = ""
253        
254         self.execute('CREATE TABLE %s (%s%s);' %
255                      (table.qname, ", ".join(fields), pk))
256        
257         if incr_fields:
258             # Wow, what a hack. We have to INSERT a dummy row to set the
259             # autoincrement initial value(s), and we can't delete it until
260             # after the CREATE INDEX statements (or the counter will revert).
261             fields = ", ".join([col.qname for col in incr_fields])
262             values = ", ".join([str(col.default - 1) for col in incr_fields])
263             self.execute("INSERT INTO %s (%s) VALUES (%s);"
264                          % (table.qname, fields, values))
265        
266         for k, index in table.columns.indices.iteritems():
267             dbtype = table.columns[k].dbtype
268             if dbtype.endswith('BLOB') or dbtype == 'TEXT':
269                 # MySQL won't allow indexes on a BLOB field without a
270                 # specific index prefix length. We choose 255 just for fun.
271                 self.execute('CREATE INDEX %s ON %s (%s(255));' %
272                              (index.qname, table.qname, q(index.colname)))
273             else:
274                 self.execute('CREATE INDEX %s ON %s (%s);' %
275                              (index.qname, table.qname, q(index.colname)))
276        
277         if incr_fields:
278             self.execute("DELETE FROM %s" % table.qname)
279        
280         dict.__setitem__(self, key, table)
281    
282     def _get_tables(self, conn=None):
283         data, _ = self.fetch("SHOW TABLES FROM %s" % self.qname, conn=conn)
284         return [db.Table(self, row[0], self.quote(row[0])) for row in data]
285    
286     def _get_columns(self, tablename, conn=None):
287         # cols are: Field, Type, Null, Key, Default, Extra.
288         # See http://dev.mysql.com/doc/refman/4.1/en/describe.html
289         data, _ = self.fetch("SHOW COLUMNS FROM %s.%s" %
290                              (self.qname, self.quote(tablename)), conn=conn)
291         cols = []
292         for row in data:
293             c = db.Column(row[0], self.quote(row[0]), None, None)
294            
295             dbtype = row[1].upper()
296             parenpos = dbtype.find("(")
297             if parenpos > -1:
298                 args = dbtype[parenpos+1:-1]
299                 baretype = dbtype[:parenpos]
300                 if baretype in ("DECIMAL", "NUMERIC"):
301                     args = [x.strip() for x in args.split(",")]
302                     c.hints['precision'], c.hints['scale'] = args
303                 else:
304                     c.hints['bytes'] = args
305             elif dbtype == "FLOAT":
306                 c.hints['precision'] = 24
307             elif dbtype.startswith("DOUBLE"):
308                 c.hints['precision'] = 53
309             elif dbtype in ("TINYBLOB", "TINYTEXT"):
310                 c.hints['bytes'] = (2 ** 8) - 1
311             elif dbtype in ("BLOB", "TEXT"):
312                 c.hints['bytes'] = (2 ** 16) - 1
313             elif dbtype in ("MEDIUMBLOB", "MEDIUMTEXT"):
314                 c.hints['bytes'] = (2 ** 24) - 1
315             elif dbtype in ("LONGBLOB", "LONGTEXT"):
316                 c.hints['bytes'] = (2 ** 32) - 1
317            
318             c.key = bool(row[3])
319            
320             if row[4]:
321                 c.default = self.python_type(dbtype)(row[4])
322            
323             if "auto_increment" in row[5].lower():
324                 c.autoincrement = True
325            
326             c.dbtype = dbtype
327            
328             cols.append(c)
329         return cols
330    
331     def _get_indices(self, tablename, conn=None):
332         indices = []
333         try:
334             # cols are: Table, Non_unique, Key_name, Seq_in_index, Column_name,
335             # Collation, Cardinality, Sub_part, Packed, Null, Index_type, Comment
336             data, _ = self.fetch("SHOW INDEX FROM %s.%s"
337                                  % (self.qname, self.quote(tablename)),
338                                  conn=conn)
339         except _mysql.ProgrammingError, x:
340             if x.args[0] != 1146:
341                 raise
342         else:
343             for row in data:
344                 i = db.Index(row[2], self.quote(row[2]),
345                              row[0], row[4], not row[1])
346                 indices.append(i)
347         return indices
348    
349     def python_type(self, dbtype):
350         """Return a Python type which can store values of the given dbtype."""
351         dbtype = dbtype.upper()
352         parenpos = dbtype.find("(")
353         if parenpos > -1:
354             dbtype = dbtype[:parenpos]
355        
356         if dbtype in ('TINYINT', 'SMALLINT', 'MEDIUMINT', 'INT', 'INTEGER'):
357             return int
358         elif dbtype == 'BIGINT':
359             return long
360         elif dbtype in ('BOOL', 'BOOLEAN'):
361             return bool
362         elif dbtype in ('FLOAT', 'DOUBLE', 'DOUBLE PRECISION', 'REAL'):
363             return float
364         elif dbtype in ('DECIMAL', 'NUMERIC'):
365             if db.decimal:
366                 return db.decimal.Decimal
367             elif db.fixedpoint:
368                 return db.fixedpoint.Fixedpoint
369         elif dbtype == 'DATE':
370             return datetime.date
371         elif dbtype in ('DATETIME', 'TIMESTAMP'):
372             return datetime.datetime
373         elif dbtype == 'TIME':
374             return datetime.time
375         elif dbtype in ('CHAR', 'VARCHAR', 'BINARY', 'VARBINARY',
376                         'TINYBLOB', 'TINYTEXT', 'BLOB', 'TEXT',
377                         'MEDIUMBLOB', 'MEDIUMTEXT', 'LONGBLOB', 'LONGTEXT'):
378             return str
379        
380         raise TypeError("Database type %r could not be converted "
381                         "to a Python type." % dbtype)
382    
383     def quote(self, name):
384         """Return name, quoted for use in an SQL statement."""
385         return '`' + name.replace('`', '``') + '`'
386    
387     def _get_conn(self):
388         try:
389             conn = _mysql.connect(**self.connargs)
390         except _mysql.OperationalError, x:
391             if x.args[0] == 1040:   # Too many connections
392                 raise db.OutOfConnectionsError
393             raise
394         return conn
395    
396     def _template_conn(self):
397         tmplconn = self.connargs.copy()
398         tmplconn['db'] = 'mysql'
399         return _mysql.connect(**tmplconn)
400    
401     def execute(self, query, conn=None):
402         """execute(query, conn=None) -> result set."""
403         if conn is None:
404             conn = self.connection()
405         if isinstance(query, unicode):
406             query = query.encode(self.adaptertosql.encoding)
407         self.log(query, logflags.SQL)
408         try:
409             return conn.query(query)
410         except _mysql.OperationalError, x:
411             if x.args[0] == 1030 and x.args[1] == 'Got error 139 from storage engine':
412                 raise ValueError("row length exceeds 8000 byte limit")
413             raise
414    
415     def fetch(self, query, conn=None):
416         """fetch(query, conn=None) -> rowdata, columns.
417         
418         rowdata: a nested list (or tuples), column values within rows.
419         columns: a series of 2-tuples (or more). The first tuple value
420             will be the column name, the second value will be the column
421             type.
422         """
423         if conn is None:
424             conn = self.connection()
425         self.execute(query, conn)
426         # store_result uses a client-side cursor
427         res = conn.store_result()
428         return res.fetch_row(0, 0), res.describe()
429    
430     def create_database(self):
431         # _mysql has create_db and drop_db commands, but they're deprecated.
432         sql = 'CREATE DATABASE %s;' % self.qname
433         conn = self._template_conn()
434         self.execute(sql, conn)
435         conn.close()
436         self.clear()
437    
438     def drop_database(self):
439         sql = 'DROP DATABASE %s;' % self.qname
440         conn = self._template_conn()
441         self.execute(sql, conn)
442         conn.close()
443         self.clear()
444
445
446
447 class StorageManagerMySQL(db.StorageManagerDB):
448     """StoreManager to save and retrieve Units via _mysql."""
449    
450     databaseclass = MySQLDatabase
451    
452     def __init__(self, arena, allOptions={}):
453         connargs = ["host", "user", "passwd", "db", "port", "unix_socket",
454                     "conv", "connect_time", "compress", "named_pipe",
455                     "init_command", "read_default_file", "read_default_group",
456                     "cursorclass", "client_flag",
457                     ]
458         connargs = dict([(k, v) for k, v in allOptions.iteritems()
459                          if k in connargs])
460         allOptions['connargs'] = connargs
461         allOptions['name'] = connargs['db']
462         db.StorageManagerDB.__init__(self, arena, allOptions)
463    
464     def destroy(self, unit):
465         """destroy(unit). Delete the unit."""
466         t = self.db[unit.__class__.__name__].qname
467         self.db.execute('DELETE FROM %s WHERE %s;' % (t, self.id_clause(unit)))
468    
469     def _seq_UnitSequencerInteger(self, unit):
470         """Reserve a unit using the table's AUTO_INCREMENT field."""
471         cls = unit.__class__
472         t = self.db[cls.__name__]
473        
474         fields = []
475         values = []
476         for key in cls.properties:
477             col = t.columns[key]
478             if col.autoincrement:
479                 # Skip this field, since we're using a sequencer
480                 continue
481             val = self.db.adaptertosql.coerce(getattr(unit, key), col.dbtype)
482             fields.append(col.qname)
483             values.append(val)
484        
485         fields = ", ".join(fields)
486         values = ", ".join(values)
487        
488         conn = self.db.connection()
489         self.db.execute('INSERT INTO %s (%s) VALUES (%s);' %
490                         (t.qname, fields, values), conn)
491        
492         # Grab the new ID. This is threadsafe because db.reserve has a mutex.
493         setattr(unit, cls.identifiers[0], conn.insert_id())
Note: See TracBrowser for help on using the browser.