Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

Changeset 229

Show
Ignore:
Timestamp:
07/18/06 20:28:17
Author:
fumanchu
Message:

Bah. The DB Introspection test wasn't running. Here are fixes for MySQL, PostreSQL. More fixes coming.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/storage/db.py

    r228 r229  
    731731        del self.db[cls.__name__].columns.indices[name] 
    732732     
    733     def sync(self, conn=None): 
     733    def sync(self): 
    734734        """Map new Table objects to all registered classes.""" 
    735735        # Use the superclass call to avoid DROP TABLE. 
    736736        dict.clear(self.db) 
    737         dbtables = self.db._get_tables(conn
     737        dbtables = self.db._get_tables(
    738738        for cls in self.arena._registered_classes: 
    739739            # Try to find a matching Table object from _get_tables. 
    740             t = [x for x in dbtables if x.name == self.table_name(cls.__name__)] 
     740            t = [x for x in dbtables if x.name == self.db.table_name(cls.__name__)] 
    741741            if t: 
    742742                t = t[0] 
     
    744744                for ckey in cls.properties: 
    745745                    # Try to find a matching Column object from _get_columns. 
    746                     c = [x for x in dbcols if x.name == self.column_name(cls.__name__, ckey)] 
     746                    c = [x for x in dbcols if x.name == self.db.column_name(cls.__name__, ckey)] 
    747747                    if c: 
    748748                        c = c[0] 
     
    752752                dbindices = self.db._get_indices(t.name) 
    753753                for ikey in cls.indices(): 
    754                     iname = self.table_name("i" + cls.__name__ + ikey) 
     754                    iname = self.db.table_name("i" + cls.__name__ + ikey) 
    755755                    # Try to find a matching Column object from _get_columns. 
    756756                    i = [x for x in dbindices if x.name == iname] 
  • trunk/storage/dbmodel.py

    r228 r229  
    839839         
    840840        t.db.execute("ALTER TABLE %s ADD COLUMN %s %s%s;" % 
    841                            (t.qname, column.qname, column.dbtype, default)) 
     841                     (t.qname, column.qname, column.dbtype, default)) 
    842842        dict.__setitem__(self, key, column) 
    843843     
     
    854854        t = self.table 
    855855        t.db.execute("ALTER TABLE %s RENAME COLUMN %s TO %s;" % 
    856                            (t.qname, oldcol.qname, newcol.qname)) 
     856                     (t.qname, oldcol.qname, newcol.qname)) 
    857857     
    858858    def rename(self, oldkey, newkey): 
  • trunk/storage/storemysql.py

    r228 r229  
    9595    def float_type(self, precision): 
    9696        """Return a datatype which can handle the given precision.""" 
    97         if precision <= 23: 
    98             return "FLOAT" 
    99         else: 
    100             return "DOUBLE PRECISION" 
     97        # "p represents the precision in bits, but MySQL uses this value 
     98        # only to determine whether to use FLOAT or DOUBLE for the 
     99        # resulting data type. If p is from 0 to 24, the data type 
     100        # becomes FLOAT with no M or D values. If p is from 25 to 53, 
     101        # the data type becomes DOUBLE with no M or D values." 
     102        return "FLOAT(%s)" % precision 
    101103     
    102104    def coerce_str(self, cls, key): 
     
    109111                return "VARBINARY(%s)" % bytes 
    110112            elif bytes < 2 ** 16: 
    111                 return "BLOB" 
     113                if self.db._version >= storage.Version("4.1"): 
     114                    return "BLOB(%s)" % bytes 
     115                else: 
     116                    return "BLOB" 
    112117            elif bytes < 2 ** 24: 
    113118                return "MEDIUMBLOB" 
     
    176181    indexsetclass = MySQLIndexSet 
    177182     
     183    def __init__(self, name, **kwargs): 
     184        db.Database.__init__(self, name, **kwargs) 
     185         
     186        self.decompiler = MySQLDecompiler 
     187         
     188        # Get the version string from MySQL, to see if we need 
     189        # a different decompiler. 
     190        conn = self._template_conn() 
     191        rowdata, cols = self.fetch("SELECT version();", conn) 
     192        conn.close() 
     193        v = rowdata[0][0] 
     194        self._version = storage.Version(v) 
     195        if self._version > storage.Version("4.1.1"): 
     196            self.decompiler = MySQLDecompiler411 
     197     
     198    def version(self): 
     199        return "MySQL Version: %s" % self._version 
     200     
    178201    def __setitem__(self, key, table): 
    179202        q = self.quote 
     
    184207            qname = col.qname 
    185208            dbtype = col.dbtype 
    186             fields.append('%s %s' % (qname, dbtype)) 
     209             
     210            default = col.default or "" 
     211            if default: 
     212                default = " DEFAULT %s" % self.adaptertosql.coerce(default) 
     213             
     214            f = '%s %s%s' % (qname, dbtype, default) 
     215            fields.append(f) 
     216             
    187217            # See create_storage for the other half of this hack. 
    188218            if colname in table.mysql_identifiers: 
     
    233263        # cols are: Field, Type, Null, Key, Default, Extra. 
    234264        # See http://dev.mysql.com/doc/refman/4.1/en/describe.html 
    235         q = self.quote 
    236265        data, _ = self.fetch("SHOW COLUMNS FROM %s.%s" % 
    237                              (self.qname, q(tablename)), conn=conn) 
     266                             (self.qname, self.quote(tablename)), conn=conn) 
    238267        cols = [] 
    239268        for row in data: 
    240             c = db.Column(row[0], self.quote(row[0]), None, row[4]
    241              
    242             dbtype = row[1] 
     269            c = db.Column(row[0], self.quote(row[0]), None, None
     270             
     271            dbtype = row[1].upper() 
    243272            parenpos = dbtype.find("(") 
    244273            if parenpos > -1: 
    245                 c.hints['bytes'] = dbtype[parenpos+1:-1] 
    246                 dbtype = dbtype[:parenpos] 
     274                args = dbtype[parenpos+1:-1] 
     275                baretype = dbtype[:parenpos] 
     276                if baretype in ("DECIMAL", "NUMERIC"): 
     277                    args = [x.strip() for x in args.split(",")] 
     278                    c.hints['precision'], c.hints['scale'] = args 
     279                else: 
     280                    c.hints['bytes'] = args 
     281            elif dbtype == "FLOAT": 
     282                c.hints['precision'] = 24 
     283            elif dbtype.startswith("DOUBLE"): 
     284                c.hints['precision'] = 53 
     285            elif dbtype in ("TINYBLOB", "TINYTEXT"): 
     286                c.hints['bytes'] = (2 ** 8) - 1 
     287            elif dbtype in ("BLOB", "TEXT"): 
     288                c.hints['bytes'] = (2 ** 16) - 1 
     289            elif dbtype in ("MEDIUMBLOB", "MEDIUMTEXT"): 
     290                c.hints['bytes'] = (2 ** 24) - 1 
     291            elif dbtype in ("LONGBLOB", "LONGTEXT"): 
     292                c.hints['bytes'] = (2 ** 32) - 1 
     293             
     294            if row[4]: 
     295                c.default = self.python_type(dbtype)(row[4]) 
     296             
     297            if row[5]: 
     298                # Usually auto_increment 
     299                dbtype += " " + row[5] 
    247300            c.dbtype = dbtype 
    248301             
    249302            cols.append(c) 
    250303        return cols 
     304     
     305    def _get_indices(self, tablename, conn=None): 
     306        indices = [] 
     307        try: 
     308            # cols are: Table, Non_unique, Key_name, Seq_in_index, Column_name, 
     309            # Collation, Cardinality, Sub_part, Packed, Null, Index_type, Comment 
     310            data, _ = self.fetch("SHOW INDEX FROM %s.%s" 
     311                                 % (self.qname, self.quote(tablename)), 
     312                                 conn=conn) 
     313        except _mysql.ProgrammingError, x: 
     314            if x.args[0] != 1146: 
     315                raise 
     316        else: 
     317            for row in data: 
     318                i = db.Index(row[2], self.quote(row[2]), 
     319                             row[0], row[4], None, not row[1]) 
     320                indices.append(i) 
     321        return indices 
    251322     
    252323    def python_type(self, dbtype): 
    253324        """Return a Python type which can store values of the given dbtype.""" 
    254325        dbtype = dbtype.upper() 
    255          
    256         if dbtype.endswith("AUTO_INCREMENT"): 
    257             return int 
    258         elif dbtype in ('TINYINT', 'SMALLINT', 'MEDIUMINT', 'INT', 'INTEGER'): 
     326        dbtype = dbtype.replace(" AUTO_INCREMENT", "") 
     327        parenpos = dbtype.find("(") 
     328        if parenpos > -1: 
     329            dbtype = dbtype[:parenpos] 
     330         
     331        if dbtype in ('TINYINT', 'SMALLINT', 'MEDIUMINT', 'INT', 'INTEGER'): 
    259332            return int 
    260333        elif dbtype == 'BIGINT': 
     
    262335        elif dbtype in ('FLOAT', 'DOUBLE', 'DOUBLE PRECISION', 'REAL'): 
    263336            return float 
    264         elif dbtype.startswith('DECIMAL') or dbtype.startswith('NUMERIC'): 
     337        elif dbtype in ('DECIMAL', 'NUMERIC'): 
    265338            if db.decimal: 
    266339                return db.decimal.Decimal 
     
    273346        elif dbtype == 'TIME': 
    274347            return datetime.time 
    275         for t in ('CHAR', 'VARCHAR', 'BINARY', 'VARBINARY', 
    276                   'TINYBLOB', 'TINYTEXT', 'BLOB', 'TEXT', 
    277                   'MEDIUMBLOB', 'MEDIUMTEXT', 'LONGBLOB', 'LONGTEXT'): 
    278             if dbtype.startswith(t): 
    279                 return str 
     348        elif dbtype in ('CHAR', 'VARCHAR', 'BINARY', 'VARBINARY', 
     349                        'TINYBLOB', 'TINYTEXT', 'BLOB', 'TEXT', 
     350                        'MEDIUMBLOB', 'MEDIUMTEXT', 'LONGBLOB', 'LONGTEXT'): 
     351            return str 
    280352         
    281353        raise TypeError("Database type %s could not be converted " 
    282354                        "to a Python type." % repr(dbtype)) 
    283      
    284     def _get_indices(self, tablename, conn=None): 
    285         indices = [] 
    286         try: 
    287             # cols are: Table, Non_unique, Key_name, Seq_in_index, Column_name, 
    288             # Collation, Cardinality, Sub_part, Packed, Null, Index_type, Comment 
    289             q = self.sm.quote 
    290             data, _ = self.fetch("SHOW INDEX FROM %s.%s" 
    291                                  % (self.qname, q(tablename)), conn=conn) 
    292         except _mysql.ProgrammingError, x: 
    293             if x.args[0] != 1146: 
    294                 raise 
    295         else: 
    296             for row in data: 
    297                 i = db.Index(row[2], self.quote(row[2]), 
    298                              row[0], row[4], None, not row[1]) 
    299                 indices.append(i) 
    300         return indices 
    301355     
    302356    def quote(self, name): 
     
    363417                         if k in connargs]) 
    364418        allOptions['connargs'] = connargs 
    365          
    366419        allOptions['name'] = connargs['db'] 
    367          
    368420        db.StorageManagerDB.__init__(self, name, arena, allOptions) 
    369          
    370         self.db.decompiler = MySQLDecompiler 
    371         # Get the version string from MySQL, to see if we need 
    372         # a different decompiler. 
    373         conn = self.db._template_conn() 
    374         rowdata, cols = self.db.fetch("SELECT version();", conn) 
    375         conn.close() 
    376         v = rowdata[0][0] 
    377         self._version = storage.Version(v) 
    378         if self._version > storage.Version("4.1.1"): 
    379             self.db.decompiler = MySQLDecompiler411 
     421        self.typeAdapter.db = self.db 
    380422     
    381423    def destroy(self, unit): 
     
    383425        t = self.db[unit.__class__.__name__].qname 
    384426        self.db.execute('DELETE FROM %s WHERE %s;' % (t, self.id_clause(unit))) 
    385      
    386     def version(self): 
    387         return "MySQL Version: %s" % self._version 
    388427     
    389428    def _seq_UnitSequencerInteger(self, unit): 
  • trunk/storage/storepypgsql.py

    r228 r229  
    1212    like_escapes = [("%", r"\\%"), ("_", r"\\_")] 
    1313 
    14  
    15 class FieldTypeAdapterPgSQL(db.FieldTypeAdapter): 
    16      
    17     def coerce_int(self, cls, key): 
    18         prop = getattr(cls, key) 
    19         if isinstance(cls.sequencer, dejavu.UnitSequencerInteger): 
    20             if key in cls.identifiers: 
    21                 seqname = self.sm.db.quote("%s_%s_seq" % (cls.__name__, key)) 
    22                 return "INTEGER DEFAULT nextval('%s') NOT NULL" % seqname 
    23         bytes = int(prop.hints.get('bytes', db.maxint_bytes)) 
    24         return self.int_type(bytes) 
    2514 
    2615 
     
    6857        self.table = table 
    6958        self.indices = PgIndexSet(self.table) 
     59     
     60    def __setitem__(self, key, column): 
     61        t = self.table 
     62        if key in self: 
     63            del self[key] 
     64         
     65        default = column.default or "" 
     66        if default: 
     67            if not(isinstance(default, str) 
     68                   and default.startswith("nextval(")): 
     69                default = t.db.adaptertosql.coerce(default) 
     70            default = " DEFAULT %s" % default 
     71         
     72        t.db.execute("ALTER TABLE %s ADD COLUMN %s %s%s;" % 
     73                     (t.qname, column.qname, column.dbtype, default)) 
     74        dict.__setitem__(self, key, column) 
    7075 
    7176 
     
    8994                             % tablename, conn=conn) 
    9095        table_OID = data[0][0] 
     96         
    9197        sql = ("SELECT attname, atttypid, attnum, attlen, atttypmod " 
    9298               "FROM pg_attribute WHERE attrelid = %s" % table_OID) 
     
    110116             
    111117            if dbtype in ('float4', 'float8', 'money', 'numeric'): 
    112                 c.hints['precision'] = row[4
     118                c.hints['precision'] = row[3
    113119             
    114120            # Default value 
     
    118124            if default: 
    119125                default = default[0][0] 
    120                 # Sequences 
    121126                if default.startswith("nextval("): 
    122                     default = None 
     127                    # Sequence. Strip the trailing "::text" if present 
     128                    default = default.replace("::text", "") 
    123129                else: 
    124130                    # adsrc is always a string, so we must cast 
     
    139145            cols.append(c) 
    140146        return cols 
    141          
    142         raise TypeError("Database type %s could not be converted " 
    143                         "to a Python type." % repr(dbtype)) 
    144147     
    145148    def _get_indices(self, tablename, conn=None): 
     
    171174    def python_type(self, dbtype): 
    172175        """Return a Python type which can store values of the given dbtype.""" 
    173         if "nextval(" in dbtype: 
    174             return int 
    175          
    176176        dbtype = dbtype.upper() 
    177177        if dbtype in ('INT2', 'INT4', 'INTEGER'): 
     
    181181        elif dbtype == 'INT8': 
    182182            return long 
    183         elif dbtype in ('FLOAT4', 'FLOAT8', 'MONEY', 'DOUBLE PRECISION'): 
     183        elif dbtype in ('FLOAT4', 'FLOAT8', 'MONEY', 'DOUBLE PRECISION', 'REAL'): 
    184184            return float 
    185185        elif dbtype.startswith('NUMERIC'): 
     
    197197            if dbtype.startswith(t): 
    198198                return str 
     199         
     200        raise TypeError("Database type %s could not be converted " 
     201                        "to a Python type." % repr(dbtype)) 
     202     
     203    def __setitem__(self, key, table): 
     204        if key in self: 
     205            del self[key] 
     206         
     207        fields = [] 
     208        for col in table.columns.itervalues(): 
     209            default = col.default or "" 
     210            if default: 
     211                if not(isinstance(default, str) 
     212                       and default.startswith("nextval(")): 
     213                    default = self.adaptertosql.coerce(default) 
     214                default = " DEFAULT %s" % default 
     215            f = '%s %s%s' % (col.qname, col.dbtype, default) 
     216            fields.append(f) 
     217         
     218        self.execute('CREATE TABLE %s (%s);' % 
     219                     (table.qname, ", ".join(fields))) 
     220         
     221        for index in table.columns.indices.itervalues(): 
     222            self.execute('CREATE INDEX %s ON %s (%s);' % 
     223                         (index.qname, table.qname, 
     224                          self.quote(index.colname))) 
     225         
     226        dict.__setitem__(self, key, table) 
    199227     
    200228    def quote(self, name): 
     
    266294     
    267295    databaseclass = PgDatabase 
    268     typeAdapter = FieldTypeAdapterPgSQL() 
    269296     
    270297    def __init__(self, name, arena, allOptions={}): 
     
    286313        for key in cls.properties: 
    287314            col = t.columns[key] 
    288             if 'nextval' in col.dbtype
     315            if isinstance(col.default, str) and col.default.startswith('nextval')
    289316                # Skip this field, since we're using a sequencer 
    290317                continue 
     
    313340        fields = [] 
    314341        for key in cls.properties: 
    315             dbtype = self.typeAdapter.coerce(cls, key) 
    316              
    317             col = self.db.make_column(cls.__name__, key, dbtype) 
    318              
    319             prop = cls.property(key) 
    320             col.default = prop.default 
    321             col.hints = prop.hints.copy() 
    322              
    323             # Here's where we differ from the superclass: 
    324             # we have to manually CREATE SEQUENCE, 
    325             # and use class attributes to do so. 
    326             if 'nextval' in dbtype: 
    327                 seqname = self.db.quote("%s_%s_seq" % (t.name, col.name)) 
    328                 self.db.execute("CREATE SEQUENCE %s START %s;" 
    329                                 % (seqname, cls.sequencer.initial)) 
    330              
    331342            # Use the superclass call to avoid ALTER TABLE. 
    332             dict.__setitem__(t.columns, key, col
     343            dict.__setitem__(t.columns, key, self.make_column(cls, key)
    333344             
    334345            if key in indices: 
     
    339350        # Attach to self.db, which should call CREATE TABLE. 
    340351        self.db[cls.__name__] = t 
    341  
     352     
     353    def make_column(self, cls, key): 
     354        dbtype = self.typeAdapter.coerce(cls, key) 
     355         
     356        col = self.db.make_column(cls.__name__, key, dbtype) 
     357        prop = getattr(cls, key) 
     358        col.default = prop.default 
     359         
     360        # Here's where we differ from the superclass: 
     361        # we have to manually CREATE SEQUENCE, 
     362        # and use class attributes to do so. 
     363        if key in cls.identifiers: 
     364            if isinstance(cls.sequencer, dejavu.UnitSequencerInteger): 
     365                tname = self.db.table_name(cls.__name__) 
     366                seqname = self.db.quote("%s_%s_seq" % (tname, col.name)) 
     367                col.default = "nextval('%s')" % seqname 
     368                self.db.execute("CREATE SEQUENCE %s START %s;" 
     369                                % (seqname, cls.sequencer.initial)) 
     370         
     371        col.hints = prop.hints.copy() 
     372        return col 
     373     
     374    def add_property(self, cls, name): 
     375        if not self.has_property(cls, name): 
     376            self.db[cls.__name__].columns[name] = self.make_column(cls, name) 
     377     
     378    def autoclass(self, table, newclassname=None): 
     379        """Create a Unit class automatically from this table and its columns.""" 
     380        class AutoUnitClass(dejavu.Unit): 
     381            pass 
     382        for cname, c in table.columns.iteritems(): 
     383            ptype = self.db.python_type(c.dbtype) 
     384            p = AutoUnitClass.set_property(cname, ptype) 
     385            if isinstance(c.default, str) and c.default.startswith("nextval("): 
     386                AutoUnitClass.identifiers += (cname,) 
     387                initial = self.db.fetch("SELECT min_value FROM %s" % 
     388                                        c.default[9:-2])[0][0] 
     389                AutoUnitClass.sequencer = dejavu.UnitSequencerInteger(int, initial) 
     390                p.default = None 
     391            else: 
     392                p.default = c.default 
     393            p.hints = c.hints.copy() 
     394         
     395        if newclassname is None: 
     396            newclassname = table.name 
     397        AutoUnitClass.__name__ = newclassname 
     398         
     399        return AutoUnitClass 
     400 
  • trunk/test/zoo_fixture.py

    r228 r229  
    5959    PreviousZoos = UnitProperty(list, hints={'bytes': 1000}) 
    6060    LastEscape = EscapeProperty(datetime.datetime) 
    61     Lifespan = UnitProperty(float, hints={'bytes': 4}) 
    62     Age = UnitProperty(float, hints={'bytes': 4}, default=1) 
     61    Lifespan = UnitProperty(float, hints={'precision': 4}) 
     62    Age = UnitProperty(float, hints={'precision': 4}, default=1) 
    6363    MotherID = UnitProperty(int) 
    6464 
     
    786786    def test_DB_Introspection(self): 
    787787        s = arena.stores.values()[0] 
    788         if not hasattr(s, "tables"): 
     788        if not hasattr(s, "db"): 
    789789            return 
    790790         
    791         zootable = s.tables['Zoo'] 
     791        zootable = s.db['Zoo'] 
    792792        cols = zootable.columns 
    793793        self.assertEqual(len(cols), 6) 
    794794        idcol = cols['ID'] 
    795795        # Since SQLite is typeless, it always returns 'str' 
    796         self.assert_(s.tables.python_type(idcol.dbtype) in (int, str)) 
    797         self.assertEqual(idcol.default, None) 
     796        self.assert_(s.db.python_type(idcol.dbtype) in (int, str)) 
    798797         
    799798        # Test the automatic construction of Unit classes. 
     
    801800            for cls in (Zoo, Animal): 
    802801                print cls.__name__, 
    803                 t = s.tables[cls.__name__] 
     802                t = s.db[cls.__name__] 
    804803                uc = s.autoclass(t, cls.__name__) 
    805804                self.assert_(not issubclass(uc, cls)) 
     
    814813                    for k, v in orig.hints.iteritems(): 
    815814                        if isinstance(v, (int, long)): 
    816                             self.assert_(copy.hints[k] >= v) 
     815                            self.assert_(copy.hints[k] >= v, 
     816                                         "%s not >= %s" % (copy.hints[k], v)) 
    817817                        else: 
    818818                            self.assertEqual(copy.hints[k], v) 
     
    961961    engines.register_classes(arena) 
    962962     
    963     if hasattr(arena.stores['testSM'], "tables"): 
     963    if hasattr(arena.stores['testSM'], "db"): 
    964964        arena.stores['testSM'].sync() 
    965965