Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/tags/1.4.0/storage/db.py

Revision 168 (checked in by fumanchu, 6 years ago)

Fixed some errors when fetch returns empty col_defs.

  • Property svn:eol-style set to native
Line 
1 """Base classes and tools for writing database Storage Managers.
2
3 Unit type -> [SQL repr ->] DB -> incoming Python value -> Unit type
4
5
6 DATA TYPES
7 ==========
8 Since Dejavu relies on external database servers for its persistence,
9 Python datatypes must be converted to column types in the DB. When writing
10 a StorageManager, you should make sure that your type conversions can handle
11 at least the following limitations. If possible, implement the type with no
12 limits. Also, follow UnitProperty.hints['bytes'] where possible. A value
13 of zero for hints['bytes'] implies no limit. If no value is given, try to
14 assume no limit, although you may choose whatever default size you wish
15 (255 is common for strings).
16 """
17
18 import datetime
19
20 try:
21     # Builtin in Python 2.5?
22     decimal
23 except NameError:
24     try:
25         # Module in Python 2.3, 2.4
26         import decimal
27     except ImportError:
28         pass
29
30 try:
31     import fixedpoint
32 except ImportError:
33     pass
34
35 try:
36     import cPickle as pickle
37 except ImportError:
38     import pickle
39
40 import Queue
41 import threading
42 import time
43 from types import FunctionType
44 import warnings
45 import weakref
46
47
48 import dejavu
49 from dejavu import codewalk, logic, storage, LOGCONN, LOGSQL, xray
50
51
52 def getCoerceName(value, valuetype=None):
53     mod = valuetype.__module__
54     if mod == "__builtin__":
55         xform = "coerce_%s" % valuetype.__name__
56     else:
57         xform = "coerce_%s_%s" % (mod, valuetype.__name__)
58     xform = xform.replace(".", "_")
59     return xform
60
61 def getCoerceMethod(obj, value, valuetype=None):
62     if valuetype is None:
63         valuetype = type(value)
64    
65     meth = getCoerceName(value, valuetype)
66     if hasattr(obj, meth):
67         return getattr(obj, meth)
68    
69     methods = []
70     for base in valuetype.__bases__:
71         meth = getCoerceName(value, base)
72         methods.append(meth)
73         if hasattr(obj, meth):
74             return getattr(obj, meth)
75    
76     raise TypeError("'%s' is not handled by %s.  Looked for: %s " %
77                     (valuetype, obj.__class__, ",".join(methods)))
78
79
80 class FieldTypeAdapter(object):
81     """For a Python type, return the SQL column type name.
82     
83     This base class is designed to work out-of-the-box with PostgreSQL 8.
84     """
85    
86     # 1000 is the max precision for NUMERIC columns for PostgreSQL 8.
87     # Override in subclasses.
88     numeric_max_precision = 1000
89    
90     def coerce(self, cls, key):
91         """coerce(cls, key) -> SQL typename for valuetype."""
92         valuetype = cls.property_type(key)
93         mod = valuetype.__module__
94         if mod == "__builtin__":
95             xform = "coerce_%s" % valuetype.__name__
96         else:
97             xform = "coerce_%s_%s" % (mod, valuetype.__name__)
98         xform = xform.replace(".", "_")
99         try:
100             xform = getattr(self, xform)
101         except AttributeError:
102             raise TypeError("'%s' is not handled by %s." %
103                             (valuetype, self.__class__))
104         return xform(cls, key)
105    
106     def coerce_float(self, cls, key):
107         prop = getattr(cls, key)
108         bytes = int(prop.hints.get(u'bytes', '8'))
109         if bytes < 5:
110             # In MySQL, REAL is still probably 8 bytes. Meh.
111             return u"REAL"
112         else:
113             # Python floats are implemented using C doubles;
114             # actual precision depends on platform. PostgreSQL
115             # DOUBLE is 8 bytes (15 decimal-digit precision).
116             return u"DOUBLE PRECISION"
117    
118     def coerce_str(self, cls, key):
119         # The bytes hint shall not reflect the usual 4-byte base for varchar.
120         prop = getattr(cls, key)
121         bytes = int(prop.hints.get(u'bytes', '0'))
122         if bytes:
123             return u"VARCHAR(%s)" % bytes
124         else:
125             # TEXT is not an SQL standard, but it's common.
126             return u"TEXT"
127    
128     def coerce_dict(self, cls, key):
129         return self.coerce_str(cls, key)
130     def coerce_list(self, cls, key):
131         return self.coerce_str(cls, key)
132     def coerce_tuple(self, cls, key):
133         return self.coerce_str(cls, key)
134     def coerce_unicode(self, cls, key):
135         return self.coerce_str(cls, key)
136    
137     def coerce_bool(self, cls, key): return u"BOOLEAN"
138    
139     def coerce_datetime_datetime(self, cls, key): return u"TIMESTAMP"
140     def coerce_datetime_date(self, cls, key): return u"DATE"
141     def coerce_datetime_time(self, cls, key): return u"TIME"
142    
143     # I was seriously disinterested in writing a parser for interval.
144     def coerce_datetime_timedelta(self, cls, key):
145         return self.coerce_float(cls, key)
146    
147     def coerce_decimal_Decimal(self, cls, key):
148         prop = getattr(cls, key)
149         precision = int(prop.hints.get('precision', '0'))
150         if precision == 0:
151             precision = decimal.getcontext().prec
152         if precision > self.numeric_max_precision:
153             warnings.warn("Decimal precision %s > maximum %s for %s.%s, "
154                           "using %s. Values may be stored incorrectly."
155                           % (precision, self.numeric_max_precision,
156                              cls.__name__, key, self.__class__.__name__),
157                           dejavu.StorageWarning)
158             precision = self.numeric_max_precision
159         # Assume most people use decimal for money; default scale = 2.
160         scale = int(prop.hints.get(u'scale', 2))
161         return u"NUMERIC(%s, %s)" % (precision, scale)
162    
163     def coerce_decimal(self, cls, key):
164         # If decimal ever becomes a builtin. Python 2.5?
165         return self.coerce_decimal_Decimal(cls, key)
166    
167     def coerce_fixedpoint_FixedPoint(self, cls, key):
168         prop = getattr(cls, key)
169         precision = int(prop.hints.get('precision', '0'))
170         if precision == 0:
171             precision = self.numeric_max_precision
172         # Assume most people use decimal for money; default scale = 2.
173         scale = int(prop.hints.get(u'scale', 2))
174         return u"NUMERIC(%s, %s)" % (precision, scale)
175    
176     def coerce_long(self, cls, key):
177         prop = getattr(cls, key)
178         bytes = int(prop.hints.get(u'bytes', 0))
179         if bytes <= 4:
180             return self.coerce_int(cls, key)
181         elif bytes <= 8:
182             # BIGINT is usually 8 bytes
183             return "BIGINT"
184         # Anything larger than 8 bytes, use decimal/numeric.
185         if bytes > self.numeric_max_precision:
186             warnings.warn("Long bytes %s > maximum %s for %s.%s, "
187                           "using %s. Values may be stored incorrectly."
188                           % (bytes, self.numeric_max_precision,
189                              cls.__name__, key, self.__class__.__name__),
190                           dejavu.StorageWarning)
191             bytes = self.numeric_max_precision
192         return "NUMERIC(%s, 0)" % bytes
193    
194     def coerce_int(self, cls, key):
195         prop = getattr(cls, key)
196         bytes = int(prop.hints.get(u'bytes', '4'))
197         if bytes == 1:
198             return "BOOLEAN"
199         elif bytes == 2:
200             return "SMALLINT"
201 ##        elif bytes == 3:
202 ##            return "MEDIUMINT"
203         else:
204             return u"INTEGER"
205
206
207 class AdapterToSQL(object):
208     """Coerce Expression constants to SQL.
209     
210     This base class is designed to work out-of-the-box with PostgreSQL 8.
211     """
212    
213     # Notice these are ordered pairs. Escape \ before introducing new ones.
214     escapes = [("'", "''"), ("\\", r"\\")]
215     like_escapes = [("%", r"\%"), ("_", r"\_")]
216    
217     # These are not the same as coerce_bool (which is used on one side of
218     # a comparison). Instead, these are used when the whole (sub)expression
219     # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3".
220     bool_true = "TRUE"
221     bool_false = "FALSE"
222    
223     def escape_like(self, value):
224         """Prepare a string value for use in a LIKE comparison."""
225         # Notice we strip leading and trailing quote-marks.
226         value = value.strip("'\"")
227         for pat, repl in self.like_escapes:
228             value = value.replace(pat, repl)
229         return value
230    
231     def coerce(self, value, valuetype=None):
232         """coerce(value, valuetype=None) -> value, coerced by valuetype."""
233         meth = getCoerceMethod(self, value, valuetype)
234         return meth(value)
235    
236     def tostr(self, value):
237         return str(value)
238    
239     def coerce_NoneType(self, value):
240         return "NULL"
241    
242     def coerce_bool(self, value):
243         if value:
244             return 'TRUE'
245         return 'FALSE'
246    
247     # The great thing about these 3 date coercers is that you can use
248     # them with (VAR)CHAR columns just as well as with DATETIME, etc.
249     # and comparisons will still work!
250     def coerce_datetime_datetime(self, value):
251         return (u"'%04d-%02d-%02d %02d:%02d:%02d'" %
252                 (value.year, value.month, value.day,
253                  value.hour, value.minute, value.second))
254    
255     def coerce_datetime_date(self, value):
256         return u"'%04d-%02d-%02d'" % (value.year, value.month, value.day)
257    
258     def coerce_datetime_time(self, value):
259         return u"'%02d:%02d:%02d'" % (value.hour, value.minute, value.second)
260    
261     def coerce_datetime_timedelta(self, value):
262         float_val = value.days + (value.seconds / 86400.0)
263         return repr(float_val)
264    
265     coerce_decimal = tostr
266     coerce_decimal_Decimal = tostr
267    
268     def do_pickle(self, value):
269         return self.coerce_str(pickle.dumps(value))
270    
271     coerce_dict = do_pickle
272    
273     coerce_fixedpoint_FixedPoint = tostr
274     coerce_float = tostr
275     coerce_int = tostr
276    
277     coerce_list = do_pickle
278    
279     coerce_long = tostr
280    
281     def coerce_str(self, value):
282         for pat, repl in self.escapes:
283             value = value.replace(pat, repl)
284         return "'" + value + "'"
285    
286     coerce_tuple = do_pickle
287    
288     coerce_unicode = coerce_str
289
290
291 class AdapterFromDB(object):
292     """Coerce incoming values from DB types to Dejavu datatypes.
293     
294     You might notice that we pass coltype around a lot, but don't
295     refer to it in this base class. It's there so subclasses can
296     make decisions about coercion when they don't have control over
297     the types of database columns, and must make do with legacy
298     databases implementations.
299     
300     This base class is designed to work out-of-the-box with PostgreSQL 8.
301     """
302    
303     def coerce(self, value, coltype, valuetype=None):
304         """coerce(value, coltype, valuetype=None) -> value, coerced by valuetype."""
305         if value is None:
306             return None
307        
308         meth = getCoerceMethod(self, value, valuetype)
309         return meth(value, coltype)
310    
311     def consume(self, unit, key, value, coltype):
312         try:
313             expectedType = unit.__class__.property_type(key)
314             value = self.coerce(value, coltype, expectedType)
315             unit._properties[key] = value
316         except UnicodeDecodeError, x:
317             x.reason += "[%s][%s][%s]" % (key, value, coltype)
318             raise
319         except Exception, x:
320             x.args += (key, value, coltype)
321             raise
322    
323     def do_pickle(self, value, coltype):
324         # Coerce to str for pickle.loads restriction.
325         value = str(value)
326         return pickle.loads(value)
327    
328     def coerce_bool(self, value, coltype):
329         return bool(value)
330    
331     def coerce_datetime_datetime(self, value, coltype):
332         chunks = (value[0:4], value[5:7], value[8:10],
333                   value[11:13], value[14:16], value[17:19])
334         return datetime.datetime(*map(int, chunks))
335    
336     def coerce_datetime_date(self, value, coltype):
337         chunks = (value[0:4], value[5:7], value[8:10])
338         return datetime.date(*map(int, chunks))
339    
340     def coerce_datetime_time(self, value, coltype):
341         chunks = (value[0:2], value[3:5], value[6:8])
342         return datetime.time(*map(int, chunks))
343    
344     def coerce_datetime_timedelta(self, value, coltype):
345         days, seconds = divmod(value, 1)
346         return datetime.timedelta(days, int(seconds * 86400))
347    
348     def coerce_decimal(self, value, coltype):
349         return decimal(str(value))
350    
351     def coerce_decimal_Decimal(self, value, coltype):
352         return decimal.Decimal(str(value))
353    
354     coerce_dict = do_pickle
355    
356     def coerce_fixedpoint_FixedPoint(self, value, coltype):
357         return fixedpoint.FixedPoint(value)
358    
359     def coerce_float(self, value, coltype):
360         return float(value)
361    
362     def coerce_int(self, value, coltype):
363         return int(value)
364    
365     coerce_list = do_pickle
366    
367     def coerce_long(self, value, coltype):
368         return long(value)
369    
370     def coerce_str(self, value, coltype):
371         return str(value)
372    
373     coerce_tuple = do_pickle
374    
375     def coerce_unicode(self, value, coltype):
376         # You should REALLY check into your DB's encoding and override this.
377         return unicode(value, "utf8")
378
379
380 # -------------------------- SQL DECOMPILATION -------------------------- #
381
382 class ConstWrapper(unicode):
383     """Wraps a constant for use in SQLDecompiler's stack.
384     
385     When we hit LOAD_CONST while decompiling, we occasionally need to keep
386     both the base and the coerced value around (see COMPARE_OP for use
387     of ConstWrapper.basevalue).
388     """
389     def __new__(self, basevalue, coerced_value):
390         newobj = unicode.__new__(ConstWrapper, coerced_value)
391         newobj.basevalue = basevalue
392         return newobj
393    
394     def __str__(self):
395         return str(self)
396    
397     def __unicode__(self):
398         return self
399    
400     def __repr__(self):
401         return self
402
403
404 class TableRef:
405     def __init__(self, classname):
406         self.classname = classname
407
408 # Stack sentinels
409 class Sentinel(object):
410    
411     def __init__(self, name):
412         self.name = name
413    
414     def __repr__(self):
415         return 'Stack Sentinel: %s' % self.name
416
417 kw_arg = Sentinel('Keyword Arg')
418 # cannot_represent exists so that a portion of an Expression can be
419 # labeled imperfect. For example, the function dejavu.iscurrentweek
420 # rarely has an SQL equivalent. All Units (which match the rest of the
421 # Expression) will be recalled; they can then be compared in expr(unit).
422 cannot_represent = Sentinel('Cannot Repr')
423
424
425 class SQLDecompiler(codewalk.LambdaDecompiler):
426     """SQLDecompiler(classnames, expr, sm, adapter=AdapterToSQL()).
427     
428     Produce SQL from a supplied Expression object, with a lambda of the form:
429         lambda x, **kw: ...
430     
431     Attributes of each argument in the signature will be mapped to table
432     columns. Keyword arguments should be bound using Expression.bind_args
433     before calling this decompiler.
434     """
435    
436     # Some constants are function or class objects,
437     # which should not be coerced.
438     no_coerce = (FunctionType,
439                  type,
440                  type(len),       # <type 'builtin_function_or_method'>
441                  )
442    
443     sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in')
444    
445     def __init__(self, classnames, expr, sm, adapter=AdapterToSQL()):
446         self.classnames = classnames
447         self.expr = expr
448         self.adapter = adapter
449         self.sm = sm
450         obj = expr.func
451         codewalk.LambdaDecompiler.__init__(self, obj)
452    
453     def code(self):
454         self.imperfect = False
455         self.walk()
456         # After walk(), self.stack should be reduced to a single string,
457         # which is the SQL representation of our Expression.
458         result = self.stack[0]
459         if result is cannot_represent:
460             # The entire expression could not be evaluated.
461             result = self.adapter.bool_true
462         if result == self.adapter.coerce_bool(True):
463             result = self.adapter.bool_true
464         if result == self.adapter.coerce_bool(False):
465             result = self.adapter.bool_false
466         return result
467    
468     def visit_instruction(self, op, lo=None, hi=None):
469         # Get the instruction pointer for the current instruction.
470         ip = self.cursor - 3
471         if hi is None:
472             ip += 1
473             if lo is None:
474                 ip += 1
475        
476         terms = self.targets.get(ip)
477         if terms:
478             trueval = self.adapter.bool_true
479             falseval = self.adapter.bool_false
480             clause = self.stack[-1]
481             while terms:
482                 term, oper = terms.pop()
483                 if term is cannot_represent:
484                     term = trueval
485                 if clause is cannot_represent:
486                     clause = trueval
487                
488                 # Blurg. SQL Server is *so* picky.
489                 if term == self.adapter.coerce_bool(True):
490                     term = trueval
491                 elif term == self.adapter.coerce_bool(False):
492                     term = falseval
493                 if clause == self.adapter.coerce_bool(True):
494                     clause = trueval
495                 elif clause == self.adapter.coerce_bool(False):
496                     clause = falseval
497                
498                 clause = "(%s) %s (%s)" % (term, oper.upper(), clause)
499            
500             # Replace TOS with the new clause, so that further
501             # combinations have access to it.
502             self.stack[-1] = clause
503             self.debug("clause:", clause, "\n")
504            
505             if op == 1:
506                 # Py2.4: The current instruction is POP_TOP, which means
507                 # the previous is probably JUMP_*. If so, we're going to
508                 # pop the value we just placed on the stack and lose it.
509                 # We need to replace the entry that the JUMP_* made in
510                 # self.targets with our new TOS.
511                 target = self.targets[self.last_target_ip]
512                 target[-1] = ((clause, target[-1][1]))
513                 self.debug("newtarget:", self.last_target_ip, target)
514    
515     def visit_LOAD_DEREF(self, lo, hi):
516         raise ValueError("Illegal reference found in %s." % self.expr)
517    
518     def visit_LOAD_GLOBAL(self, lo, hi):
519         raise ValueError("Illegal global found in %s." % self.expr)
520    
521     def visit_LOAD_FAST(self, lo, hi):
522         arg_index = lo + (hi << 8)
523         if arg_index < self.co_argcount:
524             self.stack.append(TableRef(self.classnames[arg_index]))
525         else:
526             self.stack.append(kw_arg)
527    
528     def visit_LOAD_ATTR(self, lo, hi):
529         name = self.co_names[lo + (hi << 8)]
530         tos = self.stack.pop()
531         if isinstance(tos, TableRef):
532             atom = self.sm.column_name(tos.classname, name, full=True)
533         else:
534             # tos.name will reference an attribute of the tos object.
535             # Stick the tos and name in a tuple for later processing.
536             atom = (tos, name)
537         self.stack.append(atom)
538    
539     def visit_LOAD_CONST(self, lo, hi):
540         val = self.co_consts[lo + (hi << 8)]
541         if not isinstance(val, self.no_coerce):
542             val = ConstWrapper(val, self.adapter.coerce(val))
543         self.stack.append(val)
544    
545     def visit_BUILD_TUPLE(self, lo, hi):
546         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)])
547         self.stack.append("(" + terms + ")")
548    
549     visit_BUILD_LIST = visit_BUILD_TUPLE
550    
551     def visit_CALL_FUNCTION(self, lo, hi):
552         kwargs = {}
553         for i in xrange(hi):
554             val = self.stack.pop()
555             key = self.stack.pop()
556             kwargs[key] = val
557         kwargs = [k + "=" + v for k, v in kwargs.iteritems()]
558        
559         args = []
560         for i in xrange(lo):
561             arg = self.stack.pop()
562             args.append(arg)
563         args.reverse()
564        
565         if kwargs:
566             args += kwargs
567        
568         func = self.stack.pop()
569        
570         # Handle function objects.
571         if isinstance(func, tuple):
572             tos, name = func
573             dispatch = getattr(self, "attr_" + name, None)
574             if dispatch:
575                 self.stack.append(dispatch(tos, *args))
576                 return
577         else:
578             funcname = func.__module__ + "_" + func.__name__
579             funcname = funcname.replace(".", "_")
580             if funcname.startswith("_"):
581                 funcname = "func" + funcname
582             dispatch = getattr(self, funcname, None)
583             if dispatch:
584                 self.stack.append(dispatch(*args))
585                 return
586        
587         self.stack.append(cannot_represent)
588         self.imperfect = True
589    
590     def visit_COMPARE_OP(self, lo, hi):
591         op2, op1 = self.stack.pop(), self.stack.pop()
592         if op1 is cannot_represent or op2 is cannot_represent:
593             self.stack.append(cannot_represent)
594             return
595        
596         op = lo + (hi << 8)
597         if op in (6, 7):     # in, not in
598             value = self.containedby(op1, op2)
599             if op == 7:
600                 value = "NOT " + value
601             self.stack.append(value)
602         elif op1 == 'NULL':
603             if op == 2:
604                 self.stack.append(op2 + " IS NULL")
605             elif op == 3:
606                 self.stack.append(op2 + " IS NOT NULL")
607             else:
608                 raise ValueError("Non-equality Null comparisons not allowed.")
609         elif op2 == 'NULL':
610             if op == 2:
611                 self.stack.append(op1 + " IS NULL")
612             elif op == 3:
613                 self.stack.append(op1 + " IS NOT NULL")
614             else:
615                 raise ValueError("Non-equality Null comparisons not allowed.")
616         else:
617             # Comparison operators for strings are case-sensitive in PG et al.
618             self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2)
619    
620     def binary_op(self, op):
621         op2, op1 = self.stack.pop(), self.stack.pop()
622         self.stack.append(op1 + " " + op + " " + op2)
623    
624     def visit_BINARY_SUBSCR(self):
625         # The only BINARY_SUBSCR used in Expressions should be kwargs[key].
626         name = self.stack.pop()
627         tos = self.stack.pop()
628         if tos is not kw_arg:
629             raise ValueError("Subscript %s of %s object not allowed."
630                              % (name, tos))
631         # name, since formed in LOAD_CONST, may have extraneous quotes.
632         name = name.strip("'\"")
633         value = self.expr.kwargs[name]
634         if not isinstance(value, self.no_coerce):
635             value = ConstWrapper(value, self.adapter.coerce(value))
636         self.stack.append(value)
637    
638     def visit_UNARY_NOT(self):
639         op = self.stack.pop()
640         if op is cannot_represent:
641             self.stack.append(cannot_represent)
642         else:
643             self.stack.append("NOT (" + op + ")")
644    
645     # --------------------------- Dispatchees --------------------------- #
646    
647     def attr_startswith(self, tos, arg):
648         return tos + " LIKE '" + self.adapter.escape_like(arg) + "%'"
649    
650     def attr_endswith(self, tos, arg):
651         return tos + " LIKE '%" + self.adapter.escape_like(arg) + "'"
652    
653     def containedby(self, op1, op2):
654         if isinstance(op1, ConstWrapper):
655             # Looking for text in a field. Use Like (reverse terms).
656             return op2 + " LIKE '%" + self.adapter.escape_like(op1) + "%'"
657         else:
658             # Looking for field in (a, b, c)
659             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
660             return op1 + " IN (" + ", ".join(atoms) + ")"
661    
662     def dejavu_icontainedby(self, op1, op2):
663         if isinstance(op1, ConstWrapper):
664             # Looking for text in a field. Use Like (reverse terms).
665             return ("LOWER(" + op2 + ") LIKE '%" +
666                     self.adapter.escape_like(op1).lower() + "%'")
667         else:
668             # Looking for field in (a, b, c).
669             # Force all args to lowercase for case-insensitive comparison.
670             atoms = [self.adapter.coerce(x).lower() for x in op2.basevalue]
671             return "LOWER(%s) IN (%s)" % (op1, ", ".join(atoms))
672    
673     def dejavu_icontains(self, x, y):
674         return self.dejavu_icontainedby(y, x)
675    
676     def dejavu_istartswith(self, x, y):
677         return "LOWER(" + x + ") LIKE '" + self.adapter.escape_like(y) + "%'"
678    
679     def dejavu_iendswith(self, x, y):
680         return "LOWER(" + x + ") LIKE '%" + self.adapter.escape_like(y) + "'"
681    
682     def dejavu_ieq(self, x, y):
683         return "LOWER(" + x + ") = LOWER(" + y + ")"
684    
685     def dejavu_now(self):
686         return "NOW()"
687    
688     def dejavu_today(self):
689         return "CURRENT_DATE"
690    
691     def dejavu_year(self, x):
692         return "YEAR(" + x + ")"
693    
694     def func__builtin___len(self, x):
695         return "LENGTH(" + x + ")"
696
697
698 class ConnectionWrapper(object):
699     def __init__(self, conn=None):
700         self.conn = conn
701    
702     def __getattr__(self, attr):
703         return getattr(self.conn, attr)
704
705
706 class UnitClassWrapper(object):
707    
708     def __init__(self, wclass, sm):
709         self.cls = wclass
710         self.sm = sm
711        
712         wclsname = wclass.__name__
713         self.tablename = sm.table_name(wclsname)
714         self.alias = ""
715    
716     def columns(self):
717         wclass = self.cls
718        
719         # Place the identifier properties first
720         # in case others depend upon them.
721         idnames = [prop.key for prop in wclass.identifiers]
722         keys = idnames + [k for k in wclass.properties() if k not in idnames]
723         cols = [(wclass, k) for k in keys]
724         colnames = ['%s.%s' % (self.alias or self.tablename,
725                                self.sm.column_name(wclass.__name__, k))
726                     for k in keys]
727         return cols, colnames
728    
729     def _joinname(self):
730         if self.alias:
731             return "%s AS %s" % (self.tablename, self.alias)
732         else:
733             return self.tablename
734     joinname = property(_joinname, doc=("Table name for use in "
735                                             "JOIN clause (read-only)."))
736    
737     def association(self, classes):
738         for other in classes:
739             ua = self.cls._associations.get(other.cls.__name__, None)
740             if ua:
741                 nearClass = self.alias or self.tablename
742                 farClass = other.alias or other.tablename
743                 return ua, nearClass, farClass
744             ua = other.cls._associations.get(self.cls.__name__, None)
745             if ua:
746                 nearClass = other.alias or other.tablename
747                 farClass = self.alias or self.tablename
748                 return ua, nearClass, farClass
749         return None
750
751
752 class StorageManagerDB(storage.StorageManager):
753     """StoreManager to save and retrieve Units using a DB."""
754    
755     sql_name_max_length = 64
756     sql_name_caseless = False
757     close_connection_method = 'close'
758     use_asterisk_to_get_all = False
759    
760     decompiler = SQLDecompiler
761     typeAdapter = FieldTypeAdapter()
762     toAdapter = AdapterToSQL()
763     fromAdapter = AdapterFromDB()
764     debug_connections = False
765    
766     def __init__(self, name, arena, allOptions={}):
767         storage.StorageManager.__init__(self, name, arena, allOptions)
768        
769         # Adapter Overrides
770         def get_adapter_option(name):
771             adapter_class = allOptions.get(name)
772             if isinstance(adapter_class, basestring):
773                 adapter_class = xray.classes(adapter_class)
774             return adapter_class
775        
776         adapter = get_adapter_option(u'Type Adapter')
777         if adapter: self.typeAdapter = adapter
778         adapter = get_adapter_option(u'To Adapter')
779         if adapter: self.toAdapter = adapter
780         adapter = get_adapter_option(u'From Adapter')
781         if adapter: self.fromAdapter = adapter
782        
783         self.pool_size = int(allOptions.get(u'Pool Size', '10'))
784        
785         self.refs = {}
786         if self.pool_size > 0:
787             self.pool = Queue.Queue(self.pool_size)
788         else:
789             self.pool = None
790         self.retry = 5
791         self.threaded = True
792        
793         self.prefix = allOptions.get(u'Prefix', u"djv")
794         self.reserve_lock = threading.Lock()
795        
796         ec = {}
797         for prop in allOptions.get(u'Expanded Columns', '').split(","):
798             if prop:
799                 cls, type = [x.strip() for x in prop.split(":", 1)]
800                 lastdot = cls.rfind(".")
801                 clsname, key = cls[:lastdot], cls[lastdot + 1:]
802                 ec[(clsname, key)] = type
803         self.expanded_columns = ec
804    
805     #                               Naming                               #
806    
807     def sql_name(self, name, quoted=True):
808         """The name, escaped for SQL."""
809         if self.sql_name_caseless:
810             name = name.lower()
811        
812         maxlen = self.sql_name_max_length
813         if maxlen and len(name) > maxlen:
814             warnings.warn("The name '%s' is longer than the maximum of "
815                           "%s characters." % (name, maxlen),
816                           dejavu.StorageWarning)
817             name = name[:maxlen]
818        
819         # This base class doesn't use the "quoted" arg,
820         # but most subclasses will.
821         return name
822    
823     def column_name(self, classname, name, full=False, quoted=True):
824         """The column name, escaped for SQL. If full, include tablename."""
825         # If you want to use a map from UnitProperty names
826         # to DB column names, override this method.
827         name = self.sql_name(name, quoted=quoted)
828         if not full:
829             return name
830        
831         alias = getattr(classname, "alias", None)
832         if alias is None:
833             tname = self.table_name(classname, quoted=quoted)
834         else:
835             tname = (classname.alias or classname.tablename)
836         return '%s.%s' % (tname, name)
837    
838     def table_name(self, name, quoted=True):
839         """The table name, escaped for SQL."""
840         # If you want to use a map from Unit class names
841         # to DB table names, override this method.
842         return self.sql_name(self.prefix + name, quoted=quoted)
843    
844     #                             Connecting                              #
845    
846     def _get_conn(self):
847         # Override this with the connection call for your DB. Example:
848         #     return libpq.PQconnectdb(self.connstring)
849         raise NotImplementedError
850    
851     def connection(self):
852         if not self.threaded:
853             # Place a single 'conn' entry in self.refs.
854             try:
855                 return self.refs['conn']
856             except KeyError:
857                 self.refs['conn'] = conn = self._get_conn()
858                 return conn
859        
860         retry = 0
861         while True:
862             if self.pool is not None:
863                 try:
864                     conn = self.pool.get_nowait()
865                     # Okay, this is freaky. If we wrap here, all goes well.
866                     # If we wrap on Queue.put(), mysql crashes after 1700
867                     # or so inserts (when migrating Access tables to MySQL).
868                     # Go figure.
869                     w = ConnectionWrapper(conn)
870                     self.refs[weakref.ref(w, self.release)] = w.conn
871                     self.arena.log("-->get %s" % self.__class__.__name__,
872                                    LOGCONN)
873                     return w
874                 except Queue.Empty:
875                     pass
876            
877             try:
878                 conn = self._get_conn()
879                 w = ConnectionWrapper(conn)
880                 self.refs[weakref.ref(w, self.release)] = w.conn
881                 self.arena.log("create %s" % self.__class__.__name__, LOGCONN)
882                 return w
883             except OutOfConnectionsError:
884                 retry += 1
885                 if retry < self.retry:
886                     time.sleep(retry * 1)
887                     conn = None
888                     continue
889                 raise OutOfConnectionsError()
890    
891     def release(self, ref):
892         # This method should only be called if self.threaded is True
893         conn = self.refs.pop(ref)
894        
895         if self.pool is not None:
896             try:
897                 self.pool.put_nowait(conn)
898                 self.arena.log("<--put %s" % self.__class__.__name__, LOGCONN)
899                 return
900             except Queue.Full:
901                 pass
902        
903         getattr(conn, self.close_connection_method)()
904         self.arena.log("___close___ %s" % self.__class__.__name__, LOGCONN)
905    
906     def shutdown(self):
907         # Empty the pool.
908         if self.pool:
909             while True:
910                 try:
911                     self.pool.get(True, 0.5)
912                 except Queue.Empty:
913                     break
914        
915         # Empty self.refs.
916         while self.refs:
917             ref, conn = self.refs.popitem()
918             getattr(conn, self.close_connection_method)()
919    
920     def select(self, cls, expr, fields=None, distinct=False):
921         clsname = cls.__name__
922         tablename = self.table_name(clsname)
923         if fields:
924             fields = [self.column_name(clsname, x) for x in fields]
925             if distinct:
926                 sql = u'SELECT DISTINCT %s FROM %s'
927             else:
928                 sql = u'SELECT %s FROM %s'
929             sql = sql % (u', '.join(fields), tablename)
930         else:
931             sql = u'SELECT * FROM %s' % tablename
932        
933         w, i = self.where((clsname,), expr)
934         if len(w) > 0:
935             w = u" WHERE " + w
936         else:
937             w = u""
938         sql += w + ";"
939         return sql, i
940    
941     def where(self, classnames, expr):
942         decom = self.decompiler(classnames, expr, self, self.toAdapter)
943         return decom.code(), decom.imperfect
944    
945     def execute(self, query, conn=None):
946         """execute(query, conn=None) -> result set."""
947         if conn is None:
948             conn = self.connection()
949         self.arena.log(query, LOGSQL)
950         return conn.query(query.encode('utf8'))
951    
952     def fetch(self, query, conn=None):
953         """fetch(query, conn=None) -> rowdata, columns.
954         
955         rowdata will be an iterable of iterables containing the result values.
956         columns will be an iterable of (column name, data type) pairs.
957         
958         This base class uses SQLite3 syntax.
959         """
960         res = self.execute(query, conn)
961         return res.row_list, res.col_defs
962    
963     def recall(self, cls, expr=None):
964         clsname = cls.__name__
965        
966         if expr is None:
967             expr = logic.Expression(lambda x: True)
968         sql, imperfect = self.select(cls, expr)
969         data, col_defs = self.fetch(sql)
970         if data:
971             columns = dict([(col[0], (index, col[1])) for index, col
972                             in enumerate(col_defs)])
973            
974             # Get specs on properties. Put the identifier properties
975             # first, in case other fields depend upon them.
976             # See load_expanded, for example.
977             props = []
978             idnames = [prop.key for prop in cls.identifiers]
979             for key in idnames + [x for x in cls.properties() if x not in idnames]:
980                 index, ftype = columns[self.column_name(clsname, key, quoted=False)]
981                 subtype = self.expanded_columns.get((clsname, key))
982                 props.append((key, index, ftype, subtype))
983            
984             consume = self.fromAdapter.consume
985             for row in data:
986                 unit = cls()
987                 for key, index, ftype, subtype in props:
988                     value = row[index]
989                     if subtype:
990                         self.load_expanded(unit, key, subtype)
991                     else:
992                         consume(unit, key, value, ftype)
993                
994                 # If our SQL is imperfect, don't yield it to the
995                 # caller unless it passes expr(unit).
996                 if (not imperfect) or expr(unit):
997                     unit.cleanse()
998                     yield unit
999    
1000     def reserve(self, unit):
1001         """reserve(unit). -> Reserve a persistent slot for unit.
1002         
1003         Notice in particular that we do not use the auto-number or
1004         sequence generation capabilities within some databases, etc.
1005         The identifiers should be supplied by UnitSequencers via reserve().
1006         """
1007         cls = unit.__class__
1008         clsname = cls.__name__
1009         tablename = self.table_name(clsname)
1010         self.reserve_lock.acquire()
1011         try:
1012             if not unit.sequencer.valid_id(unit.identity()):
1013                 # Examine all existing IDs and grant the "next" one.
1014                 id_fields = [self.column_name(clsname, prop.key)
1015                              for prop in cls.identifiers]
1016                 data, cols = self.fetch(u'SELECT %s FROM %s;' %
1017                                         (', '.join(id_fields), tablename))
1018                 if data:
1019                     # sqlite 2, for example, has empty cols tuple if no data.
1020                     coerce = self.fromAdapter.coerce
1021                     coltypes = [cols[x][1] for x in xrange(len(cols))]
1022                     expectedTypes = [prop.type for prop in cls.identifiers]
1023                     newdata = []
1024                     for row in data:
1025                         newrow = []
1026                         for x, cell in enumerate(row):
1027                             newrow.append(coerce(cell, coltypes[x],
1028                                                  expectedTypes[x]))
1029                         newdata.append(newrow)
1030                     data = newdata
1031                     del newdata
1032                 cls.sequencer.assign(unit, data)
1033                 del data
1034                 del cols
1035            
1036             fields = []
1037             values = []
1038             for key in cls.properties():
1039                 subtype = self.expanded_columns.get((clsname, key))
1040                 if subtype:
1041                     self.save_expanded(unit, key, subtype)
1042                 else:
1043                     val = self.toAdapter.coerce(getattr(unit, key))
1044                     fields.append(self.column_name(clsname, key))
1045                     values.append(val)
1046            
1047             fields = u", ".join(fields)
1048             values = u", ".join(values)
1049             self.execute('INSERT INTO %s (%s) VALUES (%s);' %
1050                          (tablename, fields, values))
1051             unit.cleanse()
1052         finally:
1053             self.reserve_lock.release()
1054    
1055     def id_clause(self, unit):
1056         clsname = unit.__class__.__name__
1057         col = self.column_name
1058         c = self.toAdapter.coerce
1059         idnames = [prop.key for prop in unit.identifiers]
1060         return " AND ".join(["%s = %s" % (col(clsname, key),
1061                                           c(getattr(unit, key)))
1062                              for key in idnames])
1063    
1064     def save(self, unit, forceSave=False):
1065         """save(unit, forceSave=False) -> Update storage from unit's data."""
1066         if unit.dirty() or forceSave:
1067             cls = unit.__class__
1068             clsname = cls.__name__
1069            
1070             parms = []
1071             idnames = [prop.key for prop in cls.identifiers]
1072             for key in cls.properties():
1073                 if key not in idnames:
1074                     subtype = self.expanded_columns.get((clsname, key))
1075                     if subtype:
1076                         self.save_expanded(unit, key, subtype)
1077                     else:
1078                         val = self.toAdapter.coerce(getattr(unit, key))
1079                         parms.append('%s = %s' %
1080                                      (self.column_name(clsname, key), val))
1081            
1082             if parms:
1083                 sql = ('UPDATE %s SET %s WHERE %s;' %
1084                        (self.table_name(clsname), u", ".join(parms),
1085                         self.id_clause(unit)))
1086                 self.execute(sql)
1087             unit.cleanse()
1088    
1089     def save_expanded(self, unit, key, subtype):
1090         """save_expanded(unit, key, subtype). Save list in separate table."""
1091         unitcls = unit.__class__
1092         id = "_".join(map(str, unit.identity()))
1093         table = self.table_name("_%s_%s_%s" % (unitcls.__name__, id, key))
1094        
1095         # Just drop the old table and start with a new one.
1096         try:
1097             self.execute(u"DROP TABLE %s;" % table)
1098         except:
1099             pass
1100        
1101         val = getattr(unit, key)
1102         if val is None:
1103             # Don't create a new table at all. This will signal
1104             # recall() to set the attribute to None on load.
1105             pass
1106         else:
1107             ftype = getattr(self.typeAdapter, "coerce_" + subtype)(unitcls, key)
1108             self.execute(u"CREATE TABLE %s (EXPVAL %s);" % (table, ftype))
1109            
1110             for v in val:
1111                 self.execute(u"INSERT INTO %s (EXPVAL) VALUES ('%s');"
1112                              % (table, self.toAdapter.coerce(v)))
1113    
1114     def load_expanded(self, unit, key, subtype):
1115         """load_expanded(unit, key, subtype). Load list from separate table."""
1116         unitcls = unit.__class__
1117         id = "_".join(map(str, unit.identity()))
1118         table = self.table_name("_%s_%s_%s" % (unitcls.__name__, id, key))
1119         try:
1120             data, col_defs = self.fetch(u"SELECT EXPVAL FROM %s" % table)
1121         except:
1122             values = None
1123         else:
1124             coltype = col_defs[0][1]
1125             coercer = getattr(self.fromAdapter, "coerce_" + subtype)
1126             values = [coercer(row[0], coltype) for row in data]
1127            
1128             expected_type = unitcls.property_type(key)
1129             values = expected_type(values)
1130        
1131         # Set the attribute directly to avoid __set__ overhead.
1132         unit._properties[key] = values
1133    
1134     def destroy(self, unit):
1135         """destroy(unit). Delete the unit."""
1136         if self.use_asterisk_to_get_all:
1137             star = " *"
1138         else:
1139             star = ""
1140         self.execute(u'DELETE%s FROM %s WHERE %s;' %
1141                      (star, self.table_name(unit.__class__.__name__),
1142                       self.id_clause(unit)))
1143    
1144     def view(self, cls, fields, expr=None):
1145         """view(cls, fields, expr=None) -> All value-tuples for given fields."""
1146         if expr is None:
1147             expr = logic.Expression(lambda x: True)
1148        
1149         sql, imperfect = self.select(cls, expr, fields)
1150         if imperfect:
1151             # ^%$#@! There's no way to handle imperfect queries without
1152             # creating all involved Units, which defeats the purpose of
1153             # view, which was a speed issue more than anything else.
1154             warnings.warn("The requested view() query for %s Units "
1155                           "cannot produce perfect SQL with a %s datasource. "
1156                           "It may take an absurd amount of time to run, "
1157                           "since each unit must be fully-formed. %s"
1158                           % (cls.__name__, self.__class__.__name__, expr),
1159                           dejavu.StorageWarning)
1160             for unit in self.recall(cls, expr):
1161                 # Use tuples for hashability
1162                 yield tuple([getattr(unit, f) for f in fields])
1163         else:
1164             data, columns = self.fetch(sql)
1165             actualTypes = [x[1] for x in columns]
1166             expectedTypes = [cls.property_type(x) for x in fields]
1167            
1168             coerce = self.fromAdapter.coerce
1169             # Use tuples for hashability
1170             for row in data:
1171                 yield tuple([coerce(val, actualTypes[i], expectedTypes[i])
1172                              for i, val in enumerate(row)])
1173    
1174     def distinct(self, cls, fields, expr=None):
1175         """distinct(cls, fields, expr=None) -> Distinct values for given fields."""
1176         if expr is None:
1177             expr = logic.Expression(lambda x: True)
1178        
1179         sql, imperfect = self.select(cls, expr, fields, distinct=True)
1180         if imperfect:
1181             # ^%$#@! There's no way to handle imperfect queries without
1182             # creating all involved Units, which defeats the purpose of
1183             # distinct, which was a speed issue more than anything.
1184             warnings.warn("The requested distinct() query for %s Units "
1185                           "cannot produce perfect SQL with a %s datasource. "
1186                           "It may take an absurd amount of time to run, "
1187                           "since each unit must be fully-formed. %s"
1188                           % (cls.__name__, self.__class__.__name__, expr),
1189                           dejavu.StorageWarning)
1190             vals = {}
1191             for unit in self.recall(cls, expr):
1192                 # Must use tuples for hashability
1193                 val = tuple([getattr(unit, f) for f in fields])
1194                 vals[val] = None
1195             return vals.keys()
1196         else:
1197             data, columns = self.fetch(sql)
1198             actualTypes = [x[1] for x in columns]
1199             expectedTypes = [cls.property_type(x) for x in fields]
1200            
1201             coerce = self.fromAdapter.coerce
1202             # Must use inner tuples for hashability in Sandbox.distinct()
1203             return [tuple([coerce(val, actualTypes[i], expectedTypes[i])
1204                            for i, val in enumerate(row)])
1205                      for row in data]
1206    
1207     def join(self, unitjoin):
1208         cls1, cls2 = unitjoin.class1, unitjoin.class2
1209         if isinstance(cls1, dejavu.UnitJoin):
1210             name1 = self.join(cls1)
1211             classlist1 = iter(cls1)
1212         else:
1213             # cls1 is a Unit class wrapper.
1214             name1 = cls1.joinname
1215             classlist1 = [cls1]
1216        
1217         if isinstance(cls2, dejavu.UnitJoin):
1218             name2 = self.join(cls2)
1219             classlist2 = iter(cls2)
1220         else:
1221             # cls2 is a Unit class wrapper.
1222             name2 = cls2.joinname
1223             classlist2 = [cls2]
1224        
1225         j = {None: "INNER", True: "LEFT", False: "RIGHT"}[unitjoin.leftbiased]
1226        
1227         # Find an association between the two halves.
1228         ua = None
1229         for clsA in classlist1:
1230             ua = clsA.association(classlist2)
1231             if ua:
1232                 ua, nearClass, farClass = ua
1233                 break
1234         if ua is None:
1235             msg = ("No association found between %s and %s." % (name1, name2))
1236             raise dejavu.AssociationError(msg)
1237         near = '%s.%s' % (nearClass, self.column_name(nearClass, ua.nearKey))
1238         far = '%s.%s' % (farClass, self.column_name(farClass, ua.farKey))
1239        
1240         return "(%s %s JOIN %s ON %s = %s)" % (name1, j, name2, near, far)
1241    
1242     def multiselect(self, classes, expr):
1243        
1244         # Create a new unitjoin tree where each class is wrapped.
1245         # Then we can tag the wrappers with metadata with impunity.
1246         seen = {}
1247         aliascount = [0]
1248        
1249         def wrap(unitjoin):
1250             cls1, cls2 = unitjoin.class1, unitjoin.class2
1251             if isinstance(cls1, dejavu.UnitJoin):
1252                 wclass1 = wrap(cls1)
1253             else:
1254                 wclass1 = UnitClassWrapper(cls1, self)
1255                 if cls1 in seen:
1256                     aliascount[0] += 1
1257                     wclass1.alias = "t%d" % aliascount[0]
1258                 else:
1259                     seen[cls1] = None
1260             if isinstance(cls2, dejavu.UnitJoin):
1261                 wclass2 = wrap(cls2)
1262             else:
1263                 wclass2 = UnitClassWrapper(cls2, self)
1264                 if cls2 in seen:
1265                     aliascount[0] += 1
1266                     wclass2.alias = "t%d" % aliascount[0]
1267                 else:
1268                     seen[cls2] = None
1269             return dejavu.UnitJoin(wclass1, wclass2, unitjoin.leftbiased)
1270         classes = wrap(classes)
1271        
1272         joins = self.join(classes)
1273        
1274         if expr is None:
1275             expr = logic.Expression(lambda *args: True)
1276         w, imp = self.where(list(classes), expr)
1277        
1278         cols = []
1279         colnames = []
1280         for wrapper in classes:
1281             c, names = wrapper.columns()
1282             cols.extend(c)
1283             colnames.extend(names)
1284        
1285         statement = ("SELECT %s FROM %s WHERE %s" %
1286                      (u', '.join(colnames), joins, w))
1287         return statement, imp, cols
1288    
1289     def multirecall(self, classes, expr):
1290         """multirecall(classes, expr) -> Full inner join units."""
1291         sql, imp, supplied_cols = self.multiselect(classes, expr)
1292         data, recvd_cols = self.fetch(sql)
1293         if data:
1294             # Get specs on properties.
1295             props = []
1296             for sup, rec in zip(supplied_cols, recvd_cols):
1297                 c, key = sup
1298                 name, ftype = rec[0], rec[1]
1299                 subtype = self.expanded_columns.get((c.__name__, key))
1300                 props.append((c, key, ftype, subtype))
1301            
1302             consume = self.fromAdapter.consume
1303             for row in data:
1304                 index = 0
1305                 units = {}
1306                 for c, key, ftype, subtype in props:
1307                     if c in units:
1308                         unit = units[c]
1309                     else:
1310                         units[c] = unit = c()
1311                     value = row[index]
1312                     if subtype:
1313                         self.load_expanded(unit, key, subtype)
1314                     else:
1315                         consume(unit, key, value, ftype)
1316                     index += 1
1317                
1318                 unitset = []
1319                 for cls in classes:
1320                     unit = units[cls]
1321                     unit.cleanse()
1322                     unitset.append(unit)
1323                
1324                 # If our SQL is imperfect, don't yield units to the
1325                 # caller unless they pass expr(unit).
1326                 acceptable = True
1327                 if imp:
1328                     acceptable = expr(*unitset)
1329                 if acceptable:
1330                     yield unitset
1331    
1332     #                               Schemas                               #
1333    
1334     def create_database(self):
1335         self.execute("CREATE DATABASE %s;" % self.sql_name(self.dbname))
1336    
1337     def drop_database(self):
1338         self.execute("DROP DATABASE %s;" % self.sql_name(self.dbname))
1339    
1340     def create_storage(self, cls):
1341         """Create storage for the given class."""
1342         clsname = cls.__name__
1343         tablename = self.table_name(clsname)
1344         typename = self.typeAdapter.coerce
1345        
1346         fields = []
1347         for key in cls.properties():
1348             fields.append(u'%s %s' % (self.column_name(clsname, key),
1349                                       typename(cls, key)))
1350         self.execute(u'CREATE TABLE %s (%s);' % (tablename, ", ".join(fields)))
1351        
1352         for index in cls.indices():
1353             i = self.table_name("i" + clsname + index)
1354             self.execute(u'CREATE INDEX %s ON %s (%s);' %
1355                          (i, tablename, self.column_name(clsname, index)))
1356    
1357     def has_storage(self, cls):
1358         try:
1359             # Must use fetch here instead of execute, because e.g. MySQL
1360             # must call store_result if the query has a result set
1361             # (or it will crash on a subsequent execute).
1362             self.fetch("SELECT * FROM %s;" % self.table_name(cls.__name__))
1363         except:
1364             return False
1365         return True
1366    
1367     def drop_storage(self, cls):
1368         self.execute(u'DROP TABLE %s;' % self.table_name(cls.__name__))
1369    
1370     def add_property(self, cls, name):
1371         if not self.has_property(cls, name):
1372             clsname = cls.__name__
1373             self.execute("ALTER TABLE %s ADD COLUMN %s %s;" %
1374                          (self.table_name(clsname),
1375                           self.column_name(clsname, name),
1376                           self.typeAdapter.coerce(cls, name),
1377                           ))
1378    
1379     def has_property(self, cls, name):
1380         clsname = cls.__name__
1381         try:
1382             # Must use fetch here instead of execute, because e.g. MySQL
1383             # must call store_result if the query has a result set
1384             # (or it will crash on a subsequent execute).
1385             self.fetch("SELECT %s FROM %s;" %
1386                        (self.column_name(clsname, name),
1387                         self.table_name(clsname)))
1388         except:
1389             return False
1390         return True
1391    
1392     def drop_property(self, cls, name):
1393         if self.has_property(cls, name):
1394             clsname = cls.__name__
1395             self.execute("ALTER TABLE %s DROP COLUMN %s;" %
1396                          (self.table_name(clsname),
1397                           self.column_name(clsname, name)))
1398    
1399     def rename_property(self, cls, oldname, newname):
1400         clsname = cls.__name__
1401         oldname = self.column_name(clsname, oldname)
1402         newname = self.column_name(clsname, newname)
1403         if oldname != newname:
1404             self.execute("ALTER TABLE %s RENAME COLUMN %s TO %s;" %
1405                          (self.table_name(clsname), oldname, newname))
1406
1407
1408 class OutOfConnectionsError(dejavu.DejavuError):
1409     """Exception raised when a database store has run out of connections."""
1410     pass
1411
Note: See TracBrowser for help on using the browser.