Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storeado.py

Revision 8 (checked in by fumanchu, 9 years ago)

codewalk and logic are now modules in dejavu

Line 
1 import sys
2 # Put COM in free-threaded mode
3 sys.coinit_flags = 0
4
5 import win32com.client
6 import fixedpoint
7 import pywintypes
8 import pythoncom
9 import threading
10 import datetime
11 try:
12     import cPickle as pickle
13 except ImportError:
14     import pickle
15 import recur
16 from types import FunctionType
17
18 import dejavu
19 from dejavu import storage, codewalk, logic
20
21 adOpenForwardOnly = 0
22 adOpenKeyset = 1
23 adOpenDynamic = 2
24 adOpenStatic = 3
25
26 adLockReadOnly = 1
27 adLockPessimistic = 2
28 adLockOptimistic = 3
29 adLockBatchOptimistic = 4
30
31 adModeShareExclusive = 12
32
33 adStateClosed = 0
34 adStateOpen = 1
35 adStateConnecting = 2
36 adStateExecuting = 4
37 adStateFetching = 8
38
39 # 12/30/1899, the zero-Date for ADO = 693594
40 zeroHour = datetime.date(1899, 12, 30).toordinal()
41
42
43 def time_from_com(com_date):
44     """Return a valid (day, datetime.time) from a COM date or time object."""
45     hour, mins = divmod(86400 * (float(com_date) % 1), 3600)
46     mins, sec = divmod(mins, 60)
47     # Must do both int() and round() or we'll be up to 1 second off.
48     hour = int(round(hour))
49     mins = int(round(mins))
50     sec = int(round(sec))
51     return recur.sane_time(0, hour, mins, sec)
52
53
54 class AdapterFromADO(storage.Adapter):
55     """Coerce incoming values from ADO to Dejavu datatypes."""
56     def __init__(self, unit=None):
57         self.default_processor = self.coerce_unicode
58         self.processors = {datetime.datetime: self.coerce_datetime,
59                            datetime.date: self.coerce_date,
60                            datetime.time: self.coerce_time,
61                            float: self.coerce_float,
62                            int: self.coerce_int,
63                            bool: self.coerce_int,
64                            fixedpoint.FixedPoint: self.coerce_fixed,
65                            dict: self.pickle,
66                            list: self.pickle,
67                            str: self.coerce_str,
68                            }
69         self.unit = unit
70    
71     def consume(self, key, value):
72         expectedType = self.unit.__class__.property_type(key)
73         value = self.coerce(value, expectedType)
74 ##        setattr(self.unit, key, value)
75         # Experimental: set the attribute directly to avoid __set__ overhead.
76         self.unit._properties[key] = value
77    
78     def pickle(self, value):
79         aType, value = value
80         if value is None:
81             return None
82         else:
83             # Coerce to str for pickle.loads restriction.
84             value = str(value)
85         return pickle.loads(value)
86    
87     def coerce_datetime(self, value):
88         # Illegal Date/Time values will crash the
89         # app when using value.Format(). Therefore,
90         # grab the value and figure the date ourselves.
91         # Use 1-second resolution only.
92         aType, value = value
93         if value is None:
94             return None
95         elif isinstance(value, basestring):
96             return datetime.datetime(int(value[0:4]), int(value[4:6]),
97                                      int(value[6:8]))
98         else:
99             # For some reason, we need both float and int.
100             aDate = datetime.date.fromordinal(int(float(value)) + zeroHour)
101             day, aTime = time_from_com(value)
102             return datetime.datetime.combine(aDate, aTime)
103    
104     def coerce_date(self, value):
105         # See coerce_datetime
106         aType, value = value
107         if value is None:
108             return None
109         elif isinstance(value, basestring):
110             return datetime.date(int(value[0:4]), int(value[4:6]),
111                                  int(value[6:8]))
112         else:
113             return datetime.date.fromordinal(int(float(value)) + zeroHour)
114    
115     def coerce_time(self, value):
116         # See coerce_datetime
117         aType, value = value
118         if value is None:
119             return None
120         else:
121             day, aTime = time_from_com(value)
122             return aTime
123    
124     def coerce_float(self, value):
125         aType, value = value
126         if value is None:
127             return None
128         if aType == 0x06:
129             # Currency
130             value = value[1] / 10000.0
131         return float(value)
132    
133     def coerce_int(self, value):
134         aType, value = value
135         if value is None:
136             return None
137         if aType == 0x0b:
138             # Boolean
139             return value != 0
140         return int(value)
141    
142     def coerce_fixed(self, value):
143         aType, value = value
144         if value is None:
145             return None
146         if aType == 0x06:
147             # Currency
148             value = value[1] / 10000.0
149         return fixedpoint.FixedPoint(value)
150    
151     def coerce_str(self, value):
152         aType, value = value
153         if value is None:
154             return None
155         return str(value)
156    
157     def coerce_unicode(self, value):
158         aType, value = value
159         if value is None:
160             return None
161         if isinstance(value, unicode):
162             # For some reason, inValue is already a unicode object.
163             return value
164         if isinstance(value, str):
165             try:
166                 return unicode(value, "ISO-8859-1")
167             except UnicodeError:
168                 raise StandardError(type(value))
169         return unicode(value)
170
171
172
173 class AdapterToADOFields(storage.Adapter):
174     """Coerce outgoing values from Dejavu datatypes to ADO.Field types."""
175    
176     def __init__(self):
177         self.default_processor = self.coerce_str
178         self.processors = {datetime.datetime: self.coerce_datetime,
179                            datetime.date: self.coerce_date,
180                            datetime.time: self.coerce_time,
181                            fixedpoint.FixedPoint: self.coerce_fixed,
182                            dict: self.pickle,
183                            list: self.pickle,
184                            bool: self.coerce_bool,
185                            }
186    
187     def coerce_bool(self, value):
188         if value:
189             return True
190         return False
191    
192     def coerce_fixed(self, value):
193         if value is None:
194             return None
195         return float(value)
196    
197     def coerce_datetime(self, value):
198         if value is None:
199             return None
200         return self.coerce_date(value) + self.coerce_time(value)
201    
202     def coerce_date(self, value):
203         if value is None:
204             return None
205         return value.toordinal() - zeroHour
206    
207     def coerce_time(self, value):
208         if value is None:
209             return None
210         return ((value.second + (value.minute * 60) + (value.hour * 3600))
211                 / 86400.0)
212    
213     def coerce_str(self, value):
214         return value
215    
216     def pickle(self, value):
217         # We must not use a pickle format other than 0, because binary
218         # strings are not safe for all DB string fields.
219         return pickle.dumps(value)
220
221
222 class AdapterToADOSQL(storage.Adapter):
223     """Coerce Expression constants to ADO SQL."""
224    
225     def __init__(self):
226         self.default_processor = str    # NOT coerce_str
227         self.processors = {datetime.datetime: self.coerce_datetime,
228                            datetime.date: self.coerce_date,
229                            datetime.time: self.coerce_time,
230                            datetime.timedelta: self.coerce_timedelta,
231                            fixedpoint.FixedPoint: self.coerce_fixed,
232                            str: self.coerce_str,
233                            unicode: self.coerce_str,
234                            tuple: self.coerce_tuple,
235                            bool: self.coerce_bool,
236                            type(None): self.coerce_none,
237                            }
238    
239     def coerce_none(self, value):
240         return "Null"
241    
242     def coerce_bool(self, value):
243         if value:
244             return 'True'
245         return 'False'
246    
247     def coerce_datetime(self, value):
248         return (u'#%s/%s/%s %02d:%02d:%02d#' %
249                 (value.month, value.day, value.year,
250                  value.hour, value.minute, value.second))
251    
252     def coerce_date(self, value):
253         return u'#%s/%s/%s#' % (value.month, value.day, value.year)
254    
255     def coerce_time(self, value):
256         return u'#%02d:%02d:%02d#' % (value.hour, value.minute, value.second)
257    
258     def coerce_timedelta(self, value):
259         float_val = value.days + (value.seconds / 86400.0)
260         return repr(float_val)
261    
262     def coerce_fixed(self, value):
263         return str(value)
264    
265     def coerce_str(self, value):
266         value = value.replace(u"'", u"''")
267         value = value.replace("%", "[%]")
268         value = value.replace("_", "[_]")
269         return "'" + value + "'"
270    
271     def coerce_tuple(self, value):
272         return "(" + ", ".join([self.coerce(x) for x in value]) + ")"
273
274
275 def icontainedby(op1, op2, notin=False):
276     # This test doesn't work right, now that we use lists as
277     # well as tuples with IN. Need a way to mark field refs.
278     if op2.startswith("[") and op2.endswith("]"):
279         # Looking for text in a field. Use Like (reverse terms).
280         value = op2 + " Like '%" + op1[1:-1] + "%'"
281     else:
282         # Looking for field in (a, b, c)
283         value = op1 + " in " + op2
284     if notin:
285         value = "not " + value
286     return value
287
288
289 class ADOSQLDecompiler(codewalk.LambdaDecompiler):
290     """ADOSQLDecompiler(store, unitClass, expr, adapter=AdapterToADOSQL()).
291     
292     Produce SQL from a supplied lambda of the form:
293         lambda x, **kw: ...
294     
295     Attributes of x (or whatever the name of the first argument is) will be
296     mapped to table columns. Keyword arguments should be bound to the
297     Expression before calling this decompiler.
298     """
299    
300     sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in')
301     functions = {dejavu.icontains: lambda x, y: x + " Like '%" + y[1:-1] + "%'",
302                  dejavu.icontainedby: icontainedby,
303                  dejavu.istartswith: lambda x, y: x + " Like '" + y[1:-1] + "%'",
304                  dejavu.iendswith: lambda x, y: x + " Like '%" + y[1:-1] + "'",
305                  dejavu.ieq: lambda x, y: x + " = " + y,
306                  dejavu.now: lambda: "getdate()",
307                  dejavu.today: lambda: "DATEADD(dd, DATEDIFF(dd,0,getdate()), 0)",
308                  dejavu.year: lambda x: "YEAR(" + x + ")",
309                  }
310    
311     def __init__(self, store, unitClass, expr, adapter=AdapterToADOSQL()):
312         self.store = store
313         self.unitClass = unitClass
314         self.expr = expr
315         self.adapter = adapter
316         obj = expr.func
317         codewalk.LambdaDecompiler.__init__(self, obj)
318    
319     def code(self):
320         self.imperfect = False
321         self.walk()
322         result = self.stack[0]
323         if result is None:
324             result = 'True'
325         return result, self.imperfect
326    
327     def visit_target(self, terms):
328 ##        terms.reverse()
329         comp = self.stack.pop()
330         while terms:
331             term, operation = terms.pop()
332             # All this checking of None is done so that a function
333             # (like dejavu.iscurrentweek) can be labeled imperfect--
334             # all Units (which match the rest of the Expression)
335             # will be recalled. They can then be compared in
336             # expr.evaluate(unit).
337             if comp is None:
338                 if term is not None:
339                     comp = term
340             else:
341                 if term is not None:
342                     comp = "(%s) %s (%s)" % (term, operation, comp)
343         self.stack.append(comp)
344    
345     def visit_LOAD_GLOBAL(self, lo, hi):
346         pass
347    
348     def visit_LOAD_FAST(self, lo, hi):
349         pass
350    
351     def visit_LOAD_ATTR(self, lo, hi):
352         name = self.co_names[lo + (hi << 8)]
353         self.stack.append("[%s%s].[%s]" %
354                           (self.store.prefix, self.unitClass.__name__, name))
355    
356     def visit_LOAD_CONST(self, lo, hi):
357         val = self.co_consts[lo + (hi << 8)]
358         # Some constants are function or class objects,
359         # which should not be coerced.
360         no_coerce = (FunctionType, type)
361         if not isinstance(val, no_coerce):
362             val = self.adapter.coerce(val)
363         self.stack.append(val)
364    
365     def visit_BUILD_TUPLE(self, lo, hi):
366         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)])
367         self.stack.append("(" + terms + ")")
368    
369     def visit_BUILD_LIST(self, lo, hi):
370         self.visit_BUILD_TUPLE(lo, hi)
371    
372     def visit_CALL_FUNCTION(self, lo, hi):
373         kwargs = {}
374         for i in range(hi):
375             val = self.stack.pop()
376             key = self.stack.pop()
377             kwargs[key] = val
378         kwargs = [k + "=" + v for k, v in kwargs.iteritems()]
379        
380         args = []
381         for i in range(lo):
382             arg = self.stack.pop()
383             args.append(arg)
384         args.reverse()
385        
386         if kwargs:
387             args += kwargs
388        
389         func = self.stack.pop()
390        
391         # Handle function objects.
392         if func in self.functions:
393             self.stack.append(self.functions[func](*args))
394         else:
395             if isinstance(func, basestring):
396                 if func.endswith("[startswith]"):
397                     self.stack[-1] = self.stack[-1] + " Like '" + args[0][1:-1] + "%'"
398                     self.imperfect = True
399                     return
400                 elif func.endswith("[endswith]"):
401                     self.stack[-1] = self.stack[-1] + " Like '%" + args[0][1:-1] + "'"
402                     self.imperfect = True
403                     return
404                 elif func == '<built-in function len>':
405                     self.stack.append("Len(" + args[0] + ")")
406                     return
407             else:
408                 if self.stack:
409                     self.stack[-1] = None
410                 else:
411                     self.stack = [None]
412                 self.imperfect = True
413    
414     def visit_COMPARE_OP(self, lo, hi):
415         op2, op1 = self.stack.pop(), self.stack.pop()
416         op = self.sql_cmp_op[lo + (hi << 8)]
417         if op == 'in':
418             self.stack.append(icontainedby(op1, op2))
419             self.imperfect = True
420         elif op == 'not in':
421             self.stack.append(icontainedby(op1, op2, True))
422             self.imperfect = True
423         elif op == '=' and op2 == 'Null':
424             self.stack.append(op1 + " Is Null")
425         elif op == '=' and op1 == 'Null':
426             self.stack.append(op2 + " Is Null")
427         else:
428             if op2.startswith("'") and op2.endswith("'"):
429                 # All ADO comparison operators for strings are case-insensitive
430                 # by default. Rather than determine column-by-column which
431                 # might be case-sensitive, just flag them all as imperfect.
432                 self.imperfect = True
433             self.stack.append(op1 + " " + op + " " + op2)
434    
435     def binary_op(self, op):
436         op2, op1 = self.stack.pop(), self.stack.pop()
437         self.stack.append(op1 + " " + op + " " + op2)
438    
439     def visit_BINARY_SUBSCR(self):
440         """The only BINARY_SUBSCR used in Expressions should be kwargs[key]."""
441         name = self.stack.pop()
442         # name, since formed in LOAD_CONST, has extraneous single-quotes.
443         value = self.expr.kwargs[name[1:-1]]
444         value = self.adapter.coerce(value)
445         self.stack.append(value)
446    
447     def visit_UNARY_NOT(self):
448         op = self.stack.pop()
449         if op is None:
450             # Usually as a result of has(farClassName).
451             self.stack.append(None)
452         else:
453             self.stack.append("not (" + op + ")")
454
455
456 def safe_name(content):
457     return unicode(content).replace(u"_", u"")
458
459
460 class StoreIteratorADO(object):
461     """Iterator for populating Units from storage."""
462    
463     def __init__(self, store, unitClass, expr):
464         self.store  = store
465         self.unitClass = unitClass
466         self.expr = expr
467         self.colIndices = {}
468         self.fieldTypes = []
469        
470         self.sql, self.imperfect = store.select(unitClass, expr)
471    
472     def field(self, name, row):
473         try:
474             col = self.colIndices[name]
475         except KeyError, x:
476             x.args += (name, self.unitClass.__name__)
477             raise x
478         else:
479             return (self.fieldTypes[col], self.data[col][row])
480    
481     def load_data(self):
482         try:
483             anRS = self.store.recordset(self.sql, adOpenForwardOnly,
484                                         adLockReadOnly)
485         except pywintypes.com_error, x:
486             x.args += (self.sql, )
487             raise x
488        
489         for col, x in enumerate(anRS.Fields):
490             self.colIndices[x.Name] = col
491             self.fieldTypes.append(x.Type)
492        
493         self.data = []
494         if not(anRS.BOF and anRS.EOF):
495 ##            anRS.MoveFirst()
496 ##            if not(anRS.BOF or anRS.EOF):
497             # We tried .MoveNext() and lots of Fields.Item() calls.
498             # Using GetRows() beats that time by about 2/3.
499             self.data = anRS.GetRows()
500         anRS.Close()
501    
502     def units(self):
503         self.load_data()
504         if len(self.data) > 0:
505             for row in range(len(self.data[0])):
506                 unit = self.unitClass()
507                 coercer = AdapterFromADO(unit)
508                 for key in unit.__class__.properties():
509                     value = self.field(key, row)
510                     coercer.consume(key, value)
511                 # If our SQL is imperfect, don't yield it to the
512                 # caller unless it passes evaluate().
513                 if (not self.imperfect) or self.expr.evaluate(unit):
514                     yield unit
515
516
517 class CollectionLoaderADO(StoreIteratorADO):
518     """Iterable Factory for populating UnitCollections from storage."""
519    
520     def units(self):
521         self.load_data()
522         if len(self.data) > 0:
523             coll_coercer = AdapterFromADO().coerce
524             for row in range(len(self.data[0])):
525                 unit = self.unitClass()
526                 coercer = AdapterFromADO(unit)
527                 for key in unit.__class__.properties():
528                     value = self.field(key, row)
529                     coercer.consume(key, value)
530                 # If our SQL is imperfect, don't yield it to the
531                 # caller unless it passes evaluate().
532                 if (not self.imperfect) or self.expr.evaluate(unit):
533                     # Load the collection.
534                     # Grab the data dictionary (list of Unit ID's)
535                     rsource = (u"SELECT ID FROM [%s__%s]" %
536                                (self.store.prefix, safe_name(unit.ID)))
537                     try:
538                         dataRS = self.store.recordset(rsource)
539                     except pywintypes.com_error, x:
540                         # This usually occurs because the UnitCollection was
541                         # reserved but no table yet made for IDs. This is OK.
542                         pass
543                     else:
544                         idtype = self.store.arena.class_by_name(unit.Type).ID.type
545                         while not dataRS.EOF:
546                             ID = coll_coercer((0, dataRS.Fields.Item(u'ID')),
547                                               idtype)
548                             unit.add(ID)
549                             dataRS.MoveNext()
550                         dataRS.Close()
551                     yield unit
552
553
554 class StoreMultiIteratorADO(StoreIteratorADO):
555     """Iterator for populating Units (from multiple classes) from storage."""
556    
557     def __init__(self, store, unitClass, expr, pairs):
558         self.store  = store
559         self.unitClass = unitClass
560         self.expr = expr
561         self.pairs = pairs
562         self.fieldTypes = []
563        
564         sel = store.multiselect(unitClass, expr, pairs)
565         self.sql, self.imperfect, self.columns = sel
566    
567     def populate_unit(self, unit, row):
568         """Populate a Unit from a database row."""
569         coercer = AdapterFromADO(unit)
570         cls = unit.__class__
571         for key in cls.properties():
572             try:
573                 col = self.columns.index((cls, key))
574             except ValueError, x:
575                 x.args += (cls, key)
576                 raise x
577             else:
578                 coercer.consume(key, (self.fieldTypes[col],
579                                       self.data[col][row]))
580    
581     def load_data(self):
582         try:
583             anRS = self.store.recordset(self.sql, adOpenForwardOnly,
584                                         adLockReadOnly)
585         except pywintypes.com_error, x:
586             x.args += (self.sql, )
587             raise x
588        
589         for col, x in enumerate(anRS.Fields):
590             self.fieldTypes.append(x.Type)
591        
592         self.data = []
593         if not(anRS.BOF and anRS.EOF):
594 ##            anRS.MoveFirst()
595 ##            if not(anRS.BOF or anRS.EOF):
596             # We tried .MoveNext() and lots of Fields.Item() calls.
597             # Using GetRows() beats that time by about 2/3.
598             self.data = anRS.GetRows()
599         anRS.Close()
600    
601     def units(self):
602         self.load_data()
603         if len(self.data) > 0:
604             for row in range(len(self.data[0])):
605                 unit = self.unitClass()
606                 self.populate_unit(unit, row)
607                 # If our SQL is imperfect, don't yield it to the
608                 # caller unless it passes evaluate().
609                 if (not self.imperfect) or self.expr.evaluate(unit):
610                     cls, expr = self.pairs[0]
611                     farUnit = cls()
612                     self.populate_unit(farUnit, row)
613                     if farUnit.ID is None:
614                         yield unit, None
615                     elif ((not self.imperfect) or expr is None
616                           or expr.evaluate(farUnit)):
617                         yield unit, farUnit
618
619
620 class StorageManagerADO(storage.StorageManager):
621     """StoreManager to save and retrieve Units via ADO 2.7.
622     
623     You must run makepy on ADO 2.7 before installing.
624     """
625    
626     decompiler = ADOSQLDecompiler
627     monopoly = False
628     threaded = False
629    
630     def __init__(self, name, arena, allOptions={}):
631         pythoncom.CoInitialize()
632        
633         self.name = name
634         self.arena = arena
635        
636         self.connstring = allOptions[u'Connect']
637         if allOptions.get(u'Threaded', ''):
638             self.threaded = True
639             self._connection = None
640         else:
641             self._connection = win32com.client.Dispatch(r'ADODB.Connection')
642             self._connection.Open(self.connstring)
643        
644         self.prefix = allOptions.get(u'Prefix', u"djv")
645         self.cursorType = int(allOptions.get(u'CursorType', adOpenDynamic))
646         self.lockType = int(allOptions.get(u'LockType', adLockOptimistic))
647         if allOptions.get(u'Monopoly', ''):
648             self.monopoly = True
649        
650         self.reserve_lock = threading.Lock()
651    
652     def __del__(self):
653         if self._connection is not None:
654             self._connection.Close()
655    
656     def connection(self):
657         if self.threaded:
658             t = threading.currentThread()
659             if not hasattr(t, 'SMADOconn'):
660                 t.SMADOconn = win32com.client.Dispatch(r'ADODB.Connection')
661             if t.SMADOconn.State == adStateClosed:
662                 t.SMADOconn.Open(self.connstring)
663             return t.SMADOconn
664         else:
665             return self._connection
666    
667     def recordset(self, aQuery, cursorType=None, lockType=None):
668         anRS = win32com.client.Dispatch(r'ADODB.Recordset')
669 ##        anRS.Cursorlocation = 3     # adUseClient; Use to obtain .Recordcount
670         if cursorType is None:
671             cursorType = self.cursorType
672         if lockType is None:
673             lockType = self.lockType
674        
675         try:
676             anRS.Open(aQuery, self.connection(), cursorType, lockType)
677         except pywintypes.com_error, x:
678             x.args += (aQuery, )
679             raise x
680         return anRS
681    
682     def _join(self, path=[]):
683         if not path: return u''
684         firstcls = path.pop(0)
685         if not path: return firstcls.__name__
686        
687         spath = self.arena.associations.shortest_path(firstcls, path[0])
688         spath.pop(0)
689         cls = spath[0]
690         leftkey, rightkey = firstcls._associations[cls]
691         params = {u'prefix': u'djv',
692                   u'left': firstcls.__name__,
693                   u'right': cls.__name__,
694                   u'leftkey': leftkey,
695                   u'rightkey': rightkey,
696                   }
697         if len(spath) == 1:
698             params[u'child'] = u"[%(prefix)s%(right)s]" % params
699         else:
700             params[u'child'] = u"(%s)" % self._join(spath)
701        
702         return (u"[%(prefix)s%(left)s] LEFT JOIN %(child)s"
703                 u" ON [%(prefix)s%(left)s].[%(leftkey)s] = "
704                 u"[%(prefix)s%(right)s].[%(rightkey)s]" % params)
705    
706     def multiselect(self, firstcls, firstexpr, pairs):
707         firstwhere, imp = self.where(firstcls, firstexpr)
708         cols = [(firstcls, k) for k in firstcls.properties()]
709        
710         # TODO: concat multiple pairs.
711         assert len(pairs) == 1
712         for cls, expr in pairs:
713             if expr is None:
714                 expr = logic.Expression(lambda x: True)
715             j = self._join([firstcls, cls])
716            
717             w, new_imp = self.where(cls, expr)
718             imp |= new_imp
719             if w and w != "True":
720                 w = " WHERE %s AND %s" % (w, firstwhere)
721             else:
722                 w = " WHERE %s" % firstwhere
723            
724             cols += [(cls, k) for k in cls.properties()]
725             colnames = ["[%s%s].[%s]" % (self.prefix, colcls.__name__, k)
726                         for colcls, k in cols]
727            
728             statement = "SELECT %s FROM %s%s" % (u', '.join(colnames), j, w)
729            
730             return statement, imp, cols
731    
732     def select(self, unitClass, expr, distinct_fields=None):
733         tablename = self.prefix + safe_name(unitClass.__name__)
734         if distinct_fields:
735             distinct_fields = [u'[%s]' % x for x in distinct_fields]
736             sql = (u"SELECT DISTINCT %s FROM [%s]" %
737                    (u', '.join(distinct_fields), tablename))
738         else:
739             sql = u"SELECT * FROM [%s]" % tablename
740         w, i = self.where(unitClass, expr)
741         if len(w) > 0:
742             w = u" WHERE " + w
743         else:
744             w = u""
745         sql += w
746         return sql, i
747    
748     def where(self, cls, expr):
749         return self.decompiler(self, cls, expr).code()
750    
751     def execute(self, aQuery):
752         self.connection().Execute(aQuery)
753    
754     def recall(self, cls, expr=None, pairs=None):
755         if expr is None:
756             expr = logic.Expression(lambda x: True)
757         if pairs is not None:
758             return StoreMultiIteratorADO(self, cls, expr, pairs).units()
759         else:
760             if cls.__name__ == u'UnitCollection':
761                 aLoader = CollectionLoaderADO
762             else:
763                 aLoader = StoreIteratorADO
764             return aLoader(self, cls, expr).units()
765    
766     def reserve(self, unit):
767         """reserve(unit). -> Reserve a persistent slot for unit."""
768         self.reserve_lock.acquire()
769         try:
770             if unit.ID is None:
771                 data = []
772                 clsname = unit.__class__.__name__
773                 anRS = self.recordset(u"SELECT ID FROM [%s%s];" %
774                                       (self.prefix, safe_name(clsname)))
775                 if not (anRS.BOF and anRS.EOF):
776 ##                    anRS.MoveFirst()
777 ##                    if not (anRS.BOF or anRS.EOF):
778                     data = anRS.GetRows()[0]
779                 unit.ID = unit.sequencer.next(data)
780                
781                 anRS.AddNew()
782                 anRS.Fields(u'ID').Value = unit.ID
783                 anRS.Update()
784                 anRS.Close()
785         finally:
786             self.reserve_lock.release()
787    
788     def save(self, unit, forceSave=False):
789         """save(unit, forceSave=False). -> Update storage from unit's data.
790         
791         Notice in particular that we do not use the auto-number or
792         sequence generation capabilities within some databases, etc.
793         The ID should be supplied by UnitSequencers via reserve().
794         """
795         if unit.dirty or forceSave:
796             cls = unit.__class__
797             # Use a cursor always--makes mixed-quotes, newline, etc easier.
798             anRS = self.recordset("SELECT * FROM [%s%s] WHERE ID = %s" %
799                                   (self.prefix, safe_name(cls.__name__),
800                                    AdapterToADOSQL().coerce(unit.ID)))
801             if anRS.EOF and anRS.BOF:
802                 anRS.AddNew()
803                 anRS.Fields(u'ID').Value = unit.ID
804             fmt = AdapterToADOFields()
805             for key in cls.properties():
806                 eachType = cls.property_type(key)
807                 newValue = fmt.coerce(getattr(unit, key), eachType)
808                 try:
809                     anRS.Fields(key).Value = newValue
810                 except pywintypes.com_error, x:
811                     try:
812                         anRS.Close()
813                     except:
814                         pass
815                     x.args += (cls.__name__, key, eachType, newValue)
816                     raise x
817             anRS.Update()
818             # Explicitly close here, or save_collection fails on BeginTrans.
819             anRS.Close()
820             if cls.__name__ == u'UnitCollection':
821                 self.save_collection(unit)
822             unit.dirty = False
823    
824     def save_collection(self, unitColl):
825         """Update the database from the UnitCollection's data."""
826         conn = self.connection()
827 ##        # Dropped the begintrans; we were running into limits in MS Access.
828 ##        conn.BeginTrans()
829         deleteStatement = (u"DROP TABLE [%s__%s];" %
830                            (self.prefix, safe_name(unitColl.ID)))
831         conn.Execute(deleteStatement)
832        
833         cls = unitColl.unit_class()
834         fieldtype = self.createCoercions[cls.ID.type](cls, 'ID')
835         createStatement = (u"CREATE TABLE [%s__%s] (ID %s);"
836                             % (self.prefix, safe_name(unitColl.ID),
837                                fieldtype))
838         conn.Execute(createStatement)
839        
840         ins = u"INSERT INTO [%s__%s] (ID) VALUES (%s);"
841         coercer = AdapterToADOSQL().coerce
842         for eachID in unitColl.ids():
843             # Create a row for the unit.
844             # Use an INSERT command (not a cursor) for better performance.
845             # TODO: cluster inserts.
846             insStatement = ins % (self.prefix, safe_name(unitColl.ID),
847                                   coercer(eachID))
848             try:
849                 conn.Execute(insStatement)
850             except pywintypes.com_error, x:
851                 x.args += (insStatement, createStatement)
852                 raise x
853 ##        conn.CommitTrans()
854    
855     def destroy(self, unit):
856         """Delete the unit."""
857         # Use a DELETE command instead of a cursor for better performance.
858         deleteStatement = (u"DELETE * FROM [%s%s] WHERE ID = %s" %
859                            (self.prefix, safe_name(unit.__class__.__name__),
860                             AdapterToADOSQL().coerce(unit.ID)))
861         try:
862             self.execute(deleteStatement)
863         except pywintypes.com_error, x:
864             x.args += (deleteStatement, )
865             raise x
866    
867     def _create_str_storage(unitClass, key):
868         """This basic string handler does not know anything about the size
869         limitations of the particular database. You should use one of the
870         subclasses for your particular database if you need storage for
871         strings over 255 characters."""
872         # 'self' is missing from the func sig ON PURPOSE.
873         try:
874             prop = getattr(unitClass, key)
875             size = prop.hints[u'Size']
876             return u"VARCHAR(%s)" % size
877         except KeyError:
878             return u"VARCHAR(255)"
879    
880     createCoercions = {datetime.datetime: lambda x, y: u"TIMESTAMP",
881                        datetime.date: lambda x, y: u"DATE",
882                        datetime.time: lambda x, y: u"TIME",
883                        str: _create_str_storage,
884                        unicode: _create_str_storage,
885                        dict: _create_str_storage,
886                        list: _create_str_storage,
887                        fixedpoint.FixedPoint: lambda x, y: u"FLOAT",
888                        int: lambda x, y: u"INTEGER",
889                        bool: lambda x, y: u"BIT",
890                        float: lambda x, y: u"FLOAT",
891                        }
892    
893     def create_storage(self, unitClass):
894         fields = []
895         for key in unitClass.properties():
896             eachType = unitClass.property_type(key)
897             aType = self.createCoercions[eachType](unitClass, key)
898             fields.append(u"[%s] %s" % (key, aType))
899         indices = [x + " ASC" for x in unitClass.indices()]
900        
901         createStatement = (u"CREATE TABLE [%s%s] (%s)" %
902                            (self.prefix,
903                             safe_name(unitClass.__name__),
904                             ", ".join(fields)))
905         try:
906             self.execute(createStatement)
907         except Exception, x:
908             x.args += (createStatement, )
909             raise x
910        
911         for index in indices:
912             indexStatement = (u"CREATE INDEX [%si%s%s] ON [%s%s] (%s)"
913                               % (self.prefix, safe_name(unitClass.__name__),
914                                  safe_name(index),
915                                  self.prefix, safe_name(unitClass.__name__),
916                                  index))
917             try:
918                 self.execute(indexStatement)
919             except Exception, x:
920                 x.args += (indexStatement, )
921                 raise x
922        
923         return True
924    
925     def distinct(self, cls, fields, expr=None):
926         """Return distinct values for specified fields."""
927         if expr is None:
928             expr = logic.Expression(lambda x: True)
929        
930         # ^%$#@! There's no way to handle imperfect queries without
931         # creating all involved Units, which defeats the purpose of
932         # distinct, which was a speed issue more than anything. Grr.
933         sql, imperfect = self.select(cls, expr, fields)
934         # Ignore for now.
935 ##        if imperfect:
936 ##            raise ValueError(u"The following query cannot be reliably "
937 ##                             u"returned from an ADO data source.",
938 ##                             u"distinct()", cls, fields, expr)
939        
940         try:
941             anRS = self.recordset(sql, adOpenForwardOnly, adLockReadOnly)
942         except pywintypes.com_error, x:
943             x.args += (self.sql, )
944             raise x
945        
946         fieldTypes = [x.Type for x in anRS.Fields]
947         data = []
948         if not (anRS.BOF and anRS.EOF):
949             # We tried .MoveNext() and lots of Fields.Item() calls.
950             # Using GetRows() beats that time by about 2/3.
951             data = anRS.GetRows()
952         anRS.Close()
953        
954         if data:
955             coerced_data = []
956             coerce = AdapterFromADO().coerce
957             for col, field in enumerate(fields):
958                 expectedType = cls.property_type(field)
959                 actualType = fieldTypes[col]
960                 coerced_row = [coerce((actualType, val), expectedType)
961                                for val in data[col]]
962                 coerced_data.append(coerced_row)
963             data = zip(*coerced_data)
964         return data
965
966
967 class StorageManagerADO_SQLServer(StorageManagerADO):
968    
969     def _create_str_storage(unitClass, key):
970         try:
971             prop = getattr(unitClass, key)
972             size = prop.hints[u'Size']
973         except KeyError:
974             return u"VARCHAR(255)"
975         else:
976             if size == 0 or size > 8060:
977                 # 8060 is the absolute upper limit, based on the page size
978                 # of SQL server. If there are further fields defined for
979                 # the unitClass, we could exceed the max size for a record.
980                 # Perhaps someday we can calc the total requested record
981                 # size, and adjust accordingly. For now, we just trust that
982                 # units generally use a size of 0 to bump up to NTEXT.
983                 return u"NTEXT"
984             return u"VARCHAR(%s)" % size
985    
986     createCoercions = {datetime.datetime: lambda x, y: u"TIMESTAMP",
987                        datetime.date: lambda x, y: u"DATE",
988                        datetime.time: lambda x, y: u"TIME",
989                        str: _create_str_storage,
990                        unicode: _create_str_storage,
991                        dict: lambda x, y: u"NTEXT",
992                        list: lambda x, y: u"NTEXT",
993                        fixedpoint.FixedPoint: lambda x, y: u"FLOAT",
994                        float: lambda x, y: u"FLOAT",
995                        int: lambda x, y: u"INTEGER",
996                        bool: lambda x, y: u"BIT",
997                        }
998
999
1000 ###########################################################################
1001 ##                                                                       ##
1002 ##                             MS Access                                 ##
1003 ##                                                                       ##
1004 ###########################################################################
1005
1006
1007 class ADOSQLDecompiler_MSAccess(ADOSQLDecompiler):
1008     sql_cmp_op = ('<', '<=', '=', '<>', '>', '>=', 'in', 'not in')
1009     functions = {dejavu.icontains: lambda x, y: x + " Like '%" + y[1:-1] + "%'",
1010                  dejavu.icontainedby: icontainedby,
1011                  dejavu.istartswith: lambda x, y: x + " Like '" + y[1:-1] + "%'",
1012                  dejavu.iendswith: lambda x, y: x + " Like '%" + y[1:-1] + "'",
1013                  dejavu.ieq: lambda x, y: x + " = " + y,
1014                  dejavu.now: lambda: "Now()",
1015                  dejavu.today: lambda: "DateValue(Now())",
1016                  dejavu.year: lambda x: "Year(" + x + ")",
1017                  }
1018
1019
1020 class StorageManagerADO_MSAccess(StorageManagerADO):
1021    
1022     decompiler = ADOSQLDecompiler_MSAccess
1023    
1024     def _create_str_storage(unitClass, key):
1025         try:
1026             prop = getattr(unitClass, key)
1027             size = prop.hints[u'Size']
1028         except KeyError:
1029             return u"VARCHAR(255)"
1030         else:
1031             if size == 0 or size > 255:
1032                 # 255 is the upper limit, based on the max size of 'Text'
1033                 # fields within MS Access. If there are further fields
1034                 # defined for the unitClass, we could exceed the max size
1035                 # for a record. Perhaps someday we can calc the total
1036                 # requested record size, and adjust accordingly. For now,
1037                 # we just trust that units generally use aSize of 0 to
1038                 # bump up to MEMO.
1039                 return u"MEMO"
1040             return u"VARCHAR(%s)" % size
1041    
1042     createCoercions = {datetime.datetime: lambda x, y: u"TIMESTAMP",
1043                        datetime.date: lambda x, y: u"DATE",
1044                        datetime.time: lambda x, y: u"TIME",
1045                        str: _create_str_storage,
1046                        unicode: _create_str_storage,
1047                        dict: lambda x, y: u"MEMO",
1048                        list: lambda x, y: u"MEMO",
1049                        fixedpoint.FixedPoint: lambda x, y: u"FLOAT",
1050                        float: lambda x, y: u"FLOAT",
1051                        int: lambda x, y: u"INTEGER",
1052                        bool: lambda x, y: u"BIT",
1053                        }
1054
1055 if __name__ == '__main__':
1056     # Auto generate .py support for ADO 2.7
1057     print 'Please wait while support for ADO 2.7 is verified...'
1058     CLSID = '{EF53050B-882E-4776-B643-EDA472E8E3F2}'
1059     win32com.client.gencache.EnsureModule(CLSID, 0, 2, 7)
1060
Note: See TracBrowser for help on using the browser.