Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storemysql.py

Revision 227 (checked in by fumanchu, 7 years ago)

Table, Column, and Index objects now all have a qname attribute (the quoted version of their name).

  • Property svn:eol-style set to native
Line 
1 """
2 References the MySQLdb package at:
3 http://sourceforge.net/projects/mysql-python
4
5 From the MySQL manual:
6
7 "If the server SQL mode has ANSI_QUOTES enabled, string literals can be
8 quoted only with single quotes. A string quoted with double quotes will be
9 interpreted as an identifier."
10
11 So use single quotes throughout.
12 """
13
14 # Use _mysql directly to avoid all of the DB-API overhead.
15 import _mysql
16 import warnings
17 import datetime
18
19 try:
20     # Builtin in Python 2.5?
21     decimal
22 except NameError:
23     try:
24         # Module in Python 2.3, 2.4
25         import decimal
26     except ImportError:
27         pass
28
29 import dejavu
30 from dejavu import storage, logic
31 from dejavu.storage import db
32
33
34 class AdapterToMySQL(db.AdapterToSQL):
35    
36     escapes = [("'", "''"), ("\\", r"\\")]
37     like_escapes = [("%", r"\%"), ("_", r"\_")]
38    
39     # TRUE and FALSE only work with 4.1 or better.
40     bool_true = "1"
41     bool_false = "0"
42    
43     def coerce_str(self, value, skip_encoding=False):
44         if not skip_encoding and not isinstance(value, str):
45             value = value.encode(self.encoding)
46         return "'" + _mysql.escape_string(value) + "'"
47    
48     def coerce_bool(self, value):
49         # TRUE and FALSE only work with 4.1 or better.
50         if value:
51             return '1'
52         return '0'
53
54
55 class AdapterFromMySQL(db.AdapterFromDB):
56    
57     def coerce_bool(self, value, coltype):
58         if isinstance(value, basestring):
59             # either '0' or '1'
60             value = (value == '1')
61         return bool(value)
62
63
64 class MySQLDecompiler(db.SQLDecompiler):
65    
66     def dejavu_today(self):
67         return "CURDATE()"
68
69
70 class MySQLDecompiler411(MySQLDecompiler):
71     # Before MySQL 4.1.1, BINARY comparisons could use UPPER()
72     # or LOWER() to perform case-insensitive comparisons. Newer
73     # versions must use CONVERT() to obtain a case-sensitive
74     # encoding, like utf8.
75    
76     def dejavu_icontainedby(self, op1, op2):
77         if isinstance(op1, db.ConstWrapper):
78             # Looking for text in a field. Use Like (reverse terms).
79             return ("CONVERT("+ op2 + " USING utf8) LIKE '%" +
80                     self.adapter.escape_like(op1) + "%'")
81         else:
82             # Looking for field in (a, b, c).
83             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
84             return "CONVERT(%s USING utf8) IN (%s)" % (op1, ", ".join(atoms))
85    
86     def dejavu_istartswith(self, x, y):
87         return ("CONVERT(" + x + " USING utf8) LIKE '" +
88                 self.adapter.escape_like(y) + "%'")
89    
90     def dejavu_iendswith(self, x, y):
91         return ("CONVERT(" + x + " USING utf8) LIKE '%" +
92                 self.adapter.escape_like(y) + "'")
93    
94     def dejavu_ieq(self, x, y):
95         return "CONVERT(" + x + " USING utf8) = " + y
96
97
98 class FieldTypeAdapterMySQL(db.FieldTypeAdapter):
99     """Return the SQL typename of a DB column."""
100    
101     float_max_precision = 53
102     numeric_max_precision = 16
103     numeric_max_bytes = 8
104    
105     def float_type(self, precision):
106         """Return a datatype which can handle the given precision."""
107         if precision <= 23:
108             return "FLOAT"
109         else:
110             return "DOUBLE PRECISION"
111    
112     def coerce_str(self, cls, key):
113         prop = getattr(cls, key)
114         bytes = int(prop.hints.get('bytes', '0'))
115         if bytes:
116             # MySQL VARBINARY/BLOBs will do case-sensitive comparisons.
117             # They also won't truncate trailing spaces like VARCHAR does.
118             if bytes <= 255:
119                 return "VARBINARY(%s)" % bytes
120             elif bytes < 2 ** 16:
121                 return "BLOB"
122             elif bytes < 2 ** 24:
123                 return "MEDIUMBLOB"
124         return "LONGBLOB"
125    
126     def coerce_bool(self, cls, key):
127         # We could use BOOLEAN, but it wasn't introduced until 4.1.0.
128         return "BOOL"
129    
130     def coerce_datetime_datetime(self, cls, key):
131         return "DATETIME"
132    
133     def coerce_int(self, cls, key):
134         prop = getattr(cls, key)
135         bytes = int(prop.hints.get('bytes', '4'))
136         if bytes == 1:
137             return "BOOLEAN"
138         elif bytes == 2:
139             typename = "SMALLINT"
140         elif bytes == 3:
141             typename = "MEDIUMINT"
142         else:
143             typename = "INTEGER"
144         if isinstance(cls.sequencer, dejavu.UnitSequencerInteger):
145             if key in cls.identifiers:
146                 typename += " AUTO_INCREMENT"
147         return typename
148
149
150
151 class MySQLIndexSet(db.IndexSet):
152    
153     def __delitem__(self, key):
154         t = self.table
155         # MySQL might rename multiple-column indices to "PRIMARY"
156         for i in t.sm.tables._get_indices(t.name):
157             if i.colname == self[key].colname:
158                 t.sm.execute('DROP INDEX %s ON %s;' % (i.qname, t.qname))
159
160
161 class MySQLColumnSet(db.ColumnSet):
162    
163     def _rename(self, oldcol, newname):
164         # Override this to do the actual rename at the DB level.
165         t = self.table
166         t.sm.execute("ALTER TABLE %s CHANGE %s %s %s;" %
167                      (t.qname, oldcol.qname,
168                       t.sm.quote(newname), oldcol.dbtype))
169
170
171 class MySQLTableSet(db.TableSet):
172    
173     def __setitem__(self, key, table):
174         q = self.sm.quote
175        
176         fields = []
177         pk = []
178         for colname, col in table.columns.iteritems():
179             qname = col.qname
180             dbtype = col.dbtype
181             fields.append('%s %s' % (qname, dbtype))
182             if colname in table.mysql_identifiers:
183                 if dbtype.endswith('BLOB') or dbtype == 'TEXT':
184                     # MySQL won't allow indexes on a BLOB field
185                     # without a specific length.
186                     qname = "%s(255)" % qname
187                 pk.append(qname)
188        
189         pk = ", ".join(pk)
190         if pk:
191             pk = ", PRIMARY KEY (%s)" % pk
192        
193         self.sm.execute('CREATE TABLE %s (%s%s);' %
194                         (table.qname, ", ".join(fields), pk))
195        
196         seq = getattr(table, "mysql_sequencer", None)
197         if seq:
198             # Wow, what a hack. We have to INSERT a dummy row
199             # to set the autoincrement initial value, and we
200             # can't delete it until after the CREATE INDEX
201             # statements (or the counter will revert).
202             colname, initial = seq
203             self.sm.execute("INSERT INTO %s (%s) VALUES (%s);"
204                             % (table.qname, q(colname), initial - 1))
205        
206         for k, index in table.columns.indices.iteritems():
207             dbtype = table.columns[k].dbtype
208             if dbtype.endswith('BLOB') or dbtype == 'TEXT':
209                 # MySQL won't allow indexes on a BLOB field
210                 # without a specific length.
211                 self.sm.execute('CREATE INDEX %s ON %s (%s(255));' %
212                                 (index.qname, table.qname, q(index.colname)))
213             else:
214                 self.sm.execute('CREATE INDEX %s ON %s (%s);' %
215                                 (index.qname, table.qname, q(index.colname)))
216        
217         if seq:
218             self.sm.execute("DELETE FROM %s" % table.qname)
219        
220         dict.__setitem__(self, key, table)
221    
222     def _get_tables(self, conn=None):
223         data, _ = self.sm.fetch("SHOW TABLES FROM %s" %
224                                 self.sm.quote(self.sm.dbname),
225                                 conn=conn)
226         return [self.sm.tableclass(self.sm, row[0], self.sm.quote(row[0]))
227                 for row in data]
228    
229     def _get_columns(self, tablename, conn=None):
230         # cols are: Field, Type, Null, Key, Default, Extra.
231         # See http://dev.mysql.com/doc/refman/4.1/en/describe.html
232         q = self.sm.quote
233         data, _ = self.sm.fetch("SHOW COLUMNS FROM %s.%s"
234                                 % (q(self.sm.dbname), q(tablename)),
235                                 conn=conn)
236         cols = []
237         for row in data:
238             c = self.sm.columnclass(row[0], self.sm.quote(row[0]),
239                                     None, None, row[4])
240            
241             dbtype = row[1]
242             parenpos = dbtype.find("(")
243             if parenpos > -1:
244                 c.hints['bytes'] = dbtype[parenpos+1:-1]
245                 dbtype = dbtype[:parenpos]
246             c.dbtype = dbtype
247            
248             if dbtype in ('tinyint', 'smallint', 'mediumint', 'int', 'integer'):
249                 c.type = int
250             elif dbtype == 'bigint':
251                 c.type = long
252             elif dbtype in ('float', 'double', 'real'):
253                 c.type = float
254             elif dbtype in ('decimal', 'numeric'):
255                 c.type = decimal.Decimal
256             elif dbtype == 'date':
257                 c.type = datetime.date
258             elif dbtype in ('datetime', 'timestamp'):
259                 c.type = datetime.datetime
260             elif dbtype == 'time':
261                 c.type = datetime.time
262             elif dbtype in ('char', 'varchar', 'binary', 'varbinary',
263                             'tinyblob', 'tinytext', 'blob', 'text',
264                             'mediumblob', 'mediumtext',
265                             'longblob', 'longtext'):
266                 c.type = str
267             cols.append(c)
268         return cols
269    
270     def _get_indices(self, tablename, conn=None):
271         indices = []
272         try:
273             # cols are: Table, Non_unique, Key_name, Seq_in_index, Column_name,
274             # Collation, Cardinality, Sub_part, Packed, Null, Index_type, Comment
275             q = self.sm.quote
276             data, _ = self.fetch("SHOW INDEX FROM %s.%s"
277                                  % (q(self.sm.dbname), q(tablename)),
278                                  conn=conn)
279         except _mysql.ProgrammingError, x:
280             if x.args[0] != 1146:
281                 raise
282         else:
283             for row in data:
284                 i = self.sm.indexclass(row[2], self.sm.quote(row[2]),
285                                        row[0], row[4], None, not row[1])
286                 indices.append(i)
287         return indices
288
289
290
291 class StorageManagerMySQL(db.StorageManagerDB):
292     """StoreManager to save and retrieve Units via _mysql."""
293    
294     sql_name_max_length = 64
295     # MySQL uses case-sensitive database and table names on Unix, but
296     # not on Windows. Use all-lowercase identifiers to work around the
297     # problem. "Column names, index names, and column aliases are not
298     # case sensitive on any platform."
299     # If deployers set lower_case_table_names to 1, it would help.
300     sql_name_caseless = True
301    
302     typeAdapter = FieldTypeAdapterMySQL()
303     toAdapter = AdapterToMySQL()
304     fromAdapter = AdapterFromMySQL()
305    
306     tablesetclass = MySQLTableSet
307     columnsetclass = MySQLColumnSet
308     indexsetclass = MySQLIndexSet
309    
310     def __init__(self, name, arena, allOptions={}):
311         connargs = ["host", "user", "passwd", "db", "port", "unix_socket",
312                     "conv", "connect_time", "compress", "named_pipe",
313                     "init_command", "read_default_file", "read_default_group",
314                     "cursorclass", "client_flag",
315                     ]
316         self.connargs = dict([(k, v) for k, v in allOptions.iteritems()
317                               if k in connargs])
318         self.dbname = self.connargs['db']
319        
320         db.StorageManagerDB.__init__(self, name, arena, allOptions)
321        
322         self.decompiler = MySQLDecompiler
323         # Get the version string from MySQL, to see if we need
324         # a different decompiler.
325         conn = self._template_conn()
326         rowdata, cols = self.fetch("SELECT version();", conn)
327         conn.close()
328         v = rowdata[0][0]
329         self._version = storage.Version(v)
330         if self._version > storage.Version("4.1.1"):
331             self.decompiler = MySQLDecompiler411
332    
333     def quote(self, name):
334         """Return name, quoted for use in an SQL statement."""
335         return '`' + name.replace('`', '``') + '`'
336    
337     def _get_conn(self):
338         try:
339             conn = _mysql.connect(**self.connargs)
340         except _mysql.OperationalError, x:
341             if x.args[0] == 1040:   # Too many connections
342                 raise db.OutOfConnectionsError
343             raise
344         return conn
345    
346     def _template_conn(self):
347         tmplconn = self.connargs.copy()
348         tmplconn['db'] = 'mysql'
349         return _mysql.connect(**tmplconn)
350    
351     def fetch(self, query, conn=None):
352         """fetch(query, conn=None) -> rowdata, columns.
353         
354         rowdata: a nested list (or tuples), column values within rows.
355         columns: a series of 2-tuples (or more). The first tuple value
356             will be the column name, the second value will be the column
357             type.
358         """
359         if conn is None:
360             conn = self.connection()
361         self.execute(query, conn)
362         # store_result uses a client-side cursor
363         res = conn.store_result()
364         return res.fetch_row(0, 0), res.describe()
365    
366     def destroy(self, unit):
367         """destroy(unit). Delete the unit."""
368         t = self.tables[unit.__class__.__name__].qname
369         self.execute('DELETE FROM %s WHERE %s;' % (t, self.id_clause(unit)))
370    
371     def version(self):
372         return "MySQL Version: %s" % self._version
373    
374     def _seq_UnitSequencerInteger(self, unit):
375         """Reserve a unit using the table's AUTO_INCREMENT field."""
376         cls = unit.__class__
377         t = self.tables[cls.__name__]
378        
379         fields = []
380         values = []
381         for key in cls.properties:
382             typename = self.typeAdapter.coerce(cls, key)
383             if typename.endswith("AUTO_INCREMENT"):
384                 # Skip this field, since we're using AUTO_INCREMENT
385                 continue
386             val = self.toAdapter.coerce(getattr(unit, key))
387             fields.append(t.columns[key].qname)
388             values.append(val)
389        
390         fields = ", ".join(fields)
391         values = ", ".join(values)
392        
393         conn = self.connection()
394         self.execute('INSERT INTO %s (%s) VALUES (%s);' %
395                      (t.qname, fields, values), conn)
396        
397         # Grab the new ID. This is threadsafe because db.reserve has a mutex.
398         setattr(unit, cls.identifiers[0], conn.insert_id())
399    
400     #                               Schemas                               #
401    
402     def create_database(self):
403         # _mysql has create_db and drop_db commands, but they're deprecated.
404         sql = 'CREATE DATABASE %s;' % self.quote(self.sql_name(self.dbname))
405         conn = self._template_conn()
406         self.execute(sql, conn)
407         conn.close()
408    
409     def drop_database(self):
410         sql = 'DROP DATABASE %s;' % self.quote(self.sql_name(self.dbname))
411         conn = self._template_conn()
412         self.execute(sql, conn)
413         conn.close()
414    
415     def create_storage(self, cls):
416         """Create storage for the given class."""
417         colname = self.column_name
418        
419         # Make a Table object.
420         tablename = self.table_name(cls.__name__)
421         t = self.tableclass(self, tablename, self.quote(tablename))
422        
423         indices = cls.indices()
424         fields = []
425         for key in cls.properties:
426             dbtype = self.typeAdapter.coerce(cls, key)
427             prop = cls.property(key)
428             cname = colname(cls.__name__, key)
429             col = self.columnclass(cname, self.quote(cname), dbtype,
430                                    prop.type, prop.default, prop.hints.copy())
431             # Use the superclass call to avoid ALTER TABLE.
432             dict.__setitem__(t.columns, key, col)
433            
434             if key in indices:
435                 iname = self.table_name("i" + cls.__name__ + key)
436                 i = self.indexclass(iname, self.quote(iname), tablename, cname)
437                 # Use the superclass call to avoid CREATE INDEX.
438                 dict.__setitem__(t.columns.indices, key, i)
439        
440         # Hack to get PRIMARY KEY right. See MySQLTableSet.__setitem__
441         t.mysql_identifiers = cls.identifiers
442        
443         # Hack to get AUTO_INCREMENT right where initial > 1.
444         # See MySQLTableSet.__setitem__
445         if isinstance(cls.sequencer, dejavu.UnitSequencerInteger):
446             i = cls.sequencer.initial
447             if i > 1:
448                 t.mysql_sequencer = (t.columns[cls.identifiers[0]].name, i)
449        
450         # Attach to self.tables, which should call CREATE TABLE.
451         self.tables[cls.__name__] = t
452
Note: See TracBrowser for help on using the browser.