Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storeado.py

Revision 242 (checked in by fumanchu, 7 years ago)

Some API changes to make the dejavu top level cleaner:

  1. Moved dejavu.LOGxxx -> dejavu.logflags.xxx.
  2. Moved CrossTab?, COUNT, SUM to analysis namespace only.
  3. Pushed errors and containers back down one level.

Unrelated: fixed decimal tests when prec > getcontext().prec.

  • Property svn:eol-style set to native
Line 
1 import sys
2 # Put COM in free-threaded mode. This first thread will have
3 # CoInitializeEx called automatically when pythoncom is imported.
4 sys.coinit_flags = 0
5 import pythoncom
6
7 import win32com.client
8 import pywintypes
9 import datetime
10
11 try:
12     import cPickle as pickle
13 except ImportError:
14     import pickle
15
16 import warnings
17
18 import dejavu
19 from dejavu import errors, logic, storage
20 from dejavu.storage import db
21
22 adOpenForwardOnly = 0
23 adOpenKeyset = 1
24 adOpenDynamic = 2
25 adOpenStatic = 3
26
27 adLockReadOnly = 1
28 adLockPessimistic = 2
29 adLockOptimistic = 3
30 adLockBatchOptimistic = 4
31
32 adSchemaColumns = 4
33 adSchemaIndexes = 12
34 adSchemaTables = 20
35
36 adUseClient = 3
37
38 # 12/30/1899, the zero-Date for ADO = 693594
39 zeroHour = datetime.date(1899, 12, 30).toordinal()
40
41 dbtypes = {
42     0: 'EMPTY',                     2: 'SMALLINT',
43     3: 'INTEGER',                   4: 'SINGLE',
44     5: 'DOUBLE',                    6: 'CURRENCY',
45     7: 'DATE',                      8: 'BSTR',
46     9: 'IDISPATCH',                 10: 'ERROR',
47     11: 'BOOLEAN',                  12: 'VARIANT',
48     13: 'IUNKNOWN',                 14: 'DECIMAL',
49     16: 'TINYINT',                  17: 'UNSIGNEDTINYINT',
50     18: 'UNSIGNEDSMALLINT',         19: 'UNSIGNEDINT',
51     20: 'BIGINT',                   21: 'UNSIGNEDBIGINT',
52     72: 'GUID',                     128: 'BINARY',
53     129: 'CHAR',                    130: 'WCHAR',
54     131: 'NUMERIC',                 132: 'USERDEFINED',
55     133: 'DBDATE',                  134: 'DBTIME',
56     135: 'DBTIMESTAMP',             200: 'VARCHAR',
57     201: 'LONGVARCHAR',             202: 'VARWCHAR',
58     203: 'LONGVARWCHAR',            204: 'VARBINARY',
59     205: 'LONGVARBINARY'
60 }
61
62 DBCOLUMNFLAGS_WRITE = 0x4
63 DBCOLUMNFLAGS_WRITEUNKNOWN = 0x8
64 DBCOLUMNFLAGS_ISFIXEDLENGTH = 0x10
65 DBCOLUMNFLAGS_ISNULLABLE = 0x20
66 DBCOLUMNFLAGS_MAYBENULL = 0x40
67 DBCOLUMNFLAGS_ISLONG = 0x80
68 DBCOLUMNFLAGS_ISROWID = 0x100
69 DBCOLUMNFLAGS_ISROWVER = 0x200
70 DBCOLUMNFLAGS_CACHEDEFERRED = 0x1000
71
72
73 def time_from_com(com_date):
74     """Return a valid datetime.time from a COM date or time object."""
75     hour, minute = divmod(86400 * (float(com_date) % 1), 3600)
76     minute, second = divmod(minute, 60)
77     # Must do both int() and round() or we'll be up to 1 second off.
78     hour = int(round(hour))
79     minute = int(round(minute))
80     second = int(round(second))
81    
82     while second > 59:
83         second -= 60
84         minute += 1
85     while second < 0:
86         second += 60
87         minute -= 1
88     while minute > 59:
89         minute -= 60
90         hour += 1
91     while minute < 0:
92         minute += 60
93         hour -= 1
94     while hour > 23:
95         hour -= 24
96         day += 1
97     while hour < 0:
98         hour += 24
99    
100     return datetime.time(hour, minute, second)
101
102 class AdapterFromADO(db.AdapterFromDB):
103     """Coerce incoming values from ADO to Dejavu datatypes."""
104    
105     encoding = 'ISO-8859-1'
106    
107     def coerce_any_to_datetime_datetime(self, value):
108         # Illegal Date/Time values will crash the
109         # app when using value.Format(). Therefore,
110         # grab the value and figure the date ourselves.
111         # Use 1-second resolution only.
112         if isinstance(value, basestring):
113             if value:
114                 try:
115                     return datetime.datetime(int(value[0:4]), int(value[4:6]),
116                                              int(value[6:8]))
117                 except Exception:
118                     raise ValueError("'%s' %s" % (value, type(value)))
119             else:
120                 return None
121         else:
122             # For some reason, we need both float and int.
123             aDate = datetime.date.fromordinal(int(float(value)) + zeroHour)
124             return datetime.datetime.combine(aDate, time_from_com(value))
125    
126     def coerce_any_to_datetime_date(self, value):
127         # See coerce_any_to_datetime
128         if isinstance(value, basestring):
129             if value:
130                 try:
131                     return datetime.date(int(value[0:4]), int(value[4:6]),
132                                          int(value[6:8]))
133                 except Exception:
134                     raise ValueError("'%s' %s" % (value, type(value)))
135             else:
136                 return None
137         else:
138             return datetime.date.fromordinal(int(float(value)) + zeroHour)
139    
140     def coerce_any_to_datetime_time(self, value):
141         # See coerce_any_to_datetime
142         return time_from_com(value)
143    
144     def coerce_any_to_decimal_Decimal(self, value):
145         # pywin32 build 205 began support for returning
146         # COM Currency objects as decimal objects.
147         # See http://pywin32.cvs.sourceforge.net/pywin32/pywin32/CHANGES.txt?view=markup
148         if not isinstance(value, db.decimal.Decimal):
149             value = str(value)
150             value = db.decimal.Decimal(str(value))
151         return value
152    
153     coerce_any_to_float = float
154    
155     def coerce_CURRENCY_to_float(self, value):
156         if isinstance(value, tuple):
157             # See http://groups.google.com/group/comp.lang.python/
158             #           browse_frm/thread/fed03c64735c9e9c
159             value = map(long, value)
160             return ((value[1] & 0xFFFFFFFFL) | (value[0] << 32)) / 1e4
161         return float(value)
162    
163     def coerce_CURRENCY_to_decimal_Decimal(self, value):
164         # pywin32 build 205 began support for returning
165         # COM Currency objects as decimal objects.
166         # See http://pywin32.cvs.sourceforge.net/pywin32/pywin32/CHANGES.txt?view=markup
167         if not isinstance(value, db.decimal.Decimal):
168             # See http://groups.google.com/group/comp.lang.python/
169             #           browse_frm/thread/fed03c64735c9e9c
170             value = map(long, value)
171             value = (value[1] & 0xFFFFFFFFL) | (value[0] << 32)
172             return db.decimal.Decimal(value) / 10000
173         return value
174    
175     def coerce_CURRENCY_to_fixedpoint_FixedPoint(self, value):
176         if isinstance(value, db.decimal.Decimal):
177             value = str(value)
178             scale = 0
179             atoms = value.rsplit(".", 1)
180             if len(atoms) > 1:
181                 scale = len(atoms[-1])
182             return db.fixedpoint.FixedPoint(value, scale)
183         else:
184             # See http://groups.google.com/group/comp.lang.python/
185             #           browse_frm/thread/fed03c64735c9e9c
186             value = map(long, value)
187             value = (value[1] & 0xFFFFFFFFL) | (value[0] << 32)
188             return db.fixedpoint.FixedPoint(value, 4) / 1e4
189    
190     coerce_any_to_int = int
191     coerce_any_to_bool = bool
192    
193     def coerce_any_to_unicode(self, value):
194         if isinstance(value, unicode):
195             # For some reason, inValue is already a unicode object.
196             return value
197         if isinstance(value, (basestring, buffer)):
198             try:
199                 return unicode(value, self.encoding)
200             except UnicodeError:
201                 raise StandardError(type(value))
202         return unicode(value)
203
204
205
206 class ADOSQLDecompiler(db.SQLDecompiler):
207    
208     def visit_COMPARE_OP(self, lo, hi):
209         op2, op1 = self.stack.pop(), self.stack.pop()
210         if op1 is db.cannot_represent or op2 is db.cannot_represent:
211             self.stack.append(db.cannot_represent)
212             return
213        
214         op = lo + (hi << 8)
215         if op in (6, 7):     # in, not in
216             # Looking for text in a field. Use Like (reverse terms).
217             # LIKE is case-insensitive in MS SQL Server (and there
218             # doesn't seem to be a way around it). Use icontainedby
219             # and just mark imperfect.
220             value = self.dejavu_icontainedby(op1, op2)
221             if op == 7:
222                 value = "NOT " + value
223             self.stack.append(value)
224             self.imperfect = True
225         elif op1 == 'NULL':
226             if op in (2, 8):    # '==', is
227                 self.stack.append(op2 + " IS NULL")
228             elif op in (3, 9):  # '!=', 'is not'
229                 self.stack.append(op2 + " IS NOT NULL")
230             else:
231                 raise ValueError("Non-equality Null comparisons not allowed.")
232         elif op2 == 'NULL':
233             if op in (2, 8):    # '==', 'is'
234                 self.stack.append(op1 + " IS NULL")
235             elif op in (3, 9):  # '!=', 'is not'
236                 self.stack.append(op1 + " IS NOT NULL")
237             else:
238                 raise ValueError("Non-equality Null comparisons not allowed.")
239         else:
240             if (isinstance(op2, db.ConstWrapper)
241                 and isinstance(op2.basevalue, basestring)):
242                 # ADO comparison operators for strings are case-insensitive
243                 # by default. Rather than determine which columns in the DB
244                 # might be case-sensitive, just flag them all as imperfect.
245                 # TODO: might be possible to cast both to varbinary, but
246                 # that may cause problems with unicode columns.
247                 self.imperfect = True
248             self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2)
249    
250     # --------------------------- Dispatchees --------------------------- #
251    
252     def attr_startswith(self, tos, arg):
253         self.imperfect = True
254         return tos + " LIKE '" + self.adapter.escape_like(arg) + "%'"
255    
256     def attr_endswith(self, tos, arg):
257         self.imperfect = True
258         return tos + " LIKE '%" + self.adapter.escape_like(arg) + "'"
259    
260     def containedby(self, op1, op2):
261         self.imperfect = True
262         if isinstance(op1, ConstWrapper):
263             # Looking for text in a field. Use Like (reverse terms).
264             return op2 + " LIKE '%" + self.adapter.escape_like(op1) + "%'"
265         else:
266             # Looking for field in (a, b, c)
267             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
268             return op1 + " IN (" + ", ".join(atoms) + ")"
269    
270     def dejavu_icontainedby(self, op1, op2):
271         if isinstance(op1, db.ConstWrapper):
272             # Looking for text in a field. Use Like (reverse terms).
273             # LIKE is already case-insensitive in MS SQL Server;
274             # so don't use LOWER().
275             value = op2 + " LIKE '%" + self.adapter.escape_like(op1) + "%'"
276         else:
277             # Looking for field in (a, b, c)
278             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
279             value = op1 + " IN (" + ", ".join(atoms) + ")"
280         return value
281    
282     def dejavu_istartswith(self, x, y):
283         # Like is already case-insensitive in ADO; so don't use LOWER().
284         return x + " LIKE '" + self.adapter.escape_like(y) + "%'"
285    
286     def dejavu_iendswith(self, x, y):
287         # Like is already case-insensitive in ADO; so don't use LOWER().
288         return x + " LIKE '%" + self.adapter.escape_like(y) + "'"
289    
290     def dejavu_ieq(self, x, y):
291         # = is already case-insensitive in ADO.
292         return x + " = " + y
293    
294     def dejavu_now(self):
295         return "getdate()"
296    
297     def dejavu_today(self):
298         return "DATEADD(dd, DATEDIFF(dd,0,getdate()), 0)"
299    
300     def func__builtin___len(self, x):
301         return "Len(" + x + ")"
302
303
304 class ADOColumnSet(db.ColumnSet):
305    
306     def _rename(self, oldcol, newcol):
307         conn = self.table.db.connection()
308         try:
309             cat = win32com.client.Dispatch(r'ADOX.Catalog')
310             cat.ActiveConnection = conn
311             cat.Tables(self.table.name).Columns(oldcol.name).Name = newcol.name
312         finally:
313             conn = None
314             cat = None
315
316
317 def connatoms(connstring):
318     atoms = {}
319     for pair in connstring.split(";"):
320         if pair:
321             k, v = pair.split("=", 1)
322             atoms[k.upper().strip()] = v.strip()
323     return atoms
324
325
326 class ADODatabase(db.Database):
327    
328     decompiler = ADOSQLDecompiler
329     adapterfromdb = AdapterFromADO()
330     columnsetclass = ADOColumnSet
331    
332     def _get_tables(self, conn=None):
333         # cols will be
334         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202),
335         # (u'TABLE_TYPE', 202), (u'TABLE_GUID', 72), (u'DESCRIPTION', 203),
336         # (u'TABLE_PROPID', 19), (u'DATE_CREATED', 7), (u'DATE_MODIFIED', 7)]
337         data, cols = self.fetch(adSchemaTables, conn=conn, schema=True)
338         return [db.Table(self, str(row[2]), self.quote(str(row[2])))
339                 for row in data]
340    
341     def _get_columns(self, tablename, conn=None):
342         # columns will be
343         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202),
344         # (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72), (u'COLUMN_PROPID', 19),
345         # (u'ORDINAL_POSITION', 19), (u'COLUMN_HASDEFAULT', 11),
346         # (u'COLUMN_DEFAULT', 203), (u'COLUMN_FLAGS', 19), (u'IS_NULLABLE', 11),
347         # (u'DATA_TYPE', 18), (u'TYPE_GUID', 72), (u'CHARACTER_MAXIMUM_LENGTH', 19),
348         # (u'CHARACTER_OCTET_LENGTH', 19), (u'NUMERIC_PRECISION', 18),
349         # (u'NUMERIC_SCALE', 2), (u'DATETIME_PRECISION', 19),
350         # (u'CHARACTER_SET_CATALOG', 202), (u'CHARACTER_SET_SCHEMA', 202),
351         # (u'CHARACTER_SET_NAME', 202), (u'COLLATION_CATALOG', 202),
352         # (u'COLLATION_SCHEMA', 202), (u'COLLATION_NAME', 202),
353         # (u'DOMAIN_CATALOG', 202), (u'DOMAIN_SCHEMA', 202),
354         # (u'DOMAIN_NAME', 202), (u'DESCRIPTION', 203)]
355         data, coldefs = self.fetch(adSchemaColumns, conn=conn, schema=True)
356        
357         cols = []
358         for row in data:
359             # I tried passing criteria to OpenSchema, but passing None is
360             # not the same as passing pythoncom.Empty (which errors).
361             if tablename and row[2] != tablename:
362                 continue
363            
364             dbtype = dbtypes[row[11]]
365             default = row[8]
366             if default is not None:
367                 default = self.python_type(dbtype)(default)
368            
369             name = str(row[3])
370             c = db.Column(name, self.quote(name), dbtype, default)
371            
372             # This only works for SQL Server. The MSAccessDatabase will
373             # wrap this method and override autoincrement.
374             colflags = int(row[9])
375             if ((colflags & DBCOLUMNFLAGS_ISFIXEDLENGTH)
376                 and not (colflags & DBCOLUMNFLAGS_WRITE)):
377                 c.autoincrement = True
378            
379             if dbtype in ("SMALLINT", "INTEGER", "TINYINT",
380                           "UNSIGNEDTINYINT", "UNSIGNEDSMALLINT",
381                           "UNSIGNEDINT", "BIGINT", "UNSIGNEDBIGINT"):
382                 c.hints['bytes'] = row[15]
383             elif dbtype in ("SINGLE", "DOUBLE"):
384                 c.hints['precision'] = row[15]
385                 c.hints['scale'] = row[16]
386             elif dbtype == "CURRENCY":
387                 # CURRENCY allows 15 places to the left of the decimal point,
388                 # and 4 places to the right.
389                 c.hints['precision'] = 19
390                 c.hints['scale'] = 4
391             elif dbtype in ("DECIMAL", "NUMERIC"):
392                 c.hints['precision'] = row[15]
393                 c.hints['scale'] = row[16]
394                 c.dbtype = "%s(%s, %s)" % (dbtype, row[15], row[16])
395             elif dbtype in ("BSTR", "VARIANT", "BINARY", "CHAR",
396                             "VARCHAR", "VARBINARY", "WCHAR", "VARWCHAR"):
397                 if row[13]:
398                     # row[13] will be a float
399                     c.hints['bytes'] = b = int(row[13])
400                 else:
401                     # I'm kinda guessing on this. If we use "MEMO" in an
402                     # MSAccess CREATE statement, it comes back as "WCHAR",
403                     # and seems to support over 65536 bytes.
404                     c.hints['bytes'] = b = (2 ** 31) - 1
405                 c.dbtype = "%s(%s)" % (c.dbtype, b)
406             elif dbtype in ("LONGVARCHAR", "LONGVARBINARY", "LONGVARWCHAR"):
407                 if row[13]:
408                     # row[13] will be a float
409                     c.hints['bytes'] = b = int(row[13])
410                     c.dbtype = "%s(%s)" % (c.dbtype, b)
411                 else:
412                     c.hints['bytes'] = 65535
413            
414             cols.append(c)
415         return cols
416    
417     def _get_indices(self, tablename=None, conn=None):
418         # cols will be
419         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202),
420         # (u'INDEX_CATALOG', 202), (u'INDEX_SCHEMA', 202), (u'INDEX_NAME', 202),
421         # (u'PRIMARY_KEY', 11), (u'UNIQUE', 11), (u'CLUSTERED', 11), (u'TYPE', 18),
422         # (u'FILL_FACTOR', 3), (u'INITIAL_SIZE', 3), (u'NULLS', 3),
423         # (u'SORT_BOOKMARKS', 11), (u'AUTO_UPDATE', 11), (u'NULL_COLLATION', 3),
424         # (u'ORDINAL_POSITION', 19), (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72),
425         # (u'COLUMN_PROPID', 19), (u'COLLATION', 2), (u'CARDINALITY', 21),
426         # (u'PAGES', 3), (u'FILTER_CONDITION', 202), (u'INTEGRATED', 11)]
427         data, _ = self.fetch(adSchemaIndexes, conn=conn, schema=True)
428         indices = []
429         for row in data:
430             # I tried passing criteria to OpenSchema, but passing None is
431             # not the same as passing pythoncom.Empty (which errors).
432             if tablename and row[2] != tablename:
433                 continue
434             i = db.Index(row[5], self.quote(row[5]),
435                          row[2], row[17], row[6], row[7])
436             indices.append(i)
437         return indices
438    
439     def python_type(self, dbtype):
440         """Return a Python type which can store values of the given dbtype."""
441         if dbtype in ("DATE", "DBDATE"):
442             return datetime.date
443         elif dbtype == "DBTIME":
444             return datetime.time
445         elif dbtype in ("DATETIME", "DBTIMESTAMP"):
446             return datetime.datetime
447         elif dbtype in ("SMALLINT", "INTEGER", "TINYINT",
448                         "UNSIGNEDTINYINT", "UNSIGNEDSMALLINT",
449                         "UNSIGNEDINT"):
450             return int
451         elif dbtype in ("BIT", "BOOLEAN"):
452             return bool
453         elif dbtype in ("BIGINT", "UNSIGNEDBIGINT", "LONG"):
454             return long
455         elif dbtype in ("SINGLE", "DOUBLE", "DOUBLE PRECISION", "REAL"):
456             return float
457        
458         for t in ("DECIMAL", "NUMERIC", "CURRENCY"):
459             if dbtype.startswith(t):
460                 if db.decimal:
461                     return db.decimal.Decimal
462                 elif db.fixedpoint:
463                     return db.fixedpoint.FixedPoint
464        
465         for t in ("BSTR", "VARIANT", "BINARY", "CHAR", "MEMO", "TEXT",
466                   "VARCHAR", "LONGVARCHAR", "VARBINARY", "LONGVARBINARY"):
467             if dbtype.startswith(t):
468                 return str
469        
470         for t in ("WCHAR", "VARWCHAR", "LONGVARWCHAR"):
471             if dbtype.startswith(t):
472                 return unicode
473        
474         raise TypeError("Database type %s could not be converted "
475                         "to a Python type." % repr(dbtype))
476    
477     def _rename(self, oldtable, newtable):
478         conn = self.connection()
479         try:
480             cat = win32com.client.Dispatch(r'ADOX.Catalog')
481             cat.ActiveConnection = conn
482             cat.Tables(oldtable.name).Name = newtable.name
483         finally:
484             conn = None
485             cat = None
486    
487     def quote(self, name):
488         """Return name, quoted for use in an SQL statement."""
489         return '[' + name + ']'
490    
491     def _get_conn(self):
492         conn = win32com.client.Dispatch(r'ADODB.Connection')
493         conn.Open(self.Connect)
494         return conn
495    
496     def _del_conn(self, conn):
497         conn.Close()
498    
499     def execute(self, query, conn=None):
500         if conn is None:
501             conn = self.connection()
502         if isinstance(query, unicode):
503             query = query.encode(self.adaptertosql.encoding)
504         self.log(query, dejavu.logflags.SQL)
505         try:
506             conn.Execute(query)
507         except pywintypes.com_error, x:
508             x.args += (query, )
509             conn = None
510             raise
511    
512     def fetch(self, query, conn=None, schema=False):
513         """fetch(query, conn=None) -> rowdata, columns."""
514         if conn is None:
515             conn = self.connection()
516        
517         try:
518             if schema:
519                 res = conn.OpenSchema(query)
520             else:
521                 self.log(query, dejavu.logflags.SQL)
522                 res = win32com.client.Dispatch(r'ADODB.Recordset')
523                 if hasattr(conn, "conn"):
524                     # 'conn' is a ConnectionWrapper object, which .Open
525                     # won't accept. Pass the unwrapped connection instead.
526                     res.Open(query, conn.conn, adOpenForwardOnly, adLockReadOnly)
527                 else:
528                     res.Open(query, conn, adOpenForwardOnly, adLockReadOnly)
529         except pywintypes.com_error, x:
530             try:
531                 res.Close()
532             except:
533                 pass
534             x.args += (query, )
535             conn = None
536             # "raise x" here or we could get the traceback of the inner try.
537             raise x
538        
539         columns = [(x.Name, x.Type) for x in res.Fields]
540        
541         data = []
542         if not(res.BOF and res.EOF):
543             # We tried .MoveNext() and lots of Fields.Item() calls.
544             # Using GetRows() beats that time by about 2/3.
545             data = res.GetRows()
546             # Convert cols x rows -> rows x cols
547             data = zip(*data)
548         try:
549             res.Close()
550         except:
551             pass
552         conn = None
553        
554         return data, columns
555
556
557 class StorageManagerADO(db.StorageManagerDB):
558     """StoreManager to save and retrieve Units via ADO 2.7.
559     
560     You must run makepy on ADO 2.7 before installing.
561     """
562    
563     databaseclass = ADODatabase
564    
565     def version(self):
566         adoconn = win32com.client.Dispatch(r'ADODB.Connection')
567         return "ADO Version: %s" % adoconn.Version
568
569
570
571 ###########################################################################
572 ##                                                                       ##
573 ##                             SQL Server                                ##
574 ##                                                                       ##
575 ###########################################################################
576
577
578 class AdapterToADOSQL_SQLServer(db.AdapterToSQL):
579    
580     encoding = 'ISO-8859-1'
581    
582     escapes = [("'", "''")]
583     like_escapes = [("%", "[%]"), ("_", "[_]")]
584    
585     # These are not the same as coerce_bool_to_any (which is used on one side of
586     # a comparison). Instead, these are used when the whole (sub)expression
587     # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3".
588     bool_true = "(1=1)"
589     bool_false = "(1=0)"
590    
591     def coerce_bool_to_any(self, value):
592         if value:
593             return '1'
594         return '0'
595
596
597 class TypeAdapter_SQLServer(db.TypeAdapter):
598    
599     # Hm. Docs say 38, but I can't seem to get more than 12 working.
600     # They must mean 38 binary digits; math.log(2 ** 38, 10) = 11.4+
601     numeric_max_precision = 12
602     numeric_max_bytes = 6
603    
604     def coerce_bool(self, col):
605         return "BIT"
606    
607     def coerce_datetime_datetime(self, col):
608         return "DATETIME"
609    
610     def coerce_datetime_date(self, col):
611         return "DATETIME"
612    
613     def coerce_datetime_time(self, col):
614         return "DATETIME"
615    
616     def int_type(self, bytes):
617         """Return a datatype which can handle the given number of bytes."""
618         if bytes <= 2:
619             return "SMALLINT"
620         elif bytes <= 4:
621             return "INTEGER"
622         elif bytes <= 8:
623             # BIGINT is usually 8 bytes
624             return "BIGINT"
625         else:
626             # Anything larger than 8 bytes, use decimal/numeric.
627             # For PostgreSQL, "The actual storage requirement is two bytes
628             # for each group of four decimal digits, plus eight bytes
629             # overhead." Note we omit the overhead in our calculation.
630             return "NUMERIC(%s, 0)" % (bytes * 2)
631    
632     def coerce_str(self, col):
633         # The bytes hint does not reflect the usual 4-byte base for varchar.
634         bytes = int(col.hints.get('bytes', 255))
635        
636         if bytes == 0 or bytes > 8000:
637             # Okay, what the @#$%& is wrong with Redmond??!?! We can't even
638             # compare TEXT or NTEXT fields??!? Fine. We'll deny such, and
639             # warn the deployer with less swearing and exclamation points.
640             warnings.warn("You have defined a string property without "
641                           "limiting its length. Microsoft SQL Server does "
642                           "not allow comparisons on string fields larger "
643                           "than 8000 characters. Some of your data may be "
644                           "truncated.", errors.StorageWarning)
645             bytes = 8000
646        
647         # 8000 *bytes* is the absolute upper limit, based on T_SQL docs for
648         # varchar/varbinary. If there are further fields defined for the
649         # class, or the codepage uses a double-byte character set, we still
650         # might exceed the max size (8060) for a record. We could calc the
651         # total requested record size, and adjust accordingly. Meh.
652         return "VARCHAR(%s)" % bytes
653
654
655 class SQLServerColumnSet(ADOColumnSet):
656    
657     def __setitem__(self, key, column):
658         t = self.table
659        
660         dbtype = column.dbtype
661         if column.autoincrement:
662             if dbtype not in ("BOOLEAN", "SMALLINT", "INTEGER", "BIGINT"):
663                 raise ValueError("SQL Server does not allow IDENTITY "
664                                  "columns of type %s" % repr(dbtype))
665             dbtype = "%s IDENTITY(%s, 1) NOT NULL" % (dbtype, column.default)
666         else:
667             default = column.default or ""
668             if default:
669                 default = self.adaptertosql.coerce(default, dbtype)
670                 dbtype = "%s DEFAULT %s" % (dbtype, default)
671        
672         # SQL Server doesn't use the "COLUMN" keyword with "ADD"
673         t.db.execute("ALTER TABLE %s ADD %s %s;" %
674                      (t.qname, column.qname, dbtype))
675         dict.__setitem__(self, key, column)
676    
677     def _rename(self, oldcol, newcol):
678         t = self.table
679         t.db.execute("EXEC sp_rename '%s.%s', '%s', 'COLUMN'" %
680                      (t.name, oldcol.name, newcol.name))
681
682
683 class SQLServerDatabase(ADODatabase):
684    
685     columnsetclass = SQLServerColumnSet
686     adaptertosql = AdapterToADOSQL_SQLServer()
687     typeadapter = TypeAdapter_SQLServer()
688    
689     def create_database(self):
690         # This method hasn't been tested yet for SQL server (only MSDE).
691         adoconn = win32com.client.Dispatch(r'ADODB.Connection')
692         atoms = connatoms(self.Connect)
693         atoms['INITIAL CATALOG'] = "tempdb"
694         adoconn.Open("; ".join(["%s=%s" % (k, v) for k, v in atoms.iteritems()]))
695         adoconn.Execute("CREATE DATABASE %s" % self.qname)
696         adoconn.Close()
697         self.clear()
698    
699     def drop_database(self):
700         adoconn = win32com.client.Dispatch(r'ADODB.Connection')
701         atoms = connatoms(self.Connect)
702         atoms['INITIAL CATALOG'] = "tempdb"
703         adoconn.Open("; ".join(["%s=%s" % (k, v) for k, v in atoms.iteritems()]))
704         adoconn.Execute("DROP DATABASE %s;" % self.qname)
705         adoconn.Close()
706         self.clear()
707    
708     def __setitem__(self, key, table):
709         if key in self:
710             del self[key]
711        
712         fields = []
713         for column in table.columns.itervalues():
714             dbtype = column.dbtype
715             if column.autoincrement:
716                 if dbtype not in ("BOOLEAN", "SMALLINT", "INTEGER", "BIGINT"):
717                     raise ValueError("SQL Server does not allow IDENTITY "
718                                      "columns of type %s" % repr(dbtype))
719                 dbtype = "%s IDENTITY(%s, 1) NOT NULL" % (dbtype, column.default)
720             else:
721                 default = column.default or ""
722                 if default:
723                     default = self.adaptertosql.coerce(default, dbtype)
724                     dbtype = "%s DEFAULT %s" % (dbtype, default)
725             fields.append('%s %s' % (column.qname, dbtype))
726        
727         self.execute('CREATE TABLE %s (%s);' %
728                      (table.qname, ", ".join(fields)))
729        
730         for index in table.columns.indices.itervalues():
731             self.execute('CREATE INDEX %s ON %s (%s);' %
732                          (index.qname, table.qname,
733                           self.quote(index.colname)))
734        
735         dict.__setitem__(self, key, table)
736
737
738 class StorageManagerADO_SQLServer(StorageManagerADO):
739    
740     databaseclass = SQLServerDatabase
741    
742     def __init__(self, name, arena, allOptions={}):
743         atoms = connatoms(allOptions['Connect'])
744         allOptions['name'] = atoms['INITIAL CATALOG']
745         db.StorageManagerDB.__init__(self, name, arena, allOptions)
746    
747     def _seq_UnitSequencerInteger(self, unit):
748         """Reserve a unit using the table's AUTOINCREMENT field."""
749         cls = unit.__class__
750         t = self.db[cls.__name__]
751        
752         fields = []
753         values = []
754         for key in cls.properties:
755             col = t.columns[key]
756             if col.autoincrement:
757                 # Skip this field, since we're using IDENTITY
758                 continue
759             val = self.db.adaptertosql.coerce(getattr(unit, key), col.dbtype)
760             fields.append(col.qname)
761             values.append(val)
762        
763         fields = ", ".join(fields)
764         values = ", ".join(values)
765         self.db.execute('INSERT INTO %s (%s) VALUES (%s);' %
766                         (t.qname, fields, values))
767        
768         # Grab the new ID. This is threadsafe because db.reserve has a mutex.
769         # For some reason, using SCOPE_IDENTITY or IDENTITY failed (returned
770         # None) when retrieving ID's just after a 99-thread-test ran. Moving
771         # the multithreading test fixed it. IDENT_CURRENT worked regardless.
772         data, col_defs = self.db.fetch("SELECT IDENT_CURRENT('%s');" % t.qname)
773         setattr(unit, cls.identifiers[0], data[0][0])
774
775
776
777 ###########################################################################
778 ##                                                                       ##
779 ##                             MS Access                                 ##
780 ##                                                                       ##
781 ###########################################################################
782
783
784 class ADOSQLDecompiler_MSAccess(ADOSQLDecompiler):
785     sql_cmp_op = ('<', '<=', '=', '<>', '>', '>=', 'in', 'not in')
786    
787     def dejavu_now(self):
788         return "Now()"
789    
790     def dejavu_today(self):
791         return "DateValue(Now())"
792    
793     def dejavu_year(self, x):
794         return "Year(" + x + ")"
795
796
797 class TypeAdapter_MSAccess(db.TypeAdapter):
798    
799     # Hm. Docs say 28/38, but I can't seem to get more than 12 working.
800     numeric_max_precision = 12
801     numeric_max_bytes = 6
802    
803     def coerce_bool(self, col): return "BIT"
804    
805     def coerce_datetime_datetime(self, col): return "DATETIME"
806     def coerce_datetime_date(self, col): return "DATETIME"
807     def coerce_datetime_time(self, col): return "DATETIME"
808    
809     def int_type(self, bytes):
810         if bytes <= 2:
811             return "INTEGER"
812         elif bytes <= 4:
813             return "LONG"
814         else:
815             # Anything larger than 4 bytes, use decimal/numeric.
816             return "DECIMAL"
817    
818     def coerce_str(self, col):
819         # The bytes hint shall not reflect the usual 4-byte base for varchar.
820         bytes = int(col.hints.get('bytes', 255))
821        
822         # 255 chars is the upper limit for TEXT / VARCHAR in MS Access.
823         if bytes == 0 or bytes > 255:
824             # MEMO is 1 GB max when set programatically (only 64K when set
825             # in Access UI). But then, 1 GB is the limit for the whole DB.
826             # Note that OpenSchema will return a DATA_TYPE of "WCHAR".
827             return "MEMO"
828        
829         return "VARCHAR(%s)" % bytes
830
831
832
833 class AdapterToADOSQL_MSAccess(db.AdapterToSQL):
834     """Coerce Expression constants to ADO SQL."""
835    
836     encoding = 'ISO-8859-1'
837    
838     escapes = [("'", "''")]
839     like_escapes = [("%", "[%]"), ("_", "[_]")]
840    
841     def coerce_datetime_datetime_to_any(self, value):
842         return ('#%s/%s/%s %02d:%02d:%02d#' %
843                 (value.month, value.day, value.year,
844                  value.hour, value.minute, value.second))
845    
846     def coerce_datetime_date_to_any(self, value):
847         return '#%s/%s/%s#' % (value.month, value.day, value.year)
848    
849     def coerce_datetime_time_to_any(self, value):
850         return '#%02d:%02d:%02d#' % (value.hour, value.minute, value.second)
851
852
853 class MSAccessColumnSet(ADOColumnSet):
854    
855     def __setitem__(self, key, column):
856         t = self.table
857        
858         dbtype = column.dbtype
859         if column.autoincrement:
860             dbtype = "AUTOINCREMENT(%s, 1)" % column.default
861         else:
862             default = column.default or ""
863             if default:
864                 default = self.adaptertosql.coerce(default, dbtype)
865                 dbtype = "%s DEFAULT %s" % (dbtype, default)
866        
867         # SQL Server doesn't use the "COLUMN" keyword with "ADD"
868         t.db.execute("ALTER TABLE %s ADD %s %s;" %
869                      (t.qname, column.qname, dbtype))
870         dict.__setitem__(self, key, column)
871
872
873 class MSAccessDatabase(ADODatabase):
874    
875     decompiler = ADOSQLDecompiler_MSAccess
876     adaptertosql = AdapterToADOSQL_MSAccess()
877     typeadapter = TypeAdapter_MSAccess()
878    
879     columnsetclass = MSAccessColumnSet
880    
881     poolsize = 0
882    
883     def connect(self):
884         # MS Access can't use a pool, because there doesn't seem
885         # to be a commit timeout.
886         self.connection = db.SingleConnection(self._get_conn, self._del_conn)
887    
888     def _get_columns(self, tablename, conn=None):
889         cols = ADODatabase._get_columns(self, tablename, conn)
890         if conn is None:
891             conn = self.connection()
892        
893         try:
894             # Horrible hack to get autoincrement property
895             query = "SELECT * FROM %s WHERE FALSE" % self.quote(tablename)
896             res = win32com.client.Dispatch(r'ADODB.Recordset')
897             if hasattr(conn, "conn"):
898                 # 'conn' is a ConnectionWrapper object, which .Open
899                 # won't accept. Pass the unwrapped connection instead.
900                 res.Open(query, conn.conn, adOpenForwardOnly, adLockReadOnly)
901             else:
902                 res.Open(query, conn, adOpenForwardOnly, adLockReadOnly)
903         except pywintypes.com_error, x:
904             try:
905                 res.Close()
906             except:
907                 pass
908             x.args += (query, )
909             conn = None
910             # "raise x" here or we could get the traceback of the inner try.
911             raise x
912        
913         for c in cols:
914             c.autoincrement = res.Fields(c.name).Properties("ISAUTOINCREMENT").Value
915        
916         try:
917             res.Close()
918         except:
919             pass
920         conn = None
921        
922         return cols
923    
924     def python_type(self, dbtype):
925         if dbtype == "LONG":
926             return int
927         return ADODatabase.python_type(self, dbtype)
928    
929     def create_database(self):
930         # By not providing an Engine Type, it defaults to 5 = Access 2000.
931         cat = win32com.client.Dispatch(r'ADOX.Catalog')
932         cat.Create(self.Connect)
933         cat.ActiveConnection.Close()
934         self.clear()
935    
936     def drop_database(self):
937         import os
938         # This should accept relative or absolute paths
939         if os.path.exists(self.name):
940             os.remove(self.name)
941         self.clear()
942    
943     def __setitem__(self, key, table):
944         if key in self:
945             del self[key]
946        
947         fields = []
948         for column in table.columns.itervalues():
949             dbtype = column.dbtype
950             if column.autoincrement:
951                 dbtype = "AUTOINCREMENT(%s, 1)" % column.default
952             else:
953                 default = column.default or ""
954                 if default:
955                     default = self.adaptertosql.coerce(default, dbtype)
956                     dbtype = "%s DEFAULT %s" % (dbtype, default)
957             fields.append('%s %s' % (column.qname, dbtype))
958        
959         self.execute('CREATE TABLE %s (%s);' %
960                      (table.qname, ", ".join(fields)))
961        
962         for index in table.columns.indices.itervalues():
963             self.execute('CREATE INDEX %s ON %s (%s);' %
964                          (index.qname, table.qname,
965                           self.quote(index.colname)))
966        
967         dict.__setitem__(self, key, table)
968
969
970 class StorageManagerADO_MSAccess(StorageManagerADO):
971     # Jet Connections and Recordsets are always free-threaded.
972    
973     use_asterisk_to_get_all = True
974     databaseclass = MSAccessDatabase
975    
976     def __init__(self, name, arena, allOptions={}):
977         atoms = connatoms(allOptions['Connect'])
978         allOptions['name'] = (atoms.get('DATA SOURCE') or
979                               atoms.get('DATA SOURCE NAME') or
980                               atoms.get('DBQ'))
981         db.StorageManagerDB.__init__(self, name, arena, allOptions)
982    
983     def _seq_UnitSequencerInteger(self, unit):
984         """Reserve a unit using the table's AUTOINCREMENT field."""
985         cls = unit.__class__
986         t = self.db[cls.__name__]
987        
988         fields = []
989         values = []
990         for key in cls.properties:
991             col = t.columns[key]
992             if col.autoincrement:
993                 # Skip this field, since we're using AUTOINCREMENT
994                 continue
995             val = self.db.adaptertosql.coerce(getattr(unit, key), col.dbtype)
996             fields.append(col.qname)
997             values.append(val)
998        
999         fields = ", ".join(fields)
1000         values = ", ".join(values)
1001         self.db.execute('INSERT INTO %s (%s) VALUES (%s);' %
1002                         (t.qname, fields, values))
1003        
1004         # Grab the new ID. This is threadsafe because db.reserve has a mutex.
1005         data, col_defs = self.db.fetch("SELECT @@IDENTITY;")
1006         setattr(unit, cls.identifiers[0], data[0][0])
1007    
1008     def _make_column(self, cls, key):
1009         col = StorageManagerADO._make_column(self, cls, key)
1010         if col.dbtype == "MEMO":
1011             for assoc in cls._associations.itervalues():
1012                 if assoc.nearKey == key:
1013                     warnings.warn("Memo fields cannot be used as join keys. "
1014                                   "You should set %s.%s(hints={'bytes': 255})"
1015                                   % (cls.__name__, key), errors.StorageWarning)
1016         return col
1017
1018
1019 def gen_py():
1020     # Auto generate .py support for ADO 2.7+
1021     print 'Please wait while support for ADO 2.7+ is verified...'
1022     CLSID = '{EF53050B-882E-4776-B643-EDA472E8E3F2}'
1023     return win32com.client.gencache.EnsureModule(CLSID, 0, 2, 7)
1024
1025
1026 if __name__ == '__main__':
1027     gen_py()
Note: See TracBrowser for help on using the browser.