Changeset 226
- Timestamp:
- 07/16/06 04:30:04
- Files:
-
- trunk/storage/db.py (modified) (21 diffs)
- trunk/storage/dbmodel.py (added)
- trunk/storage/storeado.py (modified) (16 diffs)
- trunk/storage/storemysql.py (modified) (4 diffs)
- trunk/storage/storepypgsql.py (modified) (8 diffs)
- trunk/storage/storesqlite.py (modified) (10 diffs)
- trunk/test/test_storemsaccess.py (modified) (2 diffs)
- trunk/test/zoo_fixture.py (modified) (4 diffs)
Legend:
- Unmodified
- Added
- Removed
- Modified
- Copied
- Moved
trunk/storage/db.py
r225 r226 67 67 import dejavu 68 68 from dejavu import codewalk, logic, storage, LOGSQL, xray 69 from dbmodel import * 69 70 70 71 … … 116 117 117 118 def coerce(self, cls, key): 118 """ coerce(cls, key) -> SQL typename for valuetype."""119 """Return the SQL datatype name for valuetype.""" 119 120 valuetype = cls.property(key).type 120 121 mod = valuetype.__module__ … … 634 635 if isinstance(tos, TableRef): 635 636 # The name in question refers to a DB column. 636 atom = self.sm.column_name(tos.classname, name, full=True) 637 colname = self.sm.column_name(tos.classname, name) 638 alias = getattr(tos.classname, "alias", None) 639 if alias is None: 640 tname = self.sm.table_name(tos.classname) 641 else: 642 tname = (tos.classname.alias or tos.classname.tablename) 643 atom = '%s.%s' % (self.sm.quote(tname), self.sm.quote(colname)) 637 644 else: 638 645 # tos.name will reference an attribute of the tos object. … … 942 949 943 950 951 944 952 # --------------------------- Storage Manager --------------------------- # 945 953 … … 953 961 954 962 wclsname = wclass.__name__ 955 self.tablename = sm.table _name(wclsname)963 self.tablename = sm.tables[wclsname].name 956 964 self.alias = "" 957 965 958 966 def columns(self): 967 """Return [(wclass, UnitProperty.key), ...], ['"tbl"."col"', ...].""" 959 968 wclass = self.cls 960 969 … … 964 973 if k not in wclass.identifiers] 965 974 cols = [(wclass, k) for k in keys] 966 colnames = ['%s.%s' % (self. alias or self.tablename,967 self.sm. column_name(wclass.__name__, k))975 colnames = ['%s.%s' % (self.sm.quote(self.alias or self.tablename), 976 self.sm.quote(self.sm.column_name(wclass.__name__, k))) 968 977 for k in keys] 969 978 return cols, colnames 970 979 971 980 def _joinname(self): 981 q = self.sm.quote 972 982 if self.alias: 973 return "%s AS %s" % ( self.tablename, self.alias)974 else: 975 return self.tablename976 joinname = property(_joinname, doc=(" Table name for use in "977 "JOIN clause (read-only)."))983 return "%s AS %s" % (q(self.tablename), q(self.alias)) 984 else: 985 return q(self.tablename) 986 joinname = property(_joinname, doc=("Quoted table name for use in " 987 "JOIN clause (read-only).")) 978 988 979 989 def association(self, classes): … … 999 1009 use_asterisk_to_get_all = False 1000 1010 1011 prefix = "" 1012 1001 1013 decompiler = SQLDecompiler 1002 1014 typeAdapter = FieldTypeAdapter() … … 1004 1016 fromAdapter = AdapterFromDB() 1005 1017 1018 tablesetclass = TableSet 1019 tableclass = Table 1020 columnsetclass = ColumnSet 1021 columnclass = Column 1022 indexsetclass = IndexSet 1023 indexclass = Index 1024 1006 1025 def __init__(self, name, arena, allOptions={}): 1007 1026 storage.StorageManager.__init__(self, name, arena, allOptions) 1008 1027 1009 1028 # Adapter Overrides 1010 def get_ adapter_option(name):1011 adapter_class= allOptions.get(name)1012 if isinstance( adapter_class, basestring):1013 adapter_class = xray.classes(adapter_class)1014 return adapter_class1015 1016 adapter = get_ adapter_option('Type Adapter')1029 def get_option(name): 1030 item = allOptions.get(name) 1031 if isinstance(item, basestring): 1032 item = xray.classes(item) 1033 return item 1034 1035 adapter = get_option('Type Adapter') 1017 1036 if adapter: self.typeAdapter = adapter 1018 adapter = get_ adapter_option('To Adapter')1037 adapter = get_option('To Adapter') 1019 1038 if adapter: self.toAdapter = adapter 1020 adapter = get_ adapter_option('From Adapter')1039 adapter = get_option('From Adapter') 1021 1040 if adapter: self.fromAdapter = adapter 1041 1042 adapter = get_option('TableSet Class') 1043 if adapter: self.tablesetclass = adapter 1044 self.tables = self.tablesetclass(self) 1022 1045 1023 1046 size = int(allOptions.get('Pool Size', '10')) … … 1027 1050 self.connection = ConnectionFactory(self._get_conn, self._del_conn) 1028 1051 1029 self.prefix = allOptions.get('Prefix', " djv")1052 self.prefix = allOptions.get('Prefix', "") 1030 1053 self.reserve_lock = threading.Lock() 1031 1054 1032 1055 # Naming # 1033 1056 1034 def sql_name(self, name, quoted=True): 1035 """The name, escaped for SQL.""" 1057 def quote(self, name): 1058 """Return name, quoted for use in an SQL statement.""" 1059 # This base class doesn't use "quote", 1060 # but most subclasses will. 1061 return name 1062 1063 def sql_name(self, name): 1064 """Return the native SQL version of name.""" 1036 1065 if self.sql_name_caseless: 1037 1066 name = name.lower() … … 1044 1073 name = name[:maxlen] 1045 1074 1046 # This base class doesn't use the "quoted" arg,1047 # but most subclasses will.1048 1075 return name 1049 1076 1050 def column_name(self, classname, name , full=False, quoted=True):1051 """The column name, escaped for SQL. If full, include tablename."""1077 def column_name(self, classname, name): 1078 """The column name, escaped for SQL.""" 1052 1079 # If you want to use a map from UnitProperty names 1053 # to DB column names, override this method. 1054 name = self.sql_name(name, quoted=quoted) 1055 if not full: 1056 return name 1057 1058 alias = getattr(classname, "alias", None) 1059 if alias is None: 1060 tname = self.table_name(classname, quoted=quoted) 1061 else: 1062 tname = (classname.alias or classname.tablename) 1063 return '%s.%s' % (tname, name) 1064 1065 def table_name(self, name, quoted=True): 1066 """The table name, escaped for SQL.""" 1080 # to DB column names, override this method (that's why 1081 # the classname must be included in the args). 1082 return self.sql_name(name) 1083 1084 def table_name(self, name): 1085 """Return the SQL table name for the given key.""" 1067 1086 # If you want to use a map from Unit class names 1068 1087 # to DB table names, override this method. 1069 return self.sql_name(self.prefix + name , quoted=quoted)1088 return self.sql_name(self.prefix + name) 1070 1089 1071 1090 # Connecting # … … 1091 1110 """ 1092 1111 clsname = cls.__name__ 1093 tablename = self.table _name(clsname)1112 tablename = self.tables[clsname].name 1094 1113 if fields: 1095 fields = [self. column_name(clsname, x) for x in fields]1114 fields = [self.quote(self.column_name(clsname, x)) for x in fields] 1096 1115 if distinct: 1097 1116 sql = 'SELECT DISTINCT %s FROM %s' 1098 1117 else: 1099 1118 sql = 'SELECT %s FROM %s' 1100 sql = sql % (', '.join(fields), tablename)1101 else: 1102 sql = 'SELECT * FROM %s' % tablename1119 sql = sql % (', '.join(fields), self.quote(tablename)) 1120 else: 1121 sql = 'SELECT * FROM %s' % self.quote(tablename) 1103 1122 1104 1123 w, i = self.where((clsname,), expr) … … 1130 1149 def fetch(self, query, conn=None): 1131 1150 """fetch(query, conn=None) -> rowdata, columns. 1132 1133 query should be a SQL query in string format 1151 1152 query should be a SQL query in string format 1134 1153 rowdata will be an iterable of iterables containing the result values. 1135 1154 columns will be an iterable of (column name, data type) pairs. … … 1157 1176 idnames = list(cls.identifiers) 1158 1177 for key in idnames + [x for x in cls.properties if x not in idnames]: 1159 index, ftype = columns[self.column_name(clsname, key , quoted=False)]1178 index, ftype = columns[self.column_name(clsname, key)] 1160 1179 props.append((key, index, ftype)) 1161 1180 … … 1196 1215 cls = unit.__class__ 1197 1216 clsname = cls.__name__ 1198 tablename = self.table _name(clsname)1217 tablename = self.tables[clsname].name 1199 1218 if not unit.sequencer.valid_id(unit.identity()): 1200 1219 # Examine all existing IDs and grant the "next" one. 1201 id_fields = [self. column_name(clsname, key)1220 id_fields = [self.quote(self.column_name(clsname, key)) 1202 1221 for key in cls.identifiers] 1203 1222 data, cols = self.fetch('SELECT %s FROM %s;' % 1204 (', '.join(id_fields), tablename))1223 (', '.join(id_fields), self.quote(tablename))) 1205 1224 if data: 1206 1225 # sqlite 2, for example, has empty cols tuple if no data. … … 1226 1245 for key in cls.properties: 1227 1246 val = self.toAdapter.coerce(getattr(unit, key)) 1228 fields.append(self. column_name(clsname, key))1247 fields.append(self.quote(self.column_name(clsname, key))) 1229 1248 values.append(val) 1230 1249 … … 1232 1251 values = ", ".join(values) 1233 1252 self.execute('INSERT INTO %s (%s) VALUES (%s);' % 1234 (s tr(tablename), fields, values))1253 (self.quote(tablename), fields, values)) 1235 1254 1236 1255 def id_clause(self, unit): … … 1239 1258 col = self.column_name 1240 1259 c = self.toAdapter.coerce 1241 return " AND ".join(["%s = %s" % ( col(clsname, key),1260 return " AND ".join(["%s = %s" % (self.quote(col(clsname, key)), 1242 1261 c(getattr(unit, key))) 1243 1262 for key in unit.identifiers]) … … 1254 1273 val = self.toAdapter.coerce(getattr(unit, key)) 1255 1274 parms.append('%s = %s' % 1256 (self.column_name(clsname, key), val)) 1275 (self.quote(self.column_name(clsname, key)), 1276 val)) 1257 1277 1258 1278 if parms: 1259 1279 sql = ('UPDATE %s SET %s WHERE %s;' % 1260 (self.table_name(clsname), ", ".join(parms), 1280 (self.quote(self.tables[clsname].name), 1281 ", ".join(parms), 1261 1282 self.id_clause(unit))) 1262 1283 self.execute(sql) … … 1270 1291 star = "" 1271 1292 self.execute('DELETE%s FROM %s WHERE %s;' % 1272 (star, self. table_name(unit.__class__.__name__),1293 (star, self.quote(self.tables[unit.__class__.__name__].name), 1273 1294 self.id_clause(unit))) 1274 1295 … … 1367 1388 msg = ("No association found between %s and %s." % (name1, name2)) 1368 1389 raise dejavu.AssociationError(msg) 1369 near = '%s.%s' % (nearClass, self.column_name(nearClass, ua.nearKey)) 1370 far = '%s.%s' % (farClass, self.column_name(farClass, ua.farKey)) 1390 1391 near = '%s.%s' % (self.quote(nearClass), 1392 self.quote(self.column_name(nearClass, ua.nearKey))) 1393 far = '%s.%s' % (self.quote(farClass), 1394 self.quote(self.column_name(farClass, ua.farKey))) 1371 1395 1372 1396 return "(%s %s JOIN %s ON %s = %s)" % (name1, j, name2, near, far) … … 1462 1486 1463 1487 def create_database(self): 1464 self.execute("CREATE DATABASE %s;" % self. sql_name(self.dbname))1488 self.execute("CREATE DATABASE %s;" % self.quote(self.sql_name(self.dbname))) 1465 1489 1466 1490 def drop_database(self): 1467 self.execute("DROP DATABASE %s;" % self. sql_name(self.dbname))1491 self.execute("DROP DATABASE %s;" % self.quote(self.sql_name(self.dbname))) 1468 1492 1469 1493 def create_storage(self, cls): 1470 1494 """Create storage for the given class.""" 1471 clsname = cls.__name__ 1472 tablename = self.table_name(clsname) 1473 typename = self.typeAdapter.coerce 1474 1495 colname = self.column_name 1496 1497 # Make a Table object. 1498 tablename = self.table_name(cls.__name__) 1499 t = self.tableclass(self, tablename) 1500 1501 indices = cls.indices() 1475 1502 fields = [] 1476 1503 for key in cls.properties: 1477 fields.append('%s %s' % (self.column_name(clsname, key), 1478 typename(cls, key))) 1479 self.execute('CREATE TABLE %s (%s);' % (tablename, ", ".join(fields))) 1480 1481 for index in cls.indices(): 1482 i = self.table_name("i" + clsname + index) 1483 self.execute('CREATE INDEX %s ON %s (%s);' % 1484 (i, tablename, self.column_name(clsname, index))) 1504 dbtype = self.typeAdapter.coerce(cls, key) 1505 prop = cls.property(key) 1506 cname = colname(cls.__name__, key) 1507 col = self.columnclass(cname, dbtype, prop.type, 1508 prop.default, prop.hints.copy()) 1509 # Use the superclass call to avoid ALTER TABLE. 1510 dict.__setitem__(t.columns, key, col) 1511 1512 if key in indices: 1513 iname = self.table_name("i" + cls.__name__ + key) 1514 i = self.indexclass(iname, tablename, cname) 1515 # Use the superclass call to avoid CREATE INDEX. 1516 dict.__setitem__(t.columns.indices, key, i) 1517 1518 # Attach to self.tables, which should call CREATE TABLE. 1519 self.tables[cls.__name__] = t 1485 1520 1486 1521 def has_storage(self, cls): 1487 try: 1488 # Must use fetch here instead of execute, because e.g. MySQL 1489 # must call store_result if the query has a result set 1490 # (or it will crash on a subsequent execute). 1491 self.fetch("SELECT * FROM %s;" % self.table_name(cls.__name__)) 1492 except: 1493 return False 1494 return True 1522 return cls.__name__ in self.tables 1495 1523 1496 1524 def drop_storage(self, cls): 1497 self.execute('DROP TABLE %s;' % self.table_name(cls.__name__)) 1525 del self.tables[cls.__name__] 1526 1527 def rename_storage(self, oldname, newname): 1528 self.arena.log("rename table %s to %s" % (oldname, newname), 1529 dejavu.LOGSQL) 1530 self.tables.rename(oldname, newname) 1498 1531 1499 1532 def add_property(self, cls, name): 1500 1533 if not self.has_property(cls, name): 1501 c lsname = cls.__name__1502 self.execute("ALTER TABLE %s ADD COLUMN %s %s;" %1503 (self.table_name(clsname),1504 self.column_name(clsname, name),1505 self.typeAdapter.coerce(cls, name),1506 ))1534 cname = self.column_name(cls.__name__, name) 1535 dbtype = self.typeAdapter.coerce(cls, name) 1536 prop = getattr(cls, name) 1537 c = self.columnclass(cname, dbtype, prop.type, 1538 prop.default, prop.hints.copy()) 1539 self.tables[cls.__name__].columns[name] = c 1507 1540 1508 1541 def has_property(self, cls, name): 1509 clsname = cls.__name__ 1510 try: 1511 # Must use fetch here instead of execute, because e.g. MySQL 1512 # must call store_result if the query has a result set 1513 # (or it will crash on a subsequent execute). 1514 self.fetch("SELECT %s FROM %s;" % 1515 (self.column_name(clsname, name), 1516 self.table_name(clsname))) 1517 except: 1518 return False 1519 return True 1542 return name in self.tables[cls.__name__].columns 1520 1543 1521 1544 def drop_property(self, cls, name): 1522 1545 if self.has_property(cls, name): 1523 clsname = cls.__name__ 1524 if self.has_index(cls, name): 1525 self.drop_index(cls, name) 1526 self.execute("ALTER TABLE %s DROP COLUMN %s;" % 1527 (self.table_name(clsname), 1528 self.column_name(clsname, name))) 1546 del self.tables[cls.__name__].columns[name] 1529 1547 1530 1548 def rename_property(self, cls, oldname, newname): 1531 clsname = cls.__name__ 1532 oldname = self.column_name(clsname, oldname) 1533 newname = self.column_name(clsname, newname) 1534 if oldname != newname: 1535 self.execute("ALTER TABLE %s RENAME COLUMN %s TO %s;" % 1536 (self.table_name(clsname), oldname, newname)) 1549 self.tables[cls.__name__].columns.rename(oldname, newname) 1537 1550 1538 1551 def has_index(self, cls, name): 1539 tablename = self.table_name(cls.__name__, quoted=False) 1540 indices = [i.colname for i in self.get_indices(tablename)] 1541 return (name in indices) 1552 return name in self.tables[cls.__name__].columns.indices 1542 1553 1543 1554 def drop_index(self, cls, name): 1544 clsname = cls.__name__ 1545 self.execute('DROP INDEX %s ON %s;' % 1546 (self.sql_name("i" + clsname + name), 1547 self.table_name(clsname))) 1548 1549 1550 class Table: 1551 """A table in a database.""" 1552 1553 def __init__(self, name): 1554 self.name = name 1555 self.columns = [] 1556 1557 def __repr__(self): 1558 return "dejavu.db.Table(%s)" % repr(self.name) 1559 1560 1561 class Column: 1562 """A column in a table in a database.""" 1563 1564 def __init__(self, key, type, default=None): 1565 self.key = key 1566 self.type = type 1567 self.default = default 1568 self.hints = {} 1569 1570 def __repr__(self): 1571 return ("dejavu.db.Column(%s, %s, default=%s, hints=%s)" % 1572 (repr(self.key), repr(self.type), 1573 repr(self.default), repr(self.hints)) 1574 ) 1575 1576 1577 class Index: 1578 """An index on a table column (or columns) in a database.""" 1579 1580 def __init__(self, name, tablename, colname, pk=True, unique=True): 1581 self.name = name 1582 self.tablename = tablename 1583 self.colname = colname 1584 self.pk = pk 1585 self.unique = unique 1586 1587 def __repr__(self): 1588 return ("dejavu.db.Index(%s, %s, %s, pk=%s, unique=%s)" % 1589 (repr(self.name), repr(self.tablename), repr(self.colname), 1590 repr(self.pk), repr(self.unique))) 1591 1555 del self.tables[cls.__name__].columns.indices[name] 1556 1557 def sync(self, conn=None): 1558 """Populate self using all registered classes.""" 1559 # Use the superclass call to avoid DROP TABLE. 1560 dict.clear(self.tables) 1561 1562 dbtables = self.tables._get_tables(conn) 1563 for cls in self.arena._registered_classes: 1564 # Try to find a matching Table object from _get_tables. 1565 db_tname = self.prefix + self.table_name(cls.__name__) 1566 t = [x for x in dbtables if x.name == db_tname] 1567 if t: 1568 t = t[0] 1569 for c in self._get_columns(t.name): 1570 # Use the superclass call to avoid ALTER TABLE 1571 dict.__setitem__(t.columns, c.name, c) 1572 # Use the superclass call to avoid CREATE TABLE 1573 dict.__setitem__(self, db_tname, t) 1574 1575 def autoclass(self, table, newclassname=None): 1576 """Create a Unit class automatically from this table and its columns.""" 1577 class AutoUnitClass(dejavu.Unit): 1578 pass 1579 for cname, c in table.columns.iteritems(): 1580 AutoUnitClass.set_property(cname, c.type) 1581 1582 if newclassname is None: 1583 newclassname = table.name 1584 AutoUnitClass.__name__ = newclassname 1585 1586 return AutoUnitClass 1587 trunk/storage/storeado.py
r225 r226 54 54 zeroHour = datetime.date(1899, 12, 30).toordinal() 55 55 56 # DataTypeEnum 57 adEmpty = 0 58 adSmallInt = 2 59 adInteger = 3 60 adSingle = 4 61 adDouble = 5 62 adCurrency = 6 63 adDate = 7 64 adBSTR = 8 65 adIDispatch = 9 66 adError = 10 67 adBoolean = 11 68 adVariant = 12 69 adIUnknown = 13 70 adDecimal = 14 71 adTinyInt = 16 72 adUnsignedTinyInt = 17 73 adUnsignedSmallInt = 18 74 adUnsignedInt = 19 75 adBigInt = 20 76 adUnsignedBigInt = 21 77 adGUID = 72 # e.g. {E5D50A9B-33D2-11D3-AAB3-00104BA31425} 78 adBinary = 128 79 adChar = 129 80 adWChar = 130 81 adNumeric = 131 82 adUserDefined = 132 83 adDBDate = 133 84 adDBTime = 134 85 adDBTimeStamp = 135 86 adVarChar = 200 87 adLongVarChar = 201 88 adVarWChar = 202 89 adLongVarWChar = 203 90 adVarBinary = 204 91 adLongVarBinary = 205 92 56 dbtypes = { 57 0: 'EMPTY', 2: 'SMALLINT', 58 3: 'INTEGER', 4: 'SINGLE', 59 5: 'DOUBLE', 6: 'CURRENCY', 60 7: 'DATE', 8: 'BSTR', 61 9: 'IDISPATCH', 10: 'ERROR', 62 11: 'BOOLEAN', 12: 'VARIANT', 63 13: 'IUNKNOWN', 14: 'DECIMAL', 64 16: 'TINYINT', 17: 'UNSIGNEDTINYINT', 65 18: 'UNSIGNEDSMALLINT', 19: 'UNSIGNEDINT', 66 20: 'BIGINT', 21: 'UNSIGNEDBIGINT', 67 72: 'GUID', 128: 'BINARY', 68 129: 'CHAR', 130: 'WCHAR', 69 131: 'NUMERIC', 132: 'USERDEFINED', 70 133: 'DBDATE', 134: 'DBTIME', 71 135: 'DBTIMESTAMP', 200: 'VARCHAR', 72 201: 'LONGVARCHAR', 202: 'VARWCHAR', 73 203: 'LONGVARWCHAR', 204: 'VARBINARY', 74 205: 'LONGVARBINARY' 75 } 93 76 94 77 def time_from_com(com_date): … … 175 158 176 159 def coerce_float(self, value, coltype): 177 if coltype == adCurrencyand isinstance(value, tuple):160 if coltype == 0x06 and isinstance(value, tuple): 178 161 # See http://groups.google.com/group/comp.lang.python/ 179 162 # browse_frm/thread/fed03c64735c9e9c … … 353 336 354 337 338 class ADOColumnSet(db.ColumnSet): 339 340 def _rename(self, oldcol, newname): 341 conn = self.table.sm.connection() 342 try: 343 cat = win32com.client.Dispatch(r'ADOX.Catalog') 344 cat.ActiveConnection = conn 345 cat.Tables(self.table.name).Columns(oldcol.name).Name = newname 346 finally: 347 conn = None 348 cat = None 349 350 351 class ADOTableSet(db.TableSet): 352 353 def _get_tables(self, conn=None): 354 # cols will be 355 # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 356 # (u'TABLE_TYPE', 202), (u'TABLE_GUID', 72), (u'DESCRIPTION', 203), 357 # (u'TABLE_PROPID', 19), (u'DATE_CREATED', 7), (u'DATE_MODIFIED', 7)] 358 data, cols = self.sm.fetch(adSchemaTables, conn=conn, schema=True) 359 return [self.sm.tableclass(self.sm, row[2]) for row in data] 360 361 def _get_columns(self, tablename, conn=None): 362 # columns will be 363 # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 364 # (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72), (u'COLUMN_PROPID', 19), 365 # (u'ORDINAL_POSITION', 19), (u'COLUMN_HASDEFAULT', 11), 366 # (u'COLUMN_DEFAULT', 203), (u'COLUMN_FLAGS', 19), (u'IS_NULLABLE', 11), 367 # (u'DATA_TYPE', 18), (u'TYPE_GUID', 72), (u'CHARACTER_MAXIMUM_LENGTH', 19), 368 # (u'CHARACTER_OCTET_LENGTH', 19), (u'NUMERIC_PRECISION', 18), 369 # (u'NUMERIC_SCALE', 2), (u'DATETIME_PRECISION', 19), 370 # (u'CHARACTER_SET_CATALOG', 202), (u'CHARACTER_SET_SCHEMA', 202), 371 # (u'CHARACTER_SET_NAME', 202), (u'COLLATION_CATALOG', 202), 372 # (u'COLLATION_SCHEMA', 202), (u'COLLATION_NAME', 202), 373 # (u'DOMAIN_CATALOG', 202), (u'DOMAIN_SCHEMA', 202), 374 # (u'DOMAIN_NAME', 202), (u'DESCRIPTION', 203)] 375 data, _ = self.sm.fetch(adSchemaColumns, conn=conn, schema=True) 376 cols = [] 377 for row in data: 378 # I tried passing criteria to OpenSchema, but passing None is 379 # not the same as passing pythoncom.Empty (which errors). 380 if tablename and row[2] != tablename: 381 continue 382 383 dbtype = dbtypes[row[11]] 384 c = self.sm.columnclass(row[3], dbtype, None, row[8]) 385 if dbtype in ("DATE", "DBDATE"): 386 c.type = datetime.date 387 elif dbtype == "DBTIME": 388 c.type = datetime.time 389 elif dbtype == "DBTIMESTAMP": 390 c.type = datetime.datetime 391 elif dbtype in ("SMALLINT", "INTEGER", "TINYINT", 392 "UNSIGNEDTINYINT", "UNSIGNEDSMALLINT", 393 "UNSIGNEDINT"): 394 c.type = int 395 c.hints['bytes'] = row[15] 396 elif dbtype == "BOOLEAN": 397 c.type = bool 398 elif dbtype in ("BIGINT", "UNSIGNEDBIGINT"): 399 c.type = long 400 c.hints['bytes'] = row[15] 401 elif dbtype in ("SINGLE", "DOUBLE"): 402 c.type = float 403 c.hints['precision'] = row[15] 404 c.hints['scale'] = row[16] 405 elif dbtype in ("DECIMAL", "NUMERIC", "CURRENCY"): 406 c.type = decimal.Decimal 407 c.hints['precision'] = row[15] 408 c.hints['scale'] = row[16] 409 elif dbtype in ("BSTR", "VARIANT", "BINARY", "CHAR", 410 "VARCHAR", "LONGVARCHAR", 411 "VARBINARY", "LONGVARBINARY"): 412 c.type = str 413 if row[13]: 414 c.hints['bytes'] = row[13] 415 elif dbtype in ("WCHAR", "VARWCHAR", "LONGVARWCHAR"): 416 c.type = unicode 417 if row[13]: 418 c.hints['bytes'] = row[13] 419 cols.append(c) 420 return cols 421 422 def _get_indices(self, tablename=None, conn=None): 423 # cols will be 424 # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 425 # (u'INDEX_CATALOG', 202), (u'INDEX_SCHEMA', 202), (u'INDEX_NAME', 202), 426 # (u'PRIMARY_KEY', 11), (u'UNIQUE', 11), (u'CLUSTERED', 11), (u'TYPE', 18), 427 # (u'FILL_FACTOR', 3), (u'INITIAL_SIZE', 3), (u'NULLS', 3), 428 # (u'SORT_BOOKMARKS', 11), (u'AUTO_UPDATE', 11), (u'NULL_COLLATION', 3), 429 # (u'ORDINAL_POSITION', 19), (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72), 430 # (u'COLUMN_PROPID', 19), (u'COLLATION', 2), (u'CARDINALITY', 21), 431 # (u'PAGES', 3), (u'FILTER_CONDITION', 202), (u'INTEGRATED', 11)] 432 data, _ = self.sm.fetch(adSchemaIndexes, conn=conn, schema=True) 433 indices = [] 434 for row in data: 435 # I tried passing criteria to OpenSchema, but passing None is 436 # not the same as passing pythoncom.Empty (which errors). 437 if tablename and row[2] != tablename: 438 continue 439 i = self.sm.indexclass(row[5], row[2], row[17], row[6], row[7]) 440 indices.append(i) 441 return indices 442 443 def _rename(self, oldtable, newname): 444 conn = self.sm.connection() 445 try: 446 cat = win32com.client.Dispatch(r'ADOX.Catalog') 447 cat.ActiveConnection = conn 448 cat.Tables(oldtable.name).Name = newname 449 finally: 450 conn = None 451 cat = None 452 453 355 454 class StorageManagerADO(db.StorageManagerDB): 356 455 """StoreManager to save and retrieve Units via ADO 2.7. … … 361 460 decompiler = ADOSQLDecompiler 362 461 fromAdapter = AdapterFromADO() 462 tablesetclass = ADOTableSet 463 columnsetclass = ADOColumnSet 363 464 364 465 def connatoms(self): … … 370 471 return atoms 371 472 372 def sql_name(self, name, quoted=True): 373 if quoted: 374 name = '[' + name + ']' 375 return name 473 def quote(self, name): 474 """Return name, quoted for use in an SQL statement.""" 475 return '[' + name + ']' 376 476 377 477 def _get_conn(self): … … 443 543 adoconn = win32com.client.Dispatch(r'ADODB.Connection') 444 544 return "ADO Version: %s" % adoconn.Version 445 446 # Schemas # 447 448 def has_storage(self, cls): 449 names = [t.name for t in self.get_tables()] 450 return self.table_name(cls.__name__, quoted=False) in names 451 452 def rename_storage(self, oldname, newname): 453 oldname = self.table_name(oldname, quoted=False) 454 newname = self.table_name(newname, quoted=False) 455 self.arena.log("rename table %s to %s" % (oldname, newname), 456 dejavu.LOGSQL) 457 458 conn = self.connection() 459 try: 460 cat = win32com.client.Dispatch(r'ADOX.Catalog') 461 cat.ActiveConnection = conn 462 cat.Tables(oldname).Name = newname 463 finally: 464 conn = None 465 cat = None 466 467 def rename_property(self, cls, oldname, newname): 468 clsname = cls.__name__ 469 tblname = self.table_name(clsname, quoted=False) 470 oldname = self.column_name(clsname, oldname, quoted=False) 471 newname = self.column_name(clsname, newname, quoted=False) 472 self.arena.log("rename %s column %s to %s" % 473 (tblname, oldname, newname), 474 dejavu.LOGSQL) 475 476 conn = self.connection() 477 try: 478 cat = win32com.client.Dispatch(r'ADOX.Catalog') 479 cat.ActiveConnection = conn 480 cat.Tables(tblname).Columns(oldname).Name = newname 481 finally: 482 conn = None 483 cat = None 484 485 def drop_index(self, cls, name): 486 clsname = cls.__name__ 487 tablename = self.table_name(clsname, quoted=False) 488 qtablename = self.table_name(clsname) 489 colname = self.column_name(clsname, name, quoted=False) 490 491 for i in self.get_indices(tablename): 492 if i.colname == colname: 493 # The INDEX_NAME may include a trailing " ASC" or other data 494 self.execute('DROP INDEX [%s] ON %s;' % (i.name, qtablename)) 495 496 def get_tables(self, conn=None): 497 # cols will be 498 # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 499 # (u'TABLE_TYPE', 202), (u'TABLE_GUID', 72), (u'DESCRIPTION', 203), 500 # (u'TABLE_PROPID', 19), (u'DATE_CREATED', 7), (u'DATE_MODIFIED', 7)] 501 data, cols = self.fetch(adSchemaTables, conn=conn, schema=True) 502 return [db.Table(row[2]) for row in data] 503 504 def get_columns(self, tablename, conn=None): 505 # cols will be 506 # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 507 # (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72), (u'COLUMN_PROPID', 19), 508 # (u'ORDINAL_POSITION', 19), (u'COLUMN_HASDEFAULT', 11), 509 # (u'COLUMN_DEFAULT', 203), (u'COLUMN_FLAGS', 19), (u'IS_NULLABLE', 11), 510 # (u'DATA_TYPE', 18), (u'TYPE_GUID', 72), (u'CHARACTER_MAXIMUM_LENGTH', 19), 511 # (u'CHARACTER_OCTET_LENGTH', 19), (u'NUMERIC_PRECISION', 18), 512 # (u'NUMERIC_SCALE', 2), (u'DATETIME_PRECISION', 19), 513 # (u'CHARACTER_SET_CATALOG', 202), (u'CHARACTER_SET_SCHEMA', 202), 514 # (u'CHARACTER_SET_NAME', 202), (u'COLLATION_CATALOG', 202), 515 # (u'COLLATION_SCHEMA', 202), (u'COLLATION_NAME', 202), 516 # (u'DOMAIN_CATALOG', 202), (u'DOMAIN_SCHEMA', 202), 517 # (u'DOMAIN_NAME', 202), (u'DESCRIPTION', 203)] 518 data, _ = self.fetch(adSchemaColumns, conn=conn, schema=True) 519 cols = [] 520 for row in data: 521 # I tried passing criteria to OpenSchema, but passing None is 522 # not the same as passing pythoncom.Empty (which errors). 523 if tablename and row[2] != tablename: 524 continue 525 datatype = row[11] 526 c = db.Column(row[3], None, row[8]) 527 if datatype in (adDate, adDBDate): 528 c.type = datetime.date 529 elif datatype == adDBTime: 530 c.type = datetime.time 531 elif datatype == adDBTimeStamp: 532 c.type = datetime.datetime 533 elif datatype in (adSmallInt, adInteger, adTinyInt, 534 adUnsignedTinyInt, adUnsignedSmallInt, 535 adUnsignedInt): 536 c.type = int 537 c.hints['bytes'] = row[15] 538 elif datatype == adBoolean: 539 c.type = bool 540 elif datatype in (adBigInt, adUnsignedBigInt): 541 c.type = long 542 c.hints['bytes'] = row[15] 543 elif datatype in (adSingle, adDouble, adCurrency): 544 c.type = float 545 c.hints['bytes'] = row[15] 546 elif datatype in (adDecimal, adNumeric): 547 c.type = decimal.Decimal 548 c.hints['bytes'] = row[15] 549 elif datatype in (adBSTR, adVariant, adBinary, adChar, 550 adVarChar, adLongVarChar, 551 adVarBinary, adLongVarBinary): 552 c.type = str 553 if row[13]: 554 c.hints['bytes'] = row[13] 555 elif datatype in (adWChar, adVarWChar, adLongVarWChar): 556 c.type = unicode 557 if row[13]: 558 c.hints['bytes'] = row[13] 559 cols.append(c) 560 return cols 561 562 def get_indices(self, tablename=None, conn=None): 563 # cols will be 564 # [(u'TABLE_CATALOG', 202), (u'TABLE_SCHEMA', 202), (u'TABLE_NAME', 202), 565 # (u'INDEX_CATALOG', 202), (u'INDEX_SCHEMA', 202), (u'INDEX_NAME', 202), 566 # (u'PRIMARY_KEY', 11), (u'UNIQUE', 11), (u'CLUSTERED', 11), (u'TYPE', 18), 567 # (u'FILL_FACTOR', 3), (u'INITIAL_SIZE', 3), (u'NULLS', 3), 568 # (u'SORT_BOOKMARKS', 11), (u'AUTO_UPDATE', 11), (u'NULL_COLLATION', 3), 569 # (u'ORDINAL_POSITION', 19), (u'COLUMN_NAME', 202), (u'COLUMN_GUID', 72), 570 # (u'COLUMN_PROPID', 19), (u'COLLATION', 2), (u'CARDINALITY', 21), 571 # (u'PAGES', 3), (u'FILTER_CONDITION', 202), (u'INTEGRATED', 11)] 572 data, _ = self.fetch(adSchemaIndexes, conn=conn, schema=True) 573 indices = [] 574 for row in data: 575 # I tried passing criteria to OpenSchema, but passing None is 576 # not the same as passing pythoncom.Empty (which errors). 577 if tablename and row[2] != tablename: 578 continue 579 indices.append(db.Index(row[5], row[2], row[17], row[6], row[7])) 580 return indices 545 581 546 582 547 … … 670 635 671 636 637 class SQLServerColumnSet(ADOColumnSet): 638 639 def __setitem__(self, key, column): 640 t = self.table 641 # SQL Server doesn't use the "COLUMN" keyword with "ADD" 642 t.sm.execute("ALTER TABLE %s ADD %s %s;" % 643 (t.sm.quote(t.name), t.sm.quote(column.name), 644 column.dbtype)) 645 dict.__setitem__(self, key, column) 646 647 def _rename(self, oldcol, newname): 648 t = self.table 649 t.sm.execute("EXEC sp_rename '%s.%s', '%s', 'COLUMN'" % 650 (t.name, oldcol.name, newname)) 651 652 672 653 class StorageManagerADO_SQLServer(StorageManagerADO): 673 654 674 655 typeAdapter = FieldTypeAdapter_SQLServer() 675 656 toAdapter = AdapterToADOSQL_SQLServer() 657 columnsetclass = SQLServerColumnSet 676 658 677 659 def __init__(self, name, arena, allOptions={}): … … 685 667 cls = unit.__class__ 686 668 clsname = cls.__name__ 687 tablename = self.table _name(clsname)669 tablename = self.tables[clsname].name 688 670 689 671 fields = [] … … 695 677 continue 696 678 val = self.toAdapter.coerce(getattr(unit, key)) 697 fields.append(self. column_name(clsname, key))679 fields.append(self.quote(self.column_name(clsname, key))) 698 680 values.append(val) 699 681 … … 701 683 values = ", ".join(values) 702 684 self.execute('INSERT INTO %s (%s) VALUES (%s);' % 703 (s tr(tablename), fields, values))685 (self.quote(tablename), fields, values)) 704 686 705 687 # Grab the new ID. This is threadsafe because db.reserve has a mutex. … … 707 689 # None) when retrieving ID's just after a 99-thread-test ran. Moving 708 690 # the multithreading test fixed it. IDENT_CURRENT worked regardless. 709 data, col_defs = self.fetch("SELECT IDENT_CURRENT('%s');" 710 % str(tablename)) 691 data, col_defs = self.fetch("SELECT IDENT_CURRENT('%s');" % tablename) 711 692 setattr(unit, cls.identifiers[0], data[0][0]) 712 693 … … 714 695 715 696 def create_database(self): 716 # This method hasn't been tested yet for SQL server .697 # This method hasn't been tested yet for SQL server (only MSDE). 717 698 adoconn = win32com.client.Dispatch(r'ADODB.Connection') 718 699 atoms = self.connatoms() 719 700 atoms['INITIAL CATALOG'] = "tempdb" 720 701 adoconn.Open("; ".join(["%s=%s" % (k, v) for k, v in atoms.iteritems()])) 721 adoconn.Execute("CREATE DATABASE %s" % self. sql_name(self.dbname))702 adoconn.Execute("CREATE DATABASE %s" % self.quote(self.sql_name(self.dbname))) 722 703 adoconn.Close() 723 704 … … 727 708 atoms['INITIAL CATALOG'] = "tempdb" 728 709 adoconn.Open("; ".join(["%s=%s" % (k, v) for k, v in atoms.iteritems()])) 729 adoconn.Execute("DROP DATABASE %s;" % self. sql_name(self.dbname))710 adoconn.Execute("DROP DATABASE %s;" % self.quote(self.sql_name(self.dbname))) 730 711 adoconn.Close() 731 732 def add_property(self, cls, name): 733 clsname = cls.__name__ 734 # SQL Server doesn't use the "COLUMN" keyword with "ADD" 735 self.execute("ALTER TABLE %s ADD %s %s;" % 736 (self.table_name(clsname), 737 self.column_name(clsname, name), 738 self.typeAdapter.coerce(cls, name), 739 )) 740 741 def rename_property(self, cls, oldname, newname): 742 clsname = cls.__name__ 743 oldname = self.column_name(clsname, oldname, quoted=False) 744 newname = self.column_name(clsname, newname, quoted=False) 745 if oldname != newname: 746 self.execute("EXEC sp_rename '%s.%s', '%s', 'COLUMN'" % 747 (self.table_name(clsname), oldname, newname)) 712 748 713 749 714 … … 866 831 cls = unit.__class__ 867 832 clsname = cls.__name__ 868 tablename = self.table _name(clsname)833 tablename = self.tables[clsname].name 869 834 870 835 fields = [] … … 876 841 continue 877 842 val = self.toAdapter.coerce(getattr(unit, key)) 878 fields.append(self. column_name(clsname, key))843 fields.append(self.quote(self.column_name(clsname, key))) 879 844 values.append(val) 880 845 … … 882 847 values = ", ".join(values) 883 848 self.execute('INSERT INTO %s (%s) VALUES (%s);' % 884 (s tr(tablename), fields, values))849 (self.quote(tablename), fields, values)) 885 850 886 851 # Grab the new ID. This is threadsafe because db.reserve has a mutex. trunk/storage/storemysql.py
r225 r226 149 149 150 150 151 class StorageManagerMySQL(db.StorageManagerDB): 152 """StoreManager to save and retrieve Units via _mysql.""" 153 154 sql_name_max_length = 64 155 # MySQL uses case-sensitive database and table names on Unix, but 156 # not on Windows. Use all-lowercase identifiers to work around the 157 # problem. "Column names, index names, and column aliases are not 158 # case sensitive on any platform." 159 # If deployers set lower_case_table_names to 1, it would help. 160 sql_name_caseless = True 161 162 typeAdapter = FieldTypeAdapterMySQL() 163 toAdapter = AdapterToMySQL() 164 fromAdapter = AdapterFromMySQL() 165 166 def __init__(self, name, arena, allOptions={}): 167 connargs = ["host", "user", "passwd", "db", "port", "unix_socket", 168 "conv", "connect_time", "compress", "named_pipe", 169 "init_command", "read_default_file", "read_default_group", 170 "cursorclass", "client_flag", 171 ] 172 self.connargs = dict([(k, v) for k, v in allOptions.iteritems() 173 if k in connargs]) 174 self.dbname = self.connargs['db'] 175 176 db.StorageManagerDB.__init__(self, name, arena, allOptions) 177 178 self.decompiler = MySQLDecompiler 179 # Get the version string from MySQL, to see if we need 180 # a different decompiler. 181 conn = self._template_conn() 182 rowdata, cols = self.fetch("SELECT version();", conn) 183 conn.close() 184 v = rowdata[0][0] 185 self._version = storage.Version(v) 186 if self._version > storage.Version("4.1.1"): 187 self.decompiler = MySQLDecompiler411 188 189 def sql_name(self, name, quoted=True): 190 name = db.StorageManagerDB.sql_name(self, name, quoted) 191 if quoted: 192 name = '`' + name.replace('`', '``') + '`' 193 return name 194 195 def _get_conn(self): 196 try: 197 conn = _mysql.connect(**self.connargs) 198 except _mysql.OperationalError, x: 199 if x.args[0] == 1040: # Too many connections 200 raise db.OutOfConnectionsError 201 raise 202 return conn 203 204 def _template_conn(self): 205 tmplconn = self.connargs.copy() 206 tmplconn['db'] = 'mysql' 207 return _mysql.connect(**tmplconn) 208 209 def fetch(self, query, conn=None): 210 """fetch(query, conn=None) -> rowdata, columns. 211 212 rowdata: a nested list (or tuples), column values within rows. 213 columns: a series of 2-tuples (or more). The first tuple value 214 will be the column name, the second value will be the column 215 type. 216 """ 217 if conn is None: 218 conn = self.connection() 219 self.execute(query, conn) 220 # store_result uses a client-side cursor 221 res = conn.store_result() 222 return res.fetch_row(0, 0), res.describe() 223 224 def destroy(self, unit): 225 """destroy(unit). Delete the unit.""" 226 self.execute('DELETE FROM %s WHERE %s;' % 227 (self.table_name(unit.__class__.__name__), 228 self.id_clause(unit))) 229 230 def version(self): 231 return "MySQL Version: %s" % self._version 232 233 def _seq_UnitSequencerInteger(self, unit): 234 """Reserve a unit using the table's AUTO_INCREMENT field.""" 235 cls = unit.__class__ 236 clsname = cls.__name__ 237 tablename = self.table_name(clsname) 238 239 fields = [] 240 values = [] 241 for key in cls.properties: 242 typename = self.typeAdapter.coerce(cls, key) 243 if typename.endswith("AUTO_INCREMENT"): 244 # Skip this field, since we're using AUTO_INCREMENT 245 continue 246 val = self.toAdapter.coerce(getattr(unit, key)) 247 fields.append(self.column_name(clsname, key)) 248 values.append(val) 249 250 fields = ", ".join(fields) 251 values = ", ".join(values) 252 253 conn = self.connection() 254 self.execute('INSERT INTO %s (%s) VALUES (%s);' % 255 (str(tablename), fields, values), 256 conn) 257 258 # Grab the new ID. This is threadsafe because db.reserve has a mutex. 259 setattr(unit, cls.identifiers[0], conn.insert_id()) 260 261 # Schemas # 262 263 def create_database(self): 264 # _mysql has create_db and drop_db commands, but they're deprecated. 265 sql = 'CREATE DATABASE %s;' % self.sql_name(self.dbname) 266 conn = self._template_conn() 267 self.execute(sql, conn) 268 conn.close() 269 270 def drop_database(self): 271 sql = 'DROP DATABASE %s;' % self.sql_name(self.dbname) 272 conn = self._template_conn() 273 self.execute(sql, conn) 274 conn.close() 275 276 def create_storage(self, cls): 277 clsname = cls.__name__ 278 tablename = self.table_name(clsname) 279 typename = self.typeAdapter.coerce 151 class MySQLIndexSet(db.IndexSet): 152 153 def __delitem__(self, key): 154 t = self.table 155 # MySQL might rename multiple-column indices to "PRIMARY" 156 for i in t.sm.tables._get_indices(t.name): 157 if i.colname == self[key].colname: 158 t.sm.execute('DROP INDEX %s ON %s;' % 159 (t.sm.quote(i.name), t.sm.quote(t.name))) 160 161 162 class MySQLColumnSet(db.ColumnSet): 163 164 def _rename(self, oldcol, newname): 165 # Override this to do the actual rename at the DB level. 166 t = self.table 167 t.sm.execute("ALTER TABLE %s CHANGE %s %s %s;" % 168 (t.sm.quote(t.name), t.sm.quote(oldcol.name), 169 t.sm.quote(newname), oldcol.dbtype)) 170 171 172 class MySQLTableSet(db.TableSet): 173 174 def __setitem__(self, key, table): 175 q = self.sm.quote 280 176 281 177 fields = [] 282 178 pk = [] 283 for key in cls.properties:284 qname = self.column_name(clsname, key)285 dbtype = typename(cls, key)179 for colname, col in table.columns.iteritems(): 180 qname = q(col.name) 181 dbtype = col.dbtype 286 182 fields.append('%s %s' % (qname, dbtype)) 287 if key in cls.identifiers:183 if colname in table.mysql_identifiers: 288 184 if dbtype.endswith('BLOB') or dbtype == 'TEXT': 289 185 # MySQL won't allow indexes on a BLOB field 290 186 # without a specific length. 291 qname = "%s( %s)" % (qname, 255)187 qname = "%s(255)" % qname 292 188 pk.append(qname) 189 293 190 pk = ", ".join(pk) 294 191 if pk: 295 192 pk = ", PRIMARY KEY (%s)" % pk 296 self.execute('CREATE TABLE %s (%s%s);' 297 % (tablename, ", ".join(fields), pk)) 298 299 hasdummy = False 300 if isinstance(cls.sequencer, dejavu.UnitSequencerInteger): 301 i = cls.sequencer.initial 302 if i > 1: 303 # Wow, what a hack. We have to create a dummy row 304 # to set the autoincrement initial value, and we 305 # can't delete it until after the CREATE INDEX 306 # statements below (or the counter will revert). 307 colname = self.column_name(clsname, cls.identifiers[0]) 308 self.execute("INSERT INTO %s (%s) VALUES (%s);" 309 % (tablename, colname, i - 1)) 310 hasdummy = True 311 312 for index in cls.indices(): 313 i = self.table_name("i" + clsname + index) 314 315 dbtype = typename(cls, index) 193 194 self.sm.execute('CREATE TABLE %s (%s%s);' % 195 (q(table.name), ", ".join(fields), pk)) 196 197 seq = getattr(table, "mysql_sequencer", None) 198 if seq: 199 # Wow, what a hack. We have to INSERT a dummy row 200 # to set the autoincrement initial value, and we 201 # can't delete it until after the CREATE INDEX 202 # statements (or the counter will revert). 203 colname, initial = seq 204 self.sm.execute("INSERT INTO %s (%s) VALUES (%s);" 205 % (q(table.name), q(colname), initial - 1)) 206 207 for k, index in table.columns.indices.iteritems(): 208 dbtype = table.columns[k].dbtype 316 209 if dbtype.endswith('BLOB') or dbtype == 'TEXT': 317 210 # MySQL won't allow indexes on a BLOB field 318 211 # without a specific length. 319 self. execute('CREATE INDEX %s ON %s (%s(%s));' %320 (i, tablename,321 self.column_name(clsname, index), 255))212 self.sm.execute('CREATE INDEX %s ON %s (%s(255));' % 213 (q(index.name), q(table.name), 214 q(index.colname))) 322 215 else: 323 self.execute('CREATE INDEX %s ON %s (%s);' % 324 (i, tablename, 325 self.column_name(clsname, index))) 326 327 if hasdummy: 328 self.execute("DELETE FROM %s" % tablename) 329 330 def rename_property(self, cls, oldname, newname): 331 clsname = cls.__name__ 332 oldcolname = self.column_name(clsname, oldname) 333 newcolname = self.column_name(clsname, newname) 334 if oldcolname != newcolname: 335 self.execute("ALTER TABLE %s CHANGE %s %s %s;" % 336 (self.table_name(clsname), oldcolname, newcolname, 337 self.typeAdapter.coerce(cls, newname))) 338 339 def drop_index(self, cls, name): 340 # MySQL might rename multiple-column indices to "PRIMARY" 341 clsname = cls.__name__ 342 names = [] 343 for i in self.get_indices(self.table_name(clsname, quoted=False)): 344 if i.name not in names: 345 names.append(i.name) 346 for n in names: 347 self.execute('DROP INDEX %s ON %s;' % 348 (self.sql_name(n), self.table_name(clsname))) 349 350 def get_tables(self, conn=None): 351 data, _ = self.fetch("SHOW TABLES FROM %s" % self.dbname, 352 conn=conn) 353 return [db.Table(row[0]) for row in data] 354 355 def get_columns(self, tablename=None, conn=None): 216 self.sm.execute('CREATE INDEX %s ON %s (%s);' % 217 (q(index.name), q(table.name), 218 q(index.colname))) 219 220 if seq: 221 self.sm.execute("DELETE FROM %s" % q(table.name)) 222 223 dict.__setitem__(self, key, table) 224 225 def _get_tables(self, conn=None): 226 data, _ = self.sm.fetch("SHOW TABLES FROM %s" % 227 self.sm.quote(self.sm.dbname), 228 conn=conn) 229 return [self.sm.tableclass(self.sm, row[0]) for row in data] 230 231 def _get_columns(self, tablename, conn=None): 356 232 # cols are: Field, Type, Null, Key, Default, Extra. 357 233 # See http://dev.mysql.com/doc/refman/4.1/en/describe.html 358 data, _ = self.fetch("SHOW COLUMNS FROM %s.%s" 359 % (self.dbname, self.sql_name(tablename)), 360 conn=conn) 234 q = self.sm.quote 235 data, _ = self.sm.fetch("SHOW COLUMNS FROM %s.%s" 236 % (q(self.sm.dbname), q(tablename)), 237 conn=conn) 361 238 cols = [] 362 239 for row in data: 363 c = db.Column(row[0], None, row[4])240 c = self.sm.columnclass(row[0], None, None, row[4]) 364 241 365 242 dbtype = row[1] … … 368 245 c.hints['bytes'] = dbtype[parenpos+1:-1] 369 246 dbtype = dbtype[:parenpos] 247 c.dbtype = dbtype 370 248 371 249 if dbtype in ('tinyint', 'smallint', 'mediumint', 'int', 'integer'): … … 391 269 return cols 392 270 393 def get_indices(self, tablename, conn=None):271 def _get_indices(self, tablename, conn=None): 394 272 indices = [] 395 273 try: 396 274 # cols are: Table, Non_unique, Key_name, Seq_in_index, Column_name, 397 275 # Collation, Cardinality, Sub_part, Packed, Null, Index_type, Comment 276 q = self.sm.quote 398 277 data, _ = self.fetch("SHOW INDEX FROM %s.%s" 399 % ( self.dbname, self.sql_name(tablename)),278 % (q(self.sm.dbname), q(tablename)), 400 279 conn=conn) 401 280 except _mysql.ProgrammingError, x: … … 404 283 else: 405 284 for row in data: 406 indices.append(db.Index(row[2], row[0], row[4], None, not row[1])) 285 i = self.sm.indexclass(row[2], row[0], row[4], None, not row[1]) 286 indices.append(i) 407 287 return indices 408 288 289 290 291 class StorageManagerMySQL(db.StorageManagerDB): 292 """StoreManager to save and retrieve Units via _mysql.""" 293 294 sql_name_max_length = 64 295 # MySQL uses case-sensitive database and table names on Unix, but 296 # not on Windows. Use all-lowercase identifiers to work around the 297 # problem. "Column names, index names, and column aliases are not 298 # case sensitive on any platform." 299 # If deployers set lower_case_table_names to 1, it would help. 300 sql_name_caseless = True 301 302 typeAdapter = FieldTypeAdapterMySQL() 303 toAdapter = AdapterToMySQL() 304 fromAdapter = AdapterFromMySQL() 305 306 tablesetclass = MySQLTableSet 307 columnsetclass = MySQLColumnSet 308 indexsetclass = MySQLIndexSet 309 310 def __init__(self, name, arena, allOptions={}): 311 connargs = ["host", "user", "passwd", "db", "port", "unix_socket", 312 "conv", "connect_time", "compress", "named_pipe", 313 "init_command", "read_default_file", "read_default_group", 314 "cursorclass", "client_flag", 315 ] 316 self.connargs = dict([(k, v) for k, v in allOptions.iteritems() 317 if k in connargs]) 318 self.dbname = self.connargs['db'] 319 320 db.StorageManagerDB.__init__(self, name, arena, allOptions) 321 322 self.decompiler = MySQLDecompiler 323 # Get the version string from MySQL, to see if we need 324 # a different decompiler. 325 conn = self._template_conn() 326 rowdata, cols = self.fetch("SELECT version();", conn) 327 conn.close() 328 v = rowdata[0][0] 329 self._version = storage.Version(v) 330 if self._version > storage.Version("4.1.1"): 331 self.decompiler = MySQLDecompiler411 332 333 def quote(self, name): 334 """Return name, quoted for use in an SQL statement.""" 335 return '`' + name.replace('`', '``') + '`' 336 337 def _get_conn(self): 338 try: 339 conn = _mysql.connect(**self.connargs) 340 except _mysql.OperationalError, x: 341 if x.args[0] == 1040: # Too many connections 342 raise db.OutOfConnectionsError 343 raise 344 return conn 345 346 def _template_conn(self): 347 tmplconn = self.connargs.copy() 348 tmplconn['db'] = 'mysql' 349 return _mysql.connect(**tmplconn) 350 351 def fetch(self, query, conn=None): 352 """fetch(query, conn=None) -> rowdata, columns. 353 354 rowdata: a nested list (or tuples), column values within rows. 355 columns: a series of 2-tuples (or more). The first tuple value 356 will be the column name, the second value will be the column 357 type. 358 """ 359 if conn is None: 360 conn = self.connection() 361 self.execute(query, conn) 362 # store_result uses a client-side cursor 363 res = conn.store_result() 364 return res.fetch_row(0, 0), res.describe() 365 366 def destroy(self, unit): 367 """destroy(unit). Delete the unit.""" 368 self.execute('DELETE FROM %s WHERE %s;' % 369 (self.quote(self.table_name(unit.__class__.__name__)), 370 self.id_clause(unit))) 371 372 def version(self): 373 return "MySQL Version: %s" % self._version 374 375 def _seq_UnitSequencerInteger(self, unit): 376 """Reserve a unit using the table's AUTO_INCREMENT field.""" 377 cls = unit.__class__ 378 clsname = cls.__name__ 379 tablename = self.table_name(clsname) 380 381 fields = [] 382 values = [] 383 for key in cls.properties: 384 typename = self.typeAdapter.coerce(cls, key) 385 if typename.endswith("AUTO_INCREMENT"): 386 # Skip this field, since we're using AUTO_INCREMENT 387 continue 388 val = self.toAdapter.coerce(getattr(unit, key)) 389 fields.append(self.quote(self.column_name(clsname, key))) 390 values.append(val) 391 392 fields = ", ".join(fields) 393 values = ", ".join(values) 394 395 conn = self.connection() 396 self.execute('INSERT INTO %s (%s) VALUES (%s);' % 397 (self.quote(tablename), fields, values), 398 conn) 399 400 # Grab the new ID. This is threadsafe because db.reserve has a mutex. 401 setattr(unit, cls.identifiers[0], conn.insert_id()) 402 403 # Schemas # 404 405 def create_database(self): 406 # _mysql has create_db and drop_db commands, but they're deprecated. 407 sql = 'CREATE DATABASE %s;' % self.quote(self.sql_name(self.dbname)) 408 conn = self._template_conn() 409 self.execute(sql, conn) 410 conn.close() 411 412 def drop_database(self): 413 sql = 'DROP DATABASE %s;' % self.quote(self.sql_name(self.dbname)) 414 conn = self._template_conn() 415 self.execute(sql, conn) 416 conn.close() 417 418 def create_storage(self, cls): 419 """Create storage for the given class.""" 420 colname = self.column_name 421 422 # Make a Table object. 423 tablename = self.table_name(cls.__name__) 424 t = self.tableclass(self, tablename) 425 426 indices = cls.indices() 427 fields = [] 428 for key in cls.properties: 429 dbtype = self.typeAdapter.coerce(cls, key) 430 prop = cls.property(key) 431 cname = colname(cls.__name__, key) 432 col = self.columnclass(cname, dbtype, prop.type, 433 prop.default, prop.hints.copy()) 434 # Use the superclass call to avoid ALTER TABLE. 435 dict.__setitem__(t.columns, key, col) 436 437 if key in indices: 438 iname = self.table_name("i" + cls.__name__ + key) 439 i = self.indexclass(iname, tablename, cname) 440 # Use the superclass call to avoid CREATE INDEX. 441 dict.__setitem__(t.columns.indices, key, i) 442 443 # Hack to get PRIMARY KEY right. See MySQLTableSet.__setitem__ 444 t.mysql_identifiers = cls.identifiers 445 446 # Hack to get AUTO_INCREMENT right where initial > 1. 447 # See MySQLTableSet.__setitem__ 448 if isinstance(cls.sequencer, dejavu.UnitSequencerInteger): 449 i = cls.sequencer.initial 450 if i > 1: 451 colname = self.column_name(cls.__name__, cls.identifiers[0]) 452 t.mysql_sequencer = (colname, i) 453 454 # Attach to self.tables, which should call CREATE TABLE. 455 self.tables[cls.__name__] = t 456 trunk/storage/storepypgsql.py
r225 r226 29 29 if isinstance(cls.sequencer, dejavu.UnitSequencerInteger): 30 30 if key in cls.identifiers: 31 return ("INTEGER DEFAULT nextval('%s_%s_seq') NOT NULL"32 % (cls.__name__, key))31 seqname = self.sm.quote("%s_%s_seq" % (cls.__name__, key)) 32 return "INTEGER DEFAULT nextval('%s') NOT NULL" % seqname 33 33 bytes = int(prop.hints.get('bytes', db.maxint_bytes)) 34 34 return self.int_type(bytes) … … 65 65 66 66 67 class PgIndexSet(db.IndexSet): 68 69 def __delitem__(self, key): 70 """Drop the specified index.""" 71 t = self.table 72 iname = t.sm.sql_name("i" + t.name + key) 73 t.sm.execute('DROP INDEX %s;' % t.sm.quote(iname)) 74 75 76 class PgTableSet(db.TableSet): 77 78 def _get_tables(self, conn=None): 79 data, _ = self.sm.fetch("SELECT tablename FROM pg_tables WHERE schemaname" 80 " not in ('information_schema', 'pg_catalog')", 81 conn=conn) 82 return [self.sm.tableclass(self.sm, row[0]) for row in data] 83 84 def _get_columns(self, tablename, conn=None): 85 data, _ = self.sm.fetch("SELECT oid FROM pg_class WHERE relname = '%s'" 86 % tablename, conn=conn) 87 table_OID = data[0][0] 88 sql = ("SELECT attname, atttypid, attnum, attlen, atttypmod " 89 "FROM pg_attribute WHERE attrelid = %s" % table_OID) 90 data, _ = self.sm.fetch(sql, conn=conn) 91 cols = [] 92 for row in data: 93 name = row[0] 94 if name in ('tableoid', 'cmax', 'xmax', 'cmin', 'xmin', 95 'oid', 'ctid'): 96 # This is a column which PostgreSQL defines automatically 97 continue 98 99 # Data type 100 dbtype, _ = self.sm.fetch("SELECT typname, typlen FROM pg_type " 101 "WHERE oid = %s" % row[1]) 102 if dbtype: 103 dbtype = dbtype[0][0] 104 else: 105 dbtype = None 106 c = self.sm.columnclass(row[0], dbtype) 107 108 # Python type 109 if dbtype: 110 if dbtype in ('int2', 'int4'): 111 c.type = int 112 elif dbtype == 'bool': 113 c.type = bool 114 elif dbtype == 'int8': 115 c.type = long 116 elif dbtype in ('float4', 'float8', 'money'): 117 c.type = float 118 c.hints['precision'] = row[4] 119 elif dbtype == 'numeric': 120 c.type = decimal.Decimal 121 c.hints['precision'] = row[4] 122 elif dbtype == 'date': 123 c.type = datetime.date 124 elif dbtype in ('timestamp', 'timestamptz'): 125 c.type = datetime.datetime 126 elif dbtype in ('time', 'timetz'): 127 c.type = datetime.time 128 elif dbtype in ('char', 'varchar', 'bpchar', 'text'): 129 c.type = str 130 131 # Default value 132 default, _ = self.sm.fetch("SELECT adsrc FROM pg_attrdef " 133 "WHERE adnum = %s AND adrelid = %s" 134 % (row[2], table_OID)) 135 if default: 136 c.default = default[0][0] 137 # Sequences 138 if c.default.startswith("nextval("): 139 c.default = None 140 else: 141 c.default = None 142 143 bytes = row[3] 144 if bytes > 0: 145 c.hints['bytes'] = bytes 146 147 cols.append(c) 148 return cols 149 150 def _get_indices(self, tablename, conn=None): 151 # Get the OID of the parent table. 152 data, _ = self.sm.fetch("SELECT oid FROM pg_class WHERE relname = '%s'" 153 % tablename, conn=conn) 154 if not data: 155 return [] 156 157 table_OID = data[0][0] 158 indices = [] 159 data, _ = self.sm.fetch("SELECT pg_class.relname, indkey, indisprimary, " 160 "indisunique FROM pg_index LEFT JOIN pg_class " 161 "ON pg_index.indexrelid = pg_class.oid WHERE " 162 "pg_index.indrelid = %s" % table_OID, conn=conn) 163 for row in data: 164 # indkey is an "array" (we get a space-separated string of ints). 165 cols = map(int, row[1].split(" ")) 166 for col in cols: 167 d, _ = self.sm.fetch("SELECT attname FROM pg_attribute " 168 "WHERE attrelid = %s AND attnum = %s" 169 % (table_OID, col), conn=conn) 170 i = self.sm.indexclass(row[0], tablename, d[0][0], 171 bool(row[2]), bool(row[3])) 172 indices.append(i) 173 174 return indices 175 176 177 67 178 class StorageManagerPgSQL(db.StorageManagerDB): 68 179 """StoreManager to save and retrieve Units via pyPgSQL 1.35.""" … … 72 183 toAdapter = AdapterToPgSQL() 73 184 typeAdapter = FieldTypeAdapterPgSQL() 185 186 tablesetclass = PgTableSet 187 indexsetclass = PgIndexSet 74 188 75 189 def __init__(self, name, arena, allOptions={}): … … 85 199 setattr(self, k, v) 86 200 db.StorageManagerDB.__init__(self, name, arena, allOptions) 87 88 def sql_name(self, name, quoted=True):89 name = db.StorageManagerDB.sql_name(self, name, quoted)201 self.typeAdapter.sm = self 202 203 def quote(self, name): 90 204 if self.quote_all: 91 if quoted: 92 name = '"' + name.replace('"', '""') + '"' 93 else: 205 name = '"' + name.replace('"', '""') + '"' 206 return name 207 208 def sql_name(self, name): 209 name = db.StorageManagerDB.sql_name(self, name) 210 if not self.quote_all: 94 211 name = name.lower() 95 212 return name … … 140 257 cls = unit.__class__ 141 258 clsname = cls.__name__ 142 tablename = self.table _name(clsname)259 tablename = self.tables[clsname].name 143 260 144 261 fields = [] … … 150 267 continue 151 268 val = self.toAdapter.coerce(getattr(unit, key)) 152 fields.append(self. column_name(clsname, key))269 fields.append(self.quote(self.column_name(clsname, key))) 153 270 values.append(val) 154 271 … … 156 273 values = ", ".join(values) 157 274 self.execute('INSERT INTO %s (%s) VALUES (%s);' % 158 (s tr(tablename), fields, values))275 (self.quote(tablename), fields, values)) 159 276 160 277 # Grab the new ID. This is threadsafe because db.reserve has a mutex. 161 data, col_defs = self.fetch("SELECT last_value FROM %s_%s_seq;"162 % (clsname, cls.identifiers[0]))278 seqname = self.quote("%s_%s_seq" % (clsname, cls.identifiers[0])) 279 data, col_defs = self.fetch("SELECT last_value FROM %s;" % seqname) 163 280 setattr(unit, cls.identifiers[0], data[0][0]) 164 281 … … 167 284 def create_database(self): 168 285 c = self._template_conn() 169 self.execute('CREATE DATABASE %s' % self.sql_name(self.dbname), c) 286 dbname = self.quote(self.sql_name(self.dbname)) 287 self.execute('CREATE DATABASE %s' % dbname, c) 170 288 c.finish() 171 289 172 290 def drop_database(self): 173 291 c = self._template_conn() 174 self.execute("DROP DATABASE %s;" % self.sql_name(self.dbname), c) 292 dbname = self.quote(self.sql_name(self.dbname)) 293 self.execute("DROP DATABASE %s;" % dbname, c) 175 294 c.finish() 176 177 def has_storage(self, cls):178 # For some odd reason, libpq errors if you try to filter by tablename.179 sql = "SELECT tablename FROM pg_tables"180 data, cols = self.fetch(sql)181 return [self.table_name(cls.__name__, quoted=False)] in data182 295 183 296 def create_storage(self, cls): 184 297 """Create storage for the given class.""" 185 clsname = cls.__name__ 186 tablename = self.table_name(clsname) 187 typename = self.typeAdapter.coerce 188 298 colname = self.column_name 299 300 # Make a Table object. 301 tablename = self.table_name(cls.__name__) 302 t = self.tableclass(self, tablename) 303 304 indices = cls.indices() 189 305 fields = [] 190 306 for key in cls.properties: 191 dbtype = typename(cls, key) 307 dbtype = self.typeAdapter.coerce(cls, key) 308 prop = cls.property(key) 309 cname = colname(cls.__name__, key) 310 311 # Here's where we differ from the superclass: 312 # we have to manually CREATE SEQUENCE, and we must use 313 # class attributes to do so. 192 314 if 'nextval' in dbtype: 193 self.execute("CREATE SEQUENCE %s_%s_seq START %s;" 194 % (clsname, key, cls.sequencer.initial)) 195 fields.append('%s %s' % (self.column_name(clsname, key), dbtype)) 196 self.execute('CREATE TABLE %s (%s);' % (tablename, ", ".join(fields))) 197 198 for index in cls.indices(): 199 i = self.table_name("i" + clsname + index) 200 self.execute('CREATE INDEX %s ON %s (%s);' % 201 (i, tablename, self.column_name(clsname, index))) 202 203 def drop_index(self, cls, name): 204 clsname = cls.__name__ 205 for i in self.get_indices(clsname): 206 if i.colname == name: 207 self.execute('DROP INDEX %s;' % self.sql_name(i.name)) 208 209 def get_tables(self, conn=None): 210 data, _ = self.fetch("SELECT tablename FROM pg_tables WHERE " 211 "schemaname not in ('information_schema', 'pg_catalog')", 212 conn=conn) 213 return [db.Table(row[0]) for row in data] 214 215 def get_columns(self, tablename=None, conn=None): 216 data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'" 217 % tablename, conn=conn) 218 table_OID = data[0][0] 219 sql = ("SELECT attname, atttypid, attnum, attlen " 220 "FROM pg_attribute WHERE attrelid = %s" % table_OID) 221 data, _ = self.fetch(sql, conn=conn) 222 cols = [] 223 for row in data: 224 name = row[0] 225 if name in ('tableoid', 'cmax', 'xmax', 'cmin', 'xmin', 226 'oid', 'ctid'): 227 # This is a column which PostgreSQL defines automatically 228 continue 229 230 # Data type 231 dbtype, _ = self.fetch("SELECT typname, typlen FROM pg_type " 232 "WHERE oid = %s" % row[1]) 233 if dbtype: 234 dbtype = dbtype[0][0] 235 if dbtype in ('int2', 'int4'): 236 dbtype = int 237 elif dbtype == 'bool': 238 dbtype = bool 239 elif dbtype == 'int8': 240 dbtype = long 241 elif dbtype in ('float4', 'float8', 'money'): 242 dbtype = float 243 elif dbtype == 'numeric': 244 dbtype = decimal.Decimal 245 elif dbtype == 'date': 246 dbtype = datetime.date 247 elif dbtype in ('timestamp', 'timestamptz'): 248 dbtype = datetime.datetime 249 elif dbtype in ('time', 'timetz'): 250 dbtype = datetime.time 251 elif dbtype in ('char', 'varchar', 'bpchar', 'text'): 252 dbtype = str 253 else: 254 dbtype = None 255 256 # Default value 257 default, _ = self.fetch("SELECT adsrc FROM pg_attrdef " 258 "WHERE adnum = %s AND adrelid = %s" 259 % (row[2], table_OID)) 260 if default: 261 default = default[0][0] 262 if default.startswith("nextval("): 263 default = None 264 else: 265 default = None 266 267 c = db.Column(row[0], dbtype, default) 268 269 bytes = row[3] 270 if bytes > 0: 271 c.hints['bytes'] = bytes 272 273 cols.append(c) 274 return cols 275 276 def get_indices(self, tablename, conn=None): 277 # Get the OID of the parent table. 278 data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'" 279 % tablename, conn=conn) 280 if not data: 281 return [] 282 283 table_OID = data[0][0] 284 indices = [] 285 data, _ = self.fetch("SELECT pg_class.relname, indkey, indisprimary, " 286 "indisunique FROM pg_index LEFT JOIN pg_class " 287 "ON pg_index.indexrelid = pg_class.oid WHERE " 288 "pg_index.indrelid = %s" % table_OID, conn=conn) 289 for row in data: 290 cols = map(int, row[1].split(" ")) 291 for col in cols: 292 d, _ = self.fetch("SELECT attname FROM pg_attribute " 293 "WHERE attrelid = %s AND attnum = %s" 294 % (table_OID, col), conn=conn) 295 indices.append(db.Index(row[0], tablename, d[0][0], 296 bool(row[2]), bool(row[3]))) 297 298 return indices 299 315 seqname = self.quote("%s_%s_seq" % (tablename, cname)) 316 self.execute("CREATE SEQUENCE %s START %s;" 317 % (seqname, cls.sequencer.initial)) 318 319 col = self.columnclass(cname, dbtype, prop.type, 320 prop.default, prop.hints.copy()) 321 # Use the superclass call to avoid ALTER TABLE. 322 dict.__setitem__(t.columns, key, col) 323 324 if key in indices: 325 iname = self.table_name("i" + cls.__name__ + key) 326 i = self.indexclass(iname, tablename, cname) 327 # Use the superclass call to avoid CREATE INDEX. 328 dict.__setitem__(t.columns.indices, key, i) 329 330 # Attach to self.tables, which should call CREATE TABLE. 331 self.tables[cls.__name__] = t 332 trunk/storage/storesqlite.py
r225 r226 143 143 class FieldTypeAdapterSQLite(db.FieldTypeAdapter): 144 144 145 numeric_max_precision = 14 146 numeric_max_bytes = 7 147 145 148 def coerce(self, cls, key): 146 149 """coerce(cls, key) -> SQL typename for valuetype.""" … … 155 158 156 159 160 class SQLiteTableSet(db.TableSet): 161 162 def _get_tables(self, conn=None): 163 data, _ = self.sm.fetch("SELECT name FROM sqlite_master WHERE type = 'table'") 164 return [self.sm.tableclass(self.sm, row[0]) for row in data] 165 166 def _get_columns(self, tablename, conn=None): 167 data, coldefs = self.sm.fetch("SELECT * FROM %s WHERE 1 == 0" 168 % self.sm.quote(tablename), conn=conn) 169 return [self.sm.columnclass(col[0], "", str, None) for col in coldefs] 170 171 def _get_indices(self, tablename, conn=None): 172 data, _ = self.sm.fetch("SELECT name, tbl_name, sql FROM sqlite_master " 173 "WHERE type = 'index'") 174 indices = [] 175 for row in data: 176 colname = row[2].split("(")[-1] 177 colname = colname[1:-2] 178 indices.append(self.sm.indexclass(row[0], row[1], colname)) 179 return indices 180 181 def _rename(self, oldtable, newname): 182 if _rename_table_support: 183 self.sm.execute("ALTER TABLE %s RENAME TO %s" % 184 (self.sm.quote(oldtable.name), 185 self.sm.quote(newname))) 186 else: 187 raise NotImplementedError 188 189 190 class SQLiteColumnSet(db.ColumnSet): 191 192 def __setitem__(self, key, column): 193 t = self.table 194 tableset = t.sm.tables 195 196 if _add_column_support: 197 # We don't care about the type since SQLite is typeless 198 t.sm.execute("ALTER TABLE %s ADD COLUMN %s;" % 199 (t.sm.quote(t.name), t.sm.quote(column.name))) 200 dict.__setitem__(self, key, column) 201 else: 202 # Create the temporary table with the new fields (no indices). 203 temptable = t.copy() 204 temptable.name = "temp_" + temptable.name 205 temptable.columns.indices.clear() 206 dict.__setitem__(temptable.columns, key, column) 207 tableset[temptable.name] = temptable 208 209 # Copy data from the old table to the temp table. 210 selfields = [] 211 for k, c in temptable.columns.iteritems(): 212 qname = t.sm.quote(c.name) 213 if k == key: 214 # This is a new column. Populate with NULL. 215 qname = "NULL AS %s" % qname 216 selfields.append(qname) 217 t.sm.execute("INSERT INTO %s SELECT %s FROM %s;" % 218 (t.sm.quote(temptable.name), ", ".join(selfields), 219 t.sm.quote(t.name))) 220 221 # Drop the old table and create the new, final table. 222 newtable = temptable.copy() 223 newtable.name = t.name 224 tableset[t.name] = newtable 225 226 # Copy data from the temp table to the final table. 227 t.sm.execute("INSERT INTO %s SELECT * FROM %s;" % 228 (t.sm.quote(newtable.name), 229 t.sm.quote(temptable.name))) 230 231 # Drop the intermediate table. 232 tableset[temptable.name] 233 234 def __delitem__(self, key): 235 if key in self.indices: 236 del self.indices[key] 237 t = self.table 238 239 # Create the temporary table with the new fields (no indices). 240 temptable = t.copy() 241 temptable.name = "temp_" + temptable.name 242 temptable.columns.indices.clear() 243 dict.__delitem__(temptable.columns, key) 244 t.sm.tables[temptable.name] = temptable 245 246 # Copy data from the old table to the temp table. 247 selfields = [] 248 for k, c in temptable.columns.iteritems(): 249 qname = t.sm.quote(c.name) 250 selfields.append(qname) 251 t.sm.execute("INSERT INTO %s SELECT %s FROM %s;" % 252 (t.sm.quote(temptable.name), ", ".join(selfields), 253 t.sm.quote(t.name))) 254 255 # Drop the old table and create the new, final table. 256 newtable = temptable.copy() 257 newtable.name = t.name 258 t.sm.tables[t.name] = newtable 259 260 # Copy data from the temp table to the final table. 261 t.sm.execute("INSERT INTO %s SELECT * FROM %s;" % 262 (t.sm.quote(t.name), t.sm.quote(temptable.name))) 263 264 # Drop the intermediate table. 265 del t.sm.tables[temptable.name] 266 267 def rename(self, oldkey, newkey): 268 """Rename a Column.""" 269 oldcol = self[oldkey] 270 oldname = oldcol.name 271 t = self.table 272 newname = t.sm.column_name(self.table.name, newkey) 273 274 if oldname != newname: 275 # Create the temporary table with the new fields (no indices). 276 dict.__delitem__(self, oldkey) 277 dict.__setitem__(self, newkey, oldcol) 278 oldcol.name = newname 279 280 temptable = t.copy() 281 temptable.name = "temp_" + temptable.name 282 temptable.columns.indices.clear() 283 t.sm.tables[temptable.name] = temptable 284 285 # Copy data from the old table to the temp table. 286 selfields = [] 287 for k, c in temptable.columns.iteritems(): 288 qname = t.sm.quote(c.name) 289 if k == newkey: 290 qname = "%s AS %s" % (t.sm.quote(oldname), qname) 291 selfields.append(qname) 292 t.sm.execute("INSERT INTO %s SELECT %s FROM %s;" % 293 (t.sm.quote(temptable.name), ", ".join(selfields), 294 t.sm.quote(t.name))) 295 296 # Drop the old table and create the new, final table. 297 newtable = temptable.copy() 298 newtable.name = t.name 299 t.sm.tables[t.name] = newtable 300 301 # Copy data from the temp table to the final table. 302 # For some odd reason, using "SELECT *" mixes up the fields. 303 selfields = [t.sm.quote(c.name) for c in temptable.columns.values()] 304 selfields = ", ".join(selfields) 305 t.sm.execute("INSERT INTO %s (%s) SELECT %s FROM %s;" % 306 (t.sm.quote(newtable.name), selfields, selfields, 307 t.sm.quote(temptable.name))) 308 309 # Drop the intermediate table. 310 del t.sm.tables[temptable.name] 311 312 157 313 class StorageManagerSQLite(db.StorageManagerDB): 158 314 """StoreManager to save and retrieve Units via _sqlite.""" 159 315 160 316 sql_name_max_length = 0 317 161 318 decompiler = SQLiteDecompiler 162 319 toAdapter = AdapterToSQLite() 163 320 fromAdapter = AdapterFromSQLite() 164 321 typeAdapter = FieldTypeAdapterSQLite() 322 323 tablesetclass = SQLiteTableSet 324 columnsetclass = SQLiteColumnSet 165 325 166 326 def __init__(self, name, arena, allOptions={}): … … 172 332 db.StorageManagerDB.__init__(self, name, arena, allOptions) 173 333 174 def sql_name(self, name, quoted=True):175 """ sql_name(name, quoted=True) -> return name as a legal SQL identifier.334 def quote(self, name): 335 """Return name, quoted for use in an SQL statement. 176 336 177 337 From the SQLite docs: … … 186 346 ...we'll use the third option (square brackets). 187 347 """ 188 if quoted: 189 name = "[" + name + "]" 190 return name 348 return "[" + name + "]" 191 349 192 350 def _get_conn(self): … … 212 370 time.sleep(0.000001) 213 371 continue 214 raise 372 ## except _sqlite.DatabaseError, x: 373 ## # See http://www.sqlite.org/faq.html#q17 374 ## if x.args[0] == 'database schema has changed': 375 ## time.sleep(0.000001) 376 ## continue 377 raise 215 378 except Exception, x: 216 379 x.args += (query,) … … 260 423 msg = ("No association found between %s and %s." % (name1, name2)) 261 424 raise dejavu.AssociationError(msg) 262 near = '%s.%s' % (nearClass, self.column_name(nearClass, ua.nearKey)) 263 far = '%s.%s' % (farClass, self.column_name(farClass, ua.farKey)) 425 426 near = '%s.%s' % (self.quote(nearClass), 427 self.quote(self.column_name(nearClass, ua.nearKey))) 428 far = '%s.%s' % (self.quote(farClass), 429 self.quote(self.column_name(farClass, ua.farKey))) 264 430 265 431 on_clauses.append("%s = %s" % (near, far)) … … 302 468 continue 303 469 val = self.toAdapter.coerce(getattr(unit, key)) 304 fields.append(self. column_name(clsname, key))470 fields.append(self.quote(self.column_name(clsname, key))) 305 471 values.append(val) 306 472 … … 311 477 conn = self.connection() 312 478 self.execute('INSERT INTO %s (%s) VALUES (%s);' % 313 (s tr(tablename), fields, values), conn)479 (self.quote(tablename), fields, values), conn) 314 480 315 481 # Grab the new ID. This is safe because db.reserve has a mutex. … … 338 504 # the value of sequencer.initial - 1. 339 505 prev = cls.sequencer.initial - 1 340 tablename = self.table_name(cls.__name__ , quoted=False)506 tablename = self.table_name(cls.__name__) 341 507 d, c = self.fetch("SELECT * FROM SQLITE_SEQUENCE " 342 508 "WHERE name = '%s'" % tablename) … … 347 513 self.execute("INSERT INTO SQLITE_SEQUENCE (seq, name) " 348 514 "VALUES (%s, '%s')" % (prev, tablename)) 349 350 def _legacy_alter_table(self, cls, altermap): 351 """ALTER an SQLite table via an intermediate, temporary table. 352 353 altermap must be a dict of the form {newname: oldname}. 354 If oldname is given, that old field will be mapped to the new field. 355 If oldname is None, a new field will be added with the newname. 356 If newname is not present for an oldname, that field will be dropped. 357 """ 358 clsname = cls.__name__ 359 tempname = self.table_name("temp_" + clsname) 360 tablename = self.table_name(clsname) 361 362 # Create a temporary table with the new fields (no indices). 363 newfields = [self.sql_name(key) for key in altermap] 364 self.execute("CREATE TABLE %s (%s);" 365 % (tempname, ", ".join(newfields))) 366 367 # Copy data from the old table to the temp table. 368 selfields = [] 369 for newname, oldname in altermap.iteritems(): 370 if oldname == newname: 371 newname = self.sql_name(newname) 372 else: 373 if oldname is None: 374 oldname = self.toAdapter.coerce(None) 375 else: 376 oldname = self.sql_name(oldname) 377 newname = ("%s AS %s" % (oldname, self.sql_name(newname))) 378 selfields.append(newname) 379 self.execute("INSERT INTO %s SELECT %s FROM %s;" % 380 (tempname, ", ".join(selfields), tablename)) 381 382 # Drop the old table. 383 self.execute("DROP TABLE %s;" % tablename) 384 385 # Create the new, final table. 386 typename = self.typeAdapter.coerce 387 spec = [] 388 for key in altermap: 389 spec.append('%s %s' % (self.column_name(clsname, key), 390 typename(cls, key))) 391 self.execute('CREATE TABLE %s (%s);' % (tablename, ", ".join(spec))) 392 393 # Create a new index if necessary. 394 for newname, oldname in altermap.iteritems(): 395 if oldname is None and newname in cls.indices(): 396 i = self.table_name("i" + clsname + newname) 397 c = self.column_name(clsname, newname) 398 self.execute('CREATE INDEX %s ON %s (%s);' % 399 (i, tablename, c)) 400 401 # Copy data from the temp table to the final table. 402 self.execute("INSERT INTO %s SELECT * FROM %s;" % 403 (tablename, tempname)) 404 405 # Drop the intermediate table. 406 self.execute("DROP TABLE %s;" % tempname) 407 408 def _existing_fields(self, tablename): 409 """Pull field names from existing table.""" 410 data, coldefs = self.fetch("SELECT * FROM %s" % 411 self.table_name(tablename)) 412 return zip(*coldefs)[0] 413 414 def add_property(self, cls, name): 415 clsname = cls.__name__ 416 if _add_column_support: 417 self.execute("ALTER TABLE %s ADD COLUMN %s;" % 418 (self.table_name(clsname), 419 self.column_name(clsname, name))) 420 else: 421 altermap = dict([(x, x) for x in self._existing_fields(clsname)]) 422 altermap[name] = None 423 self._legacy_alter_table(cls, altermap) 424 425 def drop_property(self, cls, name): 426 altermap = dict([(x, x) for x in self._existing_fields(cls.__name__)]) 427 del altermap[name] 428 self._legacy_alter_table(cls, altermap) 429 430 def rename_property(self, cls, oldname, newname): 431 altermap = dict([(x, x) for x in self._existing_fields(cls.__name__)]) 432 del altermap[oldname] 433 altermap[newname] = oldname 434 self._legacy_alter_table(cls, altermap) 435 436 def drop_index(self, cls, name): 437 clsname = cls.__name__ 438 self.execute('DROP INDEX %s ON %s;' % 439 (self.sql_name("i" + clsname + name), 440 self.table_name(clsname))) 441 442 def get_tables(self, conn=None): 443 data, _ = self.fetch("SELECT name FROM sqlite_master WHERE type = 'table'") 444 return [db.Table(row[0]) for row in data] 445 446 def get_columns(self, tablename=None, conn=None): 447 data, coldefs = self.fetch("SELECT * FROM %s WHERE 1 == 0" 448 % self.sql_name(tablename), conn=conn) 449 cols = [] 450 for col in coldefs: 451 c = db.Column(col[0], str, None) 452 cols.append(c) 453 return cols 454 455 def get_indices(self, tablename, conn=None): 456 data, _ = self.fetch("SELECT name, tbl_name, sql FROM sqlite_master " 457 "WHERE type = 'index'") 458 indices = [] 459 for row in data: 460 colname = row[2].split("(")[-1] 461 colname = colname[1:-2] 462 indices.append(db.Index(row[0], row[1], colname)) 463 return indices 515 trunk/test/test_storemsaccess.py
r225 r226 51 51 schema=True) 52 52 for row in data: 53 match = targets.get(row[2]) 54 if not match: 55 continue 56 if match == row[3]: 57 ## print row[2], row[3], row[11] 53 if targets.get(row[2]) == row[3]: 58 54 dt = row[11] 59 55 60 56 if fta in ("CurrencyAdapter",): 61 obj.assertEqual(dt, storeado.adCurrency)57 obj.assertEqual(dt, 6) # adCurrency 62 58 else: 63 obj.assertEqual(dt, storeado.adDouble)59 obj.assertEqual(dt, 131) # adNumeric 64 60 obj.assertEqual(len(standard_runs), 0) 65 61 … … 68 64 69 65 # test the standard MS Access setup where Decimal and FixedPoint 70 # objects are stored in the database as INTEGERS, LONGS or DOUBLES66 # objects are stored in the database as INTEGERS, LONGS or NUMERIC 71 67 print 72 68 print "Standard MSAccess test." trunk/test/zoo_fixture.py
r225 r226 635 635 636 636 def test_Multithreading(self): 637 return 637 638 # Test threads overlapping on separate sandboxes 638 639 f = logic.Expression(lambda x: x.Legs == 4) … … 786 787 def test_DB_Introspection(self): 787 788 s = arena.stores.values()[0] 788 if getattr(s, "get_tables", None) is None:789 if not hasattr(s, "tables"): 789 790 return 790 791 791 tables = s.get_tables() 792 for t in tables: 793 ## print t 794 ## for c in s.get_columns(t.name): 795 ## print " ", c 796 ## for i in s.get_indices(t.name): 797 ## print " ", i 798 if t.name.lower() == "djvzoo": 799 zootable = t 800 self.assertEqual(zootable.name.lower(), "djvzoo") 801 cols = s.get_columns(zootable.name) 792 zootable = s.tables['Zoo'] 793 cols = zootable.columns 802 794 self.assertEqual(len(cols), 6) 803 804 cols = dict([(x.key.lower(), x) for x in cols]) 805 idcol = cols['id'] 806 # Since SQLite is typless, it will set all types to 'str' 795 idcol = cols['ID'] 796 # Since SQLite is typeless, we must handle when it uses 'str' 807 797 self.assert_(idcol.type in (int, str)) 808 798 self.assertEqual(idcol.default, None) 799 800 # Test the automatic construction of a Unit class. 801 uc = s.autoclass(zootable, "Zoo") 802 self.assert_(not issubclass(uc, Zoo)) 803 self.assertEqual(uc.__name__, "Zoo") 804 for pname in uc.properties: 805 p = getattr(uc, pname) 806 z = getattr(Zoo, pname) 807 self.assertEqual(p.key, z.key) 808 self.assertEqual(p.type, z.type) 809 self.assertEqual(p.default, z.default) 810 self.assertEqual(p.hints, z.hints) 809 811 810 812 def testzzzz_Schema_Upgrade(self): … … 879 881 (actual, decimal.Decimal(val), p, s)) 880 882 883 881 884 arena = dejavu.Arena() 882 885 … … 942 945 arena.register_all(globals()) 943 946 engines.register_classes(arena) 947 948 if hasattr(arena.stores['testSM'], "tables"): 949 arena.stores['testSM'].sync() 944 950 945 951 zs = ZooSchema(arena)
