Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

root/trunk/geniusql/providers/psycopg.py

Revision 1 (checked in by fumanchu, 6 years ago)

Initial import.

  • Property svn:eol-style set to native
Line 
1 # Use _psycopg directly to avoid overhead.
2 try:
3     # If possible, you should copy the _psycopg.pyd file into a top level
4     # so this SM can avoid importing the entire package.
5     import _psycopg
6 except ImportError:
7     from psycopg2 import _psycopg
8
9
10 import datetime
11 try:
12     import cPickle as pickle
13 except ImportError:
14     import pickle
15 import re
16 seq_name = re.compile(r"nextval\('([^:]+)'.*\)")
17 escape_oct = re.compile(r"[\000-\037\177-\377]")
18 replace_oct = lambda m: r"\\%03o" % ord(m.group(0))
19 unescape_oct = re.compile(r"\\(\d\d\d)")
20 replace_unoct = lambda m: chr(int(m.group(1), 8))
21
22 import dejavu
23 from dejavu import errors
24 from dejavu.storage import db
25
26
27 class AdapterToPsycoPg(db.AdapterToSQL):
28    
29     like_escapes = [("%", r"\\%"), ("_", r"\\_")]
30    
31     # Do these need to know if "SHOW DateStyle;" != "ISO, MDY" ?
32     def coerce_datetime_datetime_to_any(self, value):
33         return ("'%04d-%02d-%02d %02d:%02d:%02d.%06d'" %
34                 (value.year, value.month, value.day,
35                  value.hour, value.minute, value.second,
36                  value.microsecond))
37    
38     def coerce_datetime_date_to_any(self, value):
39         return "'%04d-%02d-%02d'" % (value.year, value.month, value.day)
40    
41     def coerce_datetime_time_to_any(self, value):
42         return ("'%02d:%02d:%02d.%06d'" %
43                 (value.hour, value.minute, value.second, value.microsecond))
44    
45     def coerce_any_to_bytea(self, value):
46         # See http://www.postgresql.org/docs/8.1/interactive/datatype-binary.html
47         value = pickle.dumps(value, 2)
48         def repl(char):
49             o = ord(char)
50             if o <= 31 or o == 39 or o == 92 or o >= 127:
51                 return r"\\%03d" % int(oct(o))
52             return char
53         return "'%s'::bytea" % "".join(map(repl, value))
54    
55     def do_pickle(self, value):
56         value = pickle.dumps(value, 2)
57         value = self.coerce_str_to_any(value, skip_encoding=False)
58         return value
59     coerce_dict_to_any = do_pickle
60     coerce_list_to_any = do_pickle
61     coerce_tuple_to_any = do_pickle
62    
63     def coerce_str_to_any(self, value, skip_encoding=False):
64         if not skip_encoding and not isinstance(value, str):
65             value = value.encode(self.encoding)
66         for pat, repl in self.escapes:
67             value = value.replace(pat, repl)
68        
69         # Escape octal sequences
70         value = escape_oct.sub(replace_oct, value)
71         return "'" + value + "'"
72
73
74 class AdapterFromPsycoPg(db.AdapterFromDB):
75    
76     def coerce_any_to_str(self, value):
77         # Unescape octal sequences
78         value = unescape_oct.sub(replace_unoct, value)
79         if isinstance(value, unicode):
80             return value.encode(self.encoding)
81         else:
82             return str(value)
83    
84     def coerce_any_to_datetime_datetime(self, value):
85         return value
86    
87     def coerce_any_to_datetime_date(self, value):
88         return value
89    
90     def coerce_any_to_datetime_time(self, value):
91         return value
92
93
94 class PsycoPgDecompiler(db.SQLDecompiler):
95    
96     def dejavu_icontainedby(self, op1, op2):
97         if isinstance(op1, db.ConstWrapper):
98             # Looking for text in a field. Use ILike (reverse terms).
99             return op2 + " ILIKE '%" + self.adapter.escape_like(op1) + "%'"
100         else:
101             # Looking for field in (a, b, c).
102             # Force all args to lowercase for case-insensitive comparison.
103             atoms = [self.adapter.coerce(x).lower() for x in op2.basevalue]
104             return "LOWER(%s) IN (%s)" % (op1, ", ".join(atoms))
105    
106     def dejavu_istartswith(self, x, y):
107         return x + " ILIKE '" + self.adapter.escape_like(y) + "%'"
108    
109     def dejavu_iendswith(self, x, y):
110         return x + " ILIKE '%" + self.adapter.escape_like(y) + "'"
111    
112     def dejavu_ieq(self, x, y):
113         # ILIKE with no wildcards should behave like ieq.
114         return x + " ILIKE '" + self.adapter.escape_like(y) + "'"
115    
116     def dejavu_year(self, x):
117         return "date_part('year', " + x + ")"
118    
119     def dejavu_month(self, x):
120         return "date_part('month', " + x + ")"
121    
122     def dejavu_day(self, x):
123         return "date_part('day', " + x + ")"
124
125
126 class PsycoPgIndexSet(db.IndexSet):
127    
128     def __delitem__(self, key):
129         """Drop the specified index."""
130         t = self.table
131         t.db.lock("Dropping index. Transactions not allowed.")
132         try:
133             # PG doesn't use DROP INDEX .. ON ..
134             t.db.execute('DROP INDEX %s;' % self[key].qname)
135         finally:
136             t.db.unlock()
137
138
139 class PsycoPgTable(db.Table):
140    
141     indexsetclass = PsycoPgIndexSet
142
143
144 class PsycoPgDatabase(db.Database):
145    
146     sql_name_max_length = 63
147     quote_all = True
148     poolsize = 10
149     encoding = 'SQL_ASCII'
150    
151     decompiler = PsycoPgDecompiler
152     adaptertosql = AdapterToPsycoPg()
153     adapterfromdb = AdapterFromPsycoPg()
154     tableclass = PsycoPgTable
155    
156     def _get_dbinfo(self, conn=None):
157         dbinfo = {}
158         try:
159             data, _ = self.fetch("SELECT pg_encoding_to_char(encoding) "
160                                  "FROM pg_database;", conn=conn)
161             dbinfo['encoding'] = data[0][0]
162         except _psycopg.DatabaseError, x:
163             if "does not exist" not in x.args[0]:
164                 raise
165         return dbinfo
166    
167     def _get_tables(self, conn=None):
168         data, _ = self.fetch("SELECT tablename FROM pg_tables WHERE schemaname"
169                              " not in ('information_schema', 'pg_catalog')",
170                              conn=conn)
171         return [self.tableclass(row[0], self.quote(row[0]), self)
172                 for row in data]
173    
174     def _get_table(self, tablename, conn=None):
175         data, _ = self.fetch("SELECT tablename FROM pg_tables WHERE "
176                              "tablename = '%s'" % tablename,
177                              conn=conn)
178         for name, in data:
179             if name == tablename:
180                 return self.tableclass(name, self.quote(name), self)
181         raise errors.MappingError(tablename)
182    
183     def _get_columns(self, tablename, conn=None):
184         # Get the OID of the table
185         data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'"
186                              % tablename, conn=conn)
187         table_OID = data[0][0]
188        
189         # Get index data so we can set col.key if pg_index.indisprimary
190         data, _ = self.fetch("SELECT indkey FROM pg_index WHERE indrelid "
191                              "= %s AND indisprimary" % table_OID, conn=conn)
192         if data:
193             # indkey is an "array" (we get a space-separated string of ints).
194             # These will equal pg_attribute.attnum, below.
195             indices = map(int, data[0][0].split(" "))
196         else:
197             indices = []
198        
199         # Get column data
200         sql = ("SELECT attname, atttypid, attnum, attlen, atttypmod "
201                "FROM pg_attribute WHERE attisdropped = False AND "
202                "attrelid = %s" % table_OID)
203         data, _ = self.fetch(sql, conn=conn)
204         cols = []
205         for row in data:
206             name = row[0]
207             if name in ('tableoid', 'cmax', 'xmax', 'cmin', 'xmin',
208                         'oid', 'ctid'):
209                 # This is a column which PostgreSQL defines automatically
210                 continue
211            
212             # Data type
213             dbtype, _ = self.fetch("SELECT typname, typlen FROM pg_type "
214                                    "WHERE oid = %s" % row[1])
215             if dbtype:
216                 dbtype = dbtype[0][0].upper()
217             else:
218                 dbtype = None
219             c = db.Column(row[0], self.quote(row[0]), dbtype,
220                           key=row[2] in indices)
221            
222             if dbtype in ('FLOAT4', 'FLOAT8'):
223                 c.hints['precision'] = row[3]
224             elif dbtype in ('MONEY', 'NUMERIC'):
225                 c.hints['precision'] = (row[4] >> 16) & 65535
226                 c.hints['scale'] = (row[4] & 65535) - 4
227            
228             # Default value
229             default, _ = self.fetch("SELECT adsrc FROM pg_attrdef "
230                                     "WHERE adnum = %s AND adrelid = %s"
231                                     % (row[2], table_OID))
232             if default:
233                 default = default[0][0]
234                 if default.startswith("nextval("):
235                     # Grab seqname from "nextval(seqname::[text|regclass])"
236                     c.autoincrement = True
237                     c.sequence_name = seq_name.search(default).group(1)
238                     c.initial = self.fetch("SELECT min_value FROM %s" %
239                                            c.sequence_name)[0][0]
240                     c.default = None
241                 else:
242                     # adsrc is always a string, so we must cast
243                     # it using our guessed type.
244                     c.default = self.python_type(dbtype)(default)
245             else:
246                 c.default = None
247            
248             if dbtype.startswith('BPCHAR') or dbtype.startswith('VARCHAR'):
249                 # See http://archives.postgresql.org/pgsql-interfaces/2004-07/msg00021.php
250                 c.hints['bytes'] = row[4] - 4
251             else:
252                 bytes = row[3]
253                 if bytes > 0:
254                     c.hints['bytes'] = bytes
255                 elif dbtype == 'TEXT':
256                     c.hints['bytes'] = 0
257            
258             cols.append(c)
259         return cols
260    
261     def _get_indices(self, tablename, conn=None):
262         # Get the OID of the parent table.
263         data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'"
264                              % tablename, conn=conn)
265         if not data:
266             return []
267        
268         table_OID = data[0][0]
269         indices = []
270         data, _ = self.fetch("SELECT pg_class.relname, indkey, indisprimary, "
271                              "indisunique FROM pg_index LEFT JOIN pg_class "
272                              "ON pg_index.indexrelid = pg_class.oid WHERE "
273                              "pg_index.indrelid = %s" % table_OID, conn=conn)
274         for row in data:
275             # indkey is an "array" (we get a space-separated string of ints).
276             cols = map(int, row[1].split(" "))
277             for col in cols:
278                 d, _ = self.fetch("SELECT attname FROM pg_attribute "
279                                   "WHERE attrelid = %s AND attnum = %s"
280                                   % (table_OID, col), conn=conn)
281                 i = db.Index(row[0], self.quote(row[0]), tablename,
282                              d[0][0], bool(row[3]))
283                 indices.append(i)
284        
285         return indices
286    
287     def python_type(self, dbtype):
288         """Return a Python type which can store values of the given dbtype."""
289         dbtype = dbtype.upper()
290         if dbtype in ('INT2', 'INT4', 'INTEGER', 'SMALLINT'):
291             return int
292         elif dbtype in ('BOOL', 'BOOLEAN'):
293             return bool
294         elif dbtype in ('INT8', 'BIGINT'):
295             return long
296         elif dbtype in ('FLOAT4', 'FLOAT8', 'MONEY', 'DOUBLE PRECISION', 'REAL'):
297             return float
298         elif dbtype.startswith('NUMERIC'):
299             if db.decimal:
300                 return db.decimal.Decimal
301             elif db.fixedpoint:
302                 return db.fixedpoint.FixedPoint
303         elif dbtype == 'DATE':
304             return datetime.date
305         elif dbtype in ('TIMESTAMP', 'TIMESTAMPTZ'):
306             return datetime.datetime
307         elif dbtype in ('TIME', 'TIMETZ'):
308             return datetime.time
309         elif dbtype in ('BYTEA'):
310             return str
311         for t in ('CHAR', 'VARCHAR', 'BPCHAR', 'TEXT'):
312             if dbtype.startswith(t):
313                 return str
314        
315         raise TypeError("Database type %r could not be converted "
316                         "to a Python type." % dbtype)
317    
318     def col_def(self, column):
319         """Return a clause for the given column for CREATE or ALTER TABLE.
320         
321         This will be of the form "name type [DEFAULT [x | nextval('seq')]]".
322         
323         PostgreSQL creates the sequence in a separate statement.
324         """
325         if column.autoincrement:
326             default = "nextval('%s')" % column.sequence_name
327         else:
328             default = column.default or ""
329             if not isinstance(default, str):
330                 default = self.adaptertosql.coerce(default, column.dbtype)
331        
332         if default:
333             default = " DEFAULT %s" % default
334        
335         return '%s %s%s' % (column.qname, column.dbtype, default)
336    
337     def create_sequence(self, table, column):
338         """Create a SEQUENCE for the given column and set its sequence_name."""
339         sname = column.sequence_name
340         if sname is None:
341             sname = self.quote("%s_%s_seq" % (table.name, column.name))
342             column.sequence_name = sname
343         self.execute("CREATE SEQUENCE %s START %s;" % (sname, column.initial))
344    
345     def drop_sequence(self, column):
346         """Drop a SEQUENCE for the given column and remove its sequence_name."""
347         if column.sequence_name is not None:
348             self.execute("DROP SEQUENCE %s;" % column.sequence_name)
349             column.sequence_name = None
350    
351     def quote(self, name):
352         if self.quote_all:
353             name = '"' + name.replace('"', '""') + '"'
354         return name
355    
356     def sql_name(self, name):
357         name = db.Database.sql_name(self, name)
358         if not self.quote_all:
359             name = name.lower()
360         return name
361    
362     default_isolation = "READ COMMITTED"
363    
364     def _get_conn(self):
365         try:
366             c = _psycopg.connect(self.Connect)
367             c.set_isolation_level(0)
368             return c
369         except _psycopg.DatabaseError, x:
370             if x.args[0].startswith('could not connect'):
371                 raise db.OutOfConnectionsError()
372             raise
373    
374     def _del_conn(self, conn):
375         conn.close()
376    
377     def _template_conn(self):
378         atoms = self.Connect.split(" ")
379         tmplconn = ""
380         for atom in atoms:
381             k, v = atom.split("=", 1)
382             if k == 'dbname': v = 'template1'
383             tmplconn += "%s=%s " % (k, v)
384         c = _psycopg.connect(tmplconn)
385         # Allow statements like CREATE DATABASE to run outside a transaction.
386         c.set_isolation_level(0)
387         return c
388    
389     def execute(self, query, conn=None):
390         """execute(query, conn=None) -> result set."""
391         if conn is None:
392             conn = self.connection()
393         if isinstance(query, unicode):
394             query = query.encode(self.adaptertosql.encoding)
395         self.log(query)
396         cursor = conn.cursor()
397         try:
398             cursor.execute(query)
399         finally:
400             cursor.close()
401    
402     def fetch(self, query, conn=None):
403         """fetch(query, conn=None) -> rowdata, columns."""
404         if conn is None:
405             conn = self.connection()
406         if isinstance(query, unicode):
407             query = query.encode(self.adaptertosql.encoding)
408         self.log(query)
409        
410         cursor = conn.cursor()
411         try:
412             cursor.execute(query)
413             data = cursor.fetchall()
414             coldefs = cursor.description
415         finally:
416             cursor.close()
417        
418         return data, coldefs
419    
420     def create_database(self):
421         self.lock("Creating database. Transactions not allowed.")
422         try:
423             # Must shut down all connections to avoid
424             # "being accessed by other users" error.
425             self.connection.shutdown()
426            
427             c = self._template_conn()
428             encoding = self.encoding
429             if encoding:
430                 encoding = " WITH ENCODING '%s'" % encoding
431             self.execute("CREATE DATABASE %s%s" % (self.qname, encoding), c)
432             c.close()
433             del c
434             self.clear()
435         finally:
436             self.unlock()
437    
438     def drop_database(self):
439         self.lock("Dropping database. Transactions not allowed.")
440         try:
441             # Must shut down all connections to avoid
442             # "being accessed by other users" error.
443             self.connection.shutdown()
444            
445             c = self._template_conn()
446             self.execute("DROP DATABASE %s;" % self.qname, c)
447             c.close()
448             del c
449             self.clear()
450         finally:
451             self.unlock()
452    
453     def version(self):
454         c = self._template_conn()
455         data, _ = self.fetch("SELECT version();", c)
456         v, = data[0]
457         c.close()
458         return "%s\npsycopg version: %s" % (v, _psycopg.__version__)
459
460
461
462 class StorageManagerPsycoPg(db.StorageManagerDB):
463     """StoreManager to save and retrieve Units via psycopg2."""
464    
465     databaseclass = PsycoPgDatabase
466    
467     def __init__(self, arena, allOptions={}):
468         for atom in allOptions['Connect'].split(" "):
469             k, v = atom.split("=", 1)
470             if k == "dbname":
471                 allOptions['name'] = v
472         db.StorageManagerDB.__init__(self, arena, allOptions)
473    
474     def _seq_UnitSequencerInteger(self, unit):
475         """Reserve a unit using the table's SERIAL field."""
476         cls = unit.__class__
477         t = self.db[cls.__name__]
478        
479         fields = []
480         values = []
481         for key in cls.properties:
482             col = t[key]
483             if col.autoincrement:
484                 # Skip this field, since we're using a sequencer
485                 continue
486             val = self.db.adaptertosql.coerce(getattr(unit, key), col.dbtype)
487             fields.append(col.qname)
488             values.append(val)
489        
490         transconn = self.db.get_transaction()
491        
492         fields = ", ".join(fields)
493         values = ", ".join(values)
494         self.db.execute('INSERT INTO %s (%s) VALUES (%s);' %
495                         (t.qname, fields, values), transconn)
496        
497         # Grab the new ID. This is threadsafe because db.reserve has a mutex.
498         idcol = cls.identifiers[0]
499         seqname = t[idcol].sequence_name
500         data, col_defs = self.db.fetch("SELECT last_value FROM %s;" % seqname,
501                                        transconn)
502         setattr(unit, idcol, data[0][0])
503
Note: See TracBrowser for help on using the browser.