Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

root/trunk/geniusql/providers/pypgsql.py

Revision 7 (checked in by fumanchu, 6 years ago)

Various bugfixes related to Table.created.

  • Property svn:eol-style set to native
Line 
1 # Use libpq directly to avoid all of the DB-API overhead.
2 from pyPgSQL import libpq
3
4 import datetime
5 try:
6     import cPickle as pickle
7 except ImportError:
8     import pickle
9
10 import re
11 seq_name = re.compile(r"nextval\('([^:]+)'.*\)")
12 escape_oct = re.compile(r"[\000-\037\177-\377]")
13 replace_oct = lambda m: r"\\%03o" % ord(m.group(0))
14 unescape_oct = re.compile(r"\\(\d\d\d)")
15 replace_unoct = lambda m: chr(int(m.group(1), 8))
16
17
18 import geniusql
19 from geniusql import errors, typerefs
20
21
22 class AdapterToPgSQL(geniusql.AdapterToSQL):
23    
24     like_escapes = [("%", r"\\%"), ("_", r"\\_")]
25    
26     # Do these need to know if "SHOW DateStyle;" != "ISO, MDY" ?
27     def coerce_datetime_datetime_to_any(self, value):
28         return ("'%04d-%02d-%02d %02d:%02d:%02d.%06d'" %
29                 (value.year, value.month, value.day,
30                  value.hour, value.minute, value.second,
31                  value.microsecond))
32    
33     def coerce_datetime_date_to_any(self, value):
34         return "'%04d-%02d-%02d'" % (value.year, value.month, value.day)
35    
36     def coerce_datetime_time_to_any(self, value):
37         return ("'%02d:%02d:%02d.%06d'" %
38                 (value.hour, value.minute, value.second, value.microsecond))
39    
40     def coerce_any_to_bytea(self, value):
41         # See http://www.postgresql.org/docs/8.1/interactive/datatype-binary.html
42         value = pickle.dumps(value, 2)
43         def repl(char):
44             o = ord(char)
45             if o <= 31 or o == 39 or o == 92 or o >= 127:
46                 return r"\\%03d" % int(oct(o))
47             return char
48         return "'%s'::bytea" % "".join(map(repl, value))
49    
50     def do_pickle(self, value):
51         value = pickle.dumps(value, 2)
52         value = self.coerce_str_to_any(value, skip_encoding=False)
53         return value
54     coerce_dict_to_any = do_pickle
55     coerce_list_to_any = do_pickle
56     coerce_tuple_to_any = do_pickle
57    
58     def coerce_str_to_any(self, value, skip_encoding=False):
59         if not skip_encoding and not isinstance(value, str):
60             value = value.encode(self.encoding)
61         for pat, repl in self.escapes:
62             value = value.replace(pat, repl)
63        
64         # Escape octal sequences
65         value = escape_oct.sub(replace_oct, value)
66         return "'" + value + "'"
67
68
69 class AdapterFromPgSQL(geniusql.AdapterFromDB):
70    
71     def coerce_any_to_str(self, value):
72         # Unescape octal sequences
73         value = unescape_oct.sub(replace_unoct, value)
74         if isinstance(value, unicode):
75             return value.encode(self.encoding)
76         else:
77             return str(value)
78
79
80 class PgSQLDecompiler(geniusql.SQLDecompiler):
81    
82     def dejavu_icontainedby(self, op1, op2):
83         if isinstance(op1, geniusql.ConstWrapper):
84             # Looking for text in a field. Use ILike (reverse terms).
85             return op2 + " ILIKE '%" + self.adapter.escape_like(op1) + "%'"
86         else:
87             # Looking for field in (a, b, c).
88             # Force all args to lowercase for case-insensitive comparison.
89             atoms = [self.adapter.coerce(x).lower() for x in op2.basevalue]
90             return "LOWER(%s) IN (%s)" % (op1, ", ".join(atoms))
91    
92     def dejavu_istartswith(self, x, y):
93         return x + " ILIKE '" + self.adapter.escape_like(y) + "%'"
94    
95     def dejavu_iendswith(self, x, y):
96         return x + " ILIKE '%" + self.adapter.escape_like(y) + "'"
97    
98     def dejavu_ieq(self, x, y):
99         # ILIKE with no wildcards should behave like ieq.
100         return x + " ILIKE '" + self.adapter.escape_like(y) + "'"
101    
102     def dejavu_year(self, x):
103         return "date_part('year', " + x + ")"
104    
105     def dejavu_month(self, x):
106         return "date_part('month', " + x + ")"
107    
108     def dejavu_day(self, x):
109         return "date_part('day', " + x + ")"
110
111
112 class PgIndexSet(geniusql.IndexSet):
113    
114     def __delitem__(self, key):
115         """Drop the specified index."""
116         t = self.table
117         t.db.lock("Dropping index. Transactions not allowed.")
118         try:
119             # PG doesn't use DROP INDEX .. ON ..
120             t.db.execute('DROP INDEX %s;' % self[key].qname)
121         finally:
122             t.db.unlock()
123
124
125 class PgTable(geniusql.Table):
126    
127     indexsetclass = PgIndexSet
128
129
130 class PgDatabase(geniusql.Database):
131    
132     sql_name_max_length = 63
133     quote_all = True
134     poolsize = 10
135     encoding = 'SQL_ASCII'
136    
137     decompiler = PgSQLDecompiler
138     adaptertosql = AdapterToPgSQL()
139     adapterfromdb = AdapterFromPgSQL()
140     tableclass = PgTable
141    
142     def _get_dbinfo(self, conn=None):
143         dbinfo = {}
144         try:
145             data, _ = self.fetch("SELECT pg_encoding_to_char(encoding) "
146                                  "FROM pg_database;", conn=conn)
147             dbinfo['encoding'] = data[0][0]
148         except libpq.DatabaseError, x:
149             if "does not exist" not in x.args[0]:
150                 raise
151         return dbinfo
152    
153     def _get_tables(self, conn=None):
154         data, _ = self.fetch("SELECT tablename FROM pg_tables WHERE schemaname"
155                              " not in ('information_schema', 'pg_catalog')",
156                              conn=conn)
157         return [self.tableclass(row[0], self.quote(row[0]),
158                                 self, created=True)
159                 for row in data]
160    
161     def _get_table(self, tablename, conn=None):
162         data, _ = self.fetch("SELECT tablename FROM pg_tables WHERE "
163                              "tablename = '%s'" % tablename,
164                              conn=conn)
165         for name, in data:
166             if name == tablename:
167                 return self.tableclass(name, self.quote(name),
168                                        self, created=True)
169         raise errors.MappingError(tablename)
170    
171     def _get_columns(self, tablename, conn=None):
172         # Get the OID of the table
173         data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'"
174                              % tablename, conn=conn)
175         table_OID = data[0][0]
176        
177         # Get index data so we can set col.key if pg_index.indisprimary
178         data, _ = self.fetch("SELECT indkey FROM pg_index WHERE indrelid "
179                              "= %s AND indisprimary" % table_OID, conn=conn)
180         if data:
181             # indkey is an "array" (we get a space-separated string of ints).
182             # These will equal pg_attribute.attnum, below.
183             indices = map(int, data[0][0].split(" "))
184         else:
185             indices = []
186        
187         # Get column data
188         sql = ("SELECT attname, atttypid, attnum, attlen, atttypmod "
189                "FROM pg_attribute WHERE attisdropped = False AND "
190                "attrelid = %s" % table_OID)
191         data, _ = self.fetch(sql, conn=conn)
192         cols = []
193         for row in data:
194             name = row[0]
195             if name in ('tableoid', 'cmax', 'xmax', 'cmin', 'xmin',
196                         'oid', 'ctid'):
197                 # This is a column which PostgreSQL defines automatically
198                 continue
199            
200             # Data type
201             dbtype, _ = self.fetch("SELECT typname, typlen FROM pg_type "
202                                    "WHERE oid = %s" % row[1])
203             if dbtype:
204                 dbtype = dbtype[0][0].upper()
205             else:
206                 dbtype = None
207             c = geniusql.Column(self.python_type(dbtype), dbtype,
208                                 None, {}, row[2] in indices,
209                                 row[0], self.quote(row[0]))
210            
211             if dbtype in ('FLOAT4', 'FLOAT8'):
212                 c.hints['precision'] = row[3]
213             elif dbtype in ('MONEY', 'NUMERIC'):
214                 c.hints['precision'] = (row[4] >> 16) & 65535
215                 c.hints['scale'] = (row[4] & 65535) - 4
216            
217             # Default value
218             default, _ = self.fetch("SELECT adsrc FROM pg_attrdef "
219                                     "WHERE adnum = %s AND adrelid = %s"
220                                     % (row[2], table_OID))
221             if default:
222                 default = default[0][0]
223                 if default.startswith("nextval("):
224                     # Grab seqname from "nextval(seqname::[text|regclass])"
225                     c.autoincrement = True
226                     c.sequence_name = seq_name.search(default).group(1)
227                     c.initial = self.fetch("SELECT min_value FROM %s" %
228                                            c.sequence_name)[0][0]
229                     c.default = None
230                 else:
231                     # adsrc is always a string, so we must cast
232                     # it using our guessed type.
233                     c.default = self.python_type(dbtype)(default)
234             else:
235                 c.default = None
236            
237             if dbtype.startswith('BPCHAR') or dbtype.startswith('VARCHAR'):
238                 # See http://archives.postgresql.org/pgsql-interfaces/2004-07/msg00021.php
239                 c.hints['bytes'] = row[4] - 4
240             else:
241                 bytes = row[3]
242                 if bytes > 0:
243                     c.hints['bytes'] = bytes
244                 elif dbtype == 'TEXT':
245                     c.hints['bytes'] = 0
246            
247             cols.append(c)
248         return cols
249    
250     def _get_indices(self, tablename, conn=None):
251         # Get the OID of the parent table.
252         data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'"
253                              % tablename, conn=conn)
254         if not data:
255             return []
256        
257         table_OID = data[0][0]
258         indices = []
259         data, _ = self.fetch("SELECT pg_class.relname, indkey, indisprimary, "
260                              "indisunique FROM pg_index LEFT JOIN pg_class "
261                              "ON pg_index.indexrelid = pg_class.oid WHERE "
262                              "pg_index.indrelid = %s" % table_OID, conn=conn)
263         for row in data:
264             # indkey is an "array" (we get a space-separated string of ints).
265             cols = map(int, row[1].split(" "))
266             for col in cols:
267                 d, _ = self.fetch("SELECT attname FROM pg_attribute "
268                                   "WHERE attrelid = %s AND attnum = %s"
269                                   % (table_OID, col), conn=conn)
270                 i = geniusql.Index(row[0], self.quote(row[0]), tablename,
271                                    d[0][0], bool(row[3]))
272                 indices.append(i)
273        
274         return indices
275    
276     def python_type(self, dbtype):
277         """Return a Python type which can store values of the given dbtype."""
278         dbtype = dbtype.upper()
279         if dbtype in ('INT2', 'INT4', 'INTEGER', 'SMALLINT'):
280             return int
281         elif dbtype in ('BOOL', 'BOOLEAN'):
282             return bool
283         elif dbtype in ('INT8', 'BIGINT'):
284             return long
285         elif dbtype in ('FLOAT4', 'FLOAT8', 'MONEY', 'DOUBLE PRECISION', 'REAL'):
286             return float
287         elif dbtype.startswith('NUMERIC'):
288             if typerefs.decimal:
289                 return typerefs.decimal.Decimal
290             elif typerefs.fixedpoint:
291                 return typerefs.fixedpoint.FixedPoint
292         elif dbtype == 'DATE':
293             return datetime.date
294         elif dbtype in ('TIMESTAMP', 'TIMESTAMPTZ'):
295             return datetime.datetime
296         elif dbtype in ('TIME', 'TIMETZ'):
297             return datetime.time
298         elif dbtype in ('BYTEA'):
299             return str
300         for t in ('CHAR', 'VARCHAR', 'BPCHAR', 'TEXT'):
301             if dbtype.startswith(t):
302                 return str
303        
304         raise TypeError("Database type %r could not be converted "
305                         "to a Python type." % dbtype)
306    
307     def columnclause(self, column):
308         """Return a clause for the given column for CREATE or ALTER TABLE.
309         
310         This will be of the form "name type [DEFAULT [x | nextval('seq')]]".
311         
312         PostgreSQL creates the sequence in a separate statement.
313         """
314         if column.autoincrement:
315             default = "nextval('%s')" % column.sequence_name
316         else:
317             default = column.default or ""
318             if not isinstance(default, str):
319                 default = self.adaptertosql.coerce(default, column.dbtype)
320        
321         if default:
322             default = " DEFAULT %s" % default
323        
324         return '%s %s%s' % (column.qname, column.dbtype, default)
325    
326     def create_sequence(self, table, column):
327         """Create a SEQUENCE for the given column and set its sequence_name."""
328         sname = column.sequence_name
329         if sname is None:
330             sname = self.quote("%s_%s_seq" % (table.name, column.name))
331             column.sequence_name = sname
332         self.execute("CREATE SEQUENCE %s START %s;" % (sname, column.initial))
333    
334     def drop_sequence(self, column):
335         """Drop a SEQUENCE for the given column and remove its sequence_name."""
336         if column.sequence_name is not None:
337             self.execute("DROP SEQUENCE %s;" % column.sequence_name)
338             column.sequence_name = None
339    
340     def quote(self, name):
341         if self.quote_all:
342             name = '"' + name.replace('"', '""') + '"'
343         return name
344    
345     def sql_name(self, name):
346         name = geniusql.Database.sql_name(self, name)
347         if not self.quote_all:
348             name = name.lower()
349         return name
350    
351     default_isolation = "READ COMMITTED"
352    
353     def _get_conn(self):
354         try:
355             return libpq.PQconnectdb(self.Connect)
356         except libpq.DatabaseError, x:
357             if x.args[0].startswith('could not connect'):
358                 raise errors.OutOfConnectionsError()
359             raise
360    
361     def _del_conn(self, conn):
362         conn.finish()
363    
364     def _template_conn(self):
365         atoms = self.Connect.split(" ")
366         tmplconn = ""
367         for atom in atoms:
368             k, v = atom.split("=", 1)
369             if k == 'dbname': v = 'template1'
370             tmplconn += "%s=%s " % (k, v)
371         return libpq.PQconnectdb(tmplconn)
372    
373     def fetch(self, query, conn=None):
374         """fetch(query, conn=None) -> rowdata, columns."""
375         res = self.execute(query, conn)
376        
377         columns = []
378         if res.resultType != libpq.EMPTY_QUERY:
379             for index in xrange(res.nfields):
380                 columns.append((res.fname(index), res.ftype(index)))
381        
382         data = [[res.getvalue(row, col) for col in xrange(res.nfields)]
383                 for row in xrange(res.ntuples)]
384         res.clear()
385        
386         return data, columns
387    
388     def _grab_new_ids(self, table, idkeys, conn):
389         """Insert a row using the table's SERIAL field."""
390         newids = {}
391         for idkey in idkeys:
392             col = table[idkey]
393             seq = col.sequence_name
394             data, _ = self.fetch("SELECT last_value FROM %s;" % seq, conn)
395             newids[idkey] = data[0][0]
396         return newids
397    
398     def create_database(self):
399         self.lock("Creating database. Transactions not allowed.")
400         try:
401             c = self._template_conn()
402             encoding = self.encoding
403             if encoding:
404                 encoding = " WITH ENCODING '%s'" % encoding
405             self.execute("CREATE DATABASE %s%s" % (self.qname, encoding), c)
406             c.finish()
407             self.clear()
408         finally:
409             self.unlock()
410    
411     def drop_database(self):
412         self.lock("Dropping database. Transactions not allowed.")
413         try:
414             # Must shut down all connections to avoid
415             # "being accessed by other users" error.
416             self.connection.shutdown()
417            
418             c = self._template_conn()
419             self.execute("DROP DATABASE %s;" % self.qname, c)
420             c.finish()
421             self.clear()
422         finally:
423             self.unlock()
424    
425     def version(self):
426         c = self._template_conn()
427         v = c.version
428         c.finish()
429         return v
430
Note: See TracBrowser for help on using the browser.