Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

root/trunk/geniusql/providers/mysql.py

Revision 7 (checked in by fumanchu, 6 years ago)

Various bugfixes related to Table.created.

  • Property svn:eol-style set to native
Line 
1 """
2 Uses the MySQLdb package at:
3 http://sourceforge.net/projects/mysql-python
4
5 From the MySQL manual:
6
7 "If the server SQL mode has ANSI_QUOTES enabled, string literals can be
8 quoted only with single quotes. A string quoted with double quotes will be
9 interpreted as an identifier."
10
11 So use single quotes throughout.
12 """
13
14 # Use _mysql directly to avoid all of the DB-API overhead.
15 import _mysql
16 import datetime
17
18 import geniusql
19 from geniusql import errors, typerefs, providers
20
21
22 class AdapterToMySQL(geniusql.AdapterToSQL):
23    
24     escapes = [("'", "''"), ("\\", r"\\")]
25     like_escapes = [("%", r"\%"), ("_", r"\_")]
26    
27     # TRUE and FALSE only work with 4.1 or better.
28     bool_true = "1"
29     bool_false = "0"
30    
31     def coerce_str_to_any(self, value, skip_encoding=False):
32         if not skip_encoding and not isinstance(value, str):
33             value = value.encode(self.encoding)
34         return "'" + _mysql.escape_string(value) + "'"
35    
36     def coerce_bool_to_any(self, value):
37         # TRUE and FALSE only work with 4.1 or better.
38         if value:
39             return '1'
40         return '0'
41
42
43 class AdapterFromMySQL(geniusql.AdapterFromDB):
44    
45     def coerce_any_to_bool(self, value):
46         if isinstance(value, basestring):
47             # either '0' or '1'
48             value = (value == '1')
49         return bool(value)
50
51
52 class MySQLDecompiler(geniusql.SQLDecompiler):
53    
54     def dejavu_today(self):
55         return "CURDATE()"
56
57
58 class MySQLDecompiler411(MySQLDecompiler):
59     # Before MySQL 4.1.1, BINARY comparisons could use UPPER()
60     # or LOWER() to perform case-insensitive comparisons. Newer
61     # versions must use CONVERT() to obtain a case-sensitive
62     # encoding, like utf8.
63    
64     def dejavu_icontainedby(self, op1, op2):
65         if isinstance(op1, geniusql.ConstWrapper):
66             # Looking for text in a field. Use Like (reverse terms).
67             return ("CONVERT("+ op2 + " USING utf8) LIKE '%" +
68                     self.adapter.escape_like(op1) + "%'")
69         else:
70             # Looking for field in (a, b, c).
71             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
72             return "CONVERT(%s USING utf8) IN (%s)" % (op1, ", ".join(atoms))
73    
74     def dejavu_istartswith(self, x, y):
75         return ("CONVERT(" + x + " USING utf8) LIKE '" +
76                 self.adapter.escape_like(y) + "%'")
77    
78     def dejavu_iendswith(self, x, y):
79         return ("CONVERT(" + x + " USING utf8) LIKE '%" +
80                 self.adapter.escape_like(y) + "'")
81    
82     def dejavu_ieq(self, x, y):
83         return "CONVERT(" + x + " USING utf8) = " + y
84
85
86 class TypeAdapterMySQL(geniusql.TypeAdapter):
87    
88     numeric_max_precision = 16
89     numeric_max_bytes = 8
90    
91     def float_type(self, precision):
92         """Return a datatype which can handle the given precision."""
93         # "p represents the precision in bits, but MySQL uses this value
94         # only to determine whether to use FLOAT or DOUBLE for the
95         # resulting data type. If p is from 0 to 24, the data type
96         # becomes FLOAT with no M or D values. If p is from 25 to 53,
97         # the data type becomes DOUBLE with no M or D values."
98         return "FLOAT(%s)" % precision
99    
100     def coerce_str(self, col):
101         bytes = int(col.hints.get('bytes', 255))
102        
103         if bytes:
104             # MySQL VARBINARY/BLOBs will do case-sensitive comparisons.
105             # They also won't truncate trailing spaces like VARCHAR does.
106             if bytes <= 255:
107                 return "VARBINARY(%s)" % bytes
108             elif bytes < 2 ** 16:
109                 return "BLOB"
110             elif bytes < 2 ** 24:
111                 return "MEDIUMBLOB"
112         return "LONGBLOB"
113    
114     def coerce_bool(self, col):
115         # We could use BOOLEAN, but it wasn't introduced until 4.1.0.
116         return "BOOL"
117    
118     def coerce_datetime_datetime(self, col):
119         return "DATETIME"
120    
121     def coerce_int(self, col):
122         bytes = int(col.hints.get('bytes', '4'))
123         if bytes <= 2:
124             return "SMALLINT"
125         elif bytes == 3:
126             return "MEDIUMINT"
127         return "INTEGER"
128
129
130 class TypeAdapterMySQL41(TypeAdapterMySQL):
131    
132     def coerce_str(self, col):
133         dbtype = TypeAdapterMySQL.coerce_str(self, col)
134         if dbtype == "BLOB":
135             dbtype = "BLOB(%s)" % col.hints['bytes']
136         return dbtype
137
138
139 class MySQLIndexSet(geniusql.IndexSet):
140    
141     def __delitem__(self, key):
142         t = self.table
143         t.db.lock("Dropping index. Transactions not allowed.")
144         try:
145             # MySQL might rename multiple-column indices to "PRIMARY"
146             for i in t.db._get_indices(t.name):
147                 if i.colname == self[key].colname:
148                     t.db.execute('DROP INDEX %s ON %s;' % (i.qname, t.qname))
149         finally:
150             t.db.unlock()
151
152
153 class MySQLTable(geniusql.Table):
154    
155     indexsetclass = MySQLIndexSet
156    
157     def _rename(self, oldcol, newcol):
158         self.db.execute("ALTER TABLE %s CHANGE %s %s %s;" %
159                         (self.qname, oldcol.qname, newcol.qname,
160                          oldcol.dbtype))
161
162
163 connargs = ["host", "user", "passwd", "db", "port", "unix_socket",
164             "conv", "connect_time", "compress", "named_pipe",
165             "init_command", "read_default_file", "read_default_group",
166             "cursorclass", "client_flag",
167             ]
168
169 class MySQLDatabase(geniusql.Database):
170    
171     sql_name_max_length = 64
172     # MySQL uses case-sensitive database and table names on Unix, but
173     # not on Windows. Use all-lowercase identifiers to work around the
174     # problem. "Column names, index names, and column aliases are not
175     # case sensitive on any platform."
176     # If deployers set lower_case_table_names to 1, it would help.
177     sql_name_caseless = True
178     encoding = "utf8"
179    
180     adaptertosql = AdapterToMySQL()
181     adapterfromdb = AdapterFromMySQL()
182     typeadapter = TypeAdapterMySQL()
183    
184     tableclass = MySQLTable
185     indexsetclass = MySQLIndexSet
186    
187     # InnoDB default
188     default_isolation = "REPEATABLE READ"
189    
190     def __init__(self, name, **kwargs):
191         geniusql.Database.__init__(self, name, **kwargs)
192        
193         self.connargs = dict([(k, v) for k, v in kwargs.iteritems()
194                               if k in connargs])
195        
196         self.decompiler = MySQLDecompiler
197        
198         # Get the version string from MySQL, to see if we need
199         # a different decompiler.
200         conn = self._template_conn()
201         rowdata, cols = self.fetch("SELECT version();", conn)
202         conn.close()
203         v = rowdata[0][0]
204         self._version = providers.Version(v)
205        
206         # decompiler
207         if self._version > providers.Version("4.1.1"):
208             self.decompiler = MySQLDecompiler411
209        
210         # type adapter
211         if self._version >= providers.Version("4.1"):
212             self.typeadapter = TypeAdapterMySQL41()
213    
214     def version(self):
215         return "MySQL Version: %s" % self._version
216    
217     def columnclause(self, column):
218         """Return a clause for the given column for CREATE or ALTER TABLE.
219         
220         This will be of the form "name type [DEFAULT x] [AUTO_INCREMENT]"
221         """
222         dbtype = column.dbtype
223        
224         autoincr = ""
225         if column.autoincrement:
226             autoincr = " AUTO_INCREMENT"
227        
228         default = column.default or ""
229         if default:
230             default = self.adaptertosql.coerce(default, dbtype)
231             default = " DEFAULT %s" % default
232        
233         return "%s %s%s%s" % (column.qname, dbtype, default, autoincr)
234    
235     def __setitem__(self, key, table):
236         q = self.quote
237         if key in self:
238             del self[key]
239        
240         # Set table.created to True, which should "turn on"
241         # any future ALTER TABLE statements.
242         table.created = True
243        
244         fields = []
245         incr_fields = []
246         pk = []
247         for colkey, col in table.iteritems():
248             fields.append(self.columnclause(col))
249            
250             if col.autoincrement:
251                 # INSERT INTO t (c) VALUES(0) doesn't work for some reason
252                 if col.initial > 1:
253                     incr_fields.append(col)
254            
255             if col.key:
256                 qname = col.qname
257                 dbtype = col.dbtype
258                 if dbtype.endswith('BLOB') or dbtype == 'TEXT':
259                     # MySQL won't allow indexes on a BLOB field without a
260                     # specific index prefix length. We choose 255 just for fun.
261                     qname = "%s(255)" % qname
262                 pk.append(qname)
263        
264         if pk:
265             pk = ", PRIMARY KEY (%s)" % ", ".join(pk)
266         else:
267             pk = ""
268        
269         encoding = self.encoding
270         if encoding:
271             encoding = " CHARACTER SET %s" % encoding
272        
273         self.lock("Creating storage. Transactions not allowed.")
274         try:
275             self.execute('CREATE TABLE %s (%s%s)%s;' %
276                          (table.qname, ", ".join(fields), pk, encoding))
277            
278             if incr_fields:
279                 # Wow, what a hack. We have to INSERT a dummy row to set the
280                 # autoincrement initial value(s), and we can't delete it until
281                 # after the CREATE INDEX statements (or the counter will revert).
282                 fields = ", ".join([col.qname for col in incr_fields])
283                 values = ", ".join([str(col.initial - 1) for col in incr_fields])
284                 self.execute("INSERT INTO %s (%s) VALUES (%s);"
285                              % (table.qname, fields, values))
286            
287             for k, index in table.indices.iteritems():
288                 dbtype = table[k].dbtype
289                 if dbtype.endswith('BLOB') or dbtype == 'TEXT':
290                     # MySQL won't allow indexes on a BLOB field without a
291                     # specific index prefix length. We choose 255 just for fun.
292                     self.execute('CREATE INDEX %s ON %s (%s(255));' %
293                                  (index.qname, table.qname, q(index.colname)))
294                 else:
295                     self.execute('CREATE INDEX %s ON %s (%s);' %
296                                  (index.qname, table.qname, q(index.colname)))
297            
298             if incr_fields:
299                 self.execute("DELETE FROM %s" % table.qname)
300         finally:
301             self.unlock()
302        
303         dict.__setitem__(self, key, table)
304    
305     def _get_tables(self, conn=None):
306         data, _ = self.fetch("SHOW TABLES FROM %s" % self.qname, conn=conn)
307         return [self.tableclass(row[0], self.quote(row[0]),
308                                 self, created=True)
309                 for row in data]
310    
311     def _get_table(self, tablename, conn=None):
312         data, _ = self.fetch("SHOW TABLES FROM %s LIKE '%s'"
313                              % (self.qname, tablename), conn=conn)
314         for row in data:
315             name = row[0]
316             if name == tablename:
317                 return self.tableclass(name, self.quote(name),
318                                        self, created=True)
319         raise errors.MappingError(tablename)
320    
321     def _get_columns(self, tablename, conn=None):
322         # cols are: Field, Type, Null, Key, Default, Extra.
323         # See http://dev.mysql.com/doc/refman/4.1/en/describe.html
324         data, _ = self.fetch("SHOW COLUMNS FROM %s.%s" %
325                              (self.qname, self.quote(tablename)), conn=conn)
326         cols = []
327         for row in data:
328             hints = {}
329             dbtype = row[1].upper()
330             parenpos = dbtype.find("(")
331             if parenpos > -1:
332                 args = dbtype[parenpos+1:-1]
333                 baretype = dbtype[:parenpos]
334                 if baretype in ("DECIMAL", "NUMERIC"):
335                     args = [x.strip() for x in args.split(",")]
336                     hints['precision'], hints['scale'] = args
337                 else:
338                     hints['bytes'] = args
339             elif dbtype == "FLOAT":
340                 hints['precision'] = 24
341             elif dbtype.startswith("DOUBLE"):
342                 hints['precision'] = 53
343             elif dbtype in ("TINYBLOB", "TINYTEXT"):
344                 hints['bytes'] = (2 ** 8) - 1
345             elif dbtype in ("BLOB", "TEXT"):
346                 hints['bytes'] = (2 ** 16) - 1
347             elif dbtype in ("MEDIUMBLOB", "MEDIUMTEXT"):
348                 hints['bytes'] = (2 ** 24) - 1
349             elif dbtype in ("LONGBLOB", "LONGTEXT"):
350                 hints['bytes'] = (2 ** 32) - 1
351            
352             key = (row[3] == "PRI")
353             pytype = self.python_type(dbtype)
354            
355             col = geniusql.Column(pytype, dbtype, None, hints, key,
356                                   row[0], self.quote(row[0]))
357            
358             if row[4]:
359                 col.default = pytype(row[4])
360             if "auto_increment" in row[5].lower():
361                 col.autoincrement = True
362            
363             cols.append(col)
364         return cols
365    
366     def _get_indices(self, tablename, conn=None):
367         indices = []
368         try:
369             # cols are: Table, Non_unique, Key_name, Seq_in_index, Column_name,
370             # Collation, Cardinality, Sub_part, Packed, Null, Index_type, Comment
371             data, _ = self.fetch("SHOW INDEX FROM %s.%s"
372                                  % (self.qname, self.quote(tablename)),
373                                  conn=conn)
374         except _mysql.ProgrammingError, x:
375             if x.args[0] != 1146:
376                 raise
377         else:
378             for row in data:
379                 i = geniusql.Index(row[2], self.quote(row[2]),
380                                    row[0], row[4], not row[1])
381                 indices.append(i)
382         return indices
383    
384     def python_type(self, dbtype):
385         """Return a Python type which can store values of the given dbtype."""
386         dbtype = dbtype.upper()
387         parenpos = dbtype.find("(")
388         if parenpos > -1:
389             dbtype = dbtype[:parenpos]
390        
391         if dbtype in ('TINYINT', 'SMALLINT', 'MEDIUMINT', 'INT', 'INTEGER'):
392             return int
393         elif dbtype == 'BIGINT':
394             return long
395         elif dbtype in ('BOOL', 'BOOLEAN'):
396             return bool
397         elif dbtype in ('FLOAT', 'DOUBLE', 'DOUBLE PRECISION', 'REAL'):
398             return float
399         elif dbtype in ('DECIMAL', 'NUMERIC'):
400             if typerefs.decimal:
401                 return typerefs.decimal.Decimal
402             elif typerefs.fixedpoint:
403                 return typerefs.fixedpoint.Fixedpoint
404         elif dbtype == 'DATE':
405             return datetime.date
406         elif dbtype in ('DATETIME', 'TIMESTAMP'):
407             return datetime.datetime
408         elif dbtype == 'TIME':
409             return datetime.time
410         elif dbtype in ('CHAR', 'VARCHAR', 'BINARY', 'VARBINARY',
411                         'TINYBLOB', 'TINYTEXT', 'BLOB', 'TEXT',
412                         'MEDIUMBLOB', 'MEDIUMTEXT', 'LONGBLOB', 'LONGTEXT'):
413             return str
414        
415         raise TypeError("Database type %r could not be converted "
416                         "to a Python type." % dbtype)
417    
418     def quote(self, name):
419         """Return name, quoted for use in an SQL statement."""
420         return '`' + name.replace('`', '``') + '`'
421    
422     def _get_conn(self):
423         try:
424             conn = _mysql.connect(**self.connargs)
425         except _mysql.OperationalError, x:
426             if x.args[0] == 1040:   # Too many connections
427                 raise errors.OutOfConnectionsError
428             raise
429         return conn
430    
431     def _template_conn(self):
432         tmplconn = self.connargs.copy()
433         tmplconn['db'] = 'mysql'
434         return _mysql.connect(**tmplconn)
435    
436     def execute(self, query, conn=None):
437         """execute(query, conn=None) -> result set."""
438         if conn is None:
439             conn = self.connection()
440         if isinstance(query, unicode):
441             query = query.encode(self.adaptertosql.encoding)
442         self.log(query)
443         try:
444             return conn.query(query)
445         except _mysql.OperationalError, x:
446             if x.args[0] == 1030 and x.args[1] == 'Got error 139 from storage engine':
447                 raise ValueError("row length exceeds 8000 byte limit")
448             raise
449    
450     def fetch(self, query, conn=None):
451         """fetch(query, conn=None) -> rowdata, columns.
452         
453         rowdata: a nested list (or tuples), column values within rows.
454         columns: a series of 2-tuples (or more). The first tuple value
455             will be the column name, the second value will be the column
456             type.
457         """
458         if conn is None:
459             conn = self.connection()
460         self.execute(query, conn)
461        
462         # store_result uses a client-side cursor
463         res = conn.store_result()
464        
465         # The Python MySQLdb library swallows lock timeouts and returns []
466         # (for example, when deadlocked during a SERIALIZABLE transaction).
467         # Raise an error instead.
468         # Oddly, although the deadlock will stall the conn.query() call,
469         # the error message is only available after store_result().
470         err = conn.error()
471         if err == "Lock wait timeout exceeded; try restarting transaction":
472             raise _mysql.OperationalError(1205, err)
473        
474         if res is None:
475             return [], []
476         return res.fetch_row(0, 0), res.describe()
477    
478     def _grab_new_ids(self, table, idkeys, conn):
479         return {idkeys[0]: conn.insert_id()}
480    
481     def create_database(self):
482         self.lock("Creating database. Transactions not allowed.")
483         try:
484             # _mysql has create_db and drop_db commands, but they're deprecated.
485             encoding = self.encoding
486             if encoding:
487                 encoding = " CHARACTER SET %s" % encoding
488             sql = 'CREATE DATABASE %s%s;' % (self.qname, encoding)
489             conn = self._template_conn()
490             self.execute(sql, conn)
491             conn.close()
492             self.clear()
493         finally:
494             self.unlock()
495    
496     def drop_database(self):
497         self.lock("Dropping database. Transactions not allowed.")
498         try:
499             sql = 'DROP DATABASE %s;' % self.qname
500             conn = self._template_conn()
501             self.execute(sql, conn)
502             conn.close()
503             self.clear()
504         finally:
505             self.unlock()
506    
507     def is_lock_error(self, exc):
508         # OperationalError: (1205, 'Lock wait timeout exceeded; try restarting transaction')
509         if not isinstance(exc, _mysql.OperationalError):
510             return False
511         return exc.args[0] == 1205
512
Note: See TracBrowser for help on using the browser.