Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

Changeset 360

Show
Ignore:
Timestamp:
12/18/06 06:21:04
Author:
fumanchu
Message:

MUCH work done on the firebird store. Now passes all tests but threading, transactions, and precision (and autosource, but only because FB has no bool type).

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/storage/storefirebird.py

    r359 r360  
    33 
    44import dejavu 
     5from dejavu import logic 
    56from dejavu.storage import db 
    67 
     
    1112 
    1213 
    13 ## {SMALLINT | INTEGER | FLOAT | DOUBLE PRECISION}[<array_dim>] 
    14 ## 
    15 ##| (DATE | TIME | TIMESTAMP} [<array_dim>] 
    16 ## 
    17 ##| {DECIMAL | NUMERIC} [(precision [, scale])] [<array_dim>] 
    18 ##| {CHAR | CHARACTER | CHARACTER VARYING | VARCHAR} [(int)] 
    19 ##[<array_dim>] [CHARACTER SET charname] 
    20 ## 
    21 ##| {NCHAR | NATIONAL CHARACTER | NATIONAL CHAR} 
    22 ##[VARYING] [(int)] [<array_dim>] 
    23 ## 
    24 ##| BLOB [SUB_TYPE {int | subtype_name}] [SEGMENT SIZE int] 
    25 ##[CHARACTER SET charname] 
    26 ## 
    27 ##| BLOB [(seglen [, subtype])]<array_dim> = [[x:]y [, [x:]y ]] 
    28  
    29 class FieldTypeAdapterFirebird(db.FieldTypeAdapter): 
     14class AdapterToFireBirdSQL(db.AdapterToSQL): 
     15     
     16    # Notice these are ordered pairs. Escape \ before introducing new ones. 
     17    # Values in these two lists should be strings encoded with self.encoding. 
     18    escapes = [("'", "''")] 
     19    like_escapes = [("\\", r"\\"), ("%", r"\%"), ("_", r"\_")] 
     20     
     21    # Firebird doesn't have true or false keywords. 
     22    bool_true = "1=1" 
     23    bool_false = "1=0" 
     24     
     25    def coerce_bool_to_any(self, value): 
     26        if value: 
     27            return '1' 
     28        return '0' 
     29 
     30 
     31class AdapterFromFirebirdDB(db.AdapterFromDB): 
     32     
     33    # !!?!??! kinterbasdb already converts to datetime? Rapture! 
     34    def coerce_any_to_datetime_datetime(self, value): 
     35        return value 
     36     
     37    def coerce_any_to_datetime_date(self, value): 
     38        return value 
     39     
     40    def coerce_any_to_datetime_time(self, value): 
     41        return value 
     42 
     43 
     44class TypeAdapterFirebird(db.TypeAdapter): 
    3045    """Return the SQL typename of a DB column.""" 
    3146     
    32     def coerce_str(self, cls, key): 
    33         prop = getattr(cls, key) 
    34         bytes = int(prop.hints.get('bytes', '0')) 
     47    # Max decimal precision for NUMERIC columns. 
     48    numeric_max_precision = 18 
     49    numeric_max_bytes = 9 
     50     
     51    def coerce_str(self, col): 
     52        # The bytes hint shall not reflect the usual 4-byte base for varchar. 
     53         
     54        # Although Firebird allows VARCHAR of 32765, 255 is usually the max 
     55        # for which an index can be created. 
     56        default = 127 
     57         
     58        bytes = int(col.hints.get('bytes', default)) 
    3559        if 1 <= bytes <= 32765: 
    3660            return "VARCHAR(%s)" % bytes 
    3761        return "BLOB" 
    3862     
    39     def coerce_bool(self, cls, key): 
     63    def coerce_bool(self, col): 
    4064        return "INTEGER" 
    4165 
    4266 
    43 class StorageManagerFirebird(db.StorageManagerDB): 
    44     """StoreManager to save and retrieve Units via Firebird 1.5.""" 
     67class FirebirdSQLDecompiler(db.SQLDecompiler): 
     68     
     69    # --------------------------- Dispatchees --------------------------- # 
     70     
     71    def attr_startswith(self, tos, arg): 
     72        return tos + " STARTING WITH " + arg 
     73     
     74    def attr_endswith(self, tos, arg): 
     75        return tos + " LIKE '%" + self.adapter.escape_like(arg) + "' ESCAPE '\\'" 
     76     
     77    def containedby(self, op1, op2): 
     78        if isinstance(op1, db.ConstWrapper): 
     79            # Looking for text in a field. Use Like (reverse terms). 
     80            return op2 + " LIKE '%" + self.adapter.escape_like(op1) + "%' ESCAPE '\\'" 
     81        else: 
     82            # Looking for field in (a, b, c) 
     83            atoms = [self.adapter.coerce(x) for x in op2.basevalue] 
     84            if atoms: 
     85                return op1 + " IN (" + ", ".join(atoms) + ")" 
     86            else: 
     87                # Nothing will match the empty list, so return none. 
     88                return self.adapter.bool_false 
     89     
     90    def dejavu_icontainedby(self, op1, op2): 
     91        if isinstance(op1, db.ConstWrapper): 
     92            return op2 + " CONTAINING " + op1 
     93        else: 
     94            # Looking for field in (a, b, c). 
     95            # Force all args to uppercase for case-insensitive comparison. 
     96            atoms = [self.adapter.coerce(x).upper() for x in op2.basevalue] 
     97            return "UPPER(%s) IN (%s)" % (op1, ", ".join(atoms)) 
     98     
     99    def dejavu_icontains(self, x, y): 
     100        return self.dejavu_icontainedby(y, x) 
     101     
     102    # Firebird has no LOWER function, but it does have an UPPER. Funky. 
     103     
     104    def dejavu_istartswith(self, x, y): 
     105        return "UPPER(" + x + ") LIKE '" + self.adapter.escape_like(y) + "%' ESCAPE '\\'" 
     106     
     107    def dejavu_iendswith(self, x, y): 
     108        return "UPPER(" + x + ") LIKE '%" + self.adapter.escape_like(y) + "' ESCAPE '\\'" 
     109     
     110    def dejavu_ieq(self, x, y): 
     111        return "UPPER(" + x + ") = UPPER(" + y + ")" 
     112     
     113    # Firebird 1.5 doesn't seem to have any date functions 
     114    dejavu_now = None 
     115    dejavu_today = None 
     116    dejavu_year = None 
     117    dejavu_month = None 
     118    dejavu_day = None 
     119     
     120    # Firebird 1.5 has no LENGTH function 
     121    func__builtin___len = None 
     122 
     123 
     124 
     125class FirebirdColumnSet(db.ColumnSet): 
     126     
     127    def __setitem__(self, key, column): 
     128        t = self.table 
     129        if key in self: 
     130            del self[key] 
     131         
     132        default = column.default or "" 
     133        if default: 
     134            default = " DEFAULT %s" % t.db.adaptertosql.coerce(default, column.dbtype) 
     135         
     136        t.db.lock("Adding property. Transactions not allowed.") 
     137        try: 
     138            # FB doesn't recognize the keyword "COLUMN" in "ADD". 
     139            t.db.execute("ALTER TABLE %s ADD %s %s%s;" % 
     140                         (t.qname, column.qname, column.dbtype, default)) 
     141            dict.__setitem__(self, key, column) 
     142        finally: 
     143            t.db.unlock() 
     144     
     145    def __delitem__(self, key): 
     146        if key in self.indices: 
     147            del self.indices[key] 
     148        t = self.table 
     149        t.db.lock("Dropping property. Transactions not allowed.") 
     150        try: 
     151            # FB doesn't recognize the keyword "COLUMN" in "DROP". 
     152            t.db.execute("ALTER TABLE %s DROP %s;" % 
     153                               (t.qname, self[key].qname)) 
     154            dict.__delitem__(self, key) 
     155        finally: 
     156            t.db.unlock() 
     157     
     158    def _rename(self, oldcol, newcol): 
     159        # Override this to do the actual rename at the DB level. 
     160        t = self.table 
     161        # FB doesn't use the keyword "RENAME". 
     162        t.db.execute("ALTER TABLE %s ALTER COLUMN %s TO %s;" % 
     163                     (t.qname, oldcol.qname, newcol.qname)) 
     164 
     165 
     166class FirebirdDatabase(db.Database): 
     167     
     168    decompiler = FirebirdSQLDecompiler 
     169    adaptertosql = AdapterToFireBirdSQL() 
     170    adapterfromdb = AdapterFromFirebirdDB() 
     171    typeadapter = TypeAdapterFirebird() 
     172    columnsetclass = FirebirdColumnSet 
    45173     
    46174    sql_name_max_length = 63 
    47     typeAdapter = FieldTypeAdapterFirebird() 
    48      
    49     def __init__(self, arena, allOptions={}): 
    50         # DSN = "host:database" 
    51         self.DSN = dsn = allOptions['DSN'] 
    52         self.dbname = dsn.split(":", 1)[-1] 
    53         self.user = allOptions['user'] 
    54         self.password = allOptions['password'] 
    55         self.encoding = allOptions.get('encoding', 'utf8') 
    56         db.StorageManagerDB.__init__(self, arena, allOptions) 
    57      
    58     def sql_name(self, name, quoted=True): 
    59         name = db.StorageManagerDB.sql_name(self, name, quoted) 
    60         if quoted: 
    61             name = '"' + name.replace('"', '""') + '"' 
    62         return name 
     175    encoding = 'utf8' 
     176     
     177    def _get_tables(self, conn=None): 
     178        data, _ = self.fetch("SELECT RDB$RELATION_NAME FROM RDB$RELATIONS " 
     179                             "WHERE RDB$SYSTEM_FLAG=0 AND RDB$VIEW_BLR IS NULL;", 
     180                             conn=conn) 
     181        return [db.Table(self, name, self.quote(name)) for name, in data] 
     182     
     183    def _get_columns(self, tablename, conn=None): 
     184        data, _ = self.fetch("SELECT RF.RDB$FIELD_NAME, T.RDB$TYPE_NAME, " 
     185                             "F.RDB$FIELD_LENGTH, RF.RDB$DEFAULT_VALUE, " 
     186                             "F.RDB$FIELD_PRECISION, F.RDB$FIELD_SCALE " 
     187                             "FROM RDB$RELATION_FIELDS RF LEFT JOIN " 
     188                             "RDB$FIELDS F ON (F.RDB$FIELD_NAME = RF.RDB$FIELD_SOURCE) " 
     189                             "LEFT JOIN RDB$TYPES T ON (T.RDB$TYPE = F.RDB$FIELD_TYPE) " 
     190                             "WHERE RF.RDB$RELATION_NAME = %s AND " 
     191                             "T.RDB$FIELD_NAME = 'RDB$FIELD_TYPE'" % tablename, 
     192                             conn=conn) 
     193        print tablename, data 
     194        cols = [] 
     195        for name, dbtype, fieldlen, default, prec, scale in data: 
     196            hints = {} 
     197            hints['precision'] = prec 
     198            hints['scale'] = scale 
     199            print hints 
     200            key = False 
     201            # Column(name, qname, dbtype, default=None, hints=None, key=False) 
     202            col = db.Column(name, self.quote(name), dbtype, default, hints, key) 
     203            cols.append(col) 
     204        return cols 
     205     
     206    def _get_indices(self, tablename, conn=None): 
     207        data, _ = self.fetch("SELECT RDB$INDEX_NAME, RDB$UNIQUE_FLAG " 
     208                             "FROM RDB$INDICES WHERE RDB$RELATION_NAME=%s " 
     209                             "AND RDB$FOREIGN_KEY IS NULL;" 
     210                             % self.quote(tablename), 
     211                             conn=conn) 
     212        return [db.Index(name, self.quote(name), tablename, colname, bool(unique)) 
     213                for name, unique in data] 
     214     
     215    def python_type(self, dbtype): 
     216        """Return a Python type which can store values of the given dbtype.""" 
     217        dbtype = dbtype.upper() 
     218         
     219        if dbtype in ('INTEGER', 'SMALLINT'): 
     220            return int 
     221        elif dbtype in ('BIGINT', ): 
     222            return long 
     223        elif dbtype in ('FLOAT', 'DOUBLE PRECISION', 'REAL'): 
     224            return float 
     225        elif dbtype.startswith('NUMERIC') or dbtype.startswith('DECIMAL'): 
     226            if db.decimal: 
     227                return db.decimal.Decimal 
     228            elif db.fixedpoint: 
     229                return db.fixedpoint.FixedPoint 
     230        elif dbtype == 'DATE': 
     231            return datetime.date 
     232        elif dbtype == 'TIMESTAMP': 
     233            return datetime.datetime 
     234        elif dbtype == 'TIME': 
     235            return datetime.time 
     236        for t in ('CHAR', 'VARCHAR', 'BLOB'): 
     237            if dbtype.startswith(t): 
     238                return str 
     239        for t in ('NCHAR', 'NATIONAL'): 
     240            if dbtype.startswith(t): 
     241                return unicode 
     242         
     243        raise TypeError("Database type %r could not be converted " 
     244                        "to a Python type." % dbtype) 
     245     
     246    def __setitem__(self, key, table): 
     247        if key in self: 
     248            del self[key] 
     249         
     250        fields = [] 
     251        pk = [] 
     252        for colkey, col in table.columns.iteritems(): 
     253            default = col.default or "" 
     254            if default: 
     255                default = " DEFAULT %s" % self.adaptertosql.coerce(default, col.dbtype) 
     256             
     257            notnull = "" 
     258            if col.key: 
     259                pk.append(col.qname) 
     260                # Firebird PK's must be NOT NULL 
     261                notnull = " NOT NULL" 
     262             
     263            fields.append('%s %s%s%s' % (col.qname, col.dbtype, default, notnull)) 
     264         
     265        if pk: 
     266            pk = ", PRIMARY KEY (%s)" % ", ".join(pk) 
     267        else: 
     268            pk = "" 
     269         
     270        self.lock("Creating storage. Transactions not allowed.") 
     271        try: 
     272            self.execute('CREATE TABLE %s (%s%s);' % 
     273                         (table.qname, ", ".join(fields), pk)) 
     274             
     275            for index in table.columns.indices.itervalues(): 
     276                self.execute('CREATE INDEX %s ON %s (%s);' % 
     277                             (index.qname, table.qname, 
     278                              self.quote(index.colname))) 
     279            dict.__setitem__(self, key, table) 
     280        finally: 
     281            self.unlock() 
     282     
     283    #                               Naming                               # 
     284     
     285    def quote(self, name): 
     286        """Return name, quoted for use in an SQL statement.""" 
     287        return '"' + name.replace('"', '""') + '"' 
    63288     
    64289    def _get_conn(self): 
     
    74299                conn = self.connection() 
    75300            if isinstance(query, unicode): 
    76                 query = query.encode(self.toAdapter.encoding) 
     301                query = query.encode(self.adaptertosql.encoding) 
    77302            self.log(query) 
    78303            cur = conn.cursor() 
    79304            cur.execute(query) 
    80305            conn.commit() 
    81 ##            conn.close() 
    82306        except Exception, x: 
    83307            x.args += (query,) 
     
    87311     
    88312    def fetch(self, query, conn=None): 
    89         """fetch(query, conn=None) -> rowdata, columns. 
    90          
    91         rowdata: a nested list (or tuples), column values within rows. 
    92         columns: a series of 2-tuples (or more). The first tuple value 
    93             will be the column name, the second value will be the column 
    94             type. 
     313        """Return rowdata, columns (name, type) for the given query. 
     314         
     315        query should be a SQL query in string format 
     316        rowdata will be an iterable of iterables containing the result values. 
     317        columns will be an iterable of (column name, data type) pairs. 
    95318        """ 
    96319        try: 
     
    98321                conn = self.connection() 
    99322            if isinstance(query, unicode): 
    100                 query = query.encode(self.toAdapter.encoding) 
     323                query = query.encode(self.adaptertosql.encoding) 
    101324            self.log(query) 
    102325            cur = conn.cursor() 
    103326            cur.execute(query) 
     327             
    104328            data = cur.fetchall() 
    105329            desc = cur.description 
     
    112336        return data, desc 
    113337     
     338    #                               Schemas                               # 
     339     
     340    def create_database(self): 
     341        self.lock("Creating database. Transactions not allowed.") 
     342        try: 
     343            # Firebird DB 'names' are actually filesystem paths. 
     344            sql = ("CREATE DATABASE %s USER '%s' PASSWORD '%s';" 
     345                   % (self.qname, self.user, self.password)) 
     346             
     347            # Use the kinterbasdb helper methods for cleaner create and drop. 
     348            # We also use dialect 3 *always* to help with quoted identifiers. 
     349            conn = kinterbasdb.create_database(sql, 3) 
     350            conn.close() 
     351             
     352            self.clear() 
     353        finally: 
     354            self.unlock() 
     355     
     356    def drop_database(self): 
     357        self.lock("Dropping database. Transactions not allowed.") 
     358        try: 
     359            # Must shut down all connections to avoid 
     360            # "being accessed by other users" error. 
     361            self.connection.shutdown() 
     362             
     363            if os.path.exists(self.qname): 
     364                c = self._get_conn() 
     365                c.drop_database() 
     366             
     367            self.clear() 
     368        finally: 
     369            self.unlock() 
     370     
     371    #                            Transactions                             # 
     372     
     373    def start(self): 
     374        """Start a transaction. Not needed if self.implicit_trans is True.""" 
     375        conn = self.get_transaction(new=True) 
     376        conn.begin() 
     377     
     378    def rollback(self): 
     379        """Roll back the current transaction, if any.""" 
     380        key = self.transaction_key() 
     381        try: 
     382            conn = self.transactions.pop(key) 
     383        except KeyError: 
     384            pass 
     385        else: 
     386            conn.rollback() 
     387     
     388    def commit(self): 
     389        """Commit the current transaction, if any.""" 
     390        key = self.transaction_key() 
     391        try: 
     392            conn = self.transactions.pop(key) 
     393        except KeyError: 
     394            pass 
     395        else: 
     396            conn.commit() 
     397 
     398 
     399 
     400class FirebirdJoinWrapper(db.UnitClassWrapper): 
     401    """Unit class wrapper, for use in parsing multiselect joins.""" 
     402     
     403    def _joinname(self): 
     404        if self.alias: 
     405            # Firebird doesn't use the "AS" keyword 
     406            return "%s %s" % (self.table.qname, self.alias) 
     407        else: 
     408            return self.table.qname 
     409    joinname = property(_joinname, doc=("Quoted table name for use in " 
     410                                        "JOIN clause (read-only).")) 
     411 
     412 
     413class StorageManagerFirebird(db.StorageManagerDB): 
     414    """StoreManager to save and retrieve Units via Firebird 1.5.""" 
     415     
     416    databaseclass = FirebirdDatabase 
     417    joinwrapper = FirebirdJoinWrapper 
     418     
     419    def __init__(self, arena, allOptions={}): 
     420        # DSN = "host:database" 
     421        self.host, self.dbname = allOptions['DSN'].split(":", 1) 
     422        allOptions['name'] = self.dbname 
     423        db.StorageManagerDB.__init__(self, arena, allOptions) 
     424     
    114425    def version(self): 
    115         return "KInterbasDB Version: %s" % repr(kinterbasdb.__version__) 
    116      
    117     #                               Schemas                               # 
    118      
    119     def create_database(self): 
    120         # Firebird DB 'names' are actually filesystem paths. 
    121         sql = ("CREATE DATABASE '%s' USER '%s' PASSWORD '%s';" 
    122                % (self.dbname, self.user, self.password)) 
    123          
    124         # Use the kinterbasdb helper methods for cleaner create and drop. 
    125         # We also use dialect 3 *always* to help with quoted identifiers. 
    126         conn = kinterbasdb.create_database(sql, 3) 
    127         conn.close() 
    128      
    129     def drop_database(self): 
    130         if os.path.exists(self.dbname): 
    131             # Close any open connections. 
    132             self.shutdown() 
    133              
    134             c = self._get_conn() 
    135             c.drop_database() 
     426        import kinterbasdb.services 
     427        svcCon = kinterbasdb.services.connect(host=self.host, 
     428                                              user=self.db.user, 
     429                                              password=self.db.password) 
     430##        conn = self.db._get_conn() 
     431        return ("KInterbasDB Version: %r\nServer Version: %r" 
     432                % (kinterbasdb.__version__, svcCon.getServerVersion())) 
     433     
     434    def multiselect(self, classes, expr): 
     435        """Return an SQL SELECT statement, an imperfect flag, and column names.""" 
     436         
     437        # Create a new unitjoin tree where each class is wrapped. 
     438        # Then we can tag the wrappers with metadata with impunity. 
     439         
     440        # Firebird 1.5 won't accept the same table twice in a JOIN 
     441        # unless *both* table names are aliased. 
     442        # seen = {} 
     443        aliascount = [0] 
     444        q = lambda name: self.db.quote(self.db.table_name(name)) 
     445         
     446        def wrap(unitjoin): 
     447            cls1, cls2 = unitjoin.class1, unitjoin.class2 
     448            if isinstance(cls1, dejavu.UnitJoin): 
     449                wclass1 = wrap(cls1) 
     450            else: 
     451                wclass1 = self.joinwrapper(cls1, self.db[cls1.__name__]) 
     452                aliascount[0] += 1 
     453                wclass1.alias = q("t%d" % aliascount[0]) 
     454            if isinstance(cls2, dejavu.UnitJoin): 
     455                wclass2 = wrap(cls2) 
     456            else: 
     457                wclass2 = self.joinwrapper(cls2, self.db[cls2.__name__]) 
     458                aliascount[0] += 1 
     459                wclass2.alias = q("t%d" % aliascount[0]) 
     460            uj = dejavu.UnitJoin(wclass1, wclass2, unitjoin.leftbiased) 
     461            # if the unitjoin had a custom association path, set it on 
     462            # the new UnitJoin instance 
     463            uj.path = unitjoin.path 
     464            return uj 
     465        classes = wrap(classes) 
     466         
     467        joins = self.join(classes) 
     468         
     469        if expr is None: 
     470            expr = logic.Expression(lambda *args: True) 
     471         
     472        wheretables = [] 
     473        for c in classes: 
     474            alias = getattr(c, "alias", None) 
     475            if alias is None: 
     476                t = self.db[c.__name__] 
     477                qname = t.qname 
     478            else: 
     479                # c is an instance of self.joinwrapper 
     480                t = c.table 
     481                qname = c.alias or t.qname 
     482             
     483            wheretables.append((qname, t)) 
     484         
     485        w, imp = self.db.where(wheretables, expr) 
     486         
     487        cols = [] 
     488        colnames = [] 
     489        for wrapper in classes: 
     490            c, names = wrapper.columns() 
     491            cols.extend(c) 
     492            colnames.extend(names) 
     493         
     494        statement = ("SELECT %s FROM %s WHERE %s" % 
     495                     (', '.join(colnames), joins, w)) 
     496        return statement, imp, cols 
     497