Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

Changeset 189

Show
Ignore:
Timestamp:
03/09/06 08:48:26
Author:
fumanchu
Message:

Ugly first crack at database introspection. I don't like the API yet, but the functionality is getting there.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/storage/db.py

    r188 r189  
    14311431        if self.has_property(cls, name): 
    14321432            clsname = cls.__name__ 
    1433             self.execute('DROP INDEX %s ON %s;' % 
    1434                          (self.table_name("i" + clsname + name), 
    1435                           self.table_name(clsname))) 
     1433            if self.has_index(cls, name): 
     1434                self.drop_index(cls, name) 
    14361435            self.execute("ALTER TABLE %s DROP COLUMN %s;" % 
    14371436                         (self.table_name(clsname), 
     
    14451444            self.execute("ALTER TABLE %s RENAME COLUMN %s TO %s;" % 
    14461445                         (self.table_name(clsname), oldname, newname)) 
     1446     
     1447    def has_index(self, cls, name): 
     1448        return name in [i.name for i in self.get_indices(cls.__name__)] 
     1449     
     1450    def drop_index(self, cls, name): 
     1451        clsname = cls.__name__ 
     1452        self.execute('DROP INDEX %s ON %s;' % 
     1453                     (self.sql_name("i" + clsname + name), 
     1454                      self.table_name(clsname))) 
     1455 
     1456 
     1457class Table: 
     1458    """A table in a database.""" 
     1459     
     1460    def __init__(self, name): 
     1461        self.name = name 
     1462        self.columns = [] 
     1463     
     1464    def __repr__(self): 
     1465        return "dejavu.db.Table(%s)" % repr(self.name) 
     1466 
     1467 
     1468class Column: 
     1469    """A column in a table in a database.""" 
     1470     
     1471    def __init__(self, key, type, default=None): 
     1472        self.key = key 
     1473        self.type = type 
     1474        self.default = default 
     1475        self.hints = {} 
     1476     
     1477    def __repr__(self): 
     1478        return ("dejavu.db.Column(%s, %s, default=%s, hints=%s)" % 
     1479                (repr(self.key), repr(self.type), 
     1480                 repr(self.default), repr(self.hints)) 
     1481                ) 
     1482 
     1483 
     1484class Index: 
     1485    """An index on a table column (or columns) in a database.""" 
     1486     
     1487    def __init__(self, name, tablename, colname, pk=True, unique=True): 
     1488        self.name = name 
     1489        self.tablename = tablename 
     1490        self.colname = colname 
     1491        self.pk = pk 
     1492        self.unique = unique 
     1493     
     1494    def __repr__(self): 
     1495        return ("dejavu.db.Index(%s, %s, %s, pk=%s, unique=%s)" % 
     1496                (repr(self.name), repr(self.tablename), repr(self.colname), 
     1497                 repr(self.pk), repr(self.unique))) 
    14471498 
    14481499 
  • trunk/storage/storeado.py

    r188 r189  
    5353# 12/30/1899, the zero-Date for ADO = 693594 
    5454zeroHour = datetime.date(1899, 12, 30).toordinal() 
     55 
     56# DataTypeEnum 
     57adEmpty = 0 
     58adSmallInt = 2 
     59adInteger = 3 
     60adSingle = 4 
     61adDouble = 5 
     62adCurrency = 6 
     63adDate = 7 
     64adBSTR = 8 
     65adIDispatch = 9 
     66adError = 10 
     67adBoolean = 11 
     68adVariant = 12 
     69adIUnknown = 13 
     70adDecimal = 14 
     71adTinyInt = 16 
     72adUnsignedTinyInt = 17 
     73adUnsignedSmallInt = 18 
     74adUnsignedInt = 19 
     75adBigInt = 20 
     76adUnsignedBigInt = 21 
     77adGUID = 72 # e.g. {E5D50A9B-33D2-11D3-AAB3-00104BA31425} 
     78adBinary = 128 
     79adChar = 129 
     80adWChar = 130 
     81adNumeric = 131 
     82adUserDefined = 132 
     83adDBDate = 133 
     84adDBTime = 134 
     85adDBTimeStamp = 135 
     86adVarChar = 200 
     87adLongVarChar = 201 
     88adVarWChar = 202 
     89adLongVarWChar = 203 
     90adVarBinary = 204 
     91adLongVarBinary = 205 
    5592 
    5693 
     
    400437     
    401438    def has_storage(self, cls): 
    402         data, col_defs = self.get_tables() 
    403         names = [x[2] for x in data] 
     439        names = [t.name for t in self.get_tables()] 
    404440        return self.table_name(cls.__name__, quoted=False) in names 
    405441     
     
    420456            raise 
    421457     
    422     def drop_property(self, cls, name): 
    423         if self.has_property(cls, name): 
    424             clsname = cls.__name__ 
    425              
    426             tablename = self.table_name(clsname, quoted=False) 
    427             qtablename = self.table_name(clsname) 
    428             colname = self.column_name(clsname, name, quoted=False) 
    429             qcolname = self.column_name(clsname, name) 
    430              
    431             data, cols = self.get_indices() 
    432             for i in data: 
    433                 if i[2] == tablename and i[17] == colname: 
    434                     # The INDEX_NAME may include a trailing " ASC" or other data 
    435                     self.execute('DROP INDEX [%s] ON %s;' % (i[5], qtablename)) 
    436             self.execute("ALTER TABLE %s DROP COLUMN %s;" % 
    437                          (qtablename, qcolname)) 
     458    def drop_index(self, cls, name): 
     459        clsname = cls.__name__ 
     460        tablename = self.table_name(clsname, quoted=False) 
     461        qtablename = self.table_name(clsname) 
     462        colname = self.column_name(clsname, name, quoted=False) 
     463         
     464        for i in self.get_indices(): 
     465            if i.tablename == tablename and i.colname == colname: 
     466                # The INDEX_NAME may include a trailing " ASC" or other data 
     467                self.execute('DROP INDEX [%s] ON %s;' % (i.name, qtablename)) 
    438468     
    439469    def get_tables(self, conn=None): 
    440         return self.fetch(adSchemaTables, conn=conn, schema=True) 
    441      
    442     def get_columns(self, conn=None): 
    443         return self.fetch(adSchemaColumns, conn=conn, schema=True) 
    444      
    445     def get_indices(self, conn=None): 
    446         # returned cols will be 
     470        # cols will be 
     471        # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 
     472        # (u'TABLE_TYPE', 202), (u'TABLE_GUID', 72), (u'DESCRIPTION', 203), 
     473        # (u'TABLE_PROPID', 19), (u'DATE_CREATED', 7), (u'DATE_MODIFIED', 7)] 
     474        data, cols = self.fetch(adSchemaTables, conn=conn, schema=True) 
     475        return [db.Table(row[2]) for row in data] 
     476     
     477    def get_columns(self, tablename, conn=None): 
     478        # cols will be 
     479        # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 
     480        # (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72), (u'COLUMN_PROPID', 19), 
     481        # (u'ORDINAL_POSITION', 19), (u'COLUMN_HASDEFAULT', 11), 
     482        # (u'COLUMN_DEFAULT', 203), (u'COLUMN_FLAGS', 19), (u'IS_NULLABLE', 11), 
     483        # (u'DATA_TYPE', 18), (u'TYPE_GUID', 72), (u'CHARACTER_MAXIMUM_LENGTH', 19), 
     484        # (u'CHARACTER_OCTET_LENGTH', 19), (u'NUMERIC_PRECISION', 18), 
     485        # (u'NUMERIC_SCALE', 2), (u'DATETIME_PRECISION', 19), 
     486        # (u'CHARACTER_SET_CATALOG', 202), (u'CHARACTER_SET_SCHEMA', 202), 
     487        # (u'CHARACTER_SET_NAME', 202), (u'COLLATION_CATALOG', 202), 
     488        # (u'COLLATION_SCHEMA', 202), (u'COLLATION_NAME', 202), 
     489        # (u'DOMAIN_CATALOG', 202), (u'DOMAIN_SCHEMA', 202), 
     490        # (u'DOMAIN_NAME', 202), (u'DESCRIPTION', 203)] 
     491        data, _ = self.fetch(adSchemaColumns, conn=conn, schema=True) 
     492        cols = [] 
     493        for row in data: 
     494            # I tried passing criteria to OpenSchema, but passing None is 
     495            # not the same as passing pythoncom.Empty (which errors). 
     496            if tablename and row[2] != tablename: 
     497                continue 
     498            datatype = row[11] 
     499            c = db.Column(row[3], None, row[8]) 
     500            if datatype in (adDate, adDBDate): 
     501                c.type = datetime.date 
     502            elif datatype == adDBTime: 
     503                c.type = datetime.time 
     504            elif datatype == adDBTimeStamp: 
     505                c.type = datetime.datetime 
     506            elif datatype in (adSmallInt, adInteger, adTinyInt, 
     507                              adUnsignedTinyInt, adUnsignedSmallInt, 
     508                              adUnsignedInt): 
     509                c.type = int 
     510                c.hints['bytes'] = row[15] 
     511            elif datatype == adBoolean: 
     512                c.type = bool 
     513            elif datatype in (adBigInt, adUnsignedBigInt): 
     514                c.type = long 
     515                c.hints['bytes'] = row[15] 
     516            elif datatype in (adSingle, adDouble, adCurrency): 
     517                c.type = float 
     518                c.hints['bytes'] = row[15] 
     519            elif datatype in (adDecimal, adNumeric): 
     520                c.type = decimal.Decimal 
     521                c.hints['bytes'] = row[15] 
     522            elif datatype in (adBSTR, adVariant, adBinary, adChar, 
     523                              adVarChar, adLongVarChar, 
     524                              adVarBinary, adLongVarBinary): 
     525                c.type = str 
     526                if row[13]: 
     527                    c.hints['bytes'] = row[13] 
     528            elif datatype in (adWChar, adVarWChar, adLongVarWChar): 
     529                c.type = unicode 
     530                if row[13]: 
     531                    c.hints['bytes'] = row[13] 
     532            cols.append(c) 
     533        return cols 
     534     
     535    def get_indices(self, tablename=None, conn=None): 
     536        # cols will be 
    447537        # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 
    448538        # (u'INDEX_CATALOG', 202), (u'INDEX_SCHEMA', 202), (u'INDEX_NAME', 202), 
     
    453543        # (u'COLUMN_PROPID', 19), (u'COLLATION', 2), (u'CARDINALITY', 21), 
    454544        # (u'PAGES', 3), (u'FILTER_CONDITION', 202), (u'INTEGRATED', 11)] 
    455         return self.fetch(adSchemaIndexes, conn=conn, schema=True) 
     545        data, _ = self.fetch(adSchemaIndexes, conn=conn, schema=True) 
     546        indices = [] 
     547        for row in data: 
     548            # I tried passing criteria to OpenSchema, but passing None is 
     549            # not the same as passing pythoncom.Empty (which errors). 
     550            if tablename and row[2] != tablename: 
     551                continue 
     552            indices.append(db.Index(row[5], row[2], row[17], row[6], row[7])) 
     553        return indices 
    456554 
    457555 
  • trunk/storage/storemysql.py

    r188 r189  
    1616import warnings 
    1717import datetime 
     18 
     19try: 
     20    # Builtin in Python 2.5? 
     21    decimal 
     22except NameError: 
     23    try: 
     24        # Module in Python 2.3, 2.4 
     25        import decimal 
     26    except ImportError: 
     27        pass 
    1828 
    1929import dejavu 
     
    159169         
    160170        self.decompiler = MySQLDecompiler 
    161         # Try to get the version string from MySQL, to see if we need 
     171        # Get the version string from MySQL, to see if we need 
    162172        # a different decompiler. 
    163173        conn = self._template_conn() 
    164         data, columns = self.fetch("SELECT VERSION();", conn) 
    165         if data: 
    166             version = storage.Version(data[0][0]) 
    167             if version > storage.Version("4.1.1"): 
    168                 self.decompiler = MySQLDecompiler411 
     174        rowdata, cols = self.fetch("SELECT version();", conn) 
    169175        conn.close() 
     176        v = rowdata[0][0] 
     177        self._version = storage.Version(v) 
     178        if self._version > storage.Version("4.1.1"): 
     179            self.decompiler = MySQLDecompiler411 
    170180     
    171181    def sql_name(self, name, quoted=True): 
     
    211221     
    212222    def version(self): 
    213         conn = self._template_conn() 
    214         rowdata, cols = self.fetch("SELECT version();", conn) 
    215         conn.close() 
    216         return "MySQL Version: %s" % rowdata[0][0] 
     223        return "MySQL Version: %s" % self._version 
    217224     
    218225    def _seq_UnitSequencerInteger(self, unit): 
     
    329336                         (self.table_name(clsname), oldcolname, newcolname, 
    330337                          self.typeAdapter.coerce(cls, newname))) 
    331  
     338     
     339    def drop_index(self, cls, name): 
     340        # MySQL might rename multiple-column indices to "PRIMARY" 
     341        clsname = cls.__name__ 
     342        names = [] 
     343        for i in self.get_indices(self.table_name(clsname, quoted=False)): 
     344            if i.name not in names: 
     345                names.append(i.name) 
     346        for n in names: 
     347            self.execute('DROP INDEX %s ON %s;' % 
     348                         (self.sql_name(n), self.table_name(clsname))) 
     349     
     350    def get_tables(self, conn=None): 
     351        data, _ = self.fetch("SHOW TABLES FROM %s" % self.dbname, 
     352                             conn=conn) 
     353        return [db.Table(row[0]) for row in data] 
     354     
     355    def get_columns(self, tablename=None, conn=None): 
     356        # cols are: Field, Type, Null, Key, Default, Extra. 
     357        # See http://dev.mysql.com/doc/refman/4.1/en/describe.html 
     358        data, _ = self.fetch("SHOW COLUMNS FROM %s.%s" 
     359                             % (self.dbname, self.sql_name(tablename)), 
     360                             conn=conn) 
     361        cols = [] 
     362        for row in data: 
     363            c = db.Column(row[0], None, row[4]) 
     364             
     365            dbtype = row[1] 
     366            parenpos = dbtype.find("(") 
     367            if parenpos > -1: 
     368                c.hints['bytes'] = dbtype[parenpos+1:-1] 
     369                dbtype = dbtype[:parenpos] 
     370             
     371            if dbtype in ('tinyint', 'smallint', 'mediumint', 'int', 'integer'): 
     372                c.type = int 
     373            elif dbtype == 'bigint': 
     374                c.type = long 
     375            elif dbtype in ('float', 'double', 'real'): 
     376                c.type = float 
     377            elif dbtype in ('decimal', 'numeric'): 
     378                c.type = decimal.Decimal 
     379            elif dbtype == 'date': 
     380                c.type = datetime.date 
     381            elif dbtype in ('datetime', 'timestamp'): 
     382                c.type = datetime.datetime 
     383            elif dbtype == 'time': 
     384                c.type = datetime.time 
     385            elif dbtype in ('char', 'varchar', 'binary', 'varbinary', 
     386                            'tinyblob', 'tinytext', 'blob', 'text', 
     387                            'mediumblob', 'mediumtext', 
     388                            'longblob', 'longtext'): 
     389                c.type = str 
     390            cols.append(c) 
     391        return cols 
     392     
     393    def get_indices(self, tablename, conn=None): 
     394        indices = [] 
     395        try: 
     396            # cols are: Table, Non_unique, Key_name, Seq_in_index, Column_name, 
     397            # Collation, Cardinality, Sub_part, Packed, Null, Index_type, Comment 
     398            data, _ = self.fetch("SHOW INDEX FROM %s.%s" 
     399                                 % (self.dbname, self.sql_name(tablename)), 
     400                                 conn=conn) 
     401        except _mysql.ProgrammingError, x: 
     402            if x.args[0] != 1146: 
     403                raise 
     404        else: 
     405            for row in data: 
     406                indices.append(db.Index(row[2], row[0], row[4], None, not row[1])) 
     407        return indices 
     408 
  • trunk/storage/storepypgsql.py

    r188 r189  
    11# Use libpq directly to avoid all of the DB-API overhead. 
    22from pyPgSQL import libpq 
     3 
    34import datetime 
     5 
     6try: 
     7    # Builtin in Python 2.5? 
     8    decimal 
     9except NameError: 
     10    try: 
     11        # Module in Python 2.3, 2.4 
     12        import decimal 
     13    except ImportError: 
     14        pass 
     15 
    416import dejavu 
    517from dejavu.storage import db 
     
    192204            self.execute('CREATE INDEX %s ON %s (%s);' % 
    193205                         (i, tablename, self.column_name(clsname, index))) 
    194  
     206     
     207    def drop_index(self, cls, name): 
     208        clsname = cls.__name__ 
     209        for i in self.get_indices(clsname): 
     210            if i.colname == name: 
     211                self.execute('DROP INDEX %s;' % self.sql_name(i.name)) 
     212     
     213    def get_tables(self, conn=None): 
     214        data, _ = self.fetch("SELECT tablename FROM pg_tables WHERE " 
     215                             "schemaname not in ('information_schema', 'pg_catalog')", 
     216                             conn=conn) 
     217        return [db.Table(row[0]) for row in data] 
     218     
     219    def get_columns(self, tablename=None, conn=None): 
     220        data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'" 
     221                             % tablename, conn=conn) 
     222        table_OID = data[0][0] 
     223        sql = ("SELECT attname, atttypid, attnum, attlen " 
     224               "FROM pg_attribute WHERE attrelid = %s" % table_OID) 
     225        data, _ = self.fetch(sql, conn=conn) 
     226        cols = [] 
     227        for row in data: 
     228            name = row[0] 
     229            if name in ('tableoid', 'cmax', 'xmax', 'cmin', 'xmin', 
     230                        'oid', 'ctid'): 
     231                # This is a column which PostgreSQL defines automatically 
     232                continue 
     233             
     234            # Data type 
     235            dbtype, _ = self.fetch("SELECT typname, typlen FROM pg_type " 
     236                                    "WHERE oid = %s" % row[1]) 
     237            if dbtype: 
     238                dbtype = dbtype[0][0] 
     239                if dbtype in ('int2', 'int4'): 
     240                    dbtype = int 
     241                elif dbtype == 'bool': 
     242                    dbtype = bool 
     243                elif dbtype == 'int8': 
     244                    dbtype = long 
     245                elif dbtype in ('float4', 'float8', 'money'): 
     246                    dbtype = float 
     247                elif dbtype == 'numeric': 
     248                    dbtype = decimal.Decimal 
     249                elif dbtype == 'date': 
     250                    dbtype = datetime.date 
     251                elif dbtype in ('timestamp', 'timestamptz'): 
     252                    dbtype = datetime.datetime 
     253                elif dbtype in ('time', 'timetz'): 
     254                    dbtype = datetime.time 
     255                elif dbtype in ('char', 'varchar', 'bpchar', 'text'): 
     256                    dbtype = str 
     257            else: 
     258                dbtype = None 
     259             
     260            # Default value 
     261            default, _ = self.fetch("SELECT adsrc FROM pg_attrdef " 
     262                                    "WHERE adnum = %s AND adrelid = %s" 
     263                                    % (row[2], table_OID)) 
     264            if default: 
     265                default = default[0][0] 
     266                if default.startswith("nextval("): 
     267                    default = None 
     268            else: 
     269                default = None 
     270             
     271            c = db.Column(row[0], dbtype, default) 
     272             
     273            bytes = row[3] 
     274            if bytes > 0: 
     275                c.hints['bytes'] = bytes 
     276             
     277            cols.append(c) 
     278        return cols 
     279     
     280    def get_indices(self, tablename, conn=None): 
     281        # Get the OID of the parent table. 
     282        data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'" 
     283                             % tablename, conn=conn) 
     284        if not data: 
     285            return [] 
     286         
     287        table_OID = data[0][0] 
     288        indices = [] 
     289        data, _ = self.fetch("SELECT pg_class.relname, indkey, indisprimary, " 
     290                             "indisunique FROM pg_index LEFT JOIN pg_class " 
     291                             "ON pg_index.indexrelid = pg_class.oid WHERE " 
     292                             "pg_index.indrelid = %s" % table_OID, conn=conn) 
     293        for row in data: 
     294            cols = map(int, row[1].split(" ")) 
     295            for col in cols: 
     296                d, _ = self.fetch("SELECT attname FROM pg_attribute " 
     297                                  "WHERE attrelid = %s AND attnum = %s" 
     298                                  % (table_OID, col), conn=conn) 
     299                indices.append(db.Index(row[0], tablename, d[0][0], 
     300                                        bool(row[2]), bool(row[3]))) 
     301         
     302        return indices 
     303 
  • trunk/storage/storesqlite.py

    r188 r189  
    430430        altermap[newname] = oldname 
    431431        self._legacy_alter_table(cls, altermap) 
    432  
     432     
     433    def drop_index(self, cls, name): 
     434        clsname = cls.__name__ 
     435        self.execute('DROP INDEX %s ON %s;' % 
     436                     (self.sql_name("i" + clsname + name), 
     437                      self.table_name(clsname))) 
     438     
     439    def get_tables(self, conn=None): 
     440        data, _ = self.fetch("SELECT name FROM sqlite_master WHERE type = 'table'") 
     441        return [db.Table(row[0]) for row in data] 
     442     
     443    def get_columns(self, tablename=None, conn=None): 
     444        data, coldefs = self.fetch("SELECT * FROM %s WHERE 1 == 0" 
     445                                   % self.sql_name(tablename), conn=conn) 
     446        cols = [] 
     447        for col in coldefs: 
     448            c = db.Column(col[0], str, None) 
     449            cols.append(c) 
     450        return cols 
     451     
     452    def get_indices(self, tablename, conn=None): 
     453        data, _ = self.fetch("SELECT name, tbl_name, sql FROM sqlite_master " 
     454                          "WHERE type = 'index'") 
     455        indices = [] 
     456        for row in data: 
     457            colname = row[2].split("(")[-1] 
     458            colname = colname[1:-2] 
     459            indices.append(db.Index(row[0], row[1], colname)) 
     460        return indices 
  • trunk/test/zoo_fixture.py

    r188 r189  
    725725##        box.commit_since("rollback point name") 
    726726     
     727    def test_DB_Introspection(self): 
     728        s = arena.stores.values()[0] 
     729        if getattr(s, "get_tables", None) is None: 
     730            return 
     731         
     732        tables = s.get_tables() 
     733        for t in tables: 
     734##            print t 
     735##            for c in s.get_columns(t.name): 
     736##                print "   ", c 
     737##            for i in s.get_indices(t.name): 
     738##                print "   ", i 
     739            if t.name.lower() == "djvzoo": 
     740                zootable = t 
     741        self.assertEqual(zootable.name.lower(), "djvzoo") 
     742        cols = s.get_columns(zootable.name) 
     743        self.assertEqual(len(cols), 6) 
     744         
     745        cols = dict([(x.key.lower(), x) for x in cols]) 
     746        idcol = cols['id'] 
     747        # Since SQLite is typless, it will set all types to 'str' 
     748        self.assert_(idcol.type in (int, str)) 
     749        self.assertEqual(idcol.default, None) 
     750     
    727751    def testzzzz_Schema_Upgrade(self): 
    728752        # Must run last.