Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/db.py

Revision 229 (checked in by fumanchu, 7 years ago)

Bah. The DB Introspection test wasn't running. Here are fixes for MySQL, PostreSQL. More fixes coming.

  • Property svn:eol-style set to native
Line 
1 """Base classes and tools for writing database Storage Managers.
2
3 DATA TYPES
4 ==========
5 Database Storage Manager modules are mostly adapters to support round-trip
6 data coercion:
7
8 Unit type -> [SQL repr ->] DB -> incoming Python value -> Unit type
9
10 Since Dejavu relies on external database servers for its persistence,
11 Python datatypes must be converted to column types in the DB. When writing
12 a StorageManager, you should make sure that your type conversions can handle
13 at least the following limitations: If possible, implement the type with no
14 limits. Also, follow UnitProperty.hints['bytes'] where possible. A value
15 of zero for hints['bytes'] implies no limit. If no value is given, try to
16 assume no limit, although you may choose whatever default size you wish
17 (255 is common for strings).
18
19 ENCODING ISSUES
20 ===============
21 All SQL sent to the database must be strings, not unicode. You can set the
22 encoding of the Adapters (I may add a more centralized encoding context in
23 the future). We must use encoded strings so that we can mix encodings
24 within the same string; for example, we might have a DB which understands
25 utf8, but a pickle value which will be encoded in raw-unicode-escape inline
26 with that. All values, therefore, must be coerced before we try to join
27 them into an SQL statement string.
28
29 """
30
31
32 try:
33     # Builtin in Python 2.5?
34     decimal
35 except NameError:
36     try:
37         # Module in Python 2.3, 2.4
38         import decimal
39     except ImportError:
40         decimal = None
41
42 try:
43     import fixedpoint
44 except ImportError:
45     fixedpoint = None
46
47 import sys
48 for maxint_bytes in xrange(9):
49     if sys.maxint <= 2 ** ((maxint_bytes * 8) - 1):
50         break
51
52
53 import threading
54 import warnings
55
56
57 import dejavu
58 from dejavu import logic, storage, LOGSQL, xray
59 from dbmodel import *
60
61
62 class FieldTypeAdapter(object):
63     """For a UnitProperty, return a database type.
64     
65     This base class is designed to work out-of-the-box with PostgreSQL 8.
66     """
67    
68     # Max binary precision for floating-point columns (= 53 for PostgreSQL 8).
69     float_max_precision = 53
70    
71     # Max decimal precision for NUMERIC columns (= 1000 for PostgreSQL 8).
72     numeric_max_precision = 1000
73    
74     # "The actual storage requirement is two bytes for each group of four
75     # decimal digits, plus eight bytes overhead." Note we omit the overhead.
76     numeric_max_bytes = 500
77    
78     def coerce(self, cls, key):
79         """Return a column object for the given UnitProperty (cls and key)."""
80         prop = cls.property(key)
81        
82         # Obtain the DB type
83         valuetype = prop.type
84         mod = valuetype.__module__
85         if mod == "__builtin__":
86             xform = "coerce_%s" % valuetype.__name__
87         else:
88             xform = "coerce_%s_%s" % (mod, valuetype.__name__)
89         xform = xform.replace(".", "_")
90         try:
91             xform = getattr(self, xform)
92         except AttributeError:
93             raise TypeError("'%s' is not handled by %s." %
94                             (valuetype, self.__class__))
95         dbtype = xform(cls, key)
96        
97         return dbtype
98    
99     def float_type(self, precision):
100         """Return a datatype which can handle floats of the given binary precision."""
101         if precision <= 24:
102             return "REAL"
103         else:
104             # Python floats are implemented using C doubles;
105             # actual precision depends on platform.
106             # PostgreSQL DOUBLE is 53 binary-digit precision.
107             return "DOUBLE PRECISION"
108    
109     def coerce_float(self, cls, key):
110         prop = getattr(cls, key)
111         # Note that 'precision' is binary digits, not decimal
112         precision = int(prop.hints.get('precision', 0))
113         if precision == 0:
114             precision = self.float_max_precision
115         elif precision > self.float_max_precision:
116             warnings.warn("Float precision %s > maximum %s for %s.%s, "
117                           "using %s. Values may be stored incorrectly."
118                           % (precision, self.float_max_precision,
119                              cls.__name__, key, self.__class__.__name__),
120                           dejavu.StorageWarning)
121             precision = self.float_max_precision
122         return self.float_type(precision)
123    
124     def coerce_str(self, cls, key):
125         # The bytes hint shall not reflect the usual 4-byte base for varchar.
126         prop = getattr(cls, key)
127         bytes = int(prop.hints.get('bytes', 0))
128         if bytes:
129             return "VARCHAR(%s)" % bytes
130         else:
131             # TEXT is not an SQL standard, but it's common.
132             return "TEXT"
133    
134     def coerce_dict(self, cls, key):
135         return self.coerce_str(cls, key)
136     def coerce_list(self, cls, key):
137         return self.coerce_str(cls, key)
138     def coerce_tuple(self, cls, key):
139         return self.coerce_str(cls, key)
140     def coerce_unicode(self, cls, key):
141         return self.coerce_str(cls, key)
142    
143     def coerce_bool(self, cls, key): return "BOOLEAN"
144    
145     def coerce_datetime_datetime(self, cls, key): return "TIMESTAMP"
146     def coerce_datetime_date(self, cls, key): return "DATE"
147     def coerce_datetime_time(self, cls, key): return "TIME"
148    
149     # I was seriously disinterested in writing a parser for interval.
150     def coerce_datetime_timedelta(self, cls, key):
151         return self.coerce_float(cls, key)
152    
153     def coerce_decimal_Decimal(self, cls, key):
154         prop = getattr(cls, key)
155         precision = int(prop.hints.get('precision', decimal.getcontext().prec))
156         if precision == 0:
157             precision = self.numeric_max_precision
158         elif precision > self.numeric_max_precision:
159             warnings.warn("Decimal precision %s > maximum %s for %s.%s, "
160                           "using %s. Values may be stored incorrectly."
161                           % (precision, self.numeric_max_precision,
162                              cls.__name__, key, self.__class__.__name__),
163                           dejavu.StorageWarning)
164             precision = self.numeric_max_precision
165        
166         # Assume most people use decimal for money; default scale = 2.
167         scale = int(prop.hints.get('scale', 2))
168         if scale > precision:
169             scale = precision
170         return "NUMERIC(%s, %s)" % (precision, scale)
171    
172     def coerce_decimal(self, cls, key):
173         # If decimal ever becomes a builtin. Python 2.5?
174         return self.coerce_decimal_Decimal(cls, key)
175    
176     def coerce_fixedpoint_FixedPoint(self, cls, key):
177         prop = getattr(cls, key)
178         # Note that fixedpoint has no theoretical precision limit.
179         precision = int(prop.hints.get('precision', 0))
180         if precision == 0:
181             precision = self.numeric_max_precision
182         elif precision > self.numeric_max_precision:
183             warnings.warn("Fixedpoint precision %s > maximum %s for %s.%s, "
184                           "using %s. Values may be stored incorrectly."
185                           % (precision, self.numeric_max_precision,
186                              cls.__name__, key, self.__class__.__name__),
187                           dejavu.StorageWarning)
188             precision = self.numeric_max_precision
189        
190         # Assume most people use fixedpoint for money; default scale = 2.
191         scale = int(prop.hints.get('scale', 2))
192         if scale > precision:
193             scale = precision
194         return "NUMERIC(%s, %s)" % (precision, scale)
195    
196     def int_type(self, bytes):
197         """Return a datatype which can handle the given number of bytes."""
198         if bytes == 1:
199             return "BOOLEAN"
200         elif bytes == 2:
201             return "SMALLINT"
202         elif bytes <= 4:
203             return "INTEGER"
204         elif bytes <= 8:
205             # BIGINT is usually 8 bytes
206             return "BIGINT"
207         else:
208             # Anything larger than 8 bytes, use decimal/numeric.
209             # For PostgreSQL, "The actual storage requirement is two bytes
210             # for each group of four decimal digits, plus eight bytes
211             # overhead." Note we omit the overhead in our calculation.
212             return "NUMERIC(%s, 0)" % (bytes * 2)
213    
214     def coerce_long(self, cls, key):
215         prop = getattr(cls, key)
216         bytes = int(prop.hints.get('bytes', self.numeric_max_precision))
217         if bytes == 0:
218             bytes = self.numeric_max_bytes
219         elif bytes > self.numeric_max_bytes:
220             warnings.warn("Long bytes %s > maximum %s for %s.%s, "
221                           "using %s. Values may be stored incorrectly."
222                           % (bytes, self.numeric_max_bytes,
223                              cls.__name__, key, self.__class__.__name__),
224                           dejavu.StorageWarning)
225             bytes = self.numeric_max_bytes
226        
227         return self.int_type(bytes)
228    
229     def coerce_int(self, cls, key):
230         prop = getattr(cls, key)
231         bytes = int(prop.hints.get('bytes', maxint_bytes))
232         if bytes == 0:
233             bytes = maxint_bytes
234         elif bytes > maxint_bytes:
235             warnings.warn("Integer bytes %s > maximum %s for %s.%s, "
236                           "using %s. Values may be stored incorrectly."
237                           % (bytes, maxint_bytes,
238                              cls.__name__, key, self.__class__.__name__),
239                           dejavu.StorageWarning)
240             bytes = maxint_bytes
241        
242         return self.int_type(bytes)
243
244
245
246 # --------------------------- Storage Manager --------------------------- #
247
248
249 class UnitClassWrapper(object):
250     """Unit class wrapper, for use in parsing multiselect joins."""
251    
252     def __init__(self, wclass, db):
253         self.cls = wclass
254         self.db = db
255        
256         wclsname = wclass.__name__
257         self.tablename = db[wclsname].name
258         self.alias = ""
259    
260     def columns(self):
261         """Return [(wclass, UnitProperty.key), ...], ['"tbl"."col"', ...]."""
262         wclass = self.cls
263         q = self.db.quote
264        
265         # Place the identifier properties first
266         # in case others depend upon them.
267         keys = list(wclass.identifiers) + [k for k in wclass.properties
268                                            if k not in wclass.identifiers]
269         cols = [(wclass, k) for k in keys]
270         colnames = ['%s.%s' % (q(self.alias or self.tablename),
271                                q(self.db.column_name(wclass.__name__, k)))
272                     for k in keys]
273         return cols, colnames
274    
275     def _joinname(self):
276         q = self.db.quote
277         if self.alias:
278             return "%s AS %s" % (q(self.tablename), q(self.alias))
279         else:
280             return q(self.tablename)
281     joinname = property(_joinname, doc=("Quoted table name for use in "
282                                         "JOIN clause (read-only)."))
283    
284     def association(self, classes):
285         for other in classes:
286             ua = self.cls._associations.get(other.cls.__name__, None)
287             if ua:
288                 nearClass = self.alias or self.tablename
289                 farClass = other.alias or other.tablename
290                 return ua, nearClass, farClass
291             ua = other.cls._associations.get(self.cls.__name__, None)
292             if ua:
293                 nearClass = other.alias or other.tablename
294                 farClass = self.alias or self.tablename
295                 return ua, nearClass, farClass
296         return None
297
298
299 class StorageManagerDB(storage.StorageManager):
300     """StoreManager base class to save and retrieve Units using a DB."""
301    
302     use_asterisk_to_get_all = False
303    
304     typeAdapter = FieldTypeAdapter()
305     databaseclass = Database
306    
307     def __init__(self, name, arena, allOptions={}):
308         storage.StorageManager.__init__(self, name, arena, allOptions)
309         self.reserve_lock = threading.Lock()
310        
311         # Adapter Overrides
312         def get_option(name):
313             item = allOptions.get(name)
314             if isinstance(item, basestring):
315                 item = xray.classes(item)
316             return item
317        
318         adapter = get_option('Type Adapter')
319         if adapter:
320             self.typeAdapter = adapter
321        
322         adapter = get_option('Database Class')
323         if adapter:
324             self.databaseclass = adapter
325        
326         allOptions = dict([(str(k), v) for k, v in allOptions.iteritems()])
327         name = allOptions.pop('name')
328         self.db = self.databaseclass(name, **allOptions)
329         self.db.log = self.arena.log
330    
331     def version(self):
332         return self.db.version()
333    
334     def shutdown(self):
335         self.db.disconnect()
336    
337     def consume(self, unit, key, value, coltype):
338         try:
339             expectedType = unit.__class__.property(key).type
340             value = self.db.adapterfromdb.coerce(value, coltype, expectedType)
341             unit._properties[key] = value
342         except UnicodeDecodeError, x:
343             x.reason += "[%s][%s][%s]" % (key, value, coltype)
344             raise
345         except Exception, x:
346             x.args += (key, value, coltype)
347             raise
348    
349     def recall(self, cls, expr=None):
350         """Yield a sequence of Unit instances which satisfy the expression."""
351         clsname = cls.__name__
352        
353         if expr is None:
354             expr = logic.Expression(lambda x: True)
355         sql, imperfect = self.db.select(cls.__name__, expr)
356         data, col_defs = self.db.fetch(sql)
357         if data:
358             t = self.db[clsname]
359             columns = dict([(col[0], (index, col[1])) for index, col
360                             in enumerate(col_defs)])
361            
362             # Get specs on properties. Put the identifier properties
363             # first, in case other fields depend upon them.
364             props = []
365             idnames = list(cls.identifiers)
366             for key in idnames + [x for x in cls.properties if x not in idnames]:
367                 index, ftype = columns[t.columns[key].name]
368                 props.append((key, index, ftype))
369            
370             for row in data:
371                 unit = cls()
372                 for key, index, ftype in props:
373                     value = row[index]
374                     self.consume(unit, key, value, ftype)
375                
376                 # If our SQL is imperfect, don't yield it to the
377                 # caller unless it passes expr(unit).
378                 if (not imperfect) or expr(unit):
379                     unit.cleanse()
380                     yield unit
381    
382     def reserve(self, unit):
383         """reserve(unit). -> Reserve a persistent slot for unit."""
384         self.reserve_lock.acquire()
385         try:
386             # First, see if our db subclass has a handler that
387             # uses the DB to generate the appropriate identifier(s).
388             seqclass = unit.sequencer.__class__.__name__
389             seq_handler = getattr(self, "_seq_%s" % seqclass, None)
390             if seq_handler:
391                 seq_handler(unit)
392             else:
393                 self._manual_reserve(unit)
394             unit.cleanse()
395         finally:
396             self.reserve_lock.release()
397    
398     def _manual_reserve(self, unit):
399         """Use when the DB cannot automatically generate an identifier.
400         The identifiers will be supplied by UnitSequencer.assign().
401         """
402         cls = unit.__class__
403         t = self.db[cls.__name__]
404         if not unit.sequencer.valid_id(unit.identity()):
405             # Examine all existing IDs and grant the "next" one.
406             id_fields = [t.columns[key].qname for key in cls.identifiers]
407             data, cols = self.db.fetch('SELECT %s FROM %s;' %
408                                        (', '.join(id_fields), t.qname))
409             if data:
410                 # sqlite 2, for example, has empty cols tuple if no data.
411                 coerce = self.db.adapterfromdb.coerce
412                 coltypes = [cols[x][1] for x in xrange(len(cols))]
413                 expectedTypes = [getattr(cls, key).type
414                                  for key in cls.identifiers]
415                 newdata = []
416                 for row in data:
417                     newrow = []
418                     for x, cell in enumerate(row):
419                         newrow.append(coerce(cell, coltypes[x],
420                                              expectedTypes[x]))
421                     newdata.append(newrow)
422                 data = newdata
423                 del newdata
424             cls.sequencer.assign(unit, data)
425             del data
426             del cols
427        
428         fields = []
429         values = []
430         for key in cls.properties:
431             val = self.db.adaptertosql.coerce(getattr(unit, key))
432             fields.append(t.columns[key].qname)
433             values.append(val)
434        
435         fields = ", ".join(fields)
436         values = ", ".join(values)
437         self.db.execute('INSERT INTO %s (%s) VALUES (%s);' %
438                         (t.qname, fields, values))
439    
440     def id_clause(self, unit):
441         """Return an SQL expression for the identifiers of the given Unit."""
442         cols = self.db[unit.__class__.__name__].columns
443         c = self.db.adaptertosql.coerce
444         pairs = ["%s = %s" % (cols[key].qname, c(getattr(unit, key)))
445                  for key in unit.identifiers]
446         return " AND ".join(pairs)
447    
448     def save(self, unit, forceSave=False):
449         """save(unit, forceSave=False) -> Update storage from unit's data."""
450         if unit.dirty() or forceSave:
451             cls = unit.__class__
452             t = self.db[cls.__name__]
453            
454             parms = []
455             for key in cls.properties:
456                 if key not in cls.identifiers:
457                     val = self.db.adaptertosql.coerce(getattr(unit, key))
458                     parms.append('%s = %s' % (t.columns[key].qname, val))
459            
460             if parms:
461                 sql = ('UPDATE %s SET %s WHERE %s;' %
462                        (t.qname, ", ".join(parms), self.id_clause(unit)))
463                 self.db.execute(sql)
464             unit.cleanse()
465    
466     def destroy(self, unit):
467         """destroy(unit). Delete the unit."""
468         if self.use_asterisk_to_get_all:
469             star = " *"
470         else:
471             star = ""
472         self.db.execute('DELETE%s FROM %s WHERE %s;' %
473                         (star, self.db[unit.__class__.__name__].qname,
474                          self.id_clause(unit)))
475    
476     def view(self, cls, fields, expr=None):
477         """view(cls, fields, expr=None) -> All value-tuples for given fields."""
478         if expr is None:
479             expr = logic.Expression(lambda x: True)
480        
481         sql, imperfect = self.db.select(cls.__name__, expr, fields)
482         if imperfect:
483             # ^%$#@! There's no way to handle imperfect queries without
484             # creating all involved Units, which defeats the purpose of
485             # view, which was a speed issue more than anything else.
486             warnings.warn("The requested view() query for %s Units "
487                           "cannot produce perfect SQL with a %s datasource. "
488                           "It may take an absurd amount of time to run, "
489                           "since each unit must be fully-formed. %s"
490                           % (cls.__name__, self.__class__.__name__, expr),
491                           dejavu.StorageWarning)
492             for unit in self.recall(cls, expr):
493                 # Use tuples for hashability
494                 yield tuple([getattr(unit, f) for f in fields])
495         else:
496             data, columns = self.db.fetch(sql)
497             actualTypes = [x[1] for x in columns]
498             expectedTypes = [cls.property(x).type for x in fields]
499            
500             coerce = self.db.adapterfromdb.coerce
501             # Use tuples for hashability
502             for row in data:
503                 yield tuple([coerce(val, actualTypes[i], expectedTypes[i])
504                              for i, val in enumerate(row)])
505    
506     def distinct(self, cls, fields, expr=None):
507         """distinct(cls, fields, expr=None) -> Distinct values for given fields."""
508         if expr is None:
509             expr = logic.Expression(lambda x: True)
510        
511         sql, imperfect = self.db.select(cls.__name__, expr,
512                                             fields, distinct=True)
513         if imperfect:
514             # ^%$#@! There's no way to handle imperfect queries without
515             # creating all involved Units, which defeats the purpose of
516             # distinct, which was a speed issue more than anything.
517             warnings.warn("The requested distinct() query for %s Units "
518                           "cannot produce perfect SQL with a %s datasource. "
519                           "It may take an absurd amount of time to run, "
520                           "since each unit must be fully-formed. %s"
521                           % (cls.__name__, self.__class__.__name__, expr),
522                           dejavu.StorageWarning)
523             vals = {}
524             for unit in self.recall(cls, expr):
525                 # Must use tuples for hashability
526                 val = tuple([getattr(unit, f) for f in fields])
527                 vals[val] = None
528             return vals.keys()
529         else:
530             data, columns = self.db.fetch(sql)
531             actualTypes = [x[1] for x in columns]
532             expectedTypes = [cls.property(x).type for x in fields]
533            
534             coerce = self.db.adapterfromdb.coerce
535             # Must use inner tuples for hashability in Sandbox.distinct()
536             return [tuple([coerce(val, actualTypes[i], expectedTypes[i])
537                            for i, val in enumerate(row)])
538                      for row in data]
539    
540     def join(self, unitjoin):
541         """Return an SQL FROM clause for the given unitjoin."""
542         cls1, cls2 = unitjoin.class1, unitjoin.class2
543         if isinstance(cls1, dejavu.UnitJoin):
544             name1 = self.join(cls1)
545             classlist1 = iter(cls1)
546         else:
547             # cls1 is a Unit class wrapper.
548             name1 = cls1.joinname
549             classlist1 = [cls1]
550        
551         if isinstance(cls2, dejavu.UnitJoin):
552             name2 = self.join(cls2)
553             classlist2 = iter(cls2)
554         else:
555             # cls2 is a Unit class wrapper.
556             name2 = cls2.joinname
557             classlist2 = [cls2]
558        
559         j = {None: "INNER", True: "LEFT", False: "RIGHT"}[unitjoin.leftbiased]
560        
561         # Find an association between the two halves.
562         ua = None
563         for clsA in classlist1:
564             ua = clsA.association(classlist2)
565             if ua:
566                 ua, nearClass, farClass = ua
567                 break
568         if ua is None:
569             msg = ("No association found between %s and %s." % (name1, name2))
570             raise dejavu.AssociationError(msg)
571        
572         t = self.db
573         near = '%s.%s' % (t.quote(t.table_name(nearClass)),
574                           t.quote(t.column_name(nearClass, ua.nearKey)))
575         far = '%s.%s' % (t.quote(t.table_name(farClass)),
576                          t.quote(t.column_name(farClass, ua.farKey)))
577        
578         return "(%s %s JOIN %s ON %s = %s)" % (name1, j, name2, near, far)
579    
580     def multiselect(self, classes, expr):
581         """Return an SQL SELECT statement, an imperfect flag, and column names."""
582        
583         # Create a new unitjoin tree where each class is wrapped.
584         # Then we can tag the wrappers with metadata with impunity.
585         seen = {}
586         aliascount = [0]
587        
588         def wrap(unitjoin):
589             cls1, cls2 = unitjoin.class1, unitjoin.class2
590             if isinstance(cls1, dejavu.UnitJoin):
591                 wclass1 = wrap(cls1)
592             else:
593                 wclass1 = UnitClassWrapper(cls1, self.db)
594                 if cls1 in seen:
595                     aliascount[0] += 1
596                     wclass1.alias = "t%d" % aliascount[0]
597                 else:
598                     seen[cls1] = None
599             if isinstance(cls2, dejavu.UnitJoin):
600                 wclass2 = wrap(cls2)
601             else:
602                 wclass2 = UnitClassWrapper(cls2, self.db)
603                 if cls2 in seen:
604                     aliascount[0] += 1
605                     wclass2.alias = "t%d" % aliascount[0]
606                 else:
607                     seen[cls2] = None
608             return dejavu.UnitJoin(wclass1, wclass2, unitjoin.leftbiased)
609         classes = wrap(classes)
610        
611         joins = self.join(classes)
612        
613         if expr is None:
614             expr = logic.Expression(lambda *args: True)
615         w, imp = self.db.where(list(classes), expr)
616        
617         cols = []
618         colnames = []
619         for wrapper in classes:
620             c, names = wrapper.columns()
621             cols.extend(c)
622             colnames.extend(names)
623        
624         statement = ("SELECT %s FROM %s WHERE %s" %
625                      (', '.join(colnames), joins, w))
626         return statement, imp, cols
627    
628     def multirecall(self, classes, expr):
629         """Yield Unit instance sets which satisfy the expression."""
630         sql, imp, supplied_cols = self.multiselect(classes, expr)
631         data, recvd_cols = self.db.fetch(sql)
632         if data:
633             # Get specs on properties.
634             props = []
635             for sup, rec in zip(supplied_cols, recvd_cols):
636                 c, key = sup
637                 name, ftype = rec[0], rec[1]
638                 props.append((c, key, ftype))
639            
640             for row in data:
641                 index = 0
642                 units = {}
643                 for c, key, ftype in props:
644                     if c in units:
645                         unit = units[c]
646                     else:
647                         units[c] = unit = c()
648                     value = row[index]
649                     self.consume(unit, key, value, ftype)
650                     index += 1
651                
652                 unitset = []
653                 for cls in classes:
654                     unit = units[cls]
655                     unit.cleanse()
656                     unitset.append(unit)
657                
658                 # If our SQL is imperfect, don't yield units to the
659                 # caller unless they pass expr(unit).
660                 acceptable = True
661                 if imp:
662                     acceptable = expr(*unitset)
663                 if acceptable:
664                     yield unitset
665    
666     #                               Schemas                               #
667    
668     def create_database(self):
669         self.db.create_database()
670    
671     def drop_database(self):
672         self.db.drop_database()
673    
674     def create_storage(self, cls):
675         """Create storage for the given class."""
676         # Make a Table object.
677         t = self.db.make_table(cls.__name__)
678        
679         indices = cls.indices()
680         fields = []
681         for key in cls.properties:
682             dbtype = self.typeAdapter.coerce(cls, key)
683             col = self.db.make_column(cls.__name__, key, dbtype)
684             prop = cls.property(key)
685             col.default = prop.default
686             col.hints = prop.hints.copy()
687             # Use the superclass call to avoid ALTER TABLE.
688             dict.__setitem__(t.columns, key, col)
689            
690             if key in indices:
691                 i = self.db.make_index(cls.__name__, key)
692                 # Use the superclass call to avoid CREATE INDEX.
693                 dict.__setitem__(t.columns.indices, key, i)
694        
695         # Attach to self.db, which should call CREATE TABLE.
696         self.db[cls.__name__] = t
697    
698     def has_storage(self, cls):
699         return cls.__name__ in self.db
700    
701     def drop_storage(self, cls):
702         del self.db[cls.__name__]
703    
704     def rename_storage(self, oldname, newname):
705         self.arena.log("rename table %s to %s" % (oldname, newname), LOGSQL)
706         self.db.rename(oldname, newname)
707    
708     def add_property(self, cls, name):
709         if not self.has_property(cls, name):
710             dbtype = self.typeAdapter.coerce(cls, name)
711             c = self.db.make_column(cls.__name__, name, dbtype)
712             prop = getattr(cls, name)
713             c.default = prop.default
714             c.hints = prop.hints.copy()
715             self.db[cls.__name__].columns[name] = c
716    
717     def has_property(self, cls, name):
718         return name in self.db[cls.__name__].columns
719    
720     def drop_property(self, cls, name):
721         if self.has_property(cls, name):
722             del self.db[cls.__name__].columns[name]
723    
724     def rename_property(self, cls, oldname, newname):
725         self.db[cls.__name__].columns.rename(oldname, newname)
726    
727     def has_index(self, cls, name):
728         return name in self.db[cls.__name__].columns.indices
729    
730     def drop_index(self, cls, name):
731         del self.db[cls.__name__].columns.indices[name]
732    
733     def sync(self):
734         """Map new Table objects to all registered classes."""
735         # Use the superclass call to avoid DROP TABLE.
736         dict.clear(self.db)
737         dbtables = self.db._get_tables()
738         for cls in self.arena._registered_classes:
739             # Try to find a matching Table object from _get_tables.
740             t = [x for x in dbtables if x.name == self.db.table_name(cls.__name__)]
741             if t:
742                 t = t[0]
743                 dbcols = self.db._get_columns(t.name)
744                 for ckey in cls.properties:
745                     # Try to find a matching Column object from _get_columns.
746                     c = [x for x in dbcols if x.name == self.db.column_name(cls.__name__, ckey)]
747                     if c:
748                         c = c[0]
749                         # Use the superclass call to avoid ALTER TABLE
750                         dict.__setitem__(t.columns, ckey, c)
751                
752                 dbindices = self.db._get_indices(t.name)
753                 for ikey in cls.indices():
754                     iname = self.db.table_name("i" + cls.__name__ + ikey)
755                     # Try to find a matching Column object from _get_columns.
756                     i = [x for x in dbindices if x.name == iname]
757                     if i:
758                         i = i[0]
759                         # Use the superclass call to avoid ALTER TABLE
760                         dict.__setitem__(t.columns.indices, ikey, i)
761                
762                 # Use the superclass call to avoid CREATE TABLE
763                 dict.__setitem__(self.db, cls.__name__, t)
764    
765     def autoclass(self, table, newclassname=None):
766         """Create a Unit class automatically from this table and its columns."""
767         class AutoUnitClass(dejavu.Unit):
768             pass
769         for cname, c in table.columns.iteritems():
770             ptype = self.db.python_type(c.dbtype)
771             p = AutoUnitClass.set_property(cname, ptype)
772             p.default = c.default
773             p.hints = c.hints.copy()
774        
775         if newclassname is None:
776             newclassname = table.name
777         AutoUnitClass.__name__ = newclassname
778        
779         return AutoUnitClass
780
Note: See TracBrowser for help on using the browser.