Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storeado.py

Revision 123 (checked in by fumanchu, 8 years ago)

Fix for #36 (column names not correctly escaped). Changes to StorageManagerDB (and its subclasses):

  1. identifier(*atoms) method changed to sql_name(name, quoted=True).
  2. New column_name(classname, name, full=False, quoted=True) method.
  3. SQLDecompiler now calls the StorageManager's column_name method (so decompilers now take the SM as a constructor arg).
  4. identifier_length is now sql_name_max_length.
  5. identifier_caseless is now sql_name_caseless.
  • Property svn:eol-style set to native
Line 
1 import sys
2 # Put COM in free-threaded mode. This first thread will have
3 # CoInitializeEx called automatically when pythoncom is imported.
4 sys.coinit_flags = 0
5 import pythoncom
6
7 import win32com.client
8 import pywintypes
9 import datetime
10
11 try:
12     import cPickle as pickle
13 except ImportError:
14     import pickle
15
16 try:
17     import fixedpoint
18 except ImportError:
19     pass
20
21 try:
22     # Builtin in Python 2.5?
23     decimal
24 except NameError:
25     try:
26         # Module in Python 2.3, 2.4
27         import decimal
28     except ImportError:
29         pass
30
31 import warnings
32
33 import dejavu
34 from dejavu import storage, logic
35 from dejavu.storage import db
36
37 adOpenForwardOnly = 0
38 adOpenKeyset = 1
39 adOpenDynamic = 2
40 adOpenStatic = 3
41
42 adLockReadOnly = 1
43 adLockPessimistic = 2
44 adLockOptimistic = 3
45 adLockBatchOptimistic = 4
46
47 adUseClient = 3
48
49 # 12/30/1899, the zero-Date for ADO = 693594
50 zeroHour = datetime.date(1899, 12, 30).toordinal()
51
52
53 def time_from_com(com_date):
54     """Return a valid datetime.time from a COM date or time object."""
55     hour, minute = divmod(86400 * (float(com_date) % 1), 3600)
56     minute, second = divmod(minute, 60)
57     # Must do both int() and round() or we'll be up to 1 second off.
58     hour = int(round(hour))
59     minute = int(round(minute))
60     second = int(round(second))
61    
62     while second > 59:
63         second -= 60
64         minute += 1
65     while second < 0:
66         second += 60
67         minute -= 1
68     while minute > 59:
69         minute -= 60
70         hour += 1
71     while minute < 0:
72         minute += 60
73         hour -= 1
74     while hour > 23:
75         hour -= 24
76         day += 1
77     while hour < 0:
78         hour += 24
79    
80     return datetime.time(hour, minute, second)
81
82
83 class AdapterFromADO(db.AdapterFromDB):
84     """Coerce incoming values from ADO to Dejavu datatypes."""
85    
86     def coerce_datetime_datetime(self, value, coltype):
87         # Illegal Date/Time values will crash the
88         # app when using value.Format(). Therefore,
89         # grab the value and figure the date ourselves.
90         # Use 1-second resolution only.
91         if isinstance(value, basestring):
92             if value:
93                 try:
94                     return datetime.datetime(int(value[0:4]), int(value[4:6]),
95                                              int(value[6:8]))
96                 except Exception, x:
97                     raise ValueError("'%s' %s" % (value, type(value)))
98             else:
99                 return None
100         else:
101             # For some reason, we need both float and int.
102             aDate = datetime.date.fromordinal(int(float(value)) + zeroHour)
103             return datetime.datetime.combine(aDate, time_from_com(value))
104    
105     def coerce_datetime_date(self, value, coltype):
106         # See coerce_datetime
107         if isinstance(value, basestring):
108             if value:
109                 try:
110                     return datetime.date(int(value[0:4]), int(value[4:6]),
111                                          int(value[6:8]))
112                 except Exception, x:
113                     raise ValueError("'%s' %s" % (value, type(value)))
114             else:
115                 return None
116         else:
117             return datetime.date.fromordinal(int(float(value)) + zeroHour)
118    
119     def coerce_datetime_time(self, value, coltype):
120         # See coerce_datetime
121         return time_from_com(value)
122    
123     def coerce_fixedpoint_FixedPoint(self, value, coltype):
124         if coltype == 0x06:
125             # Currency
126             value = value[1] / 10000.0
127         return fixedpoint.FixedPoint(value)
128    
129     def coerce_float(self, value, coltype):
130         if coltype == 0x06:
131             # Currency
132             value = value[1] / 10000.0
133         return float(value)
134    
135     def coerce_int(self, value, coltype):
136         if coltype == 0x0b:
137             # Boolean
138             return value != 0
139         return int(value)
140    
141     coerce_bool = coerce_int
142    
143     def coerce_unicode(self, value, coltype):
144         if isinstance(value, unicode):
145             # For some reason, inValue is already a unicode object.
146             return value
147         if isinstance(value, (basestring, buffer)):
148             try:
149                 return unicode(value, "ISO-8859-1")
150             except UnicodeError:
151                 raise StandardError(type(value))
152         return unicode(value)
153
154
155
156 class AdapterToADOFields(storage.Adapter):
157     """Coerce outgoing values from Dejavu datatypes to ADO.Field types."""
158    
159     def noop(self, value):
160         return value
161    
162     def coerce_bool(self, value):
163         if value:
164             return True
165         return False
166    
167     def coerce_datetime_datetime(self, value):
168         if value is None:
169             return None
170         return self.coerce_datetime_date(value) + self.coerce_datetime_time(value)
171    
172     def coerce_datetime_date(self, value):
173         if value is None:
174             return None
175         return value.toordinal() - zeroHour
176    
177     def coerce_datetime_time(self, value):
178         if value is None:
179             return None
180         return ((value.second + (value.minute * 60) + (value.hour * 3600))
181                 / 86400.0)
182    
183     def do_pickle(self, value):
184         # We must not use a pickle format other than 0, because binary
185         # strings are not safe for all DB string fields.
186         return pickle.dumps(value)
187    
188     coerce_dict = do_pickle
189    
190     def coerce_fixedpoint_FixedPoint(self, value):
191         if value is None:
192             return None
193         return float(value)
194    
195     coerce_float = noop
196     coerce_int = noop
197    
198     coerce_list = do_pickle
199    
200     coerce_long = noop
201     coerce_str = noop
202    
203     coerce_tuple = do_pickle
204    
205     coerce_unicode = noop
206
207
208 class ADOSQLDecompiler(db.SQLDecompiler):
209    
210     def visit_COMPARE_OP(self, lo, hi):
211         op2, op1 = self.stack.pop(), self.stack.pop()
212         if op1 is db.cannot_represent or op2 is db.cannot_represent:
213             self.stack.append(db.cannot_represent)
214             return
215        
216         op = lo + (hi << 8)
217         if op in (6, 7):     # in, not in
218             # Looking for text in a field. Use Like (reverse terms).
219             # LIKE is case-insensitive in MS SQL Server (and there
220             # doesn't seem to be a way around it). Use icontainedby
221             # and just mark imperfect.
222             value = self.dejavu_icontainedby(op1, op2)
223             if op == 7:
224                 value = "NOT " + value
225             self.stack.append(value)
226             self.imperfect = True
227         elif op1 == 'NULL':
228             if op == 2:
229                 self.stack.append(op2 + " IS NULL")
230             elif op == 3:
231                 self.stack.append(op2 + " IS NOT NULL")
232             else:
233                 raise ValueError("Non-equality Null comparisons not allowed.")
234         elif op2 == 'NULL':
235             if op == 2:
236                 self.stack.append(op1 + " IS NULL")
237             elif op == 3:
238                 self.stack.append(op1 + " IS NOT NULL")
239             else:
240                 raise ValueError("Non-equality Null comparisons not allowed.")
241         else:
242             if (isinstance(op2, db.ConstWrapper)
243                 and isinstance(op2.basevalue, basestring)):
244                 # ADO comparison operators for strings are case-insensitive
245                 # by default. Rather than determine which columns in the DB
246                 # might be case-sensitive, just flag them all as imperfect.
247                 # TODO: might be possible to cast both to varbinary, but
248                 # that may cause problems with unicode columns.
249                 self.imperfect = True
250             self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2)
251    
252     # --------------------------- Dispatchees --------------------------- #
253    
254     def attr_startswith(self, tos, arg):
255         self.imperfect = True
256         return tos + " LIKE '" + self.adapter.escape_like(arg) + "%'"
257    
258     def attr_endswith(self, tos, arg):
259         self.imperfect = True
260         return tos + " LIKE '%" + self.adapter.escape_like(arg) + "'"
261    
262     def containedby(self, op1, op2):
263         self.imperfect = True
264         if isinstance(op1, ConstWrapper):
265             # Looking for text in a field. Use Like (reverse terms).
266             return op2 + " LIKE '%" + self.adapter.escape_like(op1) + "%'"
267         else:
268             # Looking for field in (a, b, c)
269             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
270             return op1 + " IN (" + ", ".join(atoms) + ")"
271    
272     def dejavu_icontainedby(self, op1, op2):
273         if isinstance(op1, db.ConstWrapper):
274             # Looking for text in a field. Use Like (reverse terms).
275             # LIKE is already case-insensitive in MS SQL Server;
276             # so don't use LOWER().
277             value = op2 + " LIKE '%" + self.adapter.escape_like(op1) + "%'"
278         else:
279             # Looking for field in (a, b, c)
280             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
281             value = op1 + " IN (" + ", ".join(atoms) + ")"
282         return value
283    
284     def dejavu_istartswith(self, x, y):
285         # Like is already case-insensitive in ADO; so don't use LOWER().
286         return x + " LIKE '" + self.adapter.escape_like(y) + "%'"
287    
288     def dejavu_iendswith(self, x, y):
289         # Like is already case-insensitive in ADO; so don't use LOWER().
290         return x + " LIKE '%" + self.adapter.escape_like(y) + "'"
291    
292     def dejavu_ieq(self, x, y):
293         # = is already case-insensitive in ADO.
294         return x + " = " + y
295    
296     def dejavu_now(self):
297         return "getdate()"
298    
299     def dejavu_today(self):
300         return "DATEADD(dd, DATEDIFF(dd,0,getdate()), 0)"
301    
302     def func__builtin___len(self, x):
303         return "Len(" + x + ")"
304
305
306 class StorageManagerADO(db.StorageManagerDB):
307     """StoreManager to save and retrieve Units via ADO 2.7.
308     
309     You must run makepy on ADO 2.7 before installing.
310     """
311    
312     close_connection_method = 'Close'
313     decompiler = ADOSQLDecompiler
314     fromAdapter = AdapterFromADO()
315    
316     def connatoms(self):
317         atoms = {}
318         for pair in self.connstring.split(";"):
319             if pair:
320                 k, v = pair.split("=", 1)
321                 atoms[k.upper().strip()] = v.strip()
322         return atoms
323    
324     def sql_name(self, name, quoted=True):
325         if quoted:
326             name = '[' + name + ']'
327         return name
328    
329     def _get_conn(self):
330         conn = win32com.client.Dispatch(r'ADODB.Connection')
331         try:
332             conn.Open(self.connstring)
333             return conn
334         except pywintypes.com_error, x:
335             if x.args[2][5] == -2147467259:
336                 msg = x.args[2][2]
337                 if (
338                     # SQL Server: "Cannot open database requested in login
339                     # 'dejavu_test'. Login fails."
340                     msg.startswith("Cannot open database") or
341                     # MSAccess: "Could not find file
342                     # 'C:\Python23\Lib\site-packages\dejavu\storage\zoo.mdb'."
343                     msg.startswith("Could not find file")):
344                     if self.CreateIfMissing:
345                         self.create_database()
346                         conn.Open(self.connstring)
347                         return conn
348             raise
349    
350     def execute(self, query, conn=None):
351         if conn is None:
352             conn = self.connection()
353         self.arena.log(query, dejavu.LOGSQL)
354         try:
355             conn.Execute(query)
356         except pywintypes.com_error, x:
357             x.args += (query, )
358             conn = None
359             raise x
360    
361     def fetch(self, query, conn=None):
362         """fetch(query, conn=None) -> rowdata, columns."""
363         if conn is None:
364             conn = self.connection()
365         self.arena.log(query, dejavu.LOGSQL)
366        
367         res = win32com.client.Dispatch(r'ADODB.Recordset')
368         # Uncomment the following to get .Recordcount
369         # res.CursorLocation = adUseClient
370         try:
371             if self.threaded:
372                 # 'conn' will be a ConnectionWrapper object, which .Open
373                 # won't accept. Pass the unwrapped connection instead.
374                 res.Open(query, conn.conn, adOpenForwardOnly, adLockReadOnly)
375             else:
376                 res.Open(query, conn, adOpenForwardOnly, adLockReadOnly)
377         except pywintypes.com_error, x:
378             try:
379                 res.Close()
380             except:
381                 pass
382             x.args += (query, )
383             conn = None
384             raise x
385        
386         columns = [(x.Name, x.Type) for x in res.Fields]
387        
388         data = []
389         if not(res.BOF and res.EOF):
390             # We tried .MoveNext() and lots of Fields.Item() calls.
391             # Using GetRows() beats that time by about 2/3.
392             data = res.GetRows()
393             # Convert cols x rows -> rows x cols
394             data = zip(*data)
395         res.Close()
396         conn = None
397        
398         return data, columns
399        
400     def version(self):
401         adoconn = win32com.client.Dispatch(r'ADODB.Connection')
402         return "ADO Version: %s" % adoconn.Version
403
404
405 ###########################################################################
406 ##                                                                       ##
407 ##                             SQL Server                                ##
408 ##                                                                       ##
409 ###########################################################################
410
411
412 class AdapterToADOSQL_SQLServer(db.AdapterToSQL):
413    
414     escapes = [("'", "''")]
415     like_escapes = [("%", "[%]"), ("_", "[_]")]
416    
417     # These are not the same as coerce_bool (which is used on one side of
418     # a comparison). Instead, these are used when the whole (sub)expression
419     # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3".
420     bool_true = "(1=1)"
421     bool_false = "(1=0)"
422    
423     def coerce_bool(self, value):
424         if value:
425             return '1'
426         return '0'
427
428
429 class FieldTypeAdapter_SQLServer(db.FieldTypeAdapter):
430    
431     numeric_max_precision = 38
432    
433     def coerce_bool(self, cls, key): return u"BIT"
434    
435     def coerce_datetime_datetime(self, cls, key):
436         return u"DATETIME"
437    
438     def coerce_datetime_date(self, cls, key):
439         return u"DATETIME"
440    
441     def coerce_datetime_time(self, cls, key):
442         return u"DATETIME"
443    
444     def coerce_str(self, cls, key):
445         # The bytes hint does not reflect the usual 4-byte base for varchar.
446         prop = getattr(cls, key)
447         bytes = int(prop.hints.get(u'bytes', '0'))
448         if bytes == 0:
449             # Okay, what the @#$%& is wrong with Redmond??!?! We can't even
450             # compare TEXT or NTEXT fields??!? Fine. We'll deny such, and
451             # warn the deployer with less swearing and exclamation points.
452             warnings.warn("You have defined a string property without "
453                           "limiting its length. Microsoft SQL Server does "
454                           "not allow comparisons on string fields larger "
455                           "than 8000 characters. Some of your data may be "
456                           "truncated.", dejavu.StorageWarning)
457             bytes = 8000
458         # 8000 *bytes* is the absolute upper limit, based on T_SQL docs for
459         # varchar/varbinary. If there are further fields defined for the
460         # class, or the codepage uses a double-byte character set, we still
461         # might exceed the max size (8060) for a record. We could calc the
462         # total requested record size, and adjust accordingly. Meh.
463         return u"VARCHAR(%s)" % bytes
464
465
466 class StorageManagerADO_SQLServer(StorageManagerADO):
467    
468     typeAdapter = FieldTypeAdapter_SQLServer()
469     toAdapter = AdapterToADOSQL_SQLServer()
470    
471     def __init__(self, name, arena, allOptions={}):
472         db.StorageManagerDB.__init__(self, name, arena, allOptions)
473        
474         self.connstring = allOptions[u'Connect']
475         atoms = self.connatoms()
476         self.dbname = atoms[u'INITIAL CATALOG']
477    
478     def create_database(self):
479         # This method hasn't been tested yet for SQL server.
480         adoconn = win32com.client.Dispatch(r'ADODB.Connection')
481         atoms = self.connatoms()
482         atoms['INITIAL CATALOG'] = "tempdb"
483         adoconn.Open("; ".join(["%s=%s" % (k, v) for k, v in atoms.iteritems()]))
484         adoconn.Execute("CREATE DATABASE %s" % self.sql_name(self.dbname))
485         adoconn.Close()
486    
487     def drop_database(self):
488         adoconn = win32com.client.Dispatch(r'ADODB.Connection')
489         atoms = self.connatoms()
490         atoms['INITIAL CATALOG'] = "tempdb"
491         adoconn.Open("; ".join(["%s=%s" % (k, v) for k, v in atoms.iteritems()]))
492         adoconn.Execute("DROP DATABASE %s;" % self.sql_name(self.dbname))
493         adoconn.Close()
494
495
496 ###########################################################################
497 ##                                                                       ##
498 ##                             MS Access                                 ##
499 ##                                                                       ##
500 ###########################################################################
501
502
503 class ADOSQLDecompiler_MSAccess(ADOSQLDecompiler):
504     sql_cmp_op = ('<', '<=', '=', '<>', '>', '>=', 'in', 'not in')
505    
506     def dejavu_now(self):
507         return "Now()"
508    
509     def dejavu_today(self):
510         return "DateValue(Now())"
511    
512     def dejavu_year(self, x):
513         return "Year(" + x + ")"
514
515
516 class FieldTypeAdapter_MSAccess(db.FieldTypeAdapter):
517    
518     numeric_max_precision = 15
519    
520     def coerce_bool(self, cls, key): return u"BIT"
521    
522     def coerce_datetime_datetime(self, cls, key): return u"DATETIME"
523     def coerce_datetime_date(self, cls, key): return u"DATETIME"
524     def coerce_datetime_time(self, cls, key): return u"DATETIME"
525    
526     def numeric_type(self, cls, key, precision, scale):
527         if precision > self.numeric_max_precision:
528             warnings.warn("Decimal precision %s > maximum %s for %s.%s, "
529                           "using %s. Values may be stored incorrectly."
530                           % (precision, self.numeric_max_precision,
531                              cls.__name__, key, self.__class__.__name__),
532                           dejavu.StorageWarning)
533             precision = self.numeric_max_precision
534         if scale > 4:
535             warnings.warn("Decimal scale %s > maximum 4 for %s.%s, "
536                           "using %s. Values may be stored incorrectly."
537                           % (scale, cls.__name__, key,
538                              self.__class__.__name__),
539                           dejavu.StorageWarning)
540        
541         # MS Access doesn't let us control precision and scale directly.
542         # From http://support.microsoft.com/?kbid=104977
543         # ORACLE number            Microsoft Access data type
544         # ---------------------------------------------------
545         # Scale = 0 and
546         #     precision <= 4       Integer
547         #     precision <= 9       Long Integer
548         #     precision <= 15      Double
549         # Scale > 0 and  <= 4
550         #     precision <= 15      Double
551         # Scale > 4 and/or
552         #     precision > 15       Text
553         if scale == 0:
554             if precision <= 4:
555                 return "INTEGER"
556             elif precision <= 9:
557                 return "LONG"
558         return "DOUBLE"
559    
560     def coerce_decimal_Decimal(self, cls, key):
561         prop = getattr(cls, key)
562         precision = int(prop.hints.get('precision', '0'))
563         if precision == 0:
564             precision = decimal.getcontext().prec
565         # Assume most people use decimal for money; default scale = 2.
566         scale = int(prop.hints.get(u'scale', 2))
567         return self.numeric_type(cls, key, precision, scale)
568    
569     def coerce_fixedpoint_FixedPoint(self, cls, key):
570         prop = getattr(cls, key)
571         precision = int(prop.hints.get('precision', '0'))
572         if precision == 0:
573             precision = self.numeric_max_precision
574         # Assume most people use decimal for money; default scale = 2.
575         scale = int(prop.hints.get(u'scale', 2))
576         return self.numeric_type(cls, key, precision, scale)
577    
578     def coerce_int(self, cls, key):
579         prop = getattr(cls, key)
580         bytes = int(prop.hints.get(u'bytes', '4'))
581         if bytes == 1:
582             return "BIT"
583         else:
584             return u"INTEGER"
585    
586     def coerce_long(self, cls, key):
587         prop = getattr(cls, key)
588         bytes = int(prop.hints.get(u'bytes', 0))
589         return self.numeric_type(cls, key, precision, 0)
590    
591     def coerce_str(self, cls, key):
592         # The bytes hint shall not reflect the usual 4-byte base for varchar.
593         prop = getattr(cls, key)
594         bytes = int(prop.hints.get(u'bytes', '0'))
595         if bytes and bytes <= 255:
596             # 255 chars is the upper limit for TEXT / VARCHAR in MS Access.
597             return u"VARCHAR(%s)" % bytes
598         else:
599             # MEMO is 1 GB max when set programatically (only 64K when set
600             # in Access UI). But then, 1 GB is the limit for the whole DB.
601             for assoc in cls._associations.itervalues():
602                 if assoc.nearKey == key:
603                     warnings.warn("Memo fields cannot be used as join keys. "
604                                   "You should set %s.%s(hints={'bytes': 255})"
605                                   % (cls.__name__, key),
606                                   dejavu.StorageWarning)
607             return u"MEMO"
608
609
610 class AdapterToADOSQL_MSAccess(db.AdapterToSQL):
611     """Coerce Expression constants to ADO SQL."""
612    
613     escapes = [("'", "''")]
614     like_escapes = [("%", "[%]"), ("_", "[_]")]
615    
616     def coerce_datetime_datetime(self, value):
617         return (u'#%s/%s/%s %02d:%02d:%02d#' %
618                 (value.month, value.day, value.year,
619                  value.hour, value.minute, value.second))
620    
621     def coerce_datetime_date(self, value):
622         return u'#%s/%s/%s#' % (value.month, value.day, value.year)
623    
624     def coerce_datetime_time(self, value):
625         return u'#%02d:%02d:%02d#' % (value.hour, value.minute, value.second)
626
627
628 class StorageManagerADO_MSAccess(StorageManagerADO):
629     # Jet Connections and Recordsets are always free-threaded.
630    
631     decompiler = ADOSQLDecompiler_MSAccess
632     typeAdapter = FieldTypeAdapter_MSAccess()
633     toAdapter = AdapterToADOSQL_MSAccess()
634    
635     def __init__(self, name, arena, allOptions={}):
636         db.StorageManagerDB.__init__(self, name, arena, allOptions)
637        
638         self.connstring = allOptions[u'Connect']
639         atoms = self.connatoms()
640         self.dbname = (atoms.get(u'DATA SOURCE') or
641                        atoms.get(u'DATA SOURCE NAME') or
642                        atoms.get(u'DBQ'))
643         # MS Access can't use a pool, because there doesn't seem
644         # to be a commit timeout.
645         self.pool = None
646         self.threaded = False
647         self.debug_connections = True
648    
649     def create_database(self):
650         # By not providing an Engine Type, it defaults to 5 = Access 2000.
651         cat = win32com.client.Dispatch(r'ADOX.Catalog')
652         cat.Create(self.connstring)
653         cat.ActiveConnection.Close()
654    
655     def drop_database(self):
656         import os
657         # This should accept relative or absolute paths
658         if os.path.exists(self.dbname):
659             os.remove(self.dbname)
660
661
662 def gen_py():
663     # Auto generate .py support for ADO 2.7+
664     print 'Please wait while support for ADO 2.7+ is verified...'
665     CLSID = '{EF53050B-882E-4776-B643-EDA472E8E3F2}'
666     return win32com.client.gencache.EnsureModule(CLSID, 0, 2, 7)
667
668
669 if __name__ == '__main__':
670     gen_py()
Note: See TracBrowser for help on using the browser.