Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/db.py

Revision 193 (checked in by fumanchu, 7 years ago)

Fix for #51 (remove expanded columns). If anyone objects, this can be reinstated with very little work.

  • Property svn:eol-style set to native
Line 
1 """Base classes and tools for writing database Storage Managers.
2
3 DATA TYPES
4 ==========
5 Database Storage Manager modules are mostly adapters to support round-trip
6 data coercion:
7
8 Unit type -> [SQL repr ->] DB -> incoming Python value -> Unit type
9
10 Since Dejavu relies on external database servers for its persistence,
11 Python datatypes must be converted to column types in the DB. When writing
12 a StorageManager, you should make sure that your type conversions can handle
13 at least the following limitations: If possible, implement the type with no
14 limits. Also, follow UnitProperty.hints['bytes'] where possible. A value
15 of zero for hints['bytes'] implies no limit. If no value is given, try to
16 assume no limit, although you may choose whatever default size you wish
17 (255 is common for strings).
18
19 ENCODING ISSUES
20 ===============
21 All SQL sent to the database must be strings, not unicode. You can set the
22 encoding of the Adapters (I may add a more centralized encoding context in
23 the future). We must use encoded strings so that we can mix encodings
24 within the same string; for example, we might have a DB which understands
25 utf8, but a pickle value which will be encoded in raw-unicode-escape inline
26 with that. All values, therefore, must be coerced before we try to join
27 them into an SQL statement string.
28
29 """
30
31 import datetime
32
33 try:
34     # Builtin in Python 2.5?
35     decimal
36 except NameError:
37     try:
38         # Module in Python 2.3, 2.4
39         import decimal
40     except ImportError:
41         pass
42
43 try:
44     import fixedpoint
45 except ImportError:
46     pass
47
48 try:
49     import cPickle as pickle
50 except ImportError:
51     import pickle
52
53 import Queue
54 import threading
55 import time
56 from types import FunctionType
57 import warnings
58 import weakref
59
60
61 import dejavu
62 from dejavu import codewalk, logic, storage, LOGCONN, LOGSQL, xray
63
64
65 def getCoerceName(valuetype):
66     """Return the name of the coercion method for a given type."""
67     mod = valuetype.__module__
68     if mod == "__builtin__":
69         xform = "coerce_%s" % valuetype.__name__
70     else:
71         xform = "coerce_%s_%s" % (mod, valuetype.__name__)
72     xform = xform.replace(".", "_")
73     return xform
74
75 def getCoerceMethod(obj, value, valuetype=None):
76     """Return the coercion method for a given value [or type]."""
77     if valuetype is None:
78         valuetype = type(value)
79    
80     meth = getCoerceName(valuetype)
81     if hasattr(obj, meth):
82         return getattr(obj, meth)
83    
84     methods = []
85     for base in valuetype.__bases__:
86         meth = getCoerceName(base)
87         methods.append(meth)
88         if hasattr(obj, meth):
89             return getattr(obj, meth)
90    
91     raise TypeError("'%s' is not handled by %s.  Looked for: %s " %
92                     (valuetype, obj.__class__, ",".join(methods)))
93
94
95 class FieldTypeAdapter(object):
96     """For a Python type, return the SQL column type name.
97     
98     This base class is designed to work out-of-the-box with PostgreSQL 8.
99     """
100    
101     # 1000 is the max precision for NUMERIC columns for PostgreSQL 8.
102     # Override in subclasses.
103     numeric_max_precision = 1000
104    
105     def coerce(self, cls, key):
106         """coerce(cls, key) -> SQL typename for valuetype."""
107         valuetype = cls.property_type(key)
108         mod = valuetype.__module__
109         if mod == "__builtin__":
110             xform = "coerce_%s" % valuetype.__name__
111         else:
112             xform = "coerce_%s_%s" % (mod, valuetype.__name__)
113         xform = xform.replace(".", "_")
114         try:
115             xform = getattr(self, xform)
116         except AttributeError:
117             raise TypeError("'%s' is not handled by %s." %
118                             (valuetype, self.__class__))
119         return xform(cls, key)
120    
121     def coerce_float(self, cls, key):
122         prop = getattr(cls, key)
123         bytes = int(prop.hints.get('bytes', '8'))
124         if bytes < 5:
125             # In MySQL, REAL is still probably 8 bytes. Meh.
126             return "REAL"
127         else:
128             # Python floats are implemented using C doubles;
129             # actual precision depends on platform. PostgreSQL
130             # DOUBLE is 8 bytes (15 decimal-digit precision).
131             return "DOUBLE PRECISION"
132    
133     def coerce_str(self, cls, key):
134         # The bytes hint shall not reflect the usual 4-byte base for varchar.
135         prop = getattr(cls, key)
136         bytes = int(prop.hints.get('bytes', '0'))
137         if bytes:
138             return "VARCHAR(%s)" % bytes
139         else:
140             # TEXT is not an SQL standard, but it's common.
141             return "TEXT"
142    
143     def coerce_dict(self, cls, key):
144         return self.coerce_str(cls, key)
145     def coerce_list(self, cls, key):
146         return self.coerce_str(cls, key)
147     def coerce_tuple(self, cls, key):
148         return self.coerce_str(cls, key)
149     def coerce_unicode(self, cls, key):
150         return self.coerce_str(cls, key)
151    
152     def coerce_bool(self, cls, key): return "BOOLEAN"
153    
154     def coerce_datetime_datetime(self, cls, key): return "TIMESTAMP"
155     def coerce_datetime_date(self, cls, key): return "DATE"
156     def coerce_datetime_time(self, cls, key): return "TIME"
157    
158     # I was seriously disinterested in writing a parser for interval.
159     def coerce_datetime_timedelta(self, cls, key):
160         return self.coerce_float(cls, key)
161    
162     def coerce_decimal_Decimal(self, cls, key):
163         prop = getattr(cls, key)
164         precision = int(prop.hints.get('precision', '0'))
165         if precision == 0:
166             precision = decimal.getcontext().prec
167         if precision > self.numeric_max_precision:
168             warnings.warn("Decimal precision %s > maximum %s for %s.%s, "
169                           "using %s. Values may be stored incorrectly."
170                           % (precision, self.numeric_max_precision,
171                              cls.__name__, key, self.__class__.__name__),
172                           dejavu.StorageWarning)
173             precision = self.numeric_max_precision
174         # Assume most people use decimal for money; default scale = 2.
175         scale = int(prop.hints.get('scale', 2))
176         return "NUMERIC(%s, %s)" % (precision, scale)
177    
178     def coerce_decimal(self, cls, key):
179         # If decimal ever becomes a builtin. Python 2.5?
180         return self.coerce_decimal_Decimal(cls, key)
181    
182     def coerce_fixedpoint_FixedPoint(self, cls, key):
183         prop = getattr(cls, key)
184         precision = int(prop.hints.get('precision', '0'))
185         if precision == 0:
186             precision = self.numeric_max_precision
187         # Assume most people use decimal for money; default scale = 2.
188         scale = int(prop.hints.get('scale', 2))
189         return "NUMERIC(%s, %s)" % (precision, scale)
190    
191     def coerce_long(self, cls, key):
192         prop = getattr(cls, key)
193         bytes = int(prop.hints.get('bytes', 0))
194         if bytes <= 4:
195             return self.coerce_int(cls, key)
196         elif bytes <= 8:
197             # BIGINT is usually 8 bytes
198             return "BIGINT"
199         # Anything larger than 8 bytes, use decimal/numeric.
200         if bytes > self.numeric_max_precision:
201             warnings.warn("Long bytes %s > maximum %s for %s.%s, "
202                           "using %s. Values may be stored incorrectly."
203                           % (bytes, self.numeric_max_precision,
204                              cls.__name__, key, self.__class__.__name__),
205                           dejavu.StorageWarning)
206             bytes = self.numeric_max_precision
207         return "NUMERIC(%s, 0)" % bytes
208    
209     def coerce_int(self, cls, key):
210         prop = getattr(cls, key)
211         bytes = int(prop.hints.get('bytes', '4'))
212         if bytes == 1:
213             return "BOOLEAN"
214         elif bytes == 2:
215             return "SMALLINT"
216 ##        elif bytes == 3:
217 ##            return "MEDIUMINT"
218         else:
219             return "INTEGER"
220
221
222 class AdapterToSQL(object):
223     """Coerce Expression constants to SQL.
224     
225     This base class is designed to work out-of-the-box with PostgreSQL 8.
226     """
227    
228     # You should REALLY check into your DB's encoding and override this.
229     encoding = 'utf8'
230    
231     # Notice these are ordered pairs. Escape \ before introducing new ones.
232     # Values in these two lists should be strings encoded with self.encoding.
233     escapes = [("'", "''"), ("\\", r"\\")]
234     like_escapes = [("%", r"\%"), ("_", r"\_")]
235    
236     # These are not the same as coerce_bool (which is used on one side of
237     # a comparison). Instead, these are used when the whole (sub)expression
238     # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3".
239     bool_true = "TRUE"
240     bool_false = "FALSE"
241    
242     def escape_like(self, value):
243         """Prepare a string value for use in a LIKE comparison."""
244         if not isinstance(value, str):
245             value = value.encode(self.encoding)
246         # Notice we strip leading and trailing quote-marks.
247         value = value.strip("'\"")
248         for pat, repl in self.like_escapes:
249             value = value.replace(pat, repl)
250         return value
251    
252     def coerce(self, value, valuetype=None):
253         """coerce(value, valuetype=None) -> value, coerced by valuetype."""
254         meth = getCoerceMethod(self, value, valuetype)
255         return meth(value)
256    
257     def tostr(self, value):
258         if isinstance(value, basestring):
259             return value.encode(self.encoding)
260         else:
261             return str(value)
262    
263     def coerce_NoneType(self, value):
264         return "NULL"
265    
266     def coerce_bool(self, value):
267         if value:
268             return 'TRUE'
269         return 'FALSE'
270    
271     # The great thing about these 3 date coercers is that you can use
272     # them with (VAR)CHAR columns just as well as with DATETIME, etc.
273     # and comparisons will still work!
274     def coerce_datetime_datetime(self, value):
275         return ("'%04d-%02d-%02d %02d:%02d:%02d'" %
276                 (value.year, value.month, value.day,
277                  value.hour, value.minute, value.second))
278    
279     def coerce_datetime_date(self, value):
280         return "'%04d-%02d-%02d'" % (value.year, value.month, value.day)
281    
282     def coerce_datetime_time(self, value):
283         return "'%02d:%02d:%02d'" % (value.hour, value.minute, value.second)
284    
285     def coerce_datetime_timedelta(self, value):
286         float_val = value.days + (value.seconds / 86400.0)
287         return repr(float_val)
288    
289     coerce_decimal = tostr
290     coerce_decimal_Decimal = tostr
291    
292     def do_pickle(self, value):
293         # Note: dumps with protocol 0 uses the 'raw-unicode-escape'
294         # encoding, and we take pains not to re-encode it with
295         # self.encoding.
296         value = pickle.dumps(value)
297         value = self.coerce_str(value, skip_encoding=True)
298         return value
299    
300     coerce_dict = do_pickle
301    
302     coerce_fixedpoint_FixedPoint = tostr
303     coerce_float = tostr
304     coerce_int = tostr
305    
306     coerce_list = do_pickle
307    
308     coerce_long = tostr
309    
310     def coerce_str(self, value, skip_encoding=False):
311         if not skip_encoding and not isinstance(value, str):
312             value = value.encode(self.encoding)
313         for pat, repl in self.escapes:
314             value = value.replace(pat, repl)
315         return "'" + value + "'"
316    
317     coerce_tuple = do_pickle
318    
319     coerce_unicode = coerce_str
320
321
322 class AdapterFromDB(object):
323     """Coerce incoming values from DB types to Dejavu datatypes.
324     
325     You might notice that we pass coltype around a lot, but don't
326     refer to it in this base class. It's there so subclasses can
327     make decisions about coercion when they don't have control over
328     the types of database columns, and must make do with legacy
329     database implementations.
330     
331     This base class is designed to work out-of-the-box with PostgreSQL 8.
332     """
333    
334     # You should REALLY check into your DB's encoding and override this.
335     encoding = 'utf8'
336    
337     def coerce(self, value, coltype, valuetype=None):
338         """coerce(value, coltype, valuetype=None) -> value, coerced by valuetype."""
339         if value is None:
340             return None
341        
342         meth = getCoerceMethod(self, value, valuetype)
343         return meth(value, coltype)
344    
345     def consume(self, unit, key, value, coltype):
346         try:
347             expectedType = unit.__class__.property_type(key)
348             value = self.coerce(value, coltype, expectedType)
349             unit._properties[key] = value
350         except UnicodeDecodeError, x:
351             x.reason += "[%s][%s][%s]" % (key, value, coltype)
352             raise
353         except Exception, x:
354             x.args += (key, value, coltype)
355             raise
356    
357     def do_pickle(self, value, coltype):
358         # Coerce to str for pickle.loads restriction.
359         if isinstance(value, unicode):
360             value = value.encode(self.encoding)
361         value = str(value)
362         return pickle.loads(value)
363    
364     def coerce_bool(self, value, coltype):
365         return bool(value)
366    
367     def coerce_datetime_datetime(self, value, coltype):
368         chunks = (value[0:4], value[5:7], value[8:10],
369                   value[11:13], value[14:16], value[17:19])
370         return datetime.datetime(*map(int, chunks))
371    
372     def coerce_datetime_date(self, value, coltype):
373         chunks = (value[0:4], value[5:7], value[8:10])
374         return datetime.date(*map(int, chunks))
375    
376     def coerce_datetime_time(self, value, coltype):
377         chunks = (value[0:2], value[3:5], value[6:8])
378         return datetime.time(*map(int, chunks))
379    
380     def coerce_datetime_timedelta(self, value, coltype):
381         days, seconds = divmod(value, 1)
382         return datetime.timedelta(days, int(seconds * 86400))
383    
384     def coerce_decimal(self, value, coltype):
385         return decimal(str(value))
386    
387     def coerce_decimal_Decimal(self, value, coltype):
388         return decimal.Decimal(str(value))
389    
390     coerce_dict = do_pickle
391    
392     def coerce_fixedpoint_FixedPoint(self, value, coltype):
393         return fixedpoint.FixedPoint(value)
394    
395     def coerce_float(self, value, coltype):
396         return float(value)
397    
398     def coerce_int(self, value, coltype):
399         return int(value)
400    
401     coerce_list = do_pickle
402    
403     def coerce_long(self, value, coltype):
404         return long(value)
405    
406     def coerce_str(self, value, coltype):
407         if isinstance(value, basestring):
408             return value.encode(self.encoding)
409         else:
410             return str(value)
411    
412     coerce_tuple = do_pickle
413    
414     def coerce_unicode(self, value, coltype):
415         if isinstance(value, unicode):
416             return value
417         else:
418             return unicode(value, self.encoding)
419
420
421 # -------------------------- SQL DECOMPILATION -------------------------- #
422
423 class ConstWrapper(str):
424     """Wraps a constant for use in SQLDecompiler's stack.
425     
426     When we hit LOAD_CONST while decompiling, we occasionally need to keep
427     both the base and the coerced value around (see COMPARE_OP for use
428     of ConstWrapper.basevalue).
429     """
430     def __new__(self, basevalue, coerced_value):
431         newobj = str.__new__(ConstWrapper, coerced_value)
432         newobj.basevalue = basevalue
433         return newobj
434
435
436 class TableRef:
437     """Wraps a table reference for use in SQLDecompiler's stack.
438     
439     When we hit LOAD_FAST while decompiling, that should always be translated
440     into a table reference in the SQL.
441     """
442     def __init__(self, classname):
443         self.classname = classname
444
445 # Stack sentinels
446 class Sentinel(object):
447    
448     def __init__(self, name):
449         self.name = name
450    
451     def __repr__(self):
452         return 'Stack Sentinel: %s' % self.name
453
454 kw_arg = Sentinel('Keyword Arg')
455 # cannot_represent exists so that a portion of an Expression can be
456 # labeled imperfect. For example, the function dejavu.iscurrentweek
457 # rarely has an SQL equivalent. All Units (which match the rest of the
458 # Expression) will be recalled; they can then be compared in expr(unit).
459 cannot_represent = Sentinel('Cannot Repr')
460
461
462 class SQLDecompiler(codewalk.LambdaDecompiler):
463     """SQLDecompiler(classnames, expr, sm, adapter=AdapterToSQL()).
464     
465     Produce SQL from a supplied Expression object, with a lambda of the form:
466         lambda x, **kw: ...
467     
468     Attributes of each argument in the signature will be mapped to table
469     columns. Keyword arguments should be bound using Expression.bind_args
470     before calling this decompiler.
471     """
472    
473     # Some constants are function or class objects,
474     # which should not be coerced.
475     no_coerce = (FunctionType,
476                  type,
477                  type(len),       # <type 'builtin_function_or_method'>
478                  )
479    
480     sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in')
481    
482     def __init__(self, classnames, expr, sm, adapter=AdapterToSQL()):
483         self.classnames = classnames
484         self.expr = expr
485         self.adapter = adapter
486         self.sm = sm
487         obj = expr.func
488         codewalk.LambdaDecompiler.__init__(self, obj)
489    
490     def code(self):
491         self.imperfect = False
492         self.walk()
493         # After walk(), self.stack should be reduced to a single string,
494         # which is the SQL representation of our Expression.
495         result = self.stack[0]
496         if result is cannot_represent:
497             # The entire expression could not be evaluated.
498             result = self.adapter.bool_true
499         if result == self.adapter.coerce_bool(True):
500             result = self.adapter.bool_true
501         if result == self.adapter.coerce_bool(False):
502             result = self.adapter.bool_false
503         return result
504    
505     def visit_instruction(self, op, lo=None, hi=None):
506         # Get the instruction pointer for the current instruction.
507         ip = self.cursor - 3
508         if hi is None:
509             ip += 1
510             if lo is None:
511                 ip += 1
512        
513         terms = self.targets.get(ip)
514         if terms:
515             trueval = self.adapter.bool_true
516             falseval = self.adapter.bool_false
517             clause = self.stack[-1]
518             while terms:
519                 term, oper = terms.pop()
520                 if term is cannot_represent:
521                     term = trueval
522                 if clause is cannot_represent:
523                     clause = trueval
524                
525                 # Blurg. SQL Server is *so* picky.
526                 if term == self.adapter.coerce_bool(True):
527                     term = trueval
528                 elif term == self.adapter.coerce_bool(False):
529                     term = falseval
530                 if clause == self.adapter.coerce_bool(True):
531                     clause = trueval
532                 elif clause == self.adapter.coerce_bool(False):
533                     clause = falseval
534                
535                 clause = "(%s) %s (%s)" % (term, oper.upper(), clause)
536            
537             # Replace TOS with the new clause, so that further
538             # combinations have access to it.
539             self.stack[-1] = clause
540             self.debug("clause:", clause, "\n")
541            
542             if op == 1:
543                 # Py2.4: The current instruction is POP_TOP, which means
544                 # the previous is probably JUMP_*. If so, we're going to
545                 # pop the value we just placed on the stack and lose it.
546                 # We need to replace the entry that the JUMP_* made in
547                 # self.targets with our new TOS.
548                 target = self.targets[self.last_target_ip]
549                 target[-1] = ((clause, target[-1][1]))
550                 self.debug("newtarget:", self.last_target_ip, target)
551    
552     def visit_LOAD_DEREF(self, lo, hi):
553         raise ValueError("Illegal reference found in %s." % self.expr)
554    
555     def visit_LOAD_GLOBAL(self, lo, hi):
556         raise ValueError("Illegal global found in %s." % self.expr)
557    
558     def visit_LOAD_FAST(self, lo, hi):
559         arg_index = lo + (hi << 8)
560         if arg_index < self.co_argcount:
561             self.stack.append(TableRef(self.classnames[arg_index]))
562         else:
563             self.stack.append(kw_arg)
564    
565     def visit_LOAD_ATTR(self, lo, hi):
566         name = self.co_names[lo + (hi << 8)]
567         tos = self.stack.pop()
568         if isinstance(tos, TableRef):
569             atom = self.sm.column_name(tos.classname, name, full=True)
570         else:
571             # tos.name will reference an attribute of the tos object.
572             # Stick the tos and name in a tuple for later processing.
573             atom = (tos, name)
574         self.stack.append(atom)
575    
576     def visit_LOAD_CONST(self, lo, hi):
577         val = self.co_consts[lo + (hi << 8)]
578         if not isinstance(val, self.no_coerce):
579             val = ConstWrapper(val, self.adapter.coerce(val))
580         self.stack.append(val)
581    
582     def visit_BUILD_TUPLE(self, lo, hi):
583         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)])
584         self.stack.append("(" + terms + ")")
585    
586     visit_BUILD_LIST = visit_BUILD_TUPLE
587    
588     def visit_CALL_FUNCTION(self, lo, hi):
589         kwargs = {}
590         for i in xrange(hi):
591             val = self.stack.pop()
592             key = self.stack.pop()
593             kwargs[key] = val
594         kwargs = [k + "=" + v for k, v in kwargs.iteritems()]
595        
596         args = []
597         for i in xrange(lo):
598             arg = self.stack.pop()
599             args.append(arg)
600         args.reverse()
601        
602         if kwargs:
603             args += kwargs
604        
605         func = self.stack.pop()
606        
607         # Handle function objects.
608         if isinstance(func, tuple):
609             tos, name = func
610             dispatch = getattr(self, "attr_" + name, None)
611             if dispatch:
612                 self.stack.append(dispatch(tos, *args))
613                 return
614         else:
615             funcname = func.__module__ + "_" + func.__name__
616             funcname = funcname.replace(".", "_")
617             if funcname.startswith("_"):
618                 funcname = "func" + funcname
619             dispatch = getattr(self, funcname, None)
620             if dispatch:
621                 self.stack.append(dispatch(*args))
622                 return
623        
624         self.stack.append(cannot_represent)
625         self.imperfect = True
626    
627     def visit_COMPARE_OP(self, lo, hi):
628         op2, op1 = self.stack.pop(), self.stack.pop()
629         if op1 is cannot_represent or op2 is cannot_represent:
630             self.stack.append(cannot_represent)
631             return
632        
633         op = lo + (hi << 8)
634         if op in (6, 7):     # in, not in
635             value = self.containedby(op1, op2)
636             if op == 7:
637                 value = "NOT " + value
638             self.stack.append(value)
639         elif op1 == 'NULL':
640             if op == 2:
641                 self.stack.append(op2 + " IS NULL")
642             elif op == 3:
643                 self.stack.append(op2 + " IS NOT NULL")
644             else:
645                 raise ValueError("Non-equality Null comparisons not allowed.")
646         elif op2 == 'NULL':
647             if op == 2:
648                 self.stack.append(op1 + " IS NULL")
649             elif op == 3:
650                 self.stack.append(op1 + " IS NOT NULL")
651             else:
652                 raise ValueError("Non-equality Null comparisons not allowed.")
653         else:
654             # Comparison operators for strings are case-sensitive in PG et al.
655             self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2)
656    
657     def binary_op(self, op):
658         op2, op1 = self.stack.pop(), self.stack.pop()
659         self.stack.append(op1 + " " + op + " " + op2)
660    
661     def visit_BINARY_SUBSCR(self):
662         # The only BINARY_SUBSCR used in Expressions should be kwargs[key].
663         name = self.stack.pop()
664         tos = self.stack.pop()
665         if tos is not kw_arg:
666             raise ValueError("Subscript %s of %s object not allowed."
667                              % (name, tos))
668         # name, since formed in LOAD_CONST, may have extraneous quotes.
669         name = name.strip("'\"")
670         value = self.expr.kwargs[name]
671         if not isinstance(value, self.no_coerce):
672             value = ConstWrapper(value, self.adapter.coerce(value))
673         self.stack.append(value)
674    
675     def visit_UNARY_NOT(self):
676         op = self.stack.pop()
677         if op is cannot_represent:
678             self.stack.append(cannot_represent)
679         else:
680             self.stack.append("NOT (" + op + ")")
681    
682     # --------------------------- Dispatchees --------------------------- #
683    
684     def attr_startswith(self, tos, arg):
685         return tos + " LIKE '" + self.adapter.escape_like(arg) + "%'"
686    
687     def attr_endswith(self, tos, arg):
688         return tos + " LIKE '%" + self.adapter.escape_like(arg) + "'"
689    
690     def containedby(self, op1, op2):
691         if isinstance(op1, ConstWrapper):
692             # Looking for text in a field. Use Like (reverse terms).
693             return op2 + " LIKE '%" + self.adapter.escape_like(op1) + "%'"
694         else:
695             # Looking for field in (a, b, c)
696             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
697             return op1 + " IN (" + ", ".join(atoms) + ")"
698    
699     def dejavu_icontainedby(self, op1, op2):
700         if isinstance(op1, ConstWrapper):
701             # Looking for text in a field. Use Like (reverse terms).
702             return ("LOWER(" + op2 + ") LIKE '%" +
703                     self.adapter.escape_like(op1).lower() + "%'")
704         else:
705             # Looking for field in (a, b, c).
706             # Force all args to lowercase for case-insensitive comparison.
707             atoms = [self.adapter.coerce(x).lower() for x in op2.basevalue]
708             return "LOWER(%s) IN (%s)" % (op1, ", ".join(atoms))
709    
710     def dejavu_icontains(self, x, y):
711         return self.dejavu_icontainedby(y, x)
712    
713     def dejavu_istartswith(self, x, y):
714         return "LOWER(" + x + ") LIKE '" + self.adapter.escape_like(y) + "%'"
715    
716     def dejavu_iendswith(self, x, y):
717         return "LOWER(" + x + ") LIKE '%" + self.adapter.escape_like(y) + "'"
718    
719     def dejavu_ieq(self, x, y):
720         return "LOWER(" + x + ") = LOWER(" + y + ")"
721    
722     def dejavu_now(self):
723         return "NOW()"
724    
725     def dejavu_today(self):
726         return "CURRENT_DATE"
727    
728     def dejavu_year(self, x):
729         return "YEAR(" + x + ")"
730    
731     def func__builtin___len(self, x):
732         return "LENGTH(" + x + ")"
733
734
735 class ConnectionWrapper(object):
736     """Connection object wrapper, so it can be used as a weak reference."""
737    
738     def __init__(self, conn=None):
739         self.conn = conn
740    
741     def __getattr__(self, attr):
742         return getattr(self.conn, attr)
743
744
745 class UnitClassWrapper(object):
746     """Unit class wrapper, for use in parsing multiselect joins."""
747    
748     def __init__(self, wclass, sm):
749         self.cls = wclass
750         self.sm = sm
751        
752         wclsname = wclass.__name__
753         self.tablename = sm.table_name(wclsname)
754         self.alias = ""
755    
756     def columns(self):
757         wclass = self.cls
758        
759         # Place the identifier properties first
760         # in case others depend upon them.
761         keys = list(wclass.identifiers) + [k for k in wclass.properties
762                                            if k not in wclass.identifiers]
763         cols = [(wclass, k) for k in keys]
764         colnames = ['%s.%s' % (self.alias or self.tablename,
765                                self.sm.column_name(wclass.__name__, k))
766                     for k in keys]
767         return cols, colnames
768    
769     def _joinname(self):
770         if self.alias:
771             return "%s AS %s" % (self.tablename, self.alias)
772         else:
773             return self.tablename
774     joinname = property(_joinname, doc=("Table name for use in "
775                                             "JOIN clause (read-only)."))
776    
777     def association(self, classes):
778         for other in classes:
779             ua = self.cls._associations.get(other.cls.__name__, None)
780             if ua:
781                 nearClass = self.alias or self.tablename
782                 farClass = other.alias or other.tablename
783                 return ua, nearClass, farClass
784             ua = other.cls._associations.get(self.cls.__name__, None)
785             if ua:
786                 nearClass = other.alias or other.tablename
787                 farClass = self.alias or self.tablename
788                 return ua, nearClass, farClass
789         return None
790
791
792 class StorageManagerDB(storage.StorageManager):
793     """StoreManager base class to save and retrieve Units using a DB."""
794    
795     sql_name_max_length = 64
796     sql_name_caseless = False
797     close_connection_method = 'close'
798     use_asterisk_to_get_all = False
799    
800     decompiler = SQLDecompiler
801     typeAdapter = FieldTypeAdapter()
802     toAdapter = AdapterToSQL()
803     fromAdapter = AdapterFromDB()
804     debug_connections = False
805    
806     def __init__(self, name, arena, allOptions={}):
807         storage.StorageManager.__init__(self, name, arena, allOptions)
808        
809         # Adapter Overrides
810         def get_adapter_option(name):
811             adapter_class = allOptions.get(name)
812             if isinstance(adapter_class, basestring):
813                 adapter_class = xray.classes(adapter_class)
814             return adapter_class
815        
816         adapter = get_adapter_option('Type Adapter')
817         if adapter: self.typeAdapter = adapter
818         adapter = get_adapter_option('To Adapter')
819         if adapter: self.toAdapter = adapter
820         adapter = get_adapter_option('From Adapter')
821         if adapter: self.fromAdapter = adapter
822        
823         self.pool_size = int(allOptions.get('Pool Size', '10'))
824        
825         self.refs = {}
826         if self.pool_size > 0:
827             self.pool = Queue.Queue(self.pool_size)
828         else:
829             self.pool = None
830         self.retry = 5
831         self.threaded = True
832        
833         self.prefix = allOptions.get('Prefix', "djv")
834         self.reserve_lock = threading.Lock()
835    
836     #                               Naming                               #
837    
838     def sql_name(self, name, quoted=True):
839         """The name, escaped for SQL."""
840         if self.sql_name_caseless:
841             name = name.lower()
842        
843         maxlen = self.sql_name_max_length
844         if maxlen and len(name) > maxlen:
845             warnings.warn("The name '%s' is longer than the maximum of "
846                           "%s characters." % (name, maxlen),
847                           dejavu.StorageWarning)
848             name = name[:maxlen]
849        
850         # This base class doesn't use the "quoted" arg,
851         # but most subclasses will.
852         return name
853    
854     def column_name(self, classname, name, full=False, quoted=True):
855         """The column name, escaped for SQL. If full, include tablename."""
856         # If you want to use a map from UnitProperty names
857         # to DB column names, override this method.
858         name = self.sql_name(name, quoted=quoted)
859         if not full:
860             return name
861        
862         alias = getattr(classname, "alias", None)
863         if alias is None:
864             tname = self.table_name(classname, quoted=quoted)
865         else:
866             tname = (classname.alias or classname.tablename)
867         return '%s.%s' % (tname, name)
868    
869     def table_name(self, name, quoted=True):
870         """The table name, escaped for SQL."""
871         # If you want to use a map from Unit class names
872         # to DB table names, override this method.
873         return self.sql_name(self.prefix + name, quoted=quoted)
874    
875     #                             Connecting                              #
876    
877     def _get_conn(self):
878         # Override this with the connection call for your DB. Example:
879         #     return libpq.PQconnectdb(self.connstring)
880         raise NotImplementedError
881    
882     def connection(self):
883         """Return a connection from the pool."""
884         if not self.threaded:
885             # Place a single 'conn' entry in self.refs.
886             try:
887                 return self.refs['conn']
888             except KeyError:
889                 self.refs['conn'] = conn = self._get_conn()
890                 return conn
891        
892         retry = 0
893         while True:
894             if self.pool is not None:
895                 try:
896                     conn = self.pool.get_nowait()
897                     # Okay, this is freaky. If we wrap here, all goes well.
898                     # If we wrap on Queue.put(), mysql crashes after 1700
899                     # or so inserts (when migrating Access tables to MySQL).
900                     # Go figure.
901                     w = ConnectionWrapper(conn)
902                     self.refs[weakref.ref(w, self.release)] = w.conn
903                     self.arena.log("-->get %s" % self.__class__.__name__,
904                                    LOGCONN)
905                     return w
906                 except Queue.Empty:
907                     pass
908            
909             try:
910                 conn = self._get_conn()
911                 w = ConnectionWrapper(conn)
912                 self.refs[weakref.ref(w, self.release)] = w.conn
913                 self.arena.log("create %s" % self.__class__.__name__, LOGCONN)
914                 return w
915             except OutOfConnectionsError:
916                 retry += 1
917                 if retry < self.retry:
918                     time.sleep(retry * 1)
919                     conn = None
920                     continue
921                 raise OutOfConnectionsError()
922    
923     def release(self, ref):
924         # This method should only be called if self.threaded is True
925         conn = self.refs.pop(ref)
926        
927         if self.pool is not None:
928             try:
929                 self.pool.put_nowait(conn)
930                 self.arena.log("<--put %s" % self.__class__.__name__, LOGCONN)
931                 return
932             except Queue.Full:
933                 pass
934        
935         getattr(conn, self.close_connection_method)()
936         self.arena.log("___close___ %s" % self.__class__.__name__, LOGCONN)
937    
938     def shutdown(self):
939         """Release all database connections."""
940         # Empty the pool.
941         if self.pool:
942             while True:
943                 try:
944                     self.pool.get(True, 0.5)
945                 except Queue.Empty:
946                     break
947        
948         # Empty self.refs.
949         while self.refs:
950             ref, conn = self.refs.popitem()
951             getattr(conn, self.close_connection_method)()
952    
953     def select(self, cls, expr, fields=None, distinct=False):
954         """Return an SQL SELECT statement, and an 'imperfect' flag.
955         
956         imperfect: True or False depending on whether the generated SQL
957             perfectly satisfies the given expression.
958         """
959         clsname = cls.__name__
960         tablename = self.table_name(clsname)
961         if fields:
962             fields = [self.column_name(clsname, x) for x in fields]
963             if distinct:
964                 sql = 'SELECT DISTINCT %s FROM %s'
965             else:
966                 sql = 'SELECT %s FROM %s'
967             sql = sql % (', '.join(fields), tablename)
968         else:
969             sql = 'SELECT * FROM %s' % tablename
970        
971         w, i = self.where((clsname,), expr)
972         if len(w) > 0:
973             w = " WHERE " + w
974         else:
975             w = ""
976         sql += w + ";"
977         return sql, i
978    
979     def where(self, classnames, expr):
980         """Return an SQL WHERE clause, and an 'imperfect' flag.
981         
982         imperfect: True or False depending on whether the generated SQL
983             perfectly satisfies the given expression.
984         """
985         decom = self.decompiler(classnames, expr, self, self.toAdapter)
986         return decom.code(), decom.imperfect
987    
988     def execute(self, query, conn=None):
989         """execute(query, conn=None) -> result set."""
990         if conn is None:
991             conn = self.connection()
992         if isinstance(query, unicode):
993             query = query.encode(self.toAdapter.encoding)
994         self.arena.log(query, LOGSQL)
995         return conn.query(query)
996    
997     def fetch(self, query, conn=None):
998         """fetch(query, conn=None) -> rowdata, columns.
999         
1000         rowdata will be an iterable of iterables containing the result values.
1001         columns will be an iterable of (column name, data type) pairs.
1002         
1003         This base class uses SQLite3 syntax.
1004         """
1005         res = self.execute(query, conn)
1006         return res.row_list, res.col_defs
1007    
1008     def recall(self, cls, expr=None):
1009         """Yield a sequence of Unit instances which satisfy the expression."""
1010         clsname = cls.__name__
1011        
1012         if expr is None:
1013             expr = logic.Expression(lambda x: True)
1014         sql, imperfect = self.select(cls, expr)
1015         data, col_defs = self.fetch(sql)
1016         if data:
1017             columns = dict([(col[0], (index, col[1])) for index, col
1018                             in enumerate(col_defs)])
1019            
1020             # Get specs on properties. Put the identifier properties
1021             # first, in case other fields depend upon them.
1022             props = []
1023             idnames = list(cls.identifiers)
1024             for key in idnames + [x for x in cls.properties if x not in idnames]:
1025                 index, ftype = columns[self.column_name(clsname, key, quoted=False)]
1026                 props.append((key, index, ftype))
1027            
1028             consume = self.fromAdapter.consume
1029             for row in data:
1030                 unit = cls()
1031                 for key, index, ftype in props:
1032                     value = row[index]
1033                     consume(unit, key, value, ftype)
1034                
1035                 # If our SQL is imperfect, don't yield it to the
1036                 # caller unless it passes expr(unit).
1037                 if (not imperfect) or expr(unit):
1038                     unit.cleanse()
1039                     yield unit
1040    
1041     def reserve(self, unit):
1042         """reserve(unit). -> Reserve a persistent slot for unit."""
1043         self.reserve_lock.acquire()
1044         try:
1045             # First, see if our db subclass has a handler that
1046             # uses the DB to generate the appropriate identifier(s).
1047             seqclass = unit.sequencer.__class__.__name__
1048             seq_handler = getattr(self, "_seq_%s" % seqclass, None)
1049             if seq_handler:
1050                 seq_handler(unit)
1051             else:
1052                 self._manual_reserve(unit)
1053             unit.cleanse()
1054         finally:
1055             self.reserve_lock.release()
1056    
1057     def _manual_reserve(self, unit):
1058         """Use when the DB cannot automatically generate an identifier.
1059         The identifiers will be supplied by UnitSequencer.assign().
1060         """
1061        
1062         cls = unit.__class__
1063         clsname = cls.__name__
1064         tablename = self.table_name(clsname)
1065         if not unit.sequencer.valid_id(unit.identity()):
1066             # Examine all existing IDs and grant the "next" one.
1067             id_fields = [self.column_name(clsname, key)
1068                          for key in cls.identifiers]
1069             data, cols = self.fetch('SELECT %s FROM %s;' %
1070                                     (', '.join(id_fields), tablename))
1071             if data:
1072                 # sqlite 2, for example, has empty cols tuple if no data.
1073                 coerce = self.fromAdapter.coerce
1074                 coltypes = [cols[x][1] for x in xrange(len(cols))]
1075                 expectedTypes = [getattr(cls, key).type
1076                                  for key in cls.identifiers]
1077                 newdata = []
1078                 for row in data:
1079                     newrow = []
1080                     for x, cell in enumerate(row):
1081                         newrow.append(coerce(cell, coltypes[x],
1082                                              expectedTypes[x]))
1083                     newdata.append(newrow)
1084                 data = newdata
1085                 del newdata
1086             cls.sequencer.assign(unit, data)
1087             del data
1088             del cols
1089        
1090         fields = []
1091         values = []
1092         for key in cls.properties:
1093             val = self.toAdapter.coerce(getattr(unit, key))
1094             fields.append(self.column_name(clsname, key))
1095             values.append(val)
1096        
1097         fields = ", ".join(fields)
1098         values = ", ".join(values)
1099         self.execute('INSERT INTO %s (%s) VALUES (%s);' %
1100                      (str(tablename), fields, values))
1101    
1102     def id_clause(self, unit):
1103         """Return an SQL expression for the identifiers of the given Unit."""
1104         clsname = unit.__class__.__name__
1105         col = self.column_name
1106         c = self.toAdapter.coerce
1107         return " AND ".join(["%s = %s" % (col(clsname, key),
1108                                           c(getattr(unit, key)))
1109                              for key in unit.identifiers])
1110    
1111     def save(self, unit, forceSave=False):
1112         """save(unit, forceSave=False) -> Update storage from unit's data."""
1113         if unit.dirty() or forceSave:
1114             cls = unit.__class__
1115             clsname = cls.__name__
1116            
1117             parms = []
1118             for key in cls.properties:
1119                 if key not in cls.identifiers:
1120                     val = self.toAdapter.coerce(getattr(unit, key))
1121                     parms.append('%s = %s' %
1122                                  (self.column_name(clsname, key), val))
1123            
1124             if parms:
1125                 sql = ('UPDATE %s SET %s WHERE %s;' %
1126                        (self.table_name(clsname), ", ".join(parms),
1127                         self.id_clause(unit)))
1128                 self.execute(sql)
1129             unit.cleanse()
1130    
1131     def destroy(self, unit):
1132         """destroy(unit). Delete the unit."""
1133         if self.use_asterisk_to_get_all:
1134             star = " *"
1135         else:
1136             star = ""
1137         self.execute('DELETE%s FROM %s WHERE %s;' %
1138                      (star, self.table_name(unit.__class__.__name__),
1139                       self.id_clause(unit)))
1140    
1141     def view(self, cls, fields, expr=None):
1142         """view(cls, fields, expr=None) -> All value-tuples for given fields."""
1143         if expr is None:
1144             expr = logic.Expression(lambda x: True)
1145        
1146         sql, imperfect = self.select(cls, expr, fields)
1147         if imperfect:
1148             # ^%$#@! There's no way to handle imperfect queries without
1149             # creating all involved Units, which defeats the purpose of
1150             # view, which was a speed issue more than anything else.
1151             warnings.warn("The requested view() query for %s Units "
1152                           "cannot produce perfect SQL with a %s datasource. "
1153                           "It may take an absurd amount of time to run, "
1154                           "since each unit must be fully-formed. %s"
1155                           % (cls.__name__, self.__class__.__name__, expr),
1156                           dejavu.StorageWarning)
1157             for unit in self.recall(cls, expr):
1158                 # Use tuples for hashability
1159                 yield tuple([getattr(unit, f) for f in fields])
1160         else:
1161             data, columns = self.fetch(sql)
1162             actualTypes = [x[1] for x in columns]
1163             expectedTypes = [cls.property_type(x) for x in fields]
1164            
1165             coerce = self.fromAdapter.coerce
1166             # Use tuples for hashability
1167             for row in data:
1168                 yield tuple([coerce(val, actualTypes[i], expectedTypes[i])
1169                              for i, val in enumerate(row)])
1170    
1171     def distinct(self, cls, fields, expr=None):
1172         """distinct(cls, fields, expr=None) -> Distinct values for given fields."""
1173         if expr is None:
1174             expr = logic.Expression(lambda x: True)
1175        
1176         sql, imperfect = self.select(cls, expr, fields, distinct=True)
1177         if imperfect:
1178             # ^%$#@! There's no way to handle imperfect queries without
1179             # creating all involved Units, which defeats the purpose of
1180             # distinct, which was a speed issue more than anything.
1181             warnings.warn("The requested distinct() query for %s Units "
1182                           "cannot produce perfect SQL with a %s datasource. "
1183                           "It may take an absurd amount of time to run, "
1184                           "since each unit must be fully-formed. %s"
1185                           % (cls.__name__, self.__class__.__name__, expr),
1186                           dejavu.StorageWarning)
1187             vals = {}
1188             for unit in self.recall(cls, expr):
1189                 # Must use tuples for hashability
1190                 val = tuple([getattr(unit, f) for f in fields])
1191                 vals[val] = None
1192             return vals.keys()
1193         else:
1194             data, columns = self.fetch(sql)
1195             actualTypes = [x[1] for x in columns]
1196             expectedTypes = [cls.property_type(x) for x in fields]
1197            
1198             coerce = self.fromAdapter.coerce
1199             # Must use inner tuples for hashability in Sandbox.distinct()
1200             return [tuple([coerce(val, actualTypes[i], expectedTypes[i])
1201                            for i, val in enumerate(row)])
1202                      for row in data]
1203    
1204     def join(self, unitjoin):
1205         """Return an SQL FROM clause for the given unitjoin."""
1206         cls1, cls2 = unitjoin.class1, unitjoin.class2
1207         if isinstance(cls1, dejavu.UnitJoin):
1208             name1 = self.join(cls1)
1209             classlist1 = iter(cls1)
1210         else:
1211             # cls1 is a Unit class wrapper.
1212             name1 = cls1.joinname
1213             classlist1 = [cls1]
1214        
1215         if isinstance(cls2, dejavu.UnitJoin):
1216             name2 = self.join(cls2)
1217             classlist2 = iter(cls2)
1218         else:
1219             # cls2 is a Unit class wrapper.
1220             name2 = cls2.joinname
1221             classlist2 = [cls2]
1222        
1223         j = {None: "INNER", True: "LEFT", False: "RIGHT"}[unitjoin.leftbiased]
1224        
1225         # Find an association between the two halves.
1226         ua = None
1227         for clsA in classlist1:
1228             ua = clsA.association(classlist2)
1229             if ua:
1230                 ua, nearClass, farClass = ua
1231                 break
1232         if ua is None:
1233             msg = ("No association found between %s and %s." % (name1, name2))
1234             raise dejavu.AssociationError(msg)
1235         near = '%s.%s' % (nearClass, self.column_name(nearClass, ua.nearKey))
1236         far = '%s.%s' % (farClass, self.column_name(farClass, ua.farKey))
1237        
1238         return "(%s %s JOIN %s ON %s = %s)" % (name1, j, name2, near, far)
1239    
1240     def multiselect(self, classes, expr):
1241         """Return an SQL SELECT statement, an imperfect flag, and column names."""
1242        
1243         # Create a new unitjoin tree where each class is wrapped.
1244         # Then we can tag the wrappers with metadata with impunity.
1245         seen = {}
1246         aliascount = [0]
1247        
1248         def wrap(unitjoin):
1249             cls1, cls2 = unitjoin.class1, unitjoin.class2
1250             if isinstance(cls1, dejavu.UnitJoin):
1251                 wclass1 = wrap(cls1)
1252             else:
1253                 wclass1 = UnitClassWrapper(cls1, self)
1254                 if cls1 in seen:
1255                     aliascount[0] += 1
1256                     wclass1.alias = "t%d" % aliascount[0]
1257                 else:
1258                     seen[cls1] = None
1259             if isinstance(cls2, dejavu.UnitJoin):
1260                 wclass2 = wrap(cls2)
1261             else:
1262                 wclass2 = UnitClassWrapper(cls2, self)
1263                 if cls2 in seen:
1264                     aliascount[0] += 1
1265                     wclass2.alias = "t%d" % aliascount[0]
1266                 else:
1267                     seen[cls2] = None
1268             return dejavu.UnitJoin(wclass1, wclass2, unitjoin.leftbiased)
1269         classes = wrap(classes)
1270        
1271         joins = self.join(classes)
1272        
1273         if expr is None:
1274             expr = logic.Expression(lambda *args: True)
1275         w, imp = self.where(list(classes), expr)
1276        
1277         cols = []
1278         colnames = []
1279         for wrapper in classes:
1280             c, names = wrapper.columns()
1281             cols.extend(c)
1282             colnames.extend(names)
1283        
1284         statement = ("SELECT %s FROM %s WHERE %s" %
1285                      (', '.join(colnames), joins, w))
1286         return statement, imp, cols
1287    
1288     def multirecall(self, classes, expr):
1289         """Yield Unit instance sets which satisfy the expression."""
1290         sql, imp, supplied_cols = self.multiselect(classes, expr)
1291         data, recvd_cols = self.fetch(sql)
1292         if data:
1293             # Get specs on properties.
1294             props = []
1295             for sup, rec in zip(supplied_cols, recvd_cols):
1296                 c, key = sup
1297                 name, ftype = rec[0], rec[1]
1298                 props.append((c, key, ftype))
1299            
1300             consume = self.fromAdapter.consume
1301             for row in data:
1302                 index = 0
1303                 units = {}
1304                 for c, key, ftype in props:
1305                     if c in units:
1306                         unit = units[c]
1307                     else:
1308                         units[c] = unit = c()
1309                     value = row[index]
1310                     consume(unit, key, value, ftype)
1311                     index += 1
1312                
1313                 unitset = []
1314                 for cls in classes:
1315                     unit = units[cls]
1316                     unit.cleanse()
1317                     unitset.append(unit)
1318                
1319                 # If our SQL is imperfect, don't yield units to the
1320                 # caller unless they pass expr(unit).
1321                 acceptable = True
1322                 if imp:
1323                     acceptable = expr(*unitset)
1324                 if acceptable:
1325                     yield unitset
1326    
1327     #                               Schemas                               #
1328    
1329     def create_database(self):
1330         self.execute("CREATE DATABASE %s;" % self.sql_name(self.dbname))
1331    
1332     def drop_database(self):
1333         self.execute("DROP DATABASE %s;" % self.sql_name(self.dbname))
1334    
1335     def create_storage(self, cls):
1336         """Create storage for the given class."""
1337         clsname = cls.__name__
1338         tablename = self.table_name(clsname)
1339         typename = self.typeAdapter.coerce
1340        
1341         fields = []
1342         for key in cls.properties:
1343             fields.append('%s %s' % (self.column_name(clsname, key),
1344                                      typename(cls, key)))
1345         self.execute('CREATE TABLE %s (%s);' % (tablename, ", ".join(fields)))
1346        
1347         for index in cls.indices():
1348             i = self.table_name("i" + clsname + index)
1349             self.execute('CREATE INDEX %s ON %s (%s);' %
1350                          (i, tablename, self.column_name(clsname, index)))
1351    
1352     def has_storage(self, cls):
1353         try:
1354             # Must use fetch here instead of execute, because e.g. MySQL
1355             # must call store_result if the query has a result set
1356             # (or it will crash on a subsequent execute).
1357             self.fetch("SELECT * FROM %s;" % self.table_name(cls.__name__))
1358         except:
1359             return False
1360         return True
1361    
1362     def drop_storage(self, cls):
1363         self.execute('DROP TABLE %s;' % self.table_name(cls.__name__))
1364    
1365     def add_property(self, cls, name):
1366         if not self.has_property(cls, name):
1367             clsname = cls.__name__
1368             self.execute("ALTER TABLE %s ADD COLUMN %s %s;" %
1369                          (self.table_name(clsname),
1370                           self.column_name(clsname, name),
1371                           self.typeAdapter.coerce(cls, name),
1372                           ))
1373    
1374     def has_property(self, cls, name):
1375         clsname = cls.__name__
1376         try:
1377             # Must use fetch here instead of execute, because e.g. MySQL
1378             # must call store_result if the query has a result set
1379             # (or it will crash on a subsequent execute).
1380             self.fetch("SELECT %s FROM %s;" %
1381                        (self.column_name(clsname, name),
1382                         self.table_name(clsname)))
1383         except:
1384             return False
1385         return True
1386    
1387     def drop_property(self, cls, name):
1388         if self.has_property(cls, name):
1389             clsname = cls.__name__
1390             if self.has_index(cls, name):
1391                 self.drop_index(cls, name)
1392             self.execute("ALTER TABLE %s DROP COLUMN %s;" %
1393                          (self.table_name(clsname),
1394                           self.column_name(clsname, name)))
1395    
1396     def rename_property(self, cls, oldname, newname):
1397         clsname = cls.__name__
1398         oldname = self.column_name(clsname, oldname)
1399         newname = self.column_name(clsname, newname)
1400         if oldname != newname:
1401             self.execute("ALTER TABLE %s RENAME COLUMN %s TO %s;" %
1402                          (self.table_name(clsname), oldname, newname))
1403    
1404     def has_index(self, cls, name):
1405         tablename = self.table_name(cls.__name__, quoted=False)
1406         indices = [i.colname for i in self.get_indices(tablename)]
1407         return (name in indices)
1408    
1409     def drop_index(self, cls, name):
1410         clsname = cls.__name__
1411         self.execute('DROP INDEX %s ON %s;' %
1412                      (self.sql_name("i" + clsname + name),
1413                       self.table_name(clsname)))
1414
1415
1416 class Table:
1417     """A table in a database."""
1418    
1419     def __init__(self, name):
1420         self.name = name
1421         self.columns = []
1422    
1423     def __repr__(self):
1424         return "dejavu.db.Table(%s)" % repr(self.name)
1425
1426
1427 class Column:
1428     """A column in a table in a database."""
1429    
1430     def __init__(self, key, type, default=None):
1431         self.key = key
1432         self.type = type
1433         self.default = default
1434         self.hints = {}
1435    
1436     def __repr__(self):
1437         return ("dejavu.db.Column(%s, %s, default=%s, hints=%s)" %
1438                 (repr(self.key), repr(self.type),
1439                  repr(self.default), repr(self.hints))
1440                 )
1441
1442
1443 class Index:
1444     """An index on a table column (or columns) in a database."""
1445    
1446     def __init__(self, name, tablename, colname, pk=True, unique=True):
1447         self.name = name
1448         self.tablename = tablename
1449         self.colname = colname
1450         self.pk = pk
1451         self.unique = unique
1452    
1453     def __repr__(self):
1454         return ("dejavu.db.Index(%s, %s, %s, pk=%s, unique=%s)" %
1455                 (repr(self.name), repr(self.tablename), repr(self.colname),
1456                  repr(self.pk), repr(self.unique)))
1457
1458
1459 class OutOfConnectionsError(dejavu.DejavuError):
1460     """Exception raised when a database store has run out of connections."""
1461     pass
1462
Note: See TracBrowser for help on using the browser.