Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

Changeset 11

Show
Ignore:
Timestamp:
02/16/07 23:56:58
Author:
fumanchu
Message:

Split Schema from Database. Table objects now reference a parent schema, not a parent db. New Database.multischema attribute. Changed providers.providers to providers.registry. Inlined SelectWriter?.quote.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/geniusql/__init__.py

    r10 r11  
    2121from dejavu import logic 
    2222 
    23 from geniusql import xray 
    2423from geniusql import errors, typerefs 
    2524 
     
    3231 
    3332 
    34 def db(cls, name, options): 
    35     """Create a Database model object for the given cls, name, and options. 
    36      
    37     cls: Either a subclass of geniusql.Database or a 'shortcut name' 
    38         registered in geniusql.providers.providers. 
    39     name: the database name as used by the underlying database. 
    40      
    41     This function does not call CREATE DATABASE, nor should it open any 
    42     database connections. It simply instantiates the proper Database class. 
     33def db(provider, **options): 
     34    """Return a Database and Schema object for the given provider. 
     35     
     36    provider: A 'shortcut name' registered in geniusql.providers.registry. 
     37     
     38    This function does not call CREATE DATABASE (although it may open a 
     39    database connection). 
    4340    """ 
    44     if isinstance(cls, basestring): 
    45         try: 
    46             cls = providers.providers[cls] 
    47         except KeyError: 
    48             pass 
    49      
    50     if isinstance(cls, basestring): 
    51         cls = xray.classes(cls) 
    52      
    53     opts = dict([(str(k), v) for k, v in options.iteritems()]) 
    54     opts.pop('name', None) 
    55      
    56     return cls(name, **opts) 
     41    return providers.registry.open(provider, **options) 
    5742 
    5843 
     
    9681        Index object. 
    9782        """ 
     83        if oldname == newname: 
     84            return 
    9885        obj = self[oldname] 
    9986        if newname in self: 
    10087            dict.__delitem__(self, newname) 
     88        dict.__delitem__(self, oldname) 
    10189        dict.__setitem__(self, newname, obj) 
    10290     
    10391    def __setitem__(self, key, index): 
    104         """Drop the specified index.""" 
     92        """Create the specified index.""" 
    10593        t = self.table 
    10694        if t.created: 
    107             t.db.execute_ddl('CREATE INDEX %s ON %s (%s);' % 
    108                              (index.qname, t.qname, 
    109                               t.db.quote(index.colname))) 
     95            t.schema.db.execute_ddl('CREATE INDEX %s ON %s (%s);' % 
     96                                    (index.qname, t.qname, 
     97                                     t.schema.db.quote(index.colname))) 
    11098        dict.__setitem__(self, key, index) 
    11199     
     
    114102        t = self.table 
    115103        if t.created: 
    116             t.db.execute_ddl('DROP INDEX %s ON %s;' % 
     104            t.schema.db.execute_ddl('DROP INDEX %s ON %s;' % 
    117105                             (self[key].qname, t.qname)) 
    118106        dict.__delitem__(self, key) 
     
    132120    imperfect_type: if True, signals that we are deliberately using a 
    133121        database type other than the default (usually in order to handle 
    134         irregular values, such as huge numbers). When comparing database 
    135         values with constant values in SQL, such columns must have an 
    136         explicit adaptertosql.cast_dbtype_to_pytype method to cast 
    137         the column value to one which can successfully be compared 
    138         with the constant. If there is no matching cast_* method, 
    139         then the query will be marked imperfect. 
     122        irregular values, such as huge numbers). When comparing imperfect 
     123        column values with constant values in SQL, the database must be 
     124        able to cast the column value to the constant's type. If that 
     125        cannot be done for the given types, then the query will be marked 
     126        imperfect. 
    140127    autoincrement: if True, uses the database's built-in sequencing. 
    141128    sequence_name: for databases that use separate statements to create and 
     
    192179    name: the SQL name for this table (unquoted). 
    193180    qname: the SQL name for this table (quoted). 
    194     db: the database for this table. If None (the default), then changes to 
    195         Table items can be made with impunity. If not None, then appropriat
    196         ALTER TABLE commands are executed whenever a consumer adds or deletes 
    197         items from the Table, or calls methods like 'rename'. Therefore, 
    198         when creating Table objects from an existing database, you should 
    199         set the 'db' arg late
     181    schema: the schema for this table. 
     182    created: whether or not this Table has a concrete implementation in th
     183        database. If False (the default), then changes to Table items can be 
     184        made with impunity. If True, then appropriate ALTER TABLE commands 
     185        are executed whenever a consumer adds or deletes items from the 
     186        Table, or calls methods like 'rename'
    200187    indices: a dict-like IndexSet of Index objects. 
    201188    references: a dict of the form: {name: (nearColKey, farTableKey, farColKey)}. 
    202189    """ 
    203190     
    204     indexsetclass = IndexSet 
    205      
    206     def __new__(cls, name, qname, db, created=False): 
     191    def __new__(cls, name, qname, schema, created=False): 
    207192        return dict.__new__(cls) 
    208193     
    209     def __init__(self, name, qname, db, created=False): 
     194    def __init__(self, name, qname, schema, created=False): 
    210195        dict.__init__(self) 
    211196         
    212197        self.name = name 
    213198        self.qname = qname 
    214         self.db = db 
     199        self.schema = schema 
    215200        self.created = created 
    216201         
    217         self.indices = self.indexsetclass(self) 
     202        self.indices = schema.indexsetclass(self) 
    218203        self.references = {} 
    219204     
     
    226211    def __copy__(self): 
    227212        # Don't set 'created' when copying! 
    228         newtable = self.__class__(self.name, self.qname, self.db
     213        newtable = self.__class__(self.name, self.qname, self.schema
    229214        for key, c in self.iteritems(): 
    230215            dict.__setitem__(newtable, key, c.copy()) 
     
    242227        Column object. 
    243228        """ 
     229        if oldname == newname: 
     230            return 
     231         
    244232        obj = self[oldname] 
    245233        if newname in self: 
    246234            dict.__delitem__(self, newname) 
     235        dict.__delitem__(self, oldname) 
    247236        dict.__setitem__(self, newname, obj) 
    248237     
    249238    def _add_column(self, column): 
    250239        """Internal function to add the column to the database.""" 
    251         coldef = self.db.columnclause(column) 
    252         self.db.execute("ALTER TABLE %s ADD COLUMN %s;" % (self.qname, coldef)) 
     240        coldef = self.schema.columnclause(column) 
     241        self.schema.db.execute("ALTER TABLE %s ADD COLUMN %s;" % 
     242                               (self.qname, coldef)) 
    253243     
    254244    def __setitem__(self, key, column): 
    255245        if column.name is None: 
    256             column.name = self.db._column_name(self.name, key) 
    257             column.qname = self.db.quote(column.name) 
     246            column.name = self.schema._column_name(self.name, key) 
     247            column.qname = self.schema.db.quote(column.name) 
    258248         
    259249        if not self.created: 
     
    266256        if column.autoincrement: 
    267257            # This may or may not be a no-op, depending on the DB. 
    268             self.db.create_sequence(self, column) 
     258            self.schema.create_sequence(self, column) 
    269259        self._add_column(column) 
    270260        dict.__setitem__(self, key, column) 
     
    272262    def _drop_column(self, column): 
    273263        """Internal function to drop the column from the database.""" 
    274         self.db.execute_ddl("ALTER TABLE %s DROP COLUMN %s;" % 
    275                             (self.qname, column.qname)) 
     264        self.schema.db.execute_ddl("ALTER TABLE %s DROP COLUMN %s;" % 
     265                                   (self.qname, column.qname)) 
    276266     
    277267    def __delitem__(self, key): 
     
    287277        if column.autoincrement: 
    288278            # This may or may not be a no-op, depending on the DB. 
    289             self.db.drop_sequence(column) 
     279            self.schema.drop_sequence(column) 
    290280        dict.__delitem__(self, key) 
    291281     
    292282    def _rename(self, oldcol, newcol): 
    293283        # Override this to do the actual rename at the DB level. 
    294         self.db.execute_ddl("ALTER TABLE %s RENAME COLUMN %s TO %s;" % 
    295                             (self.qname, oldcol.qname, newcol.qname)) 
     284        self.schema.db.execute_ddl("ALTER TABLE %s RENAME COLUMN %s TO %s;" % 
     285                                   (self.qname, oldcol.qname, newcol.qname)) 
    296286     
    297287    def rename(self, oldkey, newkey): 
     
    305295         
    306296        oldname = oldcol.name 
    307         newname = self.db._column_name(self.name, newkey) 
     297        newname = self.schema._column_name(self.name, newkey) 
    308298         
    309299        if oldname != newname: 
    310300            newcol = oldcol.copy() 
    311301            newcol.name = newname 
    312             newcol.qname = self.db.quote(newname) 
     302            newcol.qname = self.schema.db.quote(newname) 
    313303            self._rename(oldcol, newcol) 
    314304         
     
    325315        """ 
    326316        colname = self[columnkey].name 
    327         name = self.db.table_name("i" + self.name + colname) 
    328         i = Index(name, self.db.quote(name), self.name, colname) 
     317        name = self.schema.table_name("i" + self.name + colname) 
     318        i = Index(name, self.schema.db.quote(name), self.name, colname) 
    329319        self.indices[columnkey] = i 
    330320        return i 
     
    339329        """ 
    340330        tpair = [(self.qname, self)] 
    341         decom = self.db.decompiler(tpair, logic.filter(**inputs), 
    342                                    self.db.adaptertosql) 
     331        decom = self.schema.db.decompiler(tpair, logic.filter(**inputs), 
     332                                          self.schema.db.adaptertosql) 
    343333        code = decom.code() 
    344334        if decom.imperfect: 
     
    356346    def insert(self, **inputs): 
    357347        """Insert a row and return {idcolkey: newid}.""" 
    358         coerce_out = self.db.adaptertosql.coerce 
    359         coerce_in = self.db.adapterfromdb.coerce 
     348        coerce_out = self.schema.db.adaptertosql.coerce 
     349        coerce_in = self.schema.db.adapterfromdb.coerce 
    360350         
    361351        fields = [] 
     
    372362                values.append(val) 
    373363         
    374         conn = self.db.connections.get() 
     364        conn = self.schema.db.connections.get() 
    375365         
    376366        fields = ", ".join(fields) 
    377367        values = ", ".join(values) 
    378         self.db.execute('INSERT INTO %s (%s) VALUES (%s);' % 
    379                         (self.qname, fields, values), conn) 
     368        self.schema.db.execute('INSERT INTO %s (%s) VALUES (%s);' % 
     369                               (self.qname, fields, values), conn) 
    380370         
    381371        if idkeys: 
     
    395385        """Update a row using the given inputs.""" 
    396386        parms = [] 
    397         coerce = self.db.adaptertosql.coerce 
     387        coerce = self.schema.db.adaptertosql.coerce 
    398388        for key, val in inputs.iteritems(): 
    399389            col = self[key] 
     
    408398            sql = ('UPDATE %s SET %s WHERE %s;' % 
    409399                   (self.qname, ", ".join(parms), self.id_clause(**inputs))) 
    410             self.db.execute(sql) 
     400            self.schema.db.execute(sql) 
    411401     
    412402    use_asterisk_to_delete_all = False 
     
    418408        else: 
    419409            star = "" 
    420         self.db.execute('DELETE%s FROM %s WHERE %s;' % 
    421                         (star, self.qname, self.id_clause(**inputs))) 
     410        self.schema.db.execute('DELETE%s FROM %s WHERE %s;' % 
     411                               (star, self.qname, self.id_clause(**inputs))) 
    422412     
    423413    def delete_all(self, **inputs): 
     
    427417        else: 
    428418            star = "" 
    429         self.db.execute('DELETE%s FROM %s WHERE %s;' % 
    430                         (star, self.qname, self.whereclause(**inputs))) 
     419        self.schema.db.execute('DELETE%s FROM %s WHERE %s;' % 
     420                               (star, self.qname, 
     421                                self.whereclause(**inputs))) 
    431422     
    432423    def select_all(self, restriction=None, **kwargs): 
     
    442433         
    443434        attrs = self.keys() 
    444         data = self.db.select(self, attrs, restriction) 
     435        data = self.schema.db.select(self, attrs, restriction) 
    445436        for row in data: 
    446437            row = dict(zip(attrs, row)) 
     
    466457 
    467458 
    468 class Database(dict): 
     459class Schema(dict): 
    469460    """A dict for managing a set of tables. 
    470461     
     
    475466    refer to each table. 
    476467     
    477     When a consumer adds and deletes items from a Database object, 
     468    When a consumer adds and deletes items from a Schema object, 
    478469    appropriate CREATE TABLE/DROP TABLE commands are executed. 
    479470    This means that a Table object to be added should have all 
    480     of its columns populated before adding it to the Database
     471    of its columns populated before adding it to the Schema
    481472    """ 
    482473     
    483     adaptertosql = AdapterToSQL() 
    484     adapterfromdb = AdapterFromDB() 
    485     typeadapter = TypeAdapter() 
    486      
    487     decompiler = SQLDecompiler 
    488     joinwrapper = TableWrapper 
    489      
    490     selectwriter = SelectWriter 
    491474    tableclass = Table 
    492     connectionmanager = ConnectionManager 
    493      
    494     def __new__(cls, name, **kwargs): 
     475    indexsetclass = IndexSet 
     476     
     477    def __new__(cls, db, name): 
    495478        return dict.__new__(cls) 
    496479     
    497     def __init__(self, name, **kwargs): 
     480    def __init__(self, db, name): 
     481        dict.__init__(self) 
     482         
     483        self.db = db 
     484        self.name = self.db.sql_name(name) 
     485        self.qname = self.db.quote(self.name) 
    498486        self._discover_lock = threading.Lock() 
    499          
    500         dict.__init__(self) 
    501         for k, v in kwargs.iteritems(): 
    502             setattr(self, k, v) 
    503          
    504         self.name = self.sql_name(name) 
    505         self.qname = self.quote(self.name) 
    506          
    507         poolsize = kwargs.get('poolsize', 10) 
    508         self.connections = self.connectionmanager(self, poolsize) 
    509          
    510487        self.discover_dbinfo() 
    511488     
     
    513490        name = getattr(self, "name", "<unknown>") 
    514491        return "%s.%s(%r)" % (self.__module__, self.__class__.__name__, name) 
    515      
    516     def version(self): 
    517         """Return a string containing version info for this database.""" 
    518         raise NotImplementedError 
    519      
    520     def log(self, msg): 
    521         pass 
    522      
    523492     
    524493    #                              Discovery                              # 
     
    568537        added to self using keys that match the database's names. 
    569538        Consumers should call the "alias(oldname, newname)" method 
    570         of Database, Table, and IndexSet in order to re-map the 
     539        of Schema, Table, and IndexSet in order to re-map the 
    571540        discovered objects using consumer-friendly names. 
    572541         
     
    576545        try: 
    577546            table = self._get_table(tablename) 
    578              
    579547            self._discover_table(table, conn) 
    580548             
     
    594562        added to self using keys that match the database's names. 
    595563        Consumers should call the "alias(oldname, newname)" method 
    596         of Database, Table, and IndexSet in order to re-map the 
     564        of Schema, Table, and IndexSet in order to re-map the 
    597565        discovered objects using consumer-friendly names. 
    598566         
     
    623591        Table object. 
    624592        """ 
     593        if oldname == newname: 
     594            return 
     595         
    625596        obj = self[oldname] 
    626597        if newname in self: 
    627598            dict.__delitem__(self, newname) 
     599        dict.__delitem__(self, oldname) 
    628600        dict.__setitem__(self, newname, obj) 
     601     
     602    def _column_name(self, tablename, columnkey): 
     603        "Return the SQL column name for the given table name and column key." 
     604        # If you want to use a map from your ORM's property names 
     605        # to DB column names, override this method (that's why 
     606        # the tablename must be included in the args). 
     607        return self.db.sql_name(columnkey) 
     608     
     609    def column(self, pytype=unicode, dbtype=None, default=None, hints=None, 
     610               key=False, autoincrement=False): 
     611        """Return a Column object from the given arguments.""" 
     612        col = Column(pytype, dbtype, default, hints, key) 
     613        col.autoincrement = autoincrement 
     614         
     615        if dbtype is None: 
     616            col.dbtype = self.db.typeadapter.coerce(col, pytype) 
     617        pytype2 = self.db.python_type(col.dbtype) 
     618        col.imperfect_type = not self.db.isrelatedtype(pytype, pytype2) 
     619         
     620        return col 
     621     
     622    prefix = "" 
     623     
     624    def table_name(self, key): 
     625        """Return the SQL table name for the given key.""" 
     626        # If you want to use a map from your ORM's class names 
     627        # to DB table names, override this method. 
     628        return self.db.sql_name(self.prefix + key) 
     629     
     630    def table(self, name): 
     631        """Create and return a Table object for the given name.""" 
     632        name = self.table_name(name) 
     633        return self.tableclass(name, self.db.quote(name), self) 
     634     
     635    def create_sequence(self, table, column): 
     636        """Create a SEQUENCE for the given column and set its sequence_name.""" 
     637        # By default, this does nothing. Databases which require a separate 
     638        # statement to create a sequence generator should override this. 
     639        pass 
     640     
     641    def drop_sequence(self, column): 
     642        """Drop a SEQUENCE for the given column and remove its sequence_name.""" 
     643        # By default, this does nothing. Databases which require a separate 
     644        # statement to drop a sequence generator should override this. 
     645        pass 
     646     
     647    def columnclause(self, column): 
     648        """Return a clause for the given column for CREATE or ALTER TABLE. 
     649         
     650        This will be of the form "name type [DEFAULT x]". 
     651         
     652        Most subclasses will override this to add autoincrement support. 
     653        """ 
     654        dbtype = column.dbtype 
     655         
     656        default = column.default or "" 
     657        if default: 
     658            default = self.db.adaptertosql.coerce(default, dbtype) 
     659            default = " DEFAULT %s" % default 
     660         
     661        return "%s %s%s" % (column.qname, dbtype, default) 
     662     
     663    def __setitem__(self, key, table): 
     664        if key in self: 
     665            del self[key] 
     666         
     667        # Set table.created to True, which should "turn on" 
     668        # any future ALTER TABLE statements. 
     669        table.created = True 
     670         
     671        fields = [] 
     672        pk = [] 
     673        for column in table.itervalues(): 
     674            if column.autoincrement: 
     675                # This may or may not be a no-op, depending on the DB. 
     676                self.create_sequence(table, column) 
     677             
     678            fields.append(self.columnclause(column)) 
     679            if column.key: 
     680                pk.append(column.qname) 
     681         
     682        if pk: 
     683            pk = ", PRIMARY KEY (%s)" % ", ".join(pk) 
     684        else: 
     685            pk = "" 
     686         
     687        self.db.execute_ddl('CREATE TABLE %s (%s%s);' % 
     688                            (table.qname, ", ".join(fields), pk)) 
     689         
     690        for index in table.indices.itervalues(): 
     691            self.db.execute_ddl('CREATE INDEX %s ON %s (%s);' % 
     692                                (index.qname, table.qname, 
     693                                 self.db.quote(index.colname))) 
     694         
     695        dict.__setitem__(self, key, table) 
     696     
     697    def __delitem__(self, key): 
     698        table = self[key] 
     699        self.db.execute_ddl('DROP TABLE %s;' % table.qname) 
     700        for col in table.itervalues(): 
     701            if col.autoincrement: 
     702                self.drop_sequence(col) 
     703        dict.__delitem__(self, key) 
     704     
     705    def _rename(self, oldtable, newtable): 
     706        # Override this to do the actual rename at the DB level. 
     707        raise NotImplementedError 
     708        newtable.created = True 
     709     
     710    def rename(self, oldkey, newkey): 
     711        """Rename a Table.""" 
     712        oldtable = self[oldkey] 
     713        oldname = oldtable.name 
     714        newname = self.db.table_name(newkey) 
     715         
     716        if oldname != newname: 
     717            newtable = oldtable.copy() 
     718            newtable.schema = self.schema 
     719            newtable.name = newname 
     720            newtable.qname = self.db.quote(newname) 
     721            self._rename(oldtable, newname) 
     722         
     723        # Use the superclass calls to avoid DROP TABLE/CREATE TABLE. 
     724        dict.__delitem__(self, oldkey) 
     725        dict.__setitem__(self, newkey, newtable) 
     726     
     727    def create_database(self): 
     728        self.db.execute_ddl("CREATE DATABASE %s;" % self.qname) 
     729        self.clear() 
     730     
     731    def drop_database(self): 
     732        # Must shut down all connections to avoid 
     733        # "being accessed by other users" error. 
     734        self.db.connections.shutdown() 
     735        self.db.execute_ddl("DROP DATABASE %s;" % self.qname) 
     736        self.clear() 
     737 
     738 
     739class Database(object): 
     740     
     741    adaptertosql = AdapterToSQL() 
     742    adapterfromdb = AdapterFromDB() 
     743    typeadapter = TypeAdapter() 
     744    decompiler = SQLDecompiler 
     745    joinwrapper = TableWrapper 
     746    selectwriter = SelectWriter 
     747    connectionmanager = ConnectionManager 
     748    schemaclass = Schema 
     749     
     750    multischema = True 
     751    multischema__doc = """If True, instances of this Database class may 
     752    spawn multiple Schema instances. This is False, for example, when 
     753    the underlying Database engine binds connections to individual files. 
     754    In most applications (that use a single schema) this presents no 
     755    problems; applications that need to handle more than one schema 
     756    at a time should inspect this value to determine whether they 
     757    need a separate Database instance per Schema instance. 
     758    """ 
     759     
     760    def __init__(self, **kwargs): 
     761        for k, v in kwargs.iteritems(): 
     762            setattr(self, k, v) 
     763         
     764        poolsize = kwargs.get('poolsize', 10) 
     765        self.connections = self.connectionmanager(self, poolsize) 
     766     
     767    def version(self): 
     768        """Return a string containing version info for this database.""" 
     769        raise NotImplementedError 
     770     
     771    def log(self, msg): 
     772        pass 
    629773     
    630774    def python_type(self, dbtype): 
     
    660804        return False 
    661805     
    662      
    663     #                              Container                              # 
    664      
    665     def columnclause(self, column): 
    666         """Return a clause for the given column for CREATE or ALTER TABLE. 
    667          
    668         This will be of the form "name type [DEFAULT x]". 
    669          
    670         Most subclasses will override this to add autoincrement support. 
    671         """ 
    672         dbtype = column.dbtype 
    673          
    674         default = column.default or "" 
    675         if default: 
    676             default = self.adaptertosql.coerce(default, dbtype) 
    677             default = " DEFAULT %s" % default 
    678          
    679         return "%s %s%s" % (column.qname, dbtype, default) 
    680      
    681     def create_sequence(self, table, column): 
    682         """Create a SEQUENCE for the given column and set its sequence_name.""" 
    683         # By default, this does nothing. Databases which require a separate 
    684         # statement to create a sequence generator should override this. 
    685         pass 
    686      
    687     def drop_sequence(self, column): 
    688         """Drop a SEQUENCE for the given column and remove its sequence_name.""" 
    689         # By default, this does nothing. Databases which require a separate 
    690         # statement to drop a sequence generator should override this. 
    691         pass 
    692      
    693     def __setitem__(self, key, table): 
    694         if key in self: 
    695             del self[key] 
    696          
    697         # Set table.created to True, which should "turn on" 
    698         # any future ALTER TABLE statements. 
    699         table.created = True 
    700          
    701         fields = [] 
    702         pk = [] 
    703         for column in table.itervalues(): 
    704             if column.autoincrement: 
    705                 # This may or may not be a no-op, depending on the DB. 
    706                 self.create_sequence(table, column) 
    707              
    708             fields.append(self.columnclause(column)) 
    709             if column.key: 
    710                 pk.append(column.qname) 
    711          
    712         if pk: 
    713             pk = ", PRIMARY KEY (%s)" % ", ".join(pk) 
    714         else: 
    715             pk = "" 
    716          
    717         self.execute_ddl('CREATE TABLE %s (%s%s);' % 
    718                          (table.qname, ", ".join(fields), pk)) 
    719          
    720         for index in table.indices.itervalues(): 
    721             self.execute_ddl('CREATE INDEX %s ON %s (%s);' % 
    722                              (index.qname, table.qname, 
    723                               self.quote(index.colname))) 
    724          
    725         dict.__setitem__(self, key, table) 
    726      
    727     def __delitem__(self, key): 
    728         table = self[key] 
    729         self.execute_ddl('DROP TABLE %s;' % table.qname) 
    730         for col in table.itervalues(): 
    731             if col.autoincrement: 
    732                 self.drop_sequence(col) 
    733         dict.__delitem__(self, key) 
    734      
    735     def _rename(self, oldtable, newtable): 
    736         # Override this to do the actual rename at the DB level. 
    737         raise NotImplementedError 
    738         newtable.created = True 
    739      
    740     def rename(self, oldkey, newkey): 
    741         """Rename a Table.""" 
    742         oldtable = self[oldkey] 
    743         oldname = oldtable.name 
    744         newname = self.table_name(newkey) 
    745          
    746         if oldname != newname: 
    747             newtable = oldtable.copy() 
    748             newtable.db = self 
    749             newtable.name = newname 
    750             newtable.qname = self.quote(newname) 
    751             self._rename(oldtable, newname) 
    752          
    753         # Use the superclass calls to avoid DROP TABLE/CREATE TABLE. 
    754         dict.__delitem__(self, oldkey) 
    755         dict.__setitem__(self, newkey, newtable) 
    756      
    757806    #                               Naming                               # 
    758807     
    759808    sql_name_max_length = 64 
    760809    sql_name_caseless = False 
    761     Prefix = "" 
    762810     
    763811    def quote(self, name): 
     
    779827         
    780828        return key 
    781      
    782     def _column_name(self, tablename, columnkey): 
    783         "Return the SQL column name for the given table name and column key." 
    784         # If you want to use a map from your ORM's property names 
    785         # to DB column names, override this method (that's why 
    786         # the tablename must be included in the args). 
    787         return self.sql_name(columnkey) 
    788      
    789     def column(self, pytype=unicode, dbtype=None, default=None, hints=None, 
    790                key=False, autoincrement=False): 
    791         """Return a Column object from the given arguments.""" 
    792         col = Column(pytype, dbtype, default, hints, key) 
    793         col.autoincrement = autoincrement 
    794          
    795         if dbtype is None: 
    796             col.dbtype = self.typeadapter.coerce(col, pytype) 
    797         pytype2 = self.python_type(col.dbtype) 
    798         col.imperfect_type = not self.isrelatedtype(pytype, pytype2) 
    799          
    800         return col 
    801      
    802     def table_name(self, key): 
    803         """Return the SQL table name for the given key.""" 
    804         # If you want to use a map from your ORM's class names 
    805         # to DB table names, override this method. 
    806         return self.sql_name(self.Prefix + key) 
    807      
    808     def table(self, name): 
    809         """Create and return a Table object for the given name.""" 
    810         name = self.table_name(name) 
    811         return self.tableclass(name, self.quote(name), self) 
    812829     
    813830    def is_timeout_error(self, exc): 
     
    867884        return ResultSet(data, sel.columns, sel.imperfect) 
    868885     
    869     def create_database(self): 
    870         self.execute_ddl("CREATE DATABASE %s;" % self.qname) 
    871         self.clear() 
    872      
    873     def drop_database(self): 
    874         # Must shut down all connections to avoid 
    875         # "being accessed by other users" error. 
    876         self.connections.shutdown() 
    877         self.execute_ddl("DROP DATABASE %s;" % self.qname) 
    878         self.clear() 
     886    def schema(self, name): 
     887        return self.schemaclass(self, name) 
    879888 
    880889 
     
    901910            val = row[i] 
    902911            if table and col: 
    903                 val = table.db.adapterfromdb.coerce(val, col.dbtype, col.pytype) 
     912                val = table.schema.db.adapterfromdb.coerce(val, col.dbtype, col.pytype) 
    904913            coerced_row.append(val) 
    905914        return coerced_row 
  • trunk/geniusql/conn.py

    r10 r11  
    165165        self.db = db 
    166166        self.poolsize = poolsize 
    167         if poolsize > 0: 
     167        self._set_factory() 
     168     
     169    def _set_factory(self): 
     170        if self.poolsize > 0: 
    168171            self._factory = ConnectionPool(self._get_conn, self._del_conn, 
    169                                            poolsize) 
     172                                           self.poolsize) 
    170173        else: 
    171174            self._factory = ConnectionFactory(self._get_conn, self._del_conn) 
  • trunk/geniusql/providers/__init__.py

    r10 r11  
    22 
    33import re 
     4from geniusql import xray 
    45 
    56 
     
    4041 
    4142 
    42 providers = { 
     43class _Registry(dict): 
     44     
     45    def open(self, key, **kwargs): 
     46        opener = self[key] 
     47        if isinstance(opener, basestring): 
     48            opener = xray.attributes(opener) 
     49        return opener(**kwargs) 
     50 
     51registry = _Registry({ 
    4352    "access": "geniusql.providers.ado.MSAccessDatabase", 
    4453    "msaccess": "geniusql.providers.ado.MSAccessDatabase", 
     
    5867    "sqlserver": "geniusql.providers.ado.SQLServerDatabase", 
    5968    "mssql": "geniusql.providers.ado.SQLServerDatabase", 
    60     } 
     69    }) 
  • trunk/geniusql/providers/ado.py

    r10 r11  
    363363    def _add_column(self, column): 
    364364        """Internal function to add the column to the database.""" 
    365         coldef = self.db.columnclause(column) 
     365        coldef = self.schema.columnclause(column) 
    366366        # SQL Server doesn't use the "COLUMN" keyword with "ADD" 
    367         self.db.execute_ddl("ALTER TABLE %s ADD %s;" % (self.qname, coldef)) 
     367        self.schema.db.execute_ddl("ALTER TABLE %s ADD %s;" % 
     368                                   (self.qname, coldef)) 
    368369     
    369370    def _rename(self, oldcol, newcol): 
    370         conn = self.db.connections.get() 
     371        conn = self.schema.db.connections.get() 
    371372        try: 
    372373            cat = win32com.client.Dispatch(r'ADOX.Catalog') 
     
    446447 
    447448 
    448 class ADODatabase(geniusql.Database): 
    449      
    450     decompiler = ADOSQLDecompiler 
    451     adapterfromdb = AdapterFromADO() 
     449class ADOSchema(geniusql.Schema): 
     450     
    452451    tableclass = ADOTable 
    453      
    454     def __init__(self, name, **kwargs): 
    455         geniusql.Database.__init__(self, name, **kwargs) 
    456         self.connections.Connect = self.Connect 
    457452     
    458453    #                              Discovery                              # 
     
    463458        # (u'TABLE_TYPE', 202), (u'TABLE_GUID', 72), (u'DESCRIPTION', 203), 
    464459        # (u'TABLE_PROPID', 19), (u'DATE_CREATED', 7), (u'DATE_MODIFIED', 7)] 
    465         data, _ = self.fetch(adSchemaTables, conn=conn, schema=True) 
    466         return [self.tableclass(str(row[2]), self.quote(str(row[2])), 
     460        data, _ = self.db.fetch(adSchemaTables, conn=conn, schema=True) 
     461        return [self.tableclass(str(row[2]), self.db.quote(str(row[2])), 
    467462                                self, created=True) 
    468463                for row in data 
     
    475470        # (u'TABLE_TYPE', 202), (u'TABLE_GUID', 72), (u'DESCRIPTION', 203), 
    476471        # (u'TABLE_PROPID', 19), (u'DATE_CREATED', 7), (u'DATE_MODIFIED', 7)] 
    477         data, _ = self.fetch(adSchemaTables, conn=conn, schema=True) 
     472        data, _ = self.db.fetch(adSchemaTables, conn=conn, schema=True) 
    478473        for row in data: 
    479474            name = str(row[2]) 
    480475            if name == tablename: 
    481                 return self.tableclass(name, self.quote(name), 
     476                return self.tableclass(name, self.db.quote(name), 
    482477                                       self, created=True) 
    483478        raise errors.MappingError(tablename) 
     
    495490        # (u'COLUMN_PROPID', 19), (u'COLLATION', 2), (u'CARDINALITY', 21), 
    496491        # (u'PAGES', 3), (u'FILTER_CONDITION', 202), (u'INTEGRATED', 11)] 
    497         data, _ = self.fetch(adSchemaIndexes, conn=conn, schema=True) 
     492        data, _ = self.db.fetch(adSchemaIndexes, conn=conn, schema=True) 
    498493        pknames = [row[17] for row in data 
    499494                   if (tablename == row[2]) and row[6]] 
     
    512507        # (u'DOMAIN_CATALOG', 202), (u'DOMAIN_SCHEMA', 202), 
    513508        # (u'DOMAIN_NAME', 202), (u'DESCRIPTION', 203)] 
    514         data, _ = self.fetch(adSchemaColumns, conn=conn, schema=True) 
     509        data, _ = self.db.fetch(adSchemaColumns, conn=conn, schema=True) 
    515510         
    516511        cols = [] 
     
    524519            default = row[8] 
    525520            if default is not None: 
    526                 deftype = self.python_type(dbtype) 
     521                deftype = self.db.python_type(dbtype) 
    527522                if issubclass(deftype, (int, long)): 
    528523                    # We may have stuck extraneous quotes in the default 
     
    533528             
    534529            name = str(row[3]) 
    535             c = geniusql.Column(self.python_type(dbtype), dbtype, 
     530            c = geniusql.Column(self.db.python_type(dbtype), dbtype, 
    536531                                default, hints={}, key=(name in pknames), 
    537                                 name=name, qname=self.quote(name)) 
     532                                name=name, qname=self.db.quote(name)) 
    538533             
    539534            # This only works for SQL Server. The MSAccessDatabase will 
     
    592587        # (u'COLUMN_PROPID', 19), (u'COLLATION', 2), (u'CARDINALITY', 21), 
    593588        # (u'PAGES', 3), (u'FILTER_CONDITION', 202), (u'INTEGRATED', 11)] 
    594         data, _ = self.fetch(adSchemaIndexes, conn=conn, schema=True) 
     589        data, _ = self.db.fetch(adSchemaIndexes, conn=conn, schema=True) 
    595590        indices = [] 
    596591        for row in data: 
     
    599594            if tablename and row[2] != tablename: 
    600595                continue 
    601             i = geniusql.Index(row[5], self.quote(row[5]), row[2], row[17], row[7]) 
     596            i = geniusql.Index(row[5], self.db.quote(row[5]), 
     597                               row[2], row[17], row[7]) 
    602598            indices.append(i) 
    603599        return indices 
     600     
     601    #                              Container                              # 
     602     
     603    def _rename(self, oldtable, newtable): 
     604        conn = self.db.connections.get() 
     605        try: 
     606            cat = win32com.client.Dispatch(r'ADOX.Catalog') 
     607            cat.ActiveConnection = conn 
     608            cat.Tables(oldtable.name).Name = newtable.name 
     609        finally: 
     610            conn = None 
     611            cat = None 
     612 
     613 
     614class ADODatabase(geniusql.Database): 
     615     
     616    decompiler = ADOSQLDecompiler 
     617    adapterfromdb = AdapterFromADO() 
     618     
     619    def __init__(self, **kwargs): 
     620        geniusql.Database.__init__(self, **kwargs) 
     621        self.connections.Connect = self.Connect 
    604622     
    605623    def python_type(self, dbtype): 
     
    641659                        "to a Python type." % dbtype) 
    642660     
    643      
    644     #                              Container                              # 
    645      
    646     def _rename(self, oldtable, newtable): 
    647         conn = self.connections.get() 
    648         try: 
    649             cat = win32com.client.Dispatch(r'ADOX.Catalog') 
    650             cat.ActiveConnection = conn 
    651             cat.Tables(oldtable.name).Name = newtable.name 
    652         finally: 
    653             conn = None 
    654             cat = None 
    655661     
    656662    #                               Naming                                # 
     
    861867     
    862868    def _rename(self, oldcol, newcol): 
    863         self.db.execute_ddl("EXEC sp_rename '%s.%s', '%s', 'COLUMN'" % 
    864                             (self.name, oldcol.name, newcol.name)) 
     869        self.schema.db.execute_ddl("EXEC sp_rename '%s.%s', '%s', 'COLUMN'" % 
     870                                   (self.name, oldcol.name, newcol.name)) 
    865871     
    866872    def _grab_new_ids(self, idkeys, conn): 
     
    869875        # None) when retrieving ID's just after a 99-thread-test ran. Moving 
    870876        # the multithreading test fixed it. IDENT_CURRENT worked regardless. 
    871         data, _ = self.db.fetch("SELECT IDENT_CURRENT('%s');" % self.qname, 
    872                                 conn) 
     877        data, _ = self.schema.db.fetch("SELECT IDENT_CURRENT('%s');" 
     878                                       % self.qname, conn) 
    873879        return {idkeys[0]: data[0][0]} 
    874880 
     
    879885 
    880886 
    881 class SQLServerDatabase(ADODatabase): 
    882      
    883     decompiler = ADOSQLDecompiler_SQLServer 
     887class SQLServerSchema(ADOSchema): 
     888     
    884889    tableclass = SQLServerTable 
    885     adaptertosql = AdapterToADOSQL_SQLServer() 
    886     typeadapter = TypeAdapter_SQLServer() 
    887     connectionmanager = SQLServerConnectionManager 
    888      
    889     def __init__(self, name, **kwargs): 
    890         ADODatabase.__init__(self, name, **kwargs) 
    891         if "2005" in self.version(): 
    892             self.connections.isolation_levels.append("SNAPSHOT") 
    893      
    894     def version(self): 
    895         conn = self.connections._get_conn(master=True) 
    896         adov = conn.Version 
    897         data, coldefs = self.fetch("SELECT @@VERSION;", conn) 
    898         sqlv, = data[0] 
    899         conn.Close() 
    900         del conn 
    901         return "ADO Version: %s\n%s" % (adov, sqlv) 
    902890     
    903891    def create_database(self): 
    904         conn = self.connections._get_conn(master=True) 
    905         self.execute_ddl("CREATE DATABASE %s;" % self.qname, conn) 
     892        conn = self.db.connections._get_conn(master=True) 
     893        self.db.execute_ddl("CREATE DATABASE %s;" % self.qname, conn) 
    906894        conn.Close() 
    907895        self.clear() 
    908896     
    909897    def drop_database(self): 
    910         conn = self.connections._get_conn(master=True) 
    911         self.execute_ddl("DROP DATABASE %s;" % self.qname, conn) 
     898        conn = self.db.connections._get_conn(master=True) 
     899        self.db.execute_ddl("DROP DATABASE %s;" % self.qname, conn) 
    912900        conn.Close() 
    913901        self.clear() 
     
    932920            default = column.default or "" 
    933921            if default: 
    934                 clause = self.adaptertosql.coerce(default, dbtype) 
     922                clause = self.db.adaptertosql.coerce(default, dbtype) 
    935923                clause = " DEFAULT %s" % clause 
    936924         
    937925        return '%s %s%s' % (column.qname, dbtype, clause) 
     926 
     927 
     928class SQLServerDatabase(ADODatabase): 
     929     
     930    decompiler = ADOSQLDecompiler_SQLServer 
     931    adaptertosql = AdapterToADOSQL_SQLServer() 
     932    typeadapter = TypeAdapter_SQLServer() 
     933    connectionmanager = SQLServerConnectionManager 
     934    schemaclass = SQLServerSchema 
     935     
     936    def __init__(self, **kwargs): 
     937        ADODatabase.__init__(self, **kwargs) 
     938        if "2005" in self.version(): 
     939            self.connections.isolation_levels.append("SNAPSHOT") 
     940     
     941    def version(self): 
     942        conn = self.connections._get_conn(master=True) 
     943        adov = conn.Version 
     944        data, coldefs = self.fetch("SELECT @@VERSION;", conn) 
     945        sqlv, = data[0] 
     946        conn.Close() 
     947        del conn 
     948        return "ADO Version: %s\n%s" % (adov, sqlv) 
    938949     
    939950    def is_timeout_error(self, exc): 
     
    10281039 
    10291040 
    1030  
    10311041class AdapterToADOSQL_MSAccess(geniusql.AdapterToSQL): 
    10321042    """Coerce Expression constants to ADO SQL.""" 
     
    10551065     
    10561066    def _grab_new_ids(self, idkeys, conn): 
    1057         data, _ = self.db.fetch("SELECT @@IDENTITY;", conn) 
     1067        data, _ = self.schema.db.fetch("SELECT @@IDENTITY;", conn) 
    10581068        return {idkeys[0]: data[0][0]} 
    10591069 
     
    10991109 
    11001110 
    1101 class MSAccessDatabase(ADODatabase): 
    1102      
    1103     decompiler = ADOSQLDecompiler_MSAccess 
    1104     adaptertosql = AdapterToADOSQL_MSAccess() 
    1105     typeadapter = TypeAdapter_MSAccess() 
     1111class MSAccessSchema(ADOSchema): 
     1112     
    11061113    tableclass = MSAccessTable 
    1107     connectionmanager = MSAccessConnectionManager 
    1108      
    1109     def version(self): 
    1110         conn = win32com.client.Dispatch(r'ADODB.Connection') 
    1111         v = conn.Version 
    1112         del conn 
    1113         return "ADO Version: %s" % v 
    11141114     
    11151115    def _get_columns(self, tablename, conn=None): 
    1116         cols = ADODatabase._get_columns(self, tablename, conn) 
     1116        cols = ADOSchema._get_columns(self, tablename, conn) 
    11171117        if conn is None: 
    1118             conn = self.connections._factory() 
     1118            conn = self.db.connections._factory() 
    11191119         
    11201120        try: 
    11211121            # Horrible hack to get autoincrement property 
    1122             query = "SELECT * FROM %s WHERE FALSE" % self.quote(tablename) 
     1122            query = "SELECT * FROM %s WHERE FALSE" % self.db.quote(tablename) 
    11231123            bareconn = conn 
    11241124            if hasattr(conn, 'conn'): 
     
    11661166        return cols 
    11671167     
    1168     def python_type(self, dbtype): 
    1169         if dbtype == "LONG": 
    1170             return int 
    1171         return ADODatabase.python_type(self, dbtype) 
    1172      
    11731168    def columnclause(self, column): 
    11741169        """Return a clause for the given column for CREATE or ALTER TABLE. 
     
    11871182            default = column.default or "" 
    11881183            if default: 
    1189                 defspec = self.adaptertosql.coerce(default, dbtype) 
     1184                defspec = self.db.adaptertosql.coerce(default, dbtype) 
    11901185                if isinstance(default, (int, long)): 
    11911186                    # Crazy quote hack to get a numeric default to work. 
     
    11981193        # By not providing an Engine Type, it defaults to 5 = Access 2000. 
    11991194        cat = win32com.client.Dispatch(r'ADOX.Catalog') 
    1200         cat.Create(self.Connect) 
     1195        cat.Create(self.db.Connect) 
    12011196        cat.ActiveConnection.Close() 
    12021197        self.clear() 
     
    12051200        # Must shut down our only connection to avoid 
    12061201        # "Permission denied" error on os.remove call below. 
    1207         self.connections.shutdown() 
     1202        self.db.connections.shutdown() 
    12081203         
    12091204        import os 
     
    12111206        if os.path.exists(self.name): 
    12121207            os.remove(self.name) 
     1208         
    12131209        self.clear() 
     1210 
     1211 
     1212class MSAccessDatabase(ADODatabase): 
     1213     
     1214    decompiler = ADOSQLDecompiler_MSAccess 
     1215    adaptertosql = AdapterToADOSQL_MSAccess() 
     1216    typeadapter = TypeAdapter_MSAccess() 
     1217    connectionmanager = MSAccessConnectionManager 
     1218    schemaclass = MSAccessSchema 
     1219     
     1220    def version(self): 
     1221        conn = win32com.client.Dispatch(r'ADODB.Connection') 
     1222        v = conn.Version 
     1223        del conn 
     1224        return "ADO Version: %s" % v 
     1225     
     1226    def python_type(self, dbtype): 
     1227        if dbtype == "LONG": 
     1228            return int 
     1229        return ADODatabase.python_type(self, dbtype) 
    12141230 
    12151231 
  • trunk/geniusql/providers/firebird.py

    r10 r11  
    211211            wt1 = self.db.joinwrapper(t1) 
    212212            self.aliascount += 1 
    213             wt1.alias = self.quote("t%d" % self.aliascount) 
     213            alias = "t%d" % self.aliascount 
     214            wt1.alias = self.db.quote(t1.schema.table_name(alias)) 
    214215            self.seen[t1.name] = None 
    215216         
     
    219220            wt2 = self.db.joinwrapper(t2) 
    220221            self.aliascount += 1 
    221             wt2.alias = self.quote("t%d" % self.aliascount) 
     222            alias = "t%d" % self.aliascount 
     223            wt2.alias = self.db.quote(t2.schema.table_name(alias)) 
    222224            self.seen[t2.name] = None 
    223225         
     
    241243    def _add_column(self, column): 
    242244        """Internal function to add the column to the database.""" 
    243         coldef = self.db.columnclause(column) 
     245        coldef = self.schema.columnclause(column) 
    244246        # FB doesn't recognize the keyword "COLUMN" in "ADD". 
    245         self.db.execute_ddl("ALTER TABLE %s ADD %s;" % (self.qname, coldef)) 
     247        self.schema.db.execute_ddl("ALTER TABLE %s ADD %s;" % 
     248                                   (self.qname, coldef)) 
    246249     
    247250    def _drop_column(self, column): 
    248251        """Internal function to drop the column from the database.""" 
    249252        # FB doesn't recognize the keyword "COLUMN" in "DROP". 
    250         self.db.execute_ddl("ALTER TABLE %s DROP %s;" % 
    251                             (self.qname, column.qname)) 
     253        self.schema.db.execute_ddl("ALTER TABLE %s DROP %s;" % 
     254                                   (self.qname, column.qname)) 
    252255     
    253256    def _rename(self, oldcol, newcol): 
    254257        # FB doesn't use the keyword "RENAME". 
    255         self.db.execute_ddl("ALTER TABLE %s ALTER COLUMN %s TO %s;" % 
    256                             (self.qname, oldcol.qname, newcol.qname)) 
     258        self.schema.db.execute_ddl("ALTER TABLE %s ALTER COLUMN %s TO %s;" % 
     259                                   (self.qname, oldcol.qname, newcol.qname)) 
    257260     
    258261    def insert(self, **inputs): 
    259262        """Insert a row and return {idcolkey: newid}.""" 
    260         coerce_out = self.db.adaptertosql.coerce 
    261         coerce_in = self.db.adapterfromdb.coerce 
     263        coerce_out = self.schema.db.adaptertosql.coerce 
     264        coerce_in = self.schema.db.adapterfromdb.coerce 
    262265         
    263266        newids = {} 
     
    267270            if col.autoincrement: 
    268271                # This advances the generator and returns its new value. 
    269                 data, _ = self.db.fetch("SELECT GEN_ID(%s, 1) FROM RDB$DATABASE;" 
    270                                         % col.sequence_name) 
     272                sql = ("SELECT GEN_ID(%s, 1) FROM RDB$DATABASE;" 
     273                       % col.sequence_name) 
     274                data, _ = self.schema.db.fetch(sql) 
    271275                newid = coerce_in(data[0][0], col.dbtype, col.pytype) 
    272276                newids[key] = newid 
     
    282286        fields = ", ".join(fields) 
    283287        values = ", ".join(values) 
    284         self.db.execute('INSERT INTO %s (%s) VALUES (%s);' 
    285                         % (self.qname, fields, values)) 
     288        self.schema.db.execute('INSERT INTO %s (%s) VALUES (%s);' 
     289                               % (self.qname, fields, values)) 
    286290         
    287291        return newids 
     
    366370 
    367371 
    368 class FirebirdDatabase(geniusql.Database): 
    369      
    370     selectwriter = FirebirdSelectWriter 
    371     decompiler = FirebirdSQLDecompiler 
    372      
    373     adaptertosql = AdapterToFireBirdSQL() 
    374     adapterfromdb = AdapterFromFirebirdDB() 
    375     typeadapter = TypeAdapterFirebird() 
     372class FirebirdSchema(geniusql.Schema): 
     373     
    376374    tableclass = FirebirdTable 
    377     connectionmanager = FirebirdConnectionManager 
    378      
    379     sql_name_max_length = 31 
    380     encoding = 'utf8' 
    381      
    382     def __init__(self, name, **kwargs): 
    383         self._discover_lock = threading.Lock() 
    384          
     375     
     376    def __init__(self, db, name): 
    385377        dict.__init__(self) 
    386         for k, v in kwargs.iteritems(): 
    387             setattr(self, k, v) 
     378         
     379        self.db = db 
    388380         
    389381        # Here's where we differ from the superclass. 
     
    391383        # so we don't set self.name = sql_name(name). 
    392384        self.name = name 
    393          
    394         self.qname = self.quote(self.name) 
    395         self.connections = self.connectionmanager(self, kwargs.get('poolsize', 10)) 
     385        # We also have to set our parent's "name" so conns can use it. 
     386        self.db.name = name 
     387         
     388        self.qname = self.db.quote(self.name) 
     389        self._discover_lock = threading.Lock() 
    396390        self.discover_dbinfo() 
    397391     
    398     def _get_dbinfo(self, conn=None): 
    399         return {} 
    400      
    401392    def _get_tables(self, conn=None): 
    402         data, _ = self.fetch("SELECT RDB$RELATION_NAME FROM RDB$RELATIONS " 
    403                              "WHERE RDB$SYSTEM_FLAG=0 AND RDB$VIEW_BLR IS NULL;", 
    404                             conn=conn) 
    405         return [self.tableclass(name.strip(), self.quote(name.strip()), 
     393        data, _ = self.db.fetch( 
     394            "SELECT RDB$RELATION_NAME FROM RDB$RELATIONS " 
     395            "WHERE RDB$SYSTEM_FLAG=0 AND RDB$VIEW_BLR IS NULL;", conn=conn) 
     396        return [self.tableclass(name.strip(), self.db.quote(name.strip()), 
    406397                                self, created=True) 
    407398                for name, in data] 
    408399     
    409400    def _get_table(self, tablename, conn=None): 
    410         data, _ = self.fetch("SELECT RDB$RELATION_NAME FROM RDB$RELATIONS " 
    411                              "WHERE RDB$SYSTEM_FLAG=0 AND RDB$VIEW_BLR IS NULL
    412                              "AND RDB$RELATION_NAME = '%s';" % tablename, 
    413                             conn=conn) 
     401        data, _ = self.db.fetch( 
     402            "SELECT RDB$RELATION_NAME FROM RDB$RELATIONS
     403            "WHERE RDB$SYSTEM_FLAG=0 AND RDB$VIEW_BLR IS NULL " 
     404            "AND RDB$RELATION_NAME = '%s';" % tablename, conn=conn) 
    414405        for name, in data: 
    415406            name = name.strip() 
    416407            if name == tablename: 
    417                 return self.tableclass(name, self.quote(name), 
     408                return self.tableclass(name, self.db.quote(name), 
    418409                                       self, created=True) 
    419410        raise errors.MappingError(tablename) 
     
    424415         
    425416        # Get Primary Key names first 
    426         data, _ = self.fetch( 
     417        data, _ = self.db.fetch( 
    427418            "SELECT S.RDB$FIELD_NAME AS COLUMN_NAME " 
    428419            "FROM RDB$RELATION_CONSTRAINTS RC " 
     
    430421            "LEFT JOIN RDB$INDEX_SEGMENTS S ON (S.RDB$INDEX_NAME = I.RDB$INDEX_NAME) " 
    431422            "WHERE (RC.RDB$CONSTRAINT_TYPE = 'PRIMARY KEY') " 
    432             "AND (I.RDB$RELATION_NAME = '%s')" % tablename, 
    433             conn=conn 
    434             ) 
     423            "AND (I.RDB$RELATION_NAME = '%s')" % tablename, conn=conn) 
    435424        pks = [row[0].rstrip() for row in data] 
    436425         
    437426        # Now get the rest of the col data 
    438         data, _ = self.fetch("SELECT RF.RDB$FIELD_NAME, T.RDB$TYPE_NAME, " 
    439                              "F.RDB$FIELD_LENGTH, RF.RDB$DEFAULT_SOURCE, " 
    440                              "F.RDB$FIELD_PRECISION, F.RDB$FIELD_SCALE " 
    441                              "FROM RDB$RELATION_FIELDS RF LEFT JOIN " 
    442                              "RDB$FIELDS F ON F.RDB$FIELD_NAME = RF.RDB$FIELD_SOURCE " 
    443                              "LEFT JOIN RDB$TYPES T ON T.RDB$TYPE = F.RDB$FIELD_TYPE " 
    444                              "WHERE RF.RDB$RELATION_NAME='%s' AND " 
    445                              "T.RDB$FIELD_NAME='RDB$FIELD_TYPE';" % tablename, 
    446                              conn=conn) 
     427        data, _ = self.db.fetch( 
     428            "SELECT RF.RDB$FIELD_NAME, T.RDB$TYPE_NAME, F.RDB$FIELD_LENGTH, " 
     429            "RF.RDB$DEFAULT_SOURCE, F.RDB$FIELD_PRECISION, F.RDB$FIELD_SCALE " 
     430            "FROM RDB$RELATION_FIELDS RF LEFT JOIN RDB$FIELDS F " 
     431            "ON F.RDB$FIELD_NAME = RF.RDB$FIELD_SOURCE " 
     432            "LEFT JOIN RDB$TYPES T ON T.RDB$TYPE = F.RDB$FIELD_TYPE " 
     433            "WHERE RF.RDB$RELATION_NAME='%s' AND " 
     434            "T.RDB$FIELD_NAME='RDB$FIELD_TYPE';" % tablename, conn=conn) 
    447435        cols = [] 
    448436        for name, dbtype, fieldlen, default, prec, scale in data: 
     
    474462             
    475463            # Column(pytype, dbtype, default=None, hints=None, key=False, name, qname) 
    476             col = geniusql.Column(self.python_type(dbtype), dbtype, default, 
    477                                   hints, key, name, self.quote(name)) 
     464            col = geniusql.Column(self.db.python_type(dbtype), dbtype, default, 
     465                                  hints, key, name, self.db.quote(name)) 
    478466            cols.append(col) 
    479467        return cols 
    480468     
    481469    def _get_indices(self, tablename, conn=None): 
    482         data, _ = self.fetch("SELECT I.RDB$INDEX_NAME, S.RDB$FIELD_NAME, " 
    483                              "I.RDB$UNIQUE_FLAG " 
    484                              "FROM RDB$INDICES I LEFT JOIN RDB$INDEX_SEGMENTS S " 
    485                              "ON (S.RDB$INDEX_NAME = I.RDB$INDEX_NAME) " 
    486                              "WHERE I.RDB$RELATION_NAME = '%s';" 
    487                              % tablename.ljust(31, " "), 
    488                              conn=conn) 
     470        data, _ = self.db.fetch( 
     471            "SELECT I.RDB$INDEX_NAME, S.RDB$FIELD_NAME, I.RDB$UNIQUE_FLAG " 
     472            "FROM RDB$INDICES I LEFT JOIN RDB$INDEX_SEGMENTS S " 
     473            "ON (S.RDB$INDEX_NAME = I.RDB$INDEX_NAME) " 
     474            "WHERE I.RDB$RELATION_NAME = '%s';" % tablename.ljust(31, " "), 
     475            conn=conn) 
     476         
    489477        indices = [] 
    490478        for name, colname, unique in data: 
     
    492480            colname = colname.rstrip() 
    493481            unique = bool(unique) 
    494             ind = geniusql.Index(name, self.quote(name), tablename, colname, unique) 
     482            ind = geniusql.Index(name, self.db.quote(name), 
     483                                 tablename, colname, unique) 
    495484            indices.append(ind) 
    496485         
    497486        return indices 
     487     
     488    def columnclause(self, column): 
     489        """Return a clause for the given column for CREATE or ALTER TABLE. 
     490         
     491        This will be of the form "name type [DEFAULT x] [NOT NULL]". 
     492         
     493        Firebird needs the sequence created in a separate SQL statement. 
     494        """ 
     495        dbtype = column.dbtype 
     496         
     497        default = column.default or "" 
     498        if default: 
     499            default = self.db.adaptertosql.coerce(default, dbtype) 
     500            default = " DEFAULT %s" % default 
     501         
     502        notnull = "" 
     503        if column.key: 
     504            # Firebird PK's must be NOT NULL 
     505            notnull = " NOT NULL" 
     506         
     507        return '%s %s%s%s' % (column.qname, dbtype, default, notnull) 
     508     
     509    def create_sequence(self, table, column): 
     510        """Create a SEQUENCE for the given column and set its sequence_name.""" 
     511        sname = column.sequence_name 
     512        if sname is None: 
     513            sname = self.db.quote("%s_%s_seq" % (table.name, column.name)) 
     514            column.sequence_name = sname 
     515        self.db.execute_ddl("CREATE GENERATOR %s;" % sname) 
     516        self.db.execute_ddl("SET GENERATOR %s TO %s;" % 
     517                            (sname, column.initial - 1)) 
     518     
     519    def drop_sequence(self, column): 
     520        """Drop a SEQUENCE for the given column and remove its sequence_name.""" 
     521        if column.sequence_name is not None: 
     522            self.db.execute_ddl("DROP GENERATOR %s;" % column.sequence_name) 
     523            column.sequence_name = None 
     524     
     525    def create_database(self): 
     526        # Firebird DB 'names' are actually filesystem paths. 
     527        sql = ("CREATE DATABASE %s USER '%s' PASSWORD '%s';" 
     528               % (self.qname, self.db.user, self.db.password)) 
     529         
     530        # Use the kinterbasdb helper methods for cleaner create and drop. 
     531        # We also use dialect 3 *always* to help with quoted identifiers. 
     532        conn = kinterbasdb.create_database(sql, 3) 
     533        conn.close() 
     534         
     535        self.clear() 
     536     
     537    def drop_database(self): 
     538        # Must shut down all connections to avoid 
     539        # "being accessed by other users" error. 
     540        self.db.connections.shutdown() 
     541         
     542        conn = self.db.connections._get_conn() 
     543        conn.drop_database() 
     544        # For some reason, the conn is already closed... 
     545##        conn.close() 
     546        self.clear() 
     547 
     548 
     549class FirebirdDatabase(geniusql.Database): 
     550     
     551    selectwriter = FirebirdSelectWriter 
     552    decompiler = FirebirdSQLDecompiler 
     553     
     554    adaptertosql = AdapterToFireBirdSQL() 
     555    adapterfromdb = AdapterFromFirebirdDB() 
     556    typeadapter = TypeAdapterFirebird() 
     557    connectionmanager = FirebirdConnectionManager 
     558     
     559    schemaclass = FirebirdSchema 
     560    multischema = False 
     561     
     562    sql_name_max_length = 31 
     563    encoding = 'utf8' 
    498564     
    499565    def python_type(self, dbtype): 
     
    528594                        "to a Python type." % dbtype) 
    529595     
    530     def columnclause(self, column): 
    531         """Return a clause for the given column for CREATE or ALTER TABLE. 
    532          
    533         This will be of the form "name type [DEFAULT x] [NOT NULL]". 
    534          
    535         Firebird needs the sequence created in a separate SQL statement. 
    536         """ 
    537         dbtype = column.dbtype 
    538          
    539         default = column.default or "" 
    540         if default: 
    541             default = self.adaptertosql.coerce(default, dbtype) 
    542             default = " DEFAULT %s" % default 
    543          
    544         notnull = "" 
    545         if column.key: 
    546             # Firebird PK's must be NOT NULL 
    547             notnull = " NOT NULL" 
    548          
    549         return '%s %s%s%s' % (column.qname, dbtype, default, notnull) 
    550      
    551     def create_sequence(self, table, column): 
    552         """Create a SEQUENCE for the given column and set its sequence_name.""" 
    553         sname = column.sequence_name 
    554         if sname is None: 
    555             sname = self.quote("%s_%s_seq" % (table.name, column.name)) 
    556             column.sequence_name = sname 
    557         self.execute_ddl("CREATE GENERATOR %s;" % sname) 
    558         self.execute_ddl("SET GENERATOR %s TO %s;" % (sname, column.initial - 1)) 
    559      
    560     def drop_sequence(self, column): 
    561         """Drop a SEQUENCE for the given column and remove its sequence_name.""" 
    562         if column.sequence_name is not None: 
    563             self.execute_ddl("DROP GENERATOR %s;" % column.sequence_name) 
    564             column.sequence_name = None 
    565      
    566596    #                               Naming                               # 
    567597     
     
    636666        return data, desc 
    637667     
    638     #                               Schemas                               # 
    639      
    640     def create_database(self): 
    641         # Firebird DB 'names' are actually filesystem paths. 
    642         sql = ("CREATE DATABASE %s USER '%s' PASSWORD '%s';" 
    643                % (self.qname, self.user, self.password)) 
    644          
    645         # Use the kinterbasdb helper methods for cleaner create and drop. 
    646         # We also use dialect 3 *always* to help with quoted identifiers. 
    647         conn = kinterbasdb.create_database(sql, 3) 
    648         conn.close() 
    649          
    650         self.clear() 
    651      
    652     def drop_database(self): 
    653         # Must shut down all connections to avoid 
    654         # "being accessed by other users" error. 
    655         self.connections.shutdown() 
    656          
    657         conn = self.connections._get_conn() 
    658         conn.drop_database() 
    659         # For some reason, the conn is already closed... 
    660 ##        conn.close() 
    661         self.clear() 
    662      
    663668    def is_timeout_error(self, exc): 
    664669        """If the given exception instance is a lock timeout, return True. 
  • trunk/geniusql/providers/mysql.py

    r10 r11  
    142142        t = self.table 
    143143        # MySQL might rename multiple-column indices to "PRIMARY" 
    144         for i in t.db._get_indices(t.name): 
     144        for i in t.schema.db._get_indices(t.name): 
    145145            if i.colname == self[key].colname: 
    146                 t.db.execute_ddl('DROP INDEX %s ON %s;' % (i.qname, t.qname)) 
     146                t.schema.db.execute_ddl('DROP INDEX %s ON %s;' % 
     147                                        (i.qname, t.qname)) 
    147148 
    148149 
    149150class MySQLTable(geniusql.Table): 
    150151     
    151     indexsetclass = MySQLIndexSet 
    152      
    153152    def _rename(self, oldcol, newcol): 
    154         self.db.execute_ddl("ALTER TABLE %s CHANGE %s %s %s;" % 
    155                             (self.qname, oldcol.qname, newcol.qname, 
    156                              oldcol.dbtype)) 
     153        self.schema.db.execute_ddl("ALTER TABLE %s CHANGE %s %s %s;" % 
     154                                   (self.qname, oldcol.qname, newcol.qname, 
     155                                    oldcol.dbtype)) 
    157156     
    158157    def _grab_new_ids(self, idkeys, conn): 
     
    188187 
    189188 
    190 class MySQLDatabase(geniusql.Database): 
    191      
    192     sql_name_max_length = 64 
    193     # MySQL uses case-sensitive database and table names on Unix, but 
    194     # not on Windows. Use all-lowercase identifiers to work around the 
    195     # problem. "Column names, index names, and column aliases are not 
    196     # case sensitive on any platform." 
    197     # If deployers set lower_case_table_names to 1, it would help. 
    198     sql_name_caseless = True 
    199     encoding = "utf8" 
    200      
    201     adaptertosql = AdapterToMySQL() 
    202     adapterfromdb = AdapterFromMySQL() 
    203     typeadapter = TypeAdapterMySQL() 
     189class MySQLSchema(geniusql.Schema): 
    204190     
    205191    tableclass = MySQLTable 
    206192    indexsetclass = MySQLIndexSet 
    207     connectionmanager = MySQLConnectionManager 
    208      
    209     def __init__(self, name, **kwargs): 
    210         geniusql.Database.__init__(self, name, **kwargs) 
    211          
    212         self.connections.connargs = dict([(k, v) for k, v in kwargs.iteritems() 
    213                                           if k in connargs]) 
    214          
    215         self.decompiler = MySQLDecompiler 
    216          
    217         # Get the version string from MySQL, to see if we need 
    218         # a different decompiler. 
    219         conn = self.connections._get_conn(master=True) 
    220         rowdata, cols = self.fetch("SELECT version();", conn) 
    221         conn.close() 
    222         v = rowdata[0][0] 
    223         self._version = providers.Version(v) 
    224          
    225         # decompiler 
    226         if self._version > providers.Version("4.1.1"): 
    227             self.decompiler = MySQLDecompiler411 
    228          
    229         # type adapter 
    230         if self._version >= providers.Version("4.1"): 
    231             self.typeadapter = TypeAdapterMySQL41() 
    232      
    233     def version(self): 
    234         return "MySQL Version: %s" % self._version 
    235193     
    236194    def columnclause(self, column): 
     
    247205        default = column.default or "" 
    248206        if default: 
    249             default = self.adaptertosql.coerce(default, dbtype) 
     207            default = self.db.adaptertosql.coerce(default, dbtype) 
    250208            default = " DEFAULT %s" % default 
    251209         
     
    253211     
    254212    def __setitem__(self, key, table): 
    255         q = self.quote 
     213        q = self.db.quote 
    256214        if key in self: 
    257215            del self[key] 
     
    286244            pk = "" 
    287245         
    288         encoding = self.encoding 
     246        encoding = self.db.encoding 
    289247        if encoding: 
    290248            encoding = " CHARACTER SET %s" % encoding 
    291249         
    292         self.execute_ddl('CREATE TABLE %s (%s%s)%s;' % 
    293                          (table.qname, ", ".join(fields), pk, encoding)) 
     250        self.db.execute_ddl('CREATE TABLE %s (%s%s)%s;' % 
     251                            (table.qname, ", ".join(fields), pk, encoding)) 
    294252         
    295253        if incr_fields: 
     
    299257            fields = ", ".join([col.qname for col in incr_fields]) 
    300258            values = ", ".join([str(col.initial - 1) for col in incr_fields]) 
    301             self.execute_ddl("INSERT INTO %s (%s) VALUES (%s);" 
    302                              % (table.qname, fields, values)) 
     259            self.db.execute_ddl("INSERT INTO %s (%s) VALUES (%s);" 
     260                                % (table.qname, fields, values)) 
    303261         
    304262        for k, index in table.indices.iteritems(): 
     
    307265                # MySQL won't allow indexes on a BLOB field without a 
    308266                # specific index prefix length. We choose 255 just for fun. 
    309                 self.execute_ddl('CREATE INDEX %s ON %s (%s(255));' % 
    310                                  (index.qname, table.qname, q(index.colname))) 
     267                self.db.execute_ddl('CREATE INDEX %s ON %s (%s(255));' % 
     268                                    (index.qname, table.qname, q(index.colname))) 
    311269            else: 
    312                 self.execute_ddl('CREATE INDEX %s ON %s (%s);' % 
    313                                  (index.qname, table.qname, q(index.colname))) 
     270                self.db.execute_ddl('CREATE INDEX %s ON %s (%s);' % 
     271                                    (index.qname, table.qname, q(index.colname))) 
    314272         
    315273        if incr_fields: 
    316             self.execute_ddl("DELETE FROM %s" % table.qname) 
     274            self.db.execute_ddl("DELETE FROM %s" % table.qname) 
    317275         
    318276        dict.__setitem__(self, key, table) 
    319277     
    320278    def _get_tables(self, conn=None): 
    321         data, _ = self.fetch("SHOW TABLES FROM %s" % self.qname, conn=conn) 
    322         return [self.tableclass(row[0], self.quote(row[0]), 
     279        data, _ = self.db.fetch("SHOW TABLES FROM %s" % self.qname, conn=conn) 
     280        return [self.tableclass(row[0], self.db.quote(row[0]), 
    323281                                self, created=True) 
    324282                for row in data] 
    325283     
    326284    def _get_table(self, tablename, conn=None): 
    327         data, _ = self.fetch("SHOW TABLES FROM %s LIKE '%s'" 
     285        data, _ = self.db.fetch("SHOW TABLES FROM %s LIKE '%s'" 
    328286                             % (self.qname, tablename), conn=conn) 
    329287        for row in data: 
    330288            name = row[0] 
    331289            if name == tablename: 
    332                 return self.tableclass(name, self.quote(name), 
     290                return self.tableclass(name, self.db.quote(name), 
    333291                                       self, created=True) 
    334292        raise errors.MappingError(tablename) 
     
    337295        # cols are: Field, Type, Null, Key, Default, Extra. 
    338296        # See http://dev.mysql.com/doc/refman/4.1/en/describe.html 
    339         data, _ = self.fetch("SHOW COLUMNS FROM %s.%s" % 
    340                              (self.qname, self.quote(tablename)), conn=conn) 
     297        data, _ = self.db.fetch("SHOW COLUMNS FROM %s.%s" % 
     298                                (self.qname, self.db.quote(tablename)), 
     299                                conn=conn) 
    341300        cols = [] 
    342301        for row in data: 
     
    366325             
    367326            key = (row[3] == "PRI") 
    368             pytype = self.python_type(dbtype) 
     327            pytype = self.db.python_type(dbtype) 
    369328             
    370329            col = geniusql.Column(pytype, dbtype, None, hints, key, 
    371                                   row[0], self.quote(row[0])) 
     330                                  row[0], self.db.quote(row[0])) 
    372331             
    373332            if row[4]: 
     
    384343            # cols are: Table, Non_unique, Key_name, Seq_in_index, Column_name, 
    385344            # Collation, Cardinality, Sub_part, Packed, Null, Index_type, Comment 
    386             data, _ = self.fetch("SHOW INDEX FROM %s.%s" 
    387                                  % (self.qname, self.quote(tablename)), 
    388                                  conn=conn) 
     345            data, _ = self.db.fetch("SHOW INDEX FROM %s.%s" 
     346                                    % (self.qname, self.db.quote(tablename)), 
     347                                    conn=conn) 
    389348        except _mysql.ProgrammingError, x: 
    390349            if x.args[0] != 1146: 
     
    392351        else: 
    393352            for row in data: 
    394                 i = geniusql.Index(row[2], self.quote(row[2]), 
     353                i = geniusql.Index(row[2], self.db.quote(row[2]), 
    395354                                   row[0], row[4], not row[1]) 
    396355                indices.append(i) 
    397356        return indices 
     357     
     358    def create_database(self): 
     359        # _mysql has create_db and drop_db commands, but they're deprecated. 
     360        encoding = self.db.encoding 
     361        if encoding: 
     362            encoding = " CHARACTER SET %s" % encoding 
     363        sql = 'CREATE DATABASE %s%s;' % (self.qname, encoding) 
     364        conn = self.db.connections._get_conn(master=True) 
     365        self.db.execute_ddl(sql, conn) 
     366        conn.close() 
     367        self.clear() 
     368     
     369    def drop_database(self): 
     370        conn = self.db.connections._get_conn(master=True) 
     371        self.db.execute_ddl('DROP DATABASE %s;' % self.qname, conn) 
     372        conn.close() 
     373        self.clear() 
     374 
     375 
     376class MySQLDatabase(geniusql.Database): 
     377     
     378    sql_name_max_length = 64 
     379    # MySQL uses case-sensitive database and table names on Unix, but 
     380    # not on Windows. Use all-lowercase identifiers to work around the 
     381    # problem. "Column names, index names, and column aliases are not 
     382    # case sensitive on any platform." 
     383    # If deployers set lower_case_table_names to 1, it would help. 
     384    sql_name_caseless = True 
     385    encoding = "utf8" 
     386     
     387    adaptertosql = AdapterToMySQL() 
     388    adapterfromdb = AdapterFromMySQL() 
     389    typeadapter = TypeAdapterMySQL() 
     390     
     391    connectionmanager = MySQLConnectionManager 
     392    schemaclass = MySQLSchema 
     393     
     394    def __init__(self, **kwargs): 
     395        geniusql.Database.__init__(self, **kwargs) 
     396         
     397        self.connections.connargs = dict([(k, v) for k, v in kwargs.iteritems() 
     398                                          if k in connargs]) 
     399         
     400        self.decompiler = MySQLDecompiler 
     401         
     402        # Get the version string from MySQL, to see if we need 
     403        # a different decompiler. 
     404        conn = self.connections._get_conn(master=True) 
     405        rowdata, cols = self.fetch("SELECT version();", conn) 
     406        conn.close() 
     407        v = rowdata[0][0] 
     408        self._version = providers.Version(v) 
     409         
     410        # decompiler 
     411        if self._version > providers.Version("4.1.1"): 
     412            self.decompiler = MySQLDecompiler411 
     413         
     414        # type adapter 
     415        if self._version >= providers.Version("4.1"): 
     416            self.typeadapter = TypeAdapterMySQL41() 
     417     
     418    def version(self): 
     419        return "MySQL Version: %s" % self._version 
    398420     
    399421    def python_type(self, dbtype): 
     
    486508        return res.fetch_row(0, 0), res.describe() 
    487509     
    488     def create_database(self): 
    489         # _mysql has create_db and drop_db commands, but they're deprecated. 
    490         encoding = self.encoding 
    491         if encoding: 
    492             encoding = " CHARACTER SET %s" % encoding 
    493         sql = 'CREATE DATABASE %s%s;' % (self.qname, encoding) 
    494         conn = self.connections._get_conn(master=True) 
    495         self.execute_ddl(sql, conn) 
    496         conn.close() 
    497         self.clear() 
    498      
    499     def drop_database(self): 
    500         conn = self.connections._get_conn(master=True) 
    501         self.execute_ddl('DROP DATABASE %s;' % self.qname, conn) 
    502         conn.close() 
    503         self.clear() 
    504      
    505510    def is_timeout_error(self, exc): 
    506511        # OperationalError: (1205, 'Lock wait timeout exceeded; try restarting transaction') 
  • trunk/geniusql/providers/postgres.py

    r10 r11  
    1010unescape_oct = re.compile(r"\\(\d\d\d)") 
    1111replace_unoct = lambda m: chr(int(m.group(1), 8)) 
     12import threading 
    1213 
    1314import geniusql 
     
    117118        """Drop the specified index.""" 
    118119        # PG doesn't use DROP INDEX .. ON .. 
    119         self.table.db.execute_ddl('DROP INDEX %s;' % self[key].qname) 
     120        self.table.schema.db.execute_ddl('DROP INDEX %s;' % self[key].qname) 
    120121 
    121122 
    122123class PgTable(geniusql.Table): 
    123      
    124     indexsetclass = PgIndexSet 
    125124     
    126125    def _grab_new_ids(self, idkeys, conn): 
     
    129128            col = self[idkey] 
    130129            seq = col.sequence_name 
    131             data, _ = self.db.fetch("SELECT last_value FROM %s;" % seq, conn) 
     130            data, _ = self.schema.db.fetch("SELECT last_value FROM %s;" % seq, conn) 
    132131            newids[idkey] = data[0][0] 
    133132        return newids 
    134133 
    135134 
    136 class PgDatabase(geniusql.Database): 
    137      
    138     sql_name_max_length = 63 
    139     quote_all = True 
    140     poolsize = 10 
    141     encoding = 'SQL_ASCII' 
    142      
    143     decompiler = PgDecompiler 
    144     adaptertosql = AdapterToPgSQL() 
    145     adapterfromdb = AdapterFromPg() 
     135class PgSchema(geniusql.Schema): 
     136     
    146137    tableclass = PgTable 
    147      
    148     def __init__(self, name, **kwargs): 
    149         import threading 
    150         self._discover_lock = threading.Lock() 
    151          
    152         dict.__init__(self) 
    153         for k, v in kwargs.iteritems(): 
    154             setattr(self, k, v) 
    155          
    156         self.name = self.sql_name(name) 
    157         self.qname = self.quote(self.name) 
    158          
    159         poolsize = kwargs.get('poolsize', 10) 
    160         self.connections = self.connectionmanager(self, poolsize) 
    161         self.connections.Connect = self.Connect 
    162          
    163         self.discover_dbinfo() 
     138    indexsetclass = PgIndexSet 
    164139     
    165140    def _get_tables(self, conn=None): 
    166         data, _ = self.fetch("SELECT tablename FROM pg_tables WHERE schemaname" 
    167                              " not in ('information_schema', 'pg_catalog')", 
    168                              conn=conn) 
    169         return [self.tableclass(row[0], self.quote(row[0]), 
     141        data, _ = self.db.fetch( 
     142            "SELECT tablename FROM pg_tables WHERE schemaname" 
     143            " not in ('information_schema', 'pg_catalog')", 
     144            conn=conn) 
     145        return [self.tableclass(row[0], self.db.quote(row[0]), 
    170146                                self, created=True) 
    171147                for row in data] 
    172148     
    173149    def _get_table(self, tablename, conn=None): 
    174         data, _ = self.fetch("SELECT tablename FROM pg_tables WHERE " 
    175                              "tablename = '%s'" % tablename, 
    176                              conn=conn) 
     150        data, _ = self.db.fetch("SELECT tablename FROM pg_tables WHERE " 
     151                                "tablename = '%s'" % tablename, conn=conn) 
    177152        for name, in data: 
    178153            if name == tablename: 
    179                 return self.tableclass(name, self.quote(name), 
     154                return self.tableclass(name, self.db.quote(name), 
    180155                                       self, created=True) 
    181156        raise errors.MappingError(tablename) 
     
    183158    def _get_columns(self, tablename, conn=None): 
    184159        # Get the OID of the table 
    185         data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'
    186                              % tablename, conn=conn) 
     160        data, _ = self.db.fetch("SELECT oid FROM pg_class WHERE
     161                                "relname = '%s'" % tablename, conn=conn) 
    187162        table_OID = data[0][0] 
    188163         
    189164        # Get index data so we can set col.key if pg_index.indisprimary 
    190         data, _ = self.fetch("SELECT indkey FROM pg_index WHERE indrelid " 
    191                              "= %s AND indisprimary" % table_OID, conn=conn) 
     165        data, _ = self.db.fetch("SELECT indkey FROM pg_index WHERE indrelid " 
     166                                "= %s AND indisprimary" % table_OID, conn=conn) 
    192167        if data: 
    193168            # indkey is an "array" (we get a space-separated string of ints). 
     
    201176               "FROM pg_attribute WHERE attisdropped = False AND " 
    202177               "attrelid = %s" % table_OID) 
    203         data, _ = self.fetch(sql, conn=conn) 
     178        data, _ = self.db.fetch(sql, conn=conn) 
    204179        cols = [] 
    205180        for row in data: 
     
    211186             
    212187            # Data type 
    213             dbtype, _ = self.fetch("SELECT typname, typlen FROM pg_type " 
    214                                    "WHERE oid = %s" % row[1], conn=conn) 
     188            dbtype, _ = self.db.fetch("SELECT typname, typlen FROM pg_type " 
     189                                      "WHERE oid = %s" % row[1], conn=conn) 
    215190            if dbtype: 
    216191                dbtype = dbtype[0][0].upper() 
    217192            else: 
    218193                dbtype = None 
    219             c = geniusql.Column(self.python_type(dbtype), dbtype, 
     194            c = geniusql.Column(self.db.python_type(dbtype), dbtype, 
    220195                                None, {}, row[2] in indices, 
    221                                 row[0], self.quote(row[0])) 
     196                                row[0], self.db.quote(row[0])) 
    222197             
    223198            if dbtype in ('FLOAT4', 'FLOAT8'): 
     
    228203             
    229204            # Default value 
    230             default, _ = self.fetch("SELECT adsrc FROM pg_attrdef " 
    231                                     "WHERE adnum = %s AND adrelid = %s" 
    232                                     % (row[2], table_OID), conn=conn) 
     205            default, _ = self.db.fetch("SELECT adsrc FROM pg_attrdef " 
     206                                       "WHERE adnum = %s AND adrelid = %s" 
     207                                       % (row[2], table_OID), conn=conn) 
    233208            if default: 
    234209                default = default[0][0] 
     
    237212                    c.autoincrement = True 
    238213                    c.sequence_name = seq_name.search(default).group(1) 
    239                     c.initial = self.fetch("SELECT min_value FROM %s" % 
    240                                            c.sequence_name, conn=conn)[0][0] 
     214                    c.initial = self.db.fetch("SELECT min_value FROM %s" % 
     215                                              c.sequence_name, conn=conn)[0][0] 
    241216                    c.default = None 
    242217                else: 
     
    244219                    # our guessed type. Be sure to strip any ::typename 
    245220                    default = default.split("::", 1)[0] 
    246                     c.default = self.python_type(dbtype)(default) 
     221                    c.default = self.db.python_type(dbtype)(default) 
    247222            else: 
    248223                c.default = None 
     
    263238    def _get_indices(self, tablename, conn=None): 
    264239        # Get the OID of the parent table. 
    265         data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'
    266                              % tablename, conn=conn) 
     240        data, _ = self.db.fetch("SELECT oid FROM pg_class WHERE
     241                                "relname = '%s'" % tablename, conn=conn) 
    267242        if not data: 
    268243            return [] 
     
    270245        table_OID = data[0][0] 
    271246        indices = [] 
    272         data, _ = self.fetch("SELECT pg_class.relname, indkey, indisprimary, " 
    273                              "indisunique FROM pg_index LEFT JOIN pg_class " 
    274                              "ON pg_index.indexrelid = pg_class.oid WHERE " 
    275                              "pg_index.indrelid = %s" % table_OID, conn=conn) 
     247        data, _ = self.db.fetch( 
     248            "SELECT pg_class.relname, indkey, indisprimary, " 
     249            "indisunique FROM pg_index LEFT JOIN pg_class " 
     250            "ON pg_index.indexrelid = pg_class.oid WHERE " 
     251            "pg_index.indrelid = %s" % table_OID, conn=conn) 
    276252        for row in data: 
    277253            # indkey is an "array" (we get a space-separated string of ints). 
    278254            cols = map(int, row[1].split(" ")) 
    279255            for col in cols: 
    280                 d, _ = self.fetch("SELECT attname FROM pg_attribute " 
    281                                   "WHERE attrelid = %s AND attnum = %s" 
    282                                   % (table_OID, col), conn=conn) 
    283                 i = geniusql.Index(row[0], self.quote(row[0]), tablename, 
     256                d, _ = self.db.fetch("SELECT attname FROM pg_attribute " 
     257                                     "WHERE attrelid = %s AND attnum = %s" 
     258                                     % (table_OID, col), conn=conn) 
     259                i = geniusql.Index(row[0], self.db.quote(row[0]), tablename, 
    284260                                   d[0][0], bool(row[3])) 
    285261                indices.append(i) 
    286262         
    287263        return indices 
     264     
     265    def columnclause(self, column): 
     266        """Return a clause for the given column for CREATE or ALTER TABLE. 
     267         
     268        This will be of the form "name type [DEFAULT [x | nextval('seq')]]". 
     269         
     270        PostgreSQL creates the sequence in a separate statement. 
     271        """ 
     272        if column.autoincrement: 
     273            default = "nextval('%s')" % column.sequence_name 
     274        else: 
     275            default = column.default or "" 
     276            if not isinstance(default, str): 
     277                default = self.db.adaptertosql.coerce(default, column.dbtype) 
     278         
     279        if default: 
     280            default = " DEFAULT %s" % default 
     281         
     282        return '%s %s%s' % (column.qname, column.dbtype, default) 
     283     
     284    def create_sequence(self, table, column): 
     285        """Create a SEQUENCE for the given column and set its sequence_name.""" 
     286        sname = column.sequence_name 
     287        if sname is None: 
     288            sname = self.db.quote("%s_%s_seq" % (table.name, column.name)) 
     289            column.sequence_name = sname 
     290        self.db.execute_ddl("CREATE SEQUENCE %s START %s;" % 
     291                            (sname, column.initial)) 
     292     
     293    def drop_sequence(self, column): 
     294        """Drop a SEQUENCE for the given column and remove its sequence_name.""" 
     295        if column.sequence_name is not None: 
     296            self.db.execute_ddl("DROP SEQUENCE %s;" % column.sequence_name) 
     297            column.sequence_name = None 
     298     
     299    def create_database(self): 
     300        c = self.db.connections._get_conn(master=True) 
     301        encoding = self.db.encoding 
     302        if encoding: 
     303            encoding = " WITH ENCODING '%s'" % encoding 
     304        self.db.execute_ddl("CREATE DATABASE %s%s" % (self.qname, encoding), c) 
     305        self.db.connections._del_conn(c) 
     306        del c 
     307        self.clear() 
     308     
     309    def drop_database(self): 
     310        c = self.db.connections._get_conn(master=True) 
     311        self.db.execute_ddl("DROP DATABASE %s;" % self.qname, c) 
     312        self.db.connections._del_conn(c) 
     313        del c 
     314        self.clear() 
     315 
     316 
     317class PgDatabase(geniusql.Database): 
     318     
     319    sql_name_max_length = 63 
     320    quote_all = True 
     321    poolsize = 10 
     322    encoding = 'SQL_ASCII' 
     323     
     324    adaptertosql = AdapterToPgSQL() 
     325    adapterfromdb = AdapterFromPg() 
     326     
     327    decompiler = PgDecompiler 
     328    schemaclass = PgSchema 
     329     
     330    def __init__(self, **kwargs): 
     331        geniusql.Database.__init__(self, **kwargs) 
     332        self.connections.Connect = self.Connect 
    288333     
    289334    def python_type(self, dbtype): 
     
    318363                        "to a Python type." % dbtype) 
    319364     
    320     def columnclause(self, column): 
    321         """Return a clause for the given column for CREATE or ALTER TABLE. 
    322          
    323         This will be of the form "name type [DEFAULT [x | nextval('seq')]]". 
    324          
    325         PostgreSQL creates the sequence in a separate statement. 
    326         """ 
    327         if column.autoincrement: 
    328             default = "nextval('%s')" % column.sequence_name 
    329         else: 
    330             default = column.default or "" 
    331             if not isinstance(default, str): 
    332                 default = self.adaptertosql.coerce(default, column.dbtype) 
    333          
    334         if default: 
    335             default = " DEFAULT %s" % default 
    336          
    337         return '%s %s%s' % (column.qname, column.dbtype, default) 
    338      
    339     def create_sequence(self, table, column): 
    340         """Create a SEQUENCE for the given column and set its sequence_name.""" 
    341         sname = column.sequence_name 
    342         if sname is None: 
    343             sname = self.quote("%s_%s_seq" % (table.name, column.name)) 
    344             column.sequence_name = sname 
    345         self.execute_ddl("CREATE SEQUENCE %s START %s;" % (sname, column.initial)) 
    346      
    347     def drop_sequence(self, column): 
    348         """Drop a SEQUENCE for the given column and remove its sequence_name.""" 
    349         if column.sequence_name is not None: 
    350             self.execute_ddl("DROP SEQUENCE %s;" % column.sequence_name) 
    351             column.sequence_name = None 
    352      
    353365    def quote(self, name): 
    354366        if self.quote_all: 
     
    361373            name = name.lower() 
    362374        return name 
    363      
    364     def create_database(self): 
    365         c = self.connections._get_conn(master=True) 
    366         encoding = self.encoding 
    367         if encoding: 
    368             encoding = " WITH ENCODING '%s'" % encoding 
    369         self.execute_ddl("CREATE DATABASE %s%s" % (self.qname, encoding), c) 
    370         self.connections._del_conn(c) 
    371         del c 
    372         self.clear() 
    373      
    374     def drop_database(self): 
    375         c = self.connections._get_conn(master=True) 
    376         self.execute_ddl("DROP DATABASE %s;" % self.qname, c) 
    377         self.connections._del_conn(c) 
    378         del c 
    379         self.clear() 
    380  
     375 
  • trunk/geniusql/providers/psycopg.py

    r10 r11  
    5858 
    5959 
    60 class PsycoPgDatabase(postgres.PgDatabase): 
    61      
    62     adapterfromdb = AdapterFromPsycoPg() 
    63     connectionmanager = PsycoPgConnectionManager 
     60class PsycoPgSchema(postgres.PgSchema): 
    6461     
    6562    def _get_dbinfo(self, conn=None): 
    6663        dbinfo = {} 
    6764        try: 
    68             data, _ = self.fetch("SELECT pg_encoding_to_char(encoding) " 
    69                                  "FROM pg_database;", conn=conn) 
     65            data, _ = self.db.fetch("SELECT pg_encoding_to_char(encoding) " 
     66                                    "FROM pg_database;", conn=conn) 
    7067            dbinfo['encoding'] = data[0][0] 
    7168        except _psycopg.DatabaseError, x: 
     
    7370                raise 
    7471        return dbinfo 
     72 
     73 
     74class PsycoPgDatabase(postgres.PgDatabase): 
     75     
     76    adapterfromdb = AdapterFromPsycoPg() 
     77    connectionmanager = PsycoPgConnectionManager 
     78    schemaclass = PsycoPgSchema 
    7579     
    7680    def version(self): 
  • trunk/geniusql/providers/pypgsql.py

    r10 r11  
    3737 
    3838 
    39 class PyPgDatabase(postgres.PgDatabase): 
    40      
    41     connectionmanager = PyPgConnectionManager 
     39class PyPgSchema(postgres.PgSchema): 
    4240     
    4341    def _get_dbinfo(self, conn=None): 
    4442        dbinfo = {} 
    4543        try: 
    46             data, _ = self.fetch("SELECT pg_encoding_to_char(encoding) " 
    47                                  "FROM pg_database;", conn=conn) 
     44            data, _ = self.db.fetch("SELECT pg_encoding_to_char(encoding) " 
     45                                    "FROM pg_database;", conn=conn) 
    4846            dbinfo['encoding'] = data[0][0] 
    4947        except libpq.DatabaseError, x: 
     
    5149                raise 
    5250        return dbinfo 
     51 
     52 
     53class PyPgDatabase(postgres.PgDatabase): 
     54     
     55    connectionmanager = PyPgConnectionManager 
     56    schemaclass = PyPgSchema 
    5357     
    5458    def fetch(self, query, conn=None): 
  • trunk/geniusql/providers/sqlite.py

    r10 r11  
    278278    def _parent_key(self): 
    279279        """Return the key of this Table in its parent Database.""" 
    280         names = [x for x in self.db if self.db[x].name == self.name] 
     280        names = [x for x in self.schema if self.schema[x].name == self.name] 
    281281        return names[0] 
    282282     
     
    285285        temptable = self.copy() 
    286286        tempkey = "temp_" + self._parent_key() 
    287         temptable.name = self.db.table_name(tempkey) 
    288         temptable.qname = self.db.quote(temptable.name) 
     287        temptable.name = self.schema.table_name(tempkey) 
     288        temptable.qname = self.schema.db.quote(temptable.name) 
    289289        temptable.indices.clear() 
    290290        return tempkey, temptable 
     
    297297        newtable.name = self.name 
    298298        newtable.qname = self.qname 
    299         self.db[thiskey] = newtable 
     299        self.schema[thiskey] = newtable 
    300300         
    301301        # Copy data from the temp table to the final table. 
     
    303303        # mixes up the fields (during rename, at least). 
    304304        selfields = ", ".join([c.qname for c in temptable.values()]) 
    305         self.db.execute("INSERT INTO %s (%s) SELECT %s FROM %s;" % 
    306                         (newtable.qname, selfields, selfields, 
    307                          temptable.qname)) 
     305        self.schema.db.execute("INSERT INTO %s (%s) SELECT %s FROM %s;" % 
     306                               (newtable.qname, selfields, selfields, 
     307                                temptable.qname)) 
    308308         
    309309        # Drop the intermediate table. 
    310         del self.db[tempkey] 
     310        del self.schema[tempkey] 
    311311     
    312312    if not _add_column_support: 
     
    319319                del self[key] 
    320320             
    321             db = self.db 
    322321            if column.autoincrement: 
    323322                # This may or may not be a no-op, depending on the DB. 
    324                 db.create_sequence(self, column) 
     323                self.schema.create_sequence(self, column) 
    325324             
    326325            # Make a temporary copy. 
     
    329328            dict.__setitem__(temptable, key, column) 
    330329            # Bind the temp table to the DB. 
    331             db[tempkey] = temptable 
     330            self.schema[tempkey] = temptable 
    332331             
    333332            # Copy data from the old table to the temp table. 
     
    339338                    qname = "NULL AS %s" % qname 
    340339                selfields.append(qname) 
    341             db.execute_ddl("INSERT INTO %s SELECT %s FROM %s;" % 
    342                            (temptable.qname, ", ".join(selfields), self.qname)) 
     340            self.schema.db.execute_ddl("INSERT INTO %s SELECT %s FROM %s;" % 
     341                                       (temptable.qname, ", ".join(selfields), 
     342                                        self.qname)) 
    343343             
    344344            # Copy data from the temp table to a new table for self. 
     
    353353            return 
    354354         
    355         db = self.db 
    356355        column = self[key] 
    357356         
     
    361360        dict.__delitem__(temptable, key) 
    362361        # Bind the temp table to the DB. 
    363         db[tempkey] = temptable 
     362        self.schema[tempkey] = temptable 
    364363         
    365364        # Copy data from the old table to the temp table. 
     
    368367            qname = c.qname 
    369368            selfields.append(qname) 
    370         db.execute_ddl("INSERT INTO %s SELECT %s FROM %s;" % 
    371                        (temptable.qname, ", ".join(selfields), self.qname)) 
     369        self.schema.db.execute_ddl("INSERT INTO %s SELECT %s FROM %s;" % 
     370                                   (temptable.qname, ", ".join(selfields), 
     371                                    self.qname)) 
    372372         
    373373        self._copy_from_temp(temptable, self._parent_key(), tempkey) 
     
    375375        if column.autoincrement: 
    376376            # This may or may not be a no-op, depending on the DB. 
    377             db.drop_sequence(column) 
     377            self.schema.drop_sequence(column) 
    378378     
    379379    def rename(self, oldkey, newkey): 
     
    381381        oldcol = self[oldkey] 
    382382        oldname = oldcol.name 
    383         db = self.db 
    384         newname = db._column_name(self.name, newkey) 
     383        newname = self.schema._column_name(self.name, newkey) 
    385384         
    386385        if oldname != newname: 
     
    388387            dict.__setitem__(self, newkey, oldcol) 
    389388            oldcol.name = newname 
    390             oldcol.qname = db.quote(newname) 
     389            oldcol.qname = self.schema.db.quote(newname) 
    391390             
    392391            # Make a temporary copy. 
    393392            tempkey, temptable = self._temp_copy() 
    394393            # Bind the temp table to the DB. 
    395             db[tempkey] = temptable 
     394            self.schema[tempkey] = temptable 
    396395             
    397396            # Copy data from the old table to the temp table. 
     
    400399                qname = c.qname 
    401400                if k == newkey: 
    402                     qname = "%s AS %s" % (db.quote(oldname), qname) 
     401                    qname = "%s AS %s" % (self.schema.db.quote(oldname), qname) 
    403402                selfields.append(qname) 
    404             db.execute_ddl("INSERT INTO %s SELECT %s FROM %s;" % 
    405                            (temptable.qname, ", ".join(selfields), self.qname)) 
     403            self.schema.db.execute_ddl("INSERT INTO %s SELECT %s FROM %s;" % 
     404                                       (temptable.qname, ", ".join(selfields), 
     405                                        self.qname)) 
    406406             
    407407            self._copy_from_temp(temptable, self._parent_key(), tempkey) 
     
    482482        self.db = db 
    483483        self.poolsize = poolsize 
     484        # Can't set up the factory until we have a schema. 
     485     
     486    def _set_factory(self): 
    484487        if self.db.name == ":memory:": 
    485488            # "Multiple connections to ":memory:" within a single process 
     
    489492            self._factory = geniusql.SingleConnection(self._get_conn, self._del_conn) 
    490493        else: 
    491             return geniusql.ConnectionManager.__init__(self, db, poolsize
     494            geniusql.ConnectionManager._set_factory(self
    492495     
    493496    if _cursor_required: 
     
    557560 
    558561 
    559 class SQLiteDatabase(geniusql.Database): 
    560      
    561     sql_name_max_length = 0 
    562      
    563     decompiler = SQLiteDecompiler 
    564     selectwriter = SQLiteSelectWriter 
    565     adaptertosql = AdapterToSQLite() 
    566     adapterfromdb = AdapterFromSQLite() 
    567     typeadapter = TypeAdapterSQLite() 
    568     connectionmanager = SQLiteConnectionManager 
     562class SQLiteSchema(geniusql.Schema): 
    569563     
    570564    tableclass = SQLiteTable 
    571565     
    572     def __init__(self, name, **kwargs): 
     566    def __init__(self, db, name): 
    573567        if name != ':memory:': 
    574568            if not os.path.isabs(name): 
    575569                name = os.path.join(os.getcwd(), name) 
    576         kwargs['mode'] = int(kwargs.pop('mode', '0755'), 8) 
    577         geniusql.Database.__init__(self, name, **kwargs) 
    578      
    579     def _get_upd(self): 
    580         return self.adaptertosql.using_perfect_dates 
    581     def _set_upd(self, value): 
    582         self.adaptertosql.using_perfect_dates = value 
    583     using_perfect_dates = property(_get_upd, _set_upd) 
    584      
    585     def isrelatedtype(self, pytype1, pytype2): 
    586         if (self.using_perfect_dates and 
    587             issubclass(pytype1, (datetime.date, datetime.time, datetime.datetime)) and 
    588             issubclass(pytype2, self.python_type(self.typeadapter.coerce(None, pytype1)))): 
    589             return True 
    590         return geniusql.Database.isrelatedtype(self, pytype1, pytype2) 
     570        geniusql.Schema.__init__(self, db, name) 
     571        self.db.name = self.name 
    591572     
    592573    def _get_tables(self, conn=None): 
    593         data, _ = self.fetch("SELECT name FROM sqlite_master " 
    594                              "WHERE type = 'table';", conn) 
     574        data, _ = self.db.fetch("SELECT name FROM sqlite_master " 
     575                                "WHERE type = 'table';", conn) 
    595576        # Note that we set Table.created here, since these already exist in the DB. 
    596         return [self.tableclass(row[0], self.quote(row[0]), 
     577        return [self.tableclass(row[0], self.db.quote(row[0]), 
    597578                                self, created=True) 
    598579                for row in data] 
    599580     
    600581    def _get_table(self, tablename, conn=None): 
    601         data, _ = self.fetch("SELECT name FROM sqlite_master WHERE name = " 
    602                              "'%s' AND type = 'table';" % tablename, conn) 
     582        data, _ = self.db.fetch("SELECT name FROM sqlite_master WHERE name = " 
     583                                "'%s' AND type = 'table';" % tablename, conn) 
    603584        # Note that we set Table.created here, since these already exist in the DB. 
    604585        for name, in data: 
    605586            if name == tablename: 
    606                 return self.tableclass(name, self.quote(name), 
     587                return self.tableclass(name, self.db.quote(name), 
    607588                                       self, created=True) 
    608589        raise errors.MappingError(tablename) 
     
    610591    def _get_columns(self, tablename, conn=None): 
    611592        # cid, name, type, notnull, dflt_value, pk 
    612         data, _ = self.fetch("PRAGMA table_info(%s);" % tablename, conn=conn) 
     593        data, _ = self.db.fetch("PRAGMA table_info(%s);" % tablename, 
     594                                conn=conn) 
    613595         
    614596        cols = [] 
    615597        for row in data: 
    616598            dbtype = row[2].upper() 
    617             c = geniusql.Column(self.python_type(dbtype), dbtype, 
     599            c = geniusql.Column(self.db.python_type(dbtype), dbtype, 
    618600                                row[4], {}, bool(row[5]), 
    619                                 row[1], self.quote(row[1])) 
     601                                row[1], self.db.quote(row[1])) 
    620602             
    621603            # "A single row can hold up to 2 ** 30 bytes of data 
     
    647629     
    648630    def _get_indices(self, tablename, conn=None): 
    649         data, _ = self.fetch("SELECT name, tbl_name, sql FROM sqlite_master " 
    650                              "WHERE type = 'index';", conn) 
     631        data, _ = self.db.fetch( 
     632            "SELECT name, tbl_name, sql FROM sqlite_master " 
     633            "WHERE type = 'index';", conn) 
     634         
    651635        indices = [] 
    652636        for row in data: 
    653637            if row[2]: 
    654638                colname = row[2].split("(")[-1] 
    655                 i = geniusql.Index(row[0], self.quote(row[0]), 
     639                i = geniusql.Index(row[0], self.db.quote(row[0]), 
    656640                                   row[1], colname[1:-2]) 
    657641                indices.append(i) 
    658642        return indices 
     643     
     644    def create_sequence(self, table, column): 
     645        """Create a SEQUENCE for the given column and set its sequence_name.""" 
     646        sname = column.sequence_name 
     647        if sname is None: 
     648            column.sequence_name = sname = table.name 
     649         
     650        # SQLite AUTOINCREMENT columns start at 1 by default. 
     651        # Manhandle the special SQLITE_SEQUENCE table to include 
     652        # the value of sequencer.initial - 1. 
     653        prev = column.initial - 1 
     654        data, coldefs = self.db.fetch("SELECT * FROM SQLITE_SEQUENCE " 
     655                                      "WHERE name = '%s';" % sname) 
     656        if data: 
     657            self.db.execute("UPDATE SQLITE_SEQUENCE SET seq = %s " 
     658                            "WHERE name = '%s';" % (prev, sname)) 
     659        else: 
     660            self.db.execute("INSERT INTO SQLITE_SEQUENCE (seq, name) " 
     661                            "VALUES (%s, '%s');" % (prev, sname)) 
     662     
     663    def drop_sequence(self, column): 
     664        """Drop a SEQUENCE for the given column and remove its sequence_name.""" 
     665        if column.sequence_name is not None: 
     666            self.db.execute("DELETE FROM SQLITE_SEQUENCE WHERE name = '%s';" 
     667                            % column.sequence_name) 
     668            column.sequence_name = None 
     669     
     670    def columnclause(self, column): 
     671        """Return a clause for the given column for CREATE or ALTER TABLE. 
     672         
     673        This will be of the form: 
     674            "name type [DEFAULT x | PRIMARY KEY AUTOINCREMENT]" 
     675        """ 
     676        dbtype = coldef = column.dbtype 
     677         
     678        if column.autoincrement: 
     679            coldef = "INTEGER PRIMARY KEY AUTOINCREMENT" 
     680        else: 
     681            default = column.default or "" 
     682            if not isinstance(default, str): 
     683                default = self.db.adaptertosql.coerce(default, dbtype) 
     684            if default: 
     685                coldef += " DEFAULT %s" % default 
     686         
     687        return '%s %s' % (column.qname, coldef) 
     688     
     689    def __setitem__(self, key, table): 
     690        if key in self: 
     691            del self[key] 
     692         
     693        # Set table.created to True, which should "turn on" 
     694        # any future ALTER TABLE statements. 
     695        table.created = True 
     696         
     697        fields = [] 
     698        pk = [] 
     699        autoincr_col = None 
     700        for col in table.itervalues(): 
     701            fields.append(self.columnclause(col)) 
     702             
     703            if col.autoincrement: 
     704                # MUST create the sequence after the table is created, 
     705                # or we get into a "no such table" loop inside execute. 
     706                autoincr_col = col 
     707             
     708            if col.key: 
     709                pk.append(col.qname) 
     710         
     711        if (autoincr_col is None) and pk: 
     712            # Seems we can't have both an AUTOINCREMENT and another PK 
     713            pk = ", PRIMARY KEY (%s)" % ", ".join(pk) 
     714        else: 
     715            pk = "" 
     716         
     717        self.db.execute_ddl('CREATE TABLE %s (%s%s);' % 
     718                            (table.qname, ", ".join(fields), pk)) 
     719         
     720        for index in table.indices.itervalues(): 
     721            self.db.execute_ddl('CREATE INDEX %s ON %s (%s);' % 
     722                                (index.qname, table.qname, 
     723                                 self.db.quote(index.colname))) 
     724         
     725        if autoincr_col: 
     726            self.create_sequence(table, autoincr_col) 
     727         
     728        dict.__setitem__(self, key, table) 
     729     
     730    def _rename(self, oldtable, newtable): 
     731        if _rename_table_support: 
     732            self.db.execute_ddl("ALTER TABLE %s RENAME TO %s" % 
     733                                (oldtable.qname, newtable.qname)) 
     734        else: 
     735            raise NotImplementedError 
     736     
     737    def create_database(self): 
     738        self.db.connections.get() 
     739     
     740    def drop_database(self): 
     741        self.db.connections.shutdown() 
     742        if self.name != ":memory:": 
     743            # This should accept relative or absolute paths 
     744            os.remove(self.name) 
     745        self.clear() 
     746 
     747 
     748class SQLiteDatabase(geniusql.Database): 
     749     
     750    sql_name_max_length = 0 
     751     
     752    decompiler = SQLiteDecompiler 
     753    selectwriter = SQLiteSelectWriter 
     754    adaptertosql = AdapterToSQLite() 
     755    adapterfromdb = AdapterFromSQLite() 
     756    typeadapter = TypeAdapterSQLite() 
     757    connectionmanager = SQLiteConnectionManager 
     758     
     759    schemaclass = SQLiteSchema 
     760    multischema = False 
     761     
     762    def __init__(self, **kwargs): 
     763        kwargs['mode'] = int(kwargs.pop('mode', '0755'), 8) 
     764        geniusql.Database.__init__(self, **kwargs) 
     765     
     766    def _get_upd(self): 
     767        return self.adaptertosql.using_perfect_dates 
     768    def _set_upd(self, value): 
     769        self.adaptertosql.using_perfect_dates = value 
     770    using_perfect_dates = property(_get_upd, _set_upd) 
     771     
     772    def isrelatedtype(self, pytype1, pytype2): 
     773        if (self.using_perfect_dates and 
     774            issubclass(pytype1, (datetime.date, datetime.time, datetime.datetime)) and 
     775            issubclass(pytype2, self.python_type(self.typeadapter.coerce(None, pytype1)))): 
     776            return True 
     777        return geniusql.Database.isrelatedtype(self, pytype1, pytype2) 
    659778     
    660779    def python_type(self, dbtype): 
     
    677796        return str 
    678797     
    679     def create_sequence(self, table, column): 
    680         """Create a SEQUENCE for the given column and set its sequence_name.""" 
    681         sname = column.sequence_name 
    682         if sname is None: 
    683             column.sequence_name = sname = table.name 
    684          
    685         # SQLite AUTOINCREMENT columns start at 1 by default. 
    686         # Manhandle the special SQLITE_SEQUENCE table to include 
    687         # the value of sequencer.initial - 1. 
    688         prev = column.initial - 1 
    689         data, coldefs = self.fetch("SELECT * FROM SQLITE_SEQUENCE " 
    690                                    "WHERE name = '%s';" % sname) 
    691         if data: 
    692             self.execute("UPDATE SQLITE_SEQUENCE SET seq = %s " 
    693                          "WHERE name = '%s';" % (prev, sname)) 
    694         else: 
    695             self.execute("INSERT INTO SQLITE_SEQUENCE (seq, name) " 
    696                          "VALUES (%s, '%s');" % (prev, sname)) 
    697      
    698     def drop_sequence(self, column): 
    699         """Drop a SEQUENCE for the given column and remove its sequence_name.""" 
    700         if column.sequence_name is not None: 
    701             self.execute("DELETE FROM SQLITE_SEQUENCE WHERE name = '%s';" 
    702                          % column.sequence_name) 
    703             column.sequence_name = None 
    704      
    705     def columnclause(self, column): 
    706         """Return a clause for the given column for CREATE or ALTER TABLE. 
    707          
    708         This will be of the form: 
    709             "name type [DEFAULT x | PRIMARY KEY AUTOINCREMENT]" 
    710         """ 
    711         dbtype = coldef = column.dbtype 
    712          
    713         if column.autoincrement: 
    714             coldef = "INTEGER PRIMARY KEY AUTOINCREMENT" 
    715         else: 
    716             default = column.default or "" 
    717             if not isinstance(default, str): 
    718                 default = self.adaptertosql.coerce(default, dbtype) 
    719             if default: 
    720                 coldef += " DEFAULT %s" % default 
    721          
    722         return '%s %s' % (column.qname, coldef) 
    723      
    724     def __setitem__(self, key, table): 
    725         if key in self: 
    726             del self[key] 
    727          
    728         # Set table.created to True, which should "turn on" 
    729         # any future ALTER TABLE statements. 
    730         table.created = True 
    731          
    732         fields = [] 
    733         pk = [] 
    734         autoincr_col = None 
    735         for col in table.itervalues(): 
    736             fields.append(self.columnclause(col)) 
    737              
    738             if col.autoincrement: 
    739                 # MUST create the sequence after the table is created, 
    740                 # or we get into a "no such table" loop inside execute. 
    741                 autoincr_col = col 
    742              
    743             if col.key: 
    744                 pk.append(col.qname) 
    745          
    746         if (autoincr_col is None) and pk: 
    747             # Seems we can't have both an AUTOINCREMENT and another PK 
    748             pk = ", PRIMARY KEY (%s)" % ", ".join(pk) 
    749         else: 
    750             pk = "" 
    751          
    752         self.execute_ddl('CREATE TABLE %s (%s%s);' % 
    753                          (table.qname, ", ".join(fields), pk)) 
    754          
    755         for index in table.indices.itervalues(): 
    756             self.execute_ddl('CREATE INDEX %s ON %s (%s);' % 
    757                              (index.qname, table.qname, 
    758                               self.quote(index.colname))) 
    759          
    760         if autoincr_col: 
    761             self.create_sequence(table, autoincr_col) 
    762          
    763         dict.__setitem__(self, key, table) 
    764      
    765     def _rename(self, oldtable, newtable): 
    766         if _rename_table_support: 
    767             self.execute_ddl("ALTER TABLE %s RENAME TO %s" % 
    768                              (oldtable.qname, newtable.qname)) 
    769         else: 
    770             raise NotImplementedError 
    771      
    772798    def quote(self, name): 
    773799        """Return name, quoted for use in an SQL statement. 
     
    785811        """ 
    786812        return "[" + name + "]" 
    787      
    788     def create_database(self): 
    789         self.connections.get() 
    790      
    791     def drop_database(self): 
    792         self.connections.shutdown() 
    793         if self.name != ":memory:": 
    794             # This should accept relative or absolute paths 
    795             os.remove(self.name) 
    796         self.clear() 
    797813     
    798814    def is_timeout_error(self, exc): 
     
    853869    def version(self): 
    854870        return "SQLite Version: %s" % _version 
    855  
     871     
     872    def schema(self, name): 
     873        s = self.schemaclass(self, name) 
     874        self.name = name 
     875        self.connections._set_factory() 
     876        return s 
     877 
  • trunk/geniusql/select.py

    r10 r11  
    576576        return decom.code(), decom.imperfect 
    577577     
    578     def quote(self, name): 
    579         """Return the given table name, quoted for SQL.""" 
    580         return self.db.quote(self.db.table_name(name)) 
    581      
    582578    def wrap(self, join): 
    583579        """Return the given Join with each node wrapped.""" 
     
    590586            if t1.name in self.seen: 
    591587                self.aliascount += 1 
    592                 wt1.alias = self.quote("t%d" % self.aliascount) 
     588                alias = "t%d" % self.aliascount 
     589                wt1.alias = self.db.quote(t1.schema.table_name(alias)) 
    593590            else: 
    594591                self.seen[t1.name] = None 
     
    600597            if t2.name in self.seen: 
    601598                self.aliascount += 1 
    602                 wt2.alias = self.quote("t%d" % self.aliascount) 
     599                alias = "t%d" % self.aliascount 
     600                wt2.alias = self.db.quote(t2.schema.table_name(alias)) 
    603601            else: 
    604602                self.seen[t2.name] = None 
     
    624622        """ 
    625623        if path is None: 
    626             db = B.table.db 
    627             for tkey in db
    628                 if db[tkey] is B.table: 
     624            schema = B.table.schema 
     625            for tkey in schema
     626                if schema[tkey] is B.table: 
    629627                    path = tkey 
    630628                    break 
  • trunk/geniusql/test/test_msaccess.py

    r10 r11  
    3030             
    3131        DB_class = "access" 
    32         opts = {u'Connect': "PROVIDER=MICROSOFT.JET.OLEDB.4.0;" 
    33                             "DATA SOURCE=zoo.mdb;", 
     32        opts = {'Connect': "PROVIDER=MICROSOFT.JET.OLEDB.4.0;" 
     33                           "DATA SOURCE=zoo.mdb;", 
    3434                } 
    3535         
  • trunk/geniusql/test/test_psycopg.py

    r10 r11  
    1919        passwd = raw_input("Enter the password for the PostgreSQL '%s' user:" % user) 
    2020     
    21     opts = {u'Connect': ("host=localhost dbname=geniusql_test " 
    22                         "user=%s password=%s" % (user, passwd)), 
     21    opts = {'Connect': ("host=localhost dbname=geniusql_test " 
     22                        "user=%s password=%s" % (user, passwd)), 
    2323            } 
    2424    DB_class = "psycopg" 
  • trunk/geniusql/test/test_pypgsql.py

    r10 r11  
    1414        passwd = raw_input("Enter the password for the PostgreSQL '%s' user:" % user) 
    1515     
    16     opts = {u'Connect': ("host=localhost dbname=geniusql_test " 
     16    opts = {'Connect': ("host=localhost dbname=geniusql_test " 
    1717                         "user=%s password=%s" % (user, passwd)), 
    1818            } 
  • trunk/geniusql/test/test_sqlserver.py

    r10 r11  
    1616                          "The SQLServer test will not be run.") 
    1717    else: 
    18         opts = {u'Connect': ("Provider=SQLOLEDB.1; Integrated Security=SSPI; " 
    19                             "Initial Catalog=geniusql_test; " 
    20                             "Data Source=(local)"), 
     18        opts = {'Connect': ("Provider=SQLOLEDB.1; Integrated Security=SSPI; " 
     19                            "Initial Catalog=geniusql_test; " 
     20                            "Data Source=(local)"), 
    2121                # Shorten the transaction deadlock timeout. 
    2222                # You may need to adjust this for your system. 
    23                 u'CommandTimeout': 10, 
     23                'CommandTimeout': 10, 
    2424                } 
    2525        DB_class = "sqlserver" 
  • trunk/geniusql/test/zoo_fixture.py

    r10 r11  
    4141     
    4242    def test_1_create_tables(self): 
    43         Animal = db.table('Animal') 
    44         Animal['ID'] = db.column(int, autoincrement=True, key=True) 
     43        Animal = schema.table('Animal') 
     44        Animal['ID'] = schema.column(int, autoincrement=True, key=True) 
    4545        Animal.add_index('ID') 
    46         Animal['ZooID'] = db.column(int) 
    47         Animal['Species'] = db.column(hints={'bytes': 100}) 
    48         Animal['Legs'] = db.column(int, default=4) 
    49         Animal['PreviousZoos'] = db.column(list, hints={'bytes': 8000}) 
    50         Animal['LastEscape'] = db.column(datetime.datetime) 
    51         Animal['Lifespan'] = db.column(float, hints={'precision': 4}) 
    52         Animal['Age'] = db.column(float, default=1, hints={'precision': 4}) 
    53         Animal['MotherID'] = db.column(int) 
    54         Animal['PreferredFoodID'] = db.column(int) 
    55         Animal['AlternateFoodID'] = db.column(int) 
     46        Animal['ZooID'] = schema.column(int) 
     47        Animal['Species'] = schema.column(hints={'bytes': 100}) 
     48        Animal['Legs'] = schema.column(int, default=4) 
     49        Animal['PreviousZoos'] = schema.column(list, hints={'bytes': 8000}) 
     50        Animal['LastEscape'] = schema.column(datetime.datetime) 
     51        Animal['Lifespan'] = schema.column(float, hints={'precision': 4}) 
     52        Animal['Age'] = schema.column(float, default=1, hints={'precision': 4}) 
     53        Animal['MotherID'] = schema.column(int) 
     54        Animal['PreferredFoodID'] = schema.column(int) 
     55        Animal['AlternateFoodID'] = schema.column(int) 
    5656        Animal.add_index('ZooID') 
    5757        Animal.references['Animal'] = ('ID', 'Animal', 'MotherID') 
    5858        Animal.references['Visit'] = ('ID', 'Visit', 'AnimalID') 
    59         db['Animal'] = Animal 
    60          
    61         Zoo = db.table('Zoo') 
    62         Zoo['ID'] = db.column(int, autoincrement=True, key=True) 
     59        schema['Animal'] = Animal 
     60         
     61        Zoo = schema.table('Zoo') 
     62        Zoo['ID'] = schema.column(int, autoincrement=True, key=True) 
    6363        Zoo.add_index('ID') 
    64         Zoo['Name'] = db.column() 
    65         Zoo['Founded'] = db.column(datetime.date) 
    66         Zoo['Opens'] = db.column(datetime.time) 
    67         Zoo['LastEscape'] = db.column(datetime.datetime) 
     64        Zoo['Name'] = schema.column() 
     65        Zoo['Founded'] = schema.column(datetime.date) 
     66        Zoo['Opens'] = schema.column(datetime.time) 
     67        Zoo['LastEscape'] = schema.column(datetime.datetime) 
    6868         
    6969        if typerefs.fixedpoint: 
    7070            # Explicitly set precision and scale so test_msaccess 
    7171            # can test CURRENCY type 
    72             Zoo['Admission'] = db.column(typerefs.fixedpoint.FixedPoint, 
     72            Zoo['Admission'] = schema.column(typerefs.fixedpoint.FixedPoint, 
    7373                                         hints={'precision': 4, 'scale': 2}) 
    7474        else: 
    75             Zoo['Admission'] = db.column(float) 
     75            Zoo['Admission'] = schema.column(float) 
    7676         
    7777        Zoo.references['Animal'] = ('ID', 'Animal', 'ZooID') 
    78         db['Zoo'] = Zoo 
    79          
    80         Food = db.table('Food') 
    81         Food['ID'] = db.column(int, autoincrement=True, key=True) 
     78        schema['Zoo'] = Zoo 
     79         
     80        Food = schema.table('Food') 
     81        Food['ID'] = schema.column(int, autoincrement=True, key=True) 
    8282        Food.add_index('ID') 
    83         Food['Name'] = db.column() 
    84         Food['NutritionValue'] = db.column(int) 
     83        Food['Name'] = schema.column() 
     84        Food['NutritionValue'] = schema.column(int) 
    8585        Food.references['Animal'] = ('ID', 'Animal', 'PreferredFoodID') 
    8686        Animal.references['Alternate Food'] = ('AlternateFoodID', 'Food', 'ID') 
    87         db['Food'] = Food 
    88          
    89         Vet = db.table('Vet') 
    90         Vet['ID'] = c = db.column(int, autoincrement=True, key=True) 
     87        schema['Food'] = Food 
     88         
     89        Vet = schema.table('Vet') 
     90        Vet['ID'] = c = schema.column(int, autoincrement=True, key=True) 
    9191        c.initial = 200 
    9292        Vet.add_index('ID') 
    93         Vet['Name'] = db.column() 
    94         Vet['ZooID'] = db.column(int) 
     93        Vet['Name'] = schema.column() 
     94        Vet['ZooID'] = schema.column(int) 
    9595        Vet.add_index('ZooID') 
    96         Vet['City'] = db.column() 
     96        Vet['City'] = schema.column() 
    9797        Vet.references['Zoo'] = ('ZooID', 'Zoo', 'ID') 
    9898        Vet.references['Visit'] = ('ID', 'Visit', 'VetID') 
    99         db['Vet'] = Vet 
    100          
    101         Visit = db.table('Visit') 
    102         Visit['ID'] = db.column(int, autoincrement=True, key=True) 
     99        schema['Vet'] = Vet 
     100         
     101        Visit = schema.table('Visit') 
     102        Visit['ID'] = schema.column(int, autoincrement=True, key=True) 
    103103        Visit.add_index('ID') 
    104         Visit['VetID'] = db.column(int) 
     104        Visit['VetID'] = schema.column(int) 
    105105        Visit.add_index('VetID') 
    106         Visit['ZooID'] = db.column(int) 
     106        Visit['ZooID'] = schema.column(int) 
    107107        Visit.add_index('ZooID') 
    108         Visit['AnimalID'] = db.column(int) 
     108        Visit['AnimalID'] = schema.column(int) 
    109109        Visit.add_index('AnimalID') 
    110         Visit['Date'] = db.column(datetime.date) 
    111         db['Visit'] = Visit 
    112          
    113         Exhibit = db.table('Exhibit') 
     110        Visit['Date'] = schema.column(datetime.date) 
     111        schema['Visit'] = Visit 
     112         
     113        Exhibit = schema.table('Exhibit') 
    114114        # Make this a string to help test vs unicode. 
    115         Exhibit['Name'] = db.column(str, key=True) 
     115        Exhibit['Name'] = schema.column(str, key=True) 
    116116        Exhibit.add_index('Name') 
    117         Exhibit['ZooID'] = db.column(int, key=True) 
     117        Exhibit['ZooID'] = schema.column(int, key=True) 
    118118        Exhibit.add_index('ZooID') 
    119         Exhibit['Animals'] = db.column(list) 
    120         Exhibit['PettingAllowed'] = db.column(bool) 
    121         Exhibit['Creators'] = db.column(tuple) 
     119        Exhibit['Animals'] = schema.column(list) 
     120        Exhibit['PettingAllowed'] = schema.column(bool) 
     121        Exhibit['Creators'] = schema.column(tuple) 
    122122         
    123123        if typerefs.decimal: 
    124             Exhibit['Acreage'] = db.column(typerefs.decimal.Decimal) 
     124            Exhibit['Acreage'] = schema.column(typerefs.decimal.Decimal) 
    125125        else: 
    126             Exhibit['Acreage'] = db.column(float) 
     126            Exhibit['Acreage'] = schema.column(float) 
    127127         
    128128        Exhibit.references['Zoo'] = ('ZooID', 'Zoo', 'ID') 
    129         db['Exhibit'] = Exhibit 
    130          
    131         t = db.table('NothingToDoWithZoos') 
    132         t['ALong'] = db.column(long, hints={'precision': 1}) 
    133         t['AFloat'] = db.column(float, hints={'precision': 1}) 
     129        schema['Exhibit'] = Exhibit 
     130         
     131        t = schema.table('NothingToDoWithZoos') 
     132        t['ALong'] = schema.column(long, hints={'precision': 1}) 
     133        t['AFloat'] = schema.column(float, hints={'precision': 1}) 
    134134        if typerefs.decimal: 
    135             t['ADecimal'] = db.column(typerefs.decimal.Decimal, 
     135            t['ADecimal'] = schema.column(typerefs.decimal.Decimal, 
    136136                                      hints={'precision': 1, 'scale': 1}) 
    137137        if typerefs.fixedpoint: 
    138             t['AFixed'] = db.column(typerefs.fixedpoint.FixedPoint, 
     138            t['AFixed'] = schema.column(typerefs.fixedpoint.FixedPoint, 
    139139                                    hints={'precision': 1, 'scale': 1}) 
    140         db['NothingToDoWithZoos'] = t 
     140        schema['NothingToDoWithZoos'] = t 
    141141     
    142142    def test_2_populate(self): 
    143         wap = db['Zoo'].insert(Name='Wild Animal Park', 
     143        wap = schema['Zoo'].insert(Name='Wild Animal Park', 
    144144                           Founded=datetime.date(2000, 1, 1), 
    145145                           # 59 can give rounding errors with divmod, which 
     
    150150                           )['ID'] 
    151151         
    152         sdz = db['Zoo'].insert(Name = 'San Diego Zoo', 
     152        sdz = schema['Zoo'].insert(Name = 'San Diego Zoo', 
    153153                           # This early date should play havoc with a number 
    154154                           # of implementations. 
     
    158158                           )['ID'] 
    159159         
    160         db['Zoo'].insert(Name = u'Montr\xe9al Biod\xf4me', 
     160        schema['Zoo'].insert(Name = u'Montr\xe9al Biod\xf4me', 
    161161                  Founded = datetime.date(1992, 6, 19), 
    162162                  Opens = datetime.time(9, 0, 0), 
     
    164164                  ) 
    165165         
    166         seaworld = db['Zoo'].insert(Name = 'Sea_World', Admission = 60)['ID'] 
     166        seaworld = schema['Zoo'].insert(Name = 'Sea_World', Admission = 60)['ID'] 
    167167         
    168168        # Animals 
    169         leopardid = db['Animal'].insert(Species='Leopard', Lifespan=73.5)['ID'] 
     169        leopardid = schema['Animal'].insert(Species='Leopard', Lifespan=73.5)['ID'] 
    170170        self.assertEqual(leopardid, 1) 
    171         db['Animal'].save(ID=leopardid, ZooID=wap, 
     171        schema['Animal'].save(ID=leopardid, ZooID=wap, 
    172172                LastEscape=datetime.datetime(2004, 12, 21, 8, 15, 0, 999907)) 
    173173         
    174         lion = db['Animal'].insert(Species='Lion', ZooID=wap)['ID'] 
    175         db['Animal'].insert(Species='Slug', Legs=1, Lifespan=.75, 
     174        lion = schema['Animal'].insert(Species='Lion', ZooID=wap)['ID'] 
     175        schema['Animal'].insert(Species='Slug', Legs=1, Lifespan=.75, 
    176176                  # Test our 8000-byte limit 
    177177                  PreviousZoos=["f" * (8000 - 14)]) 
    178178         
    179         tiger = db['Animal'].insert(Species='Tiger', ZooID=sdz, 
     179        tiger = schema['Animal'].insert(Species='Tiger', ZooID=sdz, 
    180180                                    PreviousZoos=['animal\\universe'])['ID'] 
    181181         
    182182        # Override Legs.default with itself just to make sure it works. 
    183         db['Animal'].insert(Species='Bear', Legs=4) 
     183        schema['Animal'].insert(Species='Bear', Legs=4) 
    184184        # Notice that ostrich.PreviousZoos is [], whereas leopard is None. 
    185         db['Animal'].insert(Species='Ostrich', Legs=2, PreviousZoos=[], 
     185        schema['Animal'].insert(Species='Ostrich', Legs=2, PreviousZoos=[], 
    186186                            Lifespan=103.2) 
    187         db['Animal'].insert(Species='Centipede', Legs=100) 
    188          
    189         emp = db['Animal'].insert(Species='Emperor Penguin', Legs=2, ZooID=seaworld)['ID'] 
    190         adelie = db['Animal'].insert(Species='Adelie Penguin', Legs=2, ZooID=seaworld)['ID'] 
    191          
    192         db['Animal'].insert(Species='Millipede', Legs=1000000, ZooID=sdz, 
     187        schema['Animal'].insert(Species='Centipede', Legs=100) 
     188         
     189        emp = schema['Animal'].insert(Species='Emperor Penguin', Legs=2, ZooID=seaworld)['ID'] 
     190        adelie = schema['Animal'].insert(Species='Adelie Penguin', Legs=2, ZooID=seaworld)['ID'] 
     191         
     192        schema['Animal'].insert(Species='Millipede', Legs=1000000, ZooID=sdz, 
    193193                  PreviousZoos=['Wild Animal Park']) 
    194194         
    195195        # Add a mother and child to test relationships 
    196         bai_yun = db['Animal'].insert(Species='Ape', Legs=2)['ID'] 
    197         db['Animal'].insert(Species='Ape', Legs=2, MotherID=bai_yun) 
     196        bai_yun = schema['Animal'].insert(Species='Ape', Legs=2)['ID'] 
     197        schema['Animal'].insert(Species='Ape', Legs=2, MotherID=bai_yun) 
    198198         
    199199        # Exhibits 
    200         db['Exhibit'].insert(Name = 'The Penguin Encounter', 
     200        schema['Exhibit'].insert(Name = 'The Penguin Encounter', 
    201201                             ZooID = seaworld, 
    202202                             Animals = [emp, adelie], 
     
    207207                             ) 
    208208         
    209         db['Exhibit'].insert(Name = 'Tiger River', 
     209        schema['Exhibit'].insert(Name = 'Tiger River', 
    210210                  ZooID = sdz, 
    211211                  Animals = [tiger], 
     
    215215         
    216216        # Vets 
    217         cs = db['Vet'].insert(Name = 'Charles Schroeder', ZooID = sdz)['ID'] 
    218         self.assertEqual(cs, db['Vet']['ID'].initial) 
    219          
    220         jm = db['Vet'].insert(Name = 'Jim McBain', ZooID = seaworld)['ID'] 
     217        cs = schema['Vet'].insert(Name = 'Charles Schroeder', ZooID = sdz)['ID'] 
     218        self.assertEqual(cs, schema['Vet']['ID'].initial) 
     219         
     220        jm = schema['Vet'].insert(Name = 'Jim McBain', ZooID = seaworld)['ID'] 
    221221         
    222222        # Visits 
    223223        for d in every13days: 
    224             db['Visit'].insert(VetID=cs, AnimalID=tiger, Date=d) 
     224            schema['Visit'].insert(VetID=cs, AnimalID=tiger, Date=d) 
    225225        for d in every17days: 
    226             db['Visit'].insert(VetID=jm, AnimalID=emp, Date=d) 
     226            schema['Visit'].insert(VetID=jm, AnimalID=emp, Date=d) 
    227227         
    228228        # Foods 
    229         dead_fish = db['Food'].insert(Name="Dead Fish", Nutrition=5)['ID'] 
    230         live_fish = db['Food'].insert(Name="Live Fish", Nutrition=10)['ID'] 
    231         bunnies = db['Food'].insert(Name="Live Bunny Wabbit", Nutrition=10)['ID'] 
    232         steak = db['Food'].insert(Name="T-Bone", Nutrition=7)['ID'] 
     229        dead_fish = schema['Food'].insert(Name="Dead Fish", Nutrition=5)['ID'] 
     230        live_fish = schema['Food'].insert(Name="Live Fish", Nutrition=10)['ID'] 
     231        bunnies = schema['Food'].insert(Name="Live Bunny Wabbit", Nutrition=10)['ID'] 
     232        steak = schema['Food'].insert(Name="T-Bone", Nutrition=7)['ID'] 
    233233         
    234234        # Foods --> add preferred and alternate foods 
    235         db['Animal'].save(ID=lion, 
     235        schema['Animal'].save(ID=lion, 
    236236                PreferredFoodID=steak, AlternateFoodID=bunnies) 
    237         db['Animal'].save(ID=tiger, 
     237        schema['Animal'].save(ID=tiger, 
    238238                PreferredFoodID=bunnies, AlternateFoodID=steak) 
    239         db['Animal'].save(ID=emp, 
     239        schema['Animal'].save(ID=emp, 
    240240                PreferredFoodID=live_fish, AlternateFoodID=dead_fish) 
    241         db['Animal'].save(ID=adelie, 
     241        schema['Animal'].save(ID=adelie, 
    242242                PreferredFoodID=live_fish, AlternateFoodID=dead_fish) 
    243243     
    244244    def test_3_Properties(self): 
    245245        # Zoos 
    246         WAP = db['Zoo'].select_one(Name='Wild Animal Park') 
     246        WAP = schema['Zoo'].select_one(Name='Wild Animal Park') 
    247247        self.assertEqual(WAP['Founded'], datetime.date(2000, 1, 1)) 
    248248        self.assertEqual(WAP['Opens'], datetime.time(8, 15, 59)) 
     
    252252            self.assertEqual(WAP['Admission'], 4.95) 
    253253         
    254         SDZ = db['Zoo'].select_one(Founded=datetime.date(1835, 9, 13)) 
     254        SDZ = schema['Zoo'].select_one(Founded=datetime.date(1835, 9, 13)) 
    255255        self.assertEqual(SDZ['Founded'], datetime.date(1835, 9, 13)) 
    256256        self.assertEqual(SDZ['Opens'], datetime.time(9, 0, 0)) 
     
    258258        self.assertEqual(float(SDZ['Admission']), 0) 
    259259         
    260         Biodome = db['Zoo'].select_one(Name=u'Montr\xe9al Biod\xf4me') 
     260        Biodome = schema['Zoo'].select_one(Name=u'Montr\xe9al Biod\xf4me') 
    261261        self.assertEqual(Biodome['Name'], u'Montr\xe9al Biod\xf4me') 
    262262        self.assertEqual(Biodome['Founded'], datetime.date(1992, 6, 19)) 
     
    266266         
    267267        if typerefs.fixedpoint: 
    268             seaworld = db['Zoo'].select_one(lambda z: z.Admission == 
     268            seaworld = schema['Zoo'].select_one(lambda z: z.Admission == 
    269269                                            typerefs.fixedpoint.FixedPoint(60)) 
    270270        else: 
    271             seaworld = db['Zoo'].select_one(lambda z: z.Admission == float(60)) 
     271            seaworld = schema['Zoo'].select_one(lambda z: z.Admission == float(60)) 
    272272        self.assertEqual(seaworld['Name'], u'Sea_World') 
    273273         
    274274        # Animals 
    275         leopard = db['Animal'].select_one(lambda a: a.Species == 'Leopard') 
     275        leopard = schema['Animal'].select_one(lambda a: a.Species == 'Leopard') 
    276276        self.assertEqual(leopard['Species'], 'Leopard') 
    277277        self.assertEqual(leopard['Legs'], 4) 
     
    280280        self.assertEqual(leopard['PreviousZoos'], None) 
    281281         
    282         ostrich = db['Animal'].select_one(Species='Ostrich') 
     282        ostrich = schema['Animal'].select_one(Species='Ostrich') 
    283283        self.assertEqual(ostrich['Species'], 'Ostrich') 
    284284        self.assertEqual(ostrich['Legs'], 2) 
     
    287287        self.assertEqual(ostrich['LastEscape'], None) 
    288288         
    289         millipede = db['Animal'].select_one(Legs=1000000) 
     289        millipede = schema['Animal'].select_one(Legs=1000000) 
    290290        self.assertEqual(millipede['Species'], 'Millipede') 
    291291        self.assertEqual(millipede['Legs'], 1000000) 
     
    296296        # Test that strings in a list get decoded correctly. 
    297297        # See http://projects.amor.org/dejavu/ticket/50 
    298         tiger = db['Animal'].select_one(Species='Tiger') 
     298        tiger = schema['Animal'].select_one(Species='Tiger') 
    299299        self.assertEqual(tiger['PreviousZoos'], ["animal\\universe"]) 
    300300         
    301301        # Test our 8000-byte limit. 
    302302        # len(pickle.dumps(["f" * (8000 - 14)]) == 8000 
    303         slug = db['Animal'].select_one(Species='Slug') 
     303        slug = schema['Animal'].select_one(Species='Slug') 
    304304        self.assertEqual(len(slug['PreviousZoos'][0]), 8000 - 14) 
    305305         
    306306        # Exhibits 
    307         exes = list(db['Exhibit'].select_all()) 
     307        exes = list(schema['Exhibit'].select_all()) 
    308308        self.assertEqual(len(exes), 2) 
    309309        if exes[0]['Name'] == 'The Penguin Encounter': 
     
    326326    def test_4_Expressions(self): 
    327327        def matches(lam, tkey='Animal'): 
    328             return len(list(db[tkey].select_all(lam))) 
     328            return len(list(schema[tkey].select_all(lam))) 
    329329         
    330330        self.assertEqual(matches(None, 'Zoo'), 4) 
     
    454454    def test_6_Editing(self): 
    455455        # Edit 
    456         SDZ = db['Zoo'].select_one(Name='San Diego Zoo') 
    457         db['Zoo'].save(ID=SDZ['ID'], Name='The San Diego Zoo', 
     456        SDZ = schema['Zoo'].select_one(Name='San Diego Zoo') 
     457        schema['Zoo'].save(ID=SDZ['ID'], Name='The San Diego Zoo', 
    458458                       Founded = datetime.date(1900, 1, 1), 
    459459                       Opens = datetime.time(7, 30, 0), 
     
    461461         
    462462        # Test edits 
    463         SDZ = db['Zoo'].select_one(Name='The San Diego Zoo') 
     463        SDZ = schema['Zoo'].select_one(Name='The San Diego Zoo') 
    464464        self.assertEqual(SDZ['Name'], 'The San Diego Zoo') 
    465465        self.assertEqual(SDZ['Founded'], datetime.date(1900, 1, 1)) 
     
    471471         
    472472        # Change it back 
    473         db['Zoo'].save(ID=SDZ['ID'], Name = 'San Diego Zoo', 
     473        schema['Zoo'].save(ID=SDZ['ID'], Name = 'San Diego Zoo', 
    474474                       Founded = datetime.date(1835, 9, 13), 
    475475                       Opens = datetime.time(9, 0, 0), 
     
    477477         
    478478        # Test re-edits 
    479         SDZ = db['Zoo'].select_one(Name='San Diego Zoo') 
     479        SDZ = schema['Zoo'].select_one(Name='San Diego Zoo') 
    480480        self.assertEqual(SDZ['Name'], 'San Diego Zoo') 
    481481        self.assertEqual(SDZ['Founded'], datetime.date(1835, 9, 13)) 
     
    619619     
    620620    def test_9_delete(self): 
    621         ostrich = db['Animal'].select_one(Species='Ostrich') 
     621        ostrich = schema['Animal'].select_one(Species='Ostrich') 
    622622        self.assert_(ostrich is not None) 
    623623         
    624         db['Animal'].delete(**ostrich) 
    625          
    626         ostrich = db['Animal'].select_one(Species='Ostrich') 
     624        schema['Animal'].delete(**ostrich) 
     625         
     626        ostrich = schema['Animal'].select_one(Species='Ostrich') 
    627627        self.assertEqual(ostrich, None) 
    628628         
    629629        # Re-create the ostrich and try deleting it with a non-ID kwarg. 
    630         db['Animal'].insert(Species='Ostrich', Legs=2, PreviousZoos=[], 
     630        schema['Animal'].insert(Species='Ostrich', Legs=2, PreviousZoos=[], 
    631631                            Lifespan=103.2) 
    632         ostrich = db['Animal'].select_one(Species='Ostrich') 
     632        ostrich = schema['Animal'].select_one(Species='Ostrich') 
    633633        self.assert_(ostrich is not None) 
    634634         
    635         db['Animal'].delete_all(Species='Ostrich') 
    636          
    637         ostrich = db['Animal'].select_one(Species='Ostrich') 
     635        schema['Animal'].delete_all(Species='Ostrich') 
     636         
     637        ostrich = schema['Animal'].select_one(Species='Ostrich') 
    638638        self.assertEqual(ostrich, None) 
    639639     
     
    673673     
    674674    def setUp(self): 
    675         s = arena.stores.values()[0] 
    676         if hasattr(s, "db"): 
    677             self.db = s.db 
    678         else: 
    679             self.db = None 
    680          
    681675        try: 
    682             self.old_implicit = s.db.implicit_trans 
    683             s.db.implicit_trans = False 
    684             self.old_tkey = s.db.transaction_key 
     676            self.old_implicit = db.connections.implicit_trans 
     677            db.connections.implicit_trans = False 
     678            self.old_tkey = db.connections.id 
    685679            # Use an explicit 'boxid' for the transaction key 
    686             s.db.transaction_key = lambda: self.boxid 
     680            db.connections.id = lambda: self.boxid 
    687681        except AttributeError: 
    688682            self.old_implicit = None 
     
    690684     
    691685    def tearDown(self): 
    692         if self.db and self.old_implicit is not None: 
    693             self.db.implicit_trans = self.old_implicit 
    694             self.db.transaction_key = self.old_tkey 
     686        if self.old_implicit is not None: 
     687            db.connections.implicit_trans = self.old_implicit 
     688            db.connections.id = self.old_tkey 
    695689     
    696690    def restore(self): 
     
    734728                              (level, anomaly_name)) 
    735729        except: 
    736             if self.db.is_lock_error(sys.exc_info()[1]): 
     730            if db.is_timeout_error(sys.exc_info()[1]): 
    737731                self.cleanup_boxes() 
    738732                if not level.forbids(anomaly_name): 
     
    774768                print 
    775769                print level, 
    776             if level.name in self.db.isolation_levels: 
     770            if level.name in db.connections.isolation_levels: 
    777771                self.attempt(dirty_read, "Dirty Read", level) 
    778772     
     
    802796                print 
    803797                print level, 
    804             if level.name in self.db.isolation_levels: 
     798            if level.name in db.connections.isolation_levels: 
    805799                self.attempt(nonrepeatable_read, "Nonrepeatable Read", level) 
    806800     
     
    828822                print 
    829823                print level, 
    830             if level.name in self.db.isolation_levels: 
     824            if level.name in db.connections.isolation_levels: 
    831825                self.attempt(phantom, "Phantom", level) 
    832826 
     
    869863            return 
    870864         
    871         old_implicit = zoostore.db.implicit_trans 
     865        old_implicit = db.connections.implicit_trans 
    872866        try: 
    873867            def commit_test(): 
     
    899893                    box.flush_all() 
    900894             
    901             zoostore.db.implicit_trans = True 
     895            db.connections.implicit_trans = True 
    902896            commit_test() 
    903897            if zoostore.rollback: 
    904898                rollback_test() 
    905899             
    906             zoostore.db.implicit_trans = False 
     900            db.connections.implicit_trans = False 
    907901            zoostore.start() 
    908902            commit_test() 
     
    911905                rollback_test() 
    912906        finally: 
    913             zoostore.db.implicit_trans = old_implicit 
     907            db.connections.implicit_trans = old_implicit 
    914908     
    915909    def test_ContextManagement(self): 
     
    10871081 
    10881082db = None 
     1083schema = None 
    10891084 
    10901085def _geniusqllog(message): 
     
    10971092    f.close() 
    10981093 
    1099 def setup(DB_class, name, opts): 
     1094def setup(provider, name, opts): 
    11001095    """Set up storage for Zoo classes.""" 
    1101     global db 
    1102     db = geniusql.db(DB_class, name, opts) 
     1096    global db, schema 
     1097    db = geniusql.db(provider, **opts) 
     1098    schema = db.schema(name) 
    11031099    db.log = _geniusqllog 
    11041100    print db.version() 
    1105     db.create_database() 
     1101    schema.create_database() 
    11061102 
    11071103def teardown(): 
    11081104    """Tear down storage for Zoo classes.""" 
    11091105    try: 
    1110         db.drop_database() 
     1106        schema.drop_database() 
    11111107    except (AttributeError, NotImplementedError): 
    11121108        pass