Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/geniusql.py

Revision 250 (checked in by fumanchu, 7 years ago)

Renamed dbmodel.py to geniusql.py.

  • Property svn:eol-style set to native
Line 
1 """Classes to model database realities.
2
3 Dejavu's Units are storage-agnostic, and the db.StorageManagerDB class is
4 DB-provider-agnostic. However, each application is deployed in the real
5 world, where we have to care about shifting column datatypes, indices, and
6 names. These objects model those realities, and provide StorageManagerDB a
7 way to create, read, update, and delete them.
8
9 The Table, Column, and Index objects present this metadata, and are
10 intentionally abstract. They should never contain any SQL or "smarts"
11 of any kind, besides the "qname", the quoted name, of the column or table.
12 At most, subclasses and consumers might put implementation-specific data
13 into them; for example, PostgreSQL uses a separate CREATE SEQUENCE statement
14 for autoincrement columns, and stores the name of the SEQUENCE in a new
15 Column.sequencename attribute.
16
17 The IndexSet, ColumnSet, and Database objects are all dict-like containers,
18 and therefore have a key for each value. Those keys should equate to things
19 at the consumer layer; for example, a Database may possess a pair of the
20 form: {'YoYo': Table('yoyo')} -- the key is the "friendly" name, but the
21 Table.name is a lowercase version of that, because that's what the database
22 uses in SQL to refer to that table.
23
24 """
25
26 import datetime
27 try:
28     import cPickle as pickle
29 except ImportError:
30     import pickle
31 import Queue
32
33 import sys
34
35 # Determine max bytes for int on this system.
36 maxint_bytes = 1
37 while True:
38     if sys.maxint <= 2 ** ((maxint_bytes * 8) - 1):
39         break
40     maxint_bytes += 1
41
42 # Determine max digits for float on this system. Crude but effective.
43 maxfloat_digits = 2
44 while True:
45     L = (2 ** (maxfloat_digits + 1)) - 1
46     if int(float(L)) != L:
47         break
48     maxfloat_digits += 1
49
50
51 import time
52 from types import FunctionType
53 import warnings
54 import weakref
55
56
57 try:
58     # Builtin in Python 2.5?
59     decimal
60 except NameError:
61     try:
62         # Module in Python 2.3, 2.4
63         import decimal
64     except ImportError:
65         decimal = None
66
67 try:
68     import fixedpoint
69 except ImportError:
70     fixedpoint = None
71
72
73 import dejavu
74 from dejavu import codewalk, errors, logflags
75
76
77 # ---------------------------- TYPE ADAPTERS ---------------------------- #
78
79
80 def getCoerceName(pytype):
81     """Return the name of the coercion method for a given Python type."""
82     mod = pytype.__module__
83     if mod == "__builtin__":
84         xform = "%s" % pytype.__name__
85     else:
86         xform = "%s_%s" % (mod, pytype.__name__)
87     xform = xform.replace(".", "_")
88     return xform
89
90 def getCoerceMethod(adapter, value, totype, fromtype):
91     """Return the coercion method for a given value [by type]."""
92     if isinstance(fromtype, str):
93         frombases = ()
94     else:
95         frombases = fromtype.__bases__
96         fromtype = getCoerceName(fromtype)
97    
98     if isinstance(totype, str):
99         tobases = ()
100     else:
101         tobases = totype.__bases__
102         totype = getCoerceName(totype)
103    
104     methods = []
105     if fromtype and totype:
106         methods.append("coerce_" + fromtype + "_to_" + totype)
107     if totype:
108         methods.append("coerce_any_to_" + totype)
109     if fromtype:
110         methods.append("coerce_" + fromtype + "_to_any")
111    
112     for meth in methods:
113         if hasattr(adapter, meth):
114             return getattr(adapter, meth)
115    
116     for base in tobases:
117         base = getCoerceName(base)
118         if fromtype:
119             meth = "coerce_" + fromtype + "_to_" + base
120             methods.append(meth)
121             if hasattr(adapter, meth):
122                 return getattr(adapter, meth)
123         meth = "coerce_any_to_" + base
124         methods.append(meth)
125         if hasattr(adapter, meth):
126             return getattr(adapter, meth)
127    
128     for base in frombases:
129         base = getCoerceName(base)
130         if totype:
131             meth = "coerce_" + base + "_to_" + totype
132             methods.append(meth)
133             if hasattr(adapter, meth):
134                 return getattr(adapter, meth)
135         meth = "coerce_" + base + "_to_any"
136         methods.append(meth)
137         if hasattr(adapter, meth):
138             return getattr(adapter, meth)
139    
140     raise TypeError("%s -> %s is not handled by %s.  Looked for: %s" %
141                     (fromtype, totype, adapter.__class__, ", ".join(methods)))
142
143
144 class AdapterToSQL(object):
145     """Coerce Expression constants to SQL.
146     
147     This base class is designed to work out-of-the-box with PostgreSQL 8.
148     """
149    
150     # You should REALLY check into your DB's encoding and override this.
151     encoding = 'utf8'
152    
153     # Notice these are ordered pairs. Escape \ before introducing new ones.
154     # Values in these two lists should be strings encoded with self.encoding.
155     escapes = [("'", "''"), ("\\", r"\\")]
156     like_escapes = [("%", r"\%"), ("_", r"\_")]
157    
158     # These are not the same as coerce_bool (which is used on one side of
159     # a comparison). Instead, these are used when the whole (sub)expression
160     # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3".
161     bool_true = "TRUE"
162     bool_false = "FALSE"
163    
164     def escape_like(self, value):
165         """Prepare a string value for use in a LIKE comparison."""
166         if not isinstance(value, str):
167             value = value.encode(self.encoding)
168         # Notice we strip leading and trailing quote-marks.
169         value = value.strip("'\"")
170         for pat, repl in self.like_escapes:
171             value = value.replace(pat, repl)
172         return value
173    
174     def coerce(self, value, dbtype="", pytype=None):
175         """Return value, coerced from (optional pytype) to dbtype."""
176         if pytype is None:
177             pytype = type(value)
178         if "(" in dbtype:
179             dbtype = dbtype[:dbtype.find("(")]
180         meth = getCoerceMethod(self, value, dbtype, pytype)
181         return meth(value)
182    
183     def tostr(self, value):
184         if isinstance(value, basestring):
185             return value.encode(self.encoding)
186         else:
187             return str(value)
188    
189     def coerce_NoneType_to_any(self, value):
190         return "NULL"
191    
192     def coerce_bool_to_any(self, value):
193         if value:
194             return 'TRUE'
195         return 'FALSE'
196    
197     # The great thing about these 3 date coercers is that you can use
198     # them with (VAR)CHAR columns just as well as with DATETIME, etc.
199     # and comparisons will still work!
200     def coerce_datetime_datetime_to_any(self, value):
201         return ("'%04d-%02d-%02d %02d:%02d:%02d'" %
202                 (value.year, value.month, value.day,
203                  value.hour, value.minute, value.second))
204    
205     def coerce_datetime_date_to_any(self, value):
206         return "'%04d-%02d-%02d'" % (value.year, value.month, value.day)
207    
208     def coerce_datetime_time_to_any(self, value):
209         return "'%02d:%02d:%02d'" % (value.hour, value.minute, value.second)
210    
211     def coerce_datetime_timedelta_to_any(self, value):
212         float_val = value.days + (value.seconds / 86400.0)
213         return repr(float_val)
214    
215     coerce_decimal_to_any = str
216     coerce_decimal_Decimal_to_any = str
217     def coerce_decimal_to_TEXT(self, value):
218         return "'%s'" % str(value)
219     coerce_decimal_Decimal_to_TEXT = coerce_decimal_to_TEXT
220    
221     def do_pickle(self, value):
222         # Note: dumps with protocol 0 uses the 'raw-unicode-escape'
223         # encoding, and we take pains not to re-encode it with
224         # self.encoding.
225         value = pickle.dumps(value)
226         value = self.coerce_str_to_any(value, skip_encoding=True)
227         return value
228    
229     coerce_dict_to_any = do_pickle
230    
231     coerce_fixedpoint_FixedPoint_to_any = tostr
232     def coerce_fixedpoint_FixedPoint_to_TEXT(self, value):
233         return "'%s'" % str(value)
234    
235     # Very important we use repr here so we get all 17 decimal digits.
236     coerce_float_to_any = repr
237     def coerce_float_to_TEXT(self, value):
238         return "'%s'" % repr(value)
239    
240     coerce_int_to_any = tostr
241    
242     coerce_list_to_any = do_pickle
243    
244     coerce_long_to_any = tostr
245     def coerce_long_to_TEXT(self, value):
246         return "'%s'" % str(value)
247    
248     def coerce_str_to_any(self, value, skip_encoding=False):
249         if not skip_encoding and not isinstance(value, str):
250             value = value.encode(self.encoding)
251         for pat, repl in self.escapes:
252             value = value.replace(pat, repl)
253         return "'" + value + "'"
254    
255     coerce_tuple_to_any = do_pickle
256    
257     coerce_unicode_to_any = coerce_str_to_any
258
259
260 class AdapterFromDB(object):
261     """Coerce incoming values from DB types to Python datatypes.
262     
263     This base class is designed to work out-of-the-box with PostgreSQL 8.
264     """
265    
266     # You should REALLY check into your DB's encoding and override this.
267     encoding = 'utf8'
268    
269     def coerce(self, value, dbtype, pytype):
270         """Return value, coerced from dbtype to pytype."""
271         # All columns could conceivably hold NULL => Python None
272         if value is None:
273             return None
274        
275         if "(" in dbtype:
276             dbtype = dbtype[:dbtype.find("(")]
277        
278         meth = getCoerceMethod(self, value, pytype, dbtype)
279         return meth(value)
280    
281     def do_pickle(self, value):
282         # Coerce to str for pickle.loads restriction.
283         if isinstance(value, unicode):
284             value = value.encode(self.encoding)
285         value = str(value)
286         return pickle.loads(value)
287    
288     coerce_any_to_bool = bool
289    
290     def coerce_any_to_datetime_datetime(self, value):
291         chunks = (value[0:4], value[5:7], value[8:10],
292                   value[11:13], value[14:16], value[17:19])
293         return datetime.datetime(*map(int, chunks))
294    
295     def coerce_any_to_datetime_date(self, value):
296         chunks = (value[0:4], value[5:7], value[8:10])
297         return datetime.date(*map(int, chunks))
298    
299     def coerce_any_to_datetime_time(self, value):
300         chunks = (value[0:2], value[3:5], value[6:8])
301         return datetime.time(*map(int, chunks))
302    
303     def coerce_any_to_datetime_timedelta(self, value):
304         days, seconds = divmod(value, 1)
305         return datetime.timedelta(days, int(seconds * 86400))
306    
307     def coerce_any_to_decimal(self, value):
308         return decimal(str(value))
309    
310     def coerce_any_to_decimal_Decimal(self, value):
311         return decimal.Decimal(str(value))
312    
313     coerce_any_to_dict = do_pickle
314    
315     def coerce_any_to_fixedpoint_FixedPoint(self, value):
316         if isinstance(value, basestring):
317             # Unicode really screws up fixedpoint; for example:
318             # >>> fixedpoint.FixedPoint(u'111111111111111111111111111.1')
319             # FixedPoint('111111111111111104952008704.00', 2)
320             value = str(value)
321            
322             scale = 0
323             atoms = value.rsplit(".", 1)
324             if len(atoms) > 1:
325                 scale = len(atoms[-1])
326             return fixedpoint.FixedPoint(value, scale)
327         else:
328             return fixedpoint.FixedPoint(value)
329    
330     coerce_any_to_float = float
331     coerce_any_to_int = int
332     coerce_any_to_list = do_pickle
333     coerce_any_to_long = long
334    
335     def coerce_any_to_str(self, value):
336         if isinstance(value, basestring):
337             return value.encode(self.encoding)
338         else:
339             return str(value)
340    
341     coerce_any_to_tuple = do_pickle
342    
343     def coerce_any_to_unicode(self, value):
344         if isinstance(value, unicode):
345             return value
346         else:
347             return unicode(value, self.encoding)
348
349
350 class TypeAdapter(object):
351     """Determine the best database type for a given column + Python type.
352     
353     This base class is designed to work out-of-the-box with PostgreSQL 8.
354     """
355    
356     # Max binary precision for floating-point columns (= 53 for PostgreSQL 8).
357     # Python floats are implemented using C doubles; actual precision
358     # depends on platform (but is usually 53 binary digits, see maxfloat_digits).
359     # PostgreSQL DOUBLE is 53 binary-digit precision.
360     float_max_precision = 53
361    
362     # Max decimal precision for NUMERIC columns (= 1000 for PostgreSQL 8).
363     numeric_max_precision = 1000
364    
365     # "The actual storage requirement is two bytes for each group of four
366     # decimal digits, plus eight bytes overhead." Note we omit the overhead.
367     numeric_max_bytes = 500
368    
369     def coerce(self, col, pytype):
370         """Return a database type for the given column object and Python type."""
371         xform = "coerce_" + getCoerceName(pytype)
372         try:
373             xform = getattr(self, xform)
374         except AttributeError:
375             raise TypeError("'%s' is not handled by %s." %
376                             (pytype, self.__class__))
377         return xform(col)
378    
379     def float_type(self, precision):
380         """Return a datatype which can handle floats of the given binary precision."""
381         if precision <= 24:
382             return "REAL"
383         else:
384             return "DOUBLE PRECISION"
385    
386     def coerce_float(self, col):
387         # Note that 'precision' is binary digits, not decimal.
388         precision = int(col.hints.get('precision', maxfloat_digits))
389         if precision > self.float_max_precision:
390             return "TEXT"
391         return self.float_type(precision)
392    
393     def coerce_str(self, col):
394         # The bytes hint shall not reflect the usual 4-byte base for varchar.
395         bytes = int(col.hints.get('bytes', 255))
396         if bytes and bytes <= 255:
397             return "VARCHAR(%s)" % bytes
398         # TEXT is not an SQL standard, but it's common.
399         return "TEXT"
400    
401     def coerce_dict(self, col):
402         return self.coerce_str(col)
403     def coerce_list(self, col):
404         return self.coerce_str(col)
405     def coerce_tuple(self, col):
406         return self.coerce_str(col)
407     def coerce_unicode(self, col):
408         return self.coerce_str(col)
409    
410     def coerce_bool(self, col): return "BOOLEAN"
411    
412     def coerce_datetime_datetime(self, col): return "TIMESTAMP"
413     def coerce_datetime_date(self, col): return "DATE"
414     def coerce_datetime_time(self, col): return "TIME"
415    
416     # I was seriously disinterested in writing a parser for interval.
417     def coerce_datetime_timedelta(self, col):
418         return self.coerce_float(col)
419    
420     def decimal_type(self, colname, precision, scale):
421         if precision > self.numeric_max_precision:
422             return "TEXT"
423         if scale > precision:
424             scale = precision
425         return "NUMERIC(%s, %s)" % (precision, scale)
426    
427     def coerce_decimal_Decimal(self, col):
428         precision = int(col.hints.get('precision', self.numeric_max_precision))
429         # Assume most people use decimal for money; default scale = 2.
430         scale = int(col.hints.get('scale', 2))
431         return self.decimal_type(col.name, precision, scale)
432    
433     def coerce_decimal(self, col):
434         # If decimal ever becomes a builtin. Python 2.5?
435         return self.coerce_decimal_Decimal(col)
436    
437     def coerce_fixedpoint_FixedPoint(self, col):
438         # Note that fixedpoint has no theoretical precision limit.
439         precision = int(col.hints.get('precision', self.numeric_max_precision))
440         # Assume most people use fixedpoint for money; default scale = 2.
441         scale = int(col.hints.get('scale', 2))
442         return self.decimal_type(col.name, precision, scale)
443    
444     def int_type(self, bytes):
445         """Return a datatype which can handle the given number of bytes."""
446         if bytes <= 2:
447             return "SMALLINT"
448         elif bytes <= 4:
449             return "INTEGER"
450         elif bytes <= 8:
451             # BIGINT is usually 8 bytes
452             return "BIGINT"
453         else:
454             # Anything larger than 8 bytes, use decimal/numeric.
455             # For PostgreSQL, "The actual storage requirement is two bytes
456             # for each group of four decimal digits, plus eight bytes
457             # overhead." Note we omit the overhead in our calculation.
458             return "NUMERIC(%s, 0)" % (bytes * 2)
459    
460     def coerce_long(self, col):
461         bytes = int(col.hints.get('bytes', self.numeric_max_bytes))
462         if bytes > self.numeric_max_bytes:
463             return "TEXT"
464         return self.int_type(bytes)
465    
466     def coerce_int(self, col):
467         bytes = int(col.hints.get('bytes', maxint_bytes))
468         if bytes > maxint_bytes:
469             return self.coerce_long(col)
470         return self.int_type(bytes)
471
472
473
474 # -------------------------- SQL DECOMPILATION -------------------------- #
475
476
477 class ConstWrapper(str):
478     """Wraps a constant for use in SQLDecompiler's stack.
479     
480     When we hit LOAD_CONST while decompiling, we occasionally need to keep
481     both the base and the coerced value around (see COMPARE_OP for use
482     of ConstWrapper.basevalue).
483     """
484     def __new__(self, basevalue, coerced_value):
485         newobj = str.__new__(ConstWrapper, coerced_value)
486         newobj.basevalue = basevalue
487         return newobj
488
489
490 # Stack sentinels
491 class Sentinel(object):
492    
493     def __init__(self, name):
494         self.name = name
495    
496     def __repr__(self):
497         return 'Stack Sentinel: %s' % self.name
498
499 kw_arg = Sentinel('Keyword Arg')
500 # cannot_represent exists so that a portion of an Expression can be
501 # labeled imperfect. For example, the function dejavu.iscurrentweek
502 # rarely has an SQL equivalent. All Units (which match the rest of the
503 # Expression) will be recalled; they can then be compared in expr(unit).
504 cannot_represent = Sentinel('Cannot Repr')
505
506
507 class SQLDecompiler(codewalk.LambdaDecompiler):
508     """SQLDecompiler(tables, expr, adapter=AdapterToSQL()).
509     
510     Produce SQL from a supplied Expression object, with a lambda of the form:
511         lambda x, **kw: ...
512     
513     Attributes of each argument in the signature will be mapped to table
514     columns. Keyword arguments should be bound using Expression.bind_args
515     before calling this decompiler.
516     """
517    
518     # Some constants are function or class objects,
519     # which should not be coerced.
520     no_coerce = (FunctionType,
521                  type,
522                  type(len),       # <type 'builtin_function_or_method'>
523                  )
524    
525     sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in')
526    
527     def __init__(self, tables, expr, adapter=AdapterToSQL()):
528         self.tables = tables
529         self.expr = expr
530         self.adapter = adapter
531         # Cache coerced booleans
532         self.T = adapter.coerce_bool_to_any(True)
533         self.F = adapter.coerce_bool_to_any(False)
534         obj = expr.func
535         codewalk.LambdaDecompiler.__init__(self, obj)
536    
537     def code(self):
538         self.imperfect = False
539         self.walk()
540         # After walk(), self.stack should be reduced to a single string,
541         # which is the SQL representation of our Expression.
542         result = self.stack[0]
543         if result is cannot_represent:
544             # The entire expression could not be evaluated.
545             result = self.adapter.bool_true
546         if result == self.T:
547             result = self.adapter.bool_true
548         if result == self.F:
549             result = self.adapter.bool_false
550         return result
551    
552     def visit_instruction(self, op, lo=None, hi=None):
553         # Get the instruction pointer for the current instruction.
554         ip = self.cursor - 3
555         if hi is None:
556             ip += 1
557             if lo is None:
558                 ip += 1
559        
560         terms = self.targets.get(ip)
561         if terms:
562             trueval = self.adapter.bool_true
563             falseval = self.adapter.bool_false
564             clause = self.stack[-1]
565             while terms:
566                 term, oper = terms.pop()
567                 if term is cannot_represent:
568                     # Use TRUE for the term, so all records are returned.
569                     term = trueval
570                 if clause is cannot_represent:
571                     # Use TRUE for the clause, so all records are returned.
572                     clause = trueval
573                
574                 # Blurg. SQL Server is *so* picky.
575                 if term == self.T:
576                     term = trueval
577                 elif term == self.F:
578                     term = falseval
579                 if clause == self.T:
580                     clause = trueval
581                 elif clause == self.F:
582                     clause = falseval
583                
584                 clause = "(%s) %s (%s)" % (term, oper.upper(), clause)
585            
586             # Replace TOS with the new clause, so that further
587             # combinations have access to it.
588             self.stack[-1] = clause
589             self.debug("clause:", clause, "\n")
590            
591             if op == 1:
592                 # Py2.4: The current instruction is POP_TOP, which means
593                 # the previous is probably JUMP_*. If so, we're going to
594                 # pop the value we just placed on the stack and lose it.
595                 # We need to replace the entry that the JUMP_* made in
596                 # self.targets with our new TOS.
597                 target = self.targets[self.last_target_ip]
598                 target[-1] = ((clause, target[-1][1]))
599                 self.debug("newtarget:", self.last_target_ip, target)
600    
601     def visit_LOAD_DEREF(self, lo, hi):
602         raise ValueError("Illegal reference found in %s." % self.expr)
603    
604     def visit_LOAD_GLOBAL(self, lo, hi):
605         raise ValueError("Illegal global found in %s." % self.expr)
606    
607     def visit_LOAD_FAST(self, lo, hi):
608         arg_index = lo + (hi << 8)
609         if arg_index < self.co_argcount:
610             # We've hit a reference to a positional arg, which in our
611             # case implies a reference to a DB table.
612             self.stack.append(self.tables[arg_index])
613         else:
614             # Since lambdas don't support local bindings,
615             # any remaining local name must be a keyword arg.
616             self.stack.append(kw_arg)
617    
618     def visit_LOAD_ATTR(self, lo, hi):
619         name = self.co_names[lo + (hi << 8)]
620         tos = self.stack.pop()
621         if isinstance(tos, tuple):
622             # The name in question refers to a DB column.
623             tablename, table = tos
624             col = table.columns[name]
625             if col.imperfect_type:
626                 atom = cannot_represent
627                 self.imperfect = True
628             else:
629                 atom = '%s.%s' % (tablename, col.qname)
630         else:
631             # 'tos.name' will reference an attribute of the tos object.
632             # Stick the tos and name in a tuple for later processing.
633             atom = (tos, name)
634         self.stack.append(atom)
635    
636     def visit_LOAD_CONST(self, lo, hi):
637         val = self.co_consts[lo + (hi << 8)]
638         if not isinstance(val, self.no_coerce):
639             val = ConstWrapper(val, self.adapter.coerce(val))
640         self.stack.append(val)
641    
642     def visit_BUILD_TUPLE(self, lo, hi):
643         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)])
644         self.stack.append("(" + terms + ")")
645    
646     visit_BUILD_LIST = visit_BUILD_TUPLE
647    
648     def visit_CALL_FUNCTION(self, lo, hi):
649         kwargs = {}
650         for i in xrange(hi):
651             val = self.stack.pop()
652             key = self.stack.pop()
653             kwargs[key] = val
654         kwargs = [k + "=" + v for k, v in kwargs.iteritems()]
655        
656         args = []
657         for i in xrange(lo):
658             arg = self.stack.pop()
659             args.append(arg)
660         args.reverse()
661        
662         if kwargs:
663             args += kwargs
664        
665         func = self.stack.pop()
666        
667         # Handle function objects.
668         if isinstance(func, tuple):
669             tos, name = func
670             dispatch = getattr(self, "attr_" + name, None)
671             if dispatch:
672                 self.stack.append(dispatch(tos, *args))
673                 return
674         else:
675             funcname = func.__module__ + "_" + func.__name__
676             funcname = funcname.replace(".", "_")
677             if funcname.startswith("_"):
678                 funcname = "func" + funcname
679             dispatch = getattr(self, funcname, None)
680             if dispatch:
681                 self.stack.append(dispatch(*args))
682                 return
683        
684         self.stack.append(cannot_represent)
685         self.imperfect = True
686    
687     def visit_COMPARE_OP(self, lo, hi):
688         op2, op1 = self.stack.pop(), self.stack.pop()
689         if op1 is cannot_represent or op2 is cannot_represent:
690             self.stack.append(cannot_represent)
691             return
692        
693         op = lo + (hi << 8)
694         if op in (6, 7):     # in, not in
695             value = self.containedby(op1, op2)
696             if op == 7:
697                 value = "NOT " + value
698             self.stack.append(value)
699         elif op1 == 'NULL':
700             if op in (2, 8):    # '==', is
701                 self.stack.append(op2 + " IS NULL")
702             elif op in (3, 9):  # '!=', 'is not'
703                 self.stack.append(op2 + " IS NOT NULL")
704             else:
705                 raise ValueError("Non-equality Null comparisons not allowed.")
706         elif op2 == 'NULL':
707             if op in (2, 8):    # '==', 'is'
708                 self.stack.append(op1 + " IS NULL")
709             elif op in (3, 9):  # '!=', 'is not'
710                 self.stack.append(op1 + " IS NOT NULL")
711             else:
712                 raise ValueError("Non-equality Null comparisons not allowed.")
713         else:
714             # Comparison operators for strings are case-sensitive in PG et al.
715             self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2)
716    
717     def binary_op(self, op):
718         op2, op1 = self.stack.pop(), self.stack.pop()
719         self.stack.append(op1 + " " + op + " " + op2)
720    
721     def visit_BINARY_SUBSCR(self):
722         # The only BINARY_SUBSCR used in Expressions should be kwargs[key].
723         name = self.stack.pop()
724         tos = self.stack.pop()
725         if tos is not kw_arg:
726             raise ValueError("Subscript %s of %s object not allowed."
727                              % (name, tos))
728         # name, since formed in LOAD_CONST, may have extraneous quotes.
729         name = name.strip("'\"")
730         value = self.expr.kwargs[name]
731         if not isinstance(value, self.no_coerce):
732             value = ConstWrapper(value, self.adapter.coerce(value))
733         self.stack.append(value)
734    
735     def visit_UNARY_NOT(self):
736         op = self.stack.pop()
737         if op is cannot_represent:
738             self.stack.append(cannot_represent)
739         else:
740             self.stack.append("NOT (" + op + ")")
741    
742     # --------------------------- Dispatchees --------------------------- #
743    
744     def attr_startswith(self, tos, arg):
745         return tos + " LIKE '" + self.adapter.escape_like(arg) + "%'"
746    
747     def attr_endswith(self, tos, arg):
748         return tos + " LIKE '%" + self.adapter.escape_like(arg) + "'"
749    
750     def containedby(self, op1, op2):
751         if isinstance(op1, ConstWrapper):
752             # Looking for text in a field. Use Like (reverse terms).
753             return op2 + " LIKE '%" + self.adapter.escape_like(op1) + "%'"
754         else:
755             # Looking for field in (a, b, c)
756             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
757             return op1 + " IN (" + ", ".join(atoms) + ")"
758    
759     def dejavu_icontainedby(self, op1, op2):
760         if isinstance(op1, ConstWrapper):
761             # Looking for text in a field. Use Like (reverse terms).
762             return ("LOWER(" + op2 + ") LIKE '%" +
763                     self.adapter.escape_like(op1).lower() + "%'")
764         else:
765             # Looking for field in (a, b, c).
766             # Force all args to lowercase for case-insensitive comparison.
767             atoms = [self.adapter.coerce(x).lower() for x in op2.basevalue]
768             return "LOWER(%s) IN (%s)" % (op1, ", ".join(atoms))
769    
770     def dejavu_icontains(self, x, y):
771         return self.dejavu_icontainedby(y, x)
772    
773     def dejavu_istartswith(self, x, y):
774         return "LOWER(" + x + ") LIKE '" + self.adapter.escape_like(y) + "%'"
775    
776     def dejavu_iendswith(self, x, y):
777         return "LOWER(" + x + ") LIKE '%" + self.adapter.escape_like(y) + "'"
778    
779     def dejavu_ieq(self, x, y):
780         return "LOWER(" + x + ") = LOWER(" + y + ")"
781    
782     def dejavu_now(self):
783         return "NOW()"
784    
785     def dejavu_today(self):
786         return "CURRENT_DATE"
787    
788     def dejavu_year(self, x):
789         return "YEAR(" + x + ")"
790
791     def dejavu_month(self, x):
792         return "MONTH(" + x + ")"
793    
794     def func__builtin___len(self, x):
795         return "LENGTH(" + x + ")"
796
797
798
799 # ------------------------- Connection Factories ------------------------- #
800
801
802 class ConnectionWrapper(object):
803     """Connection object wrapper, so it can be used as a weak reference."""
804    
805     def __init__(self, conn=None):
806         self.conn = conn
807    
808     def __getattr__(self, attr):
809         return getattr(self.conn, attr)
810
811
812 class OutOfConnectionsError(errors.DejavuError):
813     """Exception raised when a database store has run out of connections."""
814     pass
815
816
817 class ConnectionFactory(object):
818     """A connection factory which creates a new connection for each request."""
819    
820     def __init__(self, open, close, retry=5):
821         self.open = open
822         self.close = close
823         self.retry = retry
824         self.refs = {}
825    
826     def __call__(self):
827         """Return a connection."""
828         for i in xrange(self.retry):
829             try:
830                 conn = self.open()
831                 w = ConnectionWrapper(conn)
832                 self.refs[weakref.ref(w, self._release)] = w.conn
833                 return w
834             except OutOfConnectionsError:
835                 time.sleep(i + 1)
836                 conn = None
837         raise OutOfConnectionsError()
838    
839     def _release(self, ref):
840         """Release a connection."""
841         self.close(self.refs.pop(ref))
842    
843     def shutdown(self):
844         """Release all database connections."""
845         # Empty self.refs.
846         while self.refs:
847             ref, conn = self.refs.popitem()
848             self.close(conn)
849
850
851 class ConnectionPool(object):
852     """A database connection factory which keeps a pool of connections."""
853    
854     def __init__(self, open, close, size=10, retry=5):
855         self.open = open
856         self.close = close
857         self.refs = {}
858         self.pool = Queue.Queue(size)
859         self.retry = retry
860    
861     def __call__(self):
862         """Return a connection from the pool."""
863         for i in xrange(self.retry):
864             try:
865                 conn = self.pool.get_nowait()
866                 # Okay, this is freaky. If we wrap here, all goes well.
867                 # If we wrap on Queue.put(), mysql crashes after 1700
868                 # or so inserts (when migrating Access tables to MySQL).
869                 # Go figure.
870                 w = ConnectionWrapper(conn)
871                 self.refs[weakref.ref(w, self._release)] = w.conn
872                 return w
873             except Queue.Empty:
874                 pass
875            
876             try:
877                 conn = self.open()
878                 w = ConnectionWrapper(conn)
879                 self.refs[weakref.ref(w, self._release)] = w.conn
880                 return w
881             except OutOfConnectionsError:
882                 time.sleep(i + 1)
883                 conn = None
884         raise OutOfConnectionsError()
885    
886     def _release(self, ref):
887         """Release a connection."""
888         conn = self.refs.pop(ref)
889         try:
890             self.pool.put_nowait(conn)
891             return
892         except Queue.Full:
893             pass
894         self.close(conn)
895    
896     def shutdown(self):
897         """Release all database connections."""
898         # Empty the pool.
899         while True:
900             try:
901                 self.pool.get(True, 0.5)
902             except Queue.Empty:
903                 break
904        
905         # Empty self.refs.
906         while self.refs:
907             ref, conn = self.refs.popitem()
908             self.close(conn)
909
910
911 class SingleConnection(object):
912     """A single database connection for all consumers.
913     
914     Use this when your database cannot handle multiple connections at once,
915     but can handle multiple threads using the same connection.
916     """
917    
918     def __init__(self, open, close):
919         self.open = open
920         self.close = close
921         # Delay opening the connection, because the
922         # SM may need to create the database first.
923         self._conn = None
924    
925     def __call__(self):
926         """Return our only connection."""
927         if self._conn is None:
928             self._conn = self.open()
929         return self._conn
930    
931     def shutdown(self):
932         """Release all database connections."""
933         if self._conn is not None:
934             self.close(self._conn)
935             self._conn = None
936
937
938
939 # -------------------------- DATABASE OBJECTS -------------------------- #
940
941
942 class Index:
943     """An index on a table column (or columns) in a database."""
944    
945     def __init__(self, name, qname, tablename, colname, pk=True, unique=True):
946         self.name = name
947         self.qname = qname
948         self.tablename = tablename
949         self.colname = colname
950         self.pk = pk
951         self.unique = unique
952    
953     def __repr__(self):
954         return ("%s.%s(%s, %s, %s, pk=%s, unique=%s)" %
955                 (self.__module__, self.__class__.__name__,
956                  repr(self.name), repr(self.tablename), repr(self.colname),
957                  repr(self.pk), repr(self.unique)))
958    
959     def __copy__(self):
960         return self.__class__(self.name, self.qname, self.tablename,
961                               self.colname, self.pk, self.unique)
962     copy = __copy__
963
964
965 class IndexSet(dict):
966    
967     def __new__(cls, table):
968         return dict.__new__(cls)
969    
970     def __init__(self, table):
971         dict.__init__(self)
972         self.table = table
973    
974     def __delitem__(self, key):
975         """Drop the specified index."""
976         t = self.table.db
977         t.execute('DROP INDEX %s ON %s;' % (self[key].qname, t.qname))
978
979
980 class Column:
981     """A column in a table in a database."""
982    
983     def __init__(self, name, qname, dbtype, default=None, hints=None):
984         self.name = name
985         self.qname = qname
986         self.dbtype = dbtype
987         self.default = default
988         if hints is None:
989             hints = {}
990         self.hints = hints
991         # If autoincrement, the initial value should be put in self.default.
992         self.autoincrement = False
993         self.imperfect_type = False
994    
995     def __repr__(self):
996         return ("%s.%s(%s, dbtype=%s, default=%s, hints=%s)" %
997                 (self.__module__, self.__class__.__name__,
998                  repr(self.name), self.dbtype,
999                  repr(self.default), repr(self.hints))
1000                 )
1001    
1002     def __copy__(self):
1003         return self.__class__(self.name, self.qname, self.dbtype,
1004                               self.default, self.hints.copy())
1005     copy = __copy__
1006
1007
1008 class ColumnSet(dict):
1009    
1010     indexsetclass = IndexSet
1011    
1012     def __new__(cls, table):
1013         return dict.__new__(cls)
1014    
1015     def __init__(self, table):
1016         dict.__init__(self)
1017         self.table = table
1018         self.indices = self.indexsetclass(self.table)
1019    
1020     def __setitem__(self, key, column):
1021         t = self.table
1022         if key in self:
1023             del self[key]
1024        
1025         default = column.default or ""
1026         if default:
1027             default = " DEFAULT %s" % t.db.adaptertosql.coerce(default, column.dbtype)
1028        
1029         t.db.execute("ALTER TABLE %s ADD COLUMN %s %s%s;" %
1030                      (t.qname, column.qname, column.dbtype, default))
1031         dict.__setitem__(self, key, column)
1032    
1033     def __delitem__(self, key):
1034         if key in self.indices:
1035             del self.indices[key]
1036         t = self.table
1037         t.db.execute("ALTER TABLE %s DROP COLUMN %s;" %
1038                            (t.qname, self[key].qname))
1039         dict.__delitem__(self, key)
1040    
1041     def _rename(self, oldcol, newcol):
1042         # Override this to do the actual rename at the DB level.
1043         t = self.table
1044         t.db.execute("ALTER TABLE %s RENAME COLUMN %s TO %s;" %
1045                      (t.qname, oldcol.qname, newcol.qname))
1046    
1047     def rename(self, oldkey, newkey):
1048         """Rename a Column."""
1049         oldcol = self[oldkey]
1050         oldname = oldcol.name
1051         t = self.table
1052         newname = t.db.column_name(self.table.name, newkey)
1053        
1054         if oldname != newname:
1055             newcol = oldcol.copy()
1056             newcol.name = newname
1057             newcol.qname = t.db.quote(newname)
1058             self._rename(oldcol, newcol)
1059        
1060         # Use the superclass calls to avoid DROP COLUMN/ADD COLUMN.
1061         dict.__delitem__(self, oldkey)
1062         dict.__setitem__(self, newkey, newcol)
1063
1064
1065 class Table(object):
1066     """A table in a database.
1067     
1068     db: the database for this table.
1069     name: the SQL name for this table (unquoted).
1070     columns: a dict of {key: Column object} for this table. The 'key'
1071         argument should be a name you use for the column in your Python code,
1072         whereas the Column.name is the SQL name for the column (unquoted).
1073     """
1074    
1075     def __init__(self, db, name, qname):
1076         self.db = db
1077         self.name = name
1078         self.qname = qname
1079         self.columns = db.columnsetclass(self)
1080    
1081     def __repr__(self):
1082         return "%s.%s(%s)" % (self.__module__, self.__class__.__name__,
1083                               repr(self.name))
1084    
1085     def __copy__(self):
1086         t = self.__class__(self.db, self.name, self.qname)
1087         for key, c in self.columns.iteritems():
1088             dict.__setitem__(t.columns, key, c.copy())
1089         for key, i in self.columns.indices.iteritems():
1090             dict.__setitem__(t.columns.indices, key, i.copy())
1091         return t
1092     copy = __copy__
1093
1094
1095
1096 class Database(dict):
1097     """A dict for managing a set of tables."""
1098    
1099     decompiler = SQLDecompiler
1100     adaptertosql = AdapterToSQL()
1101     adapterfromdb = AdapterFromDB()
1102     typeadapter = TypeAdapter()
1103    
1104     columnsetclass = ColumnSet
1105    
1106     def __new__(cls, name, **kwargs):
1107         return dict.__new__(cls)
1108    
1109     def __init__(self, name, **kwargs):
1110         dict.__init__(self)
1111         for k, v in kwargs.iteritems():
1112             setattr(self, k, v)
1113        
1114         self.name = self.sql_name(name)
1115         self.qname = self.quote(self.name)
1116         self.connect()
1117    
1118     def _get_tables(self, conn=None):
1119         raise NotImplementedError
1120    
1121     def _get_columns(self, tablename, conn=None):
1122         raise NotImplementedError
1123    
1124     def _get_indices(self, tablename, conn=None):
1125         raise NotImplementedError
1126    
1127     def python_type(self, dbtype):
1128         """Return a Python type which can store values of the given dbtype."""
1129         raise TypeError("Database type %s could not be converted "
1130                         "to a Python type." % repr(dbtype))
1131    
1132     def db_type(self, col, pytype):
1133         """Return a database type which can store values of the given pytype."""
1134         return self.typeadapter.coerce(col, pytype)
1135    
1136     def isrelatedtype(self, pytype1, pytype2):
1137         """If values of both types are expressed with the same SQL, return True."""
1138         if issubclass(pytype1, pytype2) or issubclass(pytype2, pytype1):
1139             return True
1140         if issubclass(pytype1, basestring) and issubclass(pytype2, basestring):
1141             return True
1142         if ((issubclass(pytype1, int) or issubclass(pytype1, long)) and
1143             (issubclass(pytype2, int) or issubclass(pytype2, long))):
1144             return True
1145         if fixedpoint:
1146             if decimal:
1147                 if ((issubclass(pytype1, fixedpoint.FixedPoint)
1148                      or issubclass(pytype1, decimal.Decimal)) and
1149                     (issubclass(pytype2, fixedpoint.FixedPoint)
1150                      or issubclass(pytype2, decimal.Decimal))):
1151                     return True
1152             else:
1153                 if (issubclass(pytype1, fixedpoint.FixedPoint) and
1154                     issubclass(pytype2, fixedpoint.FixedPoint)):
1155                     return True
1156         else:
1157             if decimal:
1158                 if (issubclass(pytype1, decimal.Decimal) and
1159                     issubclass(pytype2, decimal.Decimal)):
1160                     return True
1161         return False
1162    
1163     def __setitem__(self, key, table):
1164         if key in self:
1165             del self[key]
1166        
1167         fields = []
1168         for col in table.columns.itervalues():
1169             default = col.default or ""
1170             if default:
1171                 default = " DEFAULT %s" % self.adaptertosql.coerce(default, col.dbtype)
1172             fields.append('%s %s%s' % (col.qname, col.dbtype, default))
1173        
1174         self.execute('CREATE TABLE %s (%s);' %
1175                      (table.qname, ", ".join(fields)))
1176        
1177         for index in table.columns.indices.itervalues():
1178             self.execute('CREATE INDEX %s ON %s (%s);' %
1179                          (index.qname, table.qname,
1180                           self.quote(index.colname)))
1181        
1182         dict.__setitem__(self, key, table)
1183    
1184     def __delitem__(self, key):
1185         self.execute('DROP TABLE %s;' % self[key].qname)
1186         dict.__delitem__(self, key)
1187    
1188     def _rename(self, oldtable, newtable):
1189         # Override this to do the actual rename at the DB level.
1190         raise NotImplementedError
1191    
1192     def rename(self, oldkey, newkey):
1193         """Rename a Table."""
1194         oldtable = self[oldkey]
1195         oldname = oldtable.name
1196         newname = self.table_name(newkey)
1197        
1198         if oldname != newname:
1199             newtable = oldtable.copy
1200             newtable.name = newname
1201             newtable.qname = self.quote(newname)
1202             self._rename(oldtable, newname)
1203        
1204         # Use the superclass calls to avoid DROP TABLE/CREATE TABLE.
1205         dict.__delitem__(self, oldkey)
1206         dict.__setitem__(self, newkey, newtable)
1207    
1208     #                               Naming                               #
1209    
1210     sql_name_max_length = 64
1211     sql_name_caseless = False
1212     Prefix = ""
1213    
1214     def quote(self, name):
1215         """Return name, quoted for use in an SQL statement."""
1216         # This base class doesn't use "quote",
1217         # but most subclasses will.
1218         return name
1219    
1220     def sql_name(self, key):
1221         """Return the native SQL version of key."""
1222         if self.sql_name_caseless:
1223             key = key.lower()
1224        
1225         maxlen = self.sql_name_max_length
1226         if maxlen and len(key) > maxlen:
1227             warnings.warn("The name '%s' is longer than the maximum of "
1228                           "%s characters." % (key, maxlen),
1229                           errors.StorageWarning)
1230             key = key[:maxlen]
1231        
1232         return key
1233    
1234     def column_name(self, tablekey, columnkey):
1235         """Return the SQL column name for the given table and column keys."""
1236         # If you want to use a map from UnitProperty names
1237         # to DB column names, override this method (that's why
1238         # the tablename must be included in the args).
1239         return self.sql_name(columnkey)
1240    
1241     def make_column(self, tablekey, columnkey, pytype, default, hints):
1242         """Return a Column object from the given table and column keys."""
1243         name = self.column_name(tablekey, columnkey)
1244         col = Column(name, self.quote(name), None, default, hints.copy())
1245         col.dbtype = self.db_type(col, pytype)
1246         pytype2 = self.python_type(col.dbtype)
1247         col.imperfect_type = not self.isrelatedtype(pytype, pytype2)
1248         return col
1249    
1250     def table_name(self, key):
1251         """Return the SQL table name for the given key."""
1252         # If you want to use a map from Unit class names
1253         # to DB table names, override this method.
1254         return self.sql_name(self.Prefix + key)
1255    
1256     def make_table(self, tablekey):
1257         name = self.table_name(tablekey)
1258         return Table(self, name, self.quote(name))
1259    
1260     def make_index(self, tablekey, columnkey):
1261         name = self.table_name("i" + tablekey + columnkey)
1262         return Index(name, self.quote(name), self.table_name(tablekey),
1263                      self.column_name(tablekey, columnkey))
1264    
1265     #                              Retrieval                              #
1266    
1267     def select(self, tablekey, expr, columnkeys=None, distinct=False):
1268         """Return an SQL SELECT statement, and an 'imperfect' flag.
1269         
1270         imperfect: True or False depending on whether the generated SQL
1271             perfectly satisfies the given expression.
1272         """
1273         t = self[tablekey]
1274         if columnkeys:
1275             colnames = [t.columns[x].qname for x in columnkeys]
1276             if distinct:
1277                 sql = 'SELECT DISTINCT %s FROM %s'
1278             else:
1279                 sql = 'SELECT %s FROM %s'
1280             sql = sql % (', '.join(colnames), t.qname)
1281         else:
1282             sql = 'SELECT * FROM %s' % t.qname
1283        
1284         w, i = self.where(t, expr)
1285         if len(w) > 0:
1286             w = " WHERE " + w
1287         else:
1288             w = ""
1289        
1290         sql += w + ";"
1291         return sql, i
1292    
1293     def where(self, tables, expr):
1294         """Return an SQL WHERE clause, and an 'imperfect' flag.
1295         
1296         tables: a Table object, a list of Table objects,
1297             or a list of (quoted-name-or-alias, Table) tuples
1298         
1299         imperfect: True or False depending on whether the generated SQL
1300             perfectly satisfies the given expression.
1301         """
1302         if not isinstance(tables, list):
1303             tables = [tables]
1304         for i, t in enumerate(tables):
1305             if not isinstance(t, (tuple, list)):
1306                 tables[i] = (t.qname, t)
1307        
1308         decom = self.decompiler(tables, expr, self.adaptertosql)
1309         return decom.code(), decom.imperfect
1310    
1311     #                             Connecting                              #
1312    
1313     poolsize = 10
1314    
1315     def connect(self):
1316         if self.poolsize > 0:
1317             self.connection = ConnectionPool(self._get_conn, self._del_conn,
1318                                              self.poolsize)
1319         else:
1320             self.connection = ConnectionFactory(self._get_conn, self._del_conn)
1321    
1322     def _get_conn(self):
1323         # Override this with the connection call for your DB. Example:
1324         #     return libpq.PQconnectdb(self.connstring)
1325         raise NotImplementedError
1326    
1327     def _del_conn(self, conn):
1328         # Override this with the close call (if any) for your DB.
1329         conn.close()
1330    
1331     def disconnect(self):
1332         """Release all database connections."""
1333         self.connection.shutdown()
1334    
1335     def log(self, msg, level):
1336         pass
1337    
1338     def execute(self, query, conn=None):
1339         """execute(query, conn=None) -> result set."""
1340         if conn is None:
1341             conn = self.connection()
1342         if isinstance(query, unicode):
1343             query = query.encode(self.adaptertosql.encoding)
1344         self.log(query, logflags.SQL)
1345         return conn.query(query)
1346    
1347     def fetch(self, query, conn=None):
1348         """fetch(query, conn=None) -> rowdata, columns.
1349         
1350         query should be a SQL query in string format
1351         rowdata will be an iterable of iterables containing the result values.
1352         columns will be an iterable of (column name, data type) pairs.
1353         
1354         This base class uses SQLite3 syntax.
1355         """
1356         res = self.execute(query, conn)
1357         return res.row_list, res.col_defs
1358    
1359     def create_database(self):
1360         self.execute("CREATE DATABASE %s;" % self.qname)
1361         self.clear()
1362    
1363     def drop_database(self):
1364         self.execute("DROP DATABASE %s;" % self.qname)
1365         self.clear()
1366
Note: See TracBrowser for help on using the browser.