Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storeado.py

Revision 246 (checked in by fumanchu, 7 years ago)

Added sync method to all SM's (default: pass) with a new semantic: pass it a list of classes you want to sync. The old way didn't allow for some classes in one store and some in another. Arena.storage(cls) also now calls sync whenever the cls has no corresponding store in _registered_classes, further proving you should always prefer calling storage() and never inspect _registered_classes directly.

Also removed arena.defaultStore in favor of inspecting SM.classnames.

  • Property svn:eol-style set to native
Line 
1 import sys
2 # Put COM in free-threaded mode. This first thread will have
3 # CoInitializeEx called automatically when pythoncom is imported.
4 sys.coinit_flags = 0
5 import pythoncom
6
7 import win32com.client
8 import pywintypes
9 import datetime
10
11 try:
12     import cPickle as pickle
13 except ImportError:
14     import pickle
15
16 import warnings
17
18 import dejavu
19 from dejavu import errors, logic, storage
20 from dejavu.storage import db
21
22 adOpenForwardOnly = 0
23 adOpenKeyset = 1
24 adOpenDynamic = 2
25 adOpenStatic = 3
26
27 adLockReadOnly = 1
28 adLockPessimistic = 2
29 adLockOptimistic = 3
30 adLockBatchOptimistic = 4
31
32 adSchemaColumns = 4
33 adSchemaIndexes = 12
34 adSchemaTables = 20
35
36 adUseClient = 3
37
38 # 12/30/1899, the zero-Date for ADO = 693594
39 zeroHour = datetime.date(1899, 12, 30).toordinal()
40
41 dbtypes = {
42     0: 'EMPTY',                     2: 'SMALLINT',
43     3: 'INTEGER',                   4: 'SINGLE',
44     5: 'DOUBLE',                    6: 'CURRENCY',
45     7: 'DATE',                      8: 'BSTR',
46     9: 'IDISPATCH',                 10: 'ERROR',
47     11: 'BOOLEAN',                  12: 'VARIANT',
48     13: 'IUNKNOWN',                 14: 'DECIMAL',
49     16: 'TINYINT',                  17: 'UNSIGNEDTINYINT',
50     18: 'UNSIGNEDSMALLINT',         19: 'UNSIGNEDINT',
51     20: 'BIGINT',                   21: 'UNSIGNEDBIGINT',
52     72: 'GUID',                     128: 'BINARY',
53     129: 'CHAR',                    130: 'WCHAR',
54     131: 'NUMERIC',                 132: 'USERDEFINED',
55     133: 'DBDATE',                  134: 'DBTIME',
56     135: 'DBTIMESTAMP',             200: 'VARCHAR',
57     201: 'LONGVARCHAR',             202: 'VARWCHAR',
58     203: 'LONGVARWCHAR',            204: 'VARBINARY',
59     205: 'LONGVARBINARY'
60 }
61
62 DBCOLUMNFLAGS_WRITE = 0x4
63 DBCOLUMNFLAGS_WRITEUNKNOWN = 0x8
64 DBCOLUMNFLAGS_ISFIXEDLENGTH = 0x10
65 DBCOLUMNFLAGS_ISNULLABLE = 0x20
66 DBCOLUMNFLAGS_MAYBENULL = 0x40
67 DBCOLUMNFLAGS_ISLONG = 0x80
68 DBCOLUMNFLAGS_ISROWID = 0x100
69 DBCOLUMNFLAGS_ISROWVER = 0x200
70 DBCOLUMNFLAGS_CACHEDEFERRED = 0x1000
71
72
73 def time_from_com(com_date):
74     """Return a valid datetime.time from a COM date or time object."""
75     hour, minute = divmod(86400 * (float(com_date) % 1), 3600)
76     minute, second = divmod(minute, 60)
77     # Must do both int() and round() or we'll be up to 1 second off.
78     hour = int(round(hour))
79     minute = int(round(minute))
80     second = int(round(second))
81    
82     while second > 59:
83         second -= 60
84         minute += 1
85     while second < 0:
86         second += 60
87         minute -= 1
88     while minute > 59:
89         minute -= 60
90         hour += 1
91     while minute < 0:
92         minute += 60
93         hour -= 1
94     while hour > 23:
95         hour -= 24
96         day += 1
97     while hour < 0:
98         hour += 24
99    
100     return datetime.time(hour, minute, second)
101
102 class AdapterFromADO(db.AdapterFromDB):
103     """Coerce incoming values from ADO to Dejavu datatypes."""
104    
105     encoding = 'ISO-8859-1'
106    
107     def coerce_any_to_datetime_datetime(self, value):
108         # Illegal Date/Time values will crash the
109         # app when using value.Format(). Therefore,
110         # grab the value and figure the date ourselves.
111         # Use 1-second resolution only.
112         if isinstance(value, basestring):
113             if value:
114                 try:
115                     return datetime.datetime(int(value[0:4]), int(value[4:6]),
116                                              int(value[6:8]))
117                 except Exception:
118                     raise ValueError("'%s' %s" % (value, type(value)))
119             else:
120                 return None
121         else:
122             # For some reason, we need both float and int.
123             aDate = datetime.date.fromordinal(int(float(value)) + zeroHour)
124             return datetime.datetime.combine(aDate, time_from_com(value))
125    
126     def coerce_any_to_datetime_date(self, value):
127         # See coerce_any_to_datetime
128         if isinstance(value, basestring):
129             if value:
130                 try:
131                     return datetime.date(int(value[0:4]), int(value[4:6]),
132                                          int(value[6:8]))
133                 except Exception:
134                     raise ValueError("'%s' %s" % (value, type(value)))
135             else:
136                 return None
137         else:
138             return datetime.date.fromordinal(int(float(value)) + zeroHour)
139    
140     def coerce_any_to_datetime_time(self, value):
141         # See coerce_any_to_datetime
142         return time_from_com(value)
143    
144     def coerce_any_to_decimal_Decimal(self, value):
145         # pywin32 build 205 began support for returning
146         # COM Currency objects as decimal objects.
147         # See http://pywin32.cvs.sourceforge.net/pywin32/pywin32/CHANGES.txt?view=markup
148         if not isinstance(value, db.decimal.Decimal):
149             value = str(value)
150             value = db.decimal.Decimal(str(value))
151         return value
152    
153     coerce_any_to_float = float
154    
155     def coerce_CURRENCY_to_float(self, value):
156         if isinstance(value, tuple):
157             # See http://groups.google.com/group/comp.lang.python/
158             #           browse_frm/thread/fed03c64735c9e9c
159             value = map(long, value)
160             return ((value[1] & 0xFFFFFFFFL) | (value[0] << 32)) / 1e4
161         return float(value)
162    
163     def coerce_CURRENCY_to_decimal_Decimal(self, value):
164         # pywin32 build 205 began support for returning
165         # COM Currency objects as decimal objects.
166         # See http://pywin32.cvs.sourceforge.net/pywin32/pywin32/CHANGES.txt?view=markup
167         if not isinstance(value, db.decimal.Decimal):
168             # See http://groups.google.com/group/comp.lang.python/
169             #           browse_frm/thread/fed03c64735c9e9c
170             value = map(long, value)
171             value = (value[1] & 0xFFFFFFFFL) | (value[0] << 32)
172             return db.decimal.Decimal(value) / 10000
173         return value
174    
175     def coerce_CURRENCY_to_fixedpoint_FixedPoint(self, value):
176         if isinstance(value, db.decimal.Decimal):
177             value = str(value)
178             scale = 0
179             atoms = value.rsplit(".", 1)
180             if len(atoms) > 1:
181                 scale = len(atoms[-1])
182             return db.fixedpoint.FixedPoint(value, scale)
183         else:
184             # See http://groups.google.com/group/comp.lang.python/
185             #           browse_frm/thread/fed03c64735c9e9c
186             value = map(long, value)
187             value = (value[1] & 0xFFFFFFFFL) | (value[0] << 32)
188             return db.fixedpoint.FixedPoint(value, 4) / 1e4
189    
190     coerce_any_to_int = int
191     coerce_any_to_bool = bool
192    
193     def coerce_any_to_unicode(self, value):
194         if isinstance(value, unicode):
195             # For some reason, inValue is already a unicode object.
196             return value
197         if isinstance(value, (basestring, buffer)):
198             try:
199                 return unicode(value, self.encoding)
200             except UnicodeError:
201                 raise StandardError(type(value))
202         return unicode(value)
203
204
205
206 class ADOSQLDecompiler(db.SQLDecompiler):
207    
208     def visit_COMPARE_OP(self, lo, hi):
209         op2, op1 = self.stack.pop(), self.stack.pop()
210         if op1 is db.cannot_represent or op2 is db.cannot_represent:
211             self.stack.append(db.cannot_represent)
212             return
213        
214         op = lo + (hi << 8)
215         if op in (6, 7):     # in, not in
216             # Looking for text in a field. Use Like (reverse terms).
217             # LIKE is case-insensitive in MS SQL Server (and there
218             # doesn't seem to be a way around it). Use icontainedby
219             # and just mark imperfect.
220             value = self.dejavu_icontainedby(op1, op2)
221             if op == 7:
222                 value = "NOT " + value
223             self.stack.append(value)
224             self.imperfect = True
225         elif op1 == 'NULL':
226             if op in (2, 8):    # '==', is
227                 self.stack.append(op2 + " IS NULL")
228             elif op in (3, 9):  # '!=', 'is not'
229                 self.stack.append(op2 + " IS NOT NULL")
230             else:
231                 raise ValueError("Non-equality Null comparisons not allowed.")
232         elif op2 == 'NULL':
233             if op in (2, 8):    # '==', 'is'
234                 self.stack.append(op1 + " IS NULL")
235             elif op in (3, 9):  # '!=', 'is not'
236                 self.stack.append(op1 + " IS NOT NULL")
237             else:
238                 raise ValueError("Non-equality Null comparisons not allowed.")
239         else:
240             if (isinstance(op2, db.ConstWrapper)
241                 and isinstance(op2.basevalue, basestring)):
242                 # ADO comparison operators for strings are case-insensitive
243                 # by default. Rather than determine which columns in the DB
244                 # might be case-sensitive, just flag them all as imperfect.
245                 # TODO: might be possible to cast both to varbinary, but
246                 # that may cause problems with unicode columns.
247                 self.imperfect = True
248             self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2)
249    
250     # --------------------------- Dispatchees --------------------------- #
251    
252     def attr_startswith(self, tos, arg):
253         self.imperfect = True
254         return tos + " LIKE '" + self.adapter.escape_like(arg) + "%'"
255    
256     def attr_endswith(self, tos, arg):
257         self.imperfect = True
258         return tos + " LIKE '%" + self.adapter.escape_like(arg) + "'"
259    
260     def containedby(self, op1, op2):
261         self.imperfect = True
262         if isinstance(op1, ConstWrapper):
263             # Looking for text in a field. Use Like (reverse terms).
264             return op2 + " LIKE '%" + self.adapter.escape_like(op1) + "%'"
265         else:
266             # Looking for field in (a, b, c)
267             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
268             return op1 + " IN (" + ", ".join(atoms) + ")"
269    
270     def dejavu_icontainedby(self, op1, op2):
271         if isinstance(op1, db.ConstWrapper):
272             # Looking for text in a field. Use Like (reverse terms).
273             # LIKE is already case-insensitive in MS SQL Server;
274             # so don't use LOWER().
275             value = op2 + " LIKE '%" + self.adapter.escape_like(op1) + "%'"
276         else:
277             # Looking for field in (a, b, c)
278             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
279             value = op1 + " IN (" + ", ".join(atoms) + ")"
280         return value
281    
282     def dejavu_istartswith(self, x, y):
283         # Like is already case-insensitive in ADO; so don't use LOWER().
284         return x + " LIKE '" + self.adapter.escape_like(y) + "%'"
285    
286     def dejavu_iendswith(self, x, y):
287         # Like is already case-insensitive in ADO; so don't use LOWER().
288         return x + " LIKE '%" + self.adapter.escape_like(y) + "'"
289    
290     def dejavu_ieq(self, x, y):
291         # = is already case-insensitive in ADO.
292         return x + " = " + y
293    
294     def dejavu_now(self):
295         return "getdate()"
296    
297     def dejavu_today(self):
298         return "DATEADD(dd, DATEDIFF(dd,0,getdate()), 0)"
299    
300     def func__builtin___len(self, x):
301         return "Len(" + x + ")"
302
303
304 class ADOColumnSet(db.ColumnSet):
305    
306     def _rename(self, oldcol, newcol):
307         conn = self.table.db.connection()
308         try:
309             cat = win32com.client.Dispatch(r'ADOX.Catalog')
310             cat.ActiveConnection = conn
311             cat.Tables(self.table.name).Columns(oldcol.name).Name = newcol.name
312         finally:
313             conn = None
314             cat = None
315
316
317 def connatoms(connstring):
318     atoms = {}
319     for pair in connstring.split(";"):
320         if pair:
321             k, v = pair.split("=", 1)
322             atoms[k.upper().strip()] = v.strip()
323     return atoms
324
325
326 class ADODatabase(db.Database):
327    
328     decompiler = ADOSQLDecompiler
329     adapterfromdb = AdapterFromADO()
330     columnsetclass = ADOColumnSet
331    
332     def _get_tables(self, conn=None):
333         # cols will be
334         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202),
335         # (u'TABLE_TYPE', 202), (u'TABLE_GUID', 72), (u'DESCRIPTION', 203),
336         # (u'TABLE_PROPID', 19), (u'DATE_CREATED', 7), (u'DATE_MODIFIED', 7)]
337         data, cols = self.fetch(adSchemaTables, conn=conn, schema=True)
338         return [db.Table(self, str(row[2]), self.quote(str(row[2])))
339                 for row in data]
340    
341     def _get_columns(self, tablename, conn=None):
342         # columns will be
343         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202),
344         # (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72), (u'COLUMN_PROPID', 19),
345         # (u'ORDINAL_POSITION', 19), (u'COLUMN_HASDEFAULT', 11),
346         # (u'COLUMN_DEFAULT', 203), (u'COLUMN_FLAGS', 19), (u'IS_NULLABLE', 11),
347         # (u'DATA_TYPE', 18), (u'TYPE_GUID', 72), (u'CHARACTER_MAXIMUM_LENGTH', 19),
348         # (u'CHARACTER_OCTET_LENGTH', 19), (u'NUMERIC_PRECISION', 18),
349         # (u'NUMERIC_SCALE', 2), (u'DATETIME_PRECISION', 19),
350         # (u'CHARACTER_SET_CATALOG', 202), (u'CHARACTER_SET_SCHEMA', 202),
351         # (u'CHARACTER_SET_NAME', 202), (u'COLLATION_CATALOG', 202),
352         # (u'COLLATION_SCHEMA', 202), (u'COLLATION_NAME', 202),
353         # (u'DOMAIN_CATALOG', 202), (u'DOMAIN_SCHEMA', 202),
354         # (u'DOMAIN_NAME', 202), (u'DESCRIPTION', 203)]
355         data, coldefs = self.fetch(adSchemaColumns, conn=conn, schema=True)
356        
357         cols = []
358         for row in data:
359             # I tried passing criteria to OpenSchema, but passing None is
360             # not the same as passing pythoncom.Empty (which errors).
361             if tablename and row[2] != tablename:
362                 continue
363            
364             dbtype = dbtypes[row[11]]
365             default = row[8]
366             if default is not None:
367                 default = self.python_type(dbtype)(default)
368            
369             name = str(row[3])
370             c = db.Column(name, self.quote(name), dbtype, default)
371            
372             # This only works for SQL Server. The MSAccessDatabase will
373             # wrap this method and override autoincrement.
374             colflags = int(row[9])
375             if ((colflags & DBCOLUMNFLAGS_ISFIXEDLENGTH)
376                 and not (colflags & DBCOLUMNFLAGS_WRITE)):
377                 c.autoincrement = True
378            
379             if dbtype in ("SMALLINT", "INTEGER", "TINYINT",
380                           "UNSIGNEDTINYINT", "UNSIGNEDSMALLINT",
381                           "UNSIGNEDINT", "BIGINT", "UNSIGNEDBIGINT"):
382                 c.hints['bytes'] = row[15]
383             elif dbtype in ("SINGLE", "DOUBLE"):
384                 c.hints['precision'] = row[15]
385                 c.hints['scale'] = row[16]
386             elif dbtype == "CURRENCY":
387                 # CURRENCY allows 15 places to the left of the decimal point,
388                 # and 4 places to the right.
389                 c.hints['precision'] = 19
390                 c.hints['scale'] = 4
391             elif dbtype in ("DECIMAL", "NUMERIC"):
392                 c.hints['precision'] = row[15]
393                 c.hints['scale'] = row[16]
394                 c.dbtype = "%s(%s, %s)" % (dbtype, row[15], row[16])
395             elif dbtype in ("BSTR", "VARIANT", "BINARY", "CHAR",
396                             "VARCHAR", "VARBINARY", "WCHAR", "VARWCHAR"):
397                 if row[13]:
398                     # row[13] will be a float
399                     c.hints['bytes'] = b = int(row[13])
400                 else:
401                     # I'm kinda guessing on this. If we use "MEMO" in an
402                     # MSAccess CREATE statement, it comes back as "WCHAR",
403                     # and seems to support over 65536 bytes.
404                     c.hints['bytes'] = b = (2 ** 31) - 1
405                 c.dbtype = "%s(%s)" % (c.dbtype, b)
406             elif dbtype in ("LONGVARCHAR", "LONGVARBINARY", "LONGVARWCHAR"):
407                 if row[13]:
408                     # row[13] will be a float
409                     c.hints['bytes'] = b = int(row[13])
410                     c.dbtype = "%s(%s)" % (c.dbtype, b)
411                 else:
412                     c.hints['bytes'] = 65535
413            
414             cols.append(c)
415         return cols
416    
417     def _get_indices(self, tablename=None, conn=None):
418         # cols will be
419         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202),
420         # (u'INDEX_CATALOG', 202), (u'INDEX_SCHEMA', 202), (u'INDEX_NAME', 202),
421         # (u'PRIMARY_KEY', 11), (u'UNIQUE', 11), (u'CLUSTERED', 11), (u'TYPE', 18),
422         # (u'FILL_FACTOR', 3), (u'INITIAL_SIZE', 3), (u'NULLS', 3),
423         # (u'SORT_BOOKMARKS', 11), (u'AUTO_UPDATE', 11), (u'NULL_COLLATION', 3),
424         # (u'ORDINAL_POSITION', 19), (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72),
425         # (u'COLUMN_PROPID', 19), (u'COLLATION', 2), (u'CARDINALITY', 21),
426         # (u'PAGES', 3), (u'FILTER_CONDITION', 202), (u'INTEGRATED', 11)]
427         data, _ = self.fetch(adSchemaIndexes, conn=conn, schema=True)
428         indices = []
429         for row in data:
430             # I tried passing criteria to OpenSchema, but passing None is
431             # not the same as passing pythoncom.Empty (which errors).
432             if tablename and row[2] != tablename:
433                 continue
434             i = db.Index(row[5], self.quote(row[5]),
435                          row[2], row[17], row[6], row[7])
436             indices.append(i)
437         return indices
438    
439     def python_type(self, dbtype):
440         """Return a Python type which can store values of the given dbtype."""
441         if dbtype in ("DATE", "DBDATE"):
442             return datetime.date
443         elif dbtype == "DBTIME":
444             return datetime.time
445         elif dbtype in ("DATETIME", "DBTIMESTAMP"):
446             return datetime.datetime
447         elif dbtype in ("SMALLINT", "INTEGER", "TINYINT",
448                         "UNSIGNEDTINYINT", "UNSIGNEDSMALLINT",
449                         "UNSIGNEDINT"):
450             return int
451         elif dbtype in ("BIT", "BOOLEAN"):
452             return bool
453         elif dbtype in ("BIGINT", "UNSIGNEDBIGINT", "LONG"):
454             return long
455         elif dbtype in ("SINGLE", "DOUBLE", "DOUBLE PRECISION", "REAL"):
456             return float
457        
458         for t in ("DECIMAL", "NUMERIC", "CURRENCY"):
459             if dbtype.startswith(t):
460                 if db.decimal:
461                     return db.decimal.Decimal
462                 elif db.fixedpoint:
463                     return db.fixedpoint.FixedPoint
464        
465         for t in ("BSTR", "VARIANT", "BINARY", "CHAR", "MEMO", "TEXT",
466                   "VARCHAR", "LONGVARCHAR", "VARBINARY", "LONGVARBINARY"):
467             if dbtype.startswith(t):
468                 return str
469        
470         for t in ("WCHAR", "VARWCHAR", "LONGVARWCHAR"):
471             if dbtype.startswith(t):
472                 return unicode
473        
474         raise TypeError("Database type %s could not be converted "
475                         "to a Python type." % repr(dbtype))
476    
477     def _rename(self, oldtable, newtable):
478         conn = self.connection()
479         try:
480             cat = win32com.client.Dispatch(r'ADOX.Catalog')
481             cat.ActiveConnection = conn
482             cat.Tables(oldtable.name).Name = newtable.name
483         finally:
484             conn = None
485             cat = None
486    
487     def quote(self, name):
488         """Return name, quoted for use in an SQL statement."""
489         return '[' + name + ']'
490    
491     def _get_conn(self):
492         conn = win32com.client.Dispatch(r'ADODB.Connection')
493         conn.Open(self.Connect)
494         return conn
495    
496     def _del_conn(self, conn):
497         conn.Close()
498    
499     def execute(self, query, conn=None):
500         if conn is None:
501             conn = self.connection()
502         if isinstance(query, unicode):
503             query = query.encode(self.adaptertosql.encoding)
504         self.log(query, dejavu.logflags.SQL)
505         try:
506             conn.Execute(query)
507         except pywintypes.com_error, x:
508             x.args += (query, )
509             conn = None
510             raise
511    
512     def fetch(self, query, conn=None, schema=False):
513         """fetch(query, conn=None) -> rowdata, columns."""
514         if conn is None:
515             conn = self.connection()
516        
517         try:
518             if schema:
519                 res = conn.OpenSchema(query)
520             else:
521                 self.log(query, dejavu.logflags.SQL)
522                 res = win32com.client.Dispatch(r'ADODB.Recordset')
523                 if hasattr(conn, "conn"):
524                     # 'conn' is a ConnectionWrapper object, which .Open
525                     # won't accept. Pass the unwrapped connection instead.
526                     res.Open(query, conn.conn, adOpenForwardOnly, adLockReadOnly)
527                 else:
528                     res.Open(query, conn, adOpenForwardOnly, adLockReadOnly)
529         except pywintypes.com_error, x:
530             try:
531                 res.Close()
532             except:
533                 pass
534             x.args += (query, )
535             conn = None
536             # "raise x" here or we could get the traceback of the inner try.
537             raise x
538        
539         columns = [(x.Name, x.Type) for x in res.Fields]
540        
541         data = []
542         if not(res.BOF and res.EOF):
543             # We tried .MoveNext() and lots of Fields.Item() calls.
544             # Using GetRows() beats that time by about 2/3.
545             data = res.GetRows()
546             # Convert cols x rows -> rows x cols
547             data = zip(*data)
548         try:
549             res.Close()
550         except:
551             pass
552         conn = None
553        
554         return data, columns
555
556
557 class StorageManagerADO(db.StorageManagerDB):
558     """StoreManager to save and retrieve Units via ADO 2.7.
559     
560     You must run makepy on ADO 2.7 before installing.
561     """
562    
563     databaseclass = ADODatabase
564    
565     def version(self):
566         adoconn = win32com.client.Dispatch(r'ADODB.Connection')
567         return "ADO Version: %s" % adoconn.Version
568
569
570
571 ###########################################################################
572 ##                                                                       ##
573 ##                             SQL Server                                ##
574 ##                                                                       ##
575 ###########################################################################
576
577
578 class AdapterToADOSQL_SQLServer(db.AdapterToSQL):
579    
580     encoding = 'ISO-8859-1'
581    
582     escapes = [("'", "''")]
583     like_escapes = [("%", "[%]"), ("_", "[_]")]
584    
585     # These are not the same as coerce_bool_to_any (which is used on one side of
586     # a comparison). Instead, these are used when the whole (sub)expression
587     # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3".
588     bool_true = "(1=1)"
589     bool_false = "(1=0)"
590    
591     def coerce_bool_to_any(self, value):
592         if value:
593             return '1'
594         return '0'
595
596
597 class TypeAdapter_SQLServer(db.TypeAdapter):
598    
599     # Hm. Docs say 38, but I can't seem to get more than 12 working.
600     # They must mean 38 binary digits; math.log(2 ** 38, 10) = 11.4+
601     numeric_max_precision = 12
602     numeric_max_bytes = 6
603    
604     def coerce_bool(self, col):
605         return "BIT"
606    
607     def coerce_datetime_datetime(self, col):
608         return "DATETIME"
609    
610     def coerce_datetime_date(self, col):
611         return "DATETIME"
612    
613     def coerce_datetime_time(self, col):
614         return "DATETIME"
615    
616     def int_type(self, bytes):
617         """Return a datatype which can handle the given number of bytes."""
618         if bytes <= 2:
619             return "SMALLINT"
620         elif bytes <= 4:
621             return "INTEGER"
622         elif bytes <= 8:
623             # BIGINT is usually 8 bytes
624             return "BIGINT"
625         else:
626             # Anything larger than 8 bytes, use decimal/numeric.
627             # For PostgreSQL, "The actual storage requirement is two bytes
628             # for each group of four decimal digits, plus eight bytes
629             # overhead." Note we omit the overhead in our calculation.
630             return "NUMERIC(%s, 0)" % (bytes * 2)
631    
632     def coerce_str(self, col):
633         # The bytes hint does not reflect the usual 4-byte base for varchar.
634         bytes = int(col.hints.get('bytes', 255))
635        
636         if bytes == 0 or bytes > 8000:
637             # Okay, what the @#$%& is wrong with Redmond??!?! We can't even
638             # compare TEXT or NTEXT fields??!? Fine. We'll deny such, and
639             # warn the deployer with less swearing and exclamation points.
640             warnings.warn("You have defined a string property without "
641                           "limiting its length. Microsoft SQL Server does "
642                           "not allow comparisons on string fields larger "
643                           "than 8000 characters. Some of your data may be "
644                           "truncated.", errors.StorageWarning)
645             bytes = 8000
646        
647         # 8000 *bytes* is the absolute upper limit, based on T_SQL docs for
648         # varchar/varbinary. If there are further fields defined for the
649         # class, or the codepage uses a double-byte character set, we still
650         # might exceed the max size (8060) for a record. We could calc the
651         # total requested record size, and adjust accordingly. Meh.
652         return "VARCHAR(%s)" % bytes
653
654
655 class SQLServerColumnSet(ADOColumnSet):
656    
657     def __setitem__(self, key, column):
658         t = self.table
659        
660         dbtype = column.dbtype
661         if column.autoincrement:
662             if dbtype not in ("BOOLEAN", "SMALLINT", "INTEGER", "BIGINT"):
663                 raise ValueError("SQL Server does not allow IDENTITY "
664                                  "columns of type %s" % repr(dbtype))
665             dbtype = "%s IDENTITY(%s, 1) NOT NULL" % (dbtype, column.default)
666         else:
667             default = column.default or ""
668             if default:
669                 default = self.adaptertosql.coerce(default, dbtype)
670                 dbtype = "%s DEFAULT %s" % (dbtype, default)
671        
672         # SQL Server doesn't use the "COLUMN" keyword with "ADD"
673         t.db.execute("ALTER TABLE %s ADD %s %s;" %
674                      (t.qname, column.qname, dbtype))
675         dict.__setitem__(self, key, column)
676    
677     def _rename(self, oldcol, newcol):
678         t = self.table
679         t.db.execute("EXEC sp_rename '%s.%s', '%s', 'COLUMN'" %
680                      (t.name, oldcol.name, newcol.name))
681
682
683 class SQLServerDatabase(ADODatabase):
684    
685     columnsetclass = SQLServerColumnSet
686     adaptertosql = AdapterToADOSQL_SQLServer()
687     typeadapter = TypeAdapter_SQLServer()
688    
689     def create_database(self):
690         # This method hasn't been tested yet for SQL server (only MSDE).
691         adoconn = win32com.client.Dispatch(r'ADODB.Connection')
692         atoms = connatoms(self.Connect)
693         atoms['INITIAL CATALOG'] = "tempdb"
694         adoconn.Open("; ".join(["%s=%s" % (k, v) for k, v in atoms.iteritems()]))
695         adoconn.Execute("CREATE DATABASE %s" % self.qname)
696         adoconn.Close()
697         self.clear()
698    
699     def drop_database(self):
700         adoconn = win32com.client.Dispatch(r'ADODB.Connection')
701         atoms = connatoms(self.Connect)
702         atoms['INITIAL CATALOG'] = "tempdb"
703         adoconn.Open("; ".join(["%s=%s" % (k, v) for k, v in atoms.iteritems()]))
704         adoconn.Execute("DROP DATABASE %s;" % self.qname)
705         adoconn.Close()
706         self.clear()
707    
708     def __setitem__(self, key, table):
709         if key in self:
710             del self[key]
711        
712         fields = []
713         for column in table.columns.itervalues():
714             dbtype = column.dbtype
715             if column.autoincrement:
716                 if dbtype not in ("BOOLEAN", "SMALLINT", "INTEGER", "BIGINT"):
717                     raise ValueError("SQL Server does not allow IDENTITY "
718                                      "columns of type %s" % repr(dbtype))
719                 dbtype = "%s IDENTITY(%s, 1) NOT NULL" % (dbtype, column.default)
720             else:
721                 default = column.default or ""
722                 if default:
723                     default = self.adaptertosql.coerce(default, dbtype)
724                     dbtype = "%s DEFAULT %s" % (dbtype, default)
725             fields.append('%s %s' % (column.qname, dbtype))
726        
727         self.execute('CREATE TABLE %s (%s);' %
728                      (table.qname, ", ".join(fields)))
729        
730         for index in table.columns.indices.itervalues():
731             self.execute('CREATE INDEX %s ON %s (%s);' %
732                          (index.qname, table.qname,
733                           self.quote(index.colname)))
734        
735         dict.__setitem__(self, key, table)
736
737
738 class StorageManagerADO_SQLServer(StorageManagerADO):
739    
740     databaseclass = SQLServerDatabase
741    
742     def __init__(self, name, arena, allOptions={}):
743         atoms = connatoms(allOptions['Connect'])
744         allOptions['name'] = atoms.get('INITIAL CATALOG') or atoms.get('DSN')
745         db.StorageManagerDB.__init__(self, name, arena, allOptions)
746    
747     def _seq_UnitSequencerInteger(self, unit):
748         """Reserve a unit using the table's AUTOINCREMENT field."""
749         cls = unit.__class__
750         t = self.db[cls.__name__]
751        
752         fields = []
753         values = []
754         for key in cls.properties:
755             col = t.columns[key]
756             if col.autoincrement:
757                 # Skip this field, since we're using IDENTITY
758                 continue
759             val = self.db.adaptertosql.coerce(getattr(unit, key), col.dbtype)
760             fields.append(col.qname)
761             values.append(val)
762        
763         fields = ", ".join(fields)
764         values = ", ".join(values)
765         self.db.execute('INSERT INTO %s (%s) VALUES (%s);' %
766                         (t.qname, fields, values))
767        
768         # Grab the new ID. This is threadsafe because db.reserve has a mutex.
769         # For some reason, using SCOPE_IDENTITY or IDENTITY failed (returned
770         # None) when retrieving ID's just after a 99-thread-test ran. Moving
771         # the multithreading test fixed it. IDENT_CURRENT worked regardless.
772         data, col_defs = self.db.fetch("SELECT IDENT_CURRENT('%s');" % t.qname)
773         setattr(unit, cls.identifiers[0], data[0][0])
774
775
776
777 ###########################################################################
778 ##                                                                       ##
779 ##                             MS Access                                 ##
780 ##                                                                       ##
781 ###########################################################################
782
783
784 class ADOSQLDecompiler_MSAccess(ADOSQLDecompiler):
785     sql_cmp_op = ('<', '<=', '=', '<>', '>', '>=', 'in', 'not in')
786    
787     def dejavu_now(self):
788         return "Now()"
789    
790     def dejavu_today(self):
791         return "DateValue(Now())"
792    
793     def dejavu_year(self, x):
794         return "Year(" + x + ")"
795
796
797 class TypeAdapter_MSAccess(db.TypeAdapter):
798    
799     # Hm. Docs say 28/38, but I can't seem to get more than 12 working.
800     numeric_max_precision = 12
801     numeric_max_bytes = 6
802    
803     def coerce_bool(self, col): return "BIT"
804    
805     def coerce_datetime_datetime(self, col): return "DATETIME"
806     def coerce_datetime_date(self, col): return "DATETIME"
807     def coerce_datetime_time(self, col): return "DATETIME"
808    
809     def int_type(self, bytes):
810         if bytes <= 2:
811             return "INTEGER"
812         elif bytes <= 4:
813             return "LONG"
814         else:
815             # Anything larger than 4 bytes, use decimal/numeric.
816             return "DECIMAL"
817    
818     def coerce_str(self, col):
819         # The bytes hint shall not reflect the usual 4-byte base for varchar.
820         bytes = int(col.hints.get('bytes', 255))
821        
822         # 255 chars is the upper limit for TEXT / VARCHAR in MS Access.
823         if bytes == 0 or bytes > 255:
824             # MEMO is 1 GB max when set programatically (only 64K when set
825             # in Access UI). But then, 1 GB is the limit for the whole DB.
826             # Note that OpenSchema will return a DATA_TYPE of "WCHAR".
827             return "MEMO"
828        
829         return "VARCHAR(%s)" % bytes
830
831
832
833 class AdapterToADOSQL_MSAccess(db.AdapterToSQL):
834     """Coerce Expression constants to ADO SQL."""
835    
836     encoding = 'ISO-8859-1'
837    
838     escapes = [("'", "''")]
839     like_escapes = [("%", "[%]"), ("_", "[_]")]
840    
841     def coerce_datetime_datetime_to_any(self, value):
842         return ('#%s/%s/%s %02d:%02d:%02d#' %
843                 (value.month, value.day, value.year,
844                  value.hour, value.minute, value.second))
845    
846     def coerce_datetime_date_to_any(self, value):
847         return '#%s/%s/%s#' % (value.month, value.day, value.year)
848    
849     def coerce_datetime_time_to_any(self, value):
850         return '#%02d:%02d:%02d#' % (value.hour, value.minute, value.second)
851
852
853 class MSAccessColumnSet(ADOColumnSet):
854    
855     def __setitem__(self, key, column):
856         t = self.table
857        
858         dbtype = column.dbtype
859         if column.autoincrement:
860             dbtype = "AUTOINCREMENT(%s, 1)" % column.default
861         else:
862             default = column.default or ""
863             if default:
864                 default = self.adaptertosql.coerce(default, dbtype)
865                 dbtype = "%s DEFAULT %s" % (dbtype, default)
866        
867         # SQL Server doesn't use the "COLUMN" keyword with "ADD"
868         t.db.execute("ALTER TABLE %s ADD %s %s;" %
869                      (t.qname, column.qname, dbtype))
870         dict.__setitem__(self, key, column)
871
872
873 class MSAccessDatabase(ADODatabase):
874    
875     decompiler = ADOSQLDecompiler_MSAccess
876     adaptertosql = AdapterToADOSQL_MSAccess()
877     typeadapter = TypeAdapter_MSAccess()
878    
879     columnsetclass = MSAccessColumnSet
880    
881     poolsize = 0
882    
883     def connect(self):
884         # MS Access can't use a pool, because there doesn't seem
885         # to be a commit timeout.
886         self.connection = db.SingleConnection(self._get_conn, self._del_conn)
887    
888     def _get_columns(self, tablename, conn=None):
889         cols = ADODatabase._get_columns(self, tablename, conn)
890         if conn is None:
891             conn = self.connection()
892        
893         try:
894             # Horrible hack to get autoincrement property
895             query = "SELECT * FROM %s WHERE FALSE" % self.quote(tablename)
896             res = win32com.client.Dispatch(r'ADODB.Recordset')
897             if hasattr(conn, "conn"):
898                 # 'conn' is a ConnectionWrapper object, which .Open
899                 # won't accept. Pass the unwrapped connection instead.
900                 res.Open(query, conn.conn, adOpenForwardOnly, adLockReadOnly)
901             else:
902                 res.Open(query, conn, adOpenForwardOnly, adLockReadOnly)
903         except pywintypes.com_error, x:
904             try:
905                 res.Close()
906             except:
907                 pass
908             x.args += (query, )
909             conn = None
910             # "raise x" here or we could get the traceback of the inner try.
911             raise x
912        
913         for c in cols:
914             c.autoincrement = res.Fields(c.name).Properties("ISAUTOINCREMENT").Value
915        
916         try:
917             res.Close()
918         except:
919             pass
920         conn = None
921        
922         return cols
923    
924     def python_type(self, dbtype):
925         if dbtype == "LONG":
926             return int
927         return ADODatabase.python_type(self, dbtype)
928    
929     def create_database(self):
930         # By not providing an Engine Type, it defaults to 5 = Access 2000.
931         cat = win32com.client.Dispatch(r'ADOX.Catalog')
932         cat.Create(self.Connect)
933         cat.ActiveConnection.Close()
934         self.clear()
935    
936     def drop_database(self):
937         import os
938         # This should accept relative or absolute paths
939         if os.path.exists(self.name):
940             os.remove(self.name)
941         self.clear()
942    
943     def __setitem__(self, key, table):
944         if key in self:
945             del self[key]
946        
947         fields = []
948         for column in table.columns.itervalues():
949             dbtype = column.dbtype
950             if column.autoincrement:
951                 dbtype = "AUTOINCREMENT(%s, 1)" % column.default
952             else:
953                 default = column.default or ""
954                 if default:
955                     default = self.adaptertosql.coerce(default, dbtype)
956                     dbtype = "%s DEFAULT %s" % (dbtype, default)
957             fields.append('%s %s' % (column.qname, dbtype))
958        
959         self.execute('CREATE TABLE %s (%s);' %
960                      (table.qname, ", ".join(fields)))
961        
962         for index in table.columns.indices.itervalues():
963             self.execute('CREATE INDEX %s ON %s (%s);' %
964                          (index.qname, table.qname,
965                           self.quote(index.colname)))
966        
967         dict.__setitem__(self, key, table)
968
969
970 class StorageManagerADO_MSAccess(StorageManagerADO):
971     # Jet Connections and Recordsets are always free-threaded.
972    
973     use_asterisk_to_get_all = True
974     databaseclass = MSAccessDatabase
975    
976     def __init__(self, name, arena, allOptions={}):
977         atoms = connatoms(allOptions['Connect'])
978         allOptions['name'] = (atoms.get('DATA SOURCE') or
979                               atoms.get('DATA SOURCE NAME') or
980                               atoms.get('DBQ'))
981         db.StorageManagerDB.__init__(self, name, arena, allOptions)
982    
983     def _seq_UnitSequencerInteger(self, unit):
984         """Reserve a unit using the table's AUTOINCREMENT field."""
985         cls = unit.__class__
986         t = self.db[cls.__name__]
987        
988         fields = []
989         values = []
990         for key in cls.properties:
991             col = t.columns[key]
992             if col.autoincrement:
993                 # Skip this field, since we're using AUTOINCREMENT
994                 continue
995             val = self.db.adaptertosql.coerce(getattr(unit, key), col.dbtype)
996             fields.append(col.qname)
997             values.append(val)
998        
999         fields = ", ".join(fields)
1000         values = ", ".join(values)
1001         self.db.execute('INSERT INTO %s (%s) VALUES (%s);' %
1002                         (t.qname, fields, values))
1003        
1004         # Grab the new ID. This is threadsafe because db.reserve has a mutex.
1005         data, col_defs = self.db.fetch("SELECT @@IDENTITY;")
1006         setattr(unit, cls.identifiers[0], data[0][0])
1007    
1008     def _make_column(self, cls, key):
1009         col = StorageManagerADO._make_column(self, cls, key)
1010         if col.dbtype == "MEMO":
1011             for assoc in cls._associations.itervalues():
1012                 if assoc.nearKey == key:
1013                     warnings.warn("Memo fields cannot be used as join keys. "
1014                                   "You should set %s.%s(hints={'bytes': 255})"
1015                                   % (cls.__name__, key), errors.StorageWarning)
1016         return col
1017
1018
1019 def gen_py():
1020     # Auto generate .py support for ADO 2.7+
1021     print 'Please wait while support for ADO 2.7+ is verified...'
1022     CLSID = '{EF53050B-882E-4776-B643-EDA472E8E3F2}'
1023     return win32com.client.gencache.EnsureModule(CLSID, 0, 2, 7)
1024
1025
1026 if __name__ == '__main__':
1027     gen_py()
Note: See TracBrowser for help on using the browser.