Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/geniusql.py

Revision 310 (checked in by fumanchu, 7 years ago)

Oops. I didn't want to default to implicit_trans yet.

  • Property svn:eol-style set to native
Line 
1 """Classes to model database realities.
2
3 Dejavu's Units are storage-agnostic, and the db.StorageManagerDB class is
4 DB-provider-agnostic. However, each application is deployed in the real
5 world, where we have to care about shifting column datatypes, indices, and
6 names. These objects model those realities, and provide StorageManagerDB a
7 way to create, read, update, and delete them.
8
9 The Table, Column, and Index objects present this metadata, and are
10 intentionally abstract. They should never contain any SQL or "smarts"
11 of any kind, besides the "qname", the quoted name, of the column or table.
12 At most, subclasses and consumers might put implementation-specific data
13 into them; for example, PostgreSQL uses a separate CREATE SEQUENCE statement
14 for autoincrement columns, and stores the name of the SEQUENCE in a new
15 Column.sequencename attribute.
16
17 The IndexSet, ColumnSet, and Database objects are all dict-like containers,
18 and therefore have a key for each value. Those keys should equate to things
19 at the consumer layer; for example, a Database may possess a pair of the
20 form: {'YoYo': Table('yoyo')} -- the key is the "friendly" name, but the
21 Table.name is a lowercase version of that, because that's what the database
22 uses in SQL to refer to that table.
23
24 """
25
26 import datetime
27 try:
28     import cPickle as pickle
29 except ImportError:
30     import pickle
31 import Queue
32
33 import sys
34
35 # Determine max bytes for int on this system.
36 maxint_bytes = 1
37 while True:
38     if sys.maxint <= 2 ** ((maxint_bytes * 8) - 1):
39         break
40     maxint_bytes += 1
41
42 # Determine max digits for float on this system. Crude but effective.
43 maxfloat_digits = 2
44 while True:
45     L = (2 ** (maxfloat_digits + 1)) - 1
46     if int(float(L)) != L:
47         break
48     maxfloat_digits += 1
49
50
51 import time
52 import threading
53 from types import FunctionType
54 import warnings
55 import weakref
56
57
58 try:
59     # Builtin in Python 2.5?
60     decimal
61 except NameError:
62     try:
63         # Module in Python 2.3, 2.4
64         import decimal
65     except ImportError:
66         decimal = None
67
68 try:
69     import fixedpoint
70 except ImportError:
71     fixedpoint = None
72
73
74 import dejavu
75 from dejavu import codewalk, errors
76
77
78 # ---------------------------- TYPE ADAPTERS ---------------------------- #
79
80
81 def getCoerceName(pytype):
82     """Return the name of the coercion method for a given Python type."""
83     mod = pytype.__module__
84     if mod == "__builtin__":
85         xform = "%s" % pytype.__name__
86     else:
87         xform = "%s_%s" % (mod, pytype.__name__)
88     xform = xform.replace(".", "_")
89     return xform
90
91 def getCoerceMethod(adapter, value, totype, fromtype):
92     """Return the coercion method for a given value [by type].
93     
94     Possible coercion methods are searched in the following order:
95       1. Exact match:    coerce       <fromtype> to <totype>
96       2. Exact totype:   coerce              any to <totype>
97       3. Exact fromtype: coerce       <fromtype> to any
98       4. totype.bases:   coerce       <fromtype> to <totype.base1>
99                          coerce              any to <totype.base1>
100                          coerce       <fromtype> to <totype.base2>...
101       5. fromtype.bases: coerce <fromtype.base1> to <totype>
102                          coerce <fromtype.base1> to any
103                          coerce <fromtype.base2> to <totype>...
104     
105     If no matching coercion method is found, a TypeError is raised.
106     """
107     if isinstance(fromtype, str):
108         frombases = ()
109     else:
110         frombases = fromtype.__bases__
111         fromtype = getCoerceName(fromtype)
112    
113     if isinstance(totype, str):
114         tobases = ()
115     else:
116         tobases = totype.__bases__
117         totype = getCoerceName(totype)
118    
119     methods = []
120     if fromtype and totype:
121         methods.append("coerce_" + fromtype + "_to_" + totype)
122     if totype:
123         methods.append("coerce_any_to_" + totype)
124     if fromtype:
125         methods.append("coerce_" + fromtype + "_to_any")
126    
127     for meth in methods:
128         if hasattr(adapter, meth):
129             return getattr(adapter, meth)
130    
131     for base in tobases:
132         base = getCoerceName(base)
133         if fromtype:
134             meth = "coerce_" + fromtype + "_to_" + base
135             methods.append(meth)
136             if hasattr(adapter, meth):
137                 return getattr(adapter, meth)
138         meth = "coerce_any_to_" + base
139         methods.append(meth)
140         if hasattr(adapter, meth):
141             return getattr(adapter, meth)
142    
143     for base in frombases:
144         base = getCoerceName(base)
145         if totype:
146             meth = "coerce_" + base + "_to_" + totype
147             methods.append(meth)
148             if hasattr(adapter, meth):
149                 return getattr(adapter, meth)
150         meth = "coerce_" + base + "_to_any"
151         methods.append(meth)
152         if hasattr(adapter, meth):
153             return getattr(adapter, meth)
154    
155     raise TypeError("%s -> %s is not handled by %s.  Looked for: %s" %
156                     (fromtype, totype, adapter.__class__, ", ".join(methods)))
157
158
159 class AdapterToSQL(object):
160     """Coerce Expression constants to SQL.
161     
162     This base class is designed to work out-of-the-box with PostgreSQL 8.
163     """
164    
165     # You should REALLY check into your DB's encoding and override this.
166     encoding = 'utf8'
167    
168     # Notice these are ordered pairs. Escape \ before introducing new ones.
169     # Values in these two lists should be strings encoded with self.encoding.
170     escapes = [("'", "''"), ("\\", r"\\")]
171     like_escapes = [("%", r"\%"), ("_", r"\_")]
172    
173     # These are not the same as coerce_bool (which is used on one side of
174     # a comparison). Instead, these are used when the whole (sub)expression
175     # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3".
176     bool_true = "TRUE"
177     bool_false = "FALSE"
178    
179     def escape_like(self, value):
180         """Prepare a string value for use in a LIKE comparison."""
181         if not isinstance(value, str):
182             value = value.encode(self.encoding)
183         # Notice we strip leading and trailing quote-marks.
184         value = value.strip("'\"")
185         for pat, repl in self.like_escapes:
186             value = value.replace(pat, repl)
187         return value
188    
189     def coerce(self, value, dbtype="", pytype=None):
190         """Return value, coerced from (optional pytype) to dbtype."""
191         if pytype is None:
192             pytype = type(value)
193         if "(" in dbtype:
194             dbtype = dbtype[:dbtype.find("(")]
195         meth = getCoerceMethod(self, value, dbtype, pytype)
196         return meth(value)
197    
198     def coerce_NoneType_to_any(self, value):
199         return "NULL"
200    
201     def coerce_bool_to_any(self, value):
202         if value:
203             return 'TRUE'
204         return 'FALSE'
205    
206     # The great thing about these 3 date coercers is that you can use
207     # them with (VAR)CHAR columns just as well as with DATETIME, etc.
208     # and comparisons will still work!
209     def coerce_datetime_datetime_to_any(self, value):
210         return ("'%04d-%02d-%02d %02d:%02d:%02d'" %
211                 (value.year, value.month, value.day,
212                  value.hour, value.minute, value.second))
213    
214     def coerce_datetime_date_to_any(self, value):
215         return "'%04d-%02d-%02d'" % (value.year, value.month, value.day)
216    
217     def coerce_datetime_time_to_any(self, value):
218         return "'%02d:%02d:%02d'" % (value.hour, value.minute, value.second)
219    
220     def coerce_datetime_timedelta_to_any(self, value):
221         float_val = value.days + (value.seconds / 86400.0)
222         return repr(float_val)
223    
224     def _to_TEXT(self, value):
225         return "'%s'" % str(value)
226    
227     coerce_decimal_to_any = str
228     coerce_decimal_Decimal_to_any = str
229     coerce_decimal_to_TEXT = _to_TEXT
230     coerce_decimal_Decimal_to_TEXT = _to_TEXT
231    
232     def do_pickle(self, value):
233         # dumps with protocol 0 uses the 'raw-unicode-escape' encoding,
234         # and we take pains not to re-encode it with self.encoding.
235         # We can't use protocol 1 or 2 (which would use UTF-8) because
236         # that introduces null bytes into the SQL, which is a no-no.
237         value = pickle.dumps(value)
238         value = self.coerce_str_to_any(value, skip_encoding=True)
239         return value
240    
241     coerce_dict_to_any = do_pickle
242    
243     coerce_fixedpoint_FixedPoint_to_any = str
244     coerce_fixedpoint_FixedPoint_to_TEXT = _to_TEXT
245    
246     # Very important we use repr here so we get all 17 decimal digits.
247     coerce_float_to_any = repr
248     coerce_float_to_TEXT = _to_TEXT
249     coerce_int_to_any = str
250     coerce_int_to_TEXT = _to_TEXT
251     coerce_list_to_any = do_pickle
252     coerce_long_to_any = str
253     coerce_long_to_TEXT = _to_TEXT
254    
255     def coerce_str_to_any(self, value, skip_encoding=False):
256         if not skip_encoding and not isinstance(value, str):
257             value = value.encode(self.encoding)
258         for pat, repl in self.escapes:
259             value = value.replace(pat, repl)
260         return "'" + value + "'"
261    
262     coerce_tuple_to_any = do_pickle
263    
264     coerce_unicode_to_any = coerce_str_to_any
265
266
267 class AdapterFromDB(object):
268     """Coerce incoming values from DB types to Python datatypes.
269     
270     This base class is designed to work out-of-the-box with PostgreSQL 8.
271     """
272    
273     # You should REALLY check into your DB's encoding and override this.
274     encoding = 'utf8'
275    
276     def coerce(self, value, dbtype, pytype):
277         """Return value, coerced from dbtype to pytype."""
278         # All columns could conceivably hold NULL => Python None
279         if value is None:
280             return None
281        
282         if "(" in dbtype:
283             dbtype = dbtype[:dbtype.find("(")]
284        
285         meth = getCoerceMethod(self, value, pytype, dbtype)
286         return meth(value)
287    
288     def do_pickle(self, value):
289         # Coerce to str for pickle.loads restriction.
290         value = self.coerce_any_to_str(value)
291         return pickle.loads(value)
292    
293     coerce_any_to_bool = bool
294    
295     def coerce_any_to_datetime_datetime(self, value):
296         chunks = (value[0:4], value[5:7], value[8:10],
297                   value[11:13], value[14:16], value[17:19],
298                   value[20:26] or 0)
299         return datetime.datetime(*map(int, chunks))
300    
301     def coerce_any_to_datetime_date(self, value):
302         chunks = (value[0:4], value[5:7], value[8:10])
303         return datetime.date(*map(int, chunks))
304    
305     def coerce_any_to_datetime_time(self, value):
306         chunks = (value[0:2], value[3:5], value[6:8])
307         return datetime.time(*map(int, chunks))
308    
309     def coerce_any_to_datetime_timedelta(self, value):
310         days, seconds = divmod(value, 1)
311         return datetime.timedelta(days, int(seconds * 86400))
312    
313     def coerce_any_to_decimal(self, value):
314         return decimal(str(value))
315    
316     def coerce_any_to_decimal_Decimal(self, value):
317         return decimal.Decimal(str(value))
318    
319     coerce_any_to_dict = do_pickle
320    
321     def coerce_any_to_fixedpoint_FixedPoint(self, value):
322         if isinstance(value, basestring):
323             # Unicode really screws up fixedpoint; for example:
324             # >>> fixedpoint.FixedPoint(u'111111111111111111111111111.1')
325             # FixedPoint('111111111111111104952008704.00', 2)
326             value = str(value)
327            
328             scale = 0
329             atoms = value.rsplit(".", 1)
330             if len(atoms) > 1:
331                 scale = len(atoms[-1])
332             return fixedpoint.FixedPoint(value, scale)
333         else:
334             return fixedpoint.FixedPoint(value)
335    
336     coerce_any_to_float = float
337     coerce_any_to_int = int
338     coerce_any_to_list = do_pickle
339     coerce_any_to_long = long
340    
341     def coerce_any_to_str(self, value):
342         if isinstance(value, unicode):
343             return value.encode(self.encoding)
344         else:
345             return str(value)
346    
347     coerce_any_to_tuple = do_pickle
348    
349     def coerce_any_to_unicode(self, value):
350         if isinstance(value, unicode):
351             return value
352         else:
353             return unicode(value, self.encoding)
354
355
356 class TypeAdapter(object):
357     """Determine the best database type for a given column + Python type.
358     
359     This base class is designed to work out-of-the-box with PostgreSQL 8.
360     """
361    
362     # Max binary precision for floating-point columns (= 53 for PostgreSQL 8).
363     # Python floats are implemented using C doubles; actual precision
364     # depends on platform (but is usually 53 binary digits, see maxfloat_digits).
365     # PostgreSQL DOUBLE is 53 binary-digit precision.
366     float_max_precision = 53
367    
368     # Max decimal precision for NUMERIC columns (= 1000 for PostgreSQL 8).
369     numeric_max_precision = 1000
370    
371     # "The actual storage requirement is two bytes for each group of four
372     # decimal digits, plus eight bytes overhead." Note we omit the overhead.
373     numeric_max_bytes = 500
374    
375     def coerce(self, col, pytype):
376         """Return a database type for the given column object and Python type."""
377         xform = "coerce_" + getCoerceName(pytype)
378         try:
379             xform = getattr(self, xform)
380         except AttributeError:
381             raise TypeError("'%s' is not handled by %s." %
382                             (pytype, self.__class__))
383         return xform(col)
384    
385     def float_type(self, precision):
386         """Return a datatype which can handle floats of the given binary precision."""
387         if precision <= 24:
388             return "REAL"
389         else:
390             return "DOUBLE PRECISION"
391    
392     def coerce_float(self, col):
393         # Note that 'precision' is binary digits, not decimal.
394         precision = int(col.hints.get('precision', maxfloat_digits))
395         if precision > self.float_max_precision:
396             return "TEXT"
397         return self.float_type(precision)
398    
399     def coerce_str(self, col):
400         # The bytes hint shall not reflect the usual 4-byte base for varchar.
401         bytes = int(col.hints.get('bytes', 255))
402         if bytes and bytes <= 255:
403             return "VARCHAR(%s)" % bytes
404         # TEXT is not an SQL standard, but it's common.
405         return "TEXT"
406    
407     def coerce_dict(self, col):
408         return self.coerce_str(col)
409     def coerce_list(self, col):
410         return self.coerce_str(col)
411     def coerce_tuple(self, col):
412         return self.coerce_str(col)
413     def coerce_unicode(self, col):
414         return self.coerce_str(col)
415    
416     def coerce_bool(self, col): return "BOOLEAN"
417    
418     def coerce_datetime_datetime(self, col): return "TIMESTAMP"
419     def coerce_datetime_date(self, col): return "DATE"
420     def coerce_datetime_time(self, col): return "TIME"
421    
422     # I was seriously disinterested in writing a parser for interval.
423     def coerce_datetime_timedelta(self, col):
424         return self.coerce_float(col)
425    
426     def decimal_type(self, colname, precision, scale):
427         if precision > self.numeric_max_precision:
428             warnings.warn("The precision of '%s' (%s) is greater than the "
429                           "maximum numeric precision (%s). Using TEXT instead."
430                           % (colname, precision, self.numeric_max_precision),
431                           errors.StorageWarning)
432             return "TEXT"
433         if scale > precision:
434             scale = precision
435         return "NUMERIC(%s, %s)" % (precision, scale)
436    
437     def coerce_decimal_Decimal(self, col):
438         precision = int(col.hints.get('precision', self.numeric_max_precision))
439         # Assume most people use decimal for money; default scale = 2.
440         scale = int(col.hints.get('scale', 2))
441         return self.decimal_type(col.name, precision, scale)
442    
443     def coerce_decimal(self, col):
444         # If decimal ever becomes a builtin. Python 2.5?
445         return self.coerce_decimal_Decimal(col)
446    
447     def coerce_fixedpoint_FixedPoint(self, col):
448         # Note that fixedpoint has no theoretical precision limit.
449         precision = int(col.hints.get('precision', self.numeric_max_precision))
450         # Assume most people use fixedpoint for money; default scale = 2.
451         scale = int(col.hints.get('scale', 2))
452         return self.decimal_type(col.name, precision, scale)
453    
454     def int_type(self, bytes):
455         """Return a datatype which can handle the given number of bytes."""
456         if bytes <= 2:
457             return "SMALLINT"
458         elif bytes <= 4:
459             return "INTEGER"
460         elif bytes <= 8:
461             # BIGINT is usually 8 bytes
462             return "BIGINT"
463         else:
464             # Anything larger than 8 bytes, use decimal/numeric.
465             # For PostgreSQL, "The actual storage requirement is two bytes
466             # for each group of four decimal digits, plus eight bytes
467             # overhead." Note we omit the overhead in our calculation.
468             return "NUMERIC(%s, 0)" % (bytes * 2)
469    
470     def coerce_long(self, col):
471         bytes = int(col.hints.get('bytes', self.numeric_max_bytes))
472         if bytes > self.numeric_max_bytes:
473             return "TEXT"
474         return self.int_type(bytes)
475    
476     def coerce_int(self, col):
477         bytes = int(col.hints.get('bytes', maxint_bytes))
478         if bytes > maxint_bytes:
479             return self.coerce_long(col)
480         return self.int_type(bytes)
481
482
483
484 # -------------------------- SQL DECOMPILATION -------------------------- #
485
486
487 class ConstWrapper(str):
488     """Wraps a constant for use in SQLDecompiler's stack.
489     
490     When we hit LOAD_CONST while decompiling, we occasionally need to keep
491     both the base and the coerced value around (see COMPARE_OP for use
492     of ConstWrapper.basevalue).
493     """
494     def __new__(self, basevalue, coerced_value):
495         newobj = str.__new__(ConstWrapper, coerced_value)
496         newobj.basevalue = basevalue
497         return newobj
498
499
500 # Stack sentinels
501 class Sentinel(object):
502    
503     def __init__(self, name):
504         self.name = name
505    
506     def __repr__(self):
507         return 'Stack Sentinel: %s' % self.name
508
509 kw_arg = Sentinel('Keyword Arg')
510 # cannot_represent exists so that a portion of an Expression can be
511 # labeled imperfect. For example, the function dejavu.iscurrentweek
512 # rarely has an SQL equivalent. All Units (which match the rest of the
513 # Expression) will be recalled; they can then be compared in expr(unit).
514 cannot_represent = Sentinel('Cannot Repr')
515
516
517 class SQLDecompiler(codewalk.LambdaDecompiler):
518     """SQLDecompiler(tables, expr, adapter=AdapterToSQL()).
519     
520     Produce SQL from a supplied Expression object, with a lambda of the form:
521         lambda x, **kw: ...
522     
523     Attributes of each argument in the signature will be mapped to table
524     columns. Keyword arguments should be bound using Expression.bind_args
525     before calling this decompiler.
526     """
527    
528     # Some constants are function or class objects,
529     # which should not be coerced.
530     no_coerce = (FunctionType,
531                  type,
532                  type(len),       # <type 'builtin_function_or_method'>
533                  )
534    
535     sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in')
536    
537     def __init__(self, tables, expr, adapter=AdapterToSQL()):
538         self.tables = tables
539         self.expr = expr
540         self.adapter = adapter
541         # Cache coerced booleans
542         self.T = adapter.coerce_bool_to_any(True)
543         self.F = adapter.coerce_bool_to_any(False)
544         obj = expr.func
545         codewalk.LambdaDecompiler.__init__(self, obj)
546    
547     def code(self):
548         self.imperfect = False
549         self.walk()
550         # After walk(), self.stack should be reduced to a single string,
551         # which is the SQL representation of our Expression.
552         result = self.stack[0]
553         if result is cannot_represent:
554             # The entire expression could not be evaluated.
555             result = self.adapter.bool_true
556         if result == self.T:
557             result = self.adapter.bool_true
558         if result == self.F:
559             result = self.adapter.bool_false
560         return result
561    
562     def visit_instruction(self, op, lo=None, hi=None):
563         # Get the instruction pointer for the current instruction.
564         ip = self.cursor - 3
565         if hi is None:
566             ip += 1
567             if lo is None:
568                 ip += 1
569        
570         terms = self.targets.get(ip)
571         if terms:
572             trueval = self.adapter.bool_true
573             falseval = self.adapter.bool_false
574             clause = self.stack[-1]
575             while terms:
576                 term, oper = terms.pop()
577                 if term is cannot_represent:
578                     # Use TRUE for the term, so all records are returned.
579                     term = trueval
580                 if clause is cannot_represent:
581                     # Use TRUE for the clause, so all records are returned.
582                     clause = trueval
583                
584                 # Blurg. SQL Server is *so* picky.
585                 if term == self.T:
586                     term = trueval
587                 elif term == self.F:
588                     term = falseval
589                 if clause == self.T:
590                     clause = trueval
591                 elif clause == self.F:
592                     clause = falseval
593                
594                 clause = "(%s) %s (%s)" % (term, oper.upper(), clause)
595            
596             # Replace TOS with the new clause, so that further
597             # combinations have access to it.
598             self.stack[-1] = clause
599             self.debug("clause:", clause, "\n")
600            
601             if op == 1:
602                 # Py2.4: The current instruction is POP_TOP, which means
603                 # the previous is probably JUMP_*. If so, we're going to
604                 # pop the value we just placed on the stack and lose it.
605                 # We need to replace the entry that the JUMP_* made in
606                 # self.targets with our new TOS.
607                 target = self.targets[self.last_target_ip]
608                 target[-1] = ((clause, target[-1][1]))
609                 self.debug("newtarget:", self.last_target_ip, target)
610    
611     def visit_LOAD_DEREF(self, lo, hi):
612         raise ValueError("Illegal reference found in %s." % self.expr)
613    
614     def visit_LOAD_GLOBAL(self, lo, hi):
615         raise ValueError("Illegal global found in %s." % self.expr)
616    
617     def visit_LOAD_FAST(self, lo, hi):
618         arg_index = lo + (hi << 8)
619         if arg_index < self.co_argcount:
620             # We've hit a reference to a positional arg, which in our
621             # case implies a reference to a DB table.
622             self.stack.append(self.tables[arg_index])
623         else:
624             # Since lambdas don't support local bindings,
625             # any remaining local name must be a keyword arg.
626             self.stack.append(kw_arg)
627    
628     def visit_LOAD_ATTR(self, lo, hi):
629         name = self.co_names[lo + (hi << 8)]
630         tos = self.stack.pop()
631         if isinstance(tos, tuple):
632             # The name in question refers to a DB column.
633             tablename, table = tos
634             col = table.columns[name]
635             if col.imperfect_type:
636                 atom = cannot_represent
637                 self.imperfect = True
638             else:
639                 atom = '%s.%s' % (tablename, col.qname)
640         else:
641             # 'tos.name' will reference an attribute of the tos object.
642             # Stick the tos and name in a tuple for later processing.
643             atom = (tos, name)
644         self.stack.append(atom)
645    
646     def visit_LOAD_CONST(self, lo, hi):
647         val = self.co_consts[lo + (hi << 8)]
648         if not isinstance(val, self.no_coerce):
649             val = ConstWrapper(val, self.adapter.coerce(val))
650         self.stack.append(val)
651    
652     def visit_BUILD_TUPLE(self, lo, hi):
653         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)])
654         self.stack.append("(" + terms + ")")
655    
656     visit_BUILD_LIST = visit_BUILD_TUPLE
657    
658     def visit_CALL_FUNCTION(self, lo, hi):
659         kwargs = {}
660         for i in xrange(hi):
661             val = self.stack.pop()
662             key = self.stack.pop()
663             kwargs[key] = val
664         kwargs = [k + "=" + v for k, v in kwargs.iteritems()]
665        
666         args = []
667         for i in xrange(lo):
668             arg = self.stack.pop()
669             args.append(arg)
670         args.reverse()
671        
672         if kwargs:
673             args += kwargs
674        
675         func = self.stack.pop()
676        
677         # Handle function objects.
678         if isinstance(func, tuple):
679             tos, name = func
680             dispatch = getattr(self, "attr_" + name, None)
681             if dispatch:
682                 self.stack.append(dispatch(tos, *args))
683                 return
684         else:
685             funcname = func.__module__ + "_" + func.__name__
686             funcname = funcname.replace(".", "_")
687             if funcname.startswith("_"):
688                 funcname = "func" + funcname
689             dispatch = getattr(self, funcname, None)
690             if dispatch:
691                 self.stack.append(dispatch(*args))
692                 return
693        
694         self.stack.append(cannot_represent)
695         self.imperfect = True
696    
697     def visit_COMPARE_OP(self, lo, hi):
698         op2, op1 = self.stack.pop(), self.stack.pop()
699         if op1 is cannot_represent or op2 is cannot_represent:
700             self.stack.append(cannot_represent)
701             return
702        
703         op = lo + (hi << 8)
704         if op in (6, 7):     # in, not in
705             value = self.containedby(op1, op2)
706             if op == 7:
707                 value = "NOT " + value
708             self.stack.append(value)
709         elif op1 == 'NULL':
710             if op in (2, 8):    # '==', is
711                 self.stack.append(op2 + " IS NULL")
712             elif op in (3, 9):  # '!=', 'is not'
713                 self.stack.append(op2 + " IS NOT NULL")
714             else:
715                 raise ValueError("Non-equality Null comparisons not allowed.")
716         elif op2 == 'NULL':
717             if op in (2, 8):    # '==', 'is'
718                 self.stack.append(op1 + " IS NULL")
719             elif op in (3, 9):  # '!=', 'is not'
720                 self.stack.append(op1 + " IS NOT NULL")
721             else:
722                 raise ValueError("Non-equality Null comparisons not allowed.")
723         else:
724             # Comparison operators for strings are case-sensitive in PG et al.
725             self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2)
726    
727     def binary_op(self, op):
728         op2, op1 = self.stack.pop(), self.stack.pop()
729         self.stack.append(op1 + " " + op + " " + op2)
730    
731     def visit_BINARY_SUBSCR(self):
732         # The only BINARY_SUBSCR used in Expressions should be kwargs[key].
733         name = self.stack.pop()
734         tos = self.stack.pop()
735         if tos is not kw_arg:
736             raise ValueError("Subscript %s of %s object not allowed."
737                              % (name, tos))
738         # name, since formed in LOAD_CONST, may have extraneous quotes.
739         name = name.strip("'\"")
740         value = self.expr.kwargs[name]
741         if not isinstance(value, self.no_coerce):
742             value = ConstWrapper(value, self.adapter.coerce(value))
743         self.stack.append(value)
744    
745     def visit_UNARY_NOT(self):
746         op = self.stack.pop()
747         if op is cannot_represent:
748             self.stack.append(cannot_represent)
749         else:
750             self.stack.append("NOT (" + op + ")")
751    
752     # --------------------------- Dispatchees --------------------------- #
753    
754     def attr_startswith(self, tos, arg):
755         return tos + " LIKE '" + self.adapter.escape_like(arg) + "%'"
756    
757     def attr_endswith(self, tos, arg):
758         return tos + " LIKE '%" + self.adapter.escape_like(arg) + "'"
759    
760     def containedby(self, op1, op2):
761         if isinstance(op1, ConstWrapper):
762             # Looking for text in a field. Use Like (reverse terms).
763             return op2 + " LIKE '%" + self.adapter.escape_like(op1) + "%'"
764         else:
765             # Looking for field in (a, b, c)
766             atoms = [self.adapter.coerce(x) for x in op2.basevalue]
767             return op1 + " IN (" + ", ".join(atoms) + ")"
768    
769     def dejavu_icontainedby(self, op1, op2):
770         if isinstance(op1, ConstWrapper):
771             # Looking for text in a field. Use Like (reverse terms).
772             return ("LOWER(" + op2 + ") LIKE '%" +
773                     self.adapter.escape_like(op1).lower() + "%'")
774         else:
775             # Looking for field in (a, b, c).
776             # Force all args to lowercase for case-insensitive comparison.
777             atoms = [self.adapter.coerce(x).lower() for x in op2.basevalue]
778             return "LOWER(%s) IN (%s)" % (op1, ", ".join(atoms))
779    
780     def dejavu_icontains(self, x, y):
781         return self.dejavu_icontainedby(y, x)
782    
783     def dejavu_istartswith(self, x, y):
784         return "LOWER(" + x + ") LIKE '" + self.adapter.escape_like(y) + "%'"
785    
786     def dejavu_iendswith(self, x, y):
787         return "LOWER(" + x + ") LIKE '%" + self.adapter.escape_like(y) + "'"
788    
789     def dejavu_ieq(self, x, y):
790         return "LOWER(" + x + ") = LOWER(" + y + ")"
791    
792     def dejavu_now(self):
793         return "NOW()"
794    
795     def dejavu_today(self):
796         return "CURRENT_DATE"
797    
798     def dejavu_year(self, x):
799         return "YEAR(" + x + ")"
800
801     def dejavu_month(self, x):
802         return "MONTH(" + x + ")"
803    
804     def dejavu_day(self, x):
805         return "DAY(" + x + ")"
806    
807     def func__builtin___len(self, x):
808         return "LENGTH(" + x + ")"
809
810
811
812 # ------------------------- Connection Factories ------------------------- #
813
814
815 class ConnectionWrapper(object):
816     """Connection object wrapper, so it can be used as a weak reference."""
817    
818     def __init__(self, conn=None):
819         self.conn = conn
820    
821     def __getattr__(self, attr):
822         return getattr(self.conn, attr)
823
824
825 class OutOfConnectionsError(errors.DejavuError):
826     """Exception raised when a database store has run out of connections."""
827     pass
828
829
830 class ConnectionFactory(object):
831     """A connection factory which creates a new connection for each request."""
832    
833     def __init__(self, open, close, retry=5):
834         self.open = open
835         self.close = close
836         self.retry = retry
837         self.refs = {}
838    
839     def __call__(self):
840         """Return a connection."""
841         for i in xrange(self.retry):
842             try:
843                 conn = self.open()
844                 w = ConnectionWrapper(conn)
845                 self.refs[weakref.ref(w, self._release)] = w.conn
846                 return w
847             except OutOfConnectionsError:
848                 time.sleep(i + 1)
849                 conn = None
850         raise OutOfConnectionsError()
851    
852     def _release(self, ref):
853         """Release a connection."""
854         self.close(self.refs.pop(ref))
855    
856     def shutdown(self):
857         """Release all database connections."""
858         # Empty self.refs.
859         while self.refs:
860             ref, conn = self.refs.popitem()
861             self.close(conn)
862
863
864 class ConnectionPool(object):
865     """A database connection factory which keeps a pool of connections."""
866    
867     def __init__(self, open, close, size=10, retry=5):
868         self.open = open
869         self.close = close
870         self.refs = {}
871         self.pool = Queue.Queue(size)
872         self.retry = retry
873    
874     def __call__(self):
875         """Return a connection from the pool."""
876         for i in xrange(self.retry):
877             try:
878                 conn = self.pool.get_nowait()
879                 # Okay, this is freaky. If we wrap here, all goes well.
880                 # If we wrap on Queue.put(), mysql crashes after 1700
881                 # or so inserts (when migrating Access tables to MySQL).
882                 # Go figure.
883                 w = ConnectionWrapper(conn)
884                 self.refs[weakref.ref(w, self._release)] = w.conn
885                 return w
886             except Queue.Empty:
887                 pass
888            
889             try:
890                 conn = self.open()
891                 w = ConnectionWrapper(conn)
892                 self.refs[weakref.ref(w, self._release)] = w.conn
893                 return w
894             except OutOfConnectionsError:
895                 time.sleep(i + 1)
896                 conn = None
897         raise OutOfConnectionsError()
898    
899     def _release(self, ref):
900         """Release a connection."""
901         conn = self.refs.pop(ref)
902         try:
903             self.pool.put_nowait(conn)
904             return
905         except Queue.Full:
906             pass
907         self.close(conn)
908    
909     def shutdown(self):
910         """Release all database connections."""
911         # Empty the pool.
912         while True:
913             try:
914                 self.pool.get(block=False)
915             except Queue.Empty:
916                 break
917        
918         # Empty self.refs.
919         while self.refs:
920             ref, conn = self.refs.popitem()
921             self.close(conn)
922
923
924 class SingleConnection(object):
925     """A single database connection for all consumers.
926     
927     Use this when your database cannot handle multiple connections at once,
928     but can handle multiple threads using the same connection.
929     """
930    
931     def __init__(self, open, close):
932         self.open = open
933         self.close = close
934         # Delay opening the connection, because the
935         # SM may need to create the database first.
936         self._conn = None
937    
938     def __call__(self):
939         """Return our only connection."""
940         if self._conn is None:
941             self._conn = self.open()
942         return self._conn
943    
944     def shutdown(self):
945         """Release all database connections."""
946         if self._conn is not None:
947             self.close(self._conn)
948             self._conn = None
949
950
951
952 # -------------------------- DATABASE OBJECTS -------------------------- #
953
954
955 class Index:
956     """An index on a table column (or columns) in a database."""
957    
958     def __init__(self, name, qname, tablename, colname, unique=True):
959         self.name = name
960         self.qname = qname
961         self.tablename = tablename
962         self.colname = colname
963         self.unique = unique
964    
965     def __repr__(self):
966         return ("%s.%s(%r, %r, %r, %r, unique=%r)" %
967                 (self.__module__, self.__class__.__name__,
968                  self.name, self.qname, self.tablename,
969                  self.colname, self.unique))
970    
971     def __copy__(self):
972         return self.__class__(self.name, self.qname, self.tablename,
973                               self.colname, self.unique)
974     copy = __copy__
975
976
977 class IndexSet(dict):
978    
979     def __new__(cls, table):
980         return dict.__new__(cls)
981    
982     def __init__(self, table):
983         dict.__init__(self)
984         self.table = table
985    
986     def __delitem__(self, key):
987         """Drop the specified index."""
988         t = self.table
989         t.db.lock("Dropping index. Transactions not allowed.")
990         try:
991             t.db.execute('DROP INDEX %s ON %s;' % (self[key].qname, t.qname))
992         finally:
993             t.db.unlock()
994
995
996 class Column:
997     """A column in a table in a database.
998     
999     name: the SQL name for this table (unquoted).
1000     qname: the SQL name for this table (quoted).
1001     dbtype: the database type name (as used in a CREATE TABLE statement).
1002     default: default value for this column for new rows.
1003     hints: a dict of implementation hints, such as precision, scale, or bytes.
1004     key: True if this column is part of the table's primary key.
1005     
1006     imperfect_type: if True, signals that we are deliberately using a
1007         database type other than the default (usually in order to handle
1008         irregular values, such as huge numbers).
1009     autoincrement: if True, uses the database's built-in sequencing.
1010     """
1011    
1012     def __init__(self, name, qname, dbtype, default=None, hints=None, key=False):
1013         self.name = name
1014         self.qname = qname
1015        
1016         self.dbtype = dbtype
1017         self.imperfect_type = False
1018        
1019         self.default = default
1020         if hints is None:
1021             hints = {}
1022         self.hints = hints
1023         # If autoincrement, the initial value should be put in self.default.
1024         self.autoincrement = False
1025        
1026         self.key = key
1027    
1028     def __repr__(self):
1029         return ("%s.%s(%r, %r, dbtype=%r, default=%r, hints=%r, key=%r)" %
1030                 (self.__module__, self.__class__.__name__,
1031                  self.name, self.qname, self.dbtype,
1032                  self.default, self.hints, self.key)
1033                 )
1034    
1035     def __copy__(self):
1036         return self.__class__(self.name, self.qname, self.dbtype,
1037                               self.default, self.hints.copy(), self.key)
1038     copy = __copy__
1039
1040
1041 class ColumnSet(dict):
1042    
1043     indexsetclass = IndexSet
1044    
1045     def __new__(cls, table):
1046         return dict.__new__(cls)
1047    
1048     def __init__(self, table):
1049         dict.__init__(self)
1050         self.table = table
1051         self.indices = self.indexsetclass(self.table)
1052    
1053     def __setitem__(self, key, column):
1054         t = self.table
1055         if key in self:
1056             del self[key]
1057        
1058         default = column.default or ""
1059         if default:
1060             default = " DEFAULT %s" % t.db.adaptertosql.coerce(default, column.dbtype)
1061        
1062         t.db.lock("Adding property. Transactions not allowed.")
1063         try:
1064             t.db.execute("ALTER TABLE %s ADD COLUMN %s %s%s;" %
1065                          (t.qname, column.qname, column.dbtype, default))
1066             dict.__setitem__(self, key, column)
1067         finally:
1068             t.db.unlock()
1069    
1070     def __delitem__(self, key):
1071         if key in self.indices:
1072             del self.indices[key]
1073         t = self.table
1074         t.db.lock("Dropping property. Transactions not allowed.")
1075         try:
1076             t.db.execute("ALTER TABLE %s DROP COLUMN %s;" %
1077                                (t.qname, self[key].qname))
1078             dict.__delitem__(self, key)
1079         finally:
1080             t.db.unlock()
1081    
1082     def _rename(self, oldcol, newcol):
1083         # Override this to do the actual rename at the DB level.
1084         t = self.table
1085         t.db.execute("ALTER TABLE %s RENAME COLUMN %s TO %s;" %
1086                      (t.qname, oldcol.qname, newcol.qname))
1087    
1088     def rename(self, oldkey, newkey):
1089         """Rename a Column."""
1090         oldcol = self[oldkey]
1091         oldname = oldcol.name
1092         t = self.table
1093         newname = t.db.column_name(self.table.name, newkey)
1094        
1095         if oldname != newname:
1096             newcol = oldcol.copy()
1097             newcol.name = newname
1098             newcol.qname = t.db.quote(newname)
1099             t.db.lock("Renaming property. Transactions not allowed.")
1100             try:
1101                 self._rename(oldcol, newcol)
1102             finally:
1103                 t.db.unlock()
1104        
1105         # Use the superclass calls to avoid DROP COLUMN/ADD COLUMN.
1106         dict.__delitem__(self, oldkey)
1107         dict.__setitem__(self, newkey, newcol)
1108
1109
1110 class Table(object):
1111     """A table in a database.
1112     
1113     db: the database for this table.
1114     name: the SQL name for this table (unquoted).
1115     qname: the SQL name for this table (quoted).
1116     columns: a dict of {key: Column object} for this table. The 'key'
1117         argument should be a name you use for the column in your Python code,
1118         whereas the Column.name is the SQL name for the column (unquoted).
1119     """
1120    
1121     def __init__(self, db, name, qname):
1122         self.db = db
1123         self.name = name
1124         self.qname = qname
1125         self.columns = db.columnsetclass(self)
1126    
1127     def __repr__(self):
1128         return ("%s.%s(db, %r, %r)" %
1129                 (self.__module__, self.__class__.__name__,
1130                  self.name, self.qname))
1131    
1132     def __copy__(self):
1133         t = self.__class__(self.db, self.name, self.qname)
1134         for key, c in self.columns.iteritems():
1135             dict.__setitem__(t.columns, key, c.copy())
1136         for key, i in self.columns.indices.iteritems():
1137             dict.__setitem__(t.columns.indices, key, i.copy())
1138         return t
1139     copy = __copy__
1140
1141
1142
1143 class Database(dict):
1144     """A dict for managing a set of tables."""
1145    
1146     decompiler = SQLDecompiler
1147     adaptertosql = AdapterToSQL()
1148     adapterfromdb = AdapterFromDB()
1149     typeadapter = TypeAdapter()
1150    
1151     columnsetclass = ColumnSet
1152    
1153     def __new__(cls, name, **kwargs):
1154         return dict.__new__(cls)
1155    
1156     def __init__(self, name, **kwargs):
1157         dict.__init__(self)
1158         for k, v in kwargs.iteritems():
1159             setattr(self, k, v)
1160        
1161         self.name = self.sql_name(name)
1162         self.qname = self.quote(self.name)
1163         self.transactions = {}
1164         self.connect()
1165         self.sync_dbinfo()
1166    
1167     def _get_dbinfo(self, conn=None):
1168         return {}
1169    
1170     def sync_dbinfo(self, conn=None):
1171         """Set attributes on self with actual DB metadata, where possible."""
1172         for k, v in self._get_dbinfo().iteritems():
1173             setattr(self, k, v)
1174    
1175     def _get_tables(self, conn=None):
1176         raise NotImplementedError
1177    
1178     def _get_columns(self, tablename, conn=None):
1179         raise NotImplementedError
1180    
1181     def _get_indices(self, tablename, conn=None):
1182         raise NotImplementedError
1183    
1184     def python_type(self, dbtype):
1185         """Return a Python type which can store values of the given dbtype."""
1186         raise TypeError("Database type %r could not be converted "
1187                         "to a Python type." % dbtype)
1188    
1189     def db_type(self, col, pytype):
1190         """Return a database type which can store values of the given pytype."""
1191         return self.typeadapter.coerce(col, pytype)
1192    
1193     def isrelatedtype(self, pytype1, pytype2):
1194         """If values of both types are expressed with the same SQL, return True."""
1195         if issubclass(pytype1, pytype2) or issubclass(pytype2, pytype1):
1196             return True
1197         if issubclass(pytype1, basestring) and issubclass(pytype2, basestring):
1198             return True
1199         if ((issubclass(pytype1, int) or issubclass(pytype1, long)) and
1200             (issubclass(pytype2, int) or issubclass(pytype2, long))):
1201             return True
1202         if fixedpoint:
1203             if decimal:
1204                 if ((issubclass(pytype1, fixedpoint.FixedPoint)
1205                      or issubclass(pytype1, decimal.Decimal)) and
1206                     (issubclass(pytype2, fixedpoint.FixedPoint)
1207                      or issubclass(pytype2, decimal.Decimal))):
1208                     return True
1209             else:
1210                 if (issubclass(pytype1, fixedpoint.FixedPoint) and
1211                     issubclass(pytype2, fixedpoint.FixedPoint)):
1212                     return True
1213         else:
1214             if decimal:
1215                 if (issubclass(pytype1, decimal.Decimal) and
1216                     issubclass(pytype2, decimal.Decimal)):
1217                     return True
1218         return False
1219    
1220     def __setitem__(self, key, table):
1221         if key in self:
1222             del self[key]
1223        
1224         fields = []
1225         pk = []
1226         for colkey, col in table.columns.iteritems():
1227             default = col.default or ""
1228             if default:
1229                 default = " DEFAULT %s" % self.adaptertosql.coerce(default, col.dbtype)
1230             fields.append('%s %s%s' % (col.qname, col.dbtype, default))
1231            
1232             if col.key:
1233                 pk.append(col.qname)
1234        
1235         if pk:
1236             pk = ", PRIMARY KEY (%s)" % ", ".join(pk)
1237         else:
1238             pk = ""
1239        
1240         self.lock("Creating storage. Transactions not allowed.")
1241         try:
1242             self.execute('CREATE TABLE %s (%s%s);' %
1243                          (table.qname, ", ".join(fields), pk))
1244            
1245             for index in table.columns.indices.itervalues():
1246                 self.execute('CREATE INDEX %s ON %s (%s);' %
1247                              (index.qname, table.qname,
1248                               self.quote(index.colname)))
1249             dict.__setitem__(self, key, table)
1250         finally:
1251             self.unlock()
1252        
1253    
1254     def __delitem__(self, key):
1255         self.lock("Dropping storage. Transactions not allowed.")
1256         try:
1257             self.execute('DROP TABLE %s;' % self[key].qname)
1258             dict.__delitem__(self, key)
1259         finally:
1260             self.unlock()
1261    
1262     def _rename(self, oldtable, newtable):
1263         # Override this to do the actual rename at the DB level.
1264         raise NotImplementedError
1265    
1266     def rename(self, oldkey, newkey):
1267         """Rename a Table."""
1268         oldtable = self[oldkey]
1269         oldname = oldtable.name
1270         newname = self.table_name(newkey)
1271        
1272         if oldname != newname:
1273             newtable = oldtable.copy
1274             newtable.name = newname
1275             newtable.qname = self.quote(newname)
1276             self.lock("Renaming storage. Transactions not allowed.")
1277             try:
1278                 self._rename(oldtable, newname)
1279             finally:
1280                 self.unlock()
1281        
1282         # Use the superclass calls to avoid DROP TABLE/CREATE TABLE.
1283         dict.__delitem__(self, oldkey)
1284         dict.__setitem__(self, newkey, newtable)
1285    
1286     #                               Naming                               #
1287    
1288     sql_name_max_length = 64
1289     sql_name_caseless = False
1290     Prefix = ""
1291    
1292     def quote(self, name):
1293         """Return name, quoted for use in an SQL statement."""
1294         # This base class doesn't use "quote",
1295         # but most subclasses will.
1296         return name
1297    
1298     def sql_name(self, key):
1299         """Return the native SQL version of key."""
1300         if self.sql_name_caseless:
1301             key = key.lower()
1302        
1303         maxlen = self.sql_name_max_length
1304         if maxlen and len(key) > maxlen:
1305             warnings.warn("The name '%s' is longer than the maximum of "
1306                           "%s characters." % (key, maxlen),
1307                           errors.StorageWarning)
1308             key = key[:maxlen]
1309        
1310         return key
1311    
1312     def column_name(self, tablekey, columnkey):
1313         """Return the SQL column name for the given table and column keys."""
1314         # If you want to use a map from UnitProperty names
1315         # to DB column names, override this method (that's why
1316         # the tablename must be included in the args).
1317         return self.sql_name(columnkey)
1318    
1319     def make_column(self, tablekey, columnkey, pytype, default, hints):
1320         """Return a Column object from the given table and column keys."""
1321         name = self.column_name(tablekey, columnkey)
1322         col = Column(name, self.quote(name), None, default, hints.copy())
1323         col.dbtype = self.db_type(col, pytype)
1324         pytype2 = self.python_type(col.dbtype)
1325         col.imperfect_type = not self.isrelatedtype(pytype, pytype2)
1326         return col
1327    
1328     def table_name(self, key):
1329         """Return the SQL table name for the given key."""
1330         # If you want to use a map from Unit class names
1331         # to DB table names, override this method.
1332         return self.sql_name(self.Prefix + key)
1333    
1334     def make_table(self, tablekey):
1335         name = self.table_name(tablekey)
1336         return Table(self, name, self.quote(name))
1337    
1338     def make_index(self, tablekey, columnkey):
1339         name = self.table_name("i" + tablekey + columnkey)
1340         return Index(name, self.quote(name), self.table_name(tablekey),
1341                      self.column_name(tablekey, columnkey))
1342    
1343     #                              Retrieval                              #
1344    
1345     def select(self, tablekey, expr, columnkeys=None, distinct=False):
1346         """Return an SQL SELECT statement, and an 'imperfect' flag.
1347         
1348         imperfect: True or False depending on whether the generated SQL
1349             perfectly satisfies the given expression.
1350         """
1351         t = self[tablekey]
1352         if columnkeys:
1353             colnames = [t.columns[x].qname for x in columnkeys]
1354             if distinct:
1355                 sql = 'SELECT DISTINCT %s FROM %s'
1356             else:
1357                 sql = 'SELECT %s FROM %s'
1358             sql = sql % (', '.join(colnames), t.qname)
1359         else:
1360             sql = 'SELECT * FROM %s' % t.qname
1361        
1362         w, i = self.where(t, expr)
1363         if len(w) > 0:
1364             w = " WHERE " + w
1365         else:
1366             w = ""
1367        
1368         sql += w + ";"
1369         return sql, i
1370    
1371     def where(self, tables, expr):
1372         """Return an SQL WHERE clause, and an 'imperfect' flag.
1373         
1374         tables: a Table object, a list of Table objects,
1375             or a list of (quoted-name-or-alias, Table) tuples
1376         
1377         imperfect: True or False depending on whether the generated SQL
1378             perfectly satisfies the given expression.
1379         """
1380         if not isinstance(tables, list):
1381             tables = [tables]
1382         for i, t in enumerate(tables):
1383             if not isinstance(t, (tuple, list)):
1384                 tables[i] = (t.qname, t)
1385        
1386         decom = self.decompiler(tables, expr, self.adaptertosql)
1387         return decom.code(), decom.imperfect
1388    
1389     #                             Connecting                              #
1390    
1391     poolsize = 10
1392    
1393     def connect(self):
1394         if self.poolsize > 0:
1395             self.connection = ConnectionPool(self._get_conn, self._del_conn,
1396                                              self.poolsize)
1397         else:
1398             self.connection = ConnectionFactory(self._get_conn, self._del_conn)
1399    
1400     def _get_conn(self):
1401         # Override this with the connection call for your DB. Example:
1402         #     return libpq.PQconnectdb(self.connstring)
1403         raise NotImplementedError
1404    
1405     def _del_conn(self, conn):
1406         # Override this with the close call (if any) for your DB.
1407         conn.close()
1408    
1409     def disconnect(self):
1410         """Release all database connections."""
1411         self.connection.shutdown()
1412    
1413     def log(self, msg):
1414         pass
1415    
1416     def execute(self, query, conn=None):
1417         """execute(query, conn=None) -> result set."""
1418         if conn is None:
1419             conn = self.connection()
1420         if isinstance(query, unicode):
1421             query = query.encode(self.adaptertosql.encoding)
1422         self.log(query)
1423         return conn.query(query)
1424    
1425     def fetch(self, query, conn=None):
1426         """Return rowdata, columns (name, type) for the given query.
1427         
1428         query should be a SQL query in string format
1429         rowdata will be an iterable of iterables containing the result values.
1430         columns will be an iterable of (column name, data type) pairs.
1431         
1432         This base class uses SQLite3 syntax.
1433         """
1434         res = self.execute(query, conn)
1435         return res.row_list, res.col_defs
1436    
1437     def create_database(self):
1438         self.lock("Creating database. Transactions not allowed.")
1439         try:
1440             self.execute("CREATE DATABASE %s;" % self.qname)
1441             self.clear()
1442         finally:
1443             self.unlock()
1444    
1445     def drop_database(self):
1446         self.lock("Dropping database. Transactions not allowed.")
1447         try:
1448             # Must shut down all connections to avoid
1449             # "being accessed by other users" error.
1450             self.connection.shutdown()
1451             self.execute("DROP DATABASE %s;" % self.qname)
1452             self.clear()
1453         finally:
1454             self.unlock()
1455    
1456     #                            Transactions                             #
1457    
1458     transaction_key = threading._get_ident
1459     implicit_trans = False
1460    
1461     def get_transaction(self, new=False):
1462         """Return the (possibly new) connection for the current transaction.
1463         
1464         If self.implicit_trans is True, a new connection will be associated
1465         with self.transaction_key(); repeated calls to this method will
1466         return the same connection. The default transaction key is the
1467         current thread ID.
1468         
1469         If self.implicit_trans is False, this method returns self.connection(),
1470         unless a transaction is already in progress.
1471         """
1472         key = self.transaction_key()
1473         if key in self.transactions:
1474             conn = self.transactions[key]
1475             if isinstance(conn, TransactionLock):
1476                 raise conn
1477         else:
1478             conn = self.connection()
1479             if self.implicit_trans or new:
1480                 self.transactions[key] = conn
1481                 if not new:
1482                     self.start()
1483         return conn
1484    
1485     def start(self):
1486         """Start a transaction. Not needed if self.implicit_trans is True."""
1487         self.execute("START TRANSACTION;", self.get_transaction(new=True))
1488    
1489     def rollback(self):
1490         """Roll back the current transaction (error if not in transaction)."""
1491         key = self.transaction_key()
1492         if key in self.transactions:
1493             self.execute("ROLLBACK;", self.transactions[key])
1494             del self.transactions[key]
1495    
1496     def commit(self):
1497         """Commit the current transaction, if any."""
1498         key = self.transaction_key()
1499         if key in self.transactions:
1500             self.execute("COMMIT;", self.transactions[key])
1501             del self.transactions[key]
1502    
1503     # Change this to 'error' if you don't want autocommit on schema ops.
1504     lock_contention = 'commit'
1505    
1506     def lock(self, msg=None):
1507         """Deny transactions during schema operations (DDL statements).
1508         
1509         Any code which calls this should also call 'unlock' in a try/finally:
1510         
1511         db.lock('dropping storage')
1512         try:
1513             drop_storage(cls)
1514         finally:
1515             db.unlock()
1516         """
1517         key = self.transaction_key()
1518         if key in self.transactions:
1519             if isinstance(self.transactions[key], TransactionLock):
1520                 return
1521             if self.lock_contention == 'error':
1522                 raise TransactionLock("Schema operations are not allowed "
1523                                       "inside transactions.")
1524             self.commit()
1525        
1526         if msg is None:
1527             msg = "Transactions not allowed at the moment."
1528         self.transactions[key] = TransactionLock(msg)
1529    
1530     def unlock(self):
1531         """Allow transactions."""
1532         key = self.transaction_key()
1533         trans = self.transactions.get(key, None)
1534         if trans is None:
1535             return
1536         if not isinstance(trans, TransactionLock):
1537             raise TransactionLock("Unlock called inside transaction.")
1538         del self.transactions[key]
1539
1540
1541 class TransactionLock(Exception):
1542     """Exception raised when a transaction is requested but not allowed.
1543     
1544     This is also used as a sentinel by a Database, to signal that a
1545     given thread should not start a new transaction because the thread
1546     is currently performing schema changes (DDL statements).
1547     """
1548     pass
1549
Note: See TracBrowser for help on using the browser.