Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

Changeset 226

Show
Ignore:
Timestamp:
07/16/06 04:30:04
Author:
fumanchu
Message:

First crack at the new dbmodel module. See #18, #62.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/storage/db.py

    r225 r226  
    6767import dejavu 
    6868from dejavu import codewalk, logic, storage, LOGSQL, xray 
     69from dbmodel import * 
    6970 
    7071 
     
    116117     
    117118    def coerce(self, cls, key): 
    118         """coerce(cls, key) -> SQL typename for valuetype.""" 
     119        """Return the SQL datatype name for valuetype.""" 
    119120        valuetype = cls.property(key).type 
    120121        mod = valuetype.__module__ 
     
    634635        if isinstance(tos, TableRef): 
    635636            # The name in question refers to a DB column. 
    636             atom = self.sm.column_name(tos.classname, name, full=True) 
     637            colname = self.sm.column_name(tos.classname, name) 
     638            alias = getattr(tos.classname, "alias", None) 
     639            if alias is None: 
     640                tname = self.sm.table_name(tos.classname) 
     641            else: 
     642                tname = (tos.classname.alias or tos.classname.tablename) 
     643            atom = '%s.%s' % (self.sm.quote(tname), self.sm.quote(colname)) 
    637644        else: 
    638645            # tos.name will reference an attribute of the tos object. 
     
    942949 
    943950 
     951 
    944952# --------------------------- Storage Manager --------------------------- # 
    945953 
     
    953961         
    954962        wclsname = wclass.__name__ 
    955         self.tablename = sm.table_name(wclsname) 
     963        self.tablename = sm.tables[wclsname].name 
    956964        self.alias = "" 
    957965     
    958966    def columns(self): 
     967        """Return [(wclass, UnitProperty.key), ...], ['"tbl"."col"', ...].""" 
    959968        wclass = self.cls 
    960969         
     
    964973                                           if k not in wclass.identifiers] 
    965974        cols = [(wclass, k) for k in keys] 
    966         colnames = ['%s.%s' % (self.alias or self.tablename
    967                                self.sm.column_name(wclass.__name__, k)) 
     975        colnames = ['%s.%s' % (self.sm.quote(self.alias or self.tablename)
     976                               self.sm.quote(self.sm.column_name(wclass.__name__, k))) 
    968977                    for k in keys] 
    969978        return cols, colnames 
    970979     
    971980    def _joinname(self): 
     981        q = self.sm.quote 
    972982        if self.alias: 
    973             return "%s AS %s" % (self.tablename, self.alias
    974         else: 
    975             return self.tablename 
    976     joinname = property(_joinname, doc=("Table name for use in " 
    977                                             "JOIN clause (read-only).")) 
     983            return "%s AS %s" % (q(self.tablename), q(self.alias)
     984        else: 
     985            return q(self.tablename) 
     986    joinname = property(_joinname, doc=("Quoted table name for use in " 
     987                                        "JOIN clause (read-only).")) 
    978988     
    979989    def association(self, classes): 
     
    9991009    use_asterisk_to_get_all = False 
    10001010     
     1011    prefix = "" 
     1012     
    10011013    decompiler = SQLDecompiler 
    10021014    typeAdapter = FieldTypeAdapter() 
     
    10041016    fromAdapter = AdapterFromDB() 
    10051017     
     1018    tablesetclass = TableSet 
     1019    tableclass = Table 
     1020    columnsetclass = ColumnSet 
     1021    columnclass = Column 
     1022    indexsetclass = IndexSet 
     1023    indexclass = Index 
     1024     
    10061025    def __init__(self, name, arena, allOptions={}): 
    10071026        storage.StorageManager.__init__(self, name, arena, allOptions) 
    10081027         
    10091028        # Adapter Overrides 
    1010         def get_adapter_option(name): 
    1011             adapter_class = allOptions.get(name) 
    1012             if isinstance(adapter_class, basestring): 
    1013                 adapter_class = xray.classes(adapter_class
    1014             return adapter_class 
    1015          
    1016         adapter = get_adapter_option('Type Adapter') 
     1029        def get_option(name): 
     1030            item = allOptions.get(name) 
     1031            if isinstance(item, basestring): 
     1032                item = xray.classes(item
     1033            return item 
     1034         
     1035        adapter = get_option('Type Adapter') 
    10171036        if adapter: self.typeAdapter = adapter 
    1018         adapter = get_adapter_option('To Adapter') 
     1037        adapter = get_option('To Adapter') 
    10191038        if adapter: self.toAdapter = adapter 
    1020         adapter = get_adapter_option('From Adapter') 
     1039        adapter = get_option('From Adapter') 
    10211040        if adapter: self.fromAdapter = adapter 
     1041         
     1042        adapter = get_option('TableSet Class') 
     1043        if adapter: self.tablesetclass = adapter 
     1044        self.tables = self.tablesetclass(self) 
    10221045         
    10231046        size = int(allOptions.get('Pool Size', '10')) 
     
    10271050            self.connection = ConnectionFactory(self._get_conn, self._del_conn) 
    10281051         
    1029         self.prefix = allOptions.get('Prefix', "djv") 
     1052        self.prefix = allOptions.get('Prefix', "") 
    10301053        self.reserve_lock = threading.Lock() 
    10311054     
    10321055    #                               Naming                               # 
    10331056     
    1034     def sql_name(self, name, quoted=True): 
    1035         """The name, escaped for SQL.""" 
     1057    def quote(self, name): 
     1058        """Return name, quoted for use in an SQL statement.""" 
     1059        # This base class doesn't use "quote", 
     1060        # but most subclasses will. 
     1061        return name 
     1062     
     1063    def sql_name(self, name): 
     1064        """Return the native SQL version of name.""" 
    10361065        if self.sql_name_caseless: 
    10371066            name = name.lower() 
     
    10441073            name = name[:maxlen] 
    10451074         
    1046         # This base class doesn't use the "quoted" arg, 
    1047         # but most subclasses will. 
    10481075        return name 
    10491076     
    1050     def column_name(self, classname, name, full=False, quoted=True): 
    1051         """The column name, escaped for SQL. If full, include tablename.""" 
     1077    def column_name(self, classname, name): 
     1078        """The column name, escaped for SQL.""" 
    10521079        # If you want to use a map from UnitProperty names 
    1053         # to DB column names, override this method. 
    1054         name = self.sql_name(name, quoted=quoted) 
    1055         if not full: 
    1056             return name 
    1057          
    1058         alias = getattr(classname, "alias", None) 
    1059         if alias is None: 
    1060             tname = self.table_name(classname, quoted=quoted) 
    1061         else: 
    1062             tname = (classname.alias or classname.tablename) 
    1063         return '%s.%s' % (tname, name) 
    1064      
    1065     def table_name(self, name, quoted=True): 
    1066         """The table name, escaped for SQL.""" 
     1080        # to DB column names, override this method (that's why 
     1081        # the classname must be included in the args). 
     1082        return self.sql_name(name) 
     1083     
     1084    def table_name(self, name): 
     1085        """Return the SQL table name for the given key.""" 
    10671086        # If you want to use a map from Unit class names 
    10681087        # to DB table names, override this method. 
    1069         return self.sql_name(self.prefix + name, quoted=quoted
     1088        return self.sql_name(self.prefix + name
    10701089     
    10711090    #                             Connecting                              # 
     
    10911110        """ 
    10921111        clsname = cls.__name__ 
    1093         tablename = self.table_name(clsname) 
     1112        tablename = self.tables[clsname].name 
    10941113        if fields: 
    1095             fields = [self.column_name(clsname, x) for x in fields] 
     1114            fields = [self.quote(self.column_name(clsname, x)) for x in fields] 
    10961115            if distinct: 
    10971116                sql = 'SELECT DISTINCT %s FROM %s' 
    10981117            else: 
    10991118                sql = 'SELECT %s FROM %s' 
    1100             sql = sql % (', '.join(fields), tablename
    1101         else: 
    1102             sql = 'SELECT * FROM %s' % tablename 
     1119            sql = sql % (', '.join(fields), self.quote(tablename)
     1120        else: 
     1121            sql = 'SELECT * FROM %s' % self.quote(tablename) 
    11031122         
    11041123        w, i = self.where((clsname,), expr) 
     
    11301149    def fetch(self, query, conn=None): 
    11311150        """fetch(query, conn=None) -> rowdata, columns. 
    1132  
    1133         query should be a SQL query in string format         
     1151         
     1152        query should be a SQL query in string format 
    11341153        rowdata will be an iterable of iterables containing the result values. 
    11351154        columns will be an iterable of (column name, data type) pairs. 
     
    11571176            idnames = list(cls.identifiers) 
    11581177            for key in idnames + [x for x in cls.properties if x not in idnames]: 
    1159                 index, ftype = columns[self.column_name(clsname, key, quoted=False)] 
     1178                index, ftype = columns[self.column_name(clsname, key)] 
    11601179                props.append((key, index, ftype)) 
    11611180             
     
    11961215        cls = unit.__class__ 
    11971216        clsname = cls.__name__ 
    1198         tablename = self.table_name(clsname) 
     1217        tablename = self.tables[clsname].name 
    11991218        if not unit.sequencer.valid_id(unit.identity()): 
    12001219            # Examine all existing IDs and grant the "next" one. 
    1201             id_fields = [self.column_name(clsname, key
     1220            id_fields = [self.quote(self.column_name(clsname, key)
    12021221                         for key in cls.identifiers] 
    12031222            data, cols = self.fetch('SELECT %s FROM %s;' % 
    1204                                     (', '.join(id_fields), tablename)) 
     1223                                    (', '.join(id_fields), self.quote(tablename))) 
    12051224            if data: 
    12061225                # sqlite 2, for example, has empty cols tuple if no data. 
     
    12261245        for key in cls.properties: 
    12271246            val = self.toAdapter.coerce(getattr(unit, key)) 
    1228             fields.append(self.column_name(clsname, key)) 
     1247            fields.append(self.quote(self.column_name(clsname, key))) 
    12291248            values.append(val) 
    12301249         
     
    12321251        values = ", ".join(values) 
    12331252        self.execute('INSERT INTO %s (%s) VALUES (%s);' % 
    1234                      (str(tablename), fields, values)) 
     1253                     (self.quote(tablename), fields, values)) 
    12351254     
    12361255    def id_clause(self, unit): 
     
    12391258        col = self.column_name 
    12401259        c = self.toAdapter.coerce 
    1241         return " AND ".join(["%s = %s" % (col(clsname, key), 
     1260        return " AND ".join(["%s = %s" % (self.quote(col(clsname, key)), 
    12421261                                          c(getattr(unit, key))) 
    12431262                             for key in unit.identifiers]) 
     
    12541273                    val = self.toAdapter.coerce(getattr(unit, key)) 
    12551274                    parms.append('%s = %s' % 
    1256                                  (self.column_name(clsname, key), val)) 
     1275                                 (self.quote(self.column_name(clsname, key)), 
     1276                                  val)) 
    12571277             
    12581278            if parms: 
    12591279                sql = ('UPDATE %s SET %s WHERE %s;' % 
    1260                        (self.table_name(clsname), ", ".join(parms), 
     1280                       (self.quote(self.tables[clsname].name), 
     1281                        ", ".join(parms), 
    12611282                        self.id_clause(unit))) 
    12621283                self.execute(sql) 
     
    12701291            star = "" 
    12711292        self.execute('DELETE%s FROM %s WHERE %s;' % 
    1272                      (star, self.table_name(unit.__class__.__name__), 
     1293                     (star, self.quote(self.tables[unit.__class__.__name__].name), 
    12731294                      self.id_clause(unit))) 
    12741295     
     
    13671388            msg = ("No association found between %s and %s." % (name1, name2)) 
    13681389            raise dejavu.AssociationError(msg) 
    1369         near = '%s.%s' % (nearClass, self.column_name(nearClass, ua.nearKey)) 
    1370         far = '%s.%s' % (farClass, self.column_name(farClass, ua.farKey)) 
     1390         
     1391        near = '%s.%s' % (self.quote(nearClass), 
     1392                          self.quote(self.column_name(nearClass, ua.nearKey))) 
     1393        far = '%s.%s' % (self.quote(farClass), 
     1394                         self.quote(self.column_name(farClass, ua.farKey))) 
    13711395         
    13721396        return "(%s %s JOIN %s ON %s = %s)" % (name1, j, name2, near, far) 
     
    14621486     
    14631487    def create_database(self): 
    1464         self.execute("CREATE DATABASE %s;" % self.sql_name(self.dbname)) 
     1488        self.execute("CREATE DATABASE %s;" % self.quote(self.sql_name(self.dbname))) 
    14651489     
    14661490    def drop_database(self): 
    1467         self.execute("DROP DATABASE %s;" % self.sql_name(self.dbname)) 
     1491        self.execute("DROP DATABASE %s;" % self.quote(self.sql_name(self.dbname))) 
    14681492     
    14691493    def create_storage(self, cls): 
    14701494        """Create storage for the given class.""" 
    1471         clsname = cls.__name__ 
    1472         tablename = self.table_name(clsname) 
    1473         typename = self.typeAdapter.coerce 
    1474          
     1495        colname = self.column_name 
     1496         
     1497        # Make a Table object. 
     1498        tablename = self.table_name(cls.__name__) 
     1499        t = self.tableclass(self, tablename) 
     1500         
     1501        indices = cls.indices() 
    14751502        fields = [] 
    14761503        for key in cls.properties: 
    1477             fields.append('%s %s' % (self.column_name(clsname, key), 
    1478                                      typename(cls, key))) 
    1479         self.execute('CREATE TABLE %s (%s);' % (tablename, ", ".join(fields))) 
    1480          
    1481         for index in cls.indices(): 
    1482             i = self.table_name("i" + clsname + index) 
    1483             self.execute('CREATE INDEX %s ON %s (%s);' % 
    1484                          (i, tablename, self.column_name(clsname, index))) 
     1504            dbtype = self.typeAdapter.coerce(cls, key) 
     1505            prop = cls.property(key) 
     1506            cname = colname(cls.__name__, key) 
     1507            col = self.columnclass(cname, dbtype, prop.type, 
     1508                                   prop.default, prop.hints.copy()) 
     1509            # Use the superclass call to avoid ALTER TABLE. 
     1510            dict.__setitem__(t.columns, key, col) 
     1511             
     1512            if key in indices: 
     1513                iname = self.table_name("i" + cls.__name__ + key) 
     1514                i = self.indexclass(iname, tablename, cname) 
     1515                # Use the superclass call to avoid CREATE INDEX. 
     1516                dict.__setitem__(t.columns.indices, key, i) 
     1517         
     1518        # Attach to self.tables, which should call CREATE TABLE. 
     1519        self.tables[cls.__name__] = t 
    14851520     
    14861521    def has_storage(self, cls): 
    1487         try: 
    1488             # Must use fetch here instead of execute, because e.g. MySQL 
    1489             # must call store_result if the query has a result set 
    1490             # (or it will crash on a subsequent execute). 
    1491             self.fetch("SELECT * FROM %s;" % self.table_name(cls.__name__)) 
    1492         except: 
    1493             return False 
    1494         return True 
     1522        return cls.__name__ in self.tables 
    14951523     
    14961524    def drop_storage(self, cls): 
    1497         self.execute('DROP TABLE %s;' % self.table_name(cls.__name__)) 
     1525        del self.tables[cls.__name__] 
     1526     
     1527    def rename_storage(self, oldname, newname): 
     1528        self.arena.log("rename table %s to %s" % (oldname, newname), 
     1529                       dejavu.LOGSQL) 
     1530        self.tables.rename(oldname, newname) 
    14981531     
    14991532    def add_property(self, cls, name): 
    15001533        if not self.has_property(cls, name): 
    1501             clsname = cls.__name__ 
    1502             self.execute("ALTER TABLE %s ADD COLUMN %s %s;" % 
    1503                          (self.table_name(clsname), 
    1504                           self.column_name(clsname, name)
    1505                           self.typeAdapter.coerce(cls, name), 
    1506                           )) 
     1534            cname = self.column_name(cls.__name__, name) 
     1535            dbtype = self.typeAdapter.coerce(cls, name) 
     1536            prop = getattr(cls, name) 
     1537            c = self.columnclass(cname, dbtype, prop.type
     1538                                 prop.default, prop.hints.copy()) 
     1539            self.tables[cls.__name__].columns[name] = c 
    15071540     
    15081541    def has_property(self, cls, name): 
    1509         clsname = cls.__name__ 
    1510         try: 
    1511             # Must use fetch here instead of execute, because e.g. MySQL 
    1512             # must call store_result if the query has a result set 
    1513             # (or it will crash on a subsequent execute). 
    1514             self.fetch("SELECT %s FROM %s;" % 
    1515                        (self.column_name(clsname, name), 
    1516                         self.table_name(clsname))) 
    1517         except: 
    1518             return False 
    1519         return True 
     1542        return name in self.tables[cls.__name__].columns 
    15201543     
    15211544    def drop_property(self, cls, name): 
    15221545        if self.has_property(cls, name): 
    1523             clsname = cls.__name__ 
    1524             if self.has_index(cls, name): 
    1525                 self.drop_index(cls, name) 
    1526             self.execute("ALTER TABLE %s DROP COLUMN %s;" % 
    1527                          (self.table_name(clsname), 
    1528                           self.column_name(clsname, name))) 
     1546            del self.tables[cls.__name__].columns[name] 
    15291547     
    15301548    def rename_property(self, cls, oldname, newname): 
    1531         clsname = cls.__name__ 
    1532         oldname = self.column_name(clsname, oldname) 
    1533         newname = self.column_name(clsname, newname) 
    1534         if oldname != newname: 
    1535             self.execute("ALTER TABLE %s RENAME COLUMN %s TO %s;" % 
    1536                          (self.table_name(clsname), oldname, newname)) 
     1549        self.tables[cls.__name__].columns.rename(oldname, newname) 
    15371550     
    15381551    def has_index(self, cls, name): 
    1539         tablename = self.table_name(cls.__name__, quoted=False) 
    1540         indices = [i.colname for i in self.get_indices(tablename)] 
    1541         return (name in indices) 
     1552        return name in self.tables[cls.__name__].columns.indices 
    15421553     
    15431554    def drop_index(self, cls, name): 
    1544         clsname = cls.__name__ 
    1545         self.execute('DROP INDEX %s ON %s;' % 
    1546                      (self.sql_name("i" + clsname + name), 
    1547                       self.table_name(clsname))) 
    1548  
    1549  
    1550 class Table: 
    1551     """A table in a database.""" 
    1552      
    1553     def __init__(self, name): 
    1554         self.name = name 
    1555         self.columns = [] 
    1556      
    1557     def __repr__(self): 
    1558         return "dejavu.db.Table(%s)" % repr(self.name) 
    1559  
    1560  
    1561 class Column: 
    1562     """A column in a table in a database.""" 
    1563      
    1564     def __init__(self, key, type, default=None): 
    1565         self.key = key 
    1566         self.type = type 
    1567         self.default = default 
    1568         self.hints = {} 
    1569      
    1570     def __repr__(self): 
    1571         return ("dejavu.db.Column(%s, %s, default=%s, hints=%s)" % 
    1572                 (repr(self.key), repr(self.type), 
    1573                  repr(self.default), repr(self.hints)) 
    1574                 ) 
    1575  
    1576  
    1577 class Index: 
    1578     """An index on a table column (or columns) in a database.""" 
    1579      
    1580     def __init__(self, name, tablename, colname, pk=True, unique=True): 
    1581         self.name = name 
    1582         self.tablename = tablename 
    1583         self.colname = colname 
    1584         self.pk = pk 
    1585         self.unique = unique 
    1586      
    1587     def __repr__(self): 
    1588         return ("dejavu.db.Index(%s, %s, %s, pk=%s, unique=%s)" % 
    1589                 (repr(self.name), repr(self.tablename), repr(self.colname), 
    1590                  repr(self.pk), repr(self.unique))) 
    1591  
     1555        del self.tables[cls.__name__].columns.indices[name] 
     1556     
     1557    def sync(self, conn=None): 
     1558        """Populate self using all registered classes.""" 
     1559        # Use the superclass call to avoid DROP TABLE. 
     1560        dict.clear(self.tables) 
     1561         
     1562        dbtables = self.tables._get_tables(conn) 
     1563        for cls in self.arena._registered_classes: 
     1564            # Try to find a matching Table object from _get_tables. 
     1565            db_tname = self.prefix + self.table_name(cls.__name__) 
     1566            t = [x for x in dbtables if x.name == db_tname] 
     1567            if t: 
     1568                t = t[0] 
     1569                for c in self._get_columns(t.name): 
     1570                    # Use the superclass call to avoid ALTER TABLE 
     1571                    dict.__setitem__(t.columns, c.name, c) 
     1572                # Use the superclass call to avoid CREATE TABLE 
     1573                dict.__setitem__(self, db_tname, t) 
     1574     
     1575    def autoclass(self, table, newclassname=None): 
     1576        """Create a Unit class automatically from this table and its columns.""" 
     1577        class AutoUnitClass(dejavu.Unit): 
     1578            pass 
     1579        for cname, c in table.columns.iteritems(): 
     1580            AutoUnitClass.set_property(cname, c.type) 
     1581         
     1582        if newclassname is None: 
     1583            newclassname = table.name 
     1584        AutoUnitClass.__name__ = newclassname 
     1585         
     1586        return AutoUnitClass 
     1587 
  • trunk/storage/storeado.py

    r225 r226  
    5454zeroHour = datetime.date(1899, 12, 30).toordinal() 
    5555 
    56 # DataTypeEnum 
    57 adEmpty = 0 
    58 adSmallInt = 2 
    59 adInteger = 3 
    60 adSingle = 4 
    61 adDouble = 5 
    62 adCurrency = 6 
    63 adDate = 7 
    64 adBSTR = 8 
    65 adIDispatch = 9 
    66 adError = 10 
    67 adBoolean = 11 
    68 adVariant = 12 
    69 adIUnknown = 13 
    70 adDecimal = 14 
    71 adTinyInt = 16 
    72 adUnsignedTinyInt = 17 
    73 adUnsignedSmallInt = 18 
    74 adUnsignedInt = 19 
    75 adBigInt = 20 
    76 adUnsignedBigInt = 21 
    77 adGUID = 72 # e.g. {E5D50A9B-33D2-11D3-AAB3-00104BA31425} 
    78 adBinary = 128 
    79 adChar = 129 
    80 adWChar = 130 
    81 adNumeric = 131 
    82 adUserDefined = 132 
    83 adDBDate = 133 
    84 adDBTime = 134 
    85 adDBTimeStamp = 135 
    86 adVarChar = 200 
    87 adLongVarChar = 201 
    88 adVarWChar = 202 
    89 adLongVarWChar = 203 
    90 adVarBinary = 204 
    91 adLongVarBinary = 205 
    92  
     56dbtypes = { 
     57    0: 'EMPTY',                     2: 'SMALLINT', 
     58    3: 'INTEGER',                   4: 'SINGLE', 
     59    5: 'DOUBLE',                    6: 'CURRENCY', 
     60    7: 'DATE',                      8: 'BSTR', 
     61    9: 'IDISPATCH',                 10: 'ERROR', 
     62    11: 'BOOLEAN',                  12: 'VARIANT', 
     63    13: 'IUNKNOWN',                 14: 'DECIMAL', 
     64    16: 'TINYINT',                  17: 'UNSIGNEDTINYINT', 
     65    18: 'UNSIGNEDSMALLINT',         19: 'UNSIGNEDINT', 
     66    20: 'BIGINT',                   21: 'UNSIGNEDBIGINT', 
     67    72: 'GUID',                     128: 'BINARY', 
     68    129: 'CHAR',                    130: 'WCHAR', 
     69    131: 'NUMERIC',                 132: 'USERDEFINED', 
     70    133: 'DBDATE',                  134: 'DBTIME', 
     71    135: 'DBTIMESTAMP',             200: 'VARCHAR', 
     72    201: 'LONGVARCHAR',             202: 'VARWCHAR', 
     73    203: 'LONGVARWCHAR',            204: 'VARBINARY', 
     74    205: 'LONGVARBINARY' 
     75
    9376 
    9477def time_from_com(com_date): 
     
    175158     
    176159    def coerce_float(self, value, coltype): 
    177         if coltype == adCurrency and isinstance(value, tuple): 
     160        if coltype == 0x06 and isinstance(value, tuple): 
    178161            # See http://groups.google.com/group/comp.lang.python/ 
    179162            #           browse_frm/thread/fed03c64735c9e9c 
     
    353336 
    354337 
     338class ADOColumnSet(db.ColumnSet): 
     339     
     340    def _rename(self, oldcol, newname): 
     341        conn = self.table.sm.connection() 
     342        try: 
     343            cat = win32com.client.Dispatch(r'ADOX.Catalog') 
     344            cat.ActiveConnection = conn 
     345            cat.Tables(self.table.name).Columns(oldcol.name).Name = newname 
     346        finally: 
     347            conn = None 
     348            cat = None 
     349 
     350 
     351class ADOTableSet(db.TableSet): 
     352     
     353    def _get_tables(self, conn=None): 
     354        # cols will be 
     355        # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 
     356        # (u'TABLE_TYPE', 202), (u'TABLE_GUID', 72), (u'DESCRIPTION', 203), 
     357        # (u'TABLE_PROPID', 19), (u'DATE_CREATED', 7), (u'DATE_MODIFIED', 7)] 
     358        data, cols = self.sm.fetch(adSchemaTables, conn=conn, schema=True) 
     359        return [self.sm.tableclass(self.sm, row[2]) for row in data] 
     360     
     361    def _get_columns(self, tablename, conn=None): 
     362        # columns will be 
     363        # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 
     364        # (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72), (u'COLUMN_PROPID', 19), 
     365        # (u'ORDINAL_POSITION', 19), (u'COLUMN_HASDEFAULT', 11), 
     366        # (u'COLUMN_DEFAULT', 203), (u'COLUMN_FLAGS', 19), (u'IS_NULLABLE', 11), 
     367        # (u'DATA_TYPE', 18), (u'TYPE_GUID', 72), (u'CHARACTER_MAXIMUM_LENGTH', 19), 
     368        # (u'CHARACTER_OCTET_LENGTH', 19), (u'NUMERIC_PRECISION', 18), 
     369        # (u'NUMERIC_SCALE', 2), (u'DATETIME_PRECISION', 19), 
     370        # (u'CHARACTER_SET_CATALOG', 202), (u'CHARACTER_SET_SCHEMA', 202), 
     371        # (u'CHARACTER_SET_NAME', 202), (u'COLLATION_CATALOG', 202), 
     372        # (u'COLLATION_SCHEMA', 202), (u'COLLATION_NAME', 202), 
     373        # (u'DOMAIN_CATALOG', 202), (u'DOMAIN_SCHEMA', 202), 
     374        # (u'DOMAIN_NAME', 202), (u'DESCRIPTION', 203)] 
     375        data, _ = self.sm.fetch(adSchemaColumns, conn=conn, schema=True) 
     376        cols = [] 
     377        for row in data: 
     378            # I tried passing criteria to OpenSchema, but passing None is 
     379            # not the same as passing pythoncom.Empty (which errors). 
     380            if tablename and row[2] != tablename: 
     381                continue 
     382             
     383            dbtype = dbtypes[row[11]] 
     384            c = self.sm.columnclass(row[3], dbtype, None, row[8]) 
     385            if dbtype in ("DATE", "DBDATE"): 
     386                c.type = datetime.date 
     387            elif dbtype == "DBTIME": 
     388                c.type = datetime.time 
     389            elif dbtype == "DBTIMESTAMP": 
     390                c.type = datetime.datetime 
     391            elif dbtype in ("SMALLINT", "INTEGER", "TINYINT", 
     392                            "UNSIGNEDTINYINT", "UNSIGNEDSMALLINT", 
     393                            "UNSIGNEDINT"): 
     394                c.type = int 
     395                c.hints['bytes'] = row[15] 
     396            elif dbtype == "BOOLEAN": 
     397                c.type = bool 
     398            elif dbtype in ("BIGINT", "UNSIGNEDBIGINT"): 
     399                c.type = long 
     400                c.hints['bytes'] = row[15] 
     401            elif dbtype in ("SINGLE", "DOUBLE"): 
     402                c.type = float 
     403                c.hints['precision'] = row[15] 
     404                c.hints['scale'] = row[16] 
     405            elif dbtype in ("DECIMAL", "NUMERIC", "CURRENCY"): 
     406                c.type = decimal.Decimal 
     407                c.hints['precision'] = row[15] 
     408                c.hints['scale'] = row[16] 
     409            elif dbtype in ("BSTR", "VARIANT", "BINARY", "CHAR", 
     410                            "VARCHAR", "LONGVARCHAR", 
     411                            "VARBINARY", "LONGVARBINARY"): 
     412                c.type = str 
     413                if row[13]: 
     414                    c.hints['bytes'] = row[13] 
     415            elif dbtype in ("WCHAR", "VARWCHAR", "LONGVARWCHAR"): 
     416                c.type = unicode 
     417                if row[13]: 
     418                    c.hints['bytes'] = row[13] 
     419            cols.append(c) 
     420        return cols 
     421     
     422    def _get_indices(self, tablename=None, conn=None): 
     423        # cols will be 
     424        # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 
     425        # (u'INDEX_CATALOG', 202), (u'INDEX_SCHEMA', 202), (u'INDEX_NAME', 202), 
     426        # (u'PRIMARY_KEY', 11), (u'UNIQUE', 11), (u'CLUSTERED', 11), (u'TYPE', 18), 
     427        # (u'FILL_FACTOR', 3), (u'INITIAL_SIZE', 3), (u'NULLS', 3), 
     428        # (u'SORT_BOOKMARKS', 11), (u'AUTO_UPDATE', 11), (u'NULL_COLLATION', 3), 
     429        # (u'ORDINAL_POSITION', 19), (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72), 
     430        # (u'COLUMN_PROPID', 19), (u'COLLATION', 2), (u'CARDINALITY', 21), 
     431        # (u'PAGES', 3), (u'FILTER_CONDITION', 202), (u'INTEGRATED', 11)] 
     432        data, _ = self.sm.fetch(adSchemaIndexes, conn=conn, schema=True) 
     433        indices = [] 
     434        for row in data: 
     435            # I tried passing criteria to OpenSchema, but passing None is 
     436            # not the same as passing pythoncom.Empty (which errors). 
     437            if tablename and row[2] != tablename: 
     438                continue 
     439            i = self.sm.indexclass(row[5], row[2], row[17], row[6], row[7]) 
     440            indices.append(i) 
     441        return indices 
     442     
     443    def _rename(self, oldtable, newname): 
     444        conn = self.sm.connection() 
     445        try: 
     446            cat = win32com.client.Dispatch(r'ADOX.Catalog') 
     447            cat.ActiveConnection = conn 
     448            cat.Tables(oldtable.name).Name = newname 
     449        finally: 
     450            conn = None 
     451            cat = None 
     452 
     453 
    355454class StorageManagerADO(db.StorageManagerDB): 
    356455    """StoreManager to save and retrieve Units via ADO 2.7. 
     
    361460    decompiler = ADOSQLDecompiler 
    362461    fromAdapter = AdapterFromADO() 
     462    tablesetclass = ADOTableSet 
     463    columnsetclass = ADOColumnSet 
    363464     
    364465    def connatoms(self): 
     
    370471        return atoms 
    371472     
    372     def sql_name(self, name, quoted=True): 
    373         if quoted: 
    374             name = '[' + name + ']' 
    375         return name 
     473    def quote(self, name): 
     474        """Return name, quoted for use in an SQL statement.""" 
     475        return '[' + name + ']' 
    376476     
    377477    def _get_conn(self): 
     
    443543        adoconn = win32com.client.Dispatch(r'ADODB.Connection') 
    444544        return "ADO Version: %s" % adoconn.Version 
    445      
    446     #                               Schemas                               # 
    447      
    448     def has_storage(self, cls): 
    449         names = [t.name for t in self.get_tables()] 
    450         return self.table_name(cls.__name__, quoted=False) in names 
    451      
    452     def rename_storage(self, oldname, newname): 
    453         oldname = self.table_name(oldname, quoted=False) 
    454         newname = self.table_name(newname, quoted=False) 
    455         self.arena.log("rename table %s to %s" % (oldname, newname), 
    456                        dejavu.LOGSQL) 
    457          
    458         conn = self.connection() 
    459         try: 
    460             cat = win32com.client.Dispatch(r'ADOX.Catalog') 
    461             cat.ActiveConnection = conn 
    462             cat.Tables(oldname).Name = newname 
    463         finally: 
    464             conn = None 
    465             cat = None 
    466      
    467     def rename_property(self, cls, oldname, newname): 
    468         clsname = cls.__name__ 
    469         tblname = self.table_name(clsname, quoted=False) 
    470         oldname = self.column_name(clsname, oldname, quoted=False) 
    471         newname = self.column_name(clsname, newname, quoted=False) 
    472         self.arena.log("rename %s column %s to %s" % 
    473                        (tblname, oldname, newname), 
    474                        dejavu.LOGSQL) 
    475           
    476         conn = self.connection() 
    477         try: 
    478             cat = win32com.client.Dispatch(r'ADOX.Catalog') 
    479             cat.ActiveConnection = conn 
    480             cat.Tables(tblname).Columns(oldname).Name = newname 
    481         finally: 
    482             conn = None 
    483             cat = None 
    484      
    485     def drop_index(self, cls, name): 
    486         clsname = cls.__name__ 
    487         tablename = self.table_name(clsname, quoted=False) 
    488         qtablename = self.table_name(clsname) 
    489         colname = self.column_name(clsname, name, quoted=False) 
    490          
    491         for i in self.get_indices(tablename): 
    492             if i.colname == colname: 
    493                 # The INDEX_NAME may include a trailing " ASC" or other data 
    494                 self.execute('DROP INDEX [%s] ON %s;' % (i.name, qtablename)) 
    495      
    496     def get_tables(self, conn=None): 
    497         # cols will be 
    498         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 
    499         # (u'TABLE_TYPE', 202), (u'TABLE_GUID', 72), (u'DESCRIPTION', 203), 
    500         # (u'TABLE_PROPID', 19), (u'DATE_CREATED', 7), (u'DATE_MODIFIED', 7)] 
    501         data, cols = self.fetch(adSchemaTables, conn=conn, schema=True) 
    502         return [db.Table(row[2]) for row in data] 
    503      
    504     def get_columns(self, tablename, conn=None): 
    505         # cols will be 
    506         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 
    507         # (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72), (u'COLUMN_PROPID', 19), 
    508         # (u'ORDINAL_POSITION', 19), (u'COLUMN_HASDEFAULT', 11), 
    509         # (u'COLUMN_DEFAULT', 203), (u'COLUMN_FLAGS', 19), (u'IS_NULLABLE', 11), 
    510         # (u'DATA_TYPE', 18), (u'TYPE_GUID', 72), (u'CHARACTER_MAXIMUM_LENGTH', 19), 
    511         # (u'CHARACTER_OCTET_LENGTH', 19), (u'NUMERIC_PRECISION', 18), 
    512         # (u'NUMERIC_SCALE', 2), (u'DATETIME_PRECISION', 19), 
    513         # (u'CHARACTER_SET_CATALOG', 202), (u'CHARACTER_SET_SCHEMA', 202), 
    514         # (u'CHARACTER_SET_NAME', 202), (u'COLLATION_CATALOG', 202), 
    515         # (u'COLLATION_SCHEMA', 202), (u'COLLATION_NAME', 202), 
    516         # (u'DOMAIN_CATALOG', 202), (u'DOMAIN_SCHEMA', 202), 
    517         # (u'DOMAIN_NAME', 202), (u'DESCRIPTION', 203)] 
    518         data, _ = self.fetch(adSchemaColumns, conn=conn, schema=True) 
    519         cols = [] 
    520         for row in data: 
    521             # I tried passing criteria to OpenSchema, but passing None is 
    522             # not the same as passing pythoncom.Empty (which errors). 
    523             if tablename and row[2] != tablename: 
    524                 continue 
    525             datatype = row[11] 
    526             c = db.Column(row[3], None, row[8]) 
    527             if datatype in (adDate, adDBDate): 
    528                 c.type = datetime.date 
    529             elif datatype == adDBTime: 
    530                 c.type = datetime.time 
    531             elif datatype == adDBTimeStamp: 
    532                 c.type = datetime.datetime 
    533             elif datatype in (adSmallInt, adInteger, adTinyInt, 
    534                               adUnsignedTinyInt, adUnsignedSmallInt, 
    535                               adUnsignedInt): 
    536                 c.type = int 
    537                 c.hints['bytes'] = row[15] 
    538             elif datatype == adBoolean: 
    539                 c.type = bool 
    540             elif datatype in (adBigInt, adUnsignedBigInt): 
    541                 c.type = long 
    542                 c.hints['bytes'] = row[15] 
    543             elif datatype in (adSingle, adDouble, adCurrency): 
    544                 c.type = float 
    545                 c.hints['bytes'] = row[15] 
    546             elif datatype in (adDecimal, adNumeric): 
    547                 c.type = decimal.Decimal 
    548                 c.hints['bytes'] = row[15] 
    549             elif datatype in (adBSTR, adVariant, adBinary, adChar, 
    550                               adVarChar, adLongVarChar, 
    551                               adVarBinary, adLongVarBinary): 
    552                 c.type = str 
    553                 if row[13]: 
    554                     c.hints['bytes'] = row[13] 
    555             elif datatype in (adWChar, adVarWChar, adLongVarWChar): 
    556                 c.type = unicode 
    557                 if row[13]: 
    558                     c.hints['bytes'] = row[13] 
    559             cols.append(c) 
    560         return cols 
    561      
    562     def get_indices(self, tablename=None, conn=None): 
    563         # cols will be 
    564         # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 
    565         # (u'INDEX_CATALOG', 202), (u'INDEX_SCHEMA', 202), (u'INDEX_NAME', 202), 
    566         # (u'PRIMARY_KEY', 11), (u'UNIQUE', 11), (u'CLUSTERED', 11), (u'TYPE', 18), 
    567         # (u'FILL_FACTOR', 3), (u'INITIAL_SIZE', 3), (u'NULLS', 3), 
    568         # (u'SORT_BOOKMARKS', 11), (u'AUTO_UPDATE', 11), (u'NULL_COLLATION', 3), 
    569         # (u'ORDINAL_POSITION', 19), (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72), 
    570         # (u'COLUMN_PROPID', 19), (u'COLLATION', 2), (u'CARDINALITY', 21), 
    571         # (u'PAGES', 3), (u'FILTER_CONDITION', 202), (u'INTEGRATED', 11)] 
    572         data, _ = self.fetch(adSchemaIndexes, conn=conn, schema=True) 
    573         indices = [] 
    574         for row in data: 
    575             # I tried passing criteria to OpenSchema, but passing None is 
    576             # not the same as passing pythoncom.Empty (which errors). 
    577             if tablename and row[2] != tablename: 
    578                 continue 
    579             indices.append(db.Index(row[5], row[2], row[17], row[6], row[7])) 
    580         return indices 
     545 
    581546 
    582547 
     
    670635 
    671636 
     637class SQLServerColumnSet(ADOColumnSet): 
     638     
     639    def __setitem__(self, key, column): 
     640        t = self.table 
     641        # SQL Server doesn't use the "COLUMN" keyword with "ADD" 
     642        t.sm.execute("ALTER TABLE %s ADD %s %s;" % 
     643                     (t.sm.quote(t.name), t.sm.quote(column.name), 
     644                      column.dbtype)) 
     645        dict.__setitem__(self, key, column) 
     646     
     647    def _rename(self, oldcol, newname): 
     648        t = self.table 
     649        t.sm.execute("EXEC sp_rename '%s.%s', '%s', 'COLUMN'" % 
     650                     (t.name, oldcol.name, newname)) 
     651 
     652 
    672653class StorageManagerADO_SQLServer(StorageManagerADO): 
    673654     
    674655    typeAdapter = FieldTypeAdapter_SQLServer() 
    675656    toAdapter = AdapterToADOSQL_SQLServer() 
     657    columnsetclass = SQLServerColumnSet 
    676658     
    677659    def __init__(self, name, arena, allOptions={}): 
     
    685667        cls = unit.__class__ 
    686668        clsname = cls.__name__ 
    687         tablename = self.table_name(clsname) 
     669        tablename = self.tables[clsname].name 
    688670         
    689671        fields = [] 
     
    695677                continue 
    696678            val = self.toAdapter.coerce(getattr(unit, key)) 
    697             fields.append(self.column_name(clsname, key)) 
     679            fields.append(self.quote(self.column_name(clsname, key))) 
    698680            values.append(val) 
    699681         
     
    701683        values = ", ".join(values) 
    702684        self.execute('INSERT INTO %s (%s) VALUES (%s);' % 
    703                      (str(tablename), fields, values)) 
     685                     (self.quote(tablename), fields, values)) 
    704686         
    705687        # Grab the new ID. This is threadsafe because db.reserve has a mutex. 
     
    707689        # None) when retrieving ID's just after a 99-thread-test ran. Moving 
    708690        # the multithreading test fixed it. IDENT_CURRENT worked regardless. 
    709         data, col_defs = self.fetch("SELECT IDENT_CURRENT('%s');" 
    710                                     % str(tablename)) 
     691        data, col_defs = self.fetch("SELECT IDENT_CURRENT('%s');" % tablename) 
    711692        setattr(unit, cls.identifiers[0], data[0][0]) 
    712693     
     
    714695     
    715696    def create_database(self): 
    716         # This method hasn't been tested yet for SQL server
     697        # This method hasn't been tested yet for SQL server (only MSDE)
    717698        adoconn = win32com.client.Dispatch(r'ADODB.Connection') 
    718699        atoms = self.connatoms() 
    719700        atoms['INITIAL CATALOG'] = "tempdb" 
    720701        adoconn.Open("; ".join(["%s=%s" % (k, v) for k, v in atoms.iteritems()])) 
    721         adoconn.Execute("CREATE DATABASE %s" % self.sql_name(self.dbname)) 
     702        adoconn.Execute("CREATE DATABASE %s" % self.quote(self.sql_name(self.dbname))) 
    722703        adoconn.Close() 
    723704     
     
    727708        atoms['INITIAL CATALOG'] = "tempdb" 
    728709        adoconn.Open("; ".join(["%s=%s" % (k, v) for k, v in atoms.iteritems()])) 
    729         adoconn.Execute("DROP DATABASE %s;" % self.sql_name(self.dbname)) 
     710        adoconn.Execute("DROP DATABASE %s;" % self.quote(self.sql_name(self.dbname))) 
    730711        adoconn.Close() 
    731      
    732     def add_property(self, cls, name): 
    733         clsname = cls.__name__ 
    734         # SQL Server doesn't use the "COLUMN" keyword with "ADD" 
    735         self.execute("ALTER TABLE %s ADD %s %s;" % 
    736                      (self.table_name(clsname), 
    737                       self.column_name(clsname, name), 
    738                       self.typeAdapter.coerce(cls, name), 
    739                       )) 
    740      
    741     def rename_property(self, cls, oldname, newname): 
    742         clsname = cls.__name__ 
    743         oldname = self.column_name(clsname, oldname, quoted=False) 
    744         newname = self.column_name(clsname, newname, quoted=False) 
    745         if oldname != newname: 
    746             self.execute("EXEC sp_rename '%s.%s', '%s', 'COLUMN'" % 
    747                          (self.table_name(clsname), oldname, newname)) 
     712 
    748713 
    749714 
     
    866831        cls = unit.__class__ 
    867832        clsname = cls.__name__ 
    868         tablename = self.table_name(clsname) 
     833        tablename = self.tables[clsname].name 
    869834         
    870835        fields = [] 
     
    876841                continue 
    877842            val = self.toAdapter.coerce(getattr(unit, key)) 
    878             fields.append(self.column_name(clsname, key)) 
     843            fields.append(self.quote(self.column_name(clsname, key))) 
    879844            values.append(val) 
    880845         
     
    882847        values = ", ".join(values) 
    883848        self.execute('INSERT INTO %s (%s) VALUES (%s);' % 
    884                      (str(tablename), fields, values)) 
     849                     (self.quote(tablename), fields, values)) 
    885850         
    886851        # Grab the new ID. This is threadsafe because db.reserve has a mutex. 
  • trunk/storage/storemysql.py

    r225 r226  
    149149 
    150150 
    151 class StorageManagerMySQL(db.StorageManagerDB): 
    152     """StoreManager to save and retrieve Units via _mysql.""" 
    153      
    154     sql_name_max_length = 64 
    155     # MySQL uses case-sensitive database and table names on Unix, but 
    156     # not on Windows. Use all-lowercase identifiers to work around the 
    157     # problem. "Column names, index names, and column aliases are not 
    158     # case sensitive on any platform." 
    159     # If deployers set lower_case_table_names to 1, it would help. 
    160     sql_name_caseless = True 
    161      
    162     typeAdapter = FieldTypeAdapterMySQL() 
    163     toAdapter = AdapterToMySQL() 
    164     fromAdapter = AdapterFromMySQL() 
    165      
    166     def __init__(self, name, arena, allOptions={}): 
    167         connargs = ["host", "user", "passwd", "db", "port", "unix_socket", 
    168                     "conv", "connect_time", "compress", "named_pipe", 
    169                     "init_command", "read_default_file", "read_default_group", 
    170                     "cursorclass", "client_flag", 
    171                     ] 
    172         self.connargs = dict([(k, v) for k, v in allOptions.iteritems() 
    173                               if k in connargs]) 
    174         self.dbname = self.connargs['db'] 
    175          
    176         db.StorageManagerDB.__init__(self, name, arena, allOptions) 
    177          
    178         self.decompiler = MySQLDecompiler 
    179         # Get the version string from MySQL, to see if we need 
    180         # a different decompiler. 
    181         conn = self._template_conn() 
    182         rowdata, cols = self.fetch("SELECT version();", conn) 
    183         conn.close() 
    184         v = rowdata[0][0] 
    185         self._version = storage.Version(v) 
    186         if self._version > storage.Version("4.1.1"): 
    187             self.decompiler = MySQLDecompiler411 
    188      
    189     def sql_name(self, name, quoted=True): 
    190         name = db.StorageManagerDB.sql_name(self, name, quoted) 
    191         if quoted: 
    192             name = '`' + name.replace('`', '``') + '`' 
    193         return name 
    194      
    195     def _get_conn(self): 
    196         try: 
    197             conn = _mysql.connect(**self.connargs) 
    198         except _mysql.OperationalError, x: 
    199             if x.args[0] == 1040:   # Too many connections 
    200                 raise db.OutOfConnectionsError 
    201             raise 
    202         return conn 
    203      
    204     def _template_conn(self): 
    205         tmplconn = self.connargs.copy() 
    206         tmplconn['db'] = 'mysql' 
    207         return _mysql.connect(**tmplconn) 
    208      
    209     def fetch(self, query, conn=None): 
    210         """fetch(query, conn=None) -> rowdata, columns. 
    211          
    212         rowdata: a nested list (or tuples), column values within rows. 
    213         columns: a series of 2-tuples (or more). The first tuple value 
    214             will be the column name, the second value will be the column 
    215             type. 
    216         """ 
    217         if conn is None: 
    218             conn = self.connection() 
    219         self.execute(query, conn) 
    220         # store_result uses a client-side cursor 
    221         res = conn.store_result() 
    222         return res.fetch_row(0, 0), res.describe() 
    223      
    224     def destroy(self, unit): 
    225         """destroy(unit). Delete the unit.""" 
    226         self.execute('DELETE FROM %s WHERE %s;' % 
    227                      (self.table_name(unit.__class__.__name__), 
    228                       self.id_clause(unit))) 
    229      
    230     def version(self): 
    231         return "MySQL Version: %s" % self._version 
    232      
    233     def _seq_UnitSequencerInteger(self, unit): 
    234         """Reserve a unit using the table's AUTO_INCREMENT field.""" 
    235         cls = unit.__class__ 
    236         clsname = cls.__name__ 
    237         tablename = self.table_name(clsname) 
    238          
    239         fields = [] 
    240         values = [] 
    241         for key in cls.properties: 
    242             typename = self.typeAdapter.coerce(cls, key) 
    243             if typename.endswith("AUTO_INCREMENT"): 
    244                 # Skip this field, since we're using AUTO_INCREMENT 
    245                 continue 
    246             val = self.toAdapter.coerce(getattr(unit, key)) 
    247             fields.append(self.column_name(clsname, key)) 
    248             values.append(val) 
    249          
    250         fields = ", ".join(fields) 
    251         values = ", ".join(values) 
    252          
    253         conn = self.connection() 
    254         self.execute('INSERT INTO %s (%s) VALUES (%s);' % 
    255                      (str(tablename), fields, values), 
    256                      conn) 
    257          
    258         # Grab the new ID. This is threadsafe because db.reserve has a mutex. 
    259         setattr(unit, cls.identifiers[0], conn.insert_id()) 
    260      
    261     #                               Schemas                               # 
    262      
    263     def create_database(self): 
    264         # _mysql has create_db and drop_db commands, but they're deprecated. 
    265         sql = 'CREATE DATABASE %s;' % self.sql_name(self.dbname) 
    266         conn = self._template_conn() 
    267         self.execute(sql, conn) 
    268         conn.close() 
    269      
    270     def drop_database(self): 
    271         sql = 'DROP DATABASE %s;' % self.sql_name(self.dbname) 
    272         conn = self._template_conn() 
    273         self.execute(sql, conn) 
    274         conn.close() 
    275      
    276     def create_storage(self, cls): 
    277         clsname = cls.__name__ 
    278         tablename = self.table_name(clsname) 
    279         typename = self.typeAdapter.coerce 
     151class MySQLIndexSet(db.IndexSet): 
     152     
     153    def __delitem__(self, key): 
     154        t = self.table 
     155        # MySQL might rename multiple-column indices to "PRIMARY" 
     156        for i in t.sm.tables._get_indices(t.name): 
     157            if i.colname == self[key].colname: 
     158                t.sm.execute('DROP INDEX %s ON %s;' % 
     159                             (t.sm.quote(i.name), t.sm.quote(t.name))) 
     160 
     161 
     162class MySQLColumnSet(db.ColumnSet): 
     163     
     164    def _rename(self, oldcol, newname): 
     165        # Override this to do the actual rename at the DB level. 
     166        t = self.table 
     167        t.sm.execute("ALTER TABLE %s CHANGE %s %s %s;" % 
     168                     (t.sm.quote(t.name), t.sm.quote(oldcol.name), 
     169                      t.sm.quote(newname), oldcol.dbtype)) 
     170 
     171 
     172class MySQLTableSet(db.TableSet): 
     173     
     174    def __setitem__(self, key, table): 
     175        q = self.sm.quote 
    280176         
    281177        fields = [] 
    282178        pk = [] 
    283         for key in cls.properties
    284             qname = self.column_name(clsname, key
    285             dbtype = typename(cls, key) 
     179        for colname, col in table.columns.iteritems()
     180            qname = q(col.name
     181            dbtype = col.dbtype 
    286182            fields.append('%s %s' % (qname, dbtype)) 
    287             if key in cls.identifiers: 
     183            if colname in table.mysql_identifiers: 
    288184                if dbtype.endswith('BLOB') or dbtype == 'TEXT': 
    289185                    # MySQL won't allow indexes on a BLOB field 
    290186                    # without a specific length. 
    291                     qname = "%s(%s)" % (qname, 255) 
     187                    qname = "%s(255)" % qname 
    292188                pk.append(qname) 
     189         
    293190        pk = ", ".join(pk) 
    294191        if pk: 
    295192            pk = ", PRIMARY KEY (%s)" % pk 
    296         self.execute('CREATE TABLE %s (%s%s);' 
    297                      % (tablename, ", ".join(fields), pk)) 
    298          
    299         hasdummy = False 
    300         if isinstance(cls.sequencer, dejavu.UnitSequencerInteger): 
    301             i = cls.sequencer.initial 
    302             if i > 1: 
    303                 # Wow, what a hack. We have to create a dummy row 
    304                 # to set the autoincrement initial value, and we 
    305                 # can't delete it until after the CREATE INDEX 
    306                 # statements below (or the counter will revert). 
    307                 colname = self.column_name(clsname, cls.identifiers[0]) 
    308                 self.execute("INSERT INTO %s (%s) VALUES (%s);" 
    309                              % (tablename, colname, i - 1)) 
    310                 hasdummy = True 
    311          
    312         for index in cls.indices(): 
    313             i = self.table_name("i" + clsname + index) 
    314              
    315             dbtype = typename(cls, index) 
     193         
     194        self.sm.execute('CREATE TABLE %s (%s%s);' % 
     195                        (q(table.name), ", ".join(fields), pk)) 
     196         
     197        seq = getattr(table, "mysql_sequencer", None) 
     198        if seq: 
     199            # Wow, what a hack. We have to INSERT a dummy row 
     200            # to set the autoincrement initial value, and we 
     201            # can't delete it until after the CREATE INDEX 
     202            # statements (or the counter will revert). 
     203            colname, initial = seq 
     204            self.sm.execute("INSERT INTO %s (%s) VALUES (%s);" 
     205                            % (q(table.name), q(colname), initial - 1)) 
     206         
     207        for k, index in table.columns.indices.iteritems(): 
     208            dbtype = table.columns[k].dbtype 
    316209            if dbtype.endswith('BLOB') or dbtype == 'TEXT': 
    317210                # MySQL won't allow indexes on a BLOB field 
    318211                # without a specific length. 
    319                 self.execute('CREATE INDEX %s ON %s (%s(%s));' % 
    320                              (i, tablename
    321                               self.column_name(clsname, index), 255)) 
     212                self.sm.execute('CREATE INDEX %s ON %s (%s(255));' % 
     213                                (q(index.name), q(table.name)
     214                                 q(index.colname))) 
    322215            else: 
    323                 self.execute('CREATE INDEX %s ON %s (%s);' % 
    324                              (i, tablename, 
    325                               self.column_name(clsname, index))) 
    326          
    327         if hasdummy: 
    328             self.execute("DELETE FROM %s" % tablename) 
    329      
    330     def rename_property(self, cls, oldname, newname): 
    331         clsname = cls.__name__ 
    332         oldcolname = self.column_name(clsname, oldname) 
    333         newcolname = self.column_name(clsname, newname) 
    334         if oldcolname != newcolname: 
    335             self.execute("ALTER TABLE %s CHANGE %s %s %s;" % 
    336                          (self.table_name(clsname), oldcolname, newcolname, 
    337                           self.typeAdapter.coerce(cls, newname))) 
    338      
    339     def drop_index(self, cls, name): 
    340         # MySQL might rename multiple-column indices to "PRIMARY" 
    341         clsname = cls.__name__ 
    342         names = [] 
    343         for i in self.get_indices(self.table_name(clsname, quoted=False)): 
    344             if i.name not in names: 
    345                 names.append(i.name) 
    346         for n in names: 
    347             self.execute('DROP INDEX %s ON %s;' % 
    348                          (self.sql_name(n), self.table_name(clsname))) 
    349      
    350     def get_tables(self, conn=None): 
    351         data, _ = self.fetch("SHOW TABLES FROM %s" % self.dbname, 
    352                              conn=conn) 
    353         return [db.Table(row[0]) for row in data] 
    354      
    355     def get_columns(self, tablename=None, conn=None): 
     216                self.sm.execute('CREATE INDEX %s ON %s (%s);' % 
     217                                (q(index.name), q(table.name), 
     218                                 q(index.colname))) 
     219         
     220        if seq: 
     221            self.sm.execute("DELETE FROM %s" % q(table.name)) 
     222         
     223        dict.__setitem__(self, key, table) 
     224     
     225    def _get_tables(self, conn=None): 
     226        data, _ = self.sm.fetch("SHOW TABLES FROM %s" % 
     227                                self.sm.quote(self.sm.dbname), 
     228                                conn=conn) 
     229        return [self.sm.tableclass(self.sm, row[0]) for row in data] 
     230     
     231    def _get_columns(self, tablename, conn=None): 
    356232        # cols are: Field, Type, Null, Key, Default, Extra. 
    357233        # See http://dev.mysql.com/doc/refman/4.1/en/describe.html 
    358         data, _ = self.fetch("SHOW COLUMNS FROM %s.%s" 
    359                              % (self.dbname, self.sql_name(tablename)), 
    360                              conn=conn) 
     234        q = self.sm.quote 
     235        data, _ = self.sm.fetch("SHOW COLUMNS FROM %s.%s" 
     236                                % (q(self.sm.dbname), q(tablename)), 
     237                                conn=conn) 
    361238        cols = [] 
    362239        for row in data: 
    363             c = db.Column(row[0], None, row[4]) 
     240            c = self.sm.columnclass(row[0], None, None, row[4]) 
    364241             
    365242            dbtype = row[1] 
     
    368245                c.hints['bytes'] = dbtype[parenpos+1:-1] 
    369246                dbtype = dbtype[:parenpos] 
     247            c.dbtype = dbtype 
    370248             
    371249            if dbtype in ('tinyint', 'smallint', 'mediumint', 'int', 'integer'): 
     
    391269        return cols 
    392270     
    393     def get_indices(self, tablename, conn=None): 
     271    def _get_indices(self, tablename, conn=None): 
    394272        indices = [] 
    395273        try: 
    396274            # cols are: Table, Non_unique, Key_name, Seq_in_index, Column_name, 
    397275            # Collation, Cardinality, Sub_part, Packed, Null, Index_type, Comment 
     276            q = self.sm.quote 
    398277            data, _ = self.fetch("SHOW INDEX FROM %s.%s" 
    399                                  % (self.dbname, self.sql_name(tablename)), 
     278                                 % (q(self.sm.dbname), q(tablename)), 
    400279                                 conn=conn) 
    401280        except _mysql.ProgrammingError, x: 
     
    404283        else: 
    405284            for row in data: 
    406                 indices.append(db.Index(row[2], row[0], row[4], None, not row[1])) 
     285                i = self.sm.indexclass(row[2], row[0], row[4], None, not row[1]) 
     286                indices.append(i) 
    407287        return indices 
    408288 
     289 
     290 
     291class StorageManagerMySQL(db.StorageManagerDB): 
     292    """StoreManager to save and retrieve Units via _mysql.""" 
     293     
     294    sql_name_max_length = 64 
     295    # MySQL uses case-sensitive database and table names on Unix, but 
     296    # not on Windows. Use all-lowercase identifiers to work around the 
     297    # problem. "Column names, index names, and column aliases are not 
     298    # case sensitive on any platform." 
     299    # If deployers set lower_case_table_names to 1, it would help. 
     300    sql_name_caseless = True 
     301     
     302    typeAdapter = FieldTypeAdapterMySQL() 
     303    toAdapter = AdapterToMySQL() 
     304    fromAdapter = AdapterFromMySQL() 
     305     
     306    tablesetclass = MySQLTableSet 
     307    columnsetclass = MySQLColumnSet 
     308    indexsetclass = MySQLIndexSet 
     309     
     310    def __init__(self, name, arena, allOptions={}): 
     311        connargs = ["host", "user", "passwd", "db", "port", "unix_socket", 
     312                    "conv", "connect_time", "compress", "named_pipe", 
     313                    "init_command", "read_default_file", "read_default_group", 
     314                    "cursorclass", "client_flag", 
     315                    ] 
     316        self.connargs = dict([(k, v) for k, v in allOptions.iteritems() 
     317                              if k in connargs]) 
     318        self.dbname = self.connargs['db'] 
     319         
     320        db.StorageManagerDB.__init__(self, name, arena, allOptions) 
     321         
     322        self.decompiler = MySQLDecompiler 
     323        # Get the version string from MySQL, to see if we need 
     324        # a different decompiler. 
     325        conn = self._template_conn() 
     326        rowdata, cols = self.fetch("SELECT version();", conn) 
     327        conn.close() 
     328        v = rowdata[0][0] 
     329        self._version = storage.Version(v) 
     330        if self._version > storage.Version("4.1.1"): 
     331            self.decompiler = MySQLDecompiler411 
     332     
     333    def quote(self, name): 
     334        """Return name, quoted for use in an SQL statement.""" 
     335        return '`' + name.replace('`', '``') + '`' 
     336     
     337    def _get_conn(self): 
     338        try: 
     339            conn = _mysql.connect(**self.connargs) 
     340        except _mysql.OperationalError, x: 
     341            if x.args[0] == 1040:   # Too many connections 
     342                raise db.OutOfConnectionsError 
     343            raise 
     344        return conn 
     345     
     346    def _template_conn(self): 
     347        tmplconn = self.connargs.copy() 
     348        tmplconn['db'] = 'mysql' 
     349        return _mysql.connect(**tmplconn) 
     350     
     351    def fetch(self, query, conn=None): 
     352        """fetch(query, conn=None) -> rowdata, columns. 
     353         
     354        rowdata: a nested list (or tuples), column values within rows. 
     355        columns: a series of 2-tuples (or more). The first tuple value 
     356            will be the column name, the second value will be the column 
     357            type. 
     358        """ 
     359        if conn is None: 
     360            conn = self.connection() 
     361        self.execute(query, conn) 
     362        # store_result uses a client-side cursor 
     363        res = conn.store_result() 
     364        return res.fetch_row(0, 0), res.describe() 
     365     
     366    def destroy(self, unit): 
     367        """destroy(unit). Delete the unit.""" 
     368        self.execute('DELETE FROM %s WHERE %s;' % 
     369                     (self.quote(self.table_name(unit.__class__.__name__)), 
     370                      self.id_clause(unit))) 
     371     
     372    def version(self): 
     373        return "MySQL Version: %s" % self._version 
     374     
     375    def _seq_UnitSequencerInteger(self, unit): 
     376        """Reserve a unit using the table's AUTO_INCREMENT field.""" 
     377        cls = unit.__class__ 
     378        clsname = cls.__name__ 
     379        tablename = self.table_name(clsname) 
     380         
     381        fields = [] 
     382        values = [] 
     383        for key in cls.properties: 
     384            typename = self.typeAdapter.coerce(cls, key) 
     385            if typename.endswith("AUTO_INCREMENT"): 
     386                # Skip this field, since we're using AUTO_INCREMENT 
     387                continue 
     388            val = self.toAdapter.coerce(getattr(unit, key)) 
     389            fields.append(self.quote(self.column_name(clsname, key))) 
     390            values.append(val) 
     391         
     392        fields = ", ".join(fields) 
     393        values = ", ".join(values) 
     394         
     395        conn = self.connection() 
     396        self.execute('INSERT INTO %s (%s) VALUES (%s);' % 
     397                     (self.quote(tablename), fields, values), 
     398                     conn) 
     399         
     400        # Grab the new ID. This is threadsafe because db.reserve has a mutex. 
     401        setattr(unit, cls.identifiers[0], conn.insert_id()) 
     402     
     403    #                               Schemas                               # 
     404     
     405    def create_database(self): 
     406        # _mysql has create_db and drop_db commands, but they're deprecated. 
     407        sql = 'CREATE DATABASE %s;' % self.quote(self.sql_name(self.dbname)) 
     408        conn = self._template_conn() 
     409        self.execute(sql, conn) 
     410        conn.close() 
     411     
     412    def drop_database(self): 
     413        sql = 'DROP DATABASE %s;' % self.quote(self.sql_name(self.dbname)) 
     414        conn = self._template_conn() 
     415        self.execute(sql, conn) 
     416        conn.close() 
     417     
     418    def create_storage(self, cls): 
     419        """Create storage for the given class.""" 
     420        colname = self.column_name 
     421         
     422        # Make a Table object. 
     423        tablename = self.table_name(cls.__name__) 
     424        t = self.tableclass(self, tablename) 
     425         
     426        indices = cls.indices() 
     427        fields = [] 
     428        for key in cls.properties: 
     429            dbtype = self.typeAdapter.coerce(cls, key) 
     430            prop = cls.property(key) 
     431            cname = colname(cls.__name__, key) 
     432            col = self.columnclass(cname, dbtype, prop.type, 
     433                                   prop.default, prop.hints.copy()) 
     434            # Use the superclass call to avoid ALTER TABLE. 
     435            dict.__setitem__(t.columns, key, col) 
     436             
     437            if key in indices: 
     438                iname = self.table_name("i" + cls.__name__ + key) 
     439                i = self.indexclass(iname, tablename, cname) 
     440                # Use the superclass call to avoid CREATE INDEX. 
     441                dict.__setitem__(t.columns.indices, key, i) 
     442         
     443        # Hack to get PRIMARY KEY right. See MySQLTableSet.__setitem__ 
     444        t.mysql_identifiers = cls.identifiers 
     445         
     446        # Hack to get AUTO_INCREMENT right where initial > 1. 
     447        # See MySQLTableSet.__setitem__ 
     448        if isinstance(cls.sequencer, dejavu.UnitSequencerInteger): 
     449            i = cls.sequencer.initial 
     450            if i > 1: 
     451                colname = self.column_name(cls.__name__, cls.identifiers[0]) 
     452                t.mysql_sequencer = (colname, i) 
     453         
     454        # Attach to self.tables, which should call CREATE TABLE. 
     455        self.tables[cls.__name__] = t 
     456 
  • trunk/storage/storepypgsql.py

    r225 r226  
    2929        if isinstance(cls.sequencer, dejavu.UnitSequencerInteger): 
    3030            if key in cls.identifiers: 
    31                 return ("INTEGER DEFAULT nextval('%s_%s_seq') NOT NULL" 
    32                         % (cls.__name__, key)) 
     31                seqname = self.sm.quote("%s_%s_seq" % (cls.__name__, key)) 
     32                return "INTEGER DEFAULT nextval('%s') NOT NULL" % seqname 
    3333        bytes = int(prop.hints.get('bytes', db.maxint_bytes)) 
    3434        return self.int_type(bytes) 
     
    6565 
    6666 
     67class PgIndexSet(db.IndexSet): 
     68     
     69    def __delitem__(self, key): 
     70        """Drop the specified index.""" 
     71        t = self.table 
     72        iname = t.sm.sql_name("i" + t.name + key) 
     73        t.sm.execute('DROP INDEX %s;' % t.sm.quote(iname)) 
     74 
     75 
     76class PgTableSet(db.TableSet): 
     77     
     78    def _get_tables(self, conn=None): 
     79        data, _ = self.sm.fetch("SELECT tablename FROM pg_tables WHERE schemaname" 
     80                                " not in ('information_schema', 'pg_catalog')", 
     81                                conn=conn) 
     82        return [self.sm.tableclass(self.sm, row[0]) for row in data] 
     83     
     84    def _get_columns(self, tablename, conn=None): 
     85        data, _ = self.sm.fetch("SELECT oid FROM pg_class WHERE relname = '%s'" 
     86                                % tablename, conn=conn) 
     87        table_OID = data[0][0] 
     88        sql = ("SELECT attname, atttypid, attnum, attlen, atttypmod " 
     89               "FROM pg_attribute WHERE attrelid = %s" % table_OID) 
     90        data, _ = self.sm.fetch(sql, conn=conn) 
     91        cols = [] 
     92        for row in data: 
     93            name = row[0] 
     94            if name in ('tableoid', 'cmax', 'xmax', 'cmin', 'xmin', 
     95                        'oid', 'ctid'): 
     96                # This is a column which PostgreSQL defines automatically 
     97                continue 
     98             
     99            # Data type 
     100            dbtype, _ = self.sm.fetch("SELECT typname, typlen FROM pg_type " 
     101                                      "WHERE oid = %s" % row[1]) 
     102            if dbtype: 
     103                dbtype = dbtype[0][0] 
     104            else: 
     105                dbtype = None 
     106            c = self.sm.columnclass(row[0], dbtype) 
     107             
     108            # Python type 
     109            if dbtype: 
     110                if dbtype in ('int2', 'int4'): 
     111                    c.type = int 
     112                elif dbtype == 'bool': 
     113                    c.type = bool 
     114                elif dbtype == 'int8': 
     115                    c.type = long 
     116                elif dbtype in ('float4', 'float8', 'money'): 
     117                    c.type = float 
     118                    c.hints['precision'] = row[4] 
     119                elif dbtype == 'numeric': 
     120                    c.type = decimal.Decimal 
     121                    c.hints['precision'] = row[4] 
     122                elif dbtype == 'date': 
     123                    c.type = datetime.date 
     124                elif dbtype in ('timestamp', 'timestamptz'): 
     125                    c.type = datetime.datetime 
     126                elif dbtype in ('time', 'timetz'): 
     127                    c.type = datetime.time 
     128                elif dbtype in ('char', 'varchar', 'bpchar', 'text'): 
     129                    c.type = str 
     130             
     131            # Default value 
     132            default, _ = self.sm.fetch("SELECT adsrc FROM pg_attrdef " 
     133                                       "WHERE adnum = %s AND adrelid = %s" 
     134                                       % (row[2], table_OID)) 
     135            if default: 
     136                c.default = default[0][0] 
     137                # Sequences 
     138                if c.default.startswith("nextval("): 
     139                    c.default = None 
     140            else: 
     141                c.default = None 
     142             
     143            bytes = row[3] 
     144            if bytes > 0: 
     145                c.hints['bytes'] = bytes 
     146             
     147            cols.append(c) 
     148        return cols 
     149     
     150    def _get_indices(self, tablename, conn=None): 
     151        # Get the OID of the parent table. 
     152        data, _ = self.sm.fetch("SELECT oid FROM pg_class WHERE relname = '%s'" 
     153                                % tablename, conn=conn) 
     154        if not data: 
     155            return [] 
     156         
     157        table_OID = data[0][0] 
     158        indices = [] 
     159        data, _ = self.sm.fetch("SELECT pg_class.relname, indkey, indisprimary, " 
     160                                "indisunique FROM pg_index LEFT JOIN pg_class " 
     161                                "ON pg_index.indexrelid = pg_class.oid WHERE " 
     162                                "pg_index.indrelid = %s" % table_OID, conn=conn) 
     163        for row in data: 
     164            # indkey is an "array" (we get a space-separated string of ints). 
     165            cols = map(int, row[1].split(" ")) 
     166            for col in cols: 
     167                d, _ = self.sm.fetch("SELECT attname FROM pg_attribute " 
     168                                     "WHERE attrelid = %s AND attnum = %s" 
     169                                     % (table_OID, col), conn=conn) 
     170                i = self.sm.indexclass(row[0], tablename, d[0][0], 
     171                                       bool(row[2]), bool(row[3])) 
     172                indices.append(i) 
     173         
     174        return indices 
     175 
     176 
     177 
    67178class StorageManagerPgSQL(db.StorageManagerDB): 
    68179    """StoreManager to save and retrieve Units via pyPgSQL 1.35.""" 
     
    72183    toAdapter = AdapterToPgSQL() 
    73184    typeAdapter = FieldTypeAdapterPgSQL() 
     185     
     186    tablesetclass = PgTableSet 
     187    indexsetclass = PgIndexSet 
    74188     
    75189    def __init__(self, name, arena, allOptions={}): 
     
    85199            setattr(self, k, v) 
    86200        db.StorageManagerDB.__init__(self, name, arena, allOptions) 
    87      
    88     def sql_name(self, name, quoted=True): 
    89         name = db.StorageManagerDB.sql_name(self, name, quoted) 
     201        self.typeAdapter.sm = self 
     202     
     203    def quote(self, name): 
    90204        if self.quote_all: 
    91             if quoted: 
    92                 name = '"' + name.replace('"', '""') + '"' 
    93         else: 
     205            name = '"' + name.replace('"', '""') + '"' 
     206        return name 
     207     
     208    def sql_name(self, name): 
     209        name = db.StorageManagerDB.sql_name(self, name) 
     210        if not self.quote_all: 
    94211            name = name.lower() 
    95212        return name 
     
    140257        cls = unit.__class__ 
    141258        clsname = cls.__name__ 
    142         tablename = self.table_name(clsname) 
     259        tablename = self.tables[clsname].name 
    143260         
    144261        fields = [] 
     
    150267                continue 
    151268            val = self.toAdapter.coerce(getattr(unit, key)) 
    152             fields.append(self.column_name(clsname, key)) 
     269            fields.append(self.quote(self.column_name(clsname, key))) 
    153270            values.append(val) 
    154271         
     
    156273        values = ", ".join(values) 
    157274        self.execute('INSERT INTO %s (%s) VALUES (%s);' % 
    158                      (str(tablename), fields, values)) 
     275                     (self.quote(tablename), fields, values)) 
    159276         
    160277        # Grab the new ID. This is threadsafe because db.reserve has a mutex. 
    161         data, col_defs = self.fetch("SELECT last_value FROM %s_%s_seq;" 
    162                                     % (clsname, cls.identifiers[0])
     278        seqname = self.quote("%s_%s_seq" % (clsname, cls.identifiers[0])) 
     279        data, col_defs = self.fetch("SELECT last_value FROM %s;" % seqname
    163280        setattr(unit, cls.identifiers[0], data[0][0]) 
    164281     
     
    167284    def create_database(self): 
    168285        c = self._template_conn() 
    169         self.execute('CREATE DATABASE %s' % self.sql_name(self.dbname), c) 
     286        dbname = self.quote(self.sql_name(self.dbname)) 
     287        self.execute('CREATE DATABASE %s' % dbname, c) 
    170288        c.finish() 
    171289     
    172290    def drop_database(self): 
    173291        c = self._template_conn() 
    174         self.execute("DROP DATABASE %s;" % self.sql_name(self.dbname), c) 
     292        dbname = self.quote(self.sql_name(self.dbname)) 
     293        self.execute("DROP DATABASE %s;" % dbname, c) 
    175294        c.finish() 
    176      
    177     def has_storage(self, cls): 
    178         # For some odd reason, libpq errors if you try to filter by tablename. 
    179         sql = "SELECT tablename FROM pg_tables" 
    180         data, cols = self.fetch(sql) 
    181         return [self.table_name(cls.__name__, quoted=False)] in data 
    182295     
    183296    def create_storage(self, cls): 
    184297        """Create storage for the given class.""" 
    185         clsname = cls.__name__ 
    186         tablename = self.table_name(clsname) 
    187         typename = self.typeAdapter.coerce 
    188          
     298        colname = self.column_name 
     299         
     300        # Make a Table object. 
     301        tablename = self.table_name(cls.__name__) 
     302        t = self.tableclass(self, tablename) 
     303         
     304        indices = cls.indices() 
    189305        fields = [] 
    190306        for key in cls.properties: 
    191             dbtype = typename(cls, key) 
     307            dbtype = self.typeAdapter.coerce(cls, key) 
     308            prop = cls.property(key) 
     309            cname = colname(cls.__name__, key) 
     310             
     311            # Here's where we differ from the superclass: 
     312            # we have to manually CREATE SEQUENCE, and we must use 
     313            # class attributes to do so. 
    192314            if 'nextval' in dbtype: 
    193                 self.execute("CREATE SEQUENCE %s_%s_seq START %s;" 
    194                              % (clsname, key, cls.sequencer.initial)) 
    195             fields.append('%s %s' % (self.column_name(clsname, key), dbtype)) 
    196         self.execute('CREATE TABLE %s (%s);' % (tablename, ", ".join(fields))) 
    197          
    198         for index in cls.indices(): 
    199             i = self.table_name("i" + clsname + index) 
    200             self.execute('CREATE INDEX %s ON %s (%s);' % 
    201                          (i, tablename, self.column_name(clsname, index))) 
    202      
    203     def drop_index(self, cls, name): 
    204         clsname = cls.__name__ 
    205         for i in self.get_indices(clsname): 
    206             if i.colname == name: 
    207                 self.execute('DROP INDEX %s;' % self.sql_name(i.name)) 
    208      
    209     def get_tables(self, conn=None): 
    210         data, _ = self.fetch("SELECT tablename FROM pg_tables WHERE " 
    211                              "schemaname not in ('information_schema', 'pg_catalog')", 
    212                              conn=conn) 
    213         return [db.Table(row[0]) for row in data] 
    214      
    215     def get_columns(self, tablename=None, conn=None): 
    216         data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'" 
    217                              % tablename, conn=conn) 
    218         table_OID = data[0][0] 
    219         sql = ("SELECT attname, atttypid, attnum, attlen " 
    220                "FROM pg_attribute WHERE attrelid = %s" % table_OID) 
    221         data, _ = self.fetch(sql, conn=conn) 
    222         cols = [] 
    223         for row in data: 
    224             name = row[0] 
    225             if name in ('tableoid', 'cmax', 'xmax', 'cmin', 'xmin', 
    226                         'oid', 'ctid'): 
    227                 # This is a column which PostgreSQL defines automatically 
    228                 continue 
    229              
    230             # Data type 
    231             dbtype, _ = self.fetch("SELECT typname, typlen FROM pg_type " 
    232                                     "WHERE oid = %s" % row[1]) 
    233             if dbtype: 
    234                 dbtype = dbtype[0][0] 
    235                 if dbtype in ('int2', 'int4'): 
    236                     dbtype = int 
    237                 elif dbtype == 'bool': 
    238                     dbtype = bool 
    239                 elif dbtype == 'int8': 
    240                     dbtype = long 
    241                 elif dbtype in ('float4', 'float8', 'money'): 
    242                     dbtype = float 
    243                 elif dbtype == 'numeric': 
    244                     dbtype = decimal.Decimal 
    245                 elif dbtype == 'date': 
    246                     dbtype = datetime.date 
    247                 elif dbtype in ('timestamp', 'timestamptz'): 
    248                     dbtype = datetime.datetime 
    249                 elif dbtype in ('time', 'timetz'): 
    250                     dbtype = datetime.time 
    251                 elif dbtype in ('char', 'varchar', 'bpchar', 'text'): 
    252                     dbtype = str 
    253             else: 
    254                 dbtype = None 
    255              
    256             # Default value 
    257             default, _ = self.fetch("SELECT adsrc FROM pg_attrdef " 
    258                                     "WHERE adnum = %s AND adrelid = %s" 
    259                                     % (row[2], table_OID)) 
    260             if default: 
    261                 default = default[0][0] 
    262                 if default.startswith("nextval("): 
    263                     default = None 
    264             else: 
    265                 default = None 
    266              
    267             c = db.Column(row[0], dbtype, default) 
    268              
    269             bytes = row[3] 
    270             if bytes > 0: 
    271                 c.hints['bytes'] = bytes 
    272              
    273             cols.append(c) 
    274         return cols 
    275      
    276     def get_indices(self, tablename, conn=None): 
    277         # Get the OID of the parent table. 
    278         data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'" 
    279                              % tablename, conn=conn) 
    280         if not data: 
    281             return [] 
    282          
    283         table_OID = data[0][0] 
    284         indices = [] 
    285         data, _ = self.fetch("SELECT pg_class.relname, indkey, indisprimary, " 
    286                              "indisunique FROM pg_index LEFT JOIN pg_class " 
    287                              "ON pg_index.indexrelid = pg_class.oid WHERE " 
    288                              "pg_index.indrelid = %s" % table_OID, conn=conn) 
    289         for row in data: 
    290             cols = map(int, row[1].split(" ")) 
    291             for col in cols: 
    292                 d, _ = self.fetch("SELECT attname FROM pg_attribute " 
    293                                   "WHERE attrelid = %s AND attnum = %s" 
    294                                   % (table_OID, col), conn=conn) 
    295                 indices.append(db.Index(row[0], tablename, d[0][0], 
    296                                         bool(row[2]), bool(row[3]))) 
    297          
    298         return indices 
    299  
     315                seqname = self.quote("%s_%s_seq" % (tablename, cname)) 
     316                self.execute("CREATE SEQUENCE %s START %s;" 
     317                             % (seqname, cls.sequencer.initial)) 
     318             
     319            col = self.columnclass(cname, dbtype, prop.type, 
     320                                   prop.default, prop.hints.copy()) 
     321            # Use the superclass call to avoid ALTER TABLE. 
     322            dict.__setitem__(t.columns, key, col) 
     323             
     324            if key in indices: 
     325                iname = self.table_name("i" + cls.__name__ + key) 
     326                i = self.indexclass(iname, tablename, cname) 
     327                # Use the superclass call to avoid CREATE INDEX. 
     328                dict.__setitem__(t.columns.indices, key, i) 
     329         
     330        # Attach to self.tables, which should call CREATE TABLE. 
     331        self.tables[cls.__name__] = t 
     332 
  • trunk/storage/storesqlite.py

    r225 r226  
    143143class FieldTypeAdapterSQLite(db.FieldTypeAdapter): 
    144144     
     145    numeric_max_precision = 14 
     146    numeric_max_bytes = 7 
     147     
    145148    def coerce(self, cls, key): 
    146149        """coerce(cls, key) -> SQL typename for valuetype.""" 
     
    155158 
    156159 
     160class SQLiteTableSet(db.TableSet): 
     161     
     162    def _get_tables(self, conn=None): 
     163        data, _ = self.sm.fetch("SELECT name FROM sqlite_master WHERE type = 'table'") 
     164        return [self.sm.tableclass(self.sm, row[0]) for row in data] 
     165     
     166    def _get_columns(self, tablename, conn=None): 
     167        data, coldefs = self.sm.fetch("SELECT * FROM %s WHERE 1 == 0" 
     168                                      % self.sm.quote(tablename), conn=conn) 
     169        return [self.sm.columnclass(col[0], "", str, None) for col in coldefs] 
     170     
     171    def _get_indices(self, tablename, conn=None): 
     172        data, _ = self.sm.fetch("SELECT name, tbl_name, sql FROM sqlite_master " 
     173                                "WHERE type = 'index'") 
     174        indices = [] 
     175        for row in data: 
     176            colname = row[2].split("(")[-1] 
     177            colname = colname[1:-2] 
     178            indices.append(self.sm.indexclass(row[0], row[1], colname)) 
     179        return indices 
     180     
     181    def _rename(self, oldtable, newname): 
     182        if _rename_table_support: 
     183            self.sm.execute("ALTER TABLE %s RENAME TO %s" % 
     184                            (self.sm.quote(oldtable.name), 
     185                             self.sm.quote(newname))) 
     186        else: 
     187            raise NotImplementedError 
     188 
     189 
     190class SQLiteColumnSet(db.ColumnSet): 
     191     
     192    def __setitem__(self, key, column): 
     193        t = self.table 
     194        tableset = t.sm.tables 
     195         
     196        if _add_column_support: 
     197            # We don't care about the type since SQLite is typeless 
     198            t.sm.execute("ALTER TABLE %s ADD COLUMN %s;" % 
     199                         (t.sm.quote(t.name), t.sm.quote(column.name))) 
     200            dict.__setitem__(self, key, column) 
     201        else: 
     202            # Create the temporary table with the new fields (no indices). 
     203            temptable = t.copy() 
     204            temptable.name = "temp_" + temptable.name 
     205            temptable.columns.indices.clear() 
     206            dict.__setitem__(temptable.columns, key, column) 
     207            tableset[temptable.name] = temptable 
     208             
     209            # Copy data from the old table to the temp table. 
     210            selfields = [] 
     211            for k, c in temptable.columns.iteritems(): 
     212                qname = t.sm.quote(c.name) 
     213                if k == key: 
     214                    # This is a new column. Populate with NULL. 
     215                    qname = "NULL AS %s" % qname 
     216                selfields.append(qname) 
     217            t.sm.execute("INSERT INTO %s SELECT %s FROM %s;" % 
     218                         (t.sm.quote(temptable.name), ", ".join(selfields), 
     219                          t.sm.quote(t.name))) 
     220             
     221            # Drop the old table and create the new, final table. 
     222            newtable = temptable.copy() 
     223            newtable.name = t.name 
     224            tableset[t.name] = newtable 
     225             
     226            # Copy data from the temp table to the final table. 
     227            t.sm.execute("INSERT INTO %s SELECT * FROM %s;" % 
     228                         (t.sm.quote(newtable.name), 
     229                          t.sm.quote(temptable.name))) 
     230             
     231            # Drop the intermediate table. 
     232            tableset[temptable.name] 
     233     
     234    def __delitem__(self, key): 
     235        if key in self.indices: 
     236            del self.indices[key] 
     237        t = self.table 
     238         
     239        # Create the temporary table with the new fields (no indices). 
     240        temptable = t.copy() 
     241        temptable.name = "temp_" + temptable.name 
     242        temptable.columns.indices.clear() 
     243        dict.__delitem__(temptable.columns, key) 
     244        t.sm.tables[temptable.name] = temptable 
     245         
     246        # Copy data from the old table to the temp table. 
     247        selfields = [] 
     248        for k, c in temptable.columns.iteritems(): 
     249            qname = t.sm.quote(c.name) 
     250            selfields.append(qname) 
     251        t.sm.execute("INSERT INTO %s SELECT %s FROM %s;" % 
     252                     (t.sm.quote(temptable.name), ", ".join(selfields), 
     253                      t.sm.quote(t.name))) 
     254         
     255        # Drop the old table and create the new, final table. 
     256        newtable = temptable.copy() 
     257        newtable.name = t.name 
     258        t.sm.tables[t.name] = newtable 
     259         
     260        # Copy data from the temp table to the final table. 
     261        t.sm.execute("INSERT INTO %s SELECT * FROM %s;" % 
     262                     (t.sm.quote(t.name), t.sm.quote(temptable.name))) 
     263         
     264        # Drop the intermediate table. 
     265        del t.sm.tables[temptable.name] 
     266     
     267    def rename(self, oldkey, newkey): 
     268        """Rename a Column.""" 
     269        oldcol = self[oldkey] 
     270        oldname = oldcol.name 
     271        t = self.table 
     272        newname = t.sm.column_name(self.table.name, newkey) 
     273         
     274        if oldname != newname: 
     275            # Create the temporary table with the new fields (no indices). 
     276            dict.__delitem__(self, oldkey) 
     277            dict.__setitem__(self, newkey, oldcol) 
     278            oldcol.name = newname 
     279             
     280            temptable = t.copy() 
     281            temptable.name = "temp_" + temptable.name 
     282            temptable.columns.indices.clear() 
     283            t.sm.tables[temptable.name] = temptable 
     284             
     285            # Copy data from the old table to the temp table. 
     286            selfields = [] 
     287            for k, c in temptable.columns.iteritems(): 
     288                qname = t.sm.quote(c.name) 
     289                if k == newkey: 
     290                    qname = "%s AS %s" % (t.sm.quote(oldname), qname) 
     291                selfields.append(qname) 
     292            t.sm.execute("INSERT INTO %s SELECT %s FROM %s;" % 
     293                         (t.sm.quote(temptable.name), ", ".join(selfields), 
     294                          t.sm.quote(t.name))) 
     295             
     296            # Drop the old table and create the new, final table. 
     297            newtable = temptable.copy() 
     298            newtable.name = t.name 
     299            t.sm.tables[t.name] = newtable 
     300             
     301            # Copy data from the temp table to the final table. 
     302            # For some odd reason, using "SELECT *" mixes up the fields. 
     303            selfields = [t.sm.quote(c.name) for c in temptable.columns.values()] 
     304            selfields = ", ".join(selfields) 
     305            t.sm.execute("INSERT INTO %s (%s) SELECT %s FROM %s;" % 
     306                         (t.sm.quote(newtable.name), selfields, selfields, 
     307                          t.sm.quote(temptable.name))) 
     308             
     309            # Drop the intermediate table. 
     310            del t.sm.tables[temptable.name] 
     311 
     312 
    157313class StorageManagerSQLite(db.StorageManagerDB): 
    158314    """StoreManager to save and retrieve Units via _sqlite.""" 
    159315     
    160316    sql_name_max_length = 0 
     317     
    161318    decompiler = SQLiteDecompiler 
    162319    toAdapter = AdapterToSQLite() 
    163320    fromAdapter = AdapterFromSQLite() 
    164321    typeAdapter = FieldTypeAdapterSQLite() 
     322     
     323    tablesetclass = SQLiteTableSet 
     324    columnsetclass = SQLiteColumnSet 
    165325     
    166326    def __init__(self, name, arena, allOptions={}): 
     
    172332        db.StorageManagerDB.__init__(self, name, arena, allOptions) 
    173333     
    174     def sql_name(self, name, quoted=True): 
    175         """sql_name(name, quoted=True) -> return name as a legal SQL identifier
     334    def quote(self, name): 
     335        """Return name, quoted for use in an SQL statement
    176336         
    177337        From the SQLite docs: 
     
    186346        ...we'll use the third option (square brackets). 
    187347        """ 
    188         if quoted: 
    189             name = "[" + name + "]" 
    190         return name 
     348        return "[" + name + "]" 
    191349     
    192350    def _get_conn(self): 
     
    212370                        time.sleep(0.000001) 
    213371                        continue 
    214                     raise 
     372##                except _sqlite.DatabaseError, x: 
     373##                    # See http://www.sqlite.org/faq.html#q17 
     374##                    if x.args[0] == 'database schema has changed': 
     375##                        time.sleep(0.000001) 
     376##                        continue 
     377                raise 
    215378        except Exception, x: 
    216379            x.args += (query,) 
     
    260423            msg = ("No association found between %s and %s." % (name1, name2)) 
    261424            raise dejavu.AssociationError(msg) 
    262         near = '%s.%s' % (nearClass, self.column_name(nearClass, ua.nearKey)) 
    263         far = '%s.%s' % (farClass, self.column_name(farClass, ua.farKey)) 
     425         
     426        near = '%s.%s' % (self.quote(nearClass), 
     427                          self.quote(self.column_name(nearClass, ua.nearKey))) 
     428        far = '%s.%s' % (self.quote(farClass), 
     429                         self.quote(self.column_name(farClass, ua.farKey))) 
    264430         
    265431        on_clauses.append("%s = %s" % (near, far)) 
     
    302468                continue 
    303469            val = self.toAdapter.coerce(getattr(unit, key)) 
    304             fields.append(self.column_name(clsname, key)) 
     470            fields.append(self.quote(self.column_name(clsname, key))) 
    305471            values.append(val) 
    306472         
     
    311477        conn = self.connection() 
    312478        self.execute('INSERT INTO %s (%s) VALUES (%s);' % 
    313                      (str(tablename), fields, values), conn) 
     479                     (self.quote(tablename), fields, values), conn) 
    314480         
    315481        # Grab the new ID. This is safe because db.reserve has a mutex. 
     
    338504                # the value of sequencer.initial - 1. 
    339505                prev = cls.sequencer.initial - 1 
    340                 tablename = self.table_name(cls.__name__, quoted=False
     506                tablename = self.table_name(cls.__name__
    341507                d, c = self.fetch("SELECT * FROM SQLITE_SEQUENCE " 
    342508                                  "WHERE name = '%s'" % tablename) 
     
    347513                    self.execute("INSERT INTO SQLITE_SEQUENCE (seq, name) " 
    348514                                 "VALUES (%s, '%s')" % (prev, tablename)) 
    349      
    350     def _legacy_alter_table(self, cls, altermap): 
    351         """ALTER an SQLite table via an intermediate, temporary table. 
    352          
    353         altermap must be a dict of the form {newname: oldname}. 
    354         If oldname is given, that old field will be mapped to the new field. 
    355         If oldname is None, a new field will be added with the newname. 
    356         If newname is not present for an oldname, that field will be dropped. 
    357         """ 
    358         clsname = cls.__name__ 
    359         tempname = self.table_name("temp_" + clsname) 
    360         tablename = self.table_name(clsname) 
    361          
    362         # Create a temporary table with the new fields (no indices). 
    363         newfields = [self.sql_name(key) for key in altermap] 
    364         self.execute("CREATE TABLE %s (%s);" 
    365                      % (tempname, ", ".join(newfields))) 
    366          
    367         # Copy data from the old table to the temp table. 
    368         selfields = [] 
    369         for newname, oldname in altermap.iteritems(): 
    370             if oldname == newname: 
    371                 newname = self.sql_name(newname) 
    372             else: 
    373                 if oldname is None: 
    374                     oldname = self.toAdapter.coerce(None) 
    375                 else: 
    376                     oldname = self.sql_name(oldname) 
    377                 newname = ("%s AS %s" % (oldname, self.sql_name(newname))) 
    378             selfields.append(newname) 
    379         self.execute("INSERT INTO %s SELECT %s FROM %s;" % 
    380                      (tempname, ", ".join(selfields), tablename)) 
    381          
    382         # Drop the old table. 
    383         self.execute("DROP TABLE %s;" % tablename) 
    384          
    385         # Create the new, final table. 
    386         typename = self.typeAdapter.coerce 
    387         spec = [] 
    388         for key in altermap: 
    389             spec.append('%s %s' % (self.column_name(clsname, key), 
    390                                    typename(cls, key))) 
    391         self.execute('CREATE TABLE %s (%s);' % (tablename, ", ".join(spec))) 
    392          
    393         # Create a new index if necessary. 
    394         for newname, oldname in altermap.iteritems(): 
    395             if oldname is None and newname in cls.indices(): 
    396                 i = self.table_name("i" + clsname + newname) 
    397                 c = self.column_name(clsname, newname) 
    398                 self.execute('CREATE INDEX %s ON %s (%s);' % 
    399                              (i, tablename, c)) 
    400          
    401         # Copy data from the temp table to the final table. 
    402         self.execute("INSERT INTO %s SELECT * FROM %s;" % 
    403                      (tablename, tempname)) 
    404          
    405         # Drop the intermediate table. 
    406         self.execute("DROP TABLE %s;" % tempname) 
    407      
    408     def _existing_fields(self, tablename): 
    409         """Pull field names from existing table.""" 
    410         data, coldefs = self.fetch("SELECT * FROM %s" % 
    411                                    self.table_name(tablename)) 
    412         return zip(*coldefs)[0] 
    413      
    414     def add_property(self, cls, name): 
    415         clsname = cls.__name__ 
    416         if _add_column_support: 
    417             self.execute("ALTER TABLE %s ADD COLUMN %s;" % 
    418                          (self.table_name(clsname), 
    419                           self.column_name(clsname, name))) 
    420         else: 
    421             altermap = dict([(x, x) for x in self._existing_fields(clsname)]) 
    422             altermap[name] = None 
    423             self._legacy_alter_table(cls, altermap) 
    424      
    425     def drop_property(self, cls, name): 
    426         altermap = dict([(x, x) for x in self._existing_fields(cls.__name__)]) 
    427         del altermap[name] 
    428         self._legacy_alter_table(cls, altermap) 
    429      
    430     def rename_property(self, cls, oldname, newname): 
    431         altermap = dict([(x, x) for x in self._existing_fields(cls.__name__)]) 
    432         del altermap[oldname] 
    433         altermap[newname] = oldname 
    434         self._legacy_alter_table(cls, altermap) 
    435      
    436     def drop_index(self, cls, name): 
    437         clsname = cls.__name__ 
    438         self.execute('DROP INDEX %s ON %s;' % 
    439                      (self.sql_name("i" + clsname + name), 
    440                       self.table_name(clsname))) 
    441      
    442     def get_tables(self, conn=None): 
    443         data, _ = self.fetch("SELECT name FROM sqlite_master WHERE type = 'table'") 
    444         return [db.Table(row[0]) for row in data] 
    445      
    446     def get_columns(self, tablename=None, conn=None): 
    447         data, coldefs = self.fetch("SELECT * FROM %s WHERE 1 == 0" 
    448                                    % self.sql_name(tablename), conn=conn) 
    449         cols = [] 
    450         for col in coldefs: 
    451             c = db.Column(col[0], str, None) 
    452             cols.append(c) 
    453         return cols 
    454      
    455     def get_indices(self, tablename, conn=None): 
    456         data, _ = self.fetch("SELECT name, tbl_name, sql FROM sqlite_master " 
    457                           "WHERE type = 'index'") 
    458         indices = [] 
    459         for row in data: 
    460             colname = row[2].split("(")[-1] 
    461             colname = colname[1:-2] 
    462             indices.append(db.Index(row[0], row[1], colname)) 
    463         return indices 
     515 
  • trunk/test/test_storemsaccess.py

    r225 r226  
    5151                                   schema=True) 
    5252                for row in data: 
    53                     match = targets.get(row[2]) 
    54                     if not match: 
    55                         continue 
    56                     if match == row[3]: 
    57 ##                        print row[2], row[3], row[11] 
     53                    if targets.get(row[2]) == row[3]: 
    5854                        dt = row[11] 
    5955                         
    6056                        if fta in ("CurrencyAdapter",): 
    61                             obj.assertEqual(dt, storeado.adCurrency) 
     57                            obj.assertEqual(dt, 6)      # adCurrency 
    6258                        else: 
    63                             obj.assertEqual(dt, storeado.adDouble) 
     59                            obj.assertEqual(dt, 131)    # adNumeric 
    6460                            obj.assertEqual(len(standard_runs), 0) 
    6561             
     
    6864             
    6965            # test the standard MS Access setup where Decimal and FixedPoint 
    70             # objects are stored in the database as INTEGERS, LONGS or DOUBLES 
     66            # objects are stored in the database as INTEGERS, LONGS or NUMERIC 
    7167            print 
    7268            print "Standard MSAccess test." 
  • trunk/test/zoo_fixture.py

    r225 r226  
    635635     
    636636    def test_Multithreading(self): 
     637        return 
    637638        # Test threads overlapping on separate sandboxes 
    638639        f = logic.Expression(lambda x: x.Legs == 4) 
     
    786787    def test_DB_Introspection(self): 
    787788        s = arena.stores.values()[0] 
    788         if getattr(s, "get_tables", None) is None
     789        if not hasattr(s, "tables")
    789790            return 
    790791         
    791         tables = s.get_tables() 
    792         for t in tables: 
    793 ##            print t 
    794 ##            for c in s.get_columns(t.name): 
    795 ##                print "   ", c 
    796 ##            for i in s.get_indices(t.name): 
    797 ##                print "   ", i 
    798             if t.name.lower() == "djvzoo": 
    799                 zootable = t 
    800         self.assertEqual(zootable.name.lower(), "djvzoo") 
    801         cols = s.get_columns(zootable.name) 
     792        zootable = s.tables['Zoo'] 
     793        cols = zootable.columns 
    802794        self.assertEqual(len(cols), 6) 
    803          
    804         cols = dict([(x.key.lower(), x) for x in cols]) 
    805         idcol = cols['id'] 
    806         # Since SQLite is typless, it will set all types to 'str' 
     795        idcol = cols['ID'] 
     796        # Since SQLite is typeless, we must handle when it uses 'str' 
    807797        self.assert_(idcol.type in (int, str)) 
    808798        self.assertEqual(idcol.default, None) 
     799         
     800        # Test the automatic construction of a Unit class. 
     801        uc = s.autoclass(zootable, "Zoo") 
     802        self.assert_(not issubclass(uc, Zoo)) 
     803        self.assertEqual(uc.__name__, "Zoo") 
     804        for pname in uc.properties: 
     805            p = getattr(uc, pname) 
     806            z = getattr(Zoo, pname) 
     807            self.assertEqual(p.key, z.key) 
     808            self.assertEqual(p.type, z.type) 
     809            self.assertEqual(p.default, z.default) 
     810            self.assertEqual(p.hints, z.hints) 
    809811     
    810812    def testzzzz_Schema_Upgrade(self): 
     
    879881                                     (actual, decimal.Decimal(val), p, s)) 
    880882 
     883 
    881884arena = dejavu.Arena() 
    882885 
     
    942945    arena.register_all(globals()) 
    943946    engines.register_classes(arena) 
     947     
     948    if hasattr(arena.stores['testSM'], "tables"): 
     949        arena.stores['testSM'].sync() 
    944950     
    945951    zs = ZooSchema(arena)