Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storeado.py

Revision 25 (checked in by fumanchu, 9 years ago)

1. Moved UnitCollection?._IDs to a UnitProperty?, .Members
2. Bug in CachingProxy? save (not storing dirty unit data)
3. Generalized the storage strategy of UnitCollection?.ID (use new table instead of a field) to any property, configurable by deployer.
4. Stuffed common try/except into SM.recordset, execute.
5. Reworked column types (createCoercion) to FieldTypeAdapters?.

Line 
1 import sys
2 # Put COM in free-threaded mode
3 sys.coinit_flags = 0
4
5 import win32com.client
6 import pywintypes
7 import pythoncom
8 import threading
9 import datetime
10 try:
11     import cPickle as pickle
12 except ImportError:
13     import pickle
14 from types import FunctionType
15
16 try:
17     import fixedpoint
18 except ImportError:
19     pass
20
21 import dejavu
22 from dejavu import storage, codewalk, logic
23 import recur
24
25 adOpenForwardOnly = 0
26 adOpenKeyset = 1
27 adOpenDynamic = 2
28 adOpenStatic = 3
29
30 adLockReadOnly = 1
31 adLockPessimistic = 2
32 adLockOptimistic = 3
33 adLockBatchOptimistic = 4
34
35 adModeShareExclusive = 12
36
37 adStateClosed = 0
38 adStateOpen = 1
39 adStateConnecting = 2
40 adStateExecuting = 4
41 adStateFetching = 8
42
43 # 12/30/1899, the zero-Date for ADO = 693594
44 zeroHour = datetime.date(1899, 12, 30).toordinal()
45
46
47 def time_from_com(com_date):
48     """Return a valid (day, datetime.time) from a COM date or time object."""
49     hour, mins = divmod(86400 * (float(com_date) % 1), 3600)
50     mins, sec = divmod(mins, 60)
51     # Must do both int() and round() or we'll be up to 1 second off.
52     hour = int(round(hour))
53     mins = int(round(mins))
54     sec = int(round(sec))
55     return recur.sane_time(0, hour, mins, sec)
56
57
58 class AdapterFromADO(storage.Adapter):
59     """Coerce incoming values from ADO to Dejavu datatypes."""
60     def __init__(self, unit=None):
61         self.unit = unit
62    
63     def consume(self, key, value):
64         expectedType = self.unit.__class__.property_type(key)
65         value = self.coerce(value, expectedType)
66         # Set the attribute directly to avoid __set__ overhead.
67         self.unit._properties[key] = value
68    
69     def pickle(self, value):
70         aType, value = value
71         if value is None:
72             return None
73        
74         # Coerce to str for pickle.loads restriction.
75         value = str(value)
76         return pickle.loads(value)
77    
78     def coerce_datetime_datetime(self, value):
79         # Illegal Date/Time values will crash the
80         # app when using value.Format(). Therefore,
81         # grab the value and figure the date ourselves.
82         # Use 1-second resolution only.
83         aType, value = value
84         if value is None:
85             return None
86         elif isinstance(value, basestring):
87             return datetime.datetime(int(value[0:4]), int(value[4:6]),
88                                      int(value[6:8]))
89         else:
90             # For some reason, we need both float and int.
91             aDate = datetime.date.fromordinal(int(float(value)) + zeroHour)
92             day, aTime = time_from_com(value)
93             return datetime.datetime.combine(aDate, aTime)
94    
95     def coerce_datetime_date(self, value):
96         # See coerce_datetime
97         aType, value = value
98         if value is None:
99             return None
100         elif isinstance(value, basestring):
101             return datetime.date(int(value[0:4]), int(value[4:6]),
102                                  int(value[6:8]))
103         else:
104             return datetime.date.fromordinal(int(float(value)) + zeroHour)
105    
106     def coerce_datetime_time(self, value):
107         # See coerce_datetime
108         aType, value = value
109         if value is None:
110             return None
111         else:
112             day, aTime = time_from_com(value)
113             return aTime
114    
115     coerce_dict = pickle
116    
117     def coerce_fixedpoint_FixedPoint(self, value):
118         aType, value = value
119         if value is None:
120             return None
121         if aType == 0x06:
122             # Currency
123             value = value[1] / 10000.0
124         return fixedpoint.FixedPoint(value)
125    
126     def coerce_float(self, value):
127         aType, value = value
128         if value is None:
129             return None
130         if aType == 0x06:
131             # Currency
132             value = value[1] / 10000.0
133         return float(value)
134    
135     def coerce_int(self, value):
136         aType, value = value
137         if value is None:
138             return None
139         if aType == 0x0b:
140             # Boolean
141             return value != 0
142         return int(value)
143    
144     coerce_bool = coerce_int
145     coerce_list = pickle
146    
147     def coerce_long(self, value):
148         aType, value = value
149         if value is None:
150             return None
151         return long(value)
152    
153     def coerce_str(self, value):
154         aType, value = value
155         if value is None:
156             return None
157         return str(value)
158    
159     coerce_tuple = pickle
160    
161     def coerce_unicode(self, value):
162         aType, value = value
163         if value is None:
164             return None
165         if isinstance(value, unicode):
166             # For some reason, inValue is already a unicode object.
167             return value
168         if isinstance(value, str):
169             try:
170                 return unicode(value, "ISO-8859-1")
171             except UnicodeError:
172                 raise StandardError(type(value))
173         return unicode(value)
174
175
176
177 class AdapterToADOFields(storage.Adapter):
178     """Coerce outgoing values from Dejavu datatypes to ADO.Field types."""
179    
180     def noop(self, value):
181         return value
182    
183     def coerce_bool(self, value):
184         if value:
185             return True
186         return False
187    
188     def coerce_datetime_datetime(self, value):
189         if value is None:
190             return None
191         return self.coerce_datetime_date(value) + self.coerce_datetime_time(value)
192    
193     def coerce_datetime_date(self, value):
194         if value is None:
195             return None
196         return value.toordinal() - zeroHour
197    
198     def coerce_datetime_time(self, value):
199         if value is None:
200             return None
201         return ((value.second + (value.minute * 60) + (value.hour * 3600))
202                 / 86400.0)
203    
204     def pickle(self, value):
205         # We must not use a pickle format other than 0, because binary
206         # strings are not safe for all DB string fields.
207         return pickle.dumps(value)
208    
209     coerce_dict = pickle
210    
211     def coerce_fixedpoint_FixedPoint(self, value):
212         if value is None:
213             return None
214         return float(value)
215    
216     coerce_float = noop
217     coerce_int = noop
218    
219     coerce_list = pickle
220    
221     coerce_long = noop
222     coerce_str = noop
223    
224     coerce_tuple = pickle
225    
226     coerce_unicode = noop
227
228
229 class AdapterToADOSQL(storage.Adapter):
230     """Coerce Expression constants to ADO SQL."""
231    
232     def tostr(self, value):
233         return str(value)
234    
235     def coerce_NoneType(self, value):
236         return "Null"
237    
238     def coerce_bool(self, value):
239         if value:
240             return 'True'
241         return 'False'
242    
243     def coerce_datetime_datetime(self, value):
244         return (u'#%s/%s/%s %02d:%02d:%02d#' %
245                 (value.month, value.day, value.year,
246                  value.hour, value.minute, value.second))
247    
248     def coerce_datetime_date(self, value):
249         return u'#%s/%s/%s#' % (value.month, value.day, value.year)
250    
251     def coerce_datetime_time(self, value):
252         return u'#%02d:%02d:%02d#' % (value.hour, value.minute, value.second)
253    
254     def coerce_datetime_timedelta(self, value):
255         float_val = value.days + (value.seconds / 86400.0)
256         return repr(float_val)
257    
258     coerce_fixedpoint_FixedPoint = tostr
259     coerce_float = tostr
260     coerce_int = tostr
261    
262     def coerce_list(self, value):
263         return "(" + ", ".join([self.coerce(x) for x in value]) + ")"
264    
265     coerce_long = tostr
266    
267     def coerce_str(self, value):
268         value = value.replace(u"'", u"''")
269         value = value.replace("%", "[%]")
270         value = value.replace("_", "[_]")
271         return "'" + value + "'"
272    
273     coerce_tuple = coerce_list
274    
275     coerce_unicode = coerce_str
276
277
278 def icontainedby(op1, op2, notin=False):
279     # This test doesn't work right, now that we use lists as
280     # well as tuples with IN. Need a way to mark field refs.
281     if op2.startswith("[") and op2.endswith("]"):
282         # Looking for text in a field. Use Like (reverse terms).
283         value = op2 + " Like '%" + op1[1:-1] + "%'"
284     else:
285         # Looking for field in (a, b, c)
286         value = op1 + " in " + op2
287     if notin:
288         value = "not " + value
289     return value
290
291
292 class ADOSQLDecompiler(codewalk.LambdaDecompiler):
293     """ADOSQLDecompiler(store, unitClass, expr, adapter=AdapterToADOSQL()).
294     
295     Produce SQL from a supplied Expression object, with a lambda of the form:
296         lambda x, **kw: ...
297     
298     Attributes of x (or whatever the name of the first argument is) will be
299     mapped to table columns. Keyword arguments should be bound to the
300     Expression before calling this decompiler.
301     """
302    
303     sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in')
304     functions = {dejavu.icontains: lambda x, y: x + " Like '%" + y[1:-1] + "%'",
305                  dejavu.icontainedby: icontainedby,
306                  dejavu.istartswith: lambda x, y: x + " Like '" + y[1:-1] + "%'",
307                  dejavu.iendswith: lambda x, y: x + " Like '%" + y[1:-1] + "'",
308                  dejavu.ieq: lambda x, y: x + " = " + y,
309                  dejavu.now: lambda: "getdate()",
310                  dejavu.today: lambda: "DATEADD(dd, DATEDIFF(dd,0,getdate()), 0)",
311                  dejavu.year: lambda x: "YEAR(" + x + ")",
312                  }
313    
314     def __init__(self, store, unitClass, expr, adapter=AdapterToADOSQL()):
315         self.store = store
316         self.unitClass = unitClass
317         self.expr = expr
318         self.adapter = adapter
319         obj = expr.func
320         codewalk.LambdaDecompiler.__init__(self, obj)
321    
322     def code(self):
323         self.imperfect = False
324         self.walk()
325         result = self.stack[0]
326         if result is None:
327             result = 'True'
328         return result, self.imperfect
329    
330     def visit_target(self, terms):
331         """A target is an AND or OR test."""
332         comp = self.stack.pop()
333         while terms:
334             term, operation = terms.pop()
335             # All this checking of None is done so that a function
336             # (like dejavu.iscurrentweek) can be labeled imperfect--
337             # all Units (which match the rest of the Expression)
338             # will be recalled. They can then be compared in
339             # expr.evaluate(unit).
340             if comp is None:
341                 if term is not None:
342                     comp = term
343             else:
344                 if term is not None:
345                     comp = "(%s) %s (%s)" % (term, operation, comp)
346         self.stack.append(comp)
347    
348     def visit_LOAD_DEREF(self, lo, hi):
349         raise ValueError("Illegal reference found in %s." % self.expr)
350    
351     def visit_LOAD_GLOBAL(self, lo, hi):
352         raise ValueError("Illegal global found in %s." % self.expr)
353    
354     def visit_LOAD_FAST(self, lo, hi):
355         pass
356    
357     def visit_LOAD_ATTR(self, lo, hi):
358         name = self.co_names[lo + (hi << 8)]
359         self.stack.append("[%s%s].[%s]" %
360                           (self.store.prefix, self.unitClass.__name__, name))
361    
362     def visit_LOAD_CONST(self, lo, hi):
363         val = self.co_consts[lo + (hi << 8)]
364         # Some constants are function or class objects,
365         # which should not be coerced.
366         no_coerce = (FunctionType, type)
367         if isinstance(val, no_coerce):
368             pass
369         elif isinstance(val, type(len)):
370             val = str(val)
371         else:
372             val = self.adapter.coerce(val)
373         self.stack.append(val)
374    
375     def visit_BUILD_TUPLE(self, lo, hi):
376         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)])
377         self.stack.append("(" + terms + ")")
378    
379     def visit_BUILD_LIST(self, lo, hi):
380         self.visit_BUILD_TUPLE(lo, hi)
381    
382     def visit_CALL_FUNCTION(self, lo, hi):
383         kwargs = {}
384         for i in range(hi):
385             val = self.stack.pop()
386             key = self.stack.pop()
387             kwargs[key] = val
388         kwargs = [k + "=" + v for k, v in kwargs.iteritems()]
389        
390         args = []
391         for i in range(lo):
392             arg = self.stack.pop()
393             args.append(arg)
394         args.reverse()
395        
396         if kwargs:
397             args += kwargs
398        
399         func = self.stack.pop()
400        
401         # Handle function objects.
402         if func in self.functions:
403             self.stack.append(self.functions[func](*args))
404         else:
405             if isinstance(func, basestring):
406                 if func.endswith("[startswith]"):
407                     self.stack[-1] = self.stack[-1] + " Like '" + args[0][1:-1] + "%'"
408                     self.imperfect = True
409                     return
410                 elif func.endswith("[endswith]"):
411                     self.stack[-1] = self.stack[-1] + " Like '%" + args[0][1:-1] + "'"
412                     self.imperfect = True
413                     return
414                 elif func == '<built-in function len>':
415                     self.stack.append("Len(" + args[0] + ")")
416                     return
417             else:
418                 if self.stack:
419                     self.stack[-1] = None
420                 else:
421                     self.stack = [None]
422                 self.imperfect = True
423    
424     def visit_COMPARE_OP(self, lo, hi):
425         op2, op1 = self.stack.pop(), self.stack.pop()
426         op = self.sql_cmp_op[lo + (hi << 8)]
427         if op == 'in':
428             self.stack.append(icontainedby(op1, op2))
429             self.imperfect = True
430         elif op == 'not in':
431             self.stack.append(icontainedby(op1, op2, True))
432             self.imperfect = True
433         elif op == '=' and op2 == 'Null':
434             self.stack.append(op1 + " Is Null")
435         elif op == '=' and op1 == 'Null':
436             self.stack.append(op2 + " Is Null")
437         else:
438             if op2.startswith("'") and op2.endswith("'"):
439                 # All ADO comparison operators for strings are case-insensitive
440                 # by default. Rather than determine column-by-column which
441                 # might be case-sensitive, just flag them all as imperfect.
442                 self.imperfect = True
443             self.stack.append(op1 + " " + op + " " + op2)
444    
445     def binary_op(self, op):
446         op2, op1 = self.stack.pop(), self.stack.pop()
447         self.stack.append(op1 + " " + op + " " + op2)
448    
449     def visit_BINARY_SUBSCR(self):
450         """The only BINARY_SUBSCR used in Expressions should be kwargs[key]."""
451         name = self.stack.pop()
452         # name, since formed in LOAD_CONST, has extraneous single-quotes.
453         value = self.expr.kwargs[name[1:-1]]
454         value = self.adapter.coerce(value)
455         self.stack.append(value)
456    
457     def visit_UNARY_NOT(self):
458         op = self.stack.pop()
459         if op is None:
460             # Usually as a result of has(farClassName).
461             self.stack.append(None)
462         else:
463             self.stack.append("not (" + op + ")")
464
465
466 def safe_name(content):
467     return unicode(content).replace(u"_", u"")
468
469
470 class StoreIteratorADO(object):
471     """Iterator for populating Units from storage."""
472    
473     def __init__(self, store, unitClass, expr):
474         self.store  = store
475         self.unitClass = unitClass
476         self.expr = expr
477         self.colIndices = {}
478         self.fieldTypes = []
479        
480         self.sql, self.imperfect = store.select(unitClass, expr)
481    
482     def field(self, key, row):
483         try:
484             col = self.colIndices[key]
485         except KeyError, x:
486             x.args += (key, self.unitClass.__name__)
487             raise x
488        
489         return (self.fieldTypes[col], self.data[col][row])
490    
491     def load_data(self):
492         anRS = self.store.recordset(self.sql, adOpenForwardOnly,
493                                     adLockReadOnly)
494        
495         for col, x in enumerate(anRS.Fields):
496             self.colIndices[x.Name] = col
497             self.fieldTypes.append(x.Type)
498        
499         self.data = []
500         if not(anRS.BOF and anRS.EOF):
501 ##            anRS.MoveFirst()
502 ##            if not(anRS.BOF or anRS.EOF):
503             # We tried .MoveNext() and lots of Fields.Item() calls.
504             # Using GetRows() beats that time by about 2/3.
505             self.data = anRS.GetRows()
506         anRS.Close()
507    
508     def units(self):
509         s = self.store
510         clsname = self.unitClass.__name__
511         tbl = "%s_%s" % (s.prefix, safe_name(clsname))
512         self.load_data()
513         if len(self.data) > 0:
514             for row in range(len(self.data[0])):
515                 unit = self.unitClass()
516                 coercer = AdapterFromADO(unit)
517                 for key in unit.__class__.properties():
518                     if (clsname, key) in s.expanded_columns:
519                         # Grab the expanded data
520                         try:
521                             rs = s.recordset(u"SELECT EXPVAL FROM [%s_%s_%s]"
522                                              % (tbl,
523                                                 safe_name(self.field('ID', row)[1]),
524                                                 safe_name(key)))
525                         except pywintypes.com_error, x:
526                             # This usually occurs because the parent Unit
527                             # was reserved but no table yet made for these
528                             # expanded values. This is OK. TODO: trap this
529                             # more specifically by examining the errmsg.
530                             values = []
531                         else:
532                             values = [pickle.loads(str(x)) for x in rs.GetRows()[0]]
533                             rs.Close()
534                         expectedType = unit.__class__.property_type(key)
535                         values = expectedType(values)
536                         # Set the attribute directly to avoid __set__ overhead.
537                         unit._properties[key] = values
538                     else:
539                         value = self.field(key, row)
540                         coercer.consume(key, value)
541                 # If our SQL is imperfect, don't yield it to the
542                 # caller unless it passes evaluate().
543                 if (not self.imperfect) or self.expr.evaluate(unit):
544                     yield unit
545
546
547 class StoreMultiIteratorADO(StoreIteratorADO):
548     """Iterator for populating Units (from multiple classes) from storage."""
549    
550     def __init__(self, store, unitClass, expr, pairs):
551         self.store  = store
552         self.unitClass = unitClass
553         self.expr = expr
554         self.pairs = pairs
555         self.fieldTypes = []
556        
557         sel = store.multiselect(unitClass, expr, pairs)
558         self.sql, self.imperfect, self.columns = sel
559    
560     def populate_unit(self, unit, row):
561         """Populate a Unit from a database row."""
562         coercer = AdapterFromADO(unit)
563         cls = unit.__class__
564         for key in cls.properties():
565             try:
566                 col = self.columns.index((cls, key))
567             except ValueError, x:
568                 x.args += (cls, key)
569                 raise x
570             else:
571                 coercer.consume(key, (self.fieldTypes[col],
572                                       self.data[col][row]))
573    
574     def load_data(self):
575         anRS = self.store.recordset(self.sql, adOpenForwardOnly,
576                                     adLockReadOnly)
577        
578         for col, x in enumerate(anRS.Fields):
579             self.fieldTypes.append(x.Type)
580        
581         self.data = []
582         if not(anRS.BOF and anRS.EOF):
583 ##            anRS.MoveFirst()
584 ##            if not(anRS.BOF or anRS.EOF):
585             # We tried .MoveNext() and lots of Fields.Item() calls.
586             # Using GetRows() beats that time by about 2/3.
587             self.data = anRS.GetRows()
588         anRS.Close()
589    
590     def units(self):
591         self.load_data()
592         if len(self.data) > 0:
593             for row in range(len(self.data[0])):
594                 unit = self.unitClass()
595                 self.populate_unit(unit, row)
596                 # If our SQL is imperfect, don't yield it to the
597                 # caller unless it passes evaluate().
598                 if (not self.imperfect) or self.expr.evaluate(unit):
599                     cls, expr = self.pairs[0]
600                     farUnit = cls()
601                     self.populate_unit(farUnit, row)
602                     if farUnit.ID is None:
603                         yield unit, None
604                     elif ((not self.imperfect) or expr is None
605                           or expr.evaluate(farUnit)):
606                         yield unit, farUnit
607
608
609 class FieldTypeAdapter(object):
610     """Return the SQL typename of a DB column."""
611    
612     def coerce(self, cls, key):
613         """coerce(cls, key) -> SQL typename for valuetype."""
614         valuetype = cls.property_type(key)
615         mod = valuetype.__module__
616         if mod == "__builtin__":
617             xform = "coerce_%s" % valuetype.__name__
618         else:
619             xform = "coerce_%s_%s" % (mod, valuetype.__name__)
620         xform = xform.replace(".", "_")
621         try:
622             xform = getattr(self, xform)
623         except AttributeError:
624             raise TypeError("'%s' is not handled by %s." %
625                             (valuetype, self.__class__))
626         return xform(cls, key)
627    
628     def _create_str_storage(self, cls, key):
629         """This basic string handler does not know anything about the size
630         limitations of the particular database. You should use one of the
631         subclasses for your particular database if you need storage for
632         strings over 255 characters."""
633         prop = getattr(cls, key)
634         size = prop.hints.get(u'Size', '255')
635         return u"VARCHAR(%s)" % size
636    
637     def coerce_bool(self, cls, key): return u"BIT"
638    
639     def coerce_datetime_datetime(self, cls, key): return u"TIMESTAMP"
640     def coerce_datetime_date(self, cls, key): return u"DATE"
641     def coerce_datetime_time(self, cls, key): return u"TIME"
642    
643     coerce_dict = _create_str_storage
644    
645     def coerce_fixedpoint_FixedPoint(self, cls, key): return u"FLOAT"
646     def coerce_float(self, cls, key): return u"FLOAT"
647     def coerce_int(self, cls, key): return u"INTEGER"
648    
649     coerce_list = _create_str_storage
650     coerce_str = _create_str_storage
651     coerce_tuple = _create_str_storage
652     coerce_unicode = _create_str_storage
653
654
655 class StorageManagerADO(storage.StorageManager):
656     """StoreManager to save and retrieve Units via ADO 2.7.
657     
658     You must run makepy on ADO 2.7 before installing.
659     """
660    
661     decompiler = ADOSQLDecompiler
662     createAdapter = FieldTypeAdapter()
663     threaded = False
664    
665     def __init__(self, name, arena, allOptions={}):
666         pythoncom.CoInitialize()
667        
668         storage.StorageManager.__init__(self, name, arena, allOptions)
669        
670         self.connstring = allOptions[u'Connect']
671         if allOptions.get(u'Threaded', ''):
672             self.threaded = True
673             self._connection = None
674         else:
675             self._connection = win32com.client.Dispatch(r'ADODB.Connection')
676             self._connection.Open(self.connstring)
677        
678         self.prefix = allOptions.get(u'Prefix', u"djv")
679         self.cursorType = int(allOptions.get(u'CursorType', adOpenDynamic))
680         self.lockType = int(allOptions.get(u'LockType', adLockOptimistic))
681        
682         ec = []
683         for prop in allOptions.get(u'Expanded Columns', '').split(","):
684             if prop:
685                 lastdot = prop.rfind(".")
686                 clsname, key = prop[:lastdot], prop[lastdot + 1:]
687                 ec.append((clsname, key))
688         self.expanded_columns = ec
689        
690         self.reserve_lock = threading.Lock()
691    
692     def shutdown(self):
693         if self._connection is not None:
694             self._connection.Close()
695    
696     def connection(self):
697         if self.threaded:
698             t = threading.currentThread()
699             if not hasattr(t, 'SMADOconn'):
700                 t.SMADOconn = win32com.client.Dispatch(r'ADODB.Connection')
701             if t.SMADOconn.State == adStateClosed:
702                 t.SMADOconn.Open(self.connstring)
703             return t.SMADOconn
704         else:
705             return self._connection
706    
707     def recordset(self, aQuery, cursorType=None, lockType=None):
708         anRS = win32com.client.Dispatch(r'ADODB.Recordset')
709 ##        anRS.Cursorlocation = 3     # adUseClient; Use to obtain .Recordcount
710         if cursorType is None:
711             cursorType = self.cursorType
712         if lockType is None:
713             lockType = self.lockType
714        
715         try:
716             anRS.Open(aQuery, self.connection(), cursorType, lockType)
717         except pywintypes.com_error, x:
718             try:
719                 anRS.Close()
720             except:
721                 pass
722             x.args += (aQuery, )
723             raise x
724         return anRS
725    
726     def _join(self, path=[]):
727         if not path: return u''
728         firstcls = path.pop(0)
729         if not path: return firstcls.__name__
730        
731         spath = self.arena.associations.shortest_path(firstcls, path[0])
732         spath.pop(0)
733         cls = spath[0]
734         leftkey, rightkey = firstcls._associations[cls]
735         params = {u'prefix': self.prefix,
736                   u'left': firstcls.__name__,
737                   u'right': cls.__name__,
738                   u'leftkey': leftkey,
739                   u'rightkey': rightkey,
740                   }
741         if len(spath) == 1:
742             params[u'child'] = u"[%(prefix)s%(right)s]" % params
743         else:
744             params[u'child'] = u"(%s)" % self._join(spath)
745        
746         return (u"[%(prefix)s%(left)s] LEFT JOIN %(child)s"
747                 u" ON [%(prefix)s%(left)s].[%(leftkey)s] = "
748                 u"[%(prefix)s%(right)s].[%(rightkey)s]" % params)
749    
750     def multiselect(self, firstcls, firstexpr, pairs):
751         firstwhere, imp = self.where(firstcls, firstexpr)
752         cols = [(firstcls, k) for k in firstcls.properties()]
753        
754         # TODO: concat multiple pairs.
755         assert len(pairs) == 1
756         for cls, expr in pairs:
757             if expr is None:
758                 expr = logic.Expression(lambda x: True)
759             j = self._join([firstcls, cls])
760            
761             w, new_imp = self.where(cls, expr)
762             imp |= new_imp
763             if w and w != "True":
764                 w = " WHERE %s AND %s" % (w, firstwhere)
765             else:
766                 w = " WHERE %s" % firstwhere
767            
768             cols += [(cls, k) for k in cls.properties()]
769             colnames = ["[%s%s].[%s]" % (self.prefix, colcls.__name__, k)
770                         for colcls, k in cols]
771            
772             statement = "SELECT %s FROM %s%s" % (u', '.join(colnames), j, w)
773            
774             return statement, imp, cols
775    
776     def select(self, unitClass, expr, distinct_fields=None):
777         tablename = self.prefix + safe_name(unitClass.__name__)
778         if distinct_fields:
779             distinct_fields = [u'[%s]' % x for x in distinct_fields]
780             sql = (u"SELECT DISTINCT %s FROM [%s]" %
781                    (u', '.join(distinct_fields), tablename))
782         else:
783             sql = u"SELECT * FROM [%s]" % tablename
784         w, i = self.where(unitClass, expr)
785         if len(w) > 0:
786             w = u" WHERE " + w
787         else:
788             w = u""
789         sql += w
790         return sql, i
791    
792     def where(self, cls, expr):
793         return self.decompiler(self, cls, expr).code()
794    
795     def execute(self, aQuery, conn=None):
796         if conn is None:
797             conn = self.connection()
798         try:
799             conn.Execute(aQuery)
800         except pywintypes.com_error, x:
801             x.args += (aQuery, )
802             raise x
803    
804     def recall(self, cls, expr=None, pairs=None):
805         if expr is None:
806             expr = logic.Expression(lambda x: True)
807        
808         if pairs is not None:
809             return StoreMultiIteratorADO(self, cls, expr, pairs).units()
810         else:
811             return StoreIteratorADO(self, cls, expr).units()
812    
813     def reserve(self, unit):
814         """reserve(unit). -> Reserve a persistent slot for unit."""
815         self.reserve_lock.acquire()
816         try:
817             if unit.ID is None:
818                 data = []
819                 clsname = unit.__class__.__name__
820                 anRS = self.recordset(u"SELECT ID FROM [%s%s];" %
821                                       (self.prefix, safe_name(clsname)))
822                 if not (anRS.BOF and anRS.EOF):
823                     data = anRS.GetRows()[0]
824                 unit.ID = unit.sequencer.next(data)
825                
826                 anRS.AddNew()
827                 anRS.Fields(u'ID').Value = unit.ID
828                 anRS.Update()
829                 anRS.Close()
830         finally:
831             self.reserve_lock.release()
832    
833     def save(self, unit, forceSave=False):
834         """save(unit, forceSave=False). -> Update storage from unit's data.
835         
836         Notice in particular that we do not use the auto-number or
837         sequence generation capabilities within some databases, etc.
838         The ID should be supplied by UnitSequencers via reserve().
839         """
840         if unit.dirty or forceSave:
841             cls = unit.__class__
842             clsname = cls.__name__
843             # Use a cursor always--makes mixed-quotes, newline, etc easier.
844             anRS = self.recordset("SELECT * FROM [%s%s] WHERE ID = %s" %
845                                   (self.prefix, safe_name(clsname),
846                                    AdapterToADOSQL().coerce(unit.ID)))
847             if anRS.EOF and anRS.BOF:
848                 anRS.AddNew()
849                 anRS.Fields(u'ID').Value = unit.ID
850             fmt = AdapterToADOFields()
851             for key in cls.properties():
852                 if (clsname, key) in self.expanded_columns:
853                     # Special-case this field into its own table.
854                     self.save_expanded(unit, key)
855                 else:
856                     eachType = cls.property_type(key)
857                     newValue = fmt.coerce(getattr(unit, key), eachType)
858                     try:
859                         anRS.Fields(key).Value = newValue
860                     except pywintypes.com_error, x:
861                         try:
862                             anRS.Close()
863                         except:
864                             pass
865                         x.args += (clsname, key, eachType, newValue)
866                         raise x
867             anRS.Update()
868             anRS.Close()
869             unit.dirty = False
870    
871     def save_expanded(self, unit, key):
872         """Save a field using a table specifically for that purpose."""
873         unitcls = unit.__class__
874         table = ("%s_%s_%s_%s" % (self.prefix, safe_name(unitcls.__name__),
875                                   safe_name(unit.ID), safe_name(key)))
876        
877         conn = self.connection()
878         try:
879             self.execute((u"DROP TABLE [%s];" % table), conn)
880         except pywintypes.com_error, x:
881             pass
882        
883         # Ugly, ugly hack to get NTEXT or MEMO as appropriate. The point
884         # is, we want a large text field so we can pickle each item.
885         ftype = self.createAdapter.coerce_list(None, None)
886         self.execute(u"CREATE TABLE [%s] (EXPVAL %s);" % (table, ftype), conn)
887        
888         ins = u"INSERT INTO [" + table + "] (EXPVAL) VALUES ('%s');"
889         for v in getattr(unit, key):
890             # Create a row for the unit.
891             # Use an INSERT command (not a cursor) for better performance.
892             v = pickle.dumps(v).replace("'", "''")
893             self.execute(ins % v, conn)
894    
895     def destroy(self, unit):
896         """Delete the unit."""
897         # Use a DELETE command instead of a cursor for better performance.
898         deleteStatement = (u"DELETE * FROM [%s%s] WHERE ID = %s" %
899                            (self.prefix, safe_name(unit.__class__.__name__),
900                             AdapterToADOSQL().coerce(unit.ID)))
901         self.execute(deleteStatement)
902    
903     def create_storage(self, unitClass):
904         clsname = safe_name(unitClass.__name__)
905        
906         coerce = self.createAdapter.coerce
907         fields = []
908         for key in unitClass.properties():
909             if (unitClass.__name__, key) not in self.expanded_columns:
910                 fields.append(u"[%s] %s" % (key, coerce(unitClass, key)))
911         self.execute(u"CREATE TABLE [%s%s] (%s)" %
912                      (self.prefix, clsname, ", ".join(fields)))
913        
914         for index in unitClass.indices():
915             self.execute(u"CREATE INDEX [%si%s%s] ON [%s%s] (%s ASC)"
916                          % (self.prefix, clsname, safe_name(index),
917                             self.prefix, clsname, index))
918    
919     def distinct(self, cls, fields, expr=None):
920         """Return distinct values for specified fields."""
921         if expr is None:
922             expr = logic.Expression(lambda x: True)
923        
924         # ^%$#@! There's no way to handle imperfect queries without
925         # creating all involved Units, which defeats the purpose of
926         # distinct, which was a speed issue more than anything. Grr.
927         sql, imperfect = self.select(cls, expr, fields)
928         # Ignore for now.
929 ##        if imperfect:
930 ##            raise ValueError(u"The following query cannot be reliably "
931 ##                             u"returned from an ADO data source.",
932 ##                             u"distinct()", cls, fields, expr)
933        
934         anRS = self.recordset(sql, adOpenForwardOnly, adLockReadOnly)
935        
936         fieldTypes = [x.Type for x in anRS.Fields]
937         data = []
938         if not (anRS.BOF and anRS.EOF):
939             # We tried .MoveNext() and lots of Fields.Item() calls.
940             # Using GetRows() beats that time by about 2/3.
941             data = anRS.GetRows()
942         anRS.Close()
943        
944         if data:
945             coerced_data = []
946             coerce = AdapterFromADO().coerce
947             for col, field in enumerate(fields):
948                 expectedType = cls.property_type(field)
949                 actualType = fieldTypes[col]
950                 coerced_row = [coerce((actualType, val), expectedType)
951                                for val in data[col]]
952                 coerced_data.append(coerced_row)
953             data = zip(*coerced_data)
954         return data
955
956
957 ###########################################################################
958 ##                                                                       ##
959 ##                             SQL Server                                ##
960 ##                                                                       ##
961 ###########################################################################
962
963
964 class FieldTypeAdapter_SQLServer(FieldTypeAdapter):
965    
966     def _create_str_storage(self, cls, key):
967         prop = getattr(cls, key)
968         size = prop.hints.get(u'Size', '255')
969         if size == 0 or size > 8000:
970             # 8000 *bytes* is the absolute upper limit, based on T_SQL docs
971             # for varchar. If there are further fields defined for the class,
972             # or the code page uses a double-byte character set, we still
973             # might exceed the max size (8060) for a record. We could calc
974             # the total requested record size, and adjust accordingly. For
975             # now, we just trust that units generally use a size of 0 to
976             # bump up to NTEXT (1 gig characters).
977             return u"NTEXT"
978         return u"VARCHAR(%s)" % size
979    
980     # dict, list, and tuple will all be pickled in AdapterToADO
981     def coerce_dict(self, cls, key): return u"NTEXT"
982     def coerce_list(self, cls, key): return u"NTEXT"
983     coerce_str = _create_str_storage
984     def coerce_tuple(self, cls, key): return u"NTEXT"
985     coerce_unicode = _create_str_storage
986
987
988 class StorageManagerADO_SQLServer(StorageManagerADO):
989     createAdapter = FieldTypeAdapter_SQLServer()
990
991
992 ###########################################################################
993 ##                                                                       ##
994 ##                             MS Access                                 ##
995 ##                                                                       ##
996 ###########################################################################
997
998
999 class ADOSQLDecompiler_MSAccess(ADOSQLDecompiler):
1000     sql_cmp_op = ('<', '<=', '=', '<>', '>', '>=', 'in', 'not in')
1001     functions = {dejavu.icontains: lambda x, y: x + " Like '%" + y[1:-1] + "%'",
1002                  dejavu.icontainedby: icontainedby,
1003                  dejavu.istartswith: lambda x, y: x + " Like '" + y[1:-1] + "%'",
1004                  dejavu.iendswith: lambda x, y: x + " Like '%" + y[1:-1] + "'",
1005                  dejavu.ieq: lambda x, y: x + " = " + y,
1006                  dejavu.now: lambda: "Now()",
1007                  dejavu.today: lambda: "DateValue(Now())",
1008                  dejavu.year: lambda x: "Year(" + x + ")",
1009                  }
1010
1011
1012 class FieldTypeAdapter_MSAccess(FieldTypeAdapter):
1013    
1014     def _create_str_storage(self, cls, key):
1015         prop = getattr(cls, key)
1016         size = prop.hints.get(u'Size', '255')
1017         if size == 0 or size > 255:
1018             # 255 chars is the upper limit for TEXT / VARCHAR in MS Access.
1019             # MEMO is 1 gigabyte when set programatically (only 64K when set
1020             # in Access UI). But then, 1 GB is the limit for the whole DB.
1021             return u"MEMO"
1022         return u"VARCHAR(%s)" % size
1023    
1024     # dict, list, and tuple will all be pickled in AdapterToADO
1025     def coerce_dict(self, cls, key): return u"MEMO"
1026     def coerce_list(self, cls, key): return u"MEMO"
1027     coerce_str = _create_str_storage
1028     def coerce_tuple(self, cls, key): return u"MEMO"
1029     coerce_unicode = _create_str_storage
1030
1031
1032 class StorageManagerADO_MSAccess(StorageManagerADO):
1033    
1034     decompiler = ADOSQLDecompiler_MSAccess
1035     createAdapter = FieldTypeAdapter_MSAccess()
1036
1037
1038 if __name__ == '__main__':
1039     # Auto generate .py support for ADO 2.7
1040     print 'Please wait while support for ADO 2.7 is verified...'
1041     CLSID = '{EF53050B-882E-4776-B643-EDA472E8E3F2}'
1042     win32com.client.gencache.EnsureModule(CLSID, 0, 2, 7)
1043
Note: See TracBrowser for help on using the browser.