Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

root/trunk/geniusql/deparse.py

Revision 310 (checked in by lakin, 4 months ago)

geniusql - fixing whitespcae

  • Property svn:eol-style set to native
Line 
1 import datetime
2 import sys
3 import traceback
4 from compiler import ast
5 from types import FunctionType, NoneType
6 from geniusql import logic, astwalk
7
8 # Comparison operator order from opcode.cmp_op:
9 #            0   1   2   3  4   5
10 #            <  <=  ==  !=  >  >=
11 # Comparison operator order when terms are swapped:
12 #            >  =>  ==  !=  <  <=
13 reverseop = {'<': '>', '<=': '>=', '==': '==', '!=': '!=', '>': '<', '>=': '<='}
14
15
16 __all__ = [
17     'SQLExpression', 'Sentinel',
18     'CannotRepresent', 'kw_arg', 'SQLDeparser',
19     ]
20
21
22 class SQLExpression(object):
23     """Wraps a column or other expression for use in SQLDeparser.
24
25     sql: the expression to be placed in the SQL for this expression.
26         This should not contain an alias ("AS" clause); that will be
27         provided by the consumer, usually via the 'name' attribute.
28     name: the name of the expression; may be used as an alias (with "AS").
29         This should *not* be quoted/escaped, as it may need to be merged
30         with other strings before being used (for example, "expr_" + name).
31         Consumers must quote the name attribute appropriately (usually
32         via db.quote(e.name)) before inserting it into SQL.
33     value: If not None, the expression is a "constant"; that is, we already
34         know its defined Python value (and that it does not have any basis
35         in column values).
36     aggregate: If True, the expression represents an aggregated value
37         such as MAX(colref). This flag is used by consumers to write
38         GROUP BY clauses.
39     """
40
41     def __init__(self, sql, name, dbtype, pytype, value=None):
42         self.sql = sql
43         self.name = name
44
45         self.dbtype = dbtype
46         self.pytype = pytype
47         self.adapter = None
48
49         self.value = value
50         self.aggregate = False
51
52     def __cmp__(self, other):
53         return cmp(self.sql, other.sql)
54
55     def __repr__(self):
56         return ("%s.%s(%r, dbtype=%s)" %
57                 (self.__module__, self.__class__.__name__, self.sql,
58                  self.dbtype.__class__.__name__))
59
60
61 class SQLTableRef(object):
62
63     def __init__(self, table, alias):
64         self.table = table
65         self.alias = alias
66
67
68 # AST Sentinels
69 class Sentinel(object):
70
71     def __init__(self, name):
72         self.name = name
73
74     def __repr__(self):
75         return 'AST Sentinel: %s' % self.name
76
77 kw_arg = Sentinel('Keyword Arg')
78 # CannotRepresent exists so that a portion of an Expression can be
79 # labeled imperfect. For example, the function "iscurrentweek"
80 # rarely has an SQL equivalent. All rows (which match the rest of the
81 # Expression) will be recalled; they can then be compared in expr(unit).
82 class CannotRepresent(Exception):
83     pass
84
85
86 class SqlCache(dict):
87     """A very basic length limited cache.
88
89     The limit by default is 1024 entries.
90     """
91
92     # this should probably be a LRU cache...
93
94     def __init__(self, max=1024):
95         dict.__init__(self)
96         assert max > 0
97         self.max = max
98
99     def __setitem__(self, key, val):
100         if len(self) == self.max:
101             self.popitem()
102         dict.__setitem__(self, key, val)
103
104 sql_cache = SqlCache()
105
106
107 class SQLDeparser(astwalk.ASTDeparser):
108     """Produce SQL from a supplied logic.Expression object.
109
110     Attributes of each argument in the Expression's function signature
111     will be mapped to table columns. Keyword arguments should be bound
112     using Expression.bind_args before calling this deparser.
113     """
114
115     # Whether or not the SQL perfectly matches the Python Expression.
116     # In many cases, a provider may be able to return an imperfect subset
117     # of the rows; they should generate the SQL for that and set imperfect
118     # to True. Once the deparser is running, no code should set imperfect
119     # to False.
120     # Note that this is not the same as CannotRepresent, which is the
121     # exception raised for an imperfect SUBexpression. But in general,
122     # this base class will set imperfect for you when computing AND
123     # and OR, and when the final result is examined. You should only
124     # have to set imperfect when you return an SQLExpression that is
125     # imperfect.
126     imperfect = False
127
128     # Some constants are function or class objects,
129     # which should not be coerced.
130     no_coerce = (FunctionType,
131                  type,
132                  type(len),       # <type 'builtin_function_or_method'>
133                  )
134
135     # SQL comparison operators (matching the order of opcode.cmp_op).
136     sql_cmp_op = {'<': '<',
137                   '<=': '<=',
138                   '==': '=',
139                   '!=': '!=',
140                   '>': '>',
141                   '>=': '>=',
142                   'in': 'in',
143                   'not in': 'not in',
144                   }
145
146     # SQL binary operators; a map from values in astwalk.binary_operators
147     # to their SQL equivalents. The default map is isomorphic.
148     sql_bin_op = dict([(k, k) for k in astwalk.repr_to_op])
149
150     none_expr = SQLExpression("NULL", "expr0", None, NoneType)
151
152     def __init__(self, tables, expr, typeset):
153         self.tables = tables
154         self.expr = expr
155         self.typeset = typeset
156
157         self.groups = []
158
159         # Cache coerced booleans and None
160         b = self.typeset.bool_exprs(SQLExpression)
161         self.expr_true, self.expr_false, self.comp_true, self.comp_false = b
162         for boolexpr in b:
163             self.exprcount += 1
164             boolexpr.name = "expr%s" % self.exprcount
165
166         astwalk.ASTDeparser.__init__(self, expr.ast)
167
168     exprcount = 0
169
170     def get_expr(self, sql, pytype, adapter=None, value=None):
171         """Return an SQLExpression for the given sql of the given pytype."""
172         self.exprcount += 1
173         dbtype = self.typeset.database_type(pytype, value=value)
174         e = SQLExpression(sql, "expr%s" % self.exprcount, dbtype, pytype)
175         e.adapter = adapter or dbtype.default_adapter(pytype)
176         return e
177
178     def const(self, value, sql=None):
179         """Return an SQLExpression for the given constant value."""
180         if value is None:
181             return self.none_expr
182
183         e = self.get_expr(sql, type(value), value=value)
184         e.value = value
185         if sql is None:
186             e.sql = e.adapter.push(value, e.dbtype)
187         return e
188
189     def code(self):
190         """Walk self and return a suitable WHERE clause."""
191         root = self.ast.root
192         rootrepr = repr(root)
193         tablenames = tuple([table.name for alias, table in self.tables])
194
195         # Grab the completed SQL from a cache, if available
196         try:
197             sql, imp = sql_cache[(self.typeset, rootrepr, tablenames)]
198         except KeyError:
199             pass
200         else:
201             self.imperfect = imp
202             return sql
203
204         self.imperfect = False
205
206         try:
207             result = self.walk(root)
208             # After walk(), the result should be a single string,
209             # which is the SQL representation of our Expression.
210         except CannotRepresent:
211             # The entire expression could not be evaluated.
212             result = self.expr_true
213             self.imperfect = True
214         else:
215             if result == self.comp_true:
216                 result = self.expr_true
217             elif result == self.comp_false:
218                 result = self.expr_false
219
220         # Cache the result
221         sql_cache[(self.typeset, rootrepr, tablenames)] = \
222                   (result.sql, self.imperfect)
223
224         return result.sql
225
226     def field_list(self):
227         """Walk self and return a list of field objects."""
228         self.imperfect = False
229         root = self.ast.root
230
231         # When building a field list, ignore the last BUILD_TUPLE.
232         if not isinstance(root, (astwalk.ast.Tuple, astwalk.ast.List)):
233             raise ValueError("Attribute AST roots must be Tuple or List, "
234                              "not %s" % root.__class__.__name__)
235         result = []
236         for term in root.getChildren():
237             self.aggregate = False
238             e = self.walk(term)
239             e.aggregate = self.aggregate
240             result.append(e)
241         return result
242
243     def walk(self, node):
244         """Walk the AST and return a string of code."""
245         nodetype = node.__class__.__name__
246         method = getattr(self, "visit_" + nodetype)
247         args = node.getChildren()
248         if self.verbose:
249             self.debug(nodetype, args)
250         return method(*args)
251
252     def _walk_terms(self, *terms):
253         newterms = []
254         for term in terms:
255             # TODO STRABS - The test in SQL injection are tests against using
256             #               strings in scalar values for stuff like ands and ors
257             #               this accomplishes that, but I'd like it reviewed.
258             if not self.imperfect and \
259                 isinstance(term, ast.Const) and isinstance(term.value, basestring
260             ):
261                 self.imperfect = True
262
263             try:
264                 term = self.walk(term)
265             except CannotRepresent:
266                 self.imperfect = True
267                 # Use TRUE for the term, so all records are returned.
268                 term = self.expr_true
269             else:
270                 # Blurg. SQL Server is *so* picky.
271                 if term == self.comp_true:
272                     term = self.expr_true
273                 elif term == self.comp_false:
274                     term = self.expr_false
275             newterms.append("(%s)" % term.sql)
276         return newterms
277
278     def visit_And(self, *terms):
279         newterms = self._walk_terms(*terms)
280         clause = self.get_expr(" AND ".join(newterms), bool)
281
282         if self.verbose:
283             self.debug("clause:", clause.sql, "\n")
284
285         return clause
286
287     def visit_Or(self, *terms):
288         newterms = self._walk_terms(*terms)
289         clause = self.get_expr(" OR ".join(newterms), bool)
290
291         if self.verbose:
292             self.debug("clause:", clause.sql, "\n")
293
294         return clause
295
296     def visit_Name(self, name):
297         if name in self.ast.args:
298             # We've hit a reference to a positional arg, which in our case
299             # implies a reference to a DB table.
300             alias, table = self.tables[self.ast.args.index(name)]
301             return SQLTableRef(table, alias)
302         else:
303             # Since lambdas don't support local bindings,
304             # any remaining local name must be a keyword arg.
305             return kw_arg
306
307     def visit_Getattr(self, expr, attrname):
308         expr = self.walk(expr)
309         if isinstance(expr, SQLTableRef):
310             # The name in question refers to a DB column (see visit_Name).
311             col = expr.table[attrname]
312             atom = SQLExpression('%s.%s' % (expr.alias, col.qname),
313                                  attrname, col.dbtype, col.pytype)
314             atom.adapter = col.adapter
315         else:
316             # 'expr.name' will reference an attribute of the expr object.
317             # Stick the expr and name in a tuple for later processing
318             # (for example, in visit_CallFunc).
319             atom = (expr, attrname)
320         return atom
321
322     def visit_Const(self, value):
323         if not isinstance(value, self.no_coerce):
324             value = self.const(value)
325         return value
326
327     def visit_Tuple(self, *terms):
328         val = []
329         newterms = []
330         for term in terms:
331             term = self.walk(term)
332             val.append(term.value)
333             newterms.append(term.sql)
334         return SQLExpression("(" + ", ".join(newterms) + ")",
335                              "tuple", None, tuple, tuple(val))
336
337     # Assume all DB's have a tuple () syntax but no list [] syntax
338     visit_List = visit_Tuple
339
340     def visit_CallFunc(self, func, *args):
341         # e.g. CallFunc(Name('min'), [Getattr(Name('v'), 'Date')], None, None)
342         dstar_args = args[-1]
343         star_args = args[-2]
344
345         posargs = []
346         kwargs = {}
347         for arg in args[:-2]:
348             if isinstance(arg, astwalk.ast.Keyword):
349                 kwargs[arg.name] = self.walk(arg.value)
350             else:
351                 posargs.append(self.walk(arg))
352
353         func = self.walk(func)
354
355         # Handle function objects.
356         if isinstance(func, tuple):
357             # A function which was an attribute of another object;
358             # for example, "x.Field.startswith". The tuple will be of
359             # the form (obj, name). See visit_GetAttr.
360             obj, name = func
361             dispatch = getattr(self, "attr_" + name, None)
362             if dispatch:
363                 return dispatch(obj, *posargs)
364             raise CannotRepresent("No handler found for function %r.%r." %
365                                   (obj, name))
366
367         if logic.builtins.get(func.__name__, None) is func:
368             dispatch = getattr(self, "builtins_" + func.__name__, None)
369             if dispatch:
370                 return dispatch(*posargs)
371
372         funcname = func.__module__ + "_" + func.__name__
373         funcname = funcname.replace(".", "_")
374         if funcname.startswith("_"):
375             funcname = "func" + funcname
376         dispatch = getattr(self, funcname, None)
377         if dispatch:
378             return dispatch(*posargs)
379
380         raise CannotRepresent(func)
381
382     # Validity for a comparison operation between two types.
383     compare_types = {}
384
385     def visit_Compare(self, op1, *ops):
386         op1 = self.walk(op1)
387
388         newterms = []
389         i = 0
390         while i < len(ops):
391             op, op2 = ops[i:i+2]
392             i += 2
393             op2 = self.walk(op2)
394
395             if not self.compare_types.get((op1.pytype, op, op2.pytype), False):
396                 raise CannotRepresent("No comparison function %r between %r and %r" %
397                                       (op, op1, op2))
398
399             if op == 'in':
400                 term = self.containedby(op1, op2)
401             elif op == 'not in':
402                 term = self.containedby(op1, op2)
403                 term.sql = "NOT " + term.sql
404             elif op1.sql == 'NULL':
405                 if op in ('==', 'is'):
406                     term = self.get_expr(op2.sql + " IS NULL", bool)
407                 elif op in ('!=', 'is not'):
408                     term = self.get_expr(op2.sql + " IS NOT NULL", bool)
409                 else:
410                     raise ValueError("Non-equality Null comparisons not allowed.")
411             elif op2.sql == 'NULL':
412                 if op in ('==', 'is'):
413                     term = self.get_expr(op1.sql + " IS NULL", bool)
414                 elif op in ('!=', 'is not'):
415                     term = self.get_expr(op1.sql + " IS NOT NULL", bool)
416                 else:
417                     raise ValueError("Non-equality Null comparisons not allowed.")
418             elif op in reverseop:
419                 try:
420                     sql = op1.adapter.compare_op(op1, op, self.sql_cmp_op[op], op2)
421                 except TypeError, exc:
422                     if self.verbose:
423                         self.debug("".join(traceback.format_exception(*sys.exc_info())))
424                     rop = reverseop[op]
425                     try:
426                         sql = op1.adapter.compare_op(op2, rop, self.sql_cmp_op[rop], op1)
427                     except TypeError, exc:
428                         if self.verbose:
429                             self.debug("".join(traceback.format_exception(*sys.exc_info())))
430                         raise CannotRepresent("No comparison function %r "
431                                               "between %r and %r." %
432                                               (op, op1, op2))
433                 term = self.get_expr(sql, bool)
434             else:
435                 raise ValueError("Operator %r not handled." % op)
436
437             newterms.append("(%s)" % term.sql)
438             op1 = op2
439         return self.get_expr(" and ".join(newterms), bool)
440
441     def visit_Subscript(self, expr, flags, *subs):
442         expr = self.walk(expr)
443         # The only Subscript used in Expressions should be kwargs[key].
444         if expr is not kw_arg:
445             raise ValueError("Subscript %r of %s object not allowed." %
446                              (subs, expr))
447         if len(subs) > 1:
448             raise ValueError("Multiple subscripts %r of %s not supported."  %
449                              (subs, expr))
450
451         name = subs[0].value
452
453         value = self.expr.kwargs[name]
454         if not isinstance(value, self.no_coerce):
455             value = self.const(value)
456         return value
457
458     def visit_Not(self, expr):
459         expr = self.walk(expr)
460         return self.get_expr("NOT (" + expr.sql + ")", bool)
461
462     # --------------------------- Dispatchees --------------------------- #
463
464     def attr_startswith(self, op1, op2):
465         return self.get_expr(op1.adapter.like_op(op1, op2, start_only=True),
466                              bool)
467
468     def attr_endswith(self, op1, op2):
469         return self.get_expr(op1.adapter.like_op(op1, op2, end_only=True),
470                              bool)
471
472     def containedby(self, op1, op2):
473         if op1.value is not None:
474             # Looking for text in a field. Use Like (reverse terms).
475             return self.get_expr(op2.adapter.like_op(op2, op1), bool)
476         else:
477             # Looking for field in (a, b, c)
478             atoms = []
479             for x in op2.value:
480                 adapter = op1.dbtype.default_adapter(type(x))
481                 atoms.append(adapter.push(x, op1.dbtype))
482             if atoms:
483                 return self.get_expr(op1.sql + " IN (" + ", ".join(atoms) + ")", bool)
484             else:
485                 # Nothing will match the empty list, so return none.
486                 return self.expr_false
487
488     def builtins_icontainedby(self, op1, op2):
489         if op1.value is not None:
490             # Looking for text in a field. Use Like (reverse terms).
491             return self.get_expr(op2.adapter.like_op(
492                 op2, op1, ignore_case=True), bool)
493         else:
494             # Looking for field in (a, b, c).
495             # Force all args to lowercase for case-insensitive comparison.
496             atoms = []
497             for x in op2.value:
498                 adapter = op1.dbtype.default_adapter(type(x))
499                 atoms.append(adapter.push(x.lower(), op1.dbtype))
500             return self.get_expr("LOWER(%s) IN (%s)" %
501                                  (op1.sql, ", ".join(atoms)), bool)
502
503     def builtins_icontains(self, x, y):
504         return self.builtins_icontainedby(y, x)
505
506     def builtins_istartswith(self, op1, op2):
507         return self.get_expr(op1.adapter.like_op(
508             op1, op2, ignore_case=True, start_only=True), bool)
509
510     def builtins_iendswith(self, x, y):
511         return self.get_expr(op1.adapter.like_op(
512             op1, op2, ignore_case=True, end_only=True), bool)
513
514     def builtins_ieq(self, x, y):
515         return self.get_expr("LOWER(" + x.sql + ") = LOWER(" + y.sql + ")", bool)
516
517     def builtins_now(self):
518         """Return a datetime.datetime for the current time in the local TZ."""
519         return self.get_expr("NOW()", datetime.datetime)
520
521     def builtins_utcnow(self):
522         """Return a datetime.datetime for the current time in the UTC TZ."""
523         raise CannotRepresent("utcnow not implemented")
524
525     def builtins_today(self):
526         """Return a datetime.datetime for the current time in the local TZ."""
527         return self.get_expr("CURRENT_DATE", datetime.date)
528
529     def builtins_year(self, x):
530         return self.get_expr("YEAR(" + x.sql + ")", int)
531
532     def builtins_month(self, x):
533         return self.get_expr("MONTH(" + x.sql + ")", int)
534
535     def builtins_day(self, x):
536         return self.get_expr("DAY(" + x.sql + ")", int)
537
538     def func__builtin___len(self, x):
539         return self.get_expr("LENGTH(" + x.sql + ")", int)
540
541     def func__builtin___min(self, x):
542         self.aggregate = True
543         x.name = "min_%s" % x.name
544         x.sql = "MIN(" + x.sql + ")"
545         return x
546
547     def func__builtin___max(self, x):
548         self.aggregate = True
549         x.name = "max_%s" % x.name
550         x.sql = "MAX(" + x.sql + ")"
551         return x
552
553     def builtins_count(self, x):
554         self.aggregate = True
555         return self.get_expr("COUNT(" + x.sql + ")", int)
556
557     def func__builtin___reversed(self, x):
558         # Assume reversed is always used for DESC ordering.
559         x.sql += " DESC"
560         return x
561     # For version of Python which did not possess the 'reversed' builtin.
562     builtins_reversed = func__builtin___reversed
563
564     def builtins_alias(self, x, y):
565         # We don't need to modify x.sql here; SelectWriter.deparse_attributes
566         # will include the " AS name" clause for us.
567         x.name = y.sql.strip("\"'")
568         return x
569
570     #                           Binary operations                         #
571
572     # Resultant type for a binary operation between two types.
573     result_type = {}
574
575     def binary_op(self, left, op, right):
576         left = self.walk(left)
577         right = self.walk(right)
578
579         try:
580             newsql = left.adapter.binary_op(left, op,
581                                             self.sql_bin_op[op], right)
582         except TypeError:
583             raise CannotRepresent("No binary function %r between %r and %r" %
584                                   (op, left, right))
585
586         newpytype = self.result_type[(left.pytype, op, right.pytype)]
587
588         # re-use left
589         left.sql = newsql
590         if newpytype != left.pytype:
591             left.pytype = newpytype
592             left.dbtype = self.typeset.database_type(newpytype)
593             left.adapter = left.dbtype.default_adapter(newpytype)
594         if not left.name.startswith("expr_"):
595             left.name = "expr_%s" % left.name
596         return left
597
598
599 def _binary_operation_result_types():
600     """Return a dict of (type(A), op, type(B)): type(op(A, B)) for known types."""
601     results = {}
602
603     knowntypes = [3, 3L, 3.0, 'a', u'b', True]
604     try:
605         import datetime
606         knowntypes.extend([datetime.date(2004, 1, 1),
607                            datetime.datetime(2004, 1, 31),
608                            datetime.timedelta(3)])
609     except ImportError:
610         pass
611     try:
612         import decimal
613         knowntypes.append(decimal.Decimal(3))
614     except ImportError:
615         pass
616
617     for A in knowntypes:
618         for B in knowntypes:
619             for symbol, op in astwalk.repr_to_op.iteritems():
620                 try:
621                     result = op(A, B)
622                 except TypeError:
623                     pass
624                 else:
625                     results[(type(A), symbol, type(B))] = type(result)
626
627     return results
628 SQLDeparser.result_type = _binary_operation_result_types()
629
630 def _comparison_operation_types():
631     """Return a dict of {(type(A), op, type(B)): can compare?} for known types."""
632     results = {}
633
634     knowntypes = [3, 3L, 3.0, 'a', u'b', None, True]
635     numtypes = [int, long, float]
636     try:
637         import datetime
638         knowntypes.extend([datetime.date(2004, 1, 1),
639                            datetime.datetime(2004, 1, 31),
640                            datetime.timedelta(3)])
641         datetypes = [datetime.date, datetime.datetime, datetime.timedelta]
642     except ImportError:
643         datetypes = []
644
645     try:
646         import decimal
647         knowntypes.append(decimal.Decimal(3))
648         numtypes.append(decimal.Decimal)
649     except ImportError:
650         pass
651
652     import operator, opcode
653
654     for A in knowntypes:
655         # All types should allow unrestricted containment comparisons.
656         # The type of each element in the list will have to be checked
657         # inside the Deparser.
658         for symbol in ['in', 'not in']:
659             # A in (1, 2, 3)
660             results[(type(A), symbol, list)] = True
661             results[(type(A), symbol, tuple)] = True
662
663         for B in knowntypes:
664             # Python versions previous to 2.6 allowed comparisons between
665             # unrelated types, like 'abc' > 12. Manually munge known
666             # incompatibilities in the results by special-casing the
667             # comparison operators for dissimilar types.
668             for symbol, op in [('<', operator.lt), ('<=', operator.le),
669                                ('>', operator.gt), ('>=', operator.ge)]:
670                 if type(A) in numtypes and type(B) in numtypes:
671                     results[(type(A), symbol, type(B))] = True
672                 elif type(A) in [str, unicode] and type(B) in [str, unicode]:
673                     results[(type(A), symbol, type(B))] = True
674                 elif type(A) in datetypes or type(B) in datetypes:
675                     # The datetime types are very strict about comparisons.
676                     try:
677                         result = op(A, B)
678                     except TypeError:
679                         results[(type(A), symbol, type(B))] = False
680                     else:
681                         results[(type(A), symbol, type(B))] = True
682                 else:
683                     results[(type(A), symbol, type(B))] = False
684
685             # However, all types should allow equality comparison.
686             for symbol, op in [('==', operator.eq), ('!=', operator.ne),
687                                ('is', operator.is_), ('is not', operator.is_not)]:
688                 try:
689                     result = op(A, B)
690                 except TypeError:
691                     results[(type(A), symbol, type(B))] = False
692                 else:
693                     results[(type(A), symbol, type(B))] = True
694
695             # ...and only the string types should allow A LIKE B
696             if type(A) in [str, unicode] and type(B) in [str, unicode]:
697                 results[(type(A), 'in', type(B))] = True
698                 results[(type(A), 'not in', type(B))] = True
699
700     return results
701 SQLDeparser.compare_types = _comparison_operation_types()
702
Note: See TracBrowser for help on using the browser.