Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storemysql.py

Revision 193 (checked in by fumanchu, 7 years ago)

Fix for #51 (remove expanded columns). If anyone objects, this can be reinstated with very little work.

  • Property svn:eol-style set to native
Line 
1 """
2 References the MySQLdb package at:
3 http://sourceforge.net/projects/mysql-python
4
5 From the MySQL manual:
6
7 "If the server SQL mode has ANSI_QUOTES enabled, string literals can be
8 quoted only with single quotes. A string quoted with double quotes will be
9 interpreted as an identifier."
10
11 So use single quotes throughout.
12 """
13
14 # Use _mysql directly to avoid all of the DB-API overhead.
15 import _mysql
16 import warnings
17 import datetime
18
19 try:
20     # Builtin in Python 2.5?
21     decimal
22 except NameError:
23     try:
24         # Module in Python 2.3, 2.4
25         import decimal
26     except ImportError:
27         pass
28
29 import dejavu
30 from dejavu import storage, logic
31 from dejavu.storage import db
32
33
34 class AdapterToMySQL(db.AdapterToSQL):
35    
36     escapes = [("'", "''"), ("\\", r"\\")]
37     like_escapes = [("%", r"\%"), ("_", r"\_")]
38    
39     # TRUE and FALSE only work with 4.1 or better.
40     bool_true = "1"
41     bool_false = "0"
42    
43     def coerce_str(self, value, skip_encoding=False):
44         if not skip_encoding and not isinstance(value, str):
45             value = value.encode(self.encoding)
46         return "'" + _mysql.escape_string(value) + "'"
47    
48     def coerce_bool(self, value):
49         # TRUE and FALSE only work with 4.1 or better.
50         if value:
51             return '1'
52         return '0'
53
54
55 class AdapterFromMySQL(db.AdapterFromDB):
56    
57     def coerce_bool(self, value, coltype):
58         if isinstance(value, basestring):
59             # either '0' or '1'
60             value = (value == '1')
61         return bool(value)
62
63
64 class MySQLDecompiler(db.SQLDecompiler):
65    
66     def dejavu_today(self):
67         return "CURDATE()"
68
69
70 class MySQLDecompiler411(MySQLDecompiler):
71     # Before MySQL 4.1.1, BINARY comparisons could use UPPER()
72     # or LOWER() to perform case-insensitive comparisons. Newer
73     # versions must use CONVERT() to obtain a case-sensitive
74     # encoding, like utf8.
75    
76     def dejavu_icontainedby(self, op1, op2):
77         if isinstance(op1, db.ConstWrapper):
78             # Looking for text in a field. Use Like (reverse terms).
79             return ("CONVERT("+ op2 + " USING utf8) LIKE '%" +
80                     self.adapter.escape_like(op1) + "%'")
81         else:
82             # Looking for field in (a, b, c).
83             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
84             return "CONVERT(%s USING utf8) IN (%s)" % (op1, ", ".join(atoms))
85    
86     def dejavu_istartswith(self, x, y):
87         return ("CONVERT(" + x + " USING utf8) LIKE '" +
88                 self.adapter.escape_like(y) + "%'")
89    
90     def dejavu_iendswith(self, x, y):
91         return ("CONVERT(" + x + " USING utf8) LIKE '%" +
92                 self.adapter.escape_like(y) + "'")
93    
94     def dejavu_ieq(self, x, y):
95         return "CONVERT(" + x + " USING utf8) = " + y
96
97
98 class FieldTypeAdapterMySQL(db.FieldTypeAdapter):
99     """Return the SQL typename of a DB column."""
100    
101     # This was determined through experimentation. Don't change it.
102     numeric_max_precision = 253
103    
104     def coerce_str(self, cls, key):
105         prop = getattr(cls, key)
106         bytes = int(prop.hints.get('bytes', '0'))
107         if bytes:
108             # MySQL VARBINARY/BLOBs will do case-sensitive comparisons.
109             # They also won't truncate trailing spaces like VARCHAR does.
110             if bytes <= 255:
111                 return "VARBINARY(%s)" % bytes
112             elif bytes < 2 ** 16:
113                 return "BLOB"
114             elif bytes < 2 ** 24:
115                 return "MEDIUMBLOB"
116         return "LONGBLOB"
117    
118     def coerce_bool(self, cls, key):
119         # We could use BOOLEAN, but it wasn't introduced until 4.1.0.
120         return "BOOL"
121    
122     def coerce_datetime_datetime(self, cls, key):
123         return "DATETIME"
124    
125     def coerce_int(self, cls, key):
126         prop = getattr(cls, key)
127         bytes = int(prop.hints.get('bytes', '4'))
128         if bytes == 1:
129             return "BOOLEAN"
130         elif bytes == 2:
131             typename = "SMALLINT"
132         elif bytes == 3:
133             typename = "MEDIUMINT"
134         else:
135             typename = "INTEGER"
136         if isinstance(cls.sequencer, dejavu.UnitSequencerInteger):
137             if key in cls.identifiers:
138                 typename += " AUTO_INCREMENT"
139         return typename
140
141
142
143 class StorageManagerMySQL(db.StorageManagerDB):
144     """StoreManager to save and retrieve Units via _mysql."""
145    
146     sql_name_max_length = 64
147     # MySQL uses case-sensitive database and table names on Unix, but
148     # not on Windows. Use all-lowercase identifiers to work around the
149     # problem. "Column names, index names, and column aliases are not
150     # case sensitive on any platform."
151     # If deployers set lower_case_table_names to 1, it would help.
152     sql_name_caseless = True
153    
154     typeAdapter = FieldTypeAdapterMySQL()
155     toAdapter = AdapterToMySQL()
156     fromAdapter = AdapterFromMySQL()
157    
158     def __init__(self, name, arena, allOptions={}):
159         db.StorageManagerDB.__init__(self, name, arena, allOptions)
160        
161         connargs = ["host", "user", "passwd", "db", "port", "unix_socket",
162                     "conv", "connect_time", "compress", "named_pipe",
163                     "init_command", "read_default_file", "read_default_group",
164                     "cursorclass", "client_flag",
165                     ]
166         self.connargs = dict([(k, v) for k, v in allOptions.iteritems()
167                               if k in connargs])
168         self.dbname = self.connargs['db']
169        
170         self.decompiler = MySQLDecompiler
171         # Get the version string from MySQL, to see if we need
172         # a different decompiler.
173         conn = self._template_conn()
174         rowdata, cols = self.fetch("SELECT version();", conn)
175         conn.close()
176         v = rowdata[0][0]
177         self._version = storage.Version(v)
178         if self._version > storage.Version("4.1.1"):
179             self.decompiler = MySQLDecompiler411
180    
181     def sql_name(self, name, quoted=True):
182         name = db.StorageManagerDB.sql_name(self, name, quoted)
183         if quoted:
184             name = '`' + name.replace('`', '``') + '`'
185         return name
186    
187     def _get_conn(self):
188         try:
189             conn = _mysql.connect(**self.connargs)
190         except _mysql.OperationalError, x:
191             if x.args[0] == 1040:   # Too many connections
192                 raise db.OutOfConnectionsError
193             raise
194         return conn
195    
196     def _template_conn(self):
197         tmplconn = self.connargs.copy()
198         tmplconn['db'] = 'mysql'
199         return _mysql.connect(**tmplconn)
200    
201     def fetch(self, query, conn=None):
202         """fetch(query, conn=None) -> rowdata, columns.
203         
204         rowdata: a nested list (or tuples), column values within rows.
205         columns: a series of 2-tuples (or more). The first tuple value
206             will be the column name, the second value will be the column
207             type.
208         """
209         if conn is None:
210             conn = self.connection()
211         self.execute(query, conn)
212         # store_result uses a client-side cursor
213         res = conn.store_result()
214         return res.fetch_row(0, 0), res.describe()
215    
216     def destroy(self, unit):
217         """destroy(unit). Delete the unit."""
218         self.execute('DELETE FROM %s WHERE %s;' %
219                      (self.table_name(unit.__class__.__name__),
220                       self.id_clause(unit)))
221    
222     def version(self):
223         return "MySQL Version: %s" % self._version
224    
225     def _seq_UnitSequencerInteger(self, unit):
226         """Reserve a unit using the table's AUTO_INCREMENT field."""
227         cls = unit.__class__
228         clsname = cls.__name__
229         tablename = self.table_name(clsname)
230        
231         fields = []
232         values = []
233         for key in cls.properties:
234             typename = self.typeAdapter.coerce(cls, key)
235             if typename.endswith("AUTO_INCREMENT"):
236                 # Skip this field, since we're using AUTO_INCREMENT
237                 continue
238             val = self.toAdapter.coerce(getattr(unit, key))
239             fields.append(self.column_name(clsname, key))
240             values.append(val)
241        
242         fields = ", ".join(fields)
243         values = ", ".join(values)
244         self.execute('INSERT INTO %s (%s) VALUES (%s);' %
245                      (str(tablename), fields, values))
246        
247         # Grab the new ID. This is threadsafe because db.reserve has a mutex.
248         data, col_defs = self.fetch("SELECT LAST_INSERT_ID();")
249         setattr(unit, cls.identifiers[0], data[0][0])
250    
251     #                               Schemas                               #
252    
253     def create_database(self):
254         # _mysql has create_db and drop_db commands, but they're deprecated.
255         sql = 'CREATE DATABASE %s;' % self.sql_name(self.dbname)
256         conn = self._template_conn()
257         self.execute(sql, conn)
258         conn.close()
259    
260     def drop_database(self):
261         sql = 'DROP DATABASE %s;' % self.sql_name(self.dbname)
262         conn = self._template_conn()
263         self.execute(sql, conn)
264         conn.close()
265    
266     def create_storage(self, cls):
267         clsname = cls.__name__
268         tablename = self.table_name(clsname)
269         typename = self.typeAdapter.coerce
270        
271         fields = []
272         pk = []
273         for key in cls.properties:
274             qname = self.column_name(clsname, key)
275             dbtype = typename(cls, key)
276             fields.append('%s %s' % (qname, dbtype))
277             if key in cls.identifiers:
278                 if dbtype.endswith('BLOB') or dbtype == 'TEXT':
279                     # MySQL won't allow indexes on a BLOB field
280                     # without a specific length.
281                     qname = "%s(%s)" % (qname, 255)
282                 pk.append(qname)
283         pk = ", ".join(pk)
284         if pk:
285             pk = ", PRIMARY KEY (%s)" % pk
286         self.execute('CREATE TABLE %s (%s%s);'
287                      % (tablename, ", ".join(fields), pk))
288        
289         hasdummy = False
290         if isinstance(cls.sequencer, dejavu.UnitSequencerInteger):
291             i = cls.sequencer.initial
292             if i > 1:
293                 # Wow, what a hack. We have to create a dummy row
294                 # to set the autoincrement initial value, and we
295                 # can't delete it until after the CREATE INDEX
296                 # statements below (or the counter will revert).
297                 colname = self.column_name(clsname, cls.identifiers[0])
298                 self.execute("INSERT INTO %s (%s) VALUES (%s);"
299                              % (tablename, colname, i - 1))
300                 hasdummy = True
301        
302         for index in cls.indices():
303             i = self.table_name("i" + clsname + index)
304            
305             dbtype = typename(cls, index)
306             if dbtype.endswith('BLOB') or dbtype == 'TEXT':
307                 # MySQL won't allow indexes on a BLOB field
308                 # without a specific length.
309                 self.execute('CREATE INDEX %s ON %s (%s(%s));' %
310                              (i, tablename,
311                               self.column_name(clsname, index), 255))
312             else:
313                 self.execute('CREATE INDEX %s ON %s (%s);' %
314                              (i, tablename,
315                               self.column_name(clsname, index)))
316        
317         if hasdummy:
318             self.execute("DELETE FROM %s" % tablename)
319    
320     def rename_property(self, cls, oldname, newname):
321         clsname = cls.__name__
322         oldcolname = self.column_name(clsname, oldname)
323         newcolname = self.column_name(clsname, newname)
324         if oldcolname != newcolname:
325             self.execute("ALTER TABLE %s CHANGE %s %s %s;" %
326                          (self.table_name(clsname), oldcolname, newcolname,
327                           self.typeAdapter.coerce(cls, newname)))
328    
329     def drop_index(self, cls, name):
330         # MySQL might rename multiple-column indices to "PRIMARY"
331         clsname = cls.__name__
332         names = []
333         for i in self.get_indices(self.table_name(clsname, quoted=False)):
334             if i.name not in names:
335                 names.append(i.name)
336         for n in names:
337             self.execute('DROP INDEX %s ON %s;' %
338                          (self.sql_name(n), self.table_name(clsname)))
339    
340     def get_tables(self, conn=None):
341         data, _ = self.fetch("SHOW TABLES FROM %s" % self.dbname,
342                              conn=conn)
343         return [db.Table(row[0]) for row in data]
344    
345     def get_columns(self, tablename=None, conn=None):
346         # cols are: Field, Type, Null, Key, Default, Extra.
347         # See http://dev.mysql.com/doc/refman/4.1/en/describe.html
348         data, _ = self.fetch("SHOW COLUMNS FROM %s.%s"
349                              % (self.dbname, self.sql_name(tablename)),
350                              conn=conn)
351         cols = []
352         for row in data:
353             c = db.Column(row[0], None, row[4])
354            
355             dbtype = row[1]
356             parenpos = dbtype.find("(")
357             if parenpos > -1:
358                 c.hints['bytes'] = dbtype[parenpos+1:-1]
359                 dbtype = dbtype[:parenpos]
360            
361             if dbtype in ('tinyint', 'smallint', 'mediumint', 'int', 'integer'):
362                 c.type = int
363             elif dbtype == 'bigint':
364                 c.type = long
365             elif dbtype in ('float', 'double', 'real'):
366                 c.type = float
367             elif dbtype in ('decimal', 'numeric'):
368                 c.type = decimal.Decimal
369             elif dbtype == 'date':
370                 c.type = datetime.date
371             elif dbtype in ('datetime', 'timestamp'):
372                 c.type = datetime.datetime
373             elif dbtype == 'time':
374                 c.type = datetime.time
375             elif dbtype in ('char', 'varchar', 'binary', 'varbinary',
376                             'tinyblob', 'tinytext', 'blob', 'text',
377                             'mediumblob', 'mediumtext',
378                             'longblob', 'longtext'):
379                 c.type = str
380             cols.append(c)
381         return cols
382    
383     def get_indices(self, tablename, conn=None):
384         indices = []
385         try:
386             # cols are: Table, Non_unique, Key_name, Seq_in_index, Column_name,
387             # Collation, Cardinality, Sub_part, Packed, Null, Index_type, Comment
388             data, _ = self.fetch("SHOW INDEX FROM %s.%s"
389                                  % (self.dbname, self.sql_name(tablename)),
390                                  conn=conn)
391         except _mysql.ProgrammingError, x:
392             if x.args[0] != 1146:
393                 raise
394         else:
395             for row in data:
396                 indices.append(db.Index(row[2], row[0], row[4], None, not row[1]))
397         return indices
398
Note: See TracBrowser for help on using the browser.