Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/branches/crazycache/dejavu/storage/db.py

Revision 562 (checked in by fumanchu, 4 years ago)

Using the new index_name method from Geniusql.

  • Property svn:eol-style set to native
Line 
1 """Base classes and tools for writing database Storage Managers.
2
3 DATA TYPES
4 ==========
5 Database Storage Manager modules are mostly adapters to support round-trip
6 data coercion:
7
8 Unit type -> [SQL repr ->] DB -> incoming Python value -> Unit type
9
10 Since Dejavu relies on external database servers for its persistence,
11 Python datatypes must be converted to column types in the DB. When writing
12 a StorageManager, you should make sure that your type conversions can handle
13 at least the following limitations: If possible, implement the type with no
14 limits. Also, follow UnitProperty.hints['bytes'] where possible. A value
15 of zero for hints['bytes'] implies no limit. If no value is given, try to
16 assume no limit, although you may choose whatever default size you wish
17 (255 is common for strings).
18
19 ENCODING ISSUES
20 ===============
21 All SQL sent to the database must be strings, not unicode. You can set the
22 encoding of the Adapters (I may add a more centralized encoding context in
23 the future). We must use encoded strings so that we can mix encodings
24 within the same string; for example, we might have a DB which understands
25 utf8, but a pickle value which will be encoded in raw-unicode-escape inline
26 with that. All values, therefore, must be coerced before we try to join
27 them into an SQL statement string.
28
29 """
30
31 import threading
32 import warnings
33
34
35 import geniusql
36 from geniusql import logic, logicfuncs
37
38 import dejavu
39 from dejavu import storage, logflags, xray
40 from dejavu.errors import StorageWarning, MappingError, conflict
41
42
43 # --------------------------- Storage Manager --------------------------- #
44
45
46 class StorageManagerDB(storage.StorageManager):
47     """StoreManager base class to save and retrieve Units using a DB."""
48    
49     databaseclass = geniusql.Database
50    
51     def __init__(self, allOptions={}):
52         storage.StorageManager.__init__(self, allOptions)
53         self.reserve_lock = threading.Lock()
54        
55         # Config Overrides
56         def get_option(name):
57             item = allOptions.get(name)
58             if isinstance(item, basestring):
59                 item = xray.classes(item)
60             return item
61        
62         dbclass = get_option('Database Class')
63         if dbclass:
64             self.databaseclass = dbclass
65        
66         allOptions = dict([(str(k), v) for k, v in allOptions.iteritems()])
67        
68         self.db = self.databaseclass(**allOptions)
69         self.schema = self.db.schema()
70        
71         if 'Prefix' in allOptions:
72             self.schema.prefix = allOptions['Prefix']
73        
74         def logger(msg):
75             if self.logflags & logflags.SQL:
76                 self.log(logflags.SQL.message(msg))
77         self.db.log = logger
78    
79     def version(self):
80         return self.db.version()
81    
82     def shutdown(self, conflicts='error'):
83         """Shut down all connections to internal storage.
84         
85         conflicts: see errors.conflict.
86         """
87         self.db.connections.shutdown()
88    
89     def xrecall(self, classes, expr=None, order=None, limit=None, offset=None):
90         """Yield a sequence of Unit instances which satisfy the expression."""
91         if isinstance(classes, dejavu.UnitJoin):
92             for unitrow in self._xmultirecall(classes, expr, order=order,
93                                               limit=limit, offset=offset):
94                 yield unitrow
95             return
96         else:
97             cls = classes
98        
99         if self.logflags & logflags.RECALL:
100             self.log(logflags.RECALL.message(cls, expr))
101        
102         clsname = cls.__name__
103        
104         # Put the identifier properties first, in case other fields
105         # depend upon them.
106         idnames = list(cls.identifiers)
107         attrs = idnames + [x for x in cls.properties if x not in idnames]
108         coercers = [getattr(cls, key).coerce for key in attrs]
109        
110         data = self.select((cls, attrs, expr), order=order,
111                            limit=limit, offset=offset)
112         for row in data:
113             unit = cls.__new__(cls)
114             unit._zombie = True
115             unit.__init__()
116            
117             for key, value, propcoerce in zip(attrs, row, coercers):
118                 try:
119                     if propcoerce:
120                         value = propcoerce(unit, value)
121                     unit._properties[key] = value
122                 except UnicodeDecodeError, x:
123                     x.reason += " [%r: %r]" % (key, value)
124                     raise
125                 except Exception, x:
126                     x.args += (key, value)
127                     raise
128            
129             # If our SQL is imperfect, don't yield it to the
130             # caller unless it passes expr(unit).
131             if expr and data.statement.imperfect:
132                 if not expr(unit):
133                     continue
134            
135             unit.cleanse()
136             yield unit
137    
138     def reserve(self, unit):
139         """Reserve a persistent slot for unit."""
140         self.reserve_lock.acquire()
141         try:
142             # First, see if our db subclass has a handler that
143             # uses the DB to generate the appropriate identifier(s).
144             seqclass = unit.sequencer.__class__.__name__
145             seq_handler = getattr(self, "_seq_%s" % seqclass, None)
146             if seq_handler:
147                 seq_handler(unit)
148             else:
149                 self._manual_reserve(unit)
150             unit.cleanse()
151         finally:
152             self.reserve_lock.release()
153        
154         # Usually we log ASAP, but here we log after
155         # the unit has had a chance to get an auto ID.
156         if self.logflags & logflags.RESERVE:
157             self.log(logflags.RESERVE.message(unit))
158    
159     def _seq_UnitSequencerInteger(self, unit):
160         """Reserve a unit (using the table's autoincrement fields)."""
161         cls = unit.__class__
162        
163         # Grab the new ID. This is threadsafe because reserve has a mutex.
164         newids = self.schema[cls.__name__].insert(**unit._properties)
165         for k, v in newids.iteritems():
166             setattr(unit, k, v)
167    
168     def _manual_reserve(self, unit):
169         """Use when the DB cannot automatically generate an identifier.
170         The identifiers will be supplied by UnitSequencer.assign().
171         """
172         cls = unit.__class__
173         t = self.schema[cls.__name__]
174         if not unit.sequencer.valid_id(unit.identity()):
175             # Examine all existing IDs and grant the "next" one.
176             data = list(self.db.select((t, cls.identifiers)))
177             cls.sequencer.assign(unit, data)
178         t.insert(**unit._properties)
179    
180     def save(self, unit, forceSave=False):
181         """Update storage from unit's data (if unit.dirty())."""
182         if self.logflags & logflags.SAVE:
183             self.log(logflags.SAVE.message(unit, forceSave))
184        
185         if forceSave or unit.dirty():
186             self.schema[unit.__class__.__name__].save(**unit._properties)
187             unit.cleanse()
188    
189     def destroy(self, unit):
190         """Delete the unit."""
191         if self.logflags & logflags.DESTROY:
192             self.log(logflags.DESTROY.message(unit))
193        
194         table = self.schema[unit.__class__.__name__]
195         table.delete(**unit._properties)
196    
197    
198     #                                Views                                #
199    
200     def tablejoin(self, join):
201         """Return a geniusql Join tree for the given UnitJoin."""
202         t1, t2 = join.class1, join.class2
203        
204         if isinstance(t1, dejavu.UnitJoin):
205             wt1 = self.tablejoin(t1)
206         else:
207             wt1 = self.schema[t1.__name__]
208        
209         if isinstance(t2, dejavu.UnitJoin):
210             wt2 = self.tablejoin(t2)
211         else:
212             wt2 = self.schema[t2.__name__]
213        
214         uj = geniusql.Join(wt1, wt2, join.leftbiased)
215         # if the original UnitJoin had a custom association path,
216         # copy it to the new Join instance
217         uj.path = join.path
218         return uj
219    
220     def _geniusql_query(self, query):
221         """Return a Geniusql Query object for the given Dejavu Query."""
222         rel = query.relation
223         if isinstance(rel, dejavu.UnitJoin):
224             rel = self.tablejoin(rel)
225         elif rel is None:
226             # This is a Geniusql-ism: send the schema when we have no FROM.
227             rel = self.schema
228         else:
229             rel = self.schema[rel.__name__]
230         return geniusql.Query(rel, query.attributes, query.restriction)
231    
232     def _geniusql_statement(self, statement):
233         """Return a Geniusql Statement object for the given Dejavu Statement."""
234         return geniusql.Statement(
235             self._geniusql_query(statement.query),
236             order=statement.order, limit=statement.limit,
237             offset=statement.offset, distinct=statement.distinct)
238    
239     def select(self, query, order=None, limit=None, offset=None, distinct=None):
240         """Return a geniusql Dataset for the given Query object."""
241         if not isinstance(query, dejavu.Query):
242             query = dejavu.Query(*query)
243        
244         return self.db.select(self._geniusql_query(query),
245                               order=order, limit=limit, offset=offset,
246                               distinct=distinct, strict=False)
247    
248     def insert_into(self, name, query, distinct=False):
249         """INSERT matching data INTO a new class and return the class."""
250         if not isinstance(query, dejavu.Query):
251             query = dejavu.Query(*query)
252        
253         self.db.insert_into(name, self._geniusql_query(query),
254                             distinct=distinct)
255         return Modeler(self.schema).make_class(name)
256    
257     def make_class(self, name):
258         """Return a (new) Unit class for the given storage name."""
259         return Modeler(self.schema).make_class(name)
260    
261     def xview(self, query, order=None, limit=None, offset=None, distinct=False):
262         """Yield value tuples for the given query."""
263         if not isinstance(query, dejavu.Query):
264             query = dejavu.Query(*query)
265        
266         if self.logflags & logflags.VIEW:
267             self.log(logflags.VIEW.message(query, distinct))
268        
269         data = self.select(query, order=order, limit=limit, offset=offset,
270                            distinct=distinct)
271         if data.statement.imperfect:
272             # ^%$#@! There's no way to handle imperfect queries without
273             # creating all involved Units, which defeats the performance
274             # benefits of view.
275             clsname = self.__class__.__name__
276             warnings.warn("The requested query cannot produce perfect SQL "
277                           "with a %s datasource. It may take an absurdly "
278                           "long time to run, since each unit must be fully-"
279                           "formed. %s" % (clsname, query), StorageWarning)
280             for row in storage.StorageManager.xview(self, query, order=order,
281                                                     limit=limit, offset=offset,
282                                                     distinct=distinct):
283                 yield row
284         else:
285             # Use tuples for hashability
286             for row in data:
287                 yield tuple(row)
288    
289     def count(self, cls, expr=None):
290         """Number of Units of the given cls which match the given expr."""
291         if cls.identifiers:
292             uniq = cls.identifiers
293         else:
294             uniq = cls._properties.keys()
295         # TODO: handle multiple args to count()
296         counter = lambda x: [logicfuncs.count(getattr(x, uniq[0]))]
297        
298         query = dejavu.Query(cls, counter, expr)
299        
300         if self.logflags & logflags.VIEW:
301             self.log(logflags.VIEW.message(query, False))
302        
303         data = self.select(query)
304         if data.statement.imperfect:
305             # ^%$#@! There's no way to handle imperfect queries without
306             # creating all involved Units, which defeats the performance
307             # benefits of view.
308             clsname = self.__class__.__name__
309             warnings.warn("The requested query cannot produce perfect SQL "
310                           "with a %s datasource. It may take an absurdly "
311                           "long time to run, since each unit must be fully-"
312                           "formed. %s" % (clsname, query), StorageWarning)
313             return storage.StorageManager.count(self, cls, expr)
314         else:
315             return data.scalar()
316    
317     def _xmultirecall(self, classes, expr=None, order=None, limit=None, offset=None):
318         """Yield Unit instance sets which satisfy the expression."""
319         if self.logflags & logflags.RECALL:
320             self.log(logflags.RECALL.message(classes, expr))
321        
322         # Gather attribute list.
323         allattrs = []
324         props = []
325         for cls in classes:
326             attrs = []
327             for key in cls.properties:
328                 attrs.append(key)
329                 props.append((cls, key, getattr(cls, key).coerce))
330             allattrs.append(attrs)
331        
332         data = self.select((classes, allattrs, expr), order=order,
333                            limit=limit, offset=offset)
334         for row in data:
335             # TODO: This is broken; won't work if same cls appears twice.
336             units = {}
337             for i, (cls, key, propcoerce) in enumerate(props):
338                 if cls in units:
339                     unit = units[cls]
340                 else:
341                     unit = cls.__new__(cls)
342                     unit._zombie = True
343                     unit.__init__()
344                     units[cls] = unit
345                
346                 value = row[i]
347                 try:
348                     if propcoerce:
349                         value = propcoerce(unit, value)
350                     unit._properties[key] = value
351                 except Exception, x:
352                     x.args += (cls, key)
353                     raise
354            
355             unitset = []
356             for cls in classes:
357                 unit = units[cls]
358                 unit.cleanse()
359                 unitset.append(unit)
360            
361             # If our SQL is imperfect, don't yield units to the
362             # caller unless they pass expr(unit).
363             acceptable = True
364             if expr and data.statement.imperfect:
365                 acceptable = expr(*unitset)
366             if acceptable:
367                 yield unitset
368    
369     #                               Schemas                               #
370    
371     def create_database(self, conflicts='error'):
372         """Create internal structures for the entire database.
373         
374         conflicts: see errors.conflict.
375         """
376         if self.logflags & logflags.DDL:
377             self.log(logflags.DDL.message("create database"))
378        
379         try:
380             self.db.create()
381             self.schema.create()
382         except geniusql.errors.MappingError, x:
383             conflict(conflicts, str(x))
384    
385     def drop_database(self, conflicts='error'):
386         """Destroy internal structures for the entire database.
387         
388         conflicts: see errors.conflict.
389         """
390         if self.logflags & logflags.DDL:
391             self.log(logflags.DDL.message("drop database"))
392        
393         try:
394             self.schema.drop()
395         except geniusql.errors.MappingError, x:
396             conflict(conflicts, str(x))
397        
398         try:
399             self.db.drop()
400         except geniusql.errors.MappingError, x:
401             conflict(conflicts, str(x))
402    
403     def _make_table(self, cls):
404         """Create and return a Table or View object for the given class."""
405         if hasattr(cls, "view_statement"):
406             gs = self._geniusql_statement(cls.view_statement)
407             t = self.schema.view(cls.__name__, gs)
408             fields = []
409             for key in cls.properties:
410                 t[key] = self._make_column(cls, key)
411         else:
412             t = self.schema.table(cls.__name__)
413             indices = cls.indices()
414             fields = []
415             for key in cls.properties:
416                 t[key] = col = self._make_column(cls, key)
417                 if key in indices:
418                     t.add_index(key)
419                 if col.autoincrement and col.sequence_name is None:
420                     # Not every database needs/uses sequence_name
421                     col.sequence_name = self.schema.sequence_name(t.name, key)
422        
423         # Copy associations to table.references.
424         for k, v in cls._associations.iteritems():
425             t.references[k] = (v.nearKey, v.farClass.__name__, v.farKey)
426        
427         return t
428    
429     def create_storage(self, cls, conflicts='error'):
430         """Create internal structures for the given class.
431         
432         conflicts: see errors.conflict.
433         """
434         if self.logflags & logflags.DDL:
435             self.log(logflags.DDL.message("create storage %s" % cls))
436        
437         try:
438             # Attach to self.schema, which should call CREATE TABLE.
439             self.schema[cls.__name__] = self._make_table(cls)
440         except geniusql.errors.MappingError, x:
441             conflict(conflicts, str(x))
442    
443     def _make_column(self, cls, key):
444         prop = getattr(cls, key)
445         col = self.schema.column(prop.type, prop.hints.get('dbtype'),
446                                  default=prop.default, hints=prop.hints)
447         if key in cls.identifiers:
448             col.key = True
449             if isinstance(cls.sequencer, dejavu.UnitSequencerInteger):
450                 col.autoincrement = True
451                 col.initial = cls.sequencer.initial
452         return col
453    
454     def has_storage(self, cls):
455         """If storage structures exist for the given class, return True."""
456         return cls.__name__ in self.schema
457    
458     def drop_storage(self, cls, conflicts='error'):
459         """Destroy internal structures for the given class.
460         
461         conflicts: see errors.conflict.
462         """
463         if self.logflags & logflags.DDL:
464             self.log(logflags.DDL.message("drop storage %s" % cls))
465        
466         try:
467             del self.schema[cls.__name__]
468         except geniusql.errors.MappingError, x:
469             conflict(conflicts, str(x))
470    
471     def rename_storage(self, oldname, newname, conflicts='error'):
472         """Rename internal structures for the given class.
473         
474         conflicts: see errors.conflict.
475         """
476         if self.logflags & logflags.DDL:
477             self.log(logflags.DDL.message("rename storage from %s to %s"
478                                           % (oldname, newname)))
479        
480         try:
481             self.schema.rename(oldname, newname)
482         except geniusql.errors.MappingError, x:
483             conflict(conflicts, str(x))
484    
485     def add_property(self, cls, name, conflicts='error'):
486         """Add internal structures for the given property.
487         
488         conflicts: see errors.conflict.
489         """
490         if self.logflags & logflags.DDL:
491             self.log(logflags.DDL.message("add property %s %s" %
492                                           (cls, name)))
493        
494         if not self.has_property(cls, name):
495             try:
496                 table = self.schema[cls.__name__]
497                 table[name] = col = self._make_column(cls, name)
498                 if col.autoincrement and col.sequence_name is None:
499                     # Not every database needs/uses sequence_name
500                     col.sequence_name = self.schema.sequence_name(table.name, name)
501             except geniusql.errors.MappingError, x:
502                 conflict(conflicts, str(x))
503    
504     def has_property(self, cls, name):
505         """If storage structures exist for the given property, return True."""
506         return name in self.schema[cls.__name__]
507    
508     def drop_property(self, cls, name, conflicts='error'):
509         """Destroy internal structures for the given property.
510         
511         conflicts: see errors.conflict.
512         """
513         if self.logflags & logflags.DDL:
514             self.log(logflags.DDL.message("drop property %s %s" %
515                                           (cls, name)))
516         if self.has_property(cls, name):
517             try:
518                 del self.schema[cls.__name__][name]
519             except geniusql.errors.MappingError, x:
520                 conflict(conflicts, str(x))
521    
522     def rename_property(self, cls, oldname, newname, conflicts='error'):
523         """Rename internal structures for the given property.
524         
525         conflicts: see errors.conflict.
526         """
527         if self.logflags & logflags.DDL:
528             self.log(logflags.DDL.message(
529                 "rename property %s from %s to %s" %
530                 (cls, oldname, newname)))
531        
532         try:
533             t = self.schema[cls.__name__]
534            
535             # Sometimes, a Dejavu Schema will change a code model first, and
536             # then change the database afterward. So it's possible that the
537             # column we're trying to rename hasn't been loaded, because the
538             # model layer no longer references it. So if table[oldname]
539             # raises a KeyError, try to find a column that matches oldkey.
540             tempcol = None
541             try:
542                 t[oldname]
543             except KeyError:
544                 c = [x for x in self.schema._get_columns(t.name)
545                      if x.name == self.schema.column_name(t.name, oldname)]
546                 if not c:
547                     raise KeyError("Rename failed. Old column %r not found in %r."
548                                    % (oldname, t.name))
549                 oldcol = c[0]
550                 # Use the superclass call to avoid DROP COLUMN/ADD COLUMN.
551                 dict.__setitem__(t, oldname, oldcol)
552            
553             t.rename(oldname, newname)
554         except geniusql.errors.MappingError, x:
555             conflict(conflicts, str(x))
556    
557     def add_index(self, cls, name, conflicts='error'):
558         """Add an index for the given property.
559         
560         conflicts: see errors.conflict.
561         """
562         try:
563             self.schema[cls.__name__].add_index(name)
564         except geniusql.errors.MappingError, x:
565             conflict(conflicts, str(x))
566    
567     def has_index(self, cls, name):
568         """If an index exists for the given property, return True."""
569         return name in self.schema[cls.__name__].indices
570    
571     def drop_index(self, cls, name, conflicts='error'):
572         """Destroy any index on the given property.
573         
574         conflicts: see errors.conflict.
575         """
576         try:
577             del self.schema[cls.__name__].indices[name]
578         except geniusql.errors.MappingError, x:
579             conflict(conflicts, str(x))
580    
581     auto_discover = True
582    
583     def map(self, classes, conflicts='error'):
584         """Map classes to internal storage.
585         
586         If self.auto_discover is True (the default), then Table/Column/Index
587         objects will be formed by inspecting the underlying database using
588         self.sync().
589         
590         If auto_discover is False, then mock Table/Column/Index objects
591         will be used instead; this provides a performance improvement
592         in scenarios where the model maps perfectly to the database
593         and changes to the database are not expected outside the model.
594         
595         conflicts: see errors.conflict.
596         """
597         # Map tables before views, because views depend on them
598         tables = [cls for cls in classes if not hasattr(cls, 'view_statement')]
599         views = [cls for cls in classes if hasattr(cls, 'view_statement')]
600         classes = tables + views
601        
602         if self.auto_discover:
603             self.db.discover_dbinfo()
604             self.sync(classes, conflicts)
605         else:
606             for cls in classes:
607                 if self.has_storage(cls):
608                     # If our consumer-side key is already present, skip this cls.
609                     # This allows callers to auto-sync class by class
610                     # without making a new Table object each time.
611                     continue
612                
613                 t = self._make_table(cls)
614                
615                 # Use the superclass call to avoid DROP/CREATE TABLE
616                 dict.__setitem__(self.schema, cls.__name__, t)
617    
618     def sync(self, classes, conflicts='error'):
619         """Map classes to existing Table objects (found via discovery).
620         
621         conflicts: see errors.conflict.
622         """
623         for cls in classes:
624             if cls.__name__ in self.schema:
625                 # If our consumer-side key is already present, skip this cls.
626                 # This allows callers to auto-sync class by class
627                 # without calling the expensive discover() func each time.
628                 continue
629             self._find_table(self.schema, cls, conflicts=conflicts)
630    
631     def _find_table(self, schema, cls, conflicts='error'):
632         # This is broken out to make multi-schema subclasses easier to write.
633        
634         # Try to find a matching Table or View object using the DB-side key.
635         clsname = cls.__name__
636         tablename = schema.table_name(clsname)
637         try:
638             # Do we already have a map using the DB name?
639             table = schema[tablename]
640             schema.alias(table.name, clsname)
641         except KeyError:
642             # Can we create a map? Discover the DB table and try again.
643             try:
644                 table = schema.discover(tablename)
645                 schema.alias(table.name, clsname)
646             except geniusql.errors.MappingError:
647                 msg = "%s: no such table %r." % (clsname, tablename)
648                 if self.logflags & logflags.DDL:
649                     self.log(logflags.DDL.message(msg))
650                 if conflicts == 'repair':
651                     # Didn't have one. Couldn't find one. Let's make one.
652                     self.create_storage(cls)
653                     table = schema[clsname]
654                 else:
655                     conflict(conflicts, msg)
656                     return
657        
658         # Match Column objects with class properties.
659         dbcols = dict([(c.name, c) for c in table.itervalues()])
660         indices = cls.indices()
661         for pkey in cls.properties:
662             colname = schema.column_name(table.name, pkey)
663             try:
664                 col = dbcols[colname]
665                 table.alias(colname, pkey)
666             except KeyError, x:
667                 msg = "%s: no column found for %r." % (clsname, pkey)
668                 if self.logflags & logflags.DDL:
669                     self.log(logflags.DDL.message(msg))
670                 if conflicts == 'repair':
671                     self.add_property(cls, pkey)
672                     if pkey in cls.indices() and pkey not in table.indices:
673                         self.add_index(cls, pkey)
674                     col = table[pkey]
675                 else:
676                     conflict(conflicts, msg)
677                     continue
678            
679             # Check that the column.key matches our identifiers list;
680             # this is crucial for the proper operation of OLTP methods
681             # in geniusql.Table, which uses column.key to decide
682             # the unique identifiers for a given row of data.
683             if not isinstance(table, geniusql.View):
684                 if pkey in cls.identifiers and not col.key:
685                     msg = ("%s: %r is an identifier, but the "
686                            "column is not marked as a primary key."
687                            % (clsname, pkey))
688                     if self.logflags & logflags.DDL:
689                         self.log(logflags.DDL.message(msg))
690                     if conflicts == 'repair':
691                         col.key = True
692                         table.set_primary()
693                     else:
694                         conflict(conflicts, msg)
695                         continue
696                 elif col.key and not pkey in cls.identifiers:
697                     msg = ("%s: %r is not an identifier, but the "
698                            "column is marked as a primary key."
699                            % (clsname, pkey))
700                     if self.logflags & logflags.DDL:
701                         self.log(logflags.DDL.message(msg))
702                     if conflicts == 'repair':
703                         col.key = False
704                         # Just because the current pkey is not an identifier
705                         # doesn't mean we have *no* identifiers.
706                         table.set_primary()
707                     else:
708                         conflict(conflicts, msg)
709                         continue
710            
711             col.pytype = getattr(cls, pkey).type
712            
713             # Override the default adapter (since it guessed an adapter
714             # using the default pytype, and we know better).
715             try:
716                 col.adapter = col.dbtype.default_adapter(col.pytype)
717             except TypeError, x:
718                 x.args += ("%s.%s" % (table.name, col.name),)
719                 raise
720            
721             if not hasattr(table, "indices"):
722                 # Skip indices for Views.
723                 continue
724            
725             # Try to find matching Index objects. Because index names are
726             # so platform-specific, we match attributes rather than names.
727             if pkey in indices:
728                 for ikey, idx in table.indices.items():
729                     if idx.colname == colname:
730                         table.indices.alias(ikey, schema.index_name(table, pkey))
731                         break
732                 else:
733                     msg = "%s: no index found for %r." % (clsname, pkey)
734                     if self.logflags & logflags.DDL:
735                         self.log(logflags.DDL.message(msg))
736                     if conflicts == 'repair':
737                         self.add_index(cls, pkey)
738                     else:
739                         conflict(conflicts, msg)
740                         continue
741             else:
742                 if pkey in cls.identifiers and self.db.pks_must_be_indexed:
743                     pass
744                 else:
745                     for ikey, idx in table.indices.items():
746                         if idx.colname == colname:
747                             msg = ("%s: index found for non-indexed %r."
748                                    % (clsname, pkey))
749                             if self.logflags & logflags.DDL:
750                                 self.log(logflags.DDL.message(msg))
751                             if conflicts == 'repair':
752                                 self.drop_index(cls, ikey)
753                             else:
754                                 conflict(conflicts, msg)
755                                 continue
756        
757         # Set Table.references
758         for k, v in cls._associations.iteritems():
759             table.references[k] = (v.nearKey, v.farClass.__name__, v.farKey)
760        
761         return table
762    
763     #                            Transactions                             #
764    
765     def start(self, isolation=None):
766         "Start a transaction (not needed if db.connections.implicit_trans)."
767         self.db.connections.start(isolation)
768    
769     def rollback(self):
770         """Roll back the current transaction."""
771         self.db.connections.rollback()
772    
773     def commit(self):
774         """Commit the current transaction."""
775         self.db.connections.commit()
776
777
778 class Modeler(object):
779     """Tool to automatically form Unit classes or source from existing DB's."""
780    
781     ignore = ['Unit', 'DeployedVersion',
782               'UnitEngine', 'UnitEngineRule', 'UnitCollection',
783               ]
784    
785     def __init__(self, schema):
786         self.schema = schema
787         self.ignore = self.ignore[:]
788    
789     def all_classes(self):
790         """Return a list of new classes for all tables in the Database."""
791         ignore = dict.fromkeys([self.schema.table_name(x) for x in self.ignore]
792                                + self.ignore).keys()
793        
794         self.schema.discover_all(ignore=ignore)
795        
796         classes = []
797         seen = {}
798         for key, table in self.schema.items():
799             if key not in ignore and table.name not in seen:
800                 cls = self.make_class(key)
801                 classes.append(cls)
802                 seen[table.name] = None
803         return classes
804    
805     def make_class(self, tablename, newclassname=None):
806         """Create a Unit class automatically from the named table."""
807         if tablename not in self.schema:
808             self.schema.discover(tablename)
809         table = self.schema[tablename]
810        
811         class AutoUnitClass(dejavu.Unit):
812             sequencer = dejavu.UnitSequencer()
813             identifiers = tuple([k for k in table if table[k].key])
814        
815         if newclassname is None:
816             newclassname = table.name
817             # The key is probably better than the table.name. Try it.
818             for key, t in self.schema.iteritems():
819                 if t.name == newclassname:
820                     newclassname = key
821                     break
822         AutoUnitClass.__name__ = newclassname
823        
824         indices = [idx.colname for idx in table.indices.itervalues()]
825         for cname, c in table.iteritems():
826             ptype = c.pytype
827             if ptype == int and c.dbtype.bytes == 1:
828                 # This is probably a bool
829                 ptype = bool
830             p = AutoUnitClass.set_property(cname, ptype)
831             if c.autoincrement:
832                 AutoUnitClass.sequencer = dejavu.UnitSequencerInteger(int, c.initial)
833             p.default = c.default
834            
835             p.hints = dict([(k, getattr(c.dbtype, k))
836                             for k in ("bytes", "precision", "scale")
837                             if hasattr(c.dbtype, k)])
838             if p.hints:
839                 # Postgresql hack: replace bytes=ComparableInfinity with 0,
840                 # since 0 signifies "no limit".
841                 for k, v in p.hints.iteritems():
842                     if v.__class__.__name__ == 'ComparableInfinity':
843                         p.hints[k] = 0
844            
845             p.index = (cname in indices)
846        
847         # Remove default ID property if necessary.
848         if "ID" not in table:
849             AutoUnitClass.properties.remove('ID')
850             AutoUnitClass.ID = None
851        
852         return AutoUnitClass
853    
854     def all_source(self):
855         """Return a list of strings of Unit source code for all tables."""
856         ignore = dict.fromkeys([self.schema.table_name(x) for x in self.ignore]
857                                + self.ignore).keys()
858        
859         self.schema.discover_all(ignore=ignore)
860        
861         allcode = []
862         seen = {}
863         tables = self.schema.items()
864         tables.sort()
865         for key, table in tables:
866             if key not in ignore and table.name not in seen:
867                 code = self.make_source(key)
868                 allcode.append(code)
869                 seen[table.name] = None
870         return allcode
871    
872     def make_source(self, tablename, newclassname=None):
873         """Create source code for a Unit class from the named table."""
874         if tablename not in self.schema:
875             self.schema.discover(tablename)
876         table = self.schema[tablename]
877        
878         code = []
879        
880         if newclassname is None:
881             newclassname = table.name
882             # The key is probably better than the table.name. Try it.
883             for key, t in self.schema.iteritems():
884                 if t.name == newclassname:
885                     newclassname = key
886                     break
887             # Make the name safe for use as a Python class name
888             newclassname = newclassname.replace(".", "_")
889         code.append("class %s(Unit):" % newclassname)
890        
891         if table.description:
892             import textwrap
893             block = textwrap.fill(table.description, subsequent_indent='    ')
894             code.append('    """%s"""' % block)
895        
896         sequencer = None
897         indices = [idx.colname for idx in table.indices.itervalues()]
898        
899         # iterate over all columns
900         columns = table.items()
901         columns.sort()
902         for cname, c in columns:
903             prop, seq = self._make_column_source(cname, c, cname in indices)
904             if not prop.startswith("    ID = UnitProperty(int"):
905                 code.append(prop)
906             if seq:
907                 sequencer = seq
908        
909         # Remove default ID property if necessary.
910         if "ID" not in table:
911             code.append("    # Remove the default 'ID' property.")
912             code.append("    ID = None")
913        
914         pk = tuple([k for k in table if table[k].key])
915         if pk != ("ID",):
916             code.append("    identifiers = %s" % repr(pk))
917        
918         if sequencer:
919             if sequencer != "    sequencer = UnitSequencerInteger(int, 1)":
920                 code.append(sequencer)
921         else:
922             code.append("    sequencer = UnitSequencer()")
923        
924         if len(code) == 1:
925             code.append("    pass")
926        
927         return "\n".join(code)
928    
929     def _make_column_source(self, colname, column, has_index):
930         ptype = column.pytype
931         if ptype == int and column.dbtype.bytes == 1:
932             # This is probably a bool
933             ptype = bool
934        
935         mod = ptype.__module__
936         if mod == '__builtin__':
937             ptype = ptype.__name__
938         else:
939             ptype = mod + "." + ptype.__name__
940        
941         seq = None
942         if column.autoincrement:
943             seq = ("    sequencer = UnitSequencerInteger(int, %r)" %
944                    column.initial)
945        
946         default = column.default
947         if default is None:
948             default = ""
949         else:
950             default = ", default=%r" % default
951        
952         index = ""
953         if has_index:
954             index = ", index=True"
955        
956         hints = dict([(k, getattr(column.dbtype, k))
957                       for k in ("bytes", "precision", "scale")
958                       if hasattr(column.dbtype, k)])
959         if hints:
960             # Postgresql hack: replace bytes=ComparableInfinity with 0,
961             # since 0 signifies "no limit".
962             for k, v in hints.iteritems():
963                 if v.__class__.__name__ == 'ComparableInfinity':
964                     hints[k] = 0
965             hints = ", hints=%r" % hints
966         else:
967             hints = ""
968        
969         return ("    %s = UnitProperty(%s%s%s%s)" %
970                 (colname, ptype, index, hints, default)), seq
971
Note: See TracBrowser for help on using the browser.