Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storepypgsql.py

Revision 273 (checked in by fumanchu, 7 years ago)

Oops.

  • Property svn:eol-style set to native
Line 
1 # Use libpq directly to avoid all of the DB-API overhead.
2 from pyPgSQL import libpq
3
4 import datetime
5 import re
6 seq_name = re.compile(r"nextval\('([^:]+)'.*\)")
7
8 import dejavu
9 from dejavu.storage import db
10
11
12 class AdapterToPgSQL(db.AdapterToSQL):
13    
14     like_escapes = [("%", r"\\%"), ("_", r"\\_")]
15    
16     # Do these need to know if "SHOW DateStyle;" != "ISO, MDY" ?
17     def coerce_datetime_datetime_to_any(self, value):
18         return ("'%04d-%02d-%02d %02d:%02d:%02d.%06d'" %
19                 (value.year, value.month, value.day,
20                  value.hour, value.minute, value.second,
21                  value.microsecond))
22    
23     def coerce_datetime_date_to_any(self, value):
24         return "'%04d-%02d-%02d'" % (value.year, value.month, value.day)
25    
26     def coerce_datetime_time_to_any(self, value):
27         return ("'%02d:%02d:%02d.%06d'" %
28                 (value.hour, value.minute, value.second, value.microsecond))
29
30
31 class PgSQLDecompiler(db.SQLDecompiler):
32    
33     def dejavu_icontainedby(self, op1, op2):
34         if isinstance(op1, db.ConstWrapper):
35             # Looking for text in a field. Use ILike (reverse terms).
36             return op2 + " ILIKE '%" + self.adapter.escape_like(op1) + "%'"
37         else:
38             # Looking for field in (a, b, c).
39             # Force all args to lowercase for case-insensitive comparison.
40             atoms = [self.adapter.coerce(x).lower() for x in op2.basevalue]
41             return "LOWER(%s) IN (%s)" % (op1, ", ".join(atoms))
42    
43     def dejavu_istartswith(self, x, y):
44         return x + " ILIKE '" + self.adapter.escape_like(y) + "%'"
45    
46     def dejavu_iendswith(self, x, y):
47         return x + " ILIKE '%" + self.adapter.escape_like(y) + "'"
48    
49     def dejavu_ieq(self, x, y):
50         # ILIKE with no wildcards should behave like ieq.
51         return x + " ILIKE '" + self.adapter.escape_like(y) + "'"
52    
53     def dejavu_year(self, x):
54         return "date_part('year', " + x + ")"
55    
56     def dejavu_month(self, x):
57         return "date_part('month', " + x + ")"
58
59
60
61 class PgIndexSet(db.IndexSet):
62    
63     def __delitem__(self, key):
64         """Drop the specified index."""
65         self.table.db.execute('DROP INDEX %s;' % self[key].qname)
66
67
68 class PgColumnSet(db.ColumnSet):
69    
70     def __init__(self, table):
71         dict.__init__(self)
72         self.table = table
73         self.indices = PgIndexSet(self.table)
74    
75     def __setitem__(self, key, column):
76         t = self.table
77         if key in self:
78             del self[key]
79        
80         if column.autoincrement:
81             seqname = t.db.quote("%s_%s_seq" % (t.name, column.name))
82             column.sequence_name = seqname
83             default = "nextval('%s')" % seqname
84             t.db.execute("CREATE SEQUENCE %s START %s;"
85                          % (seqname, column.default))
86         else:
87             default = column.default or ""
88             if not isinstance(default, str):
89                 default = t.db.adaptertosql.coerce(default, column.dbtype)
90        
91         if default:
92             default = " DEFAULT %s" % default
93        
94         t.db.execute("ALTER TABLE %s ADD COLUMN %s %s%s;" %
95                      (t.qname, column.qname, column.dbtype, default))
96         dict.__setitem__(self, key, column)
97
98
99 class PgDatabase(db.Database):
100    
101     sql_name_max_length = 63
102     quote_all = True
103     poolsize = 10
104    
105     decompiler = PgSQLDecompiler
106     adaptertosql = AdapterToPgSQL()
107    
108     def _get_tables(self, conn=None):
109         data, _ = self.fetch("SELECT tablename FROM pg_tables WHERE schemaname"
110                              " not in ('information_schema', 'pg_catalog')",
111                              conn=conn)
112         return [db.Table(self, row[0], self.quote(row[0])) for row in data]
113    
114     def _get_columns(self, tablename, conn=None):
115         # Get the OID of the table
116         data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'"
117                              % tablename, conn=conn)
118         table_OID = data[0][0]
119        
120         # Get index data so we can set col.key if pg_index.indisprimary
121         data, _ = self.fetch("SELECT indkey FROM pg_index WHERE indrelid "
122                              "= %s AND indisprimary" % table_OID, conn=conn)
123         if data:
124             # indkey is an "array" (we get a space-separated string of ints).
125             # These will equal pg_attribute.attnum, below.
126             indices = map(int, data[0][0].split(" "))
127         else:
128             indices = []
129        
130         # Get column data
131         sql = ("SELECT attname, atttypid, attnum, attlen, atttypmod "
132                "FROM pg_attribute WHERE attrelid = %s" % table_OID)
133         data, _ = self.fetch(sql, conn=conn)
134         cols = []
135         for row in data:
136             name = row[0]
137             if name in ('tableoid', 'cmax', 'xmax', 'cmin', 'xmin',
138                         'oid', 'ctid'):
139                 # This is a column which PostgreSQL defines automatically
140                 continue
141            
142             # Data type
143             dbtype, _ = self.fetch("SELECT typname, typlen FROM pg_type "
144                                    "WHERE oid = %s" % row[1])
145             if dbtype:
146                 dbtype = dbtype[0][0].upper()
147             else:
148                 dbtype = None
149             c = db.Column(row[0], self.quote(row[0]), dbtype,
150                           key=row[2] in indices)
151            
152             if dbtype in ('FLOAT4', 'FLOAT8'):
153                 c.hints['precision'] = row[3]
154             elif dbtype in ('MONEY', 'NUMERIC'):
155                 c.hints['precision'] = (row[4] >> 16) & 65535
156                 c.hints['scale'] = (row[4] & 65535) - 4
157            
158             # Default value
159             default, _ = self.fetch("SELECT adsrc FROM pg_attrdef "
160                                     "WHERE adnum = %s AND adrelid = %s"
161                                     % (row[2], table_OID))
162             if default:
163                 default = default[0][0]
164                 if default.startswith("nextval("):
165                     # Grab seqname from "nextval(seqname::[text|regclass])"
166                     c.autoincrement = True
167                     c.sequence_name = seq_name.search(default).group(1)
168                     c.default = self.fetch("SELECT min_value FROM %s" %
169                                            c.sequence_name)[0][0]
170                 else:
171                     # adsrc is always a string, so we must cast
172                     # it using our guessed type.
173                     c.default = self.python_type(dbtype)(default)
174             else:
175                 c.default = None
176            
177             if dbtype.startswith('BPCHAR') or dbtype.startswith('VARCHAR'):
178                 # See http://archives.postgresql.org/pgsql-interfaces/2004-07/msg00021.php
179                 c.hints['bytes'] = row[4] - 4
180             else:
181                 bytes = row[3]
182                 if bytes > 0:
183                     c.hints['bytes'] = bytes
184                 elif dbtype == 'TEXT':
185                     c.hints['bytes'] = 0
186            
187             cols.append(c)
188         return cols
189    
190     def _get_indices(self, tablename, conn=None):
191         # Get the OID of the parent table.
192         data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'"
193                              % tablename, conn=conn)
194         if not data:
195             return []
196        
197         table_OID = data[0][0]
198         indices = []
199         data, _ = self.fetch("SELECT pg_class.relname, indkey, indisprimary, "
200                              "indisunique FROM pg_index LEFT JOIN pg_class "
201                              "ON pg_index.indexrelid = pg_class.oid WHERE "
202                              "pg_index.indrelid = %s" % table_OID, conn=conn)
203         for row in data:
204             # indkey is an "array" (we get a space-separated string of ints).
205             cols = map(int, row[1].split(" "))
206             for col in cols:
207                 d, _ = self.fetch("SELECT attname FROM pg_attribute "
208                                   "WHERE attrelid = %s AND attnum = %s"
209                                   % (table_OID, col), conn=conn)
210                 i = db.Index(row[0], self.quote(row[0]), tablename,
211                              d[0][0], bool(row[3]))
212                 indices.append(i)
213        
214         return indices
215    
216     def python_type(self, dbtype):
217         """Return a Python type which can store values of the given dbtype."""
218         if dbtype in ('INT2', 'INT4', 'INTEGER', 'SMALLINT'):
219             return int
220         elif dbtype in ('BOOL', 'BOOLEAN'):
221             return bool
222         elif dbtype in ('INT8', 'BIGINT'):
223             return long
224         elif dbtype in ('FLOAT4', 'FLOAT8', 'MONEY', 'DOUBLE PRECISION', 'REAL'):
225             return float
226         elif dbtype.startswith('NUMERIC'):
227             if db.decimal:
228                 return db.decimal.Decimal
229             elif db.fixedpoint:
230                 return db.fixedpoint.FixedPoint
231         elif dbtype == 'DATE':
232             return datetime.date
233         elif dbtype in ('TIMESTAMP', 'TIMESTAMPTZ'):
234             return datetime.datetime
235         elif dbtype in ('TIME', 'TIMETZ'):
236             return datetime.time
237         for t in ('CHAR', 'VARCHAR', 'BPCHAR', 'TEXT'):
238             if dbtype.startswith(t):
239                 return str
240        
241         raise TypeError("Database type %r could not be converted "
242                         "to a Python type." % dbtype)
243    
244     def __setitem__(self, key, table):
245         if key in self:
246             del self[key]
247        
248         fields = []
249         pk = []
250         for col in table.columns.itervalues():
251             if col.autoincrement:
252                 seqname = table.db.quote("%s_%s_seq" % (table.name, col.name))
253                 col.sequence_name = seqname
254                 default = "nextval('%s')" % seqname
255                 table.db.execute("CREATE SEQUENCE %s START %s;"
256                                  % (seqname, col.default))
257             else:
258                 default = col.default or ""
259                 if not isinstance(default, str):
260                     default = table.db.adaptertosql.coerce(default, col.dbtype)
261            
262             if default:
263                 default = " DEFAULT %s" % default
264            
265             f = '%s %s%s' % (col.qname, col.dbtype, default)
266             fields.append(f)
267            
268             if col.key:
269                 pk.append(col.qname)
270        
271         if pk:
272             pk = ", PRIMARY KEY (%s)" % ", ".join(pk)
273         else:
274             pk = ""
275        
276         self.execute('CREATE TABLE %s (%s%s);' %
277                      (table.qname, ", ".join(fields), pk))
278        
279         for index in table.columns.indices.itervalues():
280             self.execute('CREATE INDEX %s ON %s (%s);' %
281                          (index.qname, table.qname,
282                           self.quote(index.colname)))
283        
284         dict.__setitem__(self, key, table)
285    
286     def __delitem__(self, key):
287         table = self[key]
288         self.execute('DROP TABLE %s;' % table.qname)
289         for col in table.columns.itervalues():
290             if col.autoincrement:
291                 self.execute("DROP SEQUENCE %s;" % col.sequence_name)
292         dict.__delitem__(self, key)
293    
294     def quote(self, name):
295         if self.quote_all:
296             name = '"' + name.replace('"', '""') + '"'
297         return name
298    
299     def sql_name(self, name):
300         name = db.Database.sql_name(self, name)
301         if not self.quote_all:
302             name = name.lower()
303         return name
304    
305     def _get_conn(self):
306         try:
307             return libpq.PQconnectdb(self.Connect)
308         except libpq.DatabaseError, x:
309             if x.args[0].startswith('could not connect'):
310                 raise db.OutOfConnectionsError()
311             raise
312    
313     def _del_conn(self, conn):
314         conn.finish()
315    
316     def _template_conn(self):
317         atoms = self.Connect.split(" ")
318         tmplconn = ""
319         for atom in atoms:
320             k, v = atom.split("=", 1)
321             if k == 'dbname': v = 'template1'
322             tmplconn += "%s=%s " % (k, v)
323         return libpq.PQconnectdb(tmplconn)
324    
325     def fetch(self, query, conn=None):
326         """fetch(query, conn=None) -> rowdata, columns."""
327         res = self.execute(query, conn)
328        
329         columns = []
330         if res.resultType != libpq.EMPTY_QUERY:
331             for index in xrange(res.nfields):
332                 columns.append((res.fname(index), res.ftype(index)))
333        
334         data = [[res.getvalue(row, col) for col in xrange(res.nfields)]
335                 for row in xrange(res.ntuples)]
336         res.clear()
337        
338         return data, columns
339    
340     def create_database(self):
341         c = self._template_conn()
342         self.execute('CREATE DATABASE %s' % self.qname, c)
343         c.finish()
344         self.clear()
345    
346     def drop_database(self):
347         # Must shut down all connections to avoid
348         # "being accessed by other users" error.
349         self.connection.shutdown()
350        
351         c = self._template_conn()
352         self.execute("DROP DATABASE %s;" % self.qname, c)
353         c.finish()
354         self.clear()
355    
356     def version(self):
357         c = self._template_conn()
358         v = c.version
359         c.finish()
360         return v
361
362
363
364 class StorageManagerPgSQL(db.StorageManagerDB):
365     """StoreManager to save and retrieve Units via pyPgSQL 1.35."""
366    
367     databaseclass = PgDatabase
368    
369     def __init__(self, arena, allOptions={}):
370         for atom in allOptions['Connect'].split(" "):
371             k, v = atom.split("=", 1)
372             if k == "dbname":
373                 allOptions['name'] = v
374         db.StorageManagerDB.__init__(self, arena, allOptions)
375    
376     def _seq_UnitSequencerInteger(self, unit):
377         """Reserve a unit using the table's SERIAL field."""
378         cls = unit.__class__
379         t = self.db[cls.__name__]
380        
381         fields = []
382         values = []
383         for key in cls.properties:
384             col = t.columns[key]
385             if col.autoincrement:
386                 # Skip this field, since we're using a sequencer
387                 continue
388             val = self.db.adaptertosql.coerce(getattr(unit, key), col.dbtype)
389             fields.append(col.qname)
390             values.append(val)
391        
392         fields = ", ".join(fields)
393         values = ", ".join(values)
394         self.db.execute('INSERT INTO %s (%s) VALUES (%s);' %
395                         (t.qname, fields, values))
396        
397         # Grab the new ID. This is threadsafe because db.reserve has a mutex.
398         idcol = cls.identifiers[0]
399         seqname = t.columns[idcol].sequence_name
400         data, col_defs = self.db.fetch("SELECT last_value FROM %s;" % seqname)
401         setattr(unit, idcol, data[0][0])
402
Note: See TracBrowser for help on using the browser.