Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storeado.py

Revision 40 (checked in by fumanchu, 8 years ago)

1. Added storepypgsql (Postgres) + tests.
2. Minor bugs, omissions in storeado.

Line 
1 import sys
2 # Put COM in free-threaded mode
3 sys.coinit_flags = 0
4
5 import win32com.client
6 import pywintypes
7 import pythoncom
8 import threading
9 import datetime
10 try:
11     import cPickle as pickle
12 except ImportError:
13     import pickle
14 from types import FunctionType
15
16 try:
17     import fixedpoint
18 except ImportError:
19     pass
20
21 import dejavu
22 from dejavu import storage, codewalk, logic
23 import recur
24
25 adOpenForwardOnly = 0
26 adOpenKeyset = 1
27 adOpenDynamic = 2
28 adOpenStatic = 3
29
30 adLockReadOnly = 1
31 adLockPessimistic = 2
32 adLockOptimistic = 3
33 adLockBatchOptimistic = 4
34
35 adModeShareExclusive = 12
36
37 adStateClosed = 0
38 adStateOpen = 1
39 adStateConnecting = 2
40 adStateExecuting = 4
41 adStateFetching = 8
42
43 # 12/30/1899, the zero-Date for ADO = 693594
44 zeroHour = datetime.date(1899, 12, 30).toordinal()
45
46
47 def time_from_com(com_date):
48     """Return a valid (day, datetime.time) from a COM date or time object."""
49     hour, mins = divmod(86400 * (float(com_date) % 1), 3600)
50     mins, sec = divmod(mins, 60)
51     # Must do both int() and round() or we'll be up to 1 second off.
52     hour = int(round(hour))
53     mins = int(round(mins))
54     sec = int(round(sec))
55     return recur.sane_time(0, hour, mins, sec)
56
57
58 class AdapterFromADO(storage.Adapter):
59     """Coerce incoming values from ADO to Dejavu datatypes."""
60    
61     def __init__(self, unit=None):
62         self.unit = unit
63    
64     def consume(self, key, value):
65         expectedType = self.unit.__class__.property_type(key)
66         value = self.coerce(value, expectedType)
67         # Set the attribute directly to avoid __set__ overhead.
68         self.unit._properties[key] = value
69    
70     def pickle(self, value):
71         aType, value = value
72         if value is None:
73             return None
74        
75         # Coerce to str for pickle.loads restriction.
76         value = str(value)
77         return pickle.loads(value)
78    
79     def coerce_datetime_datetime(self, value):
80         # Illegal Date/Time values will crash the
81         # app when using value.Format(). Therefore,
82         # grab the value and figure the date ourselves.
83         # Use 1-second resolution only.
84         aType, value = value
85         if value is None:
86             return None
87         elif isinstance(value, basestring):
88             return datetime.datetime(int(value[0:4]), int(value[4:6]),
89                                      int(value[6:8]))
90         else:
91             # For some reason, we need both float and int.
92             aDate = datetime.date.fromordinal(int(float(value)) + zeroHour)
93             day, aTime = time_from_com(value)
94             return datetime.datetime.combine(aDate, aTime)
95    
96     def coerce_datetime_date(self, value):
97         # See coerce_datetime
98         aType, value = value
99         if value is None:
100             return None
101         elif isinstance(value, basestring):
102             return datetime.date(int(value[0:4]), int(value[4:6]),
103                                  int(value[6:8]))
104         else:
105             return datetime.date.fromordinal(int(float(value)) + zeroHour)
106    
107     def coerce_datetime_time(self, value):
108         # See coerce_datetime
109         aType, value = value
110         if value is None:
111             return None
112         else:
113             day, aTime = time_from_com(value)
114             return aTime
115    
116     def coerce_datetime_timedelta(self, value):
117         aType, value = value
118         if value is None:
119             return None
120        
121         days, seconds = divmod(value, 1)
122         return datetime.timedelta(days, int(seconds * 86400))
123    
124     coerce_dict = pickle
125    
126     def coerce_fixedpoint_FixedPoint(self, value):
127         aType, value = value
128         if value is None:
129             return None
130         if aType == 0x06:
131             # Currency
132             value = value[1] / 10000.0
133         return fixedpoint.FixedPoint(value)
134    
135     def coerce_float(self, value):
136         aType, value = value
137         if value is None:
138             return None
139         if aType == 0x06:
140             # Currency
141             value = value[1] / 10000.0
142         return float(value)
143    
144     def coerce_int(self, value):
145         aType, value = value
146         if value is None:
147             return None
148         if aType == 0x0b:
149             # Boolean
150             return value != 0
151         return int(value)
152    
153     coerce_bool = coerce_int
154     coerce_list = pickle
155    
156     def coerce_long(self, value):
157         aType, value = value
158         if value is None:
159             return None
160         return long(value)
161    
162     def coerce_str(self, value):
163         aType, value = value
164         if value is None:
165             return None
166         return str(value)
167    
168     coerce_tuple = pickle
169    
170     def coerce_unicode(self, value):
171         aType, value = value
172         if value is None:
173             return None
174         if isinstance(value, unicode):
175             # For some reason, inValue is already a unicode object.
176             return value
177         if isinstance(value, str):
178             try:
179                 return unicode(value, "ISO-8859-1")
180             except UnicodeError:
181                 raise StandardError(type(value))
182         return unicode(value)
183
184
185
186 class AdapterToADOFields(storage.Adapter):
187     """Coerce outgoing values from Dejavu datatypes to ADO.Field types."""
188    
189     def noop(self, value):
190         return value
191    
192     def coerce_bool(self, value):
193         if value:
194             return True
195         return False
196    
197     def coerce_datetime_datetime(self, value):
198         if value is None:
199             return None
200         return self.coerce_datetime_date(value) + self.coerce_datetime_time(value)
201    
202     def coerce_datetime_date(self, value):
203         if value is None:
204             return None
205         return value.toordinal() - zeroHour
206    
207     def coerce_datetime_time(self, value):
208         if value is None:
209             return None
210         return ((value.second + (value.minute * 60) + (value.hour * 3600))
211                 / 86400.0)
212    
213     def pickle(self, value):
214         # We must not use a pickle format other than 0, because binary
215         # strings are not safe for all DB string fields.
216         return pickle.dumps(value)
217    
218     coerce_dict = pickle
219    
220     def coerce_fixedpoint_FixedPoint(self, value):
221         if value is None:
222             return None
223         return float(value)
224    
225     coerce_float = noop
226     coerce_int = noop
227    
228     coerce_list = pickle
229    
230     coerce_long = noop
231     coerce_str = noop
232    
233     coerce_tuple = pickle
234    
235     coerce_unicode = noop
236
237
238 class AdapterToADOSQL(storage.Adapter):
239     """Coerce Expression constants to ADO SQL."""
240    
241     def tostr(self, value):
242         return str(value)
243    
244     def coerce_NoneType(self, value):
245         return "Null"
246    
247     def coerce_bool(self, value):
248         if value:
249             return 'True'
250         return 'False'
251    
252     def coerce_datetime_datetime(self, value):
253         return (u'#%s/%s/%s %02d:%02d:%02d#' %
254                 (value.month, value.day, value.year,
255                  value.hour, value.minute, value.second))
256    
257     def coerce_datetime_date(self, value):
258         return u'#%s/%s/%s#' % (value.month, value.day, value.year)
259    
260     def coerce_datetime_time(self, value):
261         return u'#%02d:%02d:%02d#' % (value.hour, value.minute, value.second)
262    
263     def coerce_datetime_timedelta(self, value):
264         float_val = value.days + (value.seconds / 86400.0)
265         return repr(float_val)
266    
267     coerce_fixedpoint_FixedPoint = tostr
268     coerce_float = tostr
269     coerce_int = tostr
270    
271     def coerce_list(self, value):
272         return "(" + ", ".join([self.coerce(x) for x in value]) + ")"
273    
274     coerce_long = tostr
275    
276     def coerce_str(self, value):
277         value = value.replace(u"'", u"''")
278         value = value.replace("%", "[%]")
279         value = value.replace("_", "[_]")
280         return "'" + value + "'"
281    
282     coerce_tuple = coerce_list
283    
284     coerce_unicode = coerce_str
285
286
287 def icontainedby(op1, op2, notin=False):
288     # This test doesn't work right, now that we use lists as
289     # well as tuples with IN. Need a way to mark field refs.
290     if op2.startswith("[") and op2.endswith("]"):
291         # Looking for text in a field. Use Like (reverse terms).
292         value = op2 + " Like '%" + op1[1:-1] + "%'"
293     else:
294         # Looking for field in (a, b, c)
295         value = op1 + " in " + op2
296     if notin:
297         value = "not " + value
298     return value
299
300
301 # Stack sentinels
302 cannot_represent = object()
303 table_arg = object()
304 kw_arg = object()
305
306 class ADOSQLDecompiler(codewalk.LambdaDecompiler):
307     """ADOSQLDecompiler(store, unitClass, expr, adapter=AdapterToADOSQL()).
308     
309     Produce SQL from a supplied Expression object, with a lambda of the form:
310         lambda x, **kw: ...
311     
312     Attributes of x (or whatever the name of the first argument is) will be
313     mapped to table columns. Keyword arguments should be bound to the
314     Expression before calling this decompiler.
315     """
316    
317     sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in')
318     functions = {dejavu.icontains: lambda x, y: x + " Like '%" + y[1:-1] + "%'",
319                  dejavu.icontainedby: icontainedby,
320                  dejavu.istartswith: lambda x, y: x + " Like '" + y[1:-1] + "%'",
321                  dejavu.iendswith: lambda x, y: x + " Like '%" + y[1:-1] + "'",
322                  dejavu.ieq: lambda x, y: x + " = " + y,
323                  dejavu.now: lambda: "getdate()",
324                  dejavu.today: lambda: "DATEADD(dd, DATEDIFF(dd,0,getdate()), 0)",
325                  dejavu.year: lambda x: "YEAR(" + x + ")",
326                  }
327    
328     def __init__(self, store, unitClass, expr, adapter=AdapterToADOSQL()):
329         self.tablename = store.prefix + unitClass.__name__
330         self.expr = expr
331         self.adapter = adapter
332         obj = expr.func
333         codewalk.LambdaDecompiler.__init__(self, obj)
334    
335     def code(self):
336         self.imperfect = False
337         self.walk()
338         result = self.stack[0]
339         if result is cannot_represent:
340             result = 'True'
341         return result, self.imperfect
342    
343     def visit_target(self, terms):
344         """A target is an AND or OR test."""
345         comp = self.stack.pop()
346         while terms:
347             term, operation = terms.pop()
348             # All this checking of cannot_represent is done so that a
349             # function (like dejavu.iscurrentweek) can be labeled
350             # imperfect--all Units (which match the rest of the
351             # Expression) will be recalled. They can then be
352             # compared in expr.evaluate(unit).
353             if term is not cannot_represent:
354                 if comp is cannot_represent:
355                     comp = term
356                 else:
357                     comp = "(%s) %s (%s)" % (term, operation, comp)
358         self.stack.append(comp)
359    
360     def visit_LOAD_DEREF(self, lo, hi):
361         raise ValueError("Illegal reference found in %s." % self.expr)
362    
363     def visit_LOAD_GLOBAL(self, lo, hi):
364         raise ValueError("Illegal global found in %s." % self.expr)
365    
366     def visit_LOAD_FAST(self, lo, hi):
367         if lo + (hi << 8) < self.co_argcount:
368             self.stack.append(table_arg)
369         else:
370             self.stack.append(kw_arg)
371    
372     def visit_LOAD_ATTR(self, lo, hi):
373         name = self.co_names[lo + (hi << 8)]
374         tos = self.stack.pop()
375         if tos is table_arg:
376             self.stack.append("[%s].[%s]" % (self.tablename, name))
377         else:
378             self.stack.append((tos, name))
379    
380     def visit_LOAD_CONST(self, lo, hi):
381         val = self.co_consts[lo + (hi << 8)]
382         # Some constants are function or class objects,
383         # which should not be coerced.
384         no_coerce = (FunctionType, type)
385         if isinstance(val, no_coerce):
386             pass
387         elif isinstance(val, type(len)):
388             val = str(val)
389         else:
390             val = self.adapter.coerce(val)
391         self.stack.append(val)
392    
393     def visit_BUILD_TUPLE(self, lo, hi):
394         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)])
395         self.stack.append("(" + terms + ")")
396    
397     visit_BUILD_LIST = visit_BUILD_TUPLE
398    
399     def visit_CALL_FUNCTION(self, lo, hi):
400         kwargs = {}
401         for i in range(hi):
402             val = self.stack.pop()
403             key = self.stack.pop()
404             kwargs[key] = val
405         kwargs = [k + "=" + v for k, v in kwargs.iteritems()]
406        
407         args = []
408         for i in range(lo):
409             arg = self.stack.pop()
410             args.append(arg)
411         args.reverse()
412        
413         if kwargs:
414             args += kwargs
415        
416         func = self.stack.pop()
417        
418         # Handle function objects.
419         if func in self.functions:
420             self.stack.append(self.functions[func](*args))
421         else:
422             if isinstance(func, tuple):
423                 tos, func = func
424                 if func == "startswith":
425                     self.stack.append(tos + " Like '" + args[0][1:-1] + "%'")
426                     self.imperfect = True
427                     return
428                 elif func == "endswith":
429                     self.stack.append(tos + " Like '%" + args[0][1:-1] + "'")
430                     self.imperfect = True
431                     return
432             elif isinstance(func, basestring) and func == '<built-in function len>':
433                 self.stack.append("Len(" + args[0] + ")")
434                 return
435            
436             if self.stack:
437                 self.stack[-1] = cannot_represent
438             else:
439                 self.stack = [cannot_represent]
440             self.imperfect = True
441    
442     def visit_COMPARE_OP(self, lo, hi):
443         op2, op1 = self.stack.pop(), self.stack.pop()
444         op = self.sql_cmp_op[lo + (hi << 8)]
445         if op == 'in':
446             self.stack.append(icontainedby(op1, op2))
447             self.imperfect = True
448         elif op == 'not in':
449             self.stack.append(icontainedby(op1, op2, True))
450             self.imperfect = True
451         elif op == '=' and op2 == 'Null':
452             self.stack.append(op1 + " Is Null")
453         elif op == '=' and op1 == 'Null':
454             self.stack.append(op2 + " Is Null")
455         elif op == '!=' and op2 == 'Null':
456             self.stack.append(op1 + " Is Not Null")
457         elif op == '!=' and op1 == 'Null':
458             self.stack.append(op2 + " Is Not Null")
459         else:
460             if op2.startswith("'") and op2.endswith("'"):
461                 # ADO comparison operators for strings are case-insensitive
462                 # by default. Rather than determine which columns in the DB
463                 # might be case-sensitive, just flag them all as imperfect.
464                 self.imperfect = True
465             self.stack.append(op1 + " " + op + " " + op2)
466    
467     def binary_op(self, op):
468         op2, op1 = self.stack.pop(), self.stack.pop()
469         self.stack.append(op1 + " " + op + " " + op2)
470    
471     def visit_BINARY_SUBSCR(self):
472         # The only BINARY_SUBSCR used in Expressions should be kwargs[key].
473         name = self.stack.pop()
474         tos = self.stack.pop()
475         if tos is not kw_arg:
476             raise ValueError(tos, name)
477         # name, since formed in LOAD_CONST, has extraneous single-quotes.
478         name = name[1:-1]
479         value = self.expr.kwargs[name]
480         value = self.adapter.coerce(value)
481         self.stack.append(value)
482    
483     def visit_UNARY_NOT(self):
484         op = self.stack.pop()
485         if op is cannot_represent:
486             # Usually as a result of has(farClassName).
487             self.stack.append(cannot_represent)
488         else:
489             self.stack.append("not (" + op + ")")
490
491
492 def safe_name(content):
493     return unicode(content).replace(u"_", u"")
494
495
496 class StoreIteratorADO(object):
497     """Iterator for populating Units from storage."""
498    
499     def __init__(self, store, unitClass, expr):
500         self.store  = store
501         self.unitClass = unitClass
502         self.expr = expr
503         self.colIndices = {}
504         self.fieldTypes = []
505        
506         self.sql, self.imperfect = store.select(unitClass, expr)
507    
508     def field(self, key, row):
509         try:
510             col = self.colIndices[key]
511         except KeyError, x:
512             x.args += (key, self.unitClass.__name__)
513             raise x
514        
515         return (self.fieldTypes[col], self.data[col][row])
516    
517     def load_data(self):
518         anRS = self.store.recordset(self.sql, adOpenForwardOnly,
519                                     adLockReadOnly)
520        
521         for col, x in enumerate(anRS.Fields):
522             self.colIndices[x.Name] = col
523             self.fieldTypes.append(x.Type)
524        
525         self.data = []
526         if not(anRS.BOF and anRS.EOF):
527 ##            anRS.MoveFirst()
528 ##            if not(anRS.BOF or anRS.EOF):
529             # We tried .MoveNext() and lots of Fields.Item() calls.
530             # Using GetRows() beats that time by about 2/3.
531             self.data = anRS.GetRows()
532         anRS.Close()
533    
534     def units(self):
535         s = self.store
536         clsname = self.unitClass.__name__
537         tbl = "%s_%s" % (s.prefix, safe_name(clsname))
538         self.load_data()
539         if len(self.data) > 0:
540             for row in range(len(self.data[0])):
541                 unit = self.unitClass()
542                 coercer = AdapterFromADO(unit)
543                 for key in unit.__class__.properties():
544                     if (clsname, key) in s.expanded_columns:
545                         # Grab the expanded data
546                         try:
547                             rs = s.recordset(u"SELECT EXPVAL FROM [%s_%s_%s]"
548                                              % (tbl,
549                                                 safe_name(self.field('ID', row)[1]),
550                                                 safe_name(key)))
551                         except pywintypes.com_error, x:
552                             # This usually occurs because the parent Unit
553                             # was reserved but no table yet made for these
554                             # expanded values. This is OK. TODO: trap this
555                             # more specifically by examining the errmsg.
556                             values = []
557                         else:
558                             if rs.BOF and rs.EOF:
559                                 values = []
560                             else:
561                                 values = [pickle.loads(str(x)) for x in rs.GetRows()[0]]
562                             rs.Close()
563                         expectedType = unit.__class__.property_type(key)
564                         values = expectedType(values)
565                         # Set the attribute directly to avoid __set__ overhead.
566                         unit._properties[key] = values
567                     else:
568                         value = self.field(key, row)
569                         coercer.consume(key, value)
570                 # If our SQL is imperfect, don't yield it to the
571                 # caller unless it passes evaluate().
572                 if (not self.imperfect) or self.expr.evaluate(unit):
573                     yield unit
574
575
576 class StoreMultiIteratorADO(StoreIteratorADO):
577     """Iterator for populating Units (from multiple classes) from storage."""
578    
579     def __init__(self, store, unitClass, expr, pairs):
580         self.store  = store
581         self.unitClass = unitClass
582         self.expr = expr
583         self.pairs = pairs
584         self.fieldTypes = []
585        
586         sel = store.multiselect(unitClass, expr, pairs)
587         self.sql, self.imperfect, self.columns = sel
588    
589     def populate_unit(self, unit, row):
590         """Populate a Unit from a database row."""
591         coercer = AdapterFromADO(unit)
592         cls = unit.__class__
593         for key in cls.properties():
594             try:
595                 col = self.columns.index((cls, key))
596             except ValueError, x:
597                 x.args += (cls, key)
598                 raise x
599             else:
600                 coercer.consume(key, (self.fieldTypes[col],
601                                       self.data[col][row]))
602    
603     def load_data(self):
604         anRS = self.store.recordset(self.sql, adOpenForwardOnly,
605                                     adLockReadOnly)
606        
607         for col, x in enumerate(anRS.Fields):
608             self.fieldTypes.append(x.Type)
609        
610         self.data = []
611         if not(anRS.BOF and anRS.EOF):
612 ##            anRS.MoveFirst()
613 ##            if not(anRS.BOF or anRS.EOF):
614             # We tried .MoveNext() and lots of Fields.Item() calls.
615             # Using GetRows() beats that time by about 2/3.
616             self.data = anRS.GetRows()
617         anRS.Close()
618    
619     def units(self):
620         self.load_data()
621         if len(self.data) > 0:
622             for row in range(len(self.data[0])):
623                 unit = self.unitClass()
624                 self.populate_unit(unit, row)
625                 # If our SQL is imperfect, don't yield it to the
626                 # caller unless it passes evaluate().
627                 if (not self.imperfect) or self.expr.evaluate(unit):
628                     cls, expr = self.pairs[0]
629                     farUnit = cls()
630                     self.populate_unit(farUnit, row)
631                     if farUnit.ID is None:
632                         yield unit, None
633                     elif ((not self.imperfect) or expr is None
634                           or expr.evaluate(farUnit)):
635                         yield unit, farUnit
636
637
638 class FieldTypeAdapter(object):
639     """Return the SQL typename of a DB column."""
640    
641     def coerce(self, cls, key):
642         """coerce(cls, key) -> SQL typename for valuetype."""
643         valuetype = cls.property_type(key)
644         mod = valuetype.__module__
645         if mod == "__builtin__":
646             xform = "coerce_%s" % valuetype.__name__
647         else:
648             xform = "coerce_%s_%s" % (mod, valuetype.__name__)
649         xform = xform.replace(".", "_")
650         try:
651             xform = getattr(self, xform)
652         except AttributeError:
653             raise TypeError("'%s' is not handled by %s." %
654                             (valuetype, self.__class__))
655         return xform(cls, key)
656    
657     def _create_str_storage(self, cls, key):
658         """This basic string handler does not know anything about the size
659         limitations of the particular database. You should use one of the
660         subclasses for your particular database if you need storage for
661         strings over 255 characters.
662         """
663         prop = getattr(cls, key)
664         size = int(prop.hints.get(u'Size', '255'))
665         return u"VARCHAR(%s)" % size
666    
667     def coerce_bool(self, cls, key): return u"BIT"
668    
669     def coerce_datetime_datetime(self, cls, key): return u"TIMESTAMP"
670     def coerce_datetime_date(self, cls, key): return u"DATE"
671     def coerce_datetime_time(self, cls, key): return u"TIME"
672    
673     coerce_dict = _create_str_storage
674    
675     def coerce_fixedpoint_FixedPoint(self, cls, key): return u"FLOAT"
676     def coerce_float(self, cls, key): return u"FLOAT"
677     def coerce_int(self, cls, key): return u"INTEGER"
678    
679     coerce_list = _create_str_storage
680     coerce_str = _create_str_storage
681     coerce_tuple = _create_str_storage
682     coerce_unicode = _create_str_storage
683
684
685 class StorageManagerADO(storage.StorageManager):
686     """StoreManager to save and retrieve Units via ADO 2.7.
687     
688     You must run makepy on ADO 2.7 before installing.
689     """
690    
691     decompiler = ADOSQLDecompiler
692     createAdapter = FieldTypeAdapter()
693     threaded = False
694    
695     def __init__(self, name, arena, allOptions={}):
696         pythoncom.CoInitialize()
697        
698         storage.StorageManager.__init__(self, name, arena, allOptions)
699        
700         self.connstring = allOptions[u'Connect']
701         self.CreateIfMissing = allOptions.get(u'Create If Missing', '')
702         self.DBName = allOptions.get(u'DBName', None)
703         if allOptions.get(u'Threaded', ''):
704             self.threaded = True
705             self._connection = None
706         else:
707             try:
708                 self._connection = win32com.client.Dispatch(r'ADODB.Connection')
709                 self._connection.Open(self.connstring)
710             except pywintypes.com_error, x:
711                 if x.args[2][5] == -2147467259 and self.CreateIfMissing:
712                     self.create_database()
713                     self._connection.Open(self.connstring)
714                 else:
715                     raise
716        
717         self.prefix = allOptions.get(u'Prefix', u"djv")
718         self.cursorType = int(allOptions.get(u'CursorType', adOpenDynamic))
719         self.lockType = int(allOptions.get(u'LockType', adLockOptimistic))
720        
721         ec = []
722         for prop in allOptions.get(u'Expanded Columns', '').split(","):
723             if prop:
724                 lastdot = prop.rfind(".")
725                 clsname, key = prop[:lastdot], prop[lastdot + 1:]
726                 ec.append((clsname, key))
727         self.expanded_columns = ec
728        
729         self.reserve_lock = threading.Lock()
730    
731     def shutdown(self):
732         if self._connection is not None:
733             self._connection.Close()
734    
735     def connection(self):
736         if self.threaded:
737             t = threading.currentThread()
738             if not hasattr(t, 'SMADOconn'):
739                 t.SMADOconn = win32com.client.Dispatch(r'ADODB.Connection')
740             if t.SMADOconn.State == adStateClosed:
741                 try:
742                     t.SMADOconn.Open(self.connstring)
743                 except pywintypes.com_error, x:
744                     if x.args[2][5] == -2147467259 and self.CreateIfMissing:
745                         self.create_database()
746                         t.SMADOconn.Open(self.connstring)
747                     else:
748                         raise
749             return t.SMADOconn
750         else:
751             return self._connection
752    
753     def recordset(self, aQuery, cursorType=None, lockType=None):
754         anRS = win32com.client.Dispatch(r'ADODB.Recordset')
755 ##        anRS.Cursorlocation = 3     # adUseClient; Use to obtain .Recordcount
756         if cursorType is None:
757             cursorType = self.cursorType
758         if lockType is None:
759             lockType = self.lockType
760        
761         try:
762             anRS.Open(aQuery, self.connection(), cursorType, lockType)
763         except pywintypes.com_error, x:
764             try:
765                 anRS.Close()
766             except:
767                 pass
768             x.args += (aQuery, )
769             raise x
770         return anRS
771    
772     def _join(self, path=[]):
773         if not path: return u''
774         firstcls = path.pop(0)
775         if not path: return firstcls.__name__
776        
777         spath = self.arena.associations.shortest_path(firstcls, path[0])
778         spath.pop(0)
779         cls = spath[0]
780         leftkey, rightkey = firstcls._associations[cls]
781         params = {u'prefix': self.prefix,
782                   u'left': firstcls.__name__,
783                   u'right': cls.__name__,
784                   u'leftkey': leftkey,
785                   u'rightkey': rightkey,
786                   }
787         if len(spath) == 1:
788             params[u'child'] = u"[%(prefix)s%(right)s]" % params
789         else:
790             params[u'child'] = u"(%s)" % self._join(spath)
791        
792         return (u"[%(prefix)s%(left)s] LEFT JOIN %(child)s"
793                 u" ON [%(prefix)s%(left)s].[%(leftkey)s] = "
794                 u"[%(prefix)s%(right)s].[%(rightkey)s]" % params)
795    
796     def multiselect(self, firstcls, firstexpr, pairs):
797         firstwhere, imp = self.where(firstcls, firstexpr)
798         cols = [(firstcls, k) for k in firstcls.properties()]
799        
800         # TODO: concat multiple pairs.
801         if len(pairs) != 1:
802             raise ValueError("Multiselect does not yet work on multiple pairs.")
803         for cls, expr in pairs:
804             if expr is None:
805                 expr = logic.Expression(lambda x: True)
806             j = self._join([firstcls, cls])
807            
808             w, new_imp = self.where(cls, expr)
809             imp |= new_imp
810             if w and w != "True":
811                 w = " WHERE %s AND %s" % (w, firstwhere)
812             else:
813                 w = " WHERE %s" % firstwhere
814            
815             cols += [(cls, k) for k in cls.properties()]
816             colnames = ["[%s%s].[%s]" % (self.prefix, colcls.__name__, k)
817                         for colcls, k in cols]
818            
819             statement = "SELECT %s FROM %s%s" % (u', '.join(colnames), j, w)
820            
821             return statement, imp, cols
822    
823     def select(self, unitClass, expr, distinct_fields=None):
824         tablename = self.prefix + safe_name(unitClass.__name__)
825         if distinct_fields:
826             distinct_fields = [u'[%s]' % x for x in distinct_fields]
827             sql = (u"SELECT DISTINCT %s FROM [%s]" %
828                    (u', '.join(distinct_fields), tablename))
829         else:
830             sql = u"SELECT * FROM [%s]" % tablename
831         w, i = self.where(unitClass, expr)
832         if len(w) > 0:
833             w = u" WHERE " + w
834         else:
835             w = u""
836         sql += w
837         return sql, i
838    
839     def where(self, cls, expr):
840         return self.decompiler(self, cls, expr).code()
841    
842     def execute(self, aQuery, conn=None):
843         if conn is None:
844             conn = self.connection()
845         try:
846             conn.Execute(aQuery)
847         except pywintypes.com_error, x:
848             x.args += (aQuery, )
849             raise x
850    
851     def recall(self, cls, expr=None, pairs=None):
852         if expr is None:
853             expr = logic.Expression(lambda x: True)
854        
855         if pairs is not None:
856             return StoreMultiIteratorADO(self, cls, expr, pairs).units()
857         else:
858             return StoreIteratorADO(self, cls, expr).units()
859    
860     def reserve(self, unit):
861         """reserve(unit). -> Reserve a persistent slot for unit."""
862         self.reserve_lock.acquire()
863         try:
864             if unit.ID is None:
865                 data = []
866                 clsname = unit.__class__.__name__
867                 anRS = self.recordset(u"SELECT ID FROM [%s%s];" %
868                                       (self.prefix, safe_name(clsname)))
869                 if not (anRS.BOF and anRS.EOF):
870                     data = anRS.GetRows()[0]
871                 unit.ID = unit.sequencer.next(data)
872                
873                 anRS.AddNew()
874                 anRS.Fields(u'ID').Value = unit.ID
875                 anRS.Update()
876                 anRS.Close()
877         finally:
878             self.reserve_lock.release()
879    
880     def save(self, unit, forceSave=False):
881         """save(unit, forceSave=False). -> Update storage from unit's data.
882         
883         Notice in particular that we do not use the auto-number or
884         sequence generation capabilities within some databases, etc.
885         The ID should be supplied by UnitSequencers via reserve().
886         """
887         if unit.dirty() or forceSave:
888             cls = unit.__class__
889             clsname = cls.__name__
890             # Use a cursor always--makes mixed-quotes, newline, etc easier.
891             anRS = self.recordset("SELECT * FROM [%s%s] WHERE ID = %s" %
892                                   (self.prefix, safe_name(clsname),
893                                    AdapterToADOSQL().coerce(unit.ID)))
894             if anRS.EOF and anRS.BOF:
895                 anRS.AddNew()
896                 anRS.Fields(u'ID').Value = unit.ID
897             fmt = AdapterToADOFields()
898             for key in cls.properties():
899                 if (clsname, key) in self.expanded_columns:
900                     # Special-case this field into its own table.
901                     self.save_expanded(unit, key)
902                 else:
903                     eachType = cls.property_type(key)
904                     newValue = fmt.coerce(getattr(unit, key), eachType)
905                     try:
906                         anRS.Fields(key).Value = newValue
907                     except pywintypes.com_error, x:
908                         try:
909                             anRS.Close()
910                         except:
911                             pass
912                         x.args += (clsname, key, eachType, newValue)
913                         raise x
914             anRS.Update()
915             anRS.Close()
916             unit.cleanse()
917    
918     def save_expanded(self, unit, key):
919         """Save a field using a table specifically for that purpose."""
920         unitcls = unit.__class__
921         table = ("%s_%s_%s_%s" % (self.prefix, safe_name(unitcls.__name__),
922                                   safe_name(unit.ID), safe_name(key)))
923        
924         conn = self.connection()
925        
926         # Just drop the old table and start with a new one.
927         try:
928             self.execute((u"DROP TABLE [%s];" % table), conn)
929         except pywintypes.com_error, x:
930             pass
931        
932         # Ugly, ugly hack to get NTEXT or MEMO as appropriate. The point
933         # is, we want a large text field so we can pickle each item.
934         ftype = self.createAdapter.coerce_list(None, None)
935         self.execute(u"CREATE TABLE [%s] (EXPVAL %s);" % (table, ftype), conn)
936        
937         ins = u"INSERT INTO [" + table + "] (EXPVAL) VALUES ('%s');"
938         val = getattr(unit, key)
939         if val:
940             for v in val:
941                 # Create a row for the unit.
942                 # Use an INSERT command (not a cursor) for better performance.
943                 v = pickle.dumps(v).replace("'", "''")
944                 self.execute(ins % v, conn)
945    
946     def destroy(self, unit):
947         """Delete the unit."""
948         # Use a DELETE command instead of a cursor for better performance.
949         deleteStatement = (u"DELETE * FROM [%s%s] WHERE ID = %s" %
950                            (self.prefix, safe_name(unit.__class__.__name__),
951                             AdapterToADOSQL().coerce(unit.ID)))
952         self.execute(deleteStatement)
953    
954     def create_storage(self, unitClass):
955         clsname = safe_name(unitClass.__name__)
956        
957         coerce = self.createAdapter.coerce
958         fields = []
959         for key in unitClass.properties():
960             if (unitClass.__name__, key) not in self.expanded_columns:
961                 fields.append(u"[%s] %s" % (key, coerce(unitClass, key)))
962         self.execute(u"CREATE TABLE [%s%s] (%s)" %
963                      (self.prefix, clsname, ", ".join(fields)))
964        
965         for index in unitClass.indices():
966             self.execute(u"CREATE INDEX [%si%s%s] ON [%s%s] (%s ASC)"
967                          % (self.prefix, clsname, safe_name(index),
968                             self.prefix, clsname, index))
969    
970     def distinct(self, cls, fields, expr=None):
971         """Return distinct values for specified fields."""
972         if expr is None:
973             expr = logic.Expression(lambda x: True)
974        
975         # ^%$#@! There's no way to handle imperfect queries without
976         # creating all involved Units, which defeats the purpose of
977         # distinct, which was a speed issue more than anything. Grr.
978         sql, imperfect = self.select(cls, expr, fields)
979         # Ignore for now.
980 ##        if imperfect:
981 ##            raise ValueError(u"The following query cannot be reliably "
982 ##                             u"returned from an ADO data source.",
983 ##                             u"distinct()", cls, fields, expr)
984        
985         anRS = self.recordset(sql, adOpenForwardOnly, adLockReadOnly)
986        
987         fieldTypes = [x.Type for x in anRS.Fields]
988         data = []
989         if not (anRS.BOF and anRS.EOF):
990             # We tried .MoveNext() and lots of Fields.Item() calls.
991             # Using GetRows() beats that time by about 2/3.
992             data = anRS.GetRows()
993         anRS.Close()
994        
995         if data:
996             coerced_data = []
997             coerce = AdapterFromADO().coerce
998             for col, field in enumerate(fields):
999                 expectedType = cls.property_type(field)
1000                 actualType = fieldTypes[col]
1001                 coerced_row = [coerce((actualType, val), expectedType)
1002                                for val in data[col]]
1003                 coerced_data.append(coerced_row)
1004             data = zip(*coerced_data)
1005         return data
1006
1007
1008 ###########################################################################
1009 ##                                                                       ##
1010 ##                             SQL Server                                ##
1011 ##                                                                       ##
1012 ###########################################################################
1013
1014
1015 class FieldTypeAdapter_SQLServer(FieldTypeAdapter):
1016    
1017     def _create_str_storage(self, cls, key):
1018         prop = getattr(cls, key)
1019         size = int(prop.hints.get(u'Size', '255'))
1020         if size == 0 or size > 8000:
1021             # 8000 *bytes* is the absolute upper limit, based on T_SQL docs
1022             # for varchar. If there are further fields defined for the class,
1023             # or the code page uses a double-byte character set, we still
1024             # might exceed the max size (8060) for a record. We could calc
1025             # the total requested record size, and adjust accordingly. For
1026             # now, we just trust that units generally use a size of 0 to
1027             # bump up to NTEXT (1 gig characters).
1028             return u"NTEXT"
1029         return u"VARCHAR(%s)" % size
1030    
1031     # dict, list, and tuple will all be pickled in AdapterToADO
1032     def coerce_dict(self, cls, key): return u"NTEXT"
1033     def coerce_list(self, cls, key): return u"NTEXT"
1034     coerce_str = _create_str_storage
1035     def coerce_tuple(self, cls, key): return u"NTEXT"
1036     coerce_unicode = _create_str_storage
1037
1038
1039 class StorageManagerADO_SQLServer(StorageManagerADO):
1040     createAdapter = FieldTypeAdapter_SQLServer()
1041    
1042     def create_database(self):
1043         # This method hasn't been tested yet for SQL server.
1044         adoconn = win32com.client.Dispatch(r'ADODB.Connection')
1045         adoconn.Open(self.connstring)
1046         adoconn.Execute("CREATE DATABASE %s" % self.DBName)
1047
1048
1049 ###########################################################################
1050 ##                                                                       ##
1051 ##                             MS Access                                 ##
1052 ##                                                                       ##
1053 ###########################################################################
1054
1055
1056 class ADOSQLDecompiler_MSAccess(ADOSQLDecompiler):
1057     sql_cmp_op = ('<', '<=', '=', '<>', '>', '>=', 'in', 'not in')
1058     functions = {dejavu.icontains: lambda x, y: x + " Like '%" + y[1:-1] + "%'",
1059                  dejavu.icontainedby: icontainedby,
1060                  dejavu.istartswith: lambda x, y: x + " Like '" + y[1:-1] + "%'",
1061                  dejavu.iendswith: lambda x, y: x + " Like '%" + y[1:-1] + "'",
1062                  dejavu.ieq: lambda x, y: x + " = " + y,
1063                  dejavu.now: lambda: "Now()",
1064                  dejavu.today: lambda: "DateValue(Now())",
1065                  dejavu.year: lambda x: "Year(" + x + ")",
1066                  }
1067
1068
1069 class FieldTypeAdapter_MSAccess(FieldTypeAdapter):
1070    
1071     def _create_str_storage(self, cls, key):
1072         prop = getattr(cls, key)
1073         size = int(prop.hints.get(u'Size', '255'))
1074         if size == 0 or size > 255:
1075             # 255 chars is the upper limit for TEXT / VARCHAR in MS Access.
1076             # MEMO is 1 gigabyte when set programatically (only 64K when set
1077             # in Access UI). But then, 1 GB is the limit for the whole DB.
1078             return u"MEMO"
1079         return u"VARCHAR(%s)" % size
1080    
1081     # dict, list, and tuple will all be pickled in AdapterToADO
1082     def coerce_dict(self, cls, key): return u"MEMO"
1083     def coerce_list(self, cls, key): return u"MEMO"
1084     coerce_str = _create_str_storage
1085     def coerce_tuple(self, cls, key): return u"MEMO"
1086     coerce_unicode = _create_str_storage
1087
1088
1089 class StorageManagerADO_MSAccess(StorageManagerADO):
1090    
1091     decompiler = ADOSQLDecompiler_MSAccess
1092     createAdapter = FieldTypeAdapter_MSAccess()
1093    
1094     def create_database(self):
1095         cat = win32com.client.Dispatch(r'ADOX.Catalog')
1096         cat.Create(self.connstring)
1097
1098
1099 if __name__ == '__main__':
1100     # Auto generate .py support for ADO 2.7
1101     print 'Please wait while support for ADO 2.7 is verified...'
1102     CLSID = '{EF53050B-882E-4776-B643-EDA472E8E3F2}'
1103     win32com.client.gencache.EnsureModule(CLSID, 0, 2, 7)
1104
Note: See TracBrowser for help on using the browser.