Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

root/trunk/geniusql/providers/ado.py

Revision 47 (checked in by fumanchu, 6 years ago)

Fixed SQL Server binary ops.

  • Property svn:eol-style set to native
Line 
1 import sys
2 # Put COM in free-threaded mode. This first thread will have
3 # CoInitializeEx called automatically when pythoncom is imported.
4 sys.coinit_flags = 0
5
6 import pythoncom
7 Empty = pythoncom.Empty
8 clsctx = pythoncom.CLSCTX_SERVER
9
10 import win32com.client
11
12 # InvokeTypes args (always pass as *args)
13 BOF = (1002, 0, 2, (11, 0), ())
14 EOF = (1006, 0, 2, (11, 0), ())
15 Recordset_Fields = (0, 0, 2, (9, 0), ())
16 # This assumes no arguments passed to GetRows
17 Recordset_GetRows = (1016, 0, 1, (12, 0), ((3, 49), (12, 17), (12, 17)), -1, Empty, Empty)
18 Recordset_Close = (1014, 0, 1, (24, 0), (),)
19 Fields_Count = (1, 0, 2, (3, 0), ())
20 Field_Name = (1100, 0, 2, (8, 0), ())
21 Field_Type = (1102, 0, 2, (3, 0), ())
22 Field_Properties = (500, 0, 2, (9, 0), ())
23 Property_Value = (0, 0, 2, (12, 0), ())
24
25 import pywintypes
26 import datetime
27
28 import time
29
30 try:
31     import cPickle as pickle
32 except ImportError:
33     import pickle
34
35 import threading
36 import warnings
37
38
39 import geniusql
40 from geniusql import adapters, conns, decompile, errors, select, typerefs
41 from geniusql import isolation as _isolation
42
43 adOpenForwardOnly = 0
44 adOpenKeyset = 1
45 adOpenDynamic = 2
46 adOpenStatic = 3
47
48 adLockReadOnly = 1
49 adLockPessimistic = 2
50 adLockOptimistic = 3
51 adLockBatchOptimistic = 4
52
53 adSchemaColumns = 4
54 adSchemaIndexes = 12
55 adSchemaTables = 20
56 adSchemaPrimaryKeys = 28
57
58 adUseClient = 3
59
60 # See http://www.carlprothman.net/Technology/DataTypeMapping/tabid/97/Default.aspx
61 dbtypes = {# ADO Name           SQL Server              MS Access
62         0: 'EMPTY',
63         2: 'SMALLINT',        # SMALLINT                INTEGER
64         3: 'INTEGER',         # IDENTITY (6.5), INT     AUTONUMBER, LONG
65         4: 'SINGLE',          # REAL                    SINGLE
66         5: 'DOUBLE',          # FLOAT                   DOUBLE
67         6: 'CURRENCY',        # MONEY, SMALLMONEY       CURRENCY
68         7: 'DATE',            #                         DATETIME (Access 97)
69         8: 'BSTR', 9: 'IDISPATCH', 10: 'ERROR',
70         11: 'BOOLEAN',        # BIT                     YESNO
71         12: 'VARIANT',        # SQL_VARIANT (2000 +)
72         13: 'IUNKNOWN', 14: 'DECIMAL', 16: 'TINYINT',
73         17: 'UNSIGNEDTINYINT',# TINYINT                 BYTE
74         18: 'UNSIGNEDSMALLINT', 19: 'UNSIGNEDINT',
75         20: 'BIGINT',         # BIGINT
76         21: 'UNSIGNEDBIGINT',
77         72: 'GUID',
78         128: 'BINARY',        # BINARY, TIMESTAMP
79         129: 'CHAR',          # CHAR
80         130: 'WCHAR',         # NCHAR (7.0+)
81         131: 'NUMERIC',       # DECIMAL, NUMERIC        DECIMAL (Access 2000)
82         132: 'USERDEFINED',
83         133: 'DBDATE', 134: 'DBTIME',
84         135: 'DBTIMESTAMP',   # DATETIME, SMALLDATETIME   DATETIME (ODBC 97)
85         200: 'VARCHAR',       # VARCHAR                 TEXT (Access 97)
86         201: 'LONGVARCHAR',   # TEXT                    MEMO (Access 97)
87         202: 'VARWCHAR',      # NVARCHAR                TEXT (Access 2000)
88         203: 'LONGVARWCHAR',  # NTEXT (7.0+)            MEMO (Access 2000+)
89         204: 'VARBINARY',     # VARBINARY
90         205: 'LONGVARBINARY', # IMAGE                   OLEOBJECT
91 }
92
93 DBCOLUMNFLAGS_WRITE = 0x4
94 DBCOLUMNFLAGS_WRITEUNKNOWN = 0x8
95 DBCOLUMNFLAGS_ISFIXEDLENGTH = 0x10
96 DBCOLUMNFLAGS_ISNULLABLE = 0x20
97 DBCOLUMNFLAGS_MAYBENULL = 0x40
98 DBCOLUMNFLAGS_ISLONG = 0x80
99 DBCOLUMNFLAGS_ISROWID = 0x100
100 DBCOLUMNFLAGS_ISROWVER = 0x200
101 DBCOLUMNFLAGS_CACHEDEFERRED = 0x1000
102
103
104 class AdapterFromADO(adapters.AdapterFromDB):
105     """Coerce incoming values from ADO to Python datatypes."""
106    
107     encoding = 'ISO-8859-1'
108     epoch = datetime.datetime(1899, 12, 30)
109    
110     def timedelta_from_com(self, com_date):
111         """Return a valid datetime.timedelta from a COM date/time object."""
112         com_date = float(com_date)
113        
114         # MS Access represents dates and times as floats. If the value is
115         # before the epoch (12/30/1899), the seconds will be SUBTRACTED
116         # from the float. For example, -2.01 is in the morning and -2.99
117         # is in the evening of the same day. Therefore, when we split off
118         # our seconds we must use the abs value of the fractional portion.
119         neg = (com_date < 0)
120         com_date = abs(com_date)
121        
122         days = int(com_date)
123         # Must do both int() and round() or we'll be up to 1 second off.
124         secs = int(round(86400 * (com_date - days)))
125        
126         result = datetime.timedelta(days, secs)
127         if neg:
128             return -result
129         else:
130             return result
131    
132     def coerce_any_to_datetime_timedelta(self, value):
133         # Assume pywintypes.TimeType
134         return self.timedelta_from_com(value)
135    
136     def coerce_any_to_datetime_time(self, value):
137         t = self.timedelta_from_com(value)
138         if t.days:
139             raise ValueError("Time values greater than 23:59:59 not allowed.")
140         h, m = divmod(t.seconds, 3600)
141         m, s = divmod(m, 60)
142         return datetime.time(int(h), int(m), int(s))
143    
144     def datetime_from_com(self, com_date):
145         """Return a valid datetime.datetime from a COM date/time object."""
146         com_date = float(com_date)
147        
148         # MS Access represents dates and times as floats. If the value is
149         # before the epoch (12/30/1899), the seconds will be SUBTRACTED
150         # from the float. For example, -2.01 is in the morning and -2.99
151         # is in the evening of the same day. Therefore, when we split off
152         # our seconds we must use the abs value of the fractional portion.
153         # Note that we do this differently from timedelta_from_com,
154         # because there we need to subtract seconds, and here we add them.
155         days = int(com_date)
156        
157         # Must do both int() and round() or we'll be up to 1 second off.
158         secs = int(round(86400 * abs(com_date - days)))
159        
160         return self.epoch + datetime.timedelta(days, secs)
161    
162     def coerce_any_to_datetime_datetime(self, value):
163         if isinstance(value, basestring):
164             if value:
165                 try:
166                     return datetime.datetime(int(value[0:4]), int(value[4:6]),
167                                              int(value[6:8]))
168                 except Exception:
169                     raise ValueError("'%s' %s" % (value, type(value)))
170             else:
171                 return None
172         else:
173             # Illegal Date/Time values will crash the app when using
174             # value.Format(). Therefore, grab the float value and figure
175             # the date ourselves. Use 1-second resolution only.
176             return self.datetime_from_com(value)
177    
178     def coerce_any_to_datetime_date(self, value):
179         if isinstance(value, basestring):
180             if value:
181                 try:
182                     return datetime.date(int(value[0:4]), int(value[4:6]),
183                                          int(value[6:8]))
184                 except Exception:
185                     raise ValueError("'%s' %s" % (value, type(value)))
186             else:
187                 return None
188         else:
189             value = float(value)
190             days = int(value)
191             return self.epoch.date() + datetime.timedelta(days)
192    
193     def coerce_any_to_decimal_Decimal(self, value):
194         # pywin32 build 205 began support for returning
195         # COM Currency objects as decimal objects.
196         # See http://pywin32.cvs.sourceforge.net/pywin32/pywin32/CHANGES.txt?view=markup
197         if not isinstance(value, typerefs.decimal.Decimal):
198             value = str(value)
199             value = typerefs.decimal.Decimal(value)
200         return value
201    
202     def coerce_CURRENCY_to_float(self, value):
203         if isinstance(value, tuple):
204             # See http://groups.google.com/group/comp.lang.python/
205             #           browse_frm/thread/fed03c64735c9e9c
206             value = map(long, value)
207             return ((value[1] & 0xFFFFFFFFL) | (value[0] << 32)) / 1e4
208         return float(value)
209    
210     def coerce_CURRENCY_to_decimal_Decimal(self, value):
211         # pywin32 build 205 began support for returning
212         # COM Currency objects as decimal objects.
213         # See http://pywin32.cvs.sourceforge.net/pywin32/pywin32/CHANGES.txt?view=markup
214         if not isinstance(value, typerefs.decimal.Decimal):
215             # See http://groups.google.com/group/comp.lang.python/
216             #           browse_frm/thread/fed03c64735c9e9c
217             value = map(long, value)
218             value = (value[1] & 0xFFFFFFFFL) | (value[0] << 32)
219             return typerefs.decimal.Decimal(value) / 10000
220         return value
221    
222     def coerce_CURRENCY_to_fixedpoint_FixedPoint(self, value):
223         if isinstance(value, typerefs.decimal.Decimal):
224             value = str(value)
225             scale = 0
226             atoms = value.rsplit(".", 1)
227             if len(atoms) > 1:
228                 scale = len(atoms[-1])
229             return typerefs.fixedpoint.FixedPoint(value, scale)
230         else:
231             # See http://groups.google.com/group/comp.lang.python/
232             #           browse_frm/thread/fed03c64735c9e9c
233             value = map(long, value)
234             value = (value[1] & 0xFFFFFFFFL) | (value[0] << 32)
235             return typerefs.fixedpoint.FixedPoint(value, 4) / 1e4
236    
237     def coerce_any_to_unicode(self, value):
238         if isinstance(value, unicode):
239             # For some reason, value is already a unicode object.
240             return value
241        
242         if isinstance(value, (basestring, buffer)):
243             try:
244                 return unicode(value, self.encoding)
245             except UnicodeError, exc:
246                 exc.args += (type(value),)
247         return unicode(value)
248
249
250
251 class ADOSQLDecompiler(decompile.SQLDecompiler):
252    
253     def visit_COMPARE_OP(self, lo, hi):
254         op2, op1 = self.stack.pop(), self.stack.pop()
255         if op1 is decompile.cannot_represent or op2 is decompile.cannot_represent:
256             self.stack.append(decompile.cannot_represent)
257             return
258        
259         op = lo + (hi << 8)
260         if op in (6, 7):     # in, not in
261             # Looking for text in a field. Use Like (reverse terms).
262             # LIKE is case-insensitive in MS SQL Server (and there
263             # doesn't seem to be a way around it). Use icontainedby
264             # and just mark imperfect.
265             value = self.builtins_icontainedby(op1, op2)
266             if op == 7:
267                 value.sql = "NOT " + value.sql
268             self.stack.append(value)
269             self.imperfect = True
270         elif op1.sql == 'NULL':
271             if op in (2, 8):    # '==', is
272                 self.stack.append(self.get_expr(op2.sql + " IS NULL", bool))
273             elif op in (3, 9):  # '!=', 'is not'
274                 self.stack.append(self.get_expr(op2.sql + " IS NOT NULL", bool))
275             else:
276                 raise ValueError("Non-equality Null comparisons not allowed.")
277         elif op2.sql == 'NULL':
278             if op in (2, 8):    # '==', 'is'
279                 self.stack.append(self.get_expr(op1.sql + " IS NULL", bool))
280             elif op in (3, 9):  # '!=', 'is not'
281                 self.stack.append(self.get_expr(op1.sql + " IS NOT NULL", bool))
282             else:
283                 raise ValueError("Non-equality Null comparisons not allowed.")
284         else:
285             try:
286                 op1, op2 = self._compare_constants(op1, op2)
287             except TypeError:
288                 self.stack.append(decompile.cannot_represent)
289                 self.imperfect = True
290                 return
291            
292             if (isinstance(op2, decompile.SQLExpression)
293                 and issubclass(op2.pytype, basestring)):
294                 atom = self._compare_strings(op1, op, op2)
295                 if atom is not None:
296                     self.stack.append(atom)
297                     return
298            
299             e = op1.sql + " " + self.sql_cmp_op[op] + " " + op2.sql
300             self.stack.append(self.get_expr(e, bool))
301    
302     def _compare_strings(self, op1, op, op2):
303         # ADO comparison operators for strings are case-insensitive
304         # by default. Rather than determine which columns in the DB
305         # might be case-sensitive, just flag them all as imperfect.
306         # TODO: might be possible to cast both to varbinary, but
307         # that may cause problems with unicode columns.
308         self.imperfect = True
309    
310     def binary_op(self, op):
311         op2, op1 = self.stack.pop(), self.stack.pop()
312         if op1 is decompile.cannot_represent or op2 is decompile.cannot_represent:
313             self.stack.append(decompile.cannot_represent)
314             return
315        
316         t1, t2 = op1.pytype, op2.pytype
317        
318         newsql = None
319         if t1 is datetime.date:
320             if t2 is datetime.date:
321                 if op == "-":
322                     newsql = self.DATEDIFF(op1.sql, op2.sql)
323             elif t2 is datetime.timedelta:
324                 if op == "+":
325                     newsql = self.DATEADD(op1.sql, op2.sql)
326                 elif op == "-":
327                     newsql = self.DATESUB(op1.sql, op2.sql)
328         elif t1 is datetime.datetime:
329             if t2 is datetime.datetime:
330                 if op == "-":
331                     newsql = self.DATETIMEDIFF(op1.sql, op2.sql)
332             elif t2 is datetime.timedelta:
333                 if op == "+":
334                     newsql = self.DATETIMEADD(op1.sql, op2.sql)
335                 elif op == "-":
336                     newsql = self.DATETIMESUB(op1.sql, op2.sql)
337         elif t1 is datetime.timedelta:
338             if t2 is datetime.timedelta:
339                 newsql = self.TIMEDELTAADD(op1, op, op2)
340             else:
341                 if op == "+":
342                     if t2 is datetime.date:
343                         newsql = self.DATEADD(op2.sql, op1.sql)
344                     elif t2 is datetime.datetime:
345                         newsql = self.DATETIMEADD(op2.sql, op1.sql)
346         else:
347             newsql = "(%s %s %s)" % (op1.sql, op, op2.sql)
348        
349         if newsql is None:
350             raise TypeError("unsupported operand type(s) for %s: "
351                             "%r and %r" % (op, t1, t2))
352        
353         # re-use op1
354         op1.pytype = self.result_type[(t1, op, t2)]
355         op1.sql = newsql
356         if not op1.name.startswith("expr_"):
357             op1.name = "expr_%s" % op1.name
358         self.stack.append(op1)
359    
360     # --------------------------- Dispatchees --------------------------- #
361    
362     def attr_startswith(self, tos, arg):
363         self.imperfect = True
364         return self.get_expr(tos.sql + " LIKE '" + self.adapter.escape_like(arg.sql) + "%'", bool)
365    
366     def attr_endswith(self, tos, arg):
367         self.imperfect = True
368         return self.get_expr(tos.sql + " LIKE '%" + self.adapter.escape_like(arg.sql) + "'", bool)
369    
370     def containedby(self, op1, op2):
371         self.imperfect = True
372         return decompile.SQLDecompiler.containedby(self, op1, op2)
373    
374     def builtins_icontainedby(self, op1, op2):
375         # LIKE is already case-insensitive in MS SQL Server;
376         # so don't use LOWER().
377         if op1.value is not None:
378             # Looking for text in a field. Use Like (reverse terms).
379             return self.get_expr(op2.sql + " LIKE '%" +
380                                  self.adapter.escape_like(op1.sql)
381                                  + "%'", bool)
382         else:
383             # Looking for field in (a, b, c)
384             atoms = [self.adapter.coerce(x) for x in op2.value]
385             if atoms:
386                 return self.get_expr("%s IN (%s)" %
387                                      (op1.sql, ", ".join(atoms)), bool)
388             else:
389                 # Nothing will match the empty list, so return none.
390                 return self.adapter.false_expr
391         return value
392    
393     def builtins_istartswith(self, x, y):
394         # Like is already case-insensitive in ADO; so don't use LOWER().
395         return self.get_expr(x.sql + " LIKE '" + self.adapter.escape_like(y.sql) + "%'", bool)
396    
397     def builtins_iendswith(self, x, y):
398         # Like is already case-insensitive in ADO; so don't use LOWER().
399         return self.get_expr(x.sql + " LIKE '%" + self.adapter.escape_like(y.sql) + "'", bool)
400    
401     def builtins_ieq(self, x, y):
402         # = is already case-insensitive in ADO.
403         return self.get_expr(x.sql + " = " + y.sql, bool)
404    
405     def func__builtin___len(self, x):
406         return self.get_expr("Len(" + x.sql + ")", int)
407
408
409 class ADOTable(geniusql.Table):
410    
411     def _add_column(self, column):
412         """Internal function to add the column to the database."""
413         coldef = self.schema.columnclause(column)
414         # SQL Server doesn't use the "COLUMN" keyword with "ADD"
415         self.schema.db.execute_ddl("ALTER TABLE %s ADD %s;" %
416                                    (self.qname, coldef))
417    
418     def _rename(self, oldcol, newcol):
419         conn = self.schema.db.connections.get()
420         try:
421             cat = win32com.client.Dispatch(r'ADOX.Catalog')
422             cat.ActiveConnection = conn
423             cat.Tables(self.name).Columns(oldcol.name).Name = newcol.name
424         finally:
425             conn = None
426             cat = None
427    
428     def drop_primary(self):
429         """Remove any PRIMARY KEY for this Table."""
430         db = self.schema.db
431        
432         data, _ = db.fetch(adSchemaIndexes, schema=True)
433         pknames = [row[5] for row in data
434                    if (self.name == row[2]) and row[6]]
435         for name in pknames:
436             db.execute('ALTER TABLE %s DROP CONSTRAINT %s;'
437                        % (self.qname, name))
438
439
440 def connatoms(connstring):
441     atoms = {}
442     for pair in connstring.split(";"):
443         if pair:
444             k, v = pair.split("=", 1)
445             atoms[k.upper().strip()] = v.strip()
446     return atoms
447
448
449 class ADOConnectionManager(conns.ConnectionManager):
450    
451     # the amount of time to try to close the db connection
452     # before raising an exception
453     shutdowntimeout = 1 # sec.
454    
455     ConnectionTimeout = None
456     CommandTimeout = None
457    
458     def _get_conn(self, master=False):
459         if master:
460             # Must shut down all connections to avoid
461             # "being accessed by other users" error.
462             self.shutdown()
463            
464             atoms = connatoms(self.Connect)
465             atoms['INITIAL CATALOG'] = "tempdb"
466             connectstr = "; ".join(["%s=%s" % (k, v)
467                                     for k, v in atoms.iteritems()])
468         else:
469             connectstr = self.Connect
470        
471         conn = win32com.client.Dispatch(r'ADODB.Connection')
472         conn.Open(connectstr)
473         if self.ConnectionTimeout is not None:
474             conn.ConnectionTimeout = self.ConnectionTimeout
475         if self.CommandTimeout is not None:
476             conn.CommandTimeout = self.CommandTimeout
477         return conn
478    
479     def _del_conn(self, conn):
480         for trial in xrange(self.shutdowntimeout * 10):
481             try:
482                 # This may raise "Operation cannot be performed
483                 # while executing asynchronously"
484                 # if a prior operation has not yet completed.
485                 conn.Close()
486                 return
487             except pywintypes.com_error, e:
488                 try:
489                     ecode = e.args[2][-1]
490                 except IndexError:
491                     ecode = None
492                 if ecode == -2146824577:
493                     # "Operation cannot be performed while executing asynchronously"
494                     # Try again...
495                     time.sleep(0.1)
496                     continue
497                 raise
498    
499     #                            Transactions                             #
500    
501     def start(self, isolation=None):
502         """Start a transaction. Not needed if self.implicit_trans is True."""
503         conn = self.get(started=True)
504         self.db.execute("BEGIN TRANSACTION;", conn)
505         self.isolate(conn, isolation)
506
507
508 class ADOSchema(geniusql.Schema):
509    
510     tableclass = ADOTable
511    
512     #                              Discovery                              #
513    
514     def _get_tables(self, conn=None):
515         # cols will be
516         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202),
517         # (u'TABLE_TYPE', 202), (u'TABLE_GUID', 72), (u'DESCRIPTION', 203),
518         # (u'TABLE_PROPID', 19), (u'DATE_CREATED', 7), (u'DATE_MODIFIED', 7)]
519         data, _ = self.db.fetch(adSchemaTables, conn=conn, schema=True)
520         return [self.tableclass(str(row[2]), self.db.quote(str(row[2])),
521                                 self, created=True)
522                 for row in data
523                 # Ignore linked and system tables
524                 if row[3] == "TABLE"]
525    
526     def _get_table(self, tablename, conn=None):
527         # cols will be
528         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202),
529         # (u'TABLE_TYPE', 202), (u'TABLE_GUID', 72), (u'DESCRIPTION', 203),
530         # (u'TABLE_PROPID', 19), (u'DATE_CREATED', 7), (u'DATE_MODIFIED', 7)]
531         data, _ = self.db.fetch(adSchemaTables, conn=conn, schema=True)
532         for row in data:
533             name = str(row[2])
534             if name == tablename:
535                 return self.tableclass(name, self.db.quote(name),
536                                        self, created=True)
537         raise errors.MappingError(tablename)
538    
539     def _get_columns(self, tablename, conn=None):
540         # For some reason, adSchemaPrimaryKeys would only return a single
541         # record for a PK that had multiple columns. Use adSchemaIndexes.
542         # coldefs will be:
543         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202),
544         # (u'INDEX_CATALOG', 202), (u'INDEX_SCHEMA', 202), (u'INDEX_NAME', 202),
545         # (u'PRIMARY_KEY', 11), (u'UNIQUE', 11), (u'CLUSTERED', 11), (u'TYPE', 18),
546         # (u'FILL_FACTOR', 3), (u'INITIAL_SIZE', 3), (u'NULLS', 3),
547         # (u'SORT_BOOKMARKS', 11), (u'AUTO_UPDATE', 11), (u'NULL_COLLATION', 3),
548         # (u'ORDINAL_POSITION', 19), (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72),
549         # (u'COLUMN_PROPID', 19), (u'COLLATION', 2), (u'CARDINALITY', 21),
550         # (u'PAGES', 3), (u'FILTER_CONDITION', 202), (u'INTEGRATED', 11)]
551         data, _ = self.db.fetch(adSchemaIndexes, conn=conn, schema=True)
552         pknames = [row[17] for row in data
553                    if (tablename == row[2]) and row[6]]
554        
555         # columns will be
556         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202),
557         # (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72), (u'COLUMN_PROPID', 19),
558         # (u'ORDINAL_POSITION', 19), (u'COLUMN_HASDEFAULT', 11),
559         # (u'COLUMN_DEFAULT', 203), (u'COLUMN_FLAGS', 19), (u'IS_NULLABLE', 11),
560         # (u'DATA_TYPE', 18), (u'TYPE_GUID', 72), (u'CHARACTER_MAXIMUM_LENGTH', 19),
561         # (u'CHARACTER_OCTET_LENGTH', 19), (u'NUMERIC_PRECISION', 18),
562         # (u'NUMERIC_SCALE', 2), (u'DATETIME_PRECISION', 19),
563         # (u'CHARACTER_SET_CATALOG', 202), (u'CHARACTER_SET_SCHEMA', 202),
564         # (u'CHARACTER_SET_NAME', 202), (u'COLLATION_CATALOG', 202),
565         # (u'COLLATION_SCHEMA', 202), (u'COLLATION_NAME', 202),
566         # (u'DOMAIN_CATALOG', 202), (u'DOMAIN_SCHEMA', 202),
567         # (u'DOMAIN_NAME', 202), (u'DESCRIPTION', 203)]
568         data, _ = self.db.fetch(adSchemaColumns, conn=conn, schema=True)
569        
570         cols = []
571         get_pytype = self.db.typeadapter.python_type
572         for row in data:
573             # I tried passing criteria to OpenSchema, but passing None is
574             # not the same as passing pythoncom.Empty (which errors).
575             if row[2] != tablename:
576                 continue
577            
578             dbtype = dbtypes[row[11]]
579             pytype = get_pytype(dbtype)
580            
581             default = row[8]
582             if default is not None:
583                 if issubclass(pytype, (int, long, float)):
584                     # We may have stuck extraneous quotes in the default
585                     # value when using numeric defaults with MSAccess.
586                     if default.startswith("'") and default.endswith("'"):
587                         default = default[1:-1]
588                 default = pytype(default)
589            
590             name = str(row[3])
591             c = geniusql.Column(pytype, dbtype,
592                                 default, hints={}, key=(name in pknames),
593                                 name=name, qname=self.db.quote(name))
594            
595             # This only works for SQL Server. The MSAccessDatabase will
596             # wrap this method and override autoincrement.
597             colflags = int(row[9])
598             if ((colflags & DBCOLUMNFLAGS_ISFIXEDLENGTH)
599                 and not (colflags & DBCOLUMNFLAGS_WRITE)):
600                 c.autoincrement = True
601            
602             if dbtype in ("SMALLINT", "INTEGER", "TINYINT",
603                           "UNSIGNEDTINYINT", "UNSIGNEDSMALLINT",
604                           "UNSIGNEDINT", "BIGINT", "UNSIGNEDBIGINT"):
605                 c.hints['bytes'] = row[15]
606             elif dbtype in ("SINGLE", "DOUBLE"):
607                 c.hints['precision'] = row[15]
608                 c.hints['scale'] = row[16]
609             elif dbtype == "CURRENCY":
610                 # CURRENCY allows 15 places to the left of the decimal point,
611                 # and 4 places to the right.
612                 c.hints['precision'] = 19
613                 c.hints['scale'] = 4
614             elif dbtype in ("DECIMAL", "NUMERIC"):
615                 c.hints['precision'] = row[15]
616                 c.hints['scale'] = row[16]
617                 c.dbtype = "%s(%s, %s)" % (dbtype, row[15], row[16])
618             elif dbtype in ("BSTR", "VARIANT", "BINARY", "CHAR",
619                             "VARCHAR", "VARBINARY", "WCHAR", "VARWCHAR"):
620                 if row[13]:
621                     # row[13] will be a float
622                     c.hints['bytes'] = b = int(row[13])
623                 else:
624                     # I'm kinda guessing on this. If we use "MEMO" in an
625                     # MSAccess CREATE statement, it comes back as "WCHAR",
626                     # and seems to support over 65536 bytes.
627                     c.hints['bytes'] = b = (2 ** 31) - 1
628                 c.dbtype = "%s(%s)" % (c.dbtype, b)
629             elif dbtype in ("LONGVARCHAR", "LONGVARBINARY", "LONGVARWCHAR"):
630                 if row[13]:
631                     # row[13] will be a float
632                     c.hints['bytes'] = b = int(row[13])
633                     c.dbtype = "%s(%s)" % (c.dbtype, b)
634                 else:
635                     c.hints['bytes'] = 65535
636            
637             cols.append(c)
638         return cols
639    
640     def _get_indices(self, tablename=None, conn=None):
641         # cols will be
642         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202),
643         # (u'INDEX_CATALOG', 202), (u'INDEX_SCHEMA', 202), (u'INDEX_NAME', 202),
644         # (u'PRIMARY_KEY', 11), (u'UNIQUE', 11), (u'CLUSTERED', 11), (u'TYPE', 18),
645         # (u'FILL_FACTOR', 3), (u'INITIAL_SIZE', 3), (u'NULLS', 3),
646         # (u'SORT_BOOKMARKS', 11), (u'AUTO_UPDATE', 11), (u'NULL_COLLATION', 3),
647         # (u'ORDINAL_POSITION', 19), (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72),
648         # (u'COLUMN_PROPID', 19), (u'COLLATION', 2), (u'CARDINALITY', 21),
649         # (u'PAGES', 3), (u'FILTER_CONDITION', 202), (u'INTEGRATED', 11)]
650         data, _ = self.db.fetch(adSchemaIndexes, conn=conn, schema=True)
651         indices = []
652         for row in data:
653             # I tried passing criteria to OpenSchema, but passing None is
654             # not the same as passing pythoncom.Empty (which errors).
655             if tablename and row[2] != tablename:
656                 continue
657             i = geniusql.Index(row[5], self.db.quote(row[5]),
658                                row[2], row[17], row[7])
659             indices.append(i)
660         return indices
661    
662     #                              Container                              #
663    
664     def _rename(self, oldtable, newtable):
665         conn = self.db.connections.get()
666         try:
667             cat = win32com.client.Dispatch(r'ADOX.Catalog')
668             cat.ActiveConnection = conn
669             cat.Tables(oldtable.name).Name = newtable.name
670         finally:
671             conn = None
672             cat = None
673
674
675 class ADOTypeAdapter(adapters.TypeAdapter):
676    
677     _reverse_types = adapters.TypeAdapter._reverse_types.copy()
678     _reverse_types.update({
679         "DBDATE": datetime.date,
680         "DBTIME": datetime.time,
681         "DBTIMESTAMP": datetime.datetime,
682        
683         "UNSIGNEDTINYINT": int,
684         "UNSIGNEDSMALLINT": int,
685         "UNSIGNEDINT": int,
686         "BIT": bool,
687        
688         "UNSIGNEDBIGINT": long,
689        
690         "BSTR": str,
691         "VARIANT": str,
692         "BINARY": str,
693         "LONGVARCHAR": str,
694         "VARBINARY": str,
695         "LONGVARBINARY": str,
696        
697         "WCHAR": unicode,
698         "VARWCHAR": unicode,
699         "LONGVARWCHAR": unicode,
700         })
701    
702     if typerefs.decimal:
703         _reverse_types["CURRENCY"] = typerefs.decimal.Decimal
704     elif typerefs.fixedpoint:
705         _reverse_types["CURRENCY"] = typerefs.fixedpoint.FixedPoint
706
707
708 class ADODatabase(geniusql.Database):
709    
710     decompiler = ADOSQLDecompiler
711     adapterfromdb = AdapterFromADO()
712     typeadapter = ADOTypeAdapter()
713    
714     #                               Naming                                #
715    
716     def quote(self, name):
717         """Return name, quoted for use in an SQL statement."""
718         return '[' + name + ']'
719    
720     def execute(self, query, conn=None):
721         if conn is None:
722             conn = self.connections.get()
723         if isinstance(query, unicode):
724             query = query.encode(self.adaptertosql.encoding)
725        
726         self.log(query)
727         try:
728             bareconn = conn
729             if hasattr(conn, 'conn'):
730                 # 'conn' is a ConnectionWrapper object, which .Open
731                 # won't accept. Pass the unwrapped connection instead.
732                 # Note that we CANNOT write "conn = conn.conn", because
733                 # if we called get() above, we'd lose our only
734                 # reference to the wrapper and our weakref callback
735                 # would close the conn before we've executed the SQL.
736                 bareconn = conn.conn
737            
738             # Call Execute directly, skipping win32com overhead.
739             bareconn._oleobj_.InvokeTypes(6, 0, 1, (9, 0),
740                                           ((8, 1), (16396, 18), (3, 49)),
741                                           query, pythoncom.Missing, -1)
742         except pywintypes.com_error, x:
743             x.args += (query, )
744             conn = None
745             raise
746    
747     def fetch(self, query, conn=None, schema=False):
748         """fetch(query, conn=None) -> rowdata, columns."""
749         if conn is None:
750             conn = self.connections.get()
751        
752         try:
753             if schema:
754                 # Call OpenSchema(query) directly, skipping win32com overhead.
755                 res = conn._oleobj_.InvokeTypes(19, 0, 1, (9, 0),
756                                                 ((3, 1), (12, 17), (12, 17)),
757                                                 query, Empty, Empty)
758             else:
759                 self.log(query)
760                 bareconn = conn
761                 if hasattr(conn, 'conn'):
762                     # 'conn' is a ConnectionWrapper object, which .Open
763                     # won't accept. Pass the unwrapped connection instead.
764                     bareconn = conn.conn
765                
766                 # Call conn.Open(query) directly, skipping win32com overhead.
767                 res, rows_affected = bareconn._oleobj_.InvokeTypes(6, 0, 1, (9, 0),
768                                                 ((8, 1), (16396, 18), (3, 49)),
769                                                 # *args =
770                                                 query, pythoncom.Missing, -1)
771         except pywintypes.com_error, x:
772             try:
773                 # Close
774                 res.InvokeTypes(*Recordset_Close)
775             except:
776                 pass
777             res = None
778             x.args += (query, )
779             conn = None
780             # "raise x" here or we could get the traceback of the inner try.
781             raise x
782        
783         # Using xrange(Count) is slightly faster than "for x in resFields".
784         resFields = res.InvokeTypes(*Recordset_Fields)
785         fieldcount = resFields.InvokeTypes(*Fields_Count)
786         columns = []
787         for i in xrange(fieldcount):
788             # Wow. Calling this directly (instead of resFields(i))
789             # results in a 29% speedup for a 1-row fetch() of 48 fields.
790             x = resFields.InvokeTypes(0, 0, 2, (9, 0), ((12, 1),), i)
791            
792             # Wow. Calling these directly (instead of x.Name, x.Type)
793             # results in a 40% speedup for a 1-row fetch() of 48 fields.
794             name = x.InvokeTypes(*Field_Name)
795             typ = x.InvokeTypes(*Field_Type)
796             columns.append((name, typ))
797        
798         data = []
799         if not (res.InvokeTypes(*BOF) and res.InvokeTypes(*EOF)):
800             # We tried .MoveNext() and lots of Fields.Item() calls.
801             # Using GetRows() beats that time by about 2/3.
802             # Inlining GetRows results in a 14% speedup for fetch().
803             data = res.InvokeTypes(*Recordset_GetRows)
804            
805             # Convert cols x rows -> rows x cols
806             data = zip(*data)
807         try:
808             # Close
809             res.InvokeTypes(*Recordset_Close)
810         except:
811             pass
812         conn = None
813        
814         return data, columns
815
816
817
818 ###########################################################################
819 ##                                                                       ##
820 ##                             SQL Server                                ##
821 ##                                                                       ##
822 ###########################################################################
823
824
825 # "Sure, there are two 4-byte integers stored. But they are
826 # packed together into a BINARY(8). The first 4-byte being
827 # the elapsed number days since SQL Server's base date of
828 # 1900-01-01. The Second 4-bytes Store the Time of Day
829 # Represented as the Number of Milliseconds After Midnight."
830 # http://www.sql-server-performance.com/fk_datetime.asp
831
832 # Note also that SQL Server allows DATETIME in the range:
833 # "1753-01-01 00:00:00.0" to "9999-12-31 23:59:59.997".
834
835
836 class ADOSQLDecompiler_SQLServer(ADOSQLDecompiler):
837    
838     def _compare_strings(self, op1, op, op2):
839         # ADO comparison operators for strings are case-insensitive.
840         if op < 6:
841             # ('<', '<=', '==', '!=', '>', '>=')
842             # Some operations on strings can be emulated with the
843             # Convert function.
844             return self.get_expr("Convert(binary, %s) %s Convert(binary, %s)"
845                                  % (op1.sql, self.sql_cmp_op[op], op2.sql),
846                                  bool)
847         else:
848             return ADOSQLDecompiler._compare_strings(self, op1, op, op2)
849    
850     def DATEADD(dt, td):
851         """Return the SQL to add a timedelta to a date."""
852         # Days, seconds seems like a good way to avoid overflow.
853         return ("DATEADD(dd, FLOOR(%s / 86400), "
854                 "DATEADD(ss, (%s %% 86400), %s))"
855                 % (td, td, dt))
856     DATEADD = staticmethod(DATEADD)
857    
858     def DATESUB(dt, td):
859         """Return the SQL to subtract a timedelta from a date."""
860         return "(%s - FLOOR(%s / 86400.0))" % (dt, td)
861     DATESUB = staticmethod(DATESUB)
862    
863     def DATEDIFF(d1, d2):
864         """Return the SQL to subtract one date from another."""
865         # Amazing what a difference a little ".0" can make.
866         return "CAST(DATEDIFF(dd, %s, %s) * 86400.0 AS NUMERIC)" % (d2, d1)
867     DATEDIFF = staticmethod(DATEDIFF)
868    
869     def DATETIMEADD(dt, td):
870         """Return the SQL to add a timedelta to a datetime."""
871         return "(%s + (%s / 86400.0))" % (dt, td)
872     DATETIMEADD = staticmethod(DATETIMEADD)
873    
874     def DATETIMEDIFF(d1, d2):
875         """Return the SQL to subtract one datetime from another."""
876         return "CAST(CAST(%s - %s AS FLOAT) * 86400 AS NUMERIC)" % (d1, d2)
877     DATETIMEDIFF = staticmethod(DATETIMEDIFF)
878    
879     def DATETIMESUB(dt, td):
880         """Return the SQL to subtract a timedelta from a datetime."""
881         return "(%s - (%s / 86400.0))" % (dt, td)
882     DATETIMESUB = staticmethod(DATETIMESUB)
883    
884     def TIMEDELTAADD(op1, op, op2):
885         return "(%s %s %s)" % (op1.sql, op, op2.sql)
886     TIMEDELTAADD = staticmethod(TIMEDELTAADD)
887    
888     def builtins_now(self):
889         return self.get_expr("GETDATE()", datetime.datetime)
890    
891     def builtins_today(self):
892         return self.get_expr("DATEADD(dd, DATEDIFF(dd, 0, getdate()), 0)",
893                              datetime.date)
894    
895     def builtins_year(self, x):
896         return self.get_expr("DATEPART(year, " + x.sql + ")", int)
897    
898     def builtins_month(self, x):
899         return self.get_expr("DATEPART(month, " + x.sql + ")", int)
900    
901     def builtins_day(self, x):
902         return self.get_expr("DATEPART(day, " + x.sql + ")", int)
903    
904     def builtins_utcnow(self):
905         return self.get_expr("GETUTCDATE()", datetime.datetime)
906
907
908 class AdapterToADOSQL_SQLServer(adapters.AdapterToSQL):
909    
910     encoding = 'ISO-8859-1'
911    
912     escapes = [("'", "''")]
913     like_escapes = [("[", "[[]"), ("%", "[%]"), ("_", "[_]"),
914                     ("?", "[?]"), ("#", "[#]")]
915    
916     # These are not the same as coerce_bool_to_any (which is used on one side of
917     # a comparison). Instead, these are used when the whole (sub)expression
918     # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3".
919     bool_true = "(1=1)"
920     bool_false = "(1=0)"
921    
922     def coerce_bool_to_any(self, value):
923         if value:
924             return '1'
925         return '0'
926    
927     def cast_VARCHAR_to_int(self, colref):
928         return ("(CASE WHEN ISNUMERIC(%s)=1 THEN CAST(%s AS int) END)"
929                 % (colref, colref))
930
931
932 class AdapterFromADOSQL_SQLServer(AdapterFromADO):
933    
934     def coerce_any_to_datetime_time(self, value):
935         # Floats returned from SQL Server will be 2 days off
936         # because its epoch is 2 days later than MS Access.
937         return AdapterFromADO.coerce_any_to_datetime_time(self, float(value) - 2)
938    
939     def coerce_any_to_datetime_timedelta(self, value):
940         # We're using the fallback type for timedelta (secs * 86400).
941         days, secs = divmod(long(value), 86400)
942         return datetime.timedelta(int(days), int(secs))
943
944
945 class TypeAdapter_SQLServer(ADOTypeAdapter):
946    
947     # Hm. Docs say 38, but I can't seem to get more than 12 working.
948     # They must mean 38 binary digits; math.log(2 ** 38, 10) = 11.4+
949     numeric_max_precision = 12
950     numeric_max_bytes = 6
951    
952     def coerce_bool(self, hints):
953         return "BIT"
954    
955     def coerce_datetime_datetime(self, hints):
956         return "DATETIME"
957    
958     def coerce_datetime_date(self, hints):
959         return "DATETIME"
960    
961     def coerce_datetime_time(self, hints):
962         return "DATETIME"
963    
964     def int_type(self, bytes):
965         """Return a datatype which can handle the given number of bytes."""
966         if bytes <= 2:
967             return "SMALLINT"
968         elif bytes <= 4:
969             return "INTEGER"
970         elif bytes <= 8:
971             # BIGINT is usually 8 bytes
972             return "BIGINT"
973         else:
974             # Anything larger than 8 bytes, use decimal/numeric.
975             return "NUMERIC(%s, 0)" % (bytes * 2)
976    
977     def coerce_str(self, hints):
978         # The bytes hint does not reflect the usual 4-byte base for varchar.
979         bytes = int(hints.get('bytes', 255))
980        
981         if bytes == 0 or bytes > 8000:
982             # Okay, what the @#$%& is wrong with Redmond??!?! We can't even
983             # compare TEXT or NTEXT fields??!? Fine. We'll deny such, and
984             # warn the deployer with less swearing and exclamation points.
985             errors.warn("You have defined a string property without "
986                         "limiting its length. Microsoft SQL Server does "
987                         "not allow comparisons on string fields larger "
988                         "than 8000 characters. Some of your data may be "
989                         "truncated.")
990             bytes = 8000
991        
992         # 8000 *bytes* is the absolute upper limit, based on T_SQL docs for
993         # varchar/varbinary. If there are further fields defined for the
994         # class, or the codepage uses a double-byte character set, we still
995         # might exceed the max size (8060) for a record. We could calc the
996         # total requested record size, and adjust accordingly. Meh.
997         return "VARCHAR(%s)" % bytes
998
999
1000 class SQLServerTable(ADOTable):
1001    
1002     def _rename(self, oldcol, newcol):
1003         self.schema.db.execute_ddl("EXEC sp_rename '%s.%s', '%s', 'COLUMN'" %
1004                                    (self.name, oldcol.name, newcol.name))
1005    
1006     def _grab_new_ids(self, idkeys, conn):
1007         """Insert a row using the table's SERIAL field."""
1008         # For some reason, using SCOPE_IDENTITY or IDENTITY failed (returned
1009         # None) when retrieving ID's just after a 99-thread-test ran. Moving
1010         # the multithreading test fixed it. IDENT_CURRENT worked regardless.
1011         data, _ = self.schema.db.fetch("SELECT IDENT_CURRENT('%s');"
1012                                        % self.qname, conn)
1013         return {idkeys[0]: data[0][0]}
1014
1015
1016 class SQLServerConnectionManager(ADOConnectionManager):
1017    
1018     default_isolation = "READ COMMITTED"
1019
1020
1021 class SQLServerSchema(ADOSchema):
1022    
1023     tableclass = SQLServerTable
1024    
1025     def create_database(self):
1026         conn = self.db.connections._get_conn(master=True)
1027         self.db.execute_ddl("CREATE DATABASE %s;" % self.qname, conn)
1028         conn.Close()
1029         self.clear()
1030    
1031     def drop_database(self):
1032         conn = self.db.connections._get_conn(master=True)
1033         self.db.execute_ddl("DROP DATABASE %s;" % self.qname, conn)
1034         conn.Close()
1035         self.clear()
1036    
1037     def columnclause(self, column):
1038         """Return a clause for the given column for CREATE or ALTER TABLE.
1039         
1040         This will be of the form:
1041             name type [DEFAULT x|IDENTITY(initial, 1) NOT NULL]
1042         """
1043         dbtype = column.dbtype
1044        
1045         clause = ""
1046         if column.autoincrement:
1047             if dbtype not in ("BOOLEAN", "SMALLINT", "INTEGER", "BIGINT"):
1048                 raise ValueError("SQL Server does not allow IDENTITY "
1049                                  "columns of type %r" % dbtype)
1050             clause = " IDENTITY(%s, 1) NOT NULL" % column.initial
1051         else:
1052             # SQL Server does not allow a column to have
1053             # both an IDENTITY clause and a DEFAULT clause.
1054             default = column.default or ""
1055             if default:
1056                 clause = self.db.adaptertosql.coerce(default, dbtype)
1057                 clause = " DEFAULT %s" % clause
1058        
1059         return '%s %s%s' % (column.qname, dbtype, clause)
1060
1061
1062 class SQLServerDatabase(ADODatabase):
1063    
1064     decompiler = ADOSQLDecompiler_SQLServer
1065     adaptertosql = AdapterToADOSQL_SQLServer()
1066     adapterfromdb = AdapterFromADOSQL_SQLServer()
1067     typeadapter = TypeAdapter_SQLServer()
1068     connectionmanager = SQLServerConnectionManager
1069     schemaclass = SQLServerSchema
1070    
1071     def __init__(self, **kwargs):
1072         ADODatabase.__init__(self, **kwargs)
1073         if "2005" in self.version():
1074             self.connections.isolation_levels.append("SNAPSHOT")
1075    
1076     def version(self):
1077         conn = self.connections._get_conn(master=True)
1078         adov = conn.Version
1079         data, coldefs = self.fetch("SELECT @@VERSION;", conn)
1080         sqlv, = data[0]
1081         conn.Close()
1082         del conn
1083         return "ADO Version: %s\n%s" % (adov, sqlv)
1084    
1085     def is_timeout_error(self, exc):
1086         """If the given exception instance is a lock timeout, return True.
1087         
1088         This should return True for errors which arise from transaction
1089         locking timeouts; for example, if the database prevents 'dirty
1090         reads' by raising an error.
1091         """
1092         # com_error: (-2147352567, 'Exception occurred.',
1093         #   (0, 'Microsoft OLE DB Provider for SQL Server',
1094         #    'Timeout expired', None, 0, -2147217871), None,
1095         #    "UPDATE [testVet] SET [City] = 'Tehachapi' ... ;")
1096         if not isinstance(exc, pywintypes.com_error):
1097             return False
1098         return exc.args[2][5] == -2147217871
1099
1100
1101
1102 ###########################################################################
1103 ##                                                                       ##
1104 ##                             MS Access                                 ##
1105 ##                                                                       ##
1106 ###########################################################################
1107
1108
1109 class ADOSQLDecompiler_MSAccess(ADOSQLDecompiler):
1110     sql_cmp_op = ('<', '<=', '=', '<>', '>', '>=', 'in', 'not in')
1111    
1112     epoch = datetime.datetime(1899, 12, 30)
1113    
1114     def _compare_strings(self, op1, op, op2):
1115         # ADO comparison operators for strings are case-insensitive.
1116         if op < 6:
1117             # ('<', '<=', '==', '!=', '>', '>=')
1118             # Some operations on strings can be emulated with the
1119             # StrComp function. Oddly enough, "StrComp(x, y) op 0"
1120             # is the same as "x op y" in most cases.
1121             return self.get_expr("StrComp(%s, %s) %s 0" %
1122                                  (op1.sql, op2.sql, self.sql_cmp_op[op]),
1123                                  bool)
1124         else:
1125             return ADOSQLDecompiler._compare_strings(self, op1, op, op2)
1126    
1127     def builtins_now(self):
1128         return self.get_expr("Now()", datetime.datetime)
1129    
1130     def builtins_today(self):
1131         return self.get_expr("DateValue(Now())", datetime.date)
1132    
1133     def builtins_year(self, x):
1134         return self.get_expr("Year(" + x.sql + ")", int)
1135    
1136     def builtins_month(self, x):
1137         return self.get_expr("Month(" + x.sql + ")", int)
1138    
1139     def builtins_day(self, x):
1140         return self.get_expr("Day(" + x.sql + ")", int)
1141    
1142     def DATEADD(dt, td):
1143         """Return the SQL to add a timedelta to a date."""
1144         # Important to use Fix (instead of CLng, for example)
1145         # for negative numbers.
1146         return "DateAdd('d', Fix(%s), %s)" % (td, dt)
1147     DATEADD = staticmethod(DATEADD)
1148    
1149     def DATEDIFF(d1, d2):
1150         """Return the SQL to subtract one date from another."""
1151         # Important to use Fix (instead of CLng, for example)
1152         # for negative numbers.
1153         return "CDate(Fix(%s) - Fix(%s))" % (d1, d2)
1154     DATEDIFF = staticmethod(DATEDIFF)
1155     DATESUB = DATEDIFF
1156    
1157     def DATETIMEADD(dt, td):
1158         """Return the SQL to add a timedelta to a datetime."""
1159         return "CDate(%s + %s)" % (dt, td)
1160     DATETIMEADD = staticmethod(DATETIMEADD)
1161    
1162     def DATETIMEDIFF(d1, d2):
1163         """Return the SQL to subtract one (datetime or date expr) from another."""
1164         return "CDate(%s - %s)" % (d1, d2)
1165     DATETIMEDIFF = staticmethod(DATETIMEDIFF)
1166     DATETIMESUB = DATETIMEDIFF
1167    
1168     def TIMEDELTAADD(op1, op, op2):
1169         return "CDate(%s %s %s)" % (op1.sql, op, op2.sql)
1170     TIMEDELTAADD = staticmethod(TIMEDELTAADD)
1171
1172
1173 class TypeAdapter_MSAccess(ADOTypeAdapter):
1174     # http://msdn2.microsoft.com/en-us/library/ms714540.aspx
1175     # http://office.microsoft.com/en-us/access/HP010322481033.aspx
1176    
1177     # Hm. Docs say 28/38, but I can't seem to get more than 12 working.
1178     numeric_max_precision = 12
1179     numeric_max_bytes = 6
1180    
1181     _reverse_types = ADOTypeAdapter._reverse_types.copy()
1182     _reverse_types.update({
1183         "LONG": int,
1184         "MEMO": str,
1185         })
1186    
1187     def coerce_bool(self, hints): return "BIT"
1188    
1189     def coerce_datetime_datetime(self, hints): return "DATETIME"
1190     def coerce_datetime_date(self, hints): return "DATETIME"
1191     def coerce_datetime_time(self, hints): return "DATETIME"
1192     def coerce_datetime_timedelta(self, hints): return "DATETIME"
1193    
1194     def int_type(self, bytes):
1195         if bytes <= 2:
1196             return "INTEGER"
1197         elif bytes <= 4:
1198             return "LONG"
1199         else:
1200             # Anything larger than 4 bytes, use decimal/numeric.
1201             return "DECIMAL"
1202    
1203     def coerce_str(self, hints):
1204         # The bytes hint shall not reflect the usual 4-byte base for varchar.
1205         bytes = int(hints.get('bytes', 255))
1206        
1207         # 255 chars is the upper limit for TEXT / VARCHAR in MS Access.
1208         if bytes == 0 or bytes > 255:
1209             # MEMO is 1 GB max when set programatically (only 64K when set
1210             # in Access UI). But then, 1 GB is the limit for the whole DB.
1211             # Note that OpenSchema will return a DATA_TYPE of "WCHAR".
1212             return "MEMO"
1213        
1214         return "VARCHAR(%s)" % bytes
1215
1216
1217 class AdapterToADOSQL_MSAccess(adapters.AdapterToSQL):
1218     """Coerce Expression constants to ADO SQL."""
1219    
1220     encoding = 'ISO-8859-1'
1221    
1222     escapes = [("'", "''")]
1223     like_escapes = [("[", "[[]"), ("%", "[%]"), ("_", "[_]"),
1224                     ("?", "[?]"), ("#", "[#]")]
1225    
1226     def coerce_datetime_datetime_to_any(self, value):
1227         return ('#%s/%s/%s %02d:%02d:%02d#' %
1228                 (value.month, value.day, value.year,
1229                  value.hour, value.minute, value.second))
1230    
1231     def coerce_datetime_date_to_any(self, value):
1232         return '#%s/%s/%s#' % (value.month, value.day, value.year)
1233    
1234     def coerce_datetime_time_to_any(self, value):
1235         return '#%02d:%02d:%02d#' % (value.hour, value.minute, value.second)
1236    
1237     def coerce_datetime_timedelta_to_any(self, value):
1238         # This took a lot of work to get right, because timedelta
1239         # seconds are positive even if the days are negative.
1240         # So is the fractional portion of a negative Access Date!
1241         # Very important we use repr here so we get all 17 decimal
1242         # digits in the float.
1243         return ("CDate(#12/30/1899# + (%r) + %r)" %
1244                 (value.days, (value.seconds / 86400.0)))
1245
1246
1247 class MSAccessTable(ADOTable):
1248    
1249     def delete(self, **inputs):
1250         """Delete all rows matching the given identifier inputs."""
1251         # MS Access needs an asterisk to delete
1252         self.schema.db.execute('DELETE * FROM %s WHERE %s;' %
1253                                (self.qname, self.id_clause(**inputs)))
1254    
1255     def delete_all(self, **inputs):
1256         """Delete all rows matching the given inputs."""
1257         # MS Access needs an asterisk to delete
1258         self.schema.db.execute('DELETE * FROM %s WHERE %s;' %
1259                                (self.qname, self.whereclause(**inputs)))
1260    
1261     def _grab_new_ids(self, idkeys, conn):
1262         data, _ = self.schema.db.fetch("SELECT @@IDENTITY;", conn)
1263         return {idkeys[0]: data[0][0]}
1264
1265
1266 class MSAccessConnectionManager(ADOConnectionManager):
1267    
1268     poolsize = 0
1269     default_isolation = "READ UNCOMMITTED"
1270     isolation_levels = ["READ UNCOMMITTED",]
1271    
1272     def _set_factory(self):
1273         # MS Access can't use a pool, because there doesn't seem
1274         # to be a commit timeout. See http://support.microsoft.com/kb/200300
1275         # for additional synchronization issues.
1276         self._factory = conns.SingleConnection(self._get_conn, self._del_conn)
1277    
1278     def isolate(self, conn, isolation=None):
1279         """Set the isolation level of the given connection.
1280         
1281         If 'isolation' is None, our default_isolation will be used for new
1282         connections. Valid values for the 'isolation' argument may be native
1283         values for your particular database. However, it is recommended you
1284         pass items from the global 'levels' list instead; these will be
1285         automatically replaced with native values.
1286         
1287         For many databases, this must be executed after START TRANSACTION.
1288         """
1289         if isolation is None:
1290             isolation = self.default_isolation
1291        
1292         if isinstance(isolation, _isolation.IsolationLevel):
1293             # Map the given IsolationLevel object to a native value.
1294             isolation = isolation.name
1295             if isolation not in self.isolation_levels:
1296                 raise ValueError("IsolationLevel %r not allowed by %s."
1297                                  % (isolation, self.__class__.__name__))
1298        
1299         # No action to take, since you can't actually set iso level.
1300         pass
1301
1302
1303 class MSAccessSchema(ADOSchema):
1304    
1305     tableclass = MSAccessTable
1306    
1307     def _get_columns(self, tablename, conn=None):
1308         cols = ADOSchema._get_columns(self, tablename, conn)
1309         if conn is None:
1310             conn = self.db.connections._factory()
1311        
1312         try:
1313             # Horrible hack to get autoincrement property
1314             query = "SELECT * FROM %s WHERE FALSE;" % self.db.quote(tablename)
1315             bareconn = conn
1316             if hasattr(conn, 'conn'):
1317                 # 'conn' is a ConnectionWrapper object, which .Open
1318                 # won't accept. Pass the unwrapped connection instead.
1319                 bareconn = conn.conn
1320            
1321             # Call conn.Open(query) directly, skipping win32com overhead.
1322             res, rows_affected = conn._oleobj_.InvokeTypes(6, 0, 1, (9, 0),
1323                                             ((8, 1), (16396, 18), (3, 49)),
1324                                             # *args =
1325                                             query, pythoncom.Missing, -1)
1326         except pywintypes.com_error, x:
1327             try:
1328                 res.InvokeTypes(*Recordset_Close)
1329             except:
1330                 pass
1331             res = None
1332             x.args += (query, )
1333             conn = None
1334            
1335             try:
1336                 if "no read permission" in x.args[2][2]:
1337                     conn = None
1338                     return []
1339             except IndexError:
1340                 pass
1341            
1342             # "raise x" here or we could get the traceback of the inner try.
1343             raise x
1344        
1345         resFields = res.InvokeTypes(*Recordset_Fields)
1346         for c in cols:
1347             f = resFields.InvokeTypes(0, 0, 2, (9, 0), ((12, 1),), c.name)
1348             fprops = f.InvokeTypes(*Field_Properties)
1349             fprop = fprops.InvokeTypes(0, 0, 2, (9, 0), ((12, 1), ), "ISAUTOINCREMENT")
1350             c.autoincrement = fprop.InvokeTypes(*Property_Value)
1351        
1352         try:
1353             res.InvokeTypes(*Recordset_Close)
1354         except:
1355             pass
1356         conn = None
1357        
1358         return cols
1359    
1360     def columnclause(self, column):
1361         """Return a clause for the given column for CREATE or ALTER TABLE.
1362         
1363         This will be of the form:
1364             name type [DEFAULT x|AUTOINCREMENT(initial, 1)]
1365         """
1366         dbtype = column.dbtype
1367        
1368         if column.autoincrement:
1369             # MS Access does not allow a column to have
1370             # both an AUTOINCREMENT clause and a DEFAULT clause.
1371             # It also needs no type in this case.
1372             dbtype = "AUTOINCREMENT(%s, 1)" % column.initial
1373         else:
1374             default = column.default or ""
1375             if default:
1376                 defspec = self.db.adaptertosql.coerce(default, dbtype)
1377                 if isinstance(default, (int, long)):
1378                     # Crazy quote hack to get a numeric default to work.
1379                     defspec = "'%s'" % defspec
1380                 dbtype = "%s DEFAULT %s" % (dbtype, defspec)
1381        
1382         return '%s %s' % (column.qname, dbtype)
1383    
1384     def create_database(self):
1385         # By not providing an Engine Type, it defaults to 5 = Access 2000.
1386         cat = win32com.client.Dispatch(r'ADOX.Catalog')
1387         cat.Create(self.db.connections.Connect)
1388         cat.ActiveConnection.Close()
1389         self.clear()
1390    
1391     def drop_database(self):
1392         # Must shut down our only connection to avoid
1393         # "Permission denied" error on os.remove call below.
1394         self.db.connections.shutdown()
1395        
1396         import os
1397         # This should accept relative or absolute paths
1398         if os.path.exists(self.name):
1399             os.remove(self.name)
1400        
1401         self.clear()
1402
1403
1404 class MSAccessDatabase(ADODatabase):
1405    
1406     decompiler = ADOSQLDecompiler_MSAccess
1407     adaptertosql = AdapterToADOSQL_MSAccess()
1408     typeadapter = TypeAdapter_MSAccess()
1409     connectionmanager = MSAccessConnectionManager
1410     schemaclass = MSAccessSchema
1411    
1412     def version(self):
1413         conn = win32com.client.Dispatch(r'ADODB.Connection')
1414         v = conn.Version
1415         del conn
1416         return "ADO Version: %s" % v
1417
1418
1419 def gen_py():
1420     """Auto generate .py support for ADO 2.7+"""
1421     print 'Please wait while support for ADO 2.7+ is verified...'
1422    
1423     # Microsoft ActiveX Data Objects 2.8 Library
1424     result = win32com.client.gencache.EnsureModule('{2A75196C-D9EB-4129-B803-931327F72D5C}', 0, 2, 8)
1425     if result is not None:
1426         return
1427    
1428     # Microsoft ActiveX Data Objects 2.7 Library
1429     result = win32com.client.gencache.EnsureModule('{EF53050B-882E-4776-B643-EDA472E8E3F2}', 0, 2, 7)
1430     if result is not None:
1431         return
1432    
1433     raise ImportError("ADO 2.7 support could not be imported/cached")
1434
1435
1436 if __name__ == '__main__':
1437     gen_py()
Note: See TracBrowser for help on using the browser.