Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

root/trunk/geniusql/sqlwriters.py

Revision 282 (checked in by lakin, 2 years ago)

Fixing inserts that don't specify any values. Needs a custom MySQL insert statement, others use INSERT INTO <table> DEFAULT VALUES. MySQL needs INSERT INTO <table> VALUES ();

  • Property svn:eol-style set to native
Line 
1 from types import FunctionType
2
3 import geniusql
4 from geniusql import errors, Join, logic
5
6
7 __all__ = ['TableWrapper', 'SQLStatement', 'SQLWriter',
8            'SELECT', 'SelectWriter', 'UPDATE', 'UpdateWriter',
9            'DELETE', 'DeleteWriter', 'INSERT', 'InsertWriter',
10            ]
11
12
13 class TableWrapper(object):
14     """Table class wrapper, for use in parsing joins (allowing aliases)."""
15    
16     def __init__(self, table):
17         self.table = table
18         self.qname = table.qname
19         # *quoted* alias
20         self.alias = ""
21
22
23 # ------------------ Writer and statement base classes ------------------ #
24
25
26 class SQLStatement(object):
27     input = None
28     fromclause = ""
29     whereclause = ""
30     imperfect = False
31     sql = ""
32
33
34 class SQLWriter(object):
35     """Database delegate (base class) for writing SQL from a set of objects.
36     
37     db: a Database instance.
38     query: a Query instance.
39     """
40    
41     statement_class = SQLStatement
42    
43     def __init__(self, db, query):
44         self.db = db
45         self.query = query
46        
47         self.seen = {}
48         self.aliascount = 0
49        
50         self.statement = self.statement_class()
51        
52         self.process_relation()
53         self.process_where()
54         self.unpack_attributes()
55    
56     def process_relation(self):
57         # Create a new join tree where each table is wrapped.
58         # Then we can tag the wrappers with "alias" metadata with impunity.
59         relation = self.query.relation
60         if isinstance(relation, Join):
61             self.tables = self.wrap(relation)
62             self.statement.fromclause = self.joinclause(self.tables)
63         elif isinstance(relation, geniusql.Schema):
64             # This is how we say we want to SELECT scalars (no FROM clause)
65             self.tables = []
66             self.statement.fromclause = ""
67         else:
68             self.tables = [self.db.joinwrapper(relation)]
69             self.statement.fromclause = relation.qname
70         self.tablenames = [(t.alias or t.qname, t.table) for t in self.tables]
71    
72     def process_where(self):
73         """Return an SQL WHERE clause, and an 'imperfect' flag."""
74         dep = self.db.deparser(self.tablenames, self.query.restriction,
75                                self.db.typeset)
76 ##        dep.verbose = True
77         self.statement.whereclause = dep.code()
78         if dep.imperfect:
79             self.statement.imperfect = True
80    
81     def wrap(self, join):
82         """Return the given Join with each node wrapped."""
83         t1, t2 = join.table1, join.table2
84        
85         if isinstance(t1, Join):
86             wt1 = self.wrap(t1)
87         else:
88             wt1 = self.db.joinwrapper(t1)
89             if t1.name in self.seen:
90                 self.aliascount += 1
91                 alias = "t%d" % self.aliascount
92                 wt1.alias = self.db.quote(t1.schema.table_name(alias))
93             else:
94                 self.seen[t1.name] = None
95        
96         if isinstance(t2, Join):
97             wt2 = self.wrap(t2)
98         else:
99             wt2 = self.db.joinwrapper(t2)
100             if t2.name in self.seen:
101                 self.aliascount += 1
102                 alias = "t%d" % self.aliascount
103                 wt2.alias = self.db.quote(t2.schema.table_name(alias))
104             else:
105                 self.seen[t2.name] = None
106        
107         newjoin = Join(wt1, wt2, join.leftbiased)
108         # if the original Join had a custom reference path,
109         # copy it to the new Join instance
110         newjoin.path = join.path
111         return newjoin
112    
113     def joinname(self, tablewrapper):
114         """Quoted table name for use in JOIN clause."""
115         if tablewrapper.alias:
116             return "%s AS %s" % (tablewrapper.qname, tablewrapper.alias)
117         else:
118             return tablewrapper.qname
119    
120     def onclause(self, A, B, path=None):
121         """Return 'A.x = B.y' for tables A and B (or None).
122         
123         The returned value (if not None) is suitable for use in the 'ON'
124         portion of an SQL JOIN clause.
125         """
126         if path is None:
127             path = B.table.schema.key_for(B.table)
128        
129         if isinstance(path, logic.Expression):
130             dep = self.db.deparser(self.tablenames, path, self.db.typeset)
131 ##            dep.verbose = True
132             dep.walk()
133             atom = dep.stack[0]
134             if dep.imperfect:
135                 self.statement.imperfect = True
136             return atom.sql
137         else:
138             ref = A.table.references.get(path, None)
139             if ref:
140                 nearkey, _, farkey = ref
141                 near = '%s.%s' % (A.alias or A.qname, A.table[nearkey].qname)
142                 far = '%s.%s' % (B.alias or B.qname, B.table[farkey].qname)
143                 return "%s = %s" % (near, far)
144    
145     def joinclause(self, join):
146         """Return an SQL FROM clause for the given (wrapped) Join."""
147         t1, t2 = join.table1, join.table2
148         if isinstance(t1, Join):
149             name1 = self.joinclause(t1)
150             tlist1 = iter(t1)
151         else:
152             # t1 is a Table class wrapper.
153             name1 = self.joinname(t1)
154             tlist1 = [t1]
155        
156         if isinstance(t2, Join):
157             name2 = self.joinclause(t2)
158             tlist2 = iter(t2)
159         else:
160             # t2 is a Table class wrapper.
161             name2 = self.joinname(t2)
162             tlist2 = [t2]
163        
164         j = {None: "INNER", True: "LEFT", False: "RIGHT"}[join.leftbiased]
165        
166         # Find a reference between the two halves.
167         for A in tlist1:
168             for B in tlist2:
169                 on = self.onclause(A, B, join.path)
170                 if on:
171                     return "(%s %s JOIN %s ON %s)" % (name1, j, name2, on)
172                
173                 on = self.onclause(B, A, join.path)
174                 if on:
175                     return "(%s %s JOIN %s ON %s)" % (name1, j, name2, on)
176        
177         raise errors.ReferenceError("No reference found between %s and %s."
178                                     % (name1, name2))
179    
180     def unpack_attributes(self):
181         raise NotImplementedError
182
183
184
185 # -------------------------- SELECT statements -------------------------- #
186
187
188 class SELECT(SQLStatement):
189     """A SELECT SQL statement. Usually produced by an SQLWriter.
190     
191     input: a list of SQL expressions, one for each column in the
192         SELECT clause. These will include any "expr AS name" alias.
193     output: a list of tuples of the form:
194         (column key,
195          SQL name (or alias),
196          quoted SQL name (or alias),
197          source Column object)
198         One per output column.
199     """
200     output = None
201     groupby = None
202     orderby = None
203     distinct = False
204     limit = None
205     offset = None
206     into = ""
207    
208     def __init__(self):
209         self.input = []
210         self.output = []
211         self.groupby = []
212    
213     def _get_sql(self):
214         """Return an SQL SELECT statement."""
215         atoms = ["SELECT"]
216         append = atoms.append
217         if self.distinct:
218             append('DISTINCT')
219         append(', '.join(self.input))
220         if self.into:
221             append("INTO")
222             append(self.into)
223         if self.fromclause:
224             append("FROM")
225             append(self.fromclause)
226             if self.whereclause:
227                 append("WHERE")
228                 append(self.whereclause)
229         if self.groupby and len(self.groupby) < len(self.input):
230             append("GROUP BY")
231             append(", ".join(self.groupby))
232         if self.orderby:
233             append("ORDER BY")
234             append(", ".join(self.orderby))
235         if self.limit is not None:
236             append("LIMIT %d" % self.limit)
237         if self.offset is not None:
238             append("OFFSET %d" % self.offset)
239         return " ".join(atoms)
240     sql = property(_get_sql, doc="The SQL string for this SELECT statement.")
241    
242     def result_table(self, schema, name):
243         """Return a new Table object for the result of this SELECT.
244         
245         This is too expensive to do when you don't need it, so it's
246         a separate function here. Try not to call it more than once
247         for a given SelectWriter instance.
248         """
249         newtable = schema.table(name)
250         for colkey, name, qname, col in self.output:
251             newcol = col.copy()
252             newcol.name = name
253             newcol.qname = qname
254             newcol.key = False
255             newcol.autoincrement = False
256             newcol.sequence_name = None
257             newcol.initial = 1
258             newtable[colkey] = newcol
259         return newtable
260
261
262 class SelectWriter(SQLWriter):
263     """Database delegate for writing SELECT statements.
264     
265     db: a Database instance.
266     statement: pass in an instance of geniusql.Statement (which is DB-agnostic)
267         and it will be transformed into a DB-specific SQLStatement.
268     """
269    
270     statement_class = SELECT
271    
272     def __init__(self, db, statement, into=""):
273         self.output_cols = {}
274        
275         SQLWriter.__init__(self, db, statement.query)
276        
277         # Yes, we're trading one statement object for another here,
278         # but the former is a DB-agnostic Statement and the latter
279         # is a DB-specific SQLStatement.
280         self.statement.distinct = statement.distinct
281         self.statement.limit = statement.limit
282         self.statement.offset = statement.offset
283         self.statement.into = into
284        
285         order = statement.order
286         if order is None:
287             self.statement.orderby = None
288         elif isinstance(order, FunctionType):
289             order = logic.Expression(order)
290             self.deparse_order(order)
291         elif isinstance(order, logic.Expression):
292             self.deparse_order(order)
293         elif isinstance(order, basestring):
294             raise TypeError("The 'order' value %r is not one of the allowed "
295                             "types (list, lambda, None, or Expression)." %
296                             order)
297         else:
298             if isinstance(statement.query.relation, Join):
299                 raise ValueError("order must be an Expression when "
300                                  "selecting from multiple tables.")
301             else:
302                 # 'relation' is a single Table object.
303                 ob = []
304                 for key in order:
305                     # Handle embedded "ASC"/"DESC" atoms
306                     atoms = key.rsplit(" ", 1)
307                     key = statement.query.relation[atoms.pop(0)].qname
308                     if atoms:
309                         key += " " + atoms[0]
310                     ob.append(key)
311                 self.statement.orderby = ob
312    
313     def unpack_attributes(self):
314         if isinstance(self.query.attributes, logic.Expression):
315             self.deparse_attributes()
316             return
317        
318         if isinstance(self.query.relation, Join):
319             for t, attrs in zip(self.tables, self.query.attributes):
320                 # Add columns from the given table to our result table.
321                 alias = t.alias or t.qname
322                 table = t.table
323                 for colkey in attrs:
324                     col = table[colkey]
325                     if colkey in self.output_cols:
326                         # Get the key for the table.
327                         colkey = '%s_%s' % (table.schema.key_for(table), colkey)
328                         colname = '%s_%s' % (table.name, col.name)
329                         colqname = self.db.quote(colname)
330                         selname = '%s.%s AS %s' % (alias, col.qname, colqname)
331                     else:
332                         colname = col.name
333                         colqname = col.qname
334                         selname = '%s.%s' % (alias, colqname)
335                     self.statement.input.append(selname)
336                     self.statement.output.append((colkey, colname, colqname, col))
337                     self.output_cols[colkey] = col
338         else:
339             # 'relation' is a single Table object.
340             for colkey in self.query.attributes:
341                 col = self.query.relation[colkey]
342                 self.statement.input.append(col.qname)
343                 self.statement.output.append((colkey, col.name, col.qname, col))
344                 self.output_cols[colkey] = col
345    
346     def deparse_attributes(self):
347         dep = self.db.deparser(self.tablenames, self.query.attributes,
348                                self.db.typeset)
349 ##        dep.verbose = True
350        
351         for atom in dep.field_list():
352             if atom.name in self.output_cols:
353                 bare_name = atom.name
354                 index = 1
355                 while atom.name in self.output_cols:
356                     atom.name = '%s%s' % (bare_name, index)
357                     index += 1
358            
359             qname = self.db.quote(atom.name)
360             self.statement.input.append('%s AS %s' % (atom.sql, qname))
361             self.statement.output.append((atom.name, atom.name, qname, atom))
362             if not atom.aggregate:
363                 self.statement.groupby.append(atom.sql)
364            
365             self.output_cols[atom.name] = atom
366    
367     def deparse_order(self, order):
368         dep = self.db.deparser(self.tablenames, order, self.db.typeset)
369 ##        dep.verbose = True
370         self.statement.orderby = [atom.sql for atom in dep.field_list()]
371
372
373
374 # -------------------------- UPDATE statements -------------------------- #
375
376
377 class UPDATE(SQLStatement):
378     """An UPDATE SQL statement. Usually produced by an UpdateWriter.
379     
380     input: a dict of SQL expressions, one for each column in the
381         SET clause. Keys will be quoted column names and values
382         will be the new values (in SQL syntax) for the columns.
383     """
384    
385     def __init__(self):
386         self.input = {}
387    
388     def _get_sql(self):
389         """Return an SQL UPDATE statement."""
390         atoms = ["UPDATE", self.fromclause, "SET"]
391         atoms.append(', '.join(["%s = %s" % (k, v)
392                                 for k, v in self.input.iteritems()]))
393         if self.whereclause:
394             atoms.append("WHERE")
395             atoms.append(self.whereclause)
396         return " ".join(atoms)
397     sql = property(_get_sql, doc="The SQL string for this UPDATE statement.")
398
399
400 class UpdateWriter(SQLWriter):
401     """Database delegate for writing UPDATE statements.
402     
403     db: a Database instance.
404     query: a Query instance.
405     """
406    
407     statement_class = UPDATE
408    
409     def unpack_attributes(self):
410         if isinstance(self.query.relation, Join):
411             for t, attrs in zip(self.tables, self.query.attributes):
412                 # Add columns from the given table to our result table.
413                 alias, table = t
414                 for colkey, val in attrs.iteritems():
415                     col = table[colkey]
416                     fullkey = '%s.%s' % (alias, col.qname)
417                     if isinstance(val, FunctionType):
418                         val = logic.Expression(val)
419                         val = self.deparse_attribute(val)
420                     elif isinstance(val, logic.Expression):
421                         val = self.deparse_attribute(val)
422                     else:
423                         val = col.adapter.push(val, col.dbtype)
424                     self.statement.input[fullkey] = val
425         else:
426             # 'relation' is a single Table object.
427             for colkey, val in self.query.attributes.iteritems():
428                 col = self.query.relation[colkey]
429                 if isinstance(val, FunctionType):
430                     val = logic.Expression(val)
431                     val = self.deparse_attribute(val)
432                 elif isinstance(val, logic.Expression):
433                     val = self.deparse_attribute(val)
434                 else:
435                     val = col.adapter.push(val, col.dbtype)
436                 self.statement.input[col.qname] = val
437    
438     def deparse_attribute(self, value):
439         dep = self.db.deparser(self.tablenames, value, self.db.typeset)
440 ##        dep.verbose = True
441         code = dep.code()
442         if dep.imperfect:
443             raise ValueError("The given attribute expression could not be "
444                              "safely translated to SQL.", value)
445         return code
446
447
448
449 # -------------------------- DELETE statements -------------------------- #
450
451
452 class DELETE(SQLStatement):
453     """A DELETE SQL statement. Usually produced by a DeleteWriter.
454     
455     input: a list of SQL expressions, one for each column in the DELETE clause.
456     """
457    
458     def __init__(self):
459         self.input = []
460    
461     def _get_sql(self):
462         """Return an SQL DELETE statement."""
463         atoms = ["DELETE"]
464         append = atoms.append
465         append(', '.join(self.input))
466         if self.fromclause:
467             append("FROM")
468             append(self.fromclause)
469             if self.whereclause:
470                 append("WHERE")
471                 append(self.whereclause)
472         return " ".join(atoms)
473     sql = property(_get_sql, doc="The SQL string for this DELETE statement.")
474
475
476 class DeleteWriter(SQLWriter):
477     """Database delegate for writing DELETE statements.
478     
479     db: a Database instance.
480     query: a Query instance.
481     
482     For now, query.attributes should be empty, since many databases do not
483     allow any attribute list in DELETE statements.
484     """
485    
486     statement_class = DELETE
487    
488     def unpack_attributes(self):
489         pass
490
491
492
493 # -------------------------- INSERT statements -------------------------- #
494
495
496 class INSERT(SQLStatement):
497     """An INSERT SQL statement. Usually produced by an InsertWriter.
498     
499     input: a dict of SQL expressions, one for each column in the
500         SET clause. Keys will be quoted column names and values
501         will be the new values (in SQL syntax) for the columns.
502     """
503    
504     def __init__(self):
505         self.input = {}
506    
507     def _get_sql(self):
508         """Return an SQL INSERT statement."""
509         if self.input:
510             keys, values = zip(*self.input.items())
511             atoms = ["INSERT INTO", self.fromclause,
512                      '(%s)' % ', '.join(keys), "VALUES",
513                      '(%s)' % ', '.join(values)]
514         else:
515             keys, values = [], []
516             atoms = ["INSERT INTO", self.fromclause,
517                     "DEFAULT VALUES"]
518         return " ".join(atoms)
519     sql = property(_get_sql, doc="The SQL string for this INSERT statement.")
520
521
522 class InsertWriter(SQLWriter):
523     """Database delegate for writing INSERT statements.
524     
525     db: a Database instance.
526     query: a Query instance.
527     """
528    
529     statement_class = INSERT
530    
531     def unpack_attributes(self):
532         if isinstance(self.query.relation, Join):
533             for t, attrs in zip(self.tables, self.query.attributes):
534                 # Add columns from the given table to our result table.
535                 alias, table = t
536                 for colkey, val in attrs.iteritems():
537                     col = table[colkey]
538                     fullkey = '%s.%s' % (alias, col.qname)
539                     if isinstance(val, FunctionType):
540                         val = logic.Expression(val)
541                         val = self.deparse_attribute(val)
542                     elif isinstance(val, logic.Expression):
543                         val = self.deparse_attribute(val)
544                     else:
545                         val = col.adapter.push(val, col.dbtype)
546                     self.statement.input[fullkey] = val
547         else:
548             # 'relation' is a single Table object.
549             for colkey, val in self.query.attributes.iteritems():
550                 col = self.query.relation[colkey]
551                 if isinstance(val, FunctionType):
552                     val = logic.Expression(val)
553                     val = self.deparse_attribute(val)
554                 elif isinstance(val, logic.Expression):
555                     val = self.deparse_attribute(val)
556                 else:
557                     val = col.adapter.push(val, col.dbtype)
558                 self.statement.input[col.qname] = val
559    
560     def deparse_attribute(self, value):
561         dep = self.db.deparser(self.tablenames, value, self.db.typeset)
562 ##        dep.verbose = True
563         code = dep.code()
564         if dep.imperfect:
565             raise ValueError("The given attribute expression could not be "
566                              "safely translated to SQL.", value)
567         return code
Note: See TracBrowser for help on using the browser.