Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

root/trunk/geniusql/decompile.py

Revision 62 (checked in by fumanchu, 6 years ago)

Postgres fixes (new adapters and dbtypes).

  • Property svn:eol-style set to native
Line 
1 import datetime
2 from types import FunctionType
3 from geniusql import logic, codewalk
4
5 # Comparison operator order from opcode.cmp_op:
6 #            0   1   2   3  4   5
7 #            <  <=  ==  !=  >  >=
8 # Comparison operator order when terms are swapped:
9 #            >  =>  ==  !=  <  <=
10 reverseop = (4,  5,  2,  3, 0,  1)
11
12
13 __all__ = [
14     'SQLExpression', 'Sentinel',
15     'cannot_represent', 'kw_arg', 'SQLDecompiler',
16     ]
17
18
19 class SQLExpression(object):
20     """Wraps a column or other expression for use in SQLDecompiler's stack.
21     
22     sql: the expression to be placed in the SQL for this expression.
23         This should not contain an alias ("AS" clause); that will be
24         provided by the consumer, usually via the 'name' attribute.
25     name: the name of the expression; may be used as an alias (with "AS").
26     value: If not None, the expression is a "constant"; that is, we already
27         know its defined Python value (and that it does not have any basis
28         in column values).
29     """
30    
31     def __init__(self, sql, name, dbtype, pytype, value=None):
32         self.sql = sql
33         self.name = name
34        
35         self.dbtype = dbtype
36         self.pytype = pytype
37         self.adapter = None
38        
39         self.value = value
40         self.aggregate = False
41    
42     def __cmp__(self, other):
43         if isinstance(other, SQLExpression):
44             return cmp(self.sql, other.sql)
45         raise TypeError("can't compare %s to %s" % (type(self), type(other)),
46                         other)
47    
48     def __repr__(self):
49         return ("%s.%s(%r, dbtype=%s)" %
50                 (self.__module__, self.__class__.__name__, self.sql,
51                  self.dbtype.__class__.__name__))
52
53
54 # Stack sentinels
55 class Sentinel(object):
56    
57     def __init__(self, name):
58         self.name = name
59    
60     def __repr__(self):
61         return 'Stack Sentinel: %s' % self.name
62
63 kw_arg = Sentinel('Keyword Arg')
64 # cannot_represent exists so that a portion of an Expression can be
65 # labeled imperfect. For example, the function "iscurrentweek"
66 # rarely has an SQL equivalent. All rows (which match the rest of the
67 # Expression) will be recalled; they can then be compared in expr(unit).
68 cannot_represent = Sentinel('Cannot Repr')
69
70
71 class SQLDecompiler(codewalk.LambdaDecompiler):
72     """Produce SQL from a supplied logic.Expression object.
73     
74     Attributes of each argument in the Expression's function signature
75     will be mapped to table columns. Keyword arguments should be bound
76     using Expression.bind_args before calling this decompiler.
77     """
78    
79     # Whether or not the SQL perfectly matches the Python Expression.
80     # In many cases, a provider may be able to return an imperfect subset
81     # of the rows; they should generate the SQL for that and set imperfect
82     # to True. Once the decompiler is running, no code should set imperfect
83     # to False.
84     # Note that this is not the same as cannot_represent, which is the stack
85     # value for an imperfect SUBexpression. But in general, this base class
86     # will set imperfect for you when computing AND and OR, and when the
87     # final result is examined. You should only have to set imperfect
88     # when you put actual SQL on the stack that is imperfect.
89     imperfect = False
90    
91     # Some constants are function or class objects,
92     # which should not be coerced.
93     no_coerce = (FunctionType,
94                  type,
95                  type(len),       # <type 'builtin_function_or_method'>
96                  )
97    
98     # SQL comparison operators (matching the order of opcode.cmp_op).
99     sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in')
100    
101     # SQL binary operators; a map from values in codewalk.binary_operators
102     # to their SQL equivalents. The default map is isomorphic.
103     sql_bin_op = dict([(v, v) for v in codewalk.binary_repr.itervalues()])
104    
105     # These are not adapter.push(bool) (which are used on one side of
106     # a comparison). Instead, these are used when the whole (sub)expression
107     # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3".
108     bool_true = "TRUE"
109     bool_false = "FALSE"
110    
111     def __init__(self, tables, expr, adapterset):
112         self.tables = tables
113         self.expr = expr
114         self.adapterset = adapterset
115        
116         self.groups = []
117        
118         # Cache coerced booleans
119         self.true_expr = self.const(True, self.bool_true)
120         self.false_expr = self.const(False, self.bool_false)
121        
122         booldbtype = self.adapterset.database_type(bool)
123         booladapter = self.adapterset.default(bool, booldbtype)
124         self.T = self.const(True, booladapter.push(True, booldbtype))
125         self.F = self.const(False, booladapter.push(False, booldbtype))
126        
127         self.none_expr = SQLExpression("NULL", "expr0", None, type(None))
128        
129         codewalk.LambdaDecompiler.__init__(self, expr.func)
130    
131     exprcount = 0
132    
133     def get_expr(self, sql, pytype, adapter=None):
134         """Return an SQLExpression for the given sql of the given pytype."""
135         self.exprcount += 1
136         name = "expr%s" % self.exprcount
137        
138         dbtype = self.adapterset.database_type(pytype)
139         e = SQLExpression(sql, name, dbtype, pytype)
140         e.adapter = adapter or self.adapterset.default(pytype, dbtype)
141        
142         return e
143    
144     def const(self, value, sql=None):
145         """Return an SQLExpression for the given constant value."""
146         if value is None:
147             return self.none_expr
148        
149         e = self.get_expr(sql, type(value))
150         e.value = value
151         if sql is None:
152             e.sql = e.adapter.push(value, e.dbtype)
153         return e
154    
155     def append_expr(self, sql, pytype):
156         """Syntactic sugar for self.stack.append(self.get_expr(sql, pytype))."""
157         self.stack.append(self.get_expr(sql, pytype))
158    
159     def code(self):
160         """Walk self and return a suitable WHERE clause."""
161         self.imperfect = False
162         self.walk()
163         # After walk(), self.stack should be reduced to a single string,
164         # which is the SQL representation of our Expression.
165         result = self.stack[0]
166         if result is cannot_represent:
167             # The entire expression could not be evaluated.
168             result = self.true_expr
169             self.imperfect = True
170         elif result == self.T:
171             result = self.true_expr
172         elif result == self.F:
173             result = self.false_expr
174         return result.sql
175    
176     _ignore_final_build = False
177    
178     def field_list(self):
179         """Walk self and return a list of field objects."""
180         self._ignore_final_build = True
181         self.walk()
182         return self.stack
183    
184     def visit_instruction(self, op, lo=None, hi=None):
185         # Get the instruction pointer for the current instruction.
186         ip = self.cursor - 3
187         if hi is None:
188             ip += 1
189             if lo is None:
190                 ip += 1
191        
192         terms = self.targets.get(ip)
193         if terms:
194             clause = self.stack[-1]
195             while terms:
196                 term, oper = terms.pop()
197                 if term is cannot_represent:
198                     # Use TRUE for the term, so all records are returned.
199                     term = self.true_expr
200                     self.imperfect = True
201                 if clause is cannot_represent:
202                     # Use TRUE for the clause, so all records are returned.
203                     clause = self.true_expr
204                     self.imperfect = True
205                
206                 # Blurg. SQL Server is *so* picky.
207                 if term == self.T:
208                     term = self.true_expr
209                 elif term == self.F:
210                     term = self.false_expr
211                 if clause == self.T:
212                     clause = self.true_expr
213                 elif clause == self.F:
214                     clause = self.false_expr
215                
216                 clause = self.get_expr("(%s) %s (%s)" %
217                                        (term.sql, oper.upper(), clause.sql),
218                                        bool)
219            
220             # Replace TOS with the new clause, so that further
221             # combinations have access to it.
222             self.stack[-1] = clause
223             if self.verbose:
224                 self.debug("clause:", clause.sql, "\n")
225            
226             if op == 1:
227                 # Py2.4: The current instruction is POP_TOP, which means
228                 # the previous is probably JUMP_*. If so, we're going to
229                 # pop the value we just placed on the stack and lose it.
230                 # We need to replace the entry that the JUMP_* made in
231                 # self.targets with our new TOS.
232                 target = self.targets[self.last_target_ip]
233                 target[-1] = ((clause, target[-1][1]))
234                 if self.verbose:
235                     self.debug("newtarget:", self.last_target_ip, target)
236    
237     def visit_LOAD_DEREF(self, lo, hi):
238         raise ValueError("Illegal reference found in %s." % self.expr)
239    
240     def visit_LOAD_GLOBAL(self, lo, hi):
241         raise ValueError("Illegal global found in %s." % self.expr)
242    
243     def visit_LOAD_FAST(self, lo, hi):
244         arg_index = lo + (hi << 8)
245         if arg_index < self.co_argcount:
246             # We've hit a reference to a positional arg, which in our case
247             # implies a reference to a DB table. Append the (qname, table)
248             # tuple for later unpacking inside visit_LOAD_ATTR.
249             self.stack.append(self.tables[arg_index])
250         else:
251             # Since lambdas don't support local bindings,
252             # any remaining local name must be a keyword arg.
253             self.stack.append(kw_arg)
254    
255     def visit_LOAD_ATTR(self, lo, hi):
256         name = self.co_names[lo + (hi << 8)]
257         tos = self.stack.pop()
258         if isinstance(tos, tuple):
259             # The name in question refers to a DB column (see visit_LOAD_FAST).
260             alias, table = tos
261             col = table[name]
262             atom = SQLExpression('%s.%s' % (alias, col.qname),
263                                  name, col.dbtype, col.pytype)
264             atom.adapter = col.adapter
265         else:
266             # 'tos.name' will reference an attribute of the tos object.
267             # Stick the tos and name in a tuple for later processing
268             # (for example, in visit_CALL_FUNCTION).
269             atom = (tos, name)
270         self.stack.append(atom)
271    
272     def visit_LOAD_CONST(self, lo, hi):
273         val = self.co_consts[lo + (hi << 8)]
274         if not isinstance(val, self.no_coerce):
275             val = self.const(val)
276         self.stack.append(val)
277    
278     def visit_BUILD_TUPLE(self, lo, hi):
279         if self.cursor == len(self._bytecode) - 1 and self._ignore_final_build:
280             # When building a field list, ignore the last BUILD_TUPLE.
281             return
282        
283         terms = ", ".join([self.stack.pop().sql
284                            for i in range(lo + (hi << 8))])
285         self.stack.append(SQLExpression("(" + terms + ")", "tuple",
286                                         None, None, True))
287    
288     visit_BUILD_LIST = visit_BUILD_TUPLE
289    
290     def visit_CALL_FUNCTION(self, lo, hi):
291         kwargs = {}
292         for i in xrange(hi):
293             val = self.stack.pop()
294             key = self.stack.pop()
295             kwargs[key] = val
296         kwargs = [k.sql + "=" + v.sql for k, v in kwargs.iteritems()]
297        
298         args = []
299         for i in xrange(lo):
300             arg = self.stack.pop()
301             args.append(arg)
302         args.reverse()
303        
304         if kwargs:
305             args += kwargs
306        
307         func = self.stack.pop()
308        
309         # Handle function objects.
310         if isinstance(func, tuple):
311             # A function which was an attribute of another object;
312             # for example, "x.Field.startswith". The tuple will be of
313             # the form (tos, name) where "tos" is the object and 'name'
314             # is the name of the desired attribute of that object.
315             # See visit_LOAD_ATTR.
316             tos, name = func
317             dispatch = getattr(self, "attr_" + name, None)
318             if dispatch:
319                 self.stack.append(dispatch(tos, *args))
320                 return
321         elif logic.builtins.get(func.__name__, None) is func:
322             dispatch = getattr(self, "builtins_" + func.__name__, None)
323             if dispatch:
324                 self.stack.append(dispatch(*args))
325                 return
326         else:
327             funcname = func.__module__ + "_" + func.__name__
328             funcname = funcname.replace(".", "_")
329             if funcname.startswith("_"):
330                 funcname = "func" + funcname
331             dispatch = getattr(self, funcname, None)
332             if dispatch:
333                 self.stack.append(dispatch(*args))
334                 return
335        
336         self.stack.append(cannot_represent)
337    
338     def visit_COMPARE_OP(self, lo, hi):
339         op2, op1 = self.stack.pop(), self.stack.pop()
340         if op1 is cannot_represent or op2 is cannot_represent:
341             self.stack.append(cannot_represent)
342             return
343        
344         op = lo + (hi << 8)
345         if op in (6, 7):     # in, not in
346             value = self.containedby(op1, op2)
347             if op == 7:
348                 value.sql = "NOT " + value.sql
349             self.stack.append(value)
350         elif op1.sql == 'NULL':
351             if op in (2, 8):    # '==', is
352                 self.append_expr(op2.sql + " IS NULL", bool)
353             elif op in (3, 9):  # '!=', 'is not'
354                 self.append_expr(op2.sql + " IS NOT NULL", bool)
355             else:
356                 raise ValueError("Non-equality Null comparisons not allowed.")
357         elif op2.sql == 'NULL':
358             if op in (2, 8):    # '==', 'is'
359                 self.append_expr(op1.sql + " IS NULL", bool)
360             elif op in (3, 9):  # '!=', 'is not'
361                 self.append_expr(op1.sql + " IS NOT NULL", bool)
362             else:
363                 raise ValueError("Non-equality Null comparisons not allowed.")
364         elif 0 <= op <= 5:
365             try:
366                 sql = op1.adapter.compare_op(op1, op, self.sql_cmp_op[op], op2)
367             except TypeError:
368                 try:
369                     rop = reverseop[op]
370                     sql = op1.adapter.compare_op(op2, rop, self.sql_cmp_op[rop], op1)
371                 except TypeError:
372                     self.stack.append(cannot_represent)
373                     return
374             self.append_expr(sql, bool)
375         else:
376             import opcode
377             raise ValueError("Operator %r not handled." % opcode.cmp_op[op])
378    
379     def visit_BINARY_SUBSCR(self):
380         # The only BINARY_SUBSCR used in Expressions should be kwargs[key].
381         name = self.stack.pop()
382         tos = self.stack.pop()
383         if tos is not kw_arg:
384             raise ValueError("Subscript %s of %s object not allowed."
385                              % (name, tos))
386         # name, since formed in LOAD_CONST, may have extraneous quotes.
387         name = name.sql.strip("'\"")
388         value = self.expr.kwargs[name]
389         if not isinstance(value, self.no_coerce):
390             value = self.const(value)
391         self.stack.append(value)
392    
393     def visit_UNARY_NOT(self):
394         op = self.stack.pop()
395         if op is cannot_represent:
396             self.stack.append(cannot_represent)
397         else:
398             self.append_expr("NOT (" + op.sql + ")", bool)
399    
400     # --------------------------- Dispatchees --------------------------- #
401    
402     # Notice these are ordered pairs. Escape \ before introducing new ones.
403     # Values in these two lists should be strings encoded with self.encoding.
404     like_escapes = [("%", r"\%"), ("_", r"\_")]
405    
406     def escape_like(self, value):
407         """Prepare a string value for use in a LIKE comparison."""
408         if not isinstance(value, str):
409             value = value.encode(self.encoding)
410         # Notice we strip leading and trailing quote-marks.
411         value = value.strip("'\"")
412         for pat, repl in self.like_escapes:
413             value = value.replace(pat, repl)
414         return value
415    
416     def attr_startswith(self, tos, arg):
417         return self.get_expr(tos.sql + " LIKE '" + self.escape_like(arg.sql) + "%'", bool)
418    
419     def attr_endswith(self, tos, arg):
420         return self.get_expr(tos.sql + " LIKE '%" + self.escape_like(arg.sql) + "'", bool)
421    
422     def containedby(self, op1, op2):
423         if op1.value is not None:
424             # Looking for text in a field. Use Like (reverse terms).
425             like = self.escape_like(op1.sql)
426             return self.get_expr(op2.sql + " LIKE '%" + like + "%'", bool)
427         else:
428             # Looking for field in (a, b, c)
429             atoms = []
430             for x in op2.value:
431                 adapter = self.adapterset.default(type(x), op1.dbtype)
432                 atoms.append(adapter.push(x, op1.dbtype))
433             if atoms:
434                 return self.get_expr(op1.sql + " IN (" + ", ".join(atoms) + ")", bool)
435             else:
436                 # Nothing will match the empty list, so return none.
437                 return self.false_expr
438    
439     def builtins_icontainedby(self, op1, op2):
440         if op1.value is not None:
441             # Looking for text in a field. Use Like (reverse terms).
442             return self.get_expr("LOWER(" + op2.sql + ") LIKE '%" +
443                                  self.escape_like(op1.sql).lower()
444                                  + "%'", bool)
445         else:
446             # Looking for field in (a, b, c).
447             # Force all args to lowercase for case-insensitive comparison.
448             atoms = []
449             for x in op2.value:
450                 adapter = self.adapterset.default(type(x), op1.dbtype)
451                 atoms.append(adapter.push(x.lower(), op1.dbtype))
452             return self.get_expr("LOWER(%s) IN (%s)" %
453                                  (op1.sql, ", ".join(atoms)), bool)
454    
455     def builtins_icontains(self, x, y):
456         return self.builtins_icontainedby(y, x)
457    
458     def builtins_istartswith(self, x, y):
459         return self.get_expr("LOWER(" + x.sql + ") LIKE '" +
460                              self.escape_like(y.sql) + "%'", bool)
461    
462     def builtins_iendswith(self, x, y):
463         return self.get_expr("LOWER(" + x.sql + ") LIKE '%" +
464                              self.escape_like(y.sql) + "'", bool)
465    
466     def builtins_ieq(self, x, y):
467         return self.get_expr("LOWER(" + x.sql + ") = LOWER(" + y.sql + ")", bool)
468    
469     def builtins_now(self):
470         """Return a datetime.datetime for the current time in the local TZ."""
471         return self.get_expr("NOW()", datetime.datetime)
472    
473     def builtins_utcnow(self):
474         """Return a datetime.datetime for the current time in the UTC TZ."""
475         return cannot_represent
476    
477     def builtins_today(self):
478         """Return a datetime.datetime for the current time in the local TZ."""
479         return self.get_expr("CURRENT_DATE", datetime.date)
480    
481     def builtins_year(self, x):
482         return self.get_expr("YEAR(" + x.sql + ")", int)
483    
484     def builtins_month(self, x):
485         return self.get_expr("MONTH(" + x.sql + ")", int)
486    
487     def builtins_day(self, x):
488         return self.get_expr("DAY(" + x.sql + ")", int)
489    
490     def func__builtin___len(self, x):
491         return self.get_expr("LENGTH(" + x.sql + ")", int)
492    
493     def func__builtin___min(self, x):
494         x.aggregate = True
495         x.name = "min_%s" % x.name
496         x.sql = "MIN(" + x.sql + ")"
497         return x
498    
499     def func__builtin___max(self, x):
500         x.aggregate = True
501         x.name = "max_%s" % x.name
502         x.sql = "MAX(" + x.sql + ")"
503         return x
504    
505     def builtins_count(self, x):
506         e = self.get_expr("COUNT(" + x.sql + ")", int)
507         e.aggregate = True
508         return e
509    
510     #                           Binary operations                         #
511    
512     # Resultant type for a binary operation between two types.
513     result_type = {}
514    
515     def binary_op(self, op):
516         op2, op1 = self.stack.pop(), self.stack.pop()
517         if op1 is cannot_represent or op2 is cannot_represent:
518             self.stack.append(cannot_represent)
519             return
520        
521         try:
522             newsql = op1.adapter.binary_op(op1, op, self.sql_bin_op[op], op2)
523         except TypeError:
524             self.stack.append(cannot_represent)
525             return
526        
527         newpytype = self.result_type[(op1.pytype, op, op2.pytype)]
528        
529         # re-use op1
530         op1.sql = newsql
531         if newpytype != op1.pytype:
532             op1.pytype = newpytype
533             op1.dbtype = self.adapterset.database_type(newpytype)
534             op1.adapter = self.adapterset.default(newpytype, op1.dbtype)
535         if not op1.name.startswith("expr_"):
536             op1.name = "expr_%s" % op1.name
537         self.stack.append(op1)
538
539 # Add visit_BINARY_* methods.
540 for k, v in codewalk.binary_repr.iteritems():
541     setattr(SQLDecompiler, "visit_" + k,
542             lambda self, op=v: self.binary_op(op))
543 del k, v
544
545 def _binary_operation_result_types():
546     """Return a dict of (type(A), op, type(B)): type(op(A, B)) for known types."""
547     results = {}
548    
549     knowntypes = [3, 3L, 3.0, 'a', u'b']
550     try:
551         import datetime
552         knowntypes.extend([datetime.date(2004, 1, 1),
553                            datetime.datetime(2004, 1, 31),
554                            datetime.timedelta(3)])
555     except ImportError:
556         pass
557    
558     try:
559         import decimal
560         knowntypes.append(decimal.Decimal(3))
561     except ImportError:
562         pass
563    
564     ops = [(symbol, codewalk.binary_operators[name])
565            for name, symbol in codewalk.binary_repr.iteritems()]
566    
567     for A in knowntypes:
568         for B in knowntypes:
569             for symbol, op in ops:
570                 try:
571                     result = op(A, B)
572                 except TypeError:
573                     pass
574                 else:
575                     results[(type(A), symbol, type(B))] = type(result)
576    
577     return results
578 SQLDecompiler.result_type = _binary_operation_result_types()
579
Note: See TracBrowser for help on using the browser.