Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

Changeset 123

Show
Ignore:
Timestamp:
12/11/05 21:26:16
Author:
fumanchu
Message:

Fix for #36 (column names not correctly escaped). Changes to StorageManagerDB (and its subclasses):

  1. identifier(*atoms) method changed to sql_name(name, quoted=True).
  2. New column_name(classname, name, full=False, quoted=True) method.
  3. SQLDecompiler now calls the StorageManager's column_name method (so decompilers now take the SM as a constructor arg).
  4. identifier_length is now sql_name_max_length.
  5. identifier_caseless is now sql_name_caseless.
Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/doc/advanced.html

    r122 r123  
    298298inventory, invoice, and scheduling software.</p> 
    299299 
     300<p>One of the more important parts of wrapping existing tables is getting 
     301your pretty Python names mapped to ugly database names. Do this by making 
     302a custom StorageManager: override the <tt>column_name</tt> and 
     303<tt>table_name</tt> methods to do the mapping.</p> 
     304 
    300305<h4>Other Serialization Mechanisms</h4> 
    301306<h5>sockets</h5> 
  • trunk/storage/db.py

    r122 r123  
    395395 
    396396class TableRef: 
    397     def __init__(self, tablename): 
    398         self.tablename = tablename 
     397    def __init__(self, classname): 
     398        self.classname = classname 
    399399 
    400400# Stack sentinels 
     
    416416 
    417417class SQLDecompiler(codewalk.LambdaDecompiler): 
    418     """SQLDecompiler(tablename, expr, adapter=AdapterToSQL()). 
     418    """SQLDecompiler(classnames, expr, sm, adapter=AdapterToSQL()). 
    419419     
    420420    Produce SQL from a supplied Expression object, with a lambda of the form: 
    421421        lambda x, **kw: ... 
    422422     
    423     Attributes of x (or whatever the name of the first argument is) will b
    424     mapped to table columns. Keyword arguments should be bound using 
    425     Expression.bind_args before calling this decompiler. 
     423    Attributes of each argument in the signature will be mapped to tabl
     424    columns. Keyword arguments should be bound using Expression.bind_args 
     425    before calling this decompiler. 
    426426    """ 
    427427     
     
    435435    sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in') 
    436436     
    437     def __init__(self, tablenames, expr, adapter=AdapterToSQL()): 
    438         self.tablenames = tablenames 
     437    def __init__(self, classnames, expr, sm, adapter=AdapterToSQL()): 
     438        self.classnames = classnames 
    439439        self.expr = expr 
    440440        self.adapter = adapter 
     441        self.sm = sm 
    441442        obj = expr.func 
    442443        codewalk.LambdaDecompiler.__init__(self, obj) 
     
    513514        arg_index = lo + (hi << 8) 
    514515        if arg_index < self.co_argcount: 
    515             self.stack.append(TableRef(self.tablenames[arg_index])) 
     516            self.stack.append(TableRef(self.classnames[arg_index])) 
    516517        else: 
    517518            self.stack.append(kw_arg) 
     
    521522        tos = self.stack.pop() 
    522523        if isinstance(tos, TableRef): 
    523             # Call another function to make subclassing easier. 
    524             self.stack.append(self.column_name(tos.tablename, name)) 
     524            atom = self.sm.column_name(tos.classname, name, full=True) 
    525525        else: 
    526526            # tos.name will reference an attribute of the tos object. 
    527527            # Stick the tos and name in a tuple for later processing. 
    528             self.stack.append((tos, name)) 
     528            atom = (tos, name) 
     529        self.stack.append(atom) 
    529530     
    530531    def visit_LOAD_CONST(self, lo, hi): 
     
    634635            self.stack.append("NOT (" + op + ")") 
    635636     
    636     def column_name(self, tablename, name): 
    637         # This is valid SQL for PostgreSQL only and should be overridden. 
    638         # If you want to use a map from UnitProperty names to legacy DB 
    639         # column names, override this method. You will probably also 
    640         # want to override StorageManager.identifier and perform the 
    641         # same map lookup there. 
    642         return '%s."%s"' % (tablename, name) 
    643      
    644637    # --------------------------- Dispatchees --------------------------- # 
    645638     
     
    706699    """StoreManager to save and retrieve Units using a DB.""" 
    707700     
    708     identifier_length = 64 
    709     identifier_caseless = False 
     701    sql_name_max_length = 64 
     702    sql_name_caseless = False 
    710703    close_connection_method = 'close' 
    711704     
     
    757750        self.expanded_columns = ec 
    758751     
    759     def identifier(self, *atoms): 
    760         ident = ''.join(map(str, atoms)).replace('"', '""') 
    761         if self.identifier_caseless: 
    762             ident = ident.lower() 
    763         idlen = self.identifier_length 
    764         if idlen and len(ident) > idlen: 
    765             warnings.warn(("Identifier is longer than %s characters." 
    766                            % idlen), dejavu.StorageWarning) 
    767             ident = ident[:idlen] 
    768         return '"' + ident + '"' 
    769      
    770     def tablename(self, cls): 
    771         if isinstance(cls, type): 
    772             name = cls.__name__ 
    773         elif isinstance(cls, dejavu.Unit): 
    774             name = cls.__class__.__name__ 
    775         elif isinstance(cls, basestring): 
    776             name = cls 
    777         else: 
    778             raise TypeError("Cannot form tablenames from %s" % cls) 
    779         return self.identifier(self.prefix, name) 
     752    #                               Naming                               # 
     753     
     754    def sql_name(self, name, quoted=True): 
     755        """The name, escaped for SQL.""" 
     756        if self.sql_name_caseless: 
     757            name = name.lower() 
     758         
     759        maxlen = self.sql_name_max_length 
     760        if maxlen and len(name) > maxlen: 
     761            warnings.warn("The name '%s' is longer than the maximum of " 
     762                          "%s characters." % (name, maxlen), 
     763                          dejavu.StorageWarning) 
     764            name = name[:maxlen] 
     765         
     766        # This base class doesn't use the "quoted" arg, 
     767        # but most subclasses will. 
     768        return name 
     769     
     770    def column_name(self, classname, name, full=False, quoted=True): 
     771        """The column name, escaped for SQL. If full, include tablename.""" 
     772        # If you want to use a map from UnitProperty names 
     773        # to DB column names, override this method. 
     774        name = self.sql_name(name, quoted=quoted) 
     775        if full: 
     776            return '%s.%s' % (self.table_name(classname, quoted=quoted), name) 
     777        else: 
     778            return name 
     779     
     780    def table_name(self, name, quoted=True): 
     781        """The table name, escaped for SQL.""" 
     782        # If you want to use a map from Unit class names 
     783        # to DB table names, override this method. 
     784        return self.sql_name(self.prefix + name, quoted=quoted) 
     785     
     786    #                             Connecting                              # 
    780787     
    781788    def _get_conn(self): 
    782         # Override this with the connection call for your DB. Example follows
    783 ##        try: 
    784 ##            conn = libpq.PQconnectdb(self.connstring) 
    785 ##        except Exception, x: 
    786 ##            if self.CreateIfMissing: 
    787 ##                self.create_database() 
    788 ##                conn = libpq.PQconnectdb(self.connstring) 
    789 ##            else: 
    790 ##                raise 
    791 ##        return conn 
     789        # Override this with the connection call for your DB. Example
     790        # try: 
     791        #     conn = libpq.PQconnectdb(self.connstring) 
     792        # except Exception, x: 
     793        #     if self.CreateIfMissing: 
     794        #         self.create_database() 
     795        #         conn = libpq.PQconnectdb(self.connstring) 
     796        #     else: 
     797        #         raise 
     798        # return conn 
    792799        raise NotImplementedError 
    793800     
     
    861868     
    862869    def create_database(self): 
    863         self.execute("CREATE DATABASE %s;" % self.identifier(self.dbname)) 
     870        self.execute("CREATE DATABASE %s;" % self.sql_name(self.dbname)) 
    864871     
    865872    def drop_database(self): 
    866         self.execute("DROP DATABASE %s;" % self.identifier(self.dbname)) 
     873        self.execute("DROP DATABASE %s;" % self.sql_name(self.dbname)) 
    867874     
    868875    def create_storage(self, unitClass): 
    869         tablename = self.tablename(unitClass) 
    870          
    871         coerce = self.typeAdapter.coerce 
     876        clsname = unitClass.__name__ 
     877        tablename = self.table_name(clsname) 
     878        typename = self.typeAdapter.coerce 
     879         
    872880        fields = [] 
    873881        for key in unitClass.properties(): 
    874             fields.append(u'%s %s' % (self.identifier(key), 
    875                                       coerce(unitClass, key))) 
     882            fields.append(u'%s %s' % (self.column_name(clsname, key), 
     883                                      typename(unitClass, key))) 
    876884        self.execute(u'CREATE TABLE %s (%s);' % (tablename, ", ".join(fields))) 
     885         
    877886        for index in unitClass.indices(): 
    878             i = self.identifier(self.prefix, "i", unitClass.__name__, index) 
     887            i = self.table_name("i" + clsname + index) 
    879888            self.execute(u'CREATE INDEX %s ON %s (%s);' % 
    880                          (i, tablename, self.identifier(index))) 
     889                         (i, tablename, self.column_name(clsname, index))) 
    881890     
    882891    def select(self, unitClass, expr, fields=None, distinct=False): 
    883         tablename = self.tablename(unitClass) 
     892        clsname = unitClass.__name__ 
     893        tablename = self.table_name(clsname) 
    884894        if fields: 
    885             fields = [self.identifier(x) for x in fields] 
     895            fields = [self.column_name(clsname, x) for x in fields] 
    886896            if distinct: 
    887897                sql = u'SELECT DISTINCT %s FROM %s' 
     
    891901        else: 
    892902            sql = u'SELECT * FROM %s' % tablename 
    893         w, i = self.where((self.tablename(unitClass),), expr) 
     903         
     904        w, i = self.where((clsname,), expr) 
    894905        if len(w) > 0: 
    895906            w = u" WHERE " + w 
     
    899910        return sql, i 
    900911     
    901     def where(self, tablenames, expr): 
    902         decom = self.decompiler(tablenames, expr, self.toAdapter) 
     912    def where(self, classnames, expr): 
     913        decom = self.decompiler(classnames, expr, self, self.toAdapter) 
    903914        return decom.code(), decom.imperfect 
    904915     
     
    913924        """fetch(query, conn=None) -> rowdata, columns. 
    914925         
    915         This base class uses SQLite3 syntax.""" 
     926        rowdata will be an iterable of iterables containing the result values. 
     927        columns will be an iterable of (column name, data type) pairs. 
     928         
     929        This base class uses SQLite3 syntax. 
     930        """ 
    916931        res = self.execute(query, conn) 
    917932        return res.row_list, res.col_defs 
    918933     
    919934    def recall(self, cls, expr=None): 
     935        clsname = cls.__name__ 
     936         
    920937        if expr is None: 
    921938            expr = logic.Expression(lambda x: True) 
     
    926943                        in enumerate(col_defs)]) 
    927944         
    928         # Get specs on properties. Get the identifier properties 
     945        # Get specs on properties. Put the identifier properties 
    929946        # first, in case other fields depend upon them. 
    930947        # See load_expanded, for example. 
     
    932949        idnames = [prop.key for prop in cls.identifiers] 
    933950        for key in idnames + [x for x in cls.properties() if x not in idnames]: 
    934             if self.identifier_caseless: 
    935                 index, ftype = columns[key.lower()] 
    936             else: 
    937                 index, ftype = columns[key] 
    938             subtype = self.expanded_columns.get((cls.__name__, key)) 
     951            index, ftype = columns[self.column_name(clsname, key, quoted=False)] 
     952            subtype = self.expanded_columns.get((clsname, key)) 
    939953            props.append((key, index, ftype, subtype)) 
    940954         
     
    970984        """ 
    971985        cls = unit.__class__ 
    972         tablename = self.tablename(unit) 
    973         i = self.identifier 
     986        clsname = cls.__name__ 
     987        tablename = self.table_name(clsname) 
    974988        self.reserve_lock.acquire() 
    975989        try: 
    976990            if not unit.sequencer.valid_id(unit.identity()): 
    977991                # Examine all existing IDs and grant the "next" one. 
    978                 id_fields = [i(prop.key) for prop in cls.identifiers] 
     992                id_fields = [self.column_name(clsname, prop.key) 
     993                             for prop in cls.identifiers] 
    979994                data, cols = self.fetch(u'SELECT %s FROM %s;' % 
    980995                                        (', '.join(id_fields), tablename)) 
     
    10001015            values = [] 
    10011016            for key in cls.properties(): 
    1002                 subtype = self.expanded_columns.get((cls.__name__, key)) 
     1017                subtype = self.expanded_columns.get((clsname, key)) 
    10031018                if subtype: 
    10041019                    self.save_expanded(unit, key, subtype) 
    10051020                else: 
    10061021                    val = self.toAdapter.coerce(getattr(unit, key)) 
    1007                     fields.append(i(key)) 
     1022                    fields.append(self.column_name(clsname, key)) 
    10081023                    values.append(val) 
    10091024             
     
    10171032     
    10181033    def id_clause(self, unit): 
    1019         i = self.identifier 
     1034        clsname = unit.__class__.__name__ 
     1035        col = self.column_name 
    10201036        c = self.toAdapter.coerce 
    10211037        idnames = [prop.key for prop in unit.identifiers] 
    1022         return " AND ".join(["%s = %s" % (i(key), c(getattr(unit, key))) 
     1038        return " AND ".join(["%s = %s" % (col(clsname, key), 
     1039                                          c(getattr(unit, key))) 
    10231040                             for key in idnames]) 
    10241041     
     
    10271044        if unit.dirty() or forceSave: 
    10281045            cls = unit.__class__ 
     1046            clsname = cls.__name__ 
    10291047             
    10301048            parms = [] 
     
    10321050            for key in cls.properties(): 
    10331051                if key not in idnames: 
    1034                     subtype = self.expanded_columns.get((cls.__name__, key)) 
     1052                    subtype = self.expanded_columns.get((clsname, key)) 
    10351053                    if subtype: 
    10361054                        self.save_expanded(unit, key, subtype) 
    10371055                    else: 
    10381056                        val = self.toAdapter.coerce(getattr(unit, key)) 
    1039                         parms.append('%s = %s' % (self.identifier(key), val)) 
     1057                        parms.append('%s = %s' % 
     1058                                     (self.column_name(clsname, key), val)) 
    10401059             
    10411060            sql = ('UPDATE %s SET %s WHERE %s;' % 
    1042                    (self.tablename(unit), u", ".join(parms), 
     1061                   (self.table_name(clsname), u", ".join(parms), 
    10431062                    self.id_clause(unit))) 
    10441063            self.execute(sql) 
     
    10491068        unitcls = unit.__class__ 
    10501069        id = "_".join(map(str, unit.identity())) 
    1051         table = self.identifier(self.prefix, "_", unitcls.__name__, 
    1052                                 "_", id, "_", key) 
     1070        table = self.table_name("_%s_%s_%s" % (unitcls.__name__, id, key)) 
    10531071         
    10541072        # Just drop the old table and start with a new one. 
     
    10731091    def load_expanded(self, unit, key, subtype): 
    10741092        """load_expanded(unit, key, subtype). Load list from separate table.""" 
     1093        unitcls = unit.__class__ 
    10751094        id = "_".join(map(str, unit.identity())) 
    1076         table = self.identifier(self.prefix, "_", unit.__class__.__name__, 
    1077                                 "_", id, "_", key) 
     1095        table = self.table_name("_%s_%s_%s" % (unitcls.__name__, id, key)) 
    10781096        try: 
    10791097            data, col_defs = self.fetch(u"SELECT EXPVAL FROM %s" % table) 
     
    10851103            values = [coercer(row[0], coltype) for row in data] 
    10861104             
    1087             expected_type = unit.__class__.property_type(key) 
     1105            expected_type = unitcls.property_type(key) 
    10881106            values = expected_type(values) 
    10891107         
     
    10941112        """destroy(unit). Delete the unit.""" 
    10951113        self.execute(u'DELETE * FROM %s WHERE %s;' % 
    1096                      (self.tablename(unit), self.id_clause(unit))) 
     1114                     (self.table_name(unit), self.id_clause(unit))) 
    10971115     
    10981116    def view(self, cls, fields, expr=None): 
     
    11601178     
    11611179    def join(self, unitjoin): 
    1162         t = self.tablename 
    1163         i = self.identifier 
    1164          
    11651180        cls1, cls2 = unitjoin.class1, unitjoin.class2 
    11661181        if isinstance(cls1, dejavu.UnitJoin): 
     
    11691184        else: 
    11701185            # cls1 is a Unit class. 
    1171             name1 = t(cls1
     1186            name1 = self.table_name(cls1.__name__
    11721187            classlist1 = [cls1] 
    11731188         
     
    11771192        else: 
    11781193            # cls2 is a Unit class. 
    1179             name2 = t(cls2
     1194            name2 = self.table_name(cls2.__name__
    11801195            classlist2 = [cls2] 
    11811196         
     
    11991214        else: 
    12001215            j = "RIGHT" 
    1201         return ("(%s %s JOIN %s ON %s.%s = %s.%s)" % 
    1202                 (name1, j, name2, 
    1203                  t(ua.nearClass), i(ua.nearKey), 
    1204                  t(ua.farClass), i(ua.farKey))
     1216         
     1217        near = self.column_name(ua.nearClass.__name__, ua.nearKey, full=True) 
     1218        far = self.column_name(ua.farClass.__name__, ua.farKey, full=True) 
     1219        return "(%s %s JOIN %s ON %s = %s)" % (name1, j, name2, near, far
    12051220     
    12061221    def multiselect(self, classes, expr): 
    1207         tablenames = [self.tablename(cls) for cls in classes] 
    12081222        if expr is None: 
    12091223            expr = logic.Expression(lambda *args: True) 
    1210         w, imp = self.where(tablenames, expr) 
     1224        w, imp = self.where([cls.__name__ for cls in classes], expr) 
    12111225         
    12121226        joins = self.join(classes) 
     
    12261240            columns.extend([(cls, k) for k in keys]) 
    12271241         
    1228         colnames = ["%s.%s" % (self.tablename(cls), self.identifier(key)
     1242        colnames = [self.column_name(cls.__name__, key, full=True
    12291243                    for cls, key in columns] 
    12301244        statement = ("SELECT %s FROM %s WHERE %s" % 
  • trunk/storage/storeado.py

    r122 r123  
    250250            self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2) 
    251251     
    252     def column_name(self, tablename, name): 
    253         return '%s.[%s]' % (tablename, name) 
    254      
    255252    # --------------------------- Dispatchees --------------------------- # 
    256253     
     
    325322        return atoms 
    326323     
    327     def identifier(self, *atoms): 
    328         ident = ''.join(map(str, atoms)) 
    329         return '[' + ident + ']' 
     324    def sql_name(self, name, quoted=True): 
     325        if quoted: 
     326            name = '[' + name + ']' 
     327        return name 
    330328     
    331329    def _get_conn(self): 
     
    484482        atoms['INITIAL CATALOG'] = "tempdb" 
    485483        adoconn.Open("; ".join(["%s=%s" % (k, v) for k, v in atoms.iteritems()])) 
    486         adoconn.Execute("CREATE DATABASE %s" % self.identifier(self.dbname)) 
     484        adoconn.Execute("CREATE DATABASE %s" % self.sql_name(self.dbname)) 
    487485        adoconn.Close() 
    488486     
     
    492490        atoms['INITIAL CATALOG'] = "tempdb" 
    493491        adoconn.Open("; ".join(["%s=%s" % (k, v) for k, v in atoms.iteritems()])) 
    494         adoconn.Execute("DROP DATABASE %s;" % self.identifier(self.dbname)) 
     492        adoconn.Execute("DROP DATABASE %s;" % self.sql_name(self.dbname)) 
    495493        adoconn.Close() 
    496494 
  • trunk/storage/storemysql.py

    r122 r123  
    5252 
    5353class MySQLDecompiler(db.SQLDecompiler): 
    54      
    55     def column_name(self, tablename, name): 
    56         # MySQL forces lowercase column names. 
    57         return '%s.`%s`' % (tablename, name.lower()) 
    58      
    59     # --------------------------- Dispatchees --------------------------- # 
    6054     
    6155    def dejavu_today(self): 
     
    122116    """StoreManager to save and retrieve Units via _mysql.""" 
    123117     
    124     identifier_length = 64 
    125     identifier_caseless = True 
     118    sql_name_max_length = 64 
     119    # MySQL uses case-sensitive database and table names on Unix, but 
     120    # not on Windows. Use all-lowercase identifiers to work around the 
     121    # problem. "Column names, index names, and column aliases are not 
     122    # case sensitive on any platform." 
     123    # If deployers set lower_case_table_names to 1, it would help. 
     124    sql_name_caseless = True 
     125     
    126126    typeAdapter = FieldTypeAdapterMySQL() 
    127127    toAdapter = AdapterToMySQL() 
     
    149149                self.decompiler = MySQLDecompiler411 
    150150     
    151     def identifier(self, *atoms): 
    152         # MySQL uses case-sensitive database and table names on Unix, but 
    153         # not on Windows. Use all-lowercase identifiers to work around the 
    154         # problem. "Column names, index names, and column aliases are not 
    155         # case sensitive on any platform." 
    156         # If deployers set lower_case_table_names to 1, it would help. 
    157         ident = ''.join(map(str, atoms)).replace('`', '``').lower() 
    158         idlen = self.identifier_length 
    159         if idlen and len(ident) > idlen: 
    160             warnings.warn(("Identifier is longer than %s characters." 
    161                            % idlen), dejavu.StorageWarning) 
    162             ident = ident[:idlen] 
    163         return '`' + ident + '`' 
     151    def sql_name(self, name, quoted=True): 
     152        name = db.StorageManagerDB.sql_name(self, name, quoted) 
     153        if quoted: 
     154            name = '`' + name.replace('`', '``') + '`' 
     155        return name 
    164156     
    165157    def _get_conn(self): 
     
    183175    def create_database(self): 
    184176        # _mysql has create_db and drop_db commands, but they're deprecated. 
    185         sql = 'CREATE DATABASE %s;' % self.identifier(self.dbname) 
     177        sql = 'CREATE DATABASE %s;' % self.sql_name(self.dbname) 
    186178        conn = self._template_conn() 
    187179        self.execute(sql, conn) 
     
    189181     
    190182    def drop_database(self): 
    191         sql = 'DROP DATABASE %s;' % self.identifier(self.dbname) 
     183        sql = 'DROP DATABASE %s;' % self.sql_name(self.dbname) 
    192184        conn = self._template_conn() 
    193185        self.execute(sql, conn) 
     
    201193     
    202194    def create_storage(self, unitClass): 
    203         # MySQL won't allow indexes on a BLOB field without a specific length. 
    204         tablename = self.tablename(unitClass
     195        clsname = unitClass.__name__ 
     196        tablename = self.table_name(clsname
    205197         
    206198        coerce = self.typeAdapter.coerce 
    207199        fields = [] 
    208200        for key in unitClass.properties(): 
    209             fields.append(u'%s %s' % (self.identifier(key), 
     201            fields.append(u'%s %s' % (self.column_name(clsname, key), 
    210202                                      coerce(unitClass, key))) 
    211203        self.execute(u'CREATE TABLE %s (%s);' % (tablename, ", ".join(fields))) 
    212204         
    213205        for index in unitClass.indices(): 
    214             i = self.identifier(self.prefix, "i", unitClass.__name__, index) 
     206            i = self.table_name("i" + clsname + index) 
    215207             
    216208            dbtype = coerce(unitClass, index) 
    217209            if dbtype.endswith('BLOB') or dbtype == 'TEXT': 
     210                # MySQL won't allow indexes on a BLOB field 
     211                # without a specific length. 
    218212                self.execute(u'CREATE INDEX %s ON %s (%s(%s));' % 
    219                              (i, tablename, self.identifier(index), 255)) 
     213                             (i, tablename, 
     214                              self.column_name(clsname, index), 255)) 
    220215            else: 
    221216                self.execute(u'CREATE INDEX %s ON %s (%s);' % 
    222                              (i, tablename, self.identifier(index))) 
     217                             (i, tablename, 
     218                              self.column_name(clsname, index))) 
    223219     
    224220    def fetch(self, query, conn=None): 
     
    240236        """destroy(unit). Delete the unit.""" 
    241237        self.execute(u'DELETE FROM %s WHERE %s;' % 
    242                      (self.tablename(unit), self.id_clause(unit))) 
    243  
     238                     (self.table_name(unit.__class__.__name__), 
     239                      self.id_clause(unit))) 
     240 
  • trunk/storage/storepypgsql.py

    r122 r123  
    4141    """StoreManager to save and retrieve Units via pyPgSQL 1.35.""" 
    4242     
    43     identifier_length = 63 
     43    sql_name_max_length = 63 
    4444    close_connection_method = 'finish' 
    4545    decompiler = PgSQLDecompiler 
     
    5555            k, v = atom.split("=", 1) 
    5656            setattr(self, k, v) 
     57     
     58    def sql_name(self, name, quoted=True): 
     59        name = db.StorageManagerDB.sql_name(self, name, quoted) 
     60        if quoted: 
     61            name = '"' + name.replace('"', '""') + '"' 
     62        return name 
    5763     
    5864    def _get_conn(self): 
     
    8086    def create_database(self): 
    8187        c = self._template_conn() 
    82         self.execute('CREATE DATABASE %s' % self.identifier(self.dbname), c) 
     88        self.execute('CREATE DATABASE %s' % self.sql_name(self.dbname), c) 
    8389        c.finish() 
    8490     
    8591    def drop_database(self): 
    8692        c = self._template_conn() 
    87         self.execute("DROP DATABASE %s;" % self.identifier(self.dbname), c) 
     93        self.execute("DROP DATABASE %s;" % self.sql_name(self.dbname), c) 
    8894        c.finish() 
    8995     
  • trunk/storage/storesqlite.py

    r122 r123  
    4343class SQLiteDecompiler(db.SQLDecompiler): 
    4444     
    45     def column_name(self, tablename, name): 
    46         return '%s.[%s]' % (tablename, name) 
    47      
    48     # --------------------------- Dispatchees --------------------------- # 
    49      
    5045    def attr_startswith(self, tos, arg): 
    5146        if _escape_support: 
     
    10499    def dejavu_istartswith(self, x, y): 
    105100        if _escape_support: 
    106             return "LOWER(" + x + ") LIKE '" + self.adapter.escape_like(y) + r"%' ESCAPE '\'" 
     101            return ("LOWER(" + x + ") LIKE '" + self.adapter.escape_like(y) 
     102                    + r"%' ESCAPE '\'") 
    107103        else: 
    108104            if "%" in y or "_" in y: 
     
    113109    def dejavu_iendswith(self, x, y): 
    114110        if _escape_support: 
    115             return "LOWER(" + x + ") LIKE '%" + self.adapter.escape_like(y) + r"%' ESCAPE '\'" 
     111            return ("LOWER(" + x + ") LIKE '%" + self.adapter.escape_like(y) 
     112                    + r"%' ESCAPE '\'") 
    116113        else: 
    117114            if "%" in y or "_" in y: 
     
    134131    """StoreManager to save and retrieve Units via _sqlite.""" 
    135132     
    136     identifier_length = 0 
     133    sql_name_max_length = 0 
    137134    decompiler = SQLiteDecompiler 
    138135    toAdapter = AdapterToSQLite() 
     
    145142        self.mode = int(allOptions.get(u'Mode', '0755'), 8) 
    146143     
    147     def identifier(self, *atoms): 
    148         """identifier(*atoms) -> return atoms joined into a legal identifier. 
     144    def sql_name(self, name, quoted=True): 
     145        """sql_name(name, quoted=True) -> return name as a legal SQL identifier. 
    149146         
    150147        From the SQLite docs: 
     
    159156        ...we'll use the third option (square brackets). 
    160157        """ 
    161         return "[" + ''.join(map(str, atoms)) + "]" 
     158        if quoted: 
     159            name = "[" + name + "]" 
     160        return name 
    162161     
    163162    def _get_conn(self): 
     
    191190        """destroy(unit). Delete the unit.""" 
    192191        self.execute(u'DELETE FROM %s WHERE %s = %s;' % 
    193                      (self.tablename(unit), self.id_clause(unit))) 
     192                     (self.table_name(unit.__class__.__name__), 
     193                      self.id_clause(unit))) 
    194194     
    195195    def create_storage(self, unitClass): 
    196         tablename = self.tablename(unitClass) 
     196        clsname = unitClass.__name__ 
     197        tablename = self.table_name(clsname) 
    197198         
    198199        # SQLite is typeless. 
    199         fields = [self.identifier(key) for key in unitClass.properties()] 
     200        fields = [self.column_name(clsname, key) 
     201                  for key in unitClass.properties()] 
    200202         
    201203        self.execute(u'CREATE TABLE %s (%s);' % (tablename, ", ".join(fields))) 
    202204        for index in unitClass.indices(): 
    203             i = self.identifier(self.prefix, "i", unitClass.__name__, index) 
     205            i = self.table_name("i" + clsname + index) 
    204206            self.execute(u'CREATE INDEX %s ON %s (%s);' % 
    205                          (i, tablename, self.identifier(index))) 
     207                         (i, tablename, self.column_name(clsname, index))) 
    206208     
    207209    def join(self, unitjoin): 
    208         t = self.tablename 
    209         i = self.identifier 
    210          
    211210        on_clauses = [] 
    212211         
     
    218217        else: 
    219218            # cls1 is a Unit class. 
    220             name1 = t(cls1
     219            name1 = self.table_name(cls1.__name__
    221220            classlist1 = [cls1] 
    222221         
     
    227226        else: 
    228227            # cls2 is a Unit class. 
    229             name2 = t(cls2
     228            name2 = self.table_name(cls2.__name__
    230229            classlist2 = [cls2] 
    231230         
    232231        # Find an association between the two halves. 
    233232        ua = None 
    234         for cls1 in classlist1: 
    235             for cls2 in classlist2: 
    236                 ua = cls1._associations.get(cls2.__name__, None) 
     233        for clsA in classlist1: 
     234            for clsB in classlist2: 
     235                ua = clsA._associations.get(clsB.__name__, None) 
    237236                if ua: break 
    238                 ua = cls2._associations.get(cls1.__name__, None) 
     237                ua = clsB._associations.get(clsA.__name__, None) 
    239238                if ua: break 
    240239            if ua: break 
     
    250249            # My version (3.0.8) of SQLite says: 
    251250            # "RIGHT and FULL OUTER JOINs are not currently supported". 
    252             # TODO: find out if any versions do support it. 
    253251            j = "%s LEFT JOIN %s" % (name2, name1) 
    254         w = ("%s.%s = %s.%s" % (t(ua.nearClass), i(ua.nearKey), 
    255                                 t(ua.farClass), i(ua.farKey))) 
    256         on_clauses.append(w) 
     252         
     253        near = self.column_name(ua.nearClass.__name__, ua.nearKey, full=True) 
     254        far = self.column_name(ua.farClass.__name__, ua.farKey, full=True) 
     255        on_clauses.append("%s = %s" % (near, far)) 
    257256        return j, on_clauses 
    258257     
    259258    def multiselect(self, classes, expr): 
    260         tablenames = [self.tablename(cls) for cls in classes] 
    261259        if expr is None: 
    262260            expr = logic.Expression(lambda *args: True) 
    263         w, imp = self.where(tablenames, expr) 
     261        w, imp = self.where([cls.__name__ for cls in classes], expr) 
    264262         
    265263        # SQLite doesn't do nested JOINs, but instead applies them 
     
    284282            columns.extend([(cls, k) for k in keys]) 
    285283         
    286         colnames = ["%s.%s" % (self.tablename(cls), self.identifier(key)
     284        colnames = [self.column_name(cls.__name__, key, full=True
    287285                    for cls, key in columns] 
    288286        statement = ("SELECT %s FROM %s WHERE %s" %