Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

root/trunk/geniusql/__init__.py

Revision 7 (checked in by fumanchu, 6 years ago)

Various bugfixes related to Table.created.

  • Property svn:eol-style set to native
Line 
1 """Geniusql, a Python database library.
2
3 The Column and Index classes model corresponding database objects, and are
4 intentionally simple. They should rarely contain any SQL or "smarts" of
5 any kind, besides the "qname", the quoted name, of the column or index.
6 At most, subclasses and consumers might put implementation-specific data
7 into them.
8
9 The IndexSet, Table, and Database objects are all dict-like containers,
10 and therefore have a key for each value. Those keys should equate to things
11 at the consumer layer; for example, a Database may possess a pair of the
12 form: {'YoYo': Table('yoyo')} -- the key is the "friendly" name, but the
13 Table.name is a lowercase version of that, because that's what the database
14 uses in SQL to refer to that table.
15 """
16
17 __version__ = "1.0alpha"
18
19 import threading
20 from geniusql import xray
21
22 from geniusql import errors, typerefs
23
24 from geniusql.adapters import *
25 from geniusql.conn import *
26 from geniusql.isolation import *
27 from geniusql.select import *
28
29 from geniusql import providers
30
31
32 def db(cls, name, options):
33     """Create a Database model object for the given cls, name, and options.
34     
35     cls: Either a subclass of geniusql.Database or a 'shortcut name'
36         registered in geniusql.providers.providers.
37     name: the database name as used by the underlying database.
38     
39     This function does not call CREATE DATABASE, nor should it open any
40     database connections. It simply instantiates the proper Database class.
41     """
42     if isinstance(cls, basestring):
43         try:
44             cls = providers.providers[cls]
45         except KeyError:
46             pass
47    
48     if isinstance(cls, basestring):
49         cls = xray.classes(cls)
50    
51     opts = dict([(str(k), v) for k, v in options.iteritems()])
52     opts.pop('name', None)
53    
54     return cls(name, **opts)
55
56
57 class Index:
58     """An index on a table column (or columns) in a database."""
59    
60     def __init__(self, name, qname, tablename, colname, unique=True):
61         self.name = name
62         self.qname = qname
63         self.tablename = tablename
64         self.colname = colname
65         self.unique = unique
66    
67     def __repr__(self):
68         return ("%s.%s(%r, %r, %r, %r, unique=%r)" %
69                 (self.__module__, self.__class__.__name__,
70                  self.name, self.qname, self.tablename,
71                  self.colname, self.unique))
72    
73     def __copy__(self):
74         return self.__class__(self.name, self.qname, self.tablename,
75                               self.colname, self.unique)
76     copy = __copy__
77
78
79 class IndexSet(dict):
80    
81     def __new__(cls, table):
82         return dict.__new__(cls)
83    
84     def __init__(self, table):
85         dict.__init__(self)
86         self.table = table
87    
88     def alias(self, oldname, newname):
89         """Add a new key for the Index with the given, existing key.
90         
91         Consumer code should call this method when user-supplied index
92         names do not match the names in the database. This does not
93         remove the old key; both keys may be used to refer to the same
94         Index object.
95         """
96         obj = self[oldname]
97         if newname in self:
98             dict.__delitem__(self, newname)
99         dict.__setitem__(self, newname, obj)
100    
101     def __setitem__(self, key, index):
102         """Drop the specified index."""
103         t = self.table
104         if t.created:
105             t.db.lock("Creating index. Transactions not allowed.")
106             try:
107                 t.db.execute('CREATE INDEX %s ON %s (%s);' %
108                              (index.qname, t.qname,
109                               t.db.quote(index.colname)))
110             finally:
111                 t.db.unlock()
112         dict.__setitem__(self, key, index)
113    
114     def __delitem__(self, key):
115         """Drop the specified index."""
116         t = self.table
117         if t.created:
118             t.db.lock("Dropping index. Transactions not allowed.")
119             try:
120                 t.db.execute('DROP INDEX %s ON %s;' % (self[key].qname, t.qname))
121             finally:
122                 t.db.unlock()
123         dict.__delitem__(self, key)
124
125
126 class Column:
127     """A column in a table in a database.
128     
129     name: the SQL name for this table (unquoted).
130     qname: the SQL name for this table (quoted).
131     pytype: the Python type (the actual type object, not its name).
132     dbtype: the database type name (as used in a CREATE TABLE statement).
133     default: default Python value for this column for new rows.
134     hints: a dict of implementation hints, such as precision, scale, or bytes.
135     key: True if this column is part of the table's primary key.
136     
137     imperfect_type: if True, signals that we are deliberately using a
138         database type other than the default (usually in order to handle
139         irregular values, such as huge numbers).
140     autoincrement: if True, uses the database's built-in sequencing.
141     sequence_name: for databases that use separate statements to create and
142         drop sequences, this stores the name of the sequence.
143     initial: if autoincrement, holds the initial value for the sequence.
144     """
145    
146     def __init__(self, pytype, dbtype, default=None, hints=None, key=False,
147                  name=None, qname=None):
148         self.pytype = pytype
149         self.dbtype = dbtype
150         self.name = name
151         self.qname = qname
152         self.default = default
153         if hints is None:
154             hints = {}
155         else:
156             hints = hints.copy()
157         self.hints = hints
158         self.key = key
159        
160         # If autoincrement, the initial value should be put in self.initial.
161         self.autoincrement = False
162         self.sequence_name = None
163         self.initial = 1
164        
165         self.imperfect_type = False
166    
167     def __repr__(self):
168         return ("%s.%s(%r, %r, default=%r, hints=%r, key=%r, name=%r, qname=%r)" %
169                 (self.__module__, self.__class__.__name__,
170                  self.pytype, self.dbtype,
171                  self.default, self.hints, self.key,
172                  self.name, self.qname)
173                 )
174    
175     def __copy__(self):
176         newcol = self.__class__(self.pytype, self.dbtype,
177                                 self.default, self.hints, self.key,
178                                 self.name, self.qname)
179         newcol.autoincrement = self.autoincrement
180         newcol.initial = self.initial
181         newcol.imperfect_type = self.imperfect_type
182         return newcol
183     copy = __copy__
184
185
186 class Table(dict):
187     """A table in a database; a dict of Column objects.
188     
189     Values in this dict must be instances of Column (or a subclass of it).
190     Keys should be consumer-friendly names for each Column value.
191     
192     name: the SQL name for this table (unquoted).
193     qname: the SQL name for this table (quoted).
194     db: the database for this table. If None (the default), then changes to
195         Table items can be made with impunity. If not None, then appropriate
196         ALTER TABLE commands are executed whenever a consumer adds or deletes
197         items from the Table, or calls methods like 'rename'. Therefore,
198         when creating Table objects from an existing database, you should
199         set the 'db' arg late.
200     indices: a dict-like IndexSet of Index objects.
201     references: a dict of the form: {name: (nearColKey, farTableKey, farColKey)}.
202     """
203    
204     indexsetclass = IndexSet
205    
206     def __new__(cls, name, qname, db, created=False):
207         return dict.__new__(cls)
208    
209     def __init__(self, name, qname, db, created=False):
210         dict.__init__(self)
211        
212         self.name = name
213         self.qname = qname
214         self.db = db
215         self.created = created
216        
217         self.indices = self.indexsetclass(self)
218         self.references = {}
219    
220     def __repr__(self):
221         name = getattr(self, "name", "<unknown>")
222         qname = getattr(self, "qname", "<unknown>")
223         return ("%s.%s(%r, %r)" %
224                 (self.__module__, self.__class__.__name__, name, qname))
225    
226     def __copy__(self):
227         # Don't set 'created' when copying!
228         newtable = self.__class__(self.name, self.qname, self.db)
229         for key, c in self.iteritems():
230             dict.__setitem__(newtable, key, c.copy())
231         for key, i in self.indices.iteritems():
232             dict.__setitem__(newtable.indices, key, i.copy())
233         return newtable
234     copy = __copy__
235    
236     def alias(self, oldname, newname):
237         """Add a new key for the Column with the given, existing key.
238         
239         Consumer code should call this method when user-supplied column
240         names do not match the names in the database. This does not
241         remove the old key; both keys may be used to refer to the same
242         Column object.
243         """
244         obj = self[oldname]
245         if newname in self:
246             dict.__delitem__(self, newname)
247         dict.__setitem__(self, newname, obj)
248    
249     def _add_column(self, column):
250         """Internal function to add the column to the database."""
251         coldef = self.db.columnclause(column)
252         self.db.execute("ALTER TABLE %s ADD COLUMN %s;" % (self.qname, coldef))
253    
254     def __setitem__(self, key, column):
255         if column.name is None:
256             column.name = self.db._column_name(self.name, key)
257             column.qname = self.db.quote(column.name)
258        
259         if not self.created:
260             dict.__setitem__(self, key, column)
261             return
262        
263         if key in self:
264             del self[key]
265        
266         self.db.lock("Adding property. Transactions not allowed.")
267         try:
268             if column.autoincrement:
269                 # This may or may not be a no-op, depending on the DB.
270                 self.db.create_sequence(self, column)
271             self._add_column(column)
272             dict.__setitem__(self, key, column)
273         finally:
274             self.db.unlock()
275    
276     def _drop_column(self, column):
277         """Internal function to drop the column from the database."""
278         self.db.execute("ALTER TABLE %s DROP COLUMN %s;" %
279                         (self.qname, column.qname))
280    
281     def __delitem__(self, key):
282         if key in self.indices:
283             del self.indices[key]
284        
285         if not self.created:
286             dict.__delitem__(self, key)
287             return
288        
289         self.db.lock("Dropping property. Transactions not allowed.")
290         try:
291             column = self[key]
292             self._drop_column(column)
293             if column.autoincrement:
294                 # This may or may not be a no-op, depending on the DB.
295                 self.db.drop_sequence(column)
296             dict.__delitem__(self, key)
297         finally:
298             self.db.unlock()
299    
300     def _rename(self, oldcol, newcol):
301         # Override this to do the actual rename at the DB level.
302         self.db.execute("ALTER TABLE %s RENAME COLUMN %s TO %s;" %
303                         (self.qname, oldcol.qname, newcol.qname))
304    
305     def rename(self, oldkey, newkey):
306         """Rename a Column. This will change the table name in the database."""
307         oldcol = self[oldkey]
308        
309         if not self.created:
310             dict.__delitem__(self, oldkey)
311             dict.__setitem__(self, newkey, oldcol)
312             return
313        
314         oldname = oldcol.name
315         newname = self.db._column_name(self.name, newkey)
316        
317         if oldname != newname:
318             newcol = oldcol.copy()
319             newcol.name = newname
320             newcol.qname = self.db.quote(newname)
321             self.db.lock("Renaming property. Transactions not allowed.")
322             try:
323                 self._rename(oldcol, newcol)
324             finally:
325                 self.db.unlock()
326        
327         # Use the superclass calls to avoid DROP COLUMN/ADD COLUMN.
328         dict.__delitem__(self, oldkey)
329         dict.__setitem__(self, newkey, newcol)
330    
331     def add_index(self, columnkey):
332         """Add and return a new Index for the given column key.
333         
334         The new Index object will possess the same key as the column.
335         In general, the actual SQL name of the new Index will be of
336         the form: "i" + table.name + column.name.
337         """
338         colname = self[columnkey].name
339         name = self.db.table_name("i" + self.name + colname)
340         i = Index(name, self.db.quote(name), self.name, colname)
341         self.indices[columnkey] = i
342         return i
343    
344     def select_all(self, restriction=None):
345         """Yield data dicts matching the given restriction."""
346         attrs = self.keys()
347         data = self.db.select(self, attrs, restriction)
348         for row in data:
349             row = dict(zip(attrs, row))
350             if restriction and data.imperfect:
351                 # Run a dummy object through our restriction before yielding.
352                 if not restriction(ImperfectDummy(**row)):
353                     continue
354             yield row
355    
356     def select_one(self, restriction=None):
357         """Return a single data dict matching the given restriction (or None)."""
358         try:
359             return self.select_all(restriction).next()
360         except StopIteration:
361             return None
362
363
364 class ImperfectDummy(object):
365     """A dummy object for resolving imperfect queries."""
366     def __init__(self, **kwargs):
367         for k, v in kwargs.iteritems():
368             setattr(self, k, v)
369
370
371 class Database(dict):
372     """A dict for managing a set of tables.
373     
374     Values in this dict must be instances of Table. Keys should be
375     consumer-friendly names for each Table value. For example, it's
376     easiest to use all lowercase table names in MySQL; however, a
377     geniusql consumer might want their code to use TitledNames to
378     refer to each table.
379     
380     When a consumer adds and deletes items from a Database object,
381     appropriate CREATE TABLE/DROP TABLE commands are executed.
382     This means that a Table object to be added should have all
383     of its columns populated before adding it to the Database.
384     """
385    
386     adaptertosql = AdapterToSQL()
387     adapterfromdb = AdapterFromDB()
388     typeadapter = TypeAdapter()
389    
390     decompiler = SQLDecompiler
391     joinwrapper = TableWrapper
392    
393     selectwriter = SelectWriter
394     tableclass = Table
395    
396     def __new__(cls, name, **kwargs):
397         return dict.__new__(cls)
398    
399     def __init__(self, name, **kwargs):
400         self._discover_lock = threading.Lock()
401        
402         dict.__init__(self)
403         for k, v in kwargs.iteritems():
404             setattr(self, k, v)
405        
406         self.name = self.sql_name(name)
407         self.qname = self.quote(self.name)
408         self.transactions = {}
409         self.connect()
410         self.discover_dbinfo()
411    
412     def __repr__(self):
413         name = getattr(self, "name", "<unknown>")
414         return "%s.%s(%r)" % (self.__module__, self.__class__.__name__, name)
415    
416     def version(self):
417         """Return a string containing version info for this database."""
418         raise NotImplementedError
419    
420     def log(self, msg):
421         pass
422    
423    
424     #                              Discovery                              #
425    
426     def _get_dbinfo(self, conn=None):
427         return {}
428    
429     def discover_dbinfo(self, conn=None):
430         """Set attributes on self with actual DB metadata, where possible."""
431         for k, v in self._get_dbinfo().iteritems():
432             setattr(self, k, v)
433    
434     def _get_tables(self, conn=None):
435         raise NotImplementedError
436    
437     def _get_table(self, tablename, conn=None):
438         # Fallback behavior. This is slow and should be optimized by each DB.
439         for t in self._get_tables():
440             if t.name == tablename:
441                 return t
442         raise errors.MappingError(tablename)
443    
444     def _get_columns(self, tablename, conn=None):
445         raise NotImplementedError
446    
447     def _get_indices(self, tablename, conn=None):
448         raise NotImplementedError
449    
450     def _discover_table(self, table, conn=None):
451         """Populate the columns and indices of the given Table object."""
452         for col in self._get_columns(table.name, conn):
453             # Use the superclass call to avoid ALTER TABLE
454             if col.name in table:
455                 dict.__delitem__(table, col.name)
456             dict.__setitem__(table, col.name, col)
457        
458         for idx in self._get_indices(table.name, conn):
459             # Use the superclass call to avoid CREATE INDEX
460             if idx.name in table.indices:
461                 dict.__delitem__(table.indices, idx.name)
462             dict.__setitem__(table.indices, idx.name, idx)
463    
464     def discover(self, tablename, conn=None):
465         """Attach a new Table from the underlying DB to self (and return it).
466         
467         Table objects (and their Column and Index subobjects) will be
468         added to self using keys that match the database's names.
469         Consumers should call the "alias(oldname, newname)" method
470         of Database, Table, and IndexSet in order to re-map the
471         discovered objects using consumer-friendly names.
472         
473         If no such table exists, a MappingError should be raised.
474         """
475         self._discover_lock.acquire()
476         try:
477             table = self._get_table(tablename)
478            
479             self._discover_table(table, conn)
480            
481             # Use the superclass calls to avoid CREATE TABLE
482             if table.name in self:
483                 dict.__delitem__(self, table.name)
484             dict.__setitem__(self, table.name, table)
485            
486             return table
487         finally:
488             self._discover_lock.release()
489    
490     def discover_all(self, conn=None):
491         """(Re-)populate self (all table items) from the underlying DB.
492         
493         Table objects (and their Column and Index subobjects) will be
494         added to self using keys that match the database's names.
495         Consumers should call the "alias(oldname, newname)" method
496         of Database, Table, and IndexSet in order to re-map the
497         discovered objects using consumer-friendly names.
498         
499         This method is idempotent, but that doesn't mean cheap. Try not
500         to call it very often (once at app startup is usually enough).
501         If you already know the names of all the tables you want to
502         discover, it's often faster to skip this method and just use
503         the discover(tablename) method for each known name instead.
504         """
505         self._discover_lock.acquire()
506         try:
507             for table in self._get_tables(conn):
508                 self._discover_table(table, conn)
509                
510                 # Use the superclass calls to avoid CREATE TABLE
511                 if table.name in self:
512                     dict.__delitem__(self, table.name)
513                 dict.__setitem__(self, table.name, table)
514         finally:
515             self._discover_lock.release()
516    
517     def alias(self, oldname, newname):
518         """Add a new key for the Table with the given, existing key.
519         
520         Consumer code should call this method when user-supplied table
521         names do not match the names in the database. This does not
522         remove the old key; both keys may be used to refer to the same
523         Table object.
524         """
525         obj = self[oldname]
526         if newname in self:
527             dict.__delitem__(self, newname)
528         dict.__setitem__(self, newname, obj)
529    
530     def python_type(self, dbtype):
531         """Return a Python type which can store values of the given dbtype."""
532         raise TypeError("Database type %r could not be converted "
533                         "to a Python type." % dbtype)
534    
535     def isrelatedtype(self, pytype1, pytype2):
536         """If values of both types are expressed with the same SQL, return True."""
537         if issubclass(pytype1, pytype2) or issubclass(pytype2, pytype1):
538             return True
539         if issubclass(pytype1, basestring) and issubclass(pytype2, basestring):
540             return True
541         if ((issubclass(pytype1, int) or issubclass(pytype1, long)) and
542             (issubclass(pytype2, int) or issubclass(pytype2, long))):
543             return True
544         if typerefs.fixedpoint:
545             if typerefs.decimal:
546                 if ((issubclass(pytype1, typerefs.fixedpoint.FixedPoint)
547                      or issubclass(pytype1, typerefs.decimal.Decimal)) and
548                     (issubclass(pytype2, typerefs.fixedpoint.FixedPoint)
549                      or issubclass(pytype2, typerefs.decimal.Decimal))):
550                     return True
551             else:
552                 if (issubclass(pytype1, typerefs.fixedpoint.FixedPoint) and
553                     issubclass(pytype2, typerefs.fixedpoint.FixedPoint)):
554                     return True
555         else:
556             if typerefs.decimal:
557                 if (issubclass(pytype1, typerefs.decimal.Decimal) and
558                     issubclass(pytype2, typerefs.decimal.Decimal)):
559                     return True
560         return False
561    
562    
563     #                              Container                              #
564    
565     def columnclause(self, column):
566         """Return a clause for the given column for CREATE or ALTER TABLE.
567         
568         This will be of the form "name type [DEFAULT x]".
569         
570         Most subclasses will override this to add autoincrement support.
571         """
572         dbtype = column.dbtype
573        
574         default = column.default or ""
575         if default:
576             default = self.adaptertosql.coerce(default, dbtype)
577             default = " DEFAULT %s" % default
578        
579         return "%s %s%s" % (column.qname, dbtype, default)
580    
581     def create_sequence(self, table, column):
582         """Create a SEQUENCE for the given column and set its sequence_name."""
583         # By default, this does nothing. Databases which require a separate
584         # statement to create a sequence generator should override this.
585         pass
586    
587     def drop_sequence(self, column):
588         """Drop a SEQUENCE for the given column and remove its sequence_name."""
589         # By default, this does nothing. Databases which require a separate
590         # statement to drop a sequence generator should override this.
591         pass
592    
593     def __setitem__(self, key, table):
594         if key in self:
595             del self[key]
596        
597         # Set table.created to True, which should "turn on"
598         # any future ALTER TABLE statements.
599         table.created = True
600        
601         self.lock("Creating storage. Transactions not allowed.")
602         try:
603             fields = []
604             pk = []
605             for column in table.itervalues():
606                 if column.autoincrement:
607                     # This may or may not be a no-op, depending on the DB.
608                     self.create_sequence(table, column)
609                
610                 fields.append(self.columnclause(column))
611                 if column.key:
612                     pk.append(column.qname)
613            
614             if pk:
615                 pk = ", PRIMARY KEY (%s)" % ", ".join(pk)
616             else:
617                 pk = ""
618            
619             self.execute('CREATE TABLE %s (%s%s);' %
620                          (table.qname, ", ".join(fields), pk))
621            
622             for index in table.indices.itervalues():
623                 self.execute('CREATE INDEX %s ON %s (%s);' %
624                              (index.qname, table.qname,
625                               self.quote(index.colname)))
626             dict.__setitem__(self, key, table)
627         finally:
628             self.unlock()
629    
630     def __delitem__(self, key):
631         self.lock("Dropping storage. Transactions not allowed.")
632         try:
633             table = self[key]
634             self.execute('DROP TABLE %s;' % table.qname)
635             for col in table.itervalues():
636                 if col.autoincrement:
637                     self.drop_sequence(col)
638             dict.__delitem__(self, key)
639         finally:
640             self.unlock()
641    
642     def _rename(self, oldtable, newtable):
643         # Override this to do the actual rename at the DB level.
644         raise NotImplementedError
645         newtable.created = True
646    
647     def rename(self, oldkey, newkey):
648         """Rename a Table."""
649         oldtable = self[oldkey]
650         oldname = oldtable.name
651         newname = self.table_name(newkey)
652        
653         if oldname != newname:
654             newtable = oldtable.copy()
655             newtable.db = self
656             newtable.name = newname
657             newtable.qname = self.quote(newname)
658             self.lock("Renaming storage. Transactions not allowed.")
659             try:
660                 self._rename(oldtable, newname)
661             finally:
662                 self.unlock()
663        
664         # Use the superclass calls to avoid DROP TABLE/CREATE TABLE.
665         dict.__delitem__(self, oldkey)
666         dict.__setitem__(self, newkey, newtable)
667    
668     #                               Naming                               #
669    
670     sql_name_max_length = 64
671     sql_name_caseless = False
672     Prefix = ""
673    
674     def quote(self, name):
675         """Return name, quoted for use in an SQL statement."""
676         # This base class doesn't use "quote",
677         # but most subclasses will.
678         return name
679    
680     def sql_name(self, key):
681         """Return the native SQL version of key."""
682         if self.sql_name_caseless:
683             key = key.lower()
684        
685         maxlen = self.sql_name_max_length
686         if maxlen and len(key) > maxlen:
687             errors.warn("The name '%s' is longer than the maximum of "
688                         "%s characters." % (key, maxlen))
689             key = key[:maxlen]
690        
691         return key
692    
693     def _column_name(self, tablename, columnkey):
694         "Return the SQL column name for the given table name and column key."
695         # If you want to use a map from your ORM's property names
696         # to DB column names, override this method (that's why
697         # the tablename must be included in the args).
698         return self.sql_name(columnkey)
699    
700     def column(self, pytype=unicode, dbtype=None, default=None, hints=None,
701                key=False, autoincrement=False):
702         """Return a Column object from the given arguments."""
703         col = Column(pytype, dbtype, default, hints, key)
704         col.autoincrement = autoincrement
705        
706         if dbtype is None:
707             col.dbtype = self.typeadapter.coerce(col, pytype)
708         pytype2 = self.python_type(col.dbtype)
709         col.imperfect_type = not self.isrelatedtype(pytype, pytype2)
710        
711         return col
712    
713     def table_name(self, key):
714         """Return the SQL table name for the given key."""
715         # If you want to use a map from your ORM's class names
716         # to DB table names, override this method.
717         return self.sql_name(self.Prefix + key)
718    
719     def table(self, name):
720         """Create and return a Table object for the given name."""
721         name = self.table_name(name)
722         return self.tableclass(name, self.quote(name), self)
723    
724     #                             Connecting                              #
725    
726     poolsize = 10
727    
728     def connect(self):
729         if self.poolsize > 0:
730             self.connection = ConnectionPool(self._get_conn, self._del_conn,
731                                              self.poolsize)
732         else:
733             self.connection = ConnectionFactory(self._get_conn, self._del_conn)
734    
735     def _get_conn(self):
736         """Create and return a connection object."""
737         # Override this with the connection call for your DB. Example:
738         #     return libpq.PQconnectdb(self.connstring)
739         raise NotImplementedError
740    
741     def _del_conn(self, conn):
742         """Close a connection object."""
743         # Override this with the close call (if any) for your DB.
744         conn.close()
745    
746     def disconnect(self):
747         """Release all database connections."""
748         self.connection.shutdown()
749    
750     def execute(self, query, conn=None):
751         """Return a native response for the given query."""
752         if conn is None:
753             conn = self.connection()
754         if isinstance(query, unicode):
755             query = query.encode(self.adaptertosql.encoding)
756         self.log(query)
757         return conn.query(query)
758    
759     def fetch(self, query, conn=None):
760         """Return rowdata, columns (name, type) for the given query.
761         
762         query should be a SQL query in string format
763         rowdata will be an iterable of iterables containing the result values.
764         columns will be an iterable of (column name, data type) pairs.
765         
766         This base class uses _sqlite syntax.
767         """
768         res = self.execute(query, conn)
769         return res.row_list, res.col_defs
770    
771     def select(self, relation, attributes, restriction=None, distinct=False):
772         """Yield matching data, coerced to Python types (where known)."""
773         sel = self.selectwriter(self, relation, attributes, restriction)
774         data, _ = self.fetch(sel.sql(distinct), self.get_transaction())
775         return ResultSet(data, sel.columns, sel.imperfect)
776    
777     def create_database(self):
778         self.lock("Creating database. Transactions not allowed.")
779         try:
780             self.execute("CREATE DATABASE %s;" % self.qname)
781             self.clear()
782         finally:
783             self.unlock()
784    
785     def drop_database(self):
786         self.lock("Dropping database. Transactions not allowed.")
787         try:
788             # Must shut down all connections to avoid
789             # "being accessed by other users" error.
790             self.connection.shutdown()
791             self.execute("DROP DATABASE %s;" % self.qname)
792             self.clear()
793         finally:
794             self.unlock()
795    
796     #                            Transactions                             #
797    
798     transaction_key = threading._get_ident
799     implicit_trans = False
800    
801     # The "default_isolation" value should be a value native to the DB.
802     default_isolation = None
803    
804     # The values in "isolation_levels" should match the names of
805     # IsolationLevel objects in isolation.py
806     isolation_levels = ["READ UNCOMMITTED", "READ COMMITTED",
807                         "REPEATABLE READ", "SERIALIZABLE"]
808    
809     def get_transaction(self, new=False, isolation=None):
810         """Return the (possibly new) connection for the current transaction.
811         
812         If we are already in a transaction, this returns the connection for
813         that transaction. The "current transaction" is determined by a key
814         (obtained by a call to self.transaction_key); by default, the key
815         is the current thread ID (but subclasses are free to change this).
816         
817         If there is no "current transaction", a new connection object is
818         obtained by calling self.connection (which is usually a connection
819         pool object). If self.implicit_trans is True, new connections will
820         be associated with self.transaction_key(), and repeated calls to
821         get_transaction will then return the same connection object.
822         If self.implicit_trans is False, you'll get a new connection
823         (from the pool) each time.
824         """
825         key = self.transaction_key()
826         if key in self.transactions:
827             conn = self.transactions[key]
828             if isinstance(conn, errors.TransactionLock):
829                 raise conn
830         else:
831             conn = self.connection()
832             if self.implicit_trans or new:
833                 self.transactions[key] = conn
834                 if not new:
835                     self.start(isolation)
836         return conn
837    
838     def is_lock_error(self, exc):
839         """If the given exception instance is a lock timeout, return True.
840         
841         This should return True for errors which arise from transaction
842         locking timeouts; for example, if the database prevents 'dirty
843         reads' by raising an error.
844         """
845         # You should definitely override this for your database.
846         return False
847    
848     def isolate(self, conn, isolation=None):
849         """Set the isolation level of the given connection.
850         
851         If 'isolation' is None, our default_isolation will be used for new
852         connections. Valid values for the 'isolation' argument may be native
853         values for your particular database. However, it is recommended you
854         pass items from the global 'levels' list instead; these will be
855         automatically replaced with native values.
856         
857         For many databases, this must be executed after START TRANSACTION.
858         """
859         if isolation is None:
860             isolation = self.default_isolation
861        
862         if isinstance(isolation, IsolationLevel):
863             # Map the given IsolationLevel object to a native value.
864             isolation = isolation.name
865             if isolation not in self.isolation_levels:
866                 raise ValueError("IsolationLevel %r not allowed by %s. "
867                                  "Try one of %r instead."
868                                  % (isolation, self.__class__.__name__,
869                                     self.isolation_levels))
870        
871         # This is SQL92 syntax, and should work with most DB's.
872         self.execute("SET TRANSACTION ISOLATION LEVEL %s;" % isolation, conn)
873    
874     def start(self, isolation=None):
875         """Start a transaction. Not needed if self.implicit_trans is True."""
876         conn = self.get_transaction(new=True)
877         self.execute("START TRANSACTION;", conn)
878         self.isolate(conn, isolation)
879    
880     def rollback(self):
881         """Roll back the current transaction, if any."""
882         key = self.transaction_key()
883         if key in self.transactions:
884             self.execute("ROLLBACK;", self.transactions[key])
885             del self.transactions[key]
886         else:
887             # This is critical in order to support polygonal SM structures
888             # (same store being called twice by separate proxies).
889             pass
890    
891     def commit(self):
892         """Commit the current transaction, if any."""
893         key = self.transaction_key()
894         try:
895             conn = self.transactions.pop(key)
896         except KeyError:
897             # This is critical in order to support polygonal SM structures
898             # (same store being called twice by separate proxies).
899             pass
900         else:
901             self.execute("COMMIT;", conn)
902    
903     # Change this to 'error' if you don't want autocommit on schema ops.
904     lock_contention = 'commit'
905    
906     def lock(self, msg=None):
907         """Deny transactions during schema operations (DDL statements).
908         
909         Any code which calls this should also call 'unlock' in a try/finally:
910         
911         db.lock('dropping storage')
912         try:
913             drop_storage(cls)
914         finally:
915             db.unlock()
916         """
917         key = self.transaction_key()
918         if key in self.transactions:
919             if isinstance(self.transactions[key], errors.TransactionLock):
920                 return
921             if self.lock_contention == 'error':
922                 raise errors.TransactionLock("Schema operations are not "
923                                              "allowed inside transactions.")
924             self.commit()
925        
926         if msg is None:
927             msg = "Transactions not allowed at the moment."
928         self.transactions[key] = errors.TransactionLock(msg)
929    
930     def unlock(self):
931         """Allow transactions."""
932         key = self.transaction_key()
933         trans = self.transactions.get(key, None)
934         if trans is None:
935             return
936         if not isinstance(trans, errors.TransactionLock):
937             raise errors.TransactionLock("Unlock called inside transaction.")
938         del self.transactions[key]
939    
940     #                              OLTP/CRUD                               #
941    
942     def id_clause(self, tablekey, **inputs):
943         """Return an SQL expression for the identifiers of the given table."""
944         t = self[tablekey]
945         coerce = self.adaptertosql.coerce
946         pairs = []
947         for key, col in t.iteritems():
948             if col.key:
949                 val = coerce(inputs[key], col.dbtype)
950                 pairs.append("%s = %s" % (col.qname, val))
951         return " AND ".join(pairs)
952    
953     def insert(self, tablekey, **inputs):
954         """Insert a row and return {idcolkey: newid}."""
955         t = self[tablekey]
956         coerce_out = self.adaptertosql.coerce
957         coerce_in = self.adapterfromdb.coerce
958        
959         fields = []
960         idkeys = []
961         values = []
962         for key, col in t.iteritems():
963             if col.autoincrement:
964                 # Skip this field, since we're using a sequencer
965                 idkeys.append(key)
966                 continue
967             if key in inputs:
968                 val = coerce_out(inputs[key], col.dbtype)
969                 fields.append(col.qname)
970                 values.append(val)
971        
972         transconn = self.get_transaction()
973        
974         fields = ", ".join(fields)
975         values = ", ".join(values)
976         self.execute('INSERT INTO %s (%s) VALUES (%s);' %
977                      (t.qname, fields, values), transconn)
978        
979         if idkeys:
980             newids = self._grab_new_ids(t, idkeys, transconn)
981             for key in newids.keys():
982                 col = t[key]
983                 newids[key] = coerce_in(newids[key], col.dbtype, col.pytype)
984             return newids
985         else:
986             return {}
987    
988     def _grab_new_ids(self, table, idkeys):
989         # Override this to fetch and return new autoincrement values.
990         raise NotImplementedError
991    
992     def save(self, tablekey, **inputs):
993         """Update a row using the given inputs."""
994         t = self[tablekey]
995        
996         parms = []
997         coerce = self.adaptertosql.coerce
998         for key, val in inputs.iteritems():
999             col = t[key]
1000             if col.autoincrement:
1001                 # Skip this field, since we're using a sequencer
1002                 pass
1003             else:
1004                 val = coerce(val, col.dbtype)
1005                 parms.append('%s = %s' % (col.qname, val))
1006        
1007         if parms:
1008             sql = ('UPDATE %s SET %s WHERE %s;' %
1009                    (t.qname, ", ".join(parms), self.id_clause(tablekey, **inputs)))
1010             self.execute(sql, self.get_transaction())
1011
1012
1013 class ResultSet:
1014    
1015     def __init__(self, data, columns, imperfect):
1016         self.data = data
1017         self.columns = columns
1018         self.imperfect = imperfect
1019         self.cursor = 0
1020    
1021     def __iter__(self):
1022         return self
1023    
1024     def next(self):
1025         try:
1026             row = self.data[self.cursor]
1027             self.cursor += 1
1028         except IndexError:
1029             raise StopIteration
1030        
1031         coerced_row = []
1032         for i, (table, col, qname) in enumerate(self.columns):
1033             val = row[i]
1034             if table and col:
1035                 val = table.db.adapterfromdb.coerce(val, col.dbtype, col.pytype)
1036             coerced_row.append(val)
1037         return coerced_row
1038
Note: See TracBrowser for help on using the browser.