Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/db.py

Revision 69 (checked in by fumanchu, 8 years ago)

1. Fixed bug in db.SQLDecompiler.visit_CALL_FUNCTION, affecting Expressions with kwargs inside function calls.
2. db stack sentinels now have repr's to help debugging.
3. Changed Adapter.pickle to .do_pickle to help avoid shadowing the pickle module.
4. Put a conn timeout in sockets.SocketClient?.query.

Line 
1 """Base classes and tools for writing database Storage Managers.
2
3 Unit type -> [SQL repr ->] DB -> incoming Python value -> Unit type
4
5
6 DATA TYPES
7 ==========
8 Since Dejavu relies on external database servers for its persistence,
9 Python datatypes must be converted to column types in the DB. When writing
10 a StorageManager, you should make sure that your type conversions can handle
11 at least the following limitations. If possible, implement the type with no
12 limits. Also, follow UnitProperty.hints['bytes'] where possible. A value
13 of zero for hints['bytes'] implies no limit. If no value is given, try to
14 assume no limit, although you may choose whatever default size you wish
15 (255 is common for strings).
16 """
17
18 import time
19 import weakref
20 import Queue
21 from types import FunctionType
22 import dejavu
23 from dejavu import codewalk, logic, storage
24
25 try:
26     import cPickle as pickle
27 except ImportError:
28     import pickle
29
30 import datetime
31
32 try:
33     import fixedpoint
34 except ImportError:
35     pass
36
37 try:
38     # Builtin in Python 2.5?
39     decimal
40 except NameError:
41     try:
42         # Module in Python 2.3, 2.4
43         import decimal
44     except ImportError:
45         pass
46
47 import warnings
48 import threading
49
50
51 class FieldTypeAdapter(object):
52     """For a Python type, return the SQL column type name.
53     
54     This base class is designed to work out-of-the-box with PostgreSQL 8.
55     """
56    
57     # 1000 is the max precision for NUMERIC columns for PostgreSQL 8.
58     # Override in subclasses.
59     numeric_max_precision = 1000
60    
61     def coerce(self, cls, key):
62         """coerce(cls, key) -> SQL typename for valuetype."""
63         valuetype = cls.property_type(key)
64         mod = valuetype.__module__
65         if mod == "__builtin__":
66             xform = "coerce_%s" % valuetype.__name__
67         else:
68             xform = "coerce_%s_%s" % (mod, valuetype.__name__)
69         xform = xform.replace(".", "_")
70         try:
71             xform = getattr(self, xform)
72         except AttributeError:
73             raise TypeError("'%s' is not handled by %s." %
74                             (valuetype, self.__class__))
75         return xform(cls, key)
76    
77     def coerce_float(self, cls, key):
78         prop = getattr(cls, key)
79         bytes = int(prop.hints.get(u'bytes', '8'))
80         if bytes < 5:
81             # In MySQL, REAL is still probably 8 bytes. Meh.
82             return u"REAL"
83         else:
84             # Python floats are implemented using C doubles;
85             # actual precision depends on platform. PostgreSQL
86             # DOUBLE is 8 bytes (15 decimal-digit precision).
87             return u"DOUBLE PRECISION"
88    
89     def coerce_str(self, cls, key):
90         # The bytes hint shall not reflect the usual 4-byte base for varchar.
91         prop = getattr(cls, key)
92         bytes = int(prop.hints.get(u'bytes', '0'))
93         if bytes:
94             return u"VARCHAR(%s)" % bytes
95         else:
96             # TEXT is not an SQL standard, but it's common.
97             return u"TEXT"
98    
99     def coerce_dict(self, cls, key):
100         return self.coerce_str(cls, key)
101     def coerce_list(self, cls, key):
102         return self.coerce_str(cls, key)
103     def coerce_tuple(self, cls, key):
104         return self.coerce_str(cls, key)
105     def coerce_unicode(self, cls, key):
106         return self.coerce_str(cls, key)
107    
108     def coerce_bool(self, cls, key): return u"BOOLEAN"
109    
110     def coerce_datetime_datetime(self, cls, key): return u"TIMESTAMP"
111     def coerce_datetime_date(self, cls, key): return u"DATE"
112     def coerce_datetime_time(self, cls, key): return u"TIME"
113    
114     # I was seriously disinterested in writing a parser for interval.
115     def coerce_datetime_timedelta(self, cls, key):
116         return self.coerce_float(cls, key)
117    
118     def coerce_decimal_Decimal(self, cls, key):
119         prop = getattr(cls, key)
120         precision = int(prop.hints.get('precision', '0'))
121         if precision == 0:
122             precision = decimal.getcontext().prec
123         if precision > self.numeric_max_precision:
124             warnings.warn("Decimal precision %s > maximum %s for %s.%s, "
125                           "using %s. Values may be stored incorrectly."
126                           % (precision, self.numeric_max_precision,
127                              cls.__name__, key, self.__class__.__name__))
128             precision = self.numeric_max_precision
129         # Assume most people use decimal for money; default scale = 2.
130         scale = int(prop.hints.get(u'scale', 2))
131         return u"NUMERIC(%s, %s)" % (precision, scale)
132    
133     def coerce_decimal(self, cls, key):
134         # If decimal ever becomes a builtin. Python 2.5?
135         return self.coerce_decimal_Decimal(cls, key)
136    
137     def coerce_fixedpoint_FixedPoint(self, cls, key):
138         prop = getattr(cls, key)
139         precision = int(prop.hints.get('precision', '0'))
140         if precision == 0:
141             precision = self.numeric_max_precision
142         # Assume most people use decimal for money; default scale = 2.
143         scale = int(prop.hints.get(u'scale', 2))
144         return u"NUMERIC(%s, %s)" % (precision, scale)
145    
146     def coerce_long(self, cls, key):
147         prop = getattr(cls, key)
148         bytes = int(prop.hints.get(u'bytes', 0))
149         if bytes <= 4:
150             return self.coerce_int(cls, key)
151         elif bytes <= 8:
152             # BIGINT is usually 8 bytes
153             return "BIGINT"
154         # Anything larger than 8 bytes, use decimal/numeric.
155         if bytes > self.numeric_max_precision:
156             warnings.warn("Long bytes %s > maximum %s for %s.%s, "
157                           "using %s. Values may be stored incorrectly."
158                           % (bytes, self.numeric_max_precision,
159                              cls.__name__, key, self.__class__.__name__))
160             bytes = self.numeric_max_precision
161         return "NUMERIC(%s, 0)" % bytes
162    
163     def coerce_int(self, cls, key):
164         prop = getattr(cls, key)
165         bytes = int(prop.hints.get(u'bytes', '4'))
166         if bytes == 1:
167             return "BOOLEAN"
168         elif bytes == 2:
169             return "SMALLINT"
170 ##        elif bytes == 3:
171 ##            return "MEDIUMINT"
172         else:
173             return u"INTEGER"
174
175
176 class AdapterToSQL(object):
177     """Coerce Expression constants to SQL.
178     
179     This base class is designed to work out-of-the-box with PostgreSQL 8.
180     """
181    
182     # Notice these are ordered pairs. Escape \ before introducing new ones.
183     escapes = [("'", "''"), ("\\", r"\\")]
184     like_escapes = [("%", r"\%"), ("_", r"\_")]
185    
186     # These are not the same as coerce_bool (which is used on one side of
187     # a comparison). Instead, these are used when the whole (sub)expression
188     # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3".
189     bool_true = "TRUE"
190     bool_false = "FALSE"
191    
192     def escape_like(self, value):
193         """Prepare a string value for use in a LIKE comparison."""
194         # Notice we strip leading and trailing quote-marks.
195         value = value.strip("'\"")
196         for pat, repl in self.like_escapes:
197             value = value.replace(pat, repl)
198         return value
199    
200     def coerce(self, value, valuetype=None):
201         """coerce(value, valuetype=None) -> value, coerced by valuetype."""
202         if valuetype is None:
203             valuetype = type(value)
204        
205         mod = valuetype.__module__
206         if mod == "__builtin__":
207             xform = "coerce_%s" % valuetype.__name__
208         else:
209             xform = "coerce_%s_%s" % (mod, valuetype.__name__)
210         xform = xform.replace(".", "_")
211         try:
212             xform = getattr(self, xform)
213         except AttributeError:
214             raise TypeError("'%s' is not handled by %s." %
215                             (valuetype, self.__class__))
216         return xform(value)
217    
218     def tostr(self, value):
219         return str(value)
220    
221     def coerce_NoneType(self, value):
222         return "NULL"
223    
224     def coerce_bool(self, value):
225         if value:
226             return 'TRUE'
227         return 'FALSE'
228    
229     # The great thing about these 3 date coercers is that you can use
230     # them with (VAR)CHAR columns just as well as with DATETIME, etc.
231     # and comparisons will still work!
232     def coerce_datetime_datetime(self, value):
233         return (u"'%04d-%02d-%02d %02d:%02d:%02d'" %
234                 (value.year, value.month, value.day,
235                  value.hour, value.minute, value.second))
236    
237     def coerce_datetime_date(self, value):
238         return u"'%04d-%02d-%02d'" % (value.year, value.month, value.day)
239    
240     def coerce_datetime_time(self, value):
241         return u"'%02d:%02d:%02d'" % (value.hour, value.minute, value.second)
242    
243     def coerce_datetime_timedelta(self, value):
244         float_val = value.days + (value.seconds / 86400.0)
245         return repr(float_val)
246    
247     coerce_decimal = tostr
248     coerce_decimal_Decimal = tostr
249    
250     def do_pickle(self, value):
251         return self.coerce_str(pickle.dumps(value))
252    
253     coerce_dict = do_pickle
254    
255     coerce_fixedpoint_FixedPoint = tostr
256     coerce_float = tostr
257     coerce_int = tostr
258    
259     coerce_list = do_pickle
260    
261     coerce_long = tostr
262    
263     def coerce_str(self, value):
264         for pat, repl in self.escapes:
265             value = value.replace(pat, repl)
266         return "'" + value + "'"
267    
268     coerce_tuple = do_pickle
269    
270     coerce_unicode = coerce_str
271
272
273 class AdapterFromDB(object):
274     """Coerce incoming values from DB types to Dejavu datatypes.
275     
276     You might notice that we pass coltype around a lot, but don't
277     refer to it in this base class. It's there so subclasses can
278     make decisions about coercion when they don't have control over
279     the types of database columns, and must make do with legacy
280     databases implementations.
281     
282     This base class is designed to work out-of-the-box with PostgreSQL 8.
283     """
284    
285     def coerce(self, value, coltype, valuetype=None):
286         """coerce(value, coltype, valuetype=None) -> value, coerced by valuetype."""
287         if value is None:
288             return None
289        
290         if valuetype is None:
291             valuetype = type(value)
292        
293         mod = valuetype.__module__
294         if mod == "__builtin__":
295             xform = "coerce_%s" % valuetype.__name__
296         else:
297             xform = "coerce_%s_%s" % (mod, valuetype.__name__)
298         xform = xform.replace(".", "_")
299         try:
300             xform = getattr(self, xform)
301         except AttributeError:
302             raise TypeError("'%s' is not handled by %s." %
303                             (valuetype, self.__class__))
304         return xform(value, coltype)
305    
306     def consume(self, unit, key, value, coltype):
307         expectedType = unit.__class__.property_type(key)
308         value = self.coerce(value, coltype, expectedType)
309         unit._properties[key] = value
310    
311     def do_pickle(self, value, coltype):
312         # Coerce to str for pickle.loads restriction.
313         value = str(value)
314         return pickle.loads(value)
315    
316     def coerce_bool(self, value, coltype):
317         return bool(value)
318    
319     def coerce_datetime_datetime(self, value, coltype):
320         chunks = (value[0:4], value[5:7], value[8:10],
321                   value[11:13], value[14:16], value[17:19])
322         return datetime.datetime(*map(int, chunks))
323    
324     def coerce_datetime_date(self, value, coltype):
325         chunks = (value[0:4], value[5:7], value[8:10])
326         return datetime.date(*map(int, chunks))
327    
328     def coerce_datetime_time(self, value, coltype):
329         chunks = (value[0:2], value[3:5], value[6:8])
330         return datetime.time(*map(int, chunks))
331    
332     def coerce_datetime_timedelta(self, value, coltype):
333         days, seconds = divmod(value, 1)
334         return datetime.timedelta(days, int(seconds * 86400))
335    
336     def coerce_decimal(self, value, coltype):
337         return decimal(str(value))
338    
339     def coerce_decimal_Decimal(self, value, coltype):
340         return decimal.Decimal(str(value))
341    
342     coerce_dict = do_pickle
343    
344     def coerce_fixedpoint_FixedPoint(self, value, coltype):
345         return fixedpoint.FixedPoint(value)
346    
347     def coerce_float(self, value, coltype):
348         return float(value)
349    
350     def coerce_int(self, value, coltype):
351         return int(value)
352    
353     coerce_list = do_pickle
354    
355     def coerce_long(self, value, coltype):
356         return long(value)
357    
358     def coerce_str(self, value, coltype):
359         return str(value)
360    
361     coerce_tuple = do_pickle
362    
363     def coerce_unicode(self, value, coltype):
364         # You should REALLY check into your DB's encoding and override this.
365         return unicode(value, "utf8")
366
367
368 # -------------------------- SQL DECOMPILATION -------------------------- #
369
370 class ConstWrapper(unicode):
371     """Wraps a constant for use in SQLDecompiler's stack.
372     
373     When we hit LOAD_CONST while decompiling, we occasionally need to keep
374     both the base and the coerced value around (see COMPARE_OP for use
375     of ConstWrapper.basevalue).
376     """
377     def __new__(self, basevalue, coerced_value):
378         newobj = unicode.__new__(ConstWrapper, coerced_value)
379         newobj.basevalue = basevalue
380         return newobj
381    
382     def __str__(self):
383         return str(self)
384    
385     def __unicode__(self):
386         return self
387    
388     def __repr__(self):
389         return self
390
391
392 # Stack sentinels
393 class Sentinel(object):
394    
395     def __init__(self, name):
396         self.name = name
397    
398     def __repr__(self):
399         return 'Stack Sentinel: %s' % self.name
400
401 table_arg = Sentinel('Table Arg')
402 kw_arg = Sentinel('Keyword Arg')
403 # cannot_represent exists so that a portion of an Expression can be
404 # labeled imperfect. For example, the function dejavu.iscurrentweek
405 # rarely has an SQL equivalent. All Units (which match the rest of
406 # the Expression) will be recalled; they can then be compared in
407 # expr.evaluate(unit).
408 cannot_represent = Sentinel('Cannot Repr')
409
410
411 class SQLDecompiler(codewalk.LambdaDecompiler):
412     """SQLDecompiler(tablename, expr, adapter=AdapterToSQL()).
413     
414     Produce SQL from a supplied Expression object, with a lambda of the form:
415         lambda x, **kw: ...
416     
417     Attributes of x (or whatever the name of the first argument is) will be
418     mapped to table columns. Keyword arguments should be bound using
419     Expression.bind_args before calling this decompiler.
420     """
421    
422     # Some constants are function or class objects,
423     # which should not be coerced.
424     no_coerce = (FunctionType,
425                  type,
426                  type(len),       # <type 'builtin_function_or_method'>
427                  )
428    
429     sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in')
430    
431     def __init__(self, tablename, expr, adapter=AdapterToSQL()):
432         self.tablename = tablename
433         self.expr = expr
434         self.adapter = adapter
435         obj = expr.func
436         codewalk.LambdaDecompiler.__init__(self, obj)
437    
438     def code(self):
439         self.imperfect = False
440         self.walk()
441         # After walk(), self.stack should be reduced to a single string,
442         # which is the SQL representation of our Expression.
443         result = self.stack[0]
444         if result is cannot_represent:
445             # The entire expression could not be evaluated.
446             result = self.adapter.bool_true
447         if result == self.adapter.coerce_bool(True):
448             result = self.adapter.bool_true
449         if result == self.adapter.coerce_bool(False):
450             result = self.adapter.bool_false
451         return result
452    
453     def visit_target(self, terms):
454         """A target is an AND or OR test."""
455         comp = self.stack.pop()
456         trueval = self.adapter.bool_true
457         falseval = self.adapter.bool_false
458         while terms:
459             term, operation = terms.pop()
460             if term is cannot_represent:
461                 term = trueval
462             if comp is cannot_represent:
463                 comp = trueval
464            
465             # Blurg. SQL Server is *so* picky.
466             if term == self.adapter.coerce_bool(True):
467                 term = trueval
468             if term == self.adapter.coerce_bool(False):
469                 term = falseval
470             if comp == self.adapter.coerce_bool(True):
471                 comp = trueval
472             if comp == self.adapter.coerce_bool(False):
473                 comp = falseval
474            
475             comp = "(%s) %s (%s)" % (term, operation.upper(), comp)
476         self.stack.append(comp)
477    
478     def visit_LOAD_DEREF(self, lo, hi):
479         raise ValueError("Illegal reference found in %s." % self.expr)
480    
481     def visit_LOAD_GLOBAL(self, lo, hi):
482         raise ValueError("Illegal global found in %s." % self.expr)
483    
484     def visit_LOAD_FAST(self, lo, hi):
485         if lo + (hi << 8) < self.co_argcount:
486             self.stack.append(table_arg)
487         else:
488             self.stack.append(kw_arg)
489    
490     def visit_LOAD_ATTR(self, lo, hi):
491         name = self.co_names[lo + (hi << 8)]
492         tos = self.stack.pop()
493         if tos is table_arg:
494             # Call another function to make subclassing easier.
495             self.stack.append(self.column_name(name))
496         else:
497             # tos.name will reference an attribute of the tos object.
498             # Stick the tos and name in a tuple for later processing.
499             self.stack.append((tos, name))
500    
501     def visit_LOAD_CONST(self, lo, hi):
502         val = self.co_consts[lo + (hi << 8)]
503         if not isinstance(val, self.no_coerce):
504             val = ConstWrapper(val, self.adapter.coerce(val))
505         self.stack.append(val)
506    
507     def visit_BUILD_TUPLE(self, lo, hi):
508         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)])
509         self.stack.append("(" + terms + ")")
510    
511     visit_BUILD_LIST = visit_BUILD_TUPLE
512    
513     def visit_CALL_FUNCTION(self, lo, hi):
514         kwargs = {}
515         for i in xrange(hi):
516             val = self.stack.pop()
517             key = self.stack.pop()
518             kwargs[key] = val
519         kwargs = [k + "=" + v for k, v in kwargs.iteritems()]
520        
521         args = []
522         for i in xrange(lo):
523             arg = self.stack.pop()
524             args.append(arg)
525         args.reverse()
526        
527         if kwargs:
528             args += kwargs
529        
530         func = self.stack.pop()
531        
532         # Handle function objects.
533         if isinstance(func, tuple):
534             tos, name = func
535             dispatch = getattr(self, "attr_" + name, None)
536             if dispatch:
537                 self.stack.append(dispatch(tos, *args))
538                 return
539         else:
540             funcname = func.__module__ + "_" + func.__name__
541             funcname = funcname.replace(".", "_")
542             if funcname.startswith("_"):
543                 funcname = "func" + funcname
544             dispatch = getattr(self, funcname, None)
545             if dispatch:
546                 self.stack.append(dispatch(*args))
547                 return
548        
549         self.stack.append(cannot_represent)
550         self.imperfect = True
551    
552     def visit_COMPARE_OP(self, lo, hi):
553         op2, op1 = self.stack.pop(), self.stack.pop()
554         if op1 is cannot_represent or op2 is cannot_represent:
555             self.stack.append(cannot_represent)
556             return
557        
558         op = lo + (hi << 8)
559         if op in (6, 7):     # in, not in
560             value = self.containedby(op1, op2)
561             if op == 7:
562                 value = "NOT " + value
563             self.stack.append(value)
564         elif op1 == 'NULL':
565             if op == 2:
566                 self.stack.append(op2 + " IS NULL")
567             elif op == 3:
568                 self.stack.append(op2 + " IS NOT NULL")
569             else:
570                 raise ValueError("Non-equality Null comparisons not allowed.")
571         elif op2 == 'NULL':
572             if op == 2:
573                 self.stack.append(op1 + " IS NULL")
574             elif op == 3:
575                 self.stack.append(op1 + " IS NOT NULL")
576             else:
577                 raise ValueError("Non-equality Null comparisons not allowed.")
578         else:
579             # Comparison operators for strings are case-sensitive in PG et al.
580             self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2)
581    
582     def binary_op(self, op):
583         op2, op1 = self.stack.pop(), self.stack.pop()
584         self.stack.append(op1 + " " + op + " " + op2)
585    
586     def visit_BINARY_SUBSCR(self):
587         # The only BINARY_SUBSCR used in Expressions should be kwargs[key].
588         name = self.stack.pop()
589         tos = self.stack.pop()
590         if tos is not kw_arg:
591             raise ValueError("Subscript %s of %s object not allowed."
592                              % (name, tos))
593         # name, since formed in LOAD_CONST, may have extraneous quotes.
594         name = name.strip("'\"")
595         value = self.expr.kwargs[name]
596         if not isinstance(value, self.no_coerce):
597             value = ConstWrapper(value, self.adapter.coerce(value))
598         self.stack.append(value)
599    
600     def visit_UNARY_NOT(self):
601         op = self.stack.pop()
602         if op is cannot_represent:
603             self.stack.append(cannot_represent)
604         else:
605             self.stack.append("NOT (" + op + ")")
606    
607     def column_name(self, name):
608         # This is valid SQL for PostgreSQL only and should be overridden.
609         # If you want to use a map from UnitProperty names to legacy DB
610         # column names, override this method. You will probably also
611         # want to override StorageManager.identifier and perform the
612         # same map lookup there.
613         return '%s."%s"' % (self.tablename, name)
614    
615     # --------------------------- Dispatchees --------------------------- #
616    
617     def attr_startswith(self, tos, arg):
618         return tos + " LIKE '" + self.adapter.escape_like(arg) + "%'"
619    
620     def attr_endswith(self, tos, arg):
621         return tos + " LIKE '%" + self.adapter.escape_like(arg) + "'"
622    
623     def containedby(self, op1, op2):
624         if isinstance(op1, ConstWrapper):
625             # Looking for text in a field. Use Like (reverse terms).
626             return op2 + " LIKE '%" + self.adapter.escape_like(op1) + "%'"
627         else:
628             # Looking for field in (a, b, c)
629             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
630             return op1 + " IN (" + ", ".join(atoms) + ")"
631    
632     def dejavu_icontainedby(self, op1, op2):
633         if isinstance(op1, ConstWrapper):
634             # Looking for text in a field. Use Like (reverse terms).
635             return ("LOWER(" + op2 + ") LIKE '%" +
636                     self.adapter.escape_like(op1).lower() + "%'")
637         else:
638             # Looking for field in (a, b, c).
639             # Force all args to lowercase for case-insensitive comparison.
640             atoms = [self.adapter.coerce(x).lower() for x in op2.basevalue]
641             return "LOWER(%s) IN (%s)" % (op1, ", ".join(atoms))
642    
643     def dejavu_icontains(self, x, y):
644         return self.dejavu_icontainedby(y, x)
645    
646     def dejavu_istartswith(self, x, y):
647         return "LOWER(" + x + ") LIKE '" + self.adapter.escape_like(y) + "%'"
648    
649     def dejavu_iendswith(self, x, y):
650         return "LOWER(" + x + ") LIKE '%" + self.adapter.escape_like(y) + "'"
651    
652     def dejavu_ieq(self, x, y):
653         return "LOWER(" + x + ") = LOWER(" + y + ")"
654    
655     def dejavu_now(self):
656         return "NOW()"
657    
658     def dejavu_today(self):
659         return "CURRENT_DATE"
660    
661     def dejavu_year(self, x):
662         return "YEAR(" + x + ")"
663    
664     def func__builtin___len(self, x):
665         return "LENGTH(" + x + ")"
666
667
668 class ConnectionWrapper(object):
669     def __init__(self, conn=None):
670         self.conn = conn
671    
672     def __getattr__(self, attr):
673         return getattr(self.conn, attr)
674
675
676 class StorageManagerDB(storage.StorageManager):
677     """StoreManager to save and retrieve Units using a DB."""
678    
679     identifier_length = 64
680     identifier_caseless = False
681     close_connection_method = 'close'
682    
683     decompiler = SQLDecompiler
684     typeAdapter = FieldTypeAdapter()
685     toAdapter = AdapterToSQL()
686     fromAdapter = AdapterFromDB()
687     debug_connections = False
688    
689     def __init__(self, name, arena, allOptions={}):
690         storage.StorageManager.__init__(self, name, arena, allOptions)
691        
692         # Adapter Overrides
693         def get_adapter_option(name):
694             adapter_class = allOptions.get(name)
695             if isinstance(adapter_class, basestring):
696                 import xray
697                 adapter_class = xray.classes(adapter_class)
698             return adapter_class
699        
700         adapter = get_adapter_option(u'Type Adapter')
701         if adapter: self.typeAdapter = adapter
702         adapter = get_adapter_option(u'To Adapter')
703         if adapter: self.toAdapter = adapter
704         adapter = get_adapter_option(u'From Adapter')
705         if adapter: self.fromAdapter = adapter
706        
707         self.CreateIfMissing = bool(allOptions.get(u'Create If Missing', 'True'))
708         self.pool_size = int(allOptions.get(u'Pool Size', '10'))
709        
710         self.refs = {}
711         if self.pool_size > 0:
712             self.pool = Queue.Queue(self.pool_size)
713         else:
714             self.pool = None
715         self.retry = 5
716         self.threaded = True
717        
718         self.prefix = allOptions.get(u'Prefix', u"djv")
719         self.reserve_lock = threading.Lock()
720        
721         ec = {}
722         for prop in allOptions.get(u'Expanded Columns', '').split(","):
723             if prop:
724                 cls, type = [x.strip() for x in prop.split(":", 1)]
725                 lastdot = cls.rfind(".")
726                 clsname, key = cls[:lastdot], cls[lastdot + 1:]
727                 ec[(clsname, key)] = type
728         self.expanded_columns = ec
729    
730     def identifier(self, *atoms):
731         ident = ''.join(map(str, atoms)).replace('"', '""')
732         if self.identifier_caseless:
733             ident = ident.lower()
734         idlen = self.identifier_length
735         if idlen and len(ident) > idlen:
736             warnings.warn("Identifier is longer than %s characters." % idlen)
737             ident = ident[:idlen]
738         return '"' + ident + '"'
739    
740     def tablename(self, cls):
741         if isinstance(cls, type):
742             name = cls.__name__
743         elif isinstance(cls, dejavu.Unit):
744             name = cls.__class__.__name__
745         elif isinstance(cls, basestring):
746             name = cls
747         else:
748             raise TypeError("Cannot form tablenames from %s" % cls)
749         return self.identifier(self.prefix, name)
750    
751     def _get_conn(self):
752         # Override this with the connection call for your DB. Example follows:
753 ##        try:
754 ##            conn = libpq.PQconnectdb(self.connstring)
755 ##        except Exception, x:
756 ##            if self.CreateIfMissing:
757 ##                self.create_database()
758 ##                conn = libpq.PQconnectdb(self.connstring)
759 ##            else:
760 ##                raise
761 ##        return conn
762         raise NotImplementedError
763    
764     def connection(self):
765         if not self.threaded:
766             try:
767                 return self.refs['conn']
768             except KeyError:
769                 self.refs['conn'] = conn = self._get_conn()
770                 return conn
771        
772         retry = 0
773         while True:
774             if self.pool is not None:
775                 try:
776                     conn = self.pool.get_nowait()
777                     # Okay, this is freaky. If we wrap here, all goes well.
778                     # If we wrap on Queue.put(), mysql crashes after 1700
779                     # or so inserts (when migrating Access tables to MySQL).
780                     # Go figure.
781                     w = ConnectionWrapper(conn)
782                     self.refs[weakref.ref(w, self.release)] = w.conn
783                     if self.debug_connections:
784                         print "-->get %s" % self.__class__.__name__
785                     return w
786                 except Queue.Empty:
787                     pass
788            
789             try:
790                 conn = self._get_conn()
791                 w = ConnectionWrapper(conn)
792                 self.refs[weakref.ref(w, self.release)] = w.conn
793                 if self.debug_connections:
794                     print "create %s" % self.__class__.__name__
795                 return w
796             except OutOfConnectionsError:
797                 retry += 1
798                 if retry < self.retry:
799                     time.sleep(retry * 1)
800                     conn = None
801                     continue
802                 raise OutOfConnectionsError()
803    
804     def release(self, ref):
805         conn = self.refs[ref]
806         del self.refs[ref]
807        
808         if self.pool is not None:
809             try:
810                 self.pool.put_nowait(conn)
811                 if self.debug_connections:
812                     print "<--put %s" % self.__class__.__name__
813                 return
814             except Queue.Full:
815                 pass
816        
817         getattr(conn, self.close_connection_method)()
818         if self.debug_connections:
819             print "___close___ %s" % self.__class__.__name__
820    
821     def shutdown(self):
822         # Empty the pool.
823         if self.pool:
824             while True:
825                 try:
826                     self.pool.get(True, 0.5)
827                 except Queue.Empty:
828                     break
829        
830         # Empty self.refs.
831         while self.refs:
832             ref, conn = self.refs.popitem()
833             getattr(conn, self.close_connection_method)()
834    
835     def create_database(self):
836         self.execute("CREATE DATABASE %s;" % self.identifier(self.dbname))
837    
838     def drop_database(self):
839         self.execute("DROP DATABASE %s;" % self.identifier(self.dbname))
840    
841     def create_storage(self, unitClass):
842         tablename = self.tablename(unitClass)
843        
844         coerce = self.typeAdapter.coerce
845         fields = []
846         for key in unitClass.properties():
847             fields.append(u'%s %s' % (self.identifier(key),
848                                       coerce(unitClass, key)))
849         self.execute(u'CREATE TABLE %s (%s);' % (tablename, ", ".join(fields)))
850         for index in unitClass.indices():
851             i = self.identifier(self.prefix, "i", unitClass.__name__, index)
852             self.execute(u'CREATE INDEX %s ON %s (%s);' %
853                          (i, tablename, self.identifier(index)))
854    
855     def select(self, unitClass, expr, distinct_fields=None):
856         tablename = self.tablename(unitClass)
857         if distinct_fields:
858             distinct_fields = [self.identifier(x) for x in distinct_fields]
859             sql = (u'SELECT DISTINCT %s FROM %s' %
860                    (u', '.join(distinct_fields), tablename))
861         else:
862             sql = u'SELECT * FROM %s' % tablename
863         w, i = self.where(unitClass, expr)
864         if len(w) > 0:
865             w = u" WHERE " + w
866         else:
867             w = u""
868         sql += w + ";"
869         return sql, i
870    
871     def where(self, cls, expr):
872         decom = self.decompiler(self.tablename(cls), expr, self.toAdapter)
873         return decom.code(), decom.imperfect
874    
875     def execute(self, query, conn=None):
876         """execute(query, conn=None) -> result set."""
877         if conn is None:
878             conn = self.connection()
879         return conn.query(query.encode('utf8'))
880    
881     def fetch(self, query, conn=None):
882         """fetch(query, conn=None) -> rowdata, columns.
883         
884         This base class uses SQLite3 syntax."""
885         res = self.execute(query, conn)
886         return res.row_list, res.col_defs
887    
888     def recall(self, cls, expr=None):
889         if expr is None:
890             expr = logic.Expression(lambda x: True)
891         sql, imperfect = self.select(cls, expr)
892         data, col_defs = self.fetch(sql)
893        
894         columns = dict([(col[0], (index, col[1])) for index, col
895                         in enumerate(col_defs)])
896        
897         # Get specs on properties. Get the ID property first in case other
898         # fields depend upon it. See load_expanded, for example.
899         props = []
900         for key in ['ID'] + [x for x in cls.properties() if x != "ID"]:
901             if self.identifier_caseless:
902                 index, ftype = columns[key.lower()]
903             else:
904                 index, ftype = columns[key]
905             subtype = self.expanded_columns.get((cls.__name__, key))
906             props.append((key, index, ftype, subtype))
907        
908         consume = self.fromAdapter.consume
909         for row in data:
910             unit = cls()
911             for key, index, ftype, subtype in props:
912                 value = row[index]
913                 if subtype:
914                     self.load_expanded(unit, key, subtype)
915                 else:
916                     try:
917                         consume(unit, key, value, ftype)
918                     except UnicodeDecodeError, x:
919                         x.reason += "[%s][%s][%s]" % (key, value, ftype)
920                         raise x
921                     except Exception, x:
922                         x.args += (key, value, ftype)
923                         raise x
924            
925             # If our SQL is imperfect, don't yield it to the
926             # caller unless it passes evaluate().
927             if (not imperfect) or expr.evaluate(unit):
928                 unit.cleanse()
929                 yield unit
930    
931     def reserve(self, unit):
932         """reserve(unit). -> Reserve a persistent slot for unit.
933         
934         Notice in particular that we do not use the auto-number or
935         sequence generation capabilities within some databases, etc.
936         The ID should be supplied by UnitSequencers via reserve().
937         """
938         cls = unit.__class__
939         tablename = self.tablename(unit)
940         i = self.identifier
941         self.reserve_lock.acquire()
942         try:
943             if unit.ID is None:
944                 # Examine all existing IDs and grant the "next" one.
945                 data, cols = self.fetch(u'SELECT %s FROM %s;' % (i('ID'), tablename))
946                 if data:
947                     # sqlite 2, for example, has empty cols tuple if no data.
948                     coerce = self.fromAdapter.coerce
949                     coltype = cols[0][1]
950                     expectedType = cls.property_type("ID")
951                     data = [coerce(row[0], coltype, expectedType) for row in data]
952                     unit.ID = unit.sequencer.next(data)
953                 else:
954                     unit.ID = unit.sequencer.next([])
955                 del data
956                 del cols
957            
958             fields = []
959             values = []
960             for key in cls.properties():
961                 subtype = self.expanded_columns.get((cls.__name__, key))
962                 if subtype:
963                     self.save_expanded(unit, key, subtype)
964                 else:
965                     val = self.toAdapter.coerce(getattr(unit, key))
966                     fields.append(i(key))
967                     values.append(val)
968            
969             self.execute('INSERT INTO %s (%s) VALUES (%s);' %
970                          (tablename, u", ".join(fields),
971                           u", ".join(values)))
972             unit.cleanse()
973         finally:
974             self.reserve_lock.release()
975    
976     def save(self, unit, forceSave=False):
977         """save(unit, forceSave=False) -> Update storage from unit's data."""
978         if unit.dirty() or forceSave:
979             cls = unit.__class__
980            
981             parms = []
982             for key in cls.properties():
983                 if key != "ID":
984                     subtype = self.expanded_columns.get((cls.__name__, key))
985                     if subtype:
986                         self.save_expanded(unit, key, subtype)
987                     else:
988                         val = self.toAdapter.coerce(getattr(unit, key))
989                         parms.append('%s = %s' % (self.identifier(key), val))
990            
991             sql = ('UPDATE %s SET %s WHERE %s = %s;' %
992                    (self.tablename(unit), u", ".join(parms),
993                     self.identifier("ID"),
994                     self.toAdapter.coerce(unit.ID)))
995             self.execute(sql)
996             unit.cleanse()
997    
998     def save_expanded(self, unit, key, subtype):
999         """save_expanded(unit, key, subtype). Save list in separate table."""
1000         unitcls = unit.__class__
1001         table = self.identifier(self.prefix, "_", unitcls.__name__,
1002                                 "_", unit.ID, "_", key)
1003        
1004         # Just drop the old table and start with a new one.
1005         try:
1006             self.execute(u"DROP TABLE %s;" % table)
1007         except:
1008             pass
1009        
1010         val = getattr(unit, key)
1011         if val is None:
1012             # Don't create a new table at all. This will signal
1013             # recall() to set the attribute to None on load.
1014             pass
1015         else:
1016             ftype = getattr(self.typeAdapter, "coerce_" + subtype)(unitcls, key)
1017             self.execute(u"CREATE TABLE %s (EXPVAL %s);" % (table, ftype))
1018            
1019             for v in val:
1020                 self.execute(u"INSERT INTO %s (EXPVAL) VALUES ('%s');"
1021                              % (table, self.toAdapter.coerce(v)))
1022    
1023     def load_expanded(self, unit, key, subtype):
1024         """load_expanded(unit, key, subtype). Load list from separate table."""
1025         table = self.identifier(self.prefix, "_", unit.__class__.__name__,
1026                                 "_", unit.ID, "_", key)
1027         try:
1028             data, col_defs = self.fetch(u"SELECT EXPVAL FROM %s" % table)
1029         except:
1030             values = None
1031         else:
1032             coltype = col_defs[0][1]
1033             coercer = getattr(self.fromAdapter, "coerce_" + subtype)
1034             values = [coercer(row[0], coltype) for row in data]
1035            
1036             expected_type = unit.__class__.property_type(key)
1037             values = expected_type(values)
1038        
1039         # Set the attribute directly to avoid __set__ overhead.
1040         unit._properties[key] = values
1041    
1042     def destroy(self, unit):
1043         """destroy(unit). Delete the unit."""
1044         self.execute(u'DELETE * FROM %s WHERE %s = %s;' %
1045                      (self.tablename(unit), self.identifier("ID"),
1046                       self.toAdapter.coerce(unit.ID)))
1047    
1048     def distinct(self, cls, fields, expr=None):
1049         """distinct(cls, fields, expr=None) -> Distinct values for given fields."""
1050         if expr is None:
1051             expr = logic.Expression(lambda x: True)
1052        
1053         sql, imperfect = self.select(cls, expr, fields)
1054         if imperfect:
1055             # ^%$#@! There's no way to handle imperfect queries without
1056             # creating all involved Units, which defeats the purpose of
1057             # distinct, which was a speed issue more than anything.
1058             warnings.warn("The requested distinct() query for %s Units "
1059                           "cannot produce perfect SQL with a %s datasource. "
1060                           "It may take an absurd amount of time to run, "
1061                           "since each unit must be fully-formed. %s"
1062                           % (cls.__name__, self.__class__.__name__, expr))
1063             distvals = {}
1064             for unit in self.recall(cls, expr):
1065                 # Must use inner tuples for hashability in Sandbox.distinct()
1066                 val = tuple([getattr(unit, f) for f in fields])
1067                 distvals[val] = None
1068             return distvals.keys()
1069         else:
1070             data, columns = self.fetch(sql)
1071             actualTypes = [x[1] for x in columns]
1072             expectedTypes = [cls.property_type(x) for x in fields]
1073            
1074             coerce = self.fromAdapter.coerce
1075             # Must use inner tuples for hashability in Sandbox.distinct()
1076             return [tuple([coerce(val, actualTypes[i], expectedTypes[i])
1077                            for i, val in enumerate(row)])
1078                      for row in data]
1079    
1080     def multiselect(self, pairs):
1081         t = self.tablename
1082         i = self.identifier
1083         firstcls = pairs[0][0]
1084         tablenames = []
1085         # Because various databases may mangle column names, we explicitly
1086         # order the requested columns (instead of using *).
1087         columns = []
1088         wheres = []
1089         imps = []
1090        
1091         for cls, expr in pairs:
1092             tablenames.append(t(cls))
1093             # Place the ID property first in case others depend upon it.
1094             keys = ['ID'] + [k for k in cls.properties() if k != 'ID']
1095             columns.extend([(cls, k) for k in keys])
1096            
1097             if expr is None:
1098                 expr = logic.Expression(lambda x: True)
1099             w, imp = self.where(cls, expr)
1100             wheres.append(w)
1101             imps.append(imp)
1102            
1103             if cls is not firstcls:
1104                 spath = self.arena.associations.shortest_path(firstcls, cls)
1105                 # This should be firstcls in every case.
1106                 cls1 = spath.pop(0)
1107                 for cls2 in spath:
1108                     leftkey, rightkey = cls1._associations[cls2]
1109                     wheres.append("(%s.%s = %s.%s)" %
1110                                   (t(cls1), i(leftkey),
1111                                    t(cls2), i(rightkey)))
1112                     cls1 = cls2
1113        
1114         # Remove any duplicate entries in the where clauses
1115         # (there may be several from the _join_clauses).
1116         wheres = dict.fromkeys(wheres).keys()
1117        
1118         names = ["%s.%s" % (t(cls), i(key)) for cls, key in columns]
1119         tbls = u', '.join(tablenames)
1120         w = u' AND '.join(wheres)
1121         statement = "SELECT %s FROM %s WHERE %s" % (u', '.join(names), tbls, w)
1122         return statement, imps, columns
1123    
1124     def multirecall(self, *pairs):
1125         """multirecall(*pairs) -> Full inner join units for each (cls, expr) pair."""
1126         sql, imps, supplied_cols = self.multiselect(pairs)
1127         data, recvd_cols = self.fetch(sql)
1128        
1129         # Get specs on properties.
1130         props = []
1131         for sup, rec in zip(supplied_cols, recvd_cols):
1132             c, key = sup
1133             name, ftype = rec[0], rec[1]
1134             subtype = self.expanded_columns.get((c.__name__, key))
1135             props.append((c, key, ftype, subtype))
1136        
1137         consume = self.fromAdapter.consume
1138         for row in data:
1139             index = 0
1140             units = {}
1141             for c, key, ftype, subtype in props:
1142                 if c in units:
1143                     unit = units[c]
1144                 else:
1145                     units[c] = unit = c()
1146                 value = row[index]
1147                 if subtype:
1148                     self.load_expanded(unit, key, subtype)
1149                 else:
1150                     try:
1151                         consume(unit, key, value, ftype)
1152                     except UnicodeDecodeError, x:
1153                         x.reason += "[%s][%s][%s]" % (key, value, ftype)
1154                         raise x
1155                     except Exception, x:
1156                         x.args += (key, value, ftype)
1157                         raise x
1158                 index += 1
1159            
1160             # If our SQL is imperfect, don't yield units to the
1161             # caller unless they pass evaluate().
1162             acceptable = True
1163             unitset = []
1164             for pair, imp in zip(pairs, imps):
1165                 c, e = pair
1166                 unit = units[c]
1167                 unit.cleanse()
1168                 if imp:
1169                     acceptable &= e.evaluate(unit)
1170                     if not acceptable:
1171                         break
1172                 unitset.append(unit)
1173             if acceptable:
1174                 yield unitset
1175
1176
1177 class OutOfConnectionsError(dejavu.DejavuError):
1178     """Exception raised when a database store has run out of connections."""
1179     pass
1180
Note: See TracBrowser for help on using the browser.