Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/dejavu/storage/db.py

Revision 481 (checked in by fumanchu, 6 years ago)

New xmultirecall method on SM's, which also now takes order, limit, and offset.

  • Property svn:eol-style set to native
Line 
1 """Base classes and tools for writing database Storage Managers.
2
3 DATA TYPES
4 ==========
5 Database Storage Manager modules are mostly adapters to support round-trip
6 data coercion:
7
8 Unit type -> [SQL repr ->] DB -> incoming Python value -> Unit type
9
10 Since Dejavu relies on external database servers for its persistence,
11 Python datatypes must be converted to column types in the DB. When writing
12 a StorageManager, you should make sure that your type conversions can handle
13 at least the following limitations: If possible, implement the type with no
14 limits. Also, follow UnitProperty.hints['bytes'] where possible. A value
15 of zero for hints['bytes'] implies no limit. If no value is given, try to
16 assume no limit, although you may choose whatever default size you wish
17 (255 is common for strings).
18
19 ENCODING ISSUES
20 ===============
21 All SQL sent to the database must be strings, not unicode. You can set the
22 encoding of the Adapters (I may add a more centralized encoding context in
23 the future). We must use encoded strings so that we can mix encodings
24 within the same string; for example, we might have a DB which understands
25 utf8, but a pickle value which will be encoded in raw-unicode-escape inline
26 with that. All values, therefore, must be coerced before we try to join
27 them into an SQL statement string.
28
29 """
30
31 import threading
32 import warnings
33
34
35 import geniusql
36 from geniusql import logic, logicfuncs
37
38 import dejavu
39 from dejavu import storage, logflags, xray
40 from dejavu.errors import StorageWarning, MappingError
41
42
43 # --------------------------- Storage Manager --------------------------- #
44
45
46 class StorageManagerDB(storage.StorageManager):
47     """StoreManager base class to save and retrieve Units using a DB."""
48    
49     databaseclass = geniusql.Database
50    
51     def __init__(self, allOptions={}):
52         storage.StorageManager.__init__(self, allOptions)
53         self.reserve_lock = threading.Lock()
54        
55         # Config Overrides
56         def get_option(name):
57             item = allOptions.get(name)
58             if isinstance(item, basestring):
59                 item = xray.classes(item)
60             return item
61        
62         dbclass = get_option('Database Class')
63         if dbclass:
64             self.databaseclass = dbclass
65        
66         allOptions = dict([(str(k), v) for k, v in allOptions.iteritems()])
67        
68         self.db = self.databaseclass(**allOptions)
69         self.schema = self.db.schema()
70        
71         if 'Prefix' in allOptions:
72             self.schema.prefix = allOptions['Prefix']
73        
74         def logger(msg):
75             if self.logflags & logflags.SQL:
76                 self.log(logflags.SQL.message(msg))
77         self.db.log = logger
78    
79     def version(self):
80         return self.db.version()
81    
82     def shutdown(self):
83         self.db.connections.shutdown()
84    
85     def xrecall(self, cls, expr=None, order=None, limit=None, offset=None):
86         """Yield a sequence of Unit instances which satisfy the expression."""
87         if self.logflags & logflags.RECALL:
88             self.log(logflags.RECALL.message(cls, expr))
89        
90         clsname = cls.__name__
91        
92         # Put the identifier properties first, in case other fields
93         # depend upon them.
94         idnames = list(cls.identifiers)
95         attrs = idnames + [x for x in cls.properties if x not in idnames]
96         coercers = [getattr(cls, key).coerce for key in attrs]
97        
98         data = self.select((cls, attrs, expr), order=order,
99                            limit=limit, offset=offset)
100         for row in data:
101             unit = cls.__new__(cls)
102             unit._zombie = True
103             unit.__init__()
104            
105             for key, value, propcoerce in zip(attrs, row, coercers):
106                 try:
107                     if propcoerce:
108                         value = propcoerce(unit, value)
109                     unit._properties[key] = value
110                 except UnicodeDecodeError, x:
111                     x.reason += " [%r: %r]" % (key, value)
112                     raise
113                 except Exception, x:
114                     x.args += (key, value)
115                     raise
116            
117             # If our SQL is imperfect, don't yield it to the
118             # caller unless it passes expr(unit).
119             if expr and data.statement.imperfect:
120                 if not expr(unit):
121                     continue
122            
123             unit.cleanse()
124             yield unit
125    
126     def reserve(self, unit):
127         """Reserve a persistent slot for unit."""
128         self.reserve_lock.acquire()
129         try:
130             # First, see if our db subclass has a handler that
131             # uses the DB to generate the appropriate identifier(s).
132             seqclass = unit.sequencer.__class__.__name__
133             seq_handler = getattr(self, "_seq_%s" % seqclass, None)
134             if seq_handler:
135                 seq_handler(unit)
136             else:
137                 self._manual_reserve(unit)
138             unit.cleanse()
139         finally:
140             self.reserve_lock.release()
141        
142         # Usually we log ASAP, but here we log after
143         # the unit has had a chance to get an auto ID.
144         if self.logflags & logflags.RESERVE:
145             self.log(logflags.RESERVE.message(unit))
146    
147     def _seq_UnitSequencerInteger(self, unit):
148         """Reserve a unit (using the table's autoincrement fields)."""
149         cls = unit.__class__
150        
151         # Grab the new ID. This is threadsafe because reserve has a mutex.
152         newids = self.schema[cls.__name__].insert(**unit._properties)
153         for k, v in newids.iteritems():
154             setattr(unit, k, v)
155    
156     def _manual_reserve(self, unit):
157         """Use when the DB cannot automatically generate an identifier.
158         The identifiers will be supplied by UnitSequencer.assign().
159         """
160         cls = unit.__class__
161         t = self.schema[cls.__name__]
162         if not unit.sequencer.valid_id(unit.identity()):
163             # Examine all existing IDs and grant the "next" one.
164             data = list(self.db.select((t, cls.identifiers)))
165             cls.sequencer.assign(unit, data)
166         t.insert(**unit._properties)
167    
168     def save(self, unit, forceSave=False):
169         """Update storage from unit's data (if unit.dirty())."""
170         if self.logflags & logflags.SAVE:
171             self.log(logflags.SAVE.message(unit, forceSave))
172        
173         if forceSave or unit.dirty():
174             self.schema[unit.__class__.__name__].save(**unit._properties)
175             unit.cleanse()
176    
177     def destroy(self, unit):
178         """Delete the unit."""
179         if self.logflags & logflags.DESTROY:
180             self.log(logflags.DESTROY.message(unit))
181        
182         table = self.schema[unit.__class__.__name__]
183         table.delete(**unit._properties)
184    
185    
186     #                                Views                                #
187    
188     def tablejoin(self, join):
189         """Return a geniusql Join tree for the given UnitJoin."""
190         t1, t2 = join.class1, join.class2
191        
192         if isinstance(t1, dejavu.UnitJoin):
193             wt1 = self.tablejoin(t1)
194         else:
195             wt1 = self.schema[t1.__name__]
196        
197         if isinstance(t2, dejavu.UnitJoin):
198             wt2 = self.tablejoin(t2)
199         else:
200             wt2 = self.schema[t2.__name__]
201        
202         uj = geniusql.sqlwriters.Join(wt1, wt2, join.leftbiased)
203         # if the original UnitJoin had a custom association path,
204         # copy it to the new Join instance
205         uj.path = join.path
206         return uj
207    
208     def _geniusql_query(self, query):
209         """Return a Geniusql Query object for the given Dejavu Query."""
210         rel = query.relation
211         if isinstance(rel, dejavu.UnitJoin):
212             rel = self.tablejoin(rel)
213         elif rel is None:
214             # This is a Geniusql-ism: send the schema when we have no FROM.
215             rel = self.schema
216         else:
217             rel = self.schema[rel.__name__]
218         return geniusql.sqlwriters.Query(rel, query.attributes, query.restriction)
219    
220     def select(self, query, order=None, distinct=False, limit=None, offset=None):
221         """Return a geniusql Dataset for the given Query object."""
222         if not isinstance(query, dejavu.Query):
223             query = dejavu.Query(*query)
224        
225         return self.db.select(self._geniusql_query(query),
226                               order=order, distinct=distinct,
227                               limit=limit, offset=offset, strict=False)
228    
229     def insert_into(self, name, query, distinct=False):
230         """INSERT matching data INTO a new class and return the class."""
231         if not isinstance(query, dejavu.Query):
232             query = dejavu.Query(*query)
233        
234         self.db.insert_into(name, self._geniusql_query(query),
235                             distinct=distinct)
236         return Modeler(self.schema).make_class(name)
237    
238     def make_class(self, name):
239         """Return a (new) Unit class for the given storage name."""
240         return Modeler(self.schema).make_class(name)
241    
242     def xview(self, query, distinct=False):
243         """Yield value tuples for the given query."""
244         if not isinstance(query, dejavu.Query):
245             query = dejavu.Query(*query)
246        
247         if self.logflags & logflags.VIEW:
248             self.log(logflags.VIEW.message(query, distinct))
249        
250         data = self.select(query, distinct=distinct)
251         if data.statement.imperfect:
252             # ^%$#@! There's no way to handle imperfect queries without
253             # creating all involved Units, which defeats the performance
254             # benefits of view.
255             clsname = self.__class__.__name__
256             warnings.warn("The requested query cannot produce perfect SQL "
257                           "with a %s datasource. It may take an absurdly "
258                           "long time to run, since each unit must be fully-"
259                           "formed. %s" % (clsname, query), StorageWarning)
260             for row in storage.StorageManager.xview(self, query, distinct):
261                 yield row
262         else:
263             # Use tuples for hashability
264             for row in data:
265                 yield tuple(row)
266    
267     def count(self, cls, expr=None):
268         """Number of Units of the given cls which match the given expr."""
269         if cls.identifiers:
270             uniq = cls.identifiers
271         else:
272             uniq = cls._properties.keys()
273         # TODO: handle multiple args to count()
274         counter = lambda x: [logicfuncs.count(getattr(x, uniq[0]))]
275        
276         query = dejavu.Query(cls, counter, expr)
277        
278         if self.logflags & logflags.VIEW:
279             self.log(logflags.VIEW.message(query, False))
280        
281         data = self.select(query)
282         if data.statement.imperfect:
283             # ^%$#@! There's no way to handle imperfect queries without
284             # creating all involved Units, which defeats the performance
285             # benefits of view.
286             clsname = self.__class__.__name__
287             warnings.warn("The requested query cannot produce perfect SQL "
288                           "with a %s datasource. It may take an absurdly "
289                           "long time to run, since each unit must be fully-"
290                           "formed. %s" % (clsname, query), StorageWarning)
291             return storage.StorageManager.count(self, cls, expr)
292         else:
293             return data.scalar()
294    
295     def xmultirecall(self, classes, expr=None, order=None, limit=None, offset=None):
296         """Yield Unit instance sets which satisfy the expression."""
297         if self.logflags & logflags.RECALL:
298             self.log(logflags.RECALL.message(classes, expr))
299        
300         # Gather attribute list.
301         allattrs = []
302         props = []
303         for cls in classes:
304             t = self.schema[cls.__name__]
305             attrs = []
306             for key in cls.properties:
307                 attrs.append(key)
308                 props.append((cls, key, getattr(cls, key).coerce))
309             allattrs.append(attrs)
310        
311         data = self.select((classes, allattrs, expr), order=order,
312                            limit=limit, offset=offset)
313         for row in data:
314             # TODO: This is broken; won't work if same cls appears twice.
315             units = {}
316             for i, (cls, key, propcoerce) in enumerate(props):
317                 if cls in units:
318                     unit = units[cls]
319                 else:
320                     unit = cls.__new__(cls)
321                     unit._zombie = True
322                     unit.__init__()
323                     units[cls] = unit
324                
325                 value = row[i]
326                 try:
327                     if propcoerce:
328                         value = propcoerce(unit, value)
329                     unit._properties[key] = value
330                 except Exception, x:
331                     x.args += (cls, key)
332                     raise
333            
334             unitset = []
335             for cls in classes:
336                 unit = units[cls]
337                 unit.cleanse()
338                 unitset.append(unit)
339            
340             # If our SQL is imperfect, don't yield units to the
341             # caller unless they pass expr(unit).
342             acceptable = True
343             if expr and data.statement.imperfect:
344                 acceptable = expr(*unitset)
345             if acceptable:
346                 yield unitset
347    
348     #                               Schemas                               #
349    
350     def create_database(self):
351         if self.logflags & logflags.DDL:
352             self.log(logflags.DDL.message("create database"))
353         self.db.create()
354         self.schema.create()
355    
356     def drop_database(self):
357         if self.logflags & logflags.DDL:
358             self.log(logflags.DDL.message("drop database"))
359         self.schema.drop()
360         self.db.drop()
361    
362     def _make_table(self, cls):
363         """Create and return a Table object for the given class."""
364         t = self.schema.table(cls.__name__)
365        
366         indices = cls.indices()
367         fields = []
368         for key in cls.properties:
369             t[key] = self._make_column(cls, key)
370             if key in indices:
371                 t.add_index(key)
372        
373         # Copy associations to table.references.
374         for k, v in cls._associations.iteritems():
375             t.references[k] = (v.nearKey, v.farClass.__name__, v.farKey)
376        
377         return t
378    
379     def create_storage(self, cls):
380         """Create storage for the given class."""
381         if self.logflags & logflags.DDL:
382             self.log(logflags.DDL.message("create storage %s" % cls))
383         # Attach to self.schema, which should call CREATE TABLE.
384         self.schema[cls.__name__] = self._make_table(cls)
385    
386     def _make_column(self, cls, key):
387         prop = getattr(cls, key)
388         col = self.schema.column(prop.type, default=prop.default, hints=prop.hints)
389         if key in cls.identifiers:
390             col.key = True
391             if isinstance(cls.sequencer, dejavu.UnitSequencerInteger):
392                 col.autoincrement = True
393                 col.initial = cls.sequencer.initial
394         return col
395    
396     def has_storage(self, cls):
397         return cls.__name__ in self.schema
398    
399     def drop_storage(self, cls):
400         if self.logflags & logflags.DDL:
401             self.log(logflags.DDL.message("drop storage %s" % cls))
402         del self.schema[cls.__name__]
403    
404     def rename_storage(self, oldname, newname):
405         if self.logflags & logflags.DDL:
406             self.log(logflags.DDL.message("rename storage from %s to %s"
407                                           % (oldname, newname)))
408         self.schema.rename(oldname, newname)
409    
410     def add_property(self, cls, name):
411         if self.logflags & logflags.DDL:
412             self.log(logflags.DDL.message("add property %s %s" %
413                                           (cls, name)))
414         if not self.has_property(cls, name):
415             table = self.schema[cls.__name__]
416             table[name] = self._make_column(cls, name)
417    
418     def has_property(self, cls, name):
419         return name in self.schema[cls.__name__]
420    
421     def drop_property(self, cls, name):
422         if self.logflags & logflags.DDL:
423             self.log(logflags.DDL.message("drop property %s %s" %
424                                           (cls, name)))
425         if self.has_property(cls, name):
426             del self.schema[cls.__name__][name]
427    
428     def rename_property(self, cls, oldname, newname):
429         if self.logflags & logflags.DDL:
430             self.log(logflags.DDL.message(
431                 "rename property %s from %s to %s" %
432                 (cls, oldname, newname)))
433         t = self.schema[cls.__name__]
434        
435         # Sometimes, a Dejavu Schema will change a code model first, and
436         # then change the database afterward. So it's possible that the
437         # column we're trying to rename hasn't been loaded, because the
438         # model layer no longer references it. So if table[oldname]
439         # raises a KeyError, try to find a column that matches oldkey.
440         tempcol = None
441         try:
442             t[oldname]
443         except KeyError:
444             c = [x for x in self.schema._get_columns(t.name)
445                  if x.name == self.schema._column_name(t.name, oldname)]
446             if not c:
447                 raise KeyError("Rename failed. Old column %r not found in %r."
448                                % (oldname, t.name))
449             oldcol = c[0]
450             # Use the superclass call to avoid DROP COLUMN/ADD COLUMN.
451             dict.__setitem__(t, oldname, oldcol)
452        
453         t.rename(oldname, newname)
454    
455     def add_index(self, cls, name):
456         self.schema[cls.__name__].add_index(name)
457    
458     def has_index(self, cls, name):
459         return name in self.schema[cls.__name__].indices
460    
461     def drop_index(self, cls, name):
462         del self.schema[cls.__name__].indices[name]
463    
464     auto_discover = True
465    
466     def map(self, classes, conflict_mode='error'):
467         """Map classes to internal storage.
468         
469         If self.auto_discover is True (the default), then Table/Column/Index
470         objects will be formed by inspecting the underlying database using
471         self.sync().
472         
473         If auto_discover is False, then mock Table/Column/Index objects
474         will be used instead; this provides a performance improvement
475         in scenarios where the model maps perfectly to the database
476         and changes to the database are not expected outside the model.
477         
478         conflict_mode: This argument determines what happens when there are
479         discrepancies between the Dejavu model and the actual database.
480             
481             If 'error' (the default), MappingError is raised for the
482             first issue and the sync process is aborted.
483             
484             If 'warn', then a warning is raised (instead of an error)
485             for each issue, and the sync process is not aborted. This
486             allows you to see all errors at once, without having to stop
487             and fix each one and then execute the process again.
488             
489             If 'repair', then each issue will be resolved by changing
490             the database to match the model.
491         """
492         if self.auto_discover:
493             self.sync(classes, conflict_mode)
494         else:
495             for cls in classes:
496                 if self.has_storage(cls):
497                     # If our consumer-side key is already present, skip this cls.
498                     # This allows callers to auto-sync class by class
499                     # without making a new Table object each time.
500                     continue
501                
502                 t = self._make_table(cls)
503                
504                 # Use the superclass call to avoid DROP/CREATE TABLE
505                 dict.__setitem__(self.schema, cls.__name__, t)
506    
507     def sync(self, classes, conflict_mode='error'):
508         """Map classes to existing Table objects (found via discovery).
509         
510         conflict_mode: This argument determines what happens when there are
511         discrepancies between the Dejavu model and the actual database.
512             
513             If 'error' (the default), MappingError is raised for the
514             first issue and the sync process is aborted.
515             
516             If 'warn', then a warning is raised (instead of an error)
517             for each issue, and the sync process is not aborted. This
518             allows you to see all errors at once, without having to stop
519             and fix each one and then execute the process again.
520             
521             If 'repair', then each issue will be resolved by changing
522             the database to match the model.
523             
524             If 'ignore', then each issue will be silently ignored.
525         """
526         for cls in classes:
527             if cls.__name__ in self.schema:
528                 # If our consumer-side key is already present, skip this cls.
529                 # This allows callers to auto-sync class by class
530                 # without calling the expensive discover() func each time.
531                 continue
532             self._find_table(self.schema, cls, conflict_mode)
533    
534     def _find_table(self, schema, cls, conflict_mode='error'):
535         # This is broken out to make multi-schema subclasses easier to write.
536        
537         def notify(msg):
538             if conflict_mode == 'warn':
539                 warnings.warn(msg)
540             elif conflict_mode == 'ignore':
541                 pass
542             else:
543                 raise MappingError(msg)
544        
545         # Try to find a matching Table object using the DB-side key.
546         clsname = cls.__name__
547         tablename = schema.table_name(clsname)
548         try:
549             # Do we already have a map using the DB name?
550             table = schema[tablename]
551             schema.alias(table.name, clsname)
552         except KeyError:
553             # Can we create a map? Discover the DB table and try again.
554             try:
555                 table = schema.discover(tablename)
556                 schema.alias(table.name, clsname)
557             except geniusql.errors.MappingError:
558                 msg = "%s: no such table %r." % (clsname, tablename)
559                 if conflict_mode == 'repair':
560                     self.create_storage(cls)
561                     table = schema[clsname]
562                 else:
563                     notify(msg)
564                     return
565        
566         # Match Column objects with class properties.
567         dbcols = dict([(c.name, c) for c in table.itervalues()])
568         indices = cls.indices()
569         for pkey in cls.properties:
570             colname = schema._column_name(table.name, pkey)
571             try:
572                 col = dbcols[colname]
573                 table.alias(colname, pkey)
574             except KeyError, x:
575                 msg = "%s: no column found for %r." % (clsname, pkey)
576                 if conflict_mode == 'repair':
577                     self.add_property(cls, pkey)
578                     if pkey in cls.indices() and pkey not in table.indices:
579                         self.add_index(cls, pkey)
580                     col = table[pkey]
581                 else:
582                     notify(msg)
583                     continue
584            
585             # Check that the column.key matches our identifiers list;
586             # this is crucial for the proper operation of OLTP methods
587             # in geniusql.Table, which uses column.key to decide
588             # the unique identifiers for a given row of data.
589             if pkey in cls.identifiers and not col.key:
590                 msg = ("%s: %r is an identifier, but the "
591                        "column is not marked as a primary key."
592                        % (clsname, pkey))
593                 if conflict_mode == 'repair':
594                     col.key = True
595                     table.set_primary()
596                 else:
597                     notify(msg)
598                     continue
599             elif col.key and not pkey in cls.identifiers:
600                 msg = ("%s: %r is not an identifier, but the "
601                        "column is marked as a primary key."
602                        % (clsname, pkey))
603                 if conflict_mode == 'repair':
604                     col.key = False
605                     # Just because the current pkey is not an identifier
606                     # doesn't mean we have *no* identifiers.
607                     table.set_primary()
608                 else:
609                     notify(msg)
610                     continue
611            
612             col.pytype = getattr(cls, pkey).type
613            
614             # Override the default adapter (since it guessed an adapter
615             # using the default pytype, and we know better).
616             try:
617                 col.adapter = col.dbtype.default_adapter(col.pytype)
618             except TypeError, x:
619                 x.args += ("%s.%s" % (table.name, col.name),)
620                 raise
621            
622             # Try to find matching Index objects. Because index names are
623             # so platform-specific, we match attributes rather than names.
624             if pkey in indices:
625                 for ikey, idx in table.indices.items():
626                     if idx.colname == colname:
627                         a = schema.table_name("i" + clsname + pkey)
628                         table.indices.alias(ikey, a)
629                         break
630                 else:
631                     msg = "%s: no index found for %r." % (clsname, pkey)
632                     if conflict_mode == 'repair':
633                         self.add_index(cls, pkey)
634                     else:
635                         notify(msg)
636                         continue
637             else:
638                 if pkey in cls.identifiers and self.db.pks_must_be_indexed:
639                     pass
640                 else:
641                     for ikey, idx in table.indices.items():
642                         if idx.colname == colname:
643                             msg = ("%s: index found for non-indexed %r."
644                                    % (clsname, pkey))
645                             if conflict_mode == 'repair':
646                                 self.drop_index(cls, ikey)
647                             else:
648                                 notify(msg)
649                                 continue
650        
651         # Set Table.references
652         for k, v in cls._associations.iteritems():
653             table.references[k] = (v.nearKey, v.farClass.__name__, v.farKey)
654        
655         return table
656    
657     #                            Transactions                             #
658    
659     def start(self, isolation=None):
660         "Start a transaction (not needed if db.connections.implicit_trans)."
661         self.db.connections.start(isolation)
662    
663     def rollback(self):
664         """Roll back the current transaction."""
665         self.db.connections.rollback()
666    
667     def commit(self):
668         """Commit the current transaction."""
669         self.db.connections.commit()
670
671
672 class Modeler(object):
673     """Tool to automatically form Unit classes or source from existing DB's."""
674    
675     ignore = ['Unit', 'DeployedVersion',
676               'UnitEngine', 'UnitEngineRule', 'UnitCollection',
677               ]
678    
679     def __init__(self, schema):
680         self.schema = schema
681         self.ignore = self.ignore[:]
682    
683     def all_classes(self):
684         """Return a list of new classes for all tables in the Database."""
685         ignore = dict.fromkeys([self.schema.table_name(x) for x in self.ignore]
686                                + self.ignore).keys()
687        
688         self.schema.discover_all(ignore=ignore)
689        
690         classes = []
691         seen = {}
692         for key, table in self.schema.items():
693             if key not in ignore and table.name not in seen:
694                 cls = self.make_class(key)
695                 classes.append(cls)
696                 seen[table.name] = None
697         return classes
698    
699     def make_class(self, tablename, newclassname=None):
700         """Create a Unit class automatically from the named table."""
701         if tablename not in self.schema:
702             self.schema.discover(tablename)
703         table = self.schema[tablename]
704        
705         class AutoUnitClass(dejavu.Unit):
706             sequencer = dejavu.UnitSequencer()
707             identifiers = tuple([k for k in table if table[k].key])
708        
709         if newclassname is None:
710             newclassname = table.name
711             # The key is probably better than the table.name. Try it.
712             for key, t in self.schema.iteritems():
713                 if t.name == newclassname:
714                     newclassname = key
715                     break
716         AutoUnitClass.__name__ = newclassname
717        
718         indices = [idx.colname for idx in table.indices.itervalues()]
719         for cname, c in table.iteritems():
720             ptype = c.pytype
721             if ptype == int and c.dbtype.bytes == 1:
722                 # This is probably a bool
723                 ptype = bool
724             p = AutoUnitClass.set_property(cname, ptype)
725             if c.autoincrement:
726                 AutoUnitClass.sequencer = dejavu.UnitSequencerInteger(int, c.initial)
727             p.default = c.default
728            
729             p.hints = dict([(k, getattr(c.dbtype, k))
730                             for k in ("bytes", "precision", "scale")
731                             if hasattr(c.dbtype, k)])
732             if p.hints:
733                 # Postgresql hack: replace bytes=ComparableInfinity with 0,
734                 # since 0 signifies "no limit".
735                 for k, v in p.hints.iteritems():
736                     if v.__class__.__name__ == 'ComparableInfinity':
737                         p.hints[k] = 0
738            
739             p.index = (cname in indices)
740        
741         # Remove default ID property if necessary.
742         if "ID" not in table:
743             AutoUnitClass.properties.remove('ID')
744             AutoUnitClass.ID = None
745        
746         return AutoUnitClass
747    
748     def all_source(self):
749         """Return a list of strings of Unit source code for all tables."""
750         ignore = dict.fromkeys([self.schema.table_name(x) for x in self.ignore]
751                                + self.ignore).keys()
752        
753         self.schema.discover_all(ignore=ignore)
754        
755         allcode = []
756         seen = {}
757         tables = self.schema.items()
758         tables.sort()
759         for key, table in tables:
760             if key not in ignore and table.name not in seen:
761                 code = self.make_source(key)
762                 allcode.append(code)
763                 seen[table.name] = None
764         return allcode
765    
766     def make_source(self, tablename, newclassname=None):
767         """Create source code for a Unit class from the named table."""
768         if tablename not in self.schema:
769             self.schema.discover(tablename)
770         table = self.schema[tablename]
771        
772         code = []
773        
774         if newclassname is None:
775             newclassname = table.name
776             # The key is probably better than the table.name. Try it.
777             for key, t in self.schema.iteritems():
778                 if t.name == newclassname:
779                     newclassname = key
780                     break
781             # Make the name safe for use as a Python class name
782             newclassname = newclassname.replace(".", "_")
783         code.append("class %s(Unit):" % newclassname)
784        
785         if table.description:
786             import textwrap
787             block = textwrap.fill(table.description, subsequent_indent='    ')
788             code.append('    """%s"""' % block)
789        
790         sequencer = None
791         indices = [idx.colname for idx in table.indices.itervalues()]
792        
793         # iterate over all columns
794         columns = table.items()
795         columns.sort()
796         for cname, c in columns:
797             prop, seq = self._make_column_source(cname, c, cname in indices)
798             if not prop.startswith("    ID = UnitProperty(int"):
799                 code.append(prop)
800             if seq:
801                 sequencer = seq
802        
803         # Remove default ID property if necessary.
804         if "ID" not in table:
805             code.append("    # Remove the default 'ID' property.")
806             code.append("    ID = None")
807        
808         pk = tuple([k for k in table if table[k].key])
809         if pk not in [("ID",), ("id",)]:
810             code.append("    identifiers = %s" % repr(pk))
811        
812         if sequencer:
813             if sequencer != "    sequencer = UnitSequencerInteger(int, 1)":
814                 code.append(sequencer)
815         else:
816             code.append("    sequencer = UnitSequencer()")
817        
818         if len(code) == 1:
819             code.append("    pass")
820        
821         return "\n".join(code)
822    
823     def _make_column_source(self, colname, column, has_index):
824         ptype = column.pytype
825         if ptype == int and column.dbtype.bytes == 1:
826             # This is probably a bool
827             ptype = bool
828        
829         mod = ptype.__module__
830         if mod == '__builtin__':
831             ptype = ptype.__name__
832         else:
833             ptype = mod + "." + ptype.__name__
834        
835         seq = None
836         if column.autoincrement:
837             seq = ("    sequencer = UnitSequencerInteger(int, %r)" %
838                    column.initial)
839        
840         default = column.default
841         if default is None:
842             default = ""
843         else:
844             default = ", default=%r" % default
845        
846         index = ""
847         if has_index:
848             index = ", index=True"
849        
850         hints = dict([(k, getattr(column.dbtype, k))
851                       for k in ("bytes", "precision", "scale")
852                       if hasattr(column.dbtype, k)])
853         if hints:
854             # Postgresql hack: replace bytes=ComparableInfinity with 0,
855             # since 0 signifies "no limit".
856             for k, v in hints.iteritems():
857                 if v.__class__.__name__ == 'ComparableInfinity':
858                     hints[k] = 0
859             hints = ", hints=%r" % hints
860         else:
861             hints = ""
862        
863         return ("    %s = UnitProperty(%s%s%s%s)" %
864                 (colname, ptype, index, hints, default)), seq
865
Note: See TracBrowser for help on using the browser.