Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

root/trunk/geniusql/providers/postgres.py

Revision 323 (checked in by lakin, 2 months ago)

fixing the Postgresql LIKE queries.

  • Property svn:eol-style set to native
Line 
1 import datetime
2 try:
3     import cPickle as pickle
4 except ImportError:
5     import pickle
6 import re
7 seq_name = re.compile(r"nextval\('([^:]+)'.*\)")
8 escape_oct = re.compile(r"[\000-\037\177-\377]")
9 replace_oct = lambda m: r"\\%03o" % ord(m.group(0))
10 unescape_oct = re.compile(r"\\\\(\d\d\d)")
11 replace_unoct = lambda m: chr(int(m.group(1), 8))
12 import threading
13
14 import geniusql
15 from geniusql import adapters, dbtypes, deparse, errors
16
17
18 # ------------------------------ Adapters ------------------------------ #
19
20
21 class PgDATE_Adapter(adapters.date_to_SQL92DATE):
22    
23     def binary_op(self, op1, op, sqlop, op2):
24         if op2.pytype is datetime.timedelta:
25             # Postgres assumes a "date" is actually midnight, so we
26             # need to drop any h:m:s from our interval.
27             return "(%s %s date_trunc('day', %s))" % (op1.sql, sqlop, op2.sql)
28         elif op2.pytype is datetime.date:
29             # Cast to timestamp to achieve an INTERVAL result
30             return "(%s::TIMESTAMP %s %s::TIMESTAMP)" % (op1.sql, sqlop, op2.sql)
31         raise TypeError("unsupported operand type(s) for %s: "
32                         "%r and %r" % (op, op1.pytype, op2.pytype))
33
34
35 class datetime_to_PgTIMESTAMPTZ_Adapter(adapters.Adapter):
36     """Adapter for timezone-naive datetime objects.
37     
38     Postgres stores all timestamps as UTC internally, adjusting inbound
39     values based on their timezone component, and offsetting outbound
40     values relative to the Postgres 'timezone' config entry.
41     
42     This adapter assumes you always want to push and pull datetime objects
43     that have no timezone. Therefore, it doesn't supply a timezone atom
44     when sending datetime data to Postgres. When retrieving values, any
45     offset is silently ignored. That is, in both case, we assume you
46     correctly set the connection's timezone attribute.
47     """
48    
49     def push(self, value, dbtype):
50         if value is None:
51             return 'NULL'
52         return ("'%04d-%02d-%02d %02d:%02d:%02d'" %
53                 (value.year, value.month, value.day,
54                  value.hour, value.minute, value.second))
55    
56     def pull(self, value, dbtype):
57         if value is None:
58             return None
59         if isinstance(value, datetime.datetime):
60             return value
61         chunks = (value[0:4], value[5:7], value[8:10],
62                   value[11:13], value[14:16], value[17:19])
63         args = map(int, chunks)
64        
65         ms, tz = None, None
66         mstz = value[19:]
67         if mstz:
68             signpos = mstz.find("+")
69             if signpos == -1:
70                 signpos = mstz.find("-")
71            
72             if signpos != -1:
73                 # We have a timezone. Split it off.
74                 ms = mstz[:signpos]
75             else:
76                 ms = mstz
77            
78             if ms:
79                 ms = int(ms.strip("."))
80        
81         args.append(ms or 0)
82        
83         return datetime.datetime(*args)
84
85
86 class datetime_tz_to_PgTIMESTAMPTZ_Adapter(adapters.Adapter):
87     """Adapter for timezone-aware datetime objects.
88     
89     Postgres stores all timestamps as UTC internally, adjusting inbound
90     values based on their timezone component, and offsetting outbound
91     values relative to the Postgres 'timezone' config entry.
92     
93     This adapter assumes you always want to push and pull datetime objects
94     that have a valid tzinfo. Therefore, it always tries to supply a
95     timezone atom when sending datetime data to Postgres. If you push
96     a datetime with a tzinfo of None, "+00" is used for the timezone.
97     When retrieving values, any offset in the database value is used to
98     form a valid tzinfo object for the value (see dbtypes.FixedTimeZone).
99     In both directions, we assume you correctly set the connection's
100     timezone attribute.
101     """
102    
103     def push(self, value, dbtype):
104         if value is None:
105             return 'NULL'
106        
107         if value.tzinfo is None:
108             h, m = 0, 0
109         else:
110             offset = value.tzinfo.utcoffset(value)
111             minutes = (offset.days * 1440) + (offset.seconds / 60)
112             h = minutes / 60
113             m = abs(minutes) % 60
114        
115         if h < 0:
116             h = abs(h)
117             sign = "-"
118         else:
119             sign = "+"
120  
121         return ("TIMESTAMP WITH TIME ZONE "
122                 "'%04d-%02d-%02d %02d:%02d:%02d.%06d%s%02d:%02d'" %
123                 (value.year, value.month, value.day,
124                  value.hour, value.minute, value.second, value.microsecond,
125                  sign, h, m))
126    
127     def pull(self, value, dbtype):
128         if value is None:
129             return None
130         if isinstance(value, datetime.datetime):
131             return value
132         chunks = (value[0:4], value[5:7], value[8:10],
133                   value[11:13], value[14:16], value[17:19])
134         args = map(int, chunks)
135        
136         ms, tz = None, None
137         mstz = value[19:]
138         if mstz:
139             signpos = mstz.find("+")
140             if signpos == -1:
141                 signpos = mstz.find("-")
142            
143             if signpos != -1:
144                 # We have a timezone. Split it off.
145                 ms, tz = mstz[:signpos], mstz[signpos:]
146             else:
147                 ms, tz = mstz, ""
148            
149             if ms:
150                 ms = int(ms.strip("."))
151            
152             if tz:
153                 if ":" in tz:
154                     h, m = map(int, h.split(":", 1))
155                 else:
156                     h, m = int(tz), 0
157                 tz = dbtypes.FixedTimeZone((h * 60) + m)
158        
159         args.append(ms or 0)
160         args.append(tz or None)
161         return datetime.datetime(*args)
162
163
164 class PgINTERVAL_Adapter(adapters.Adapter):
165    
166     def push(self, value, dbtype):
167         if value is None:
168             return 'NULL'
169         # Ignore microseconds for now
170         h, m = divmod(value.seconds, 3600)
171         m, s = divmod(m, 60)
172         return "interval '%s %s:%s:%s'" % (value.days, h, m, s)
173    
174     def pull(self, value, dbtype):
175         if value is None:
176             return None
177         if isinstance(value, datetime.timedelta):
178             return value
179        
180         # When an interval is returned, it will be of typename
181         # "interval" or "TIMESTAMP".
182         # Assume it's in ISO format; e.g. "964 days 18:29:45.4769999981"
183         # >>> re.split(r"( ?days? ?)", "18:35:49.3222")
184         # ['18:35:49.3222']
185         # >>> re.split(r"( ?days? ?)", "964 days 18:29:45.4769999981")
186         # ['964', ' days ', '18:29:45.4769999981']
187         # >>> re.split(r"( ?days? ?)", "964 days")
188         # ['964', ' days', '']
189         # >>> re.split(r"( ?days? ?)", "1 day")
190         # ['1', ' day', '']
191         days = 0
192         atoms = re.split(r"( ?days? ?)", value)
193         hms = atoms.pop()
194         if atoms:
195             # ...then we have a day component
196             days = int(atoms[0])
197             if not hms:
198                 return datetime.timedelta(days)
199        
200         h, m, s = hms.split(":", 2)
201         if h.startswith("-"):
202             neg = True
203             h = abs(int(h))
204         else:
205             neg = False
206             h = int(h)
207         s = (h * 3600) + (int(m) * 60) + float(s)
208         if neg:
209             s = -s
210        
211         return datetime.timedelta(days, s)
212    
213     def binary_op(self, op1, op, sqlop, op2):
214         if op2.pytype is datetime.date:
215             # Postgres assumes a "date" is actually midnight, so we
216             # need to drop any h:m:s from our interval.
217             return "(date_trunc('day', %s) %s %s)" % (op1.sql, sqlop, op2.sql)
218         elif op2.pytype in (datetime.datetime, datetime.timedelta):
219             return "(%s %s %s)" % (op1.sql, sqlop, op2.sql)
220         raise TypeError("unsupported operand type(s) for %s: "
221                         "%r and %r" % (op, op1.pytype, op2.pytype))
222
223
224 class Pg_LIKE_Mixin(object):
225    
226     like_escapes = [("%", r"\%"), ("_", r"\_")]
227    
228     def escape_like(self, sql):
229         """Prepare a string value for use in a LIKE comparison."""
230         # Notice we strip leading and trailing quote-marks.
231         sql = sql.strip("'\"")
232         for pat, repl in self.like_escapes:
233             sql = sql.replace(pat, repl)
234         return sql
235    
236     def like_op(self, op1, op2, ignore_case=False,
237                 start_only=False, end_only=False):
238         """Return the SQL for 'op1 LIKE op2' (or raise TypeError).
239         
240         op1 and op2 will be SQLExpression objects.
241         
242         If 'ignore_case' is False (the default), then the LIKE comparison
243         will be performed in a case-sensitive manner; otherwise (if
244         ignore_case is True), the LIKE comparison will be performed in
245         a case-INsensitive manner.
246         
247         If 'start_only' is True, then op2 will be matched only at the start
248         of op1. If False (the default), then op2 will be matched anywhere.
249         
250         If 'end_only' is True, then op2 will be matched only at the end
251         of op1. If False (the default), then op2 will be matched anywhere.
252         
253         If both 'start_only' and 'end_only' are True, then op2 will only
254         match op1 if they are identical.
255         """
256         likeexpr = self.escape_like(op2.sql)
257         if start_only:
258             start = ''
259         else:
260             start = '%'
261         if end_only:
262             end = ''
263         else:
264             end = '%'
265         if ignore_case:
266             sql = op1.sql + " ILIKE '" + start + likeexpr + end + "'"
267         else:
268             sql = op1.sql + " LIKE '" + start + likeexpr + end + "'"
269         return sql
270
271
272 class Pg_str_to_VARCHAR(Pg_LIKE_Mixin, adapters.str_to_SQL92VARCHAR):
273    
274     def push(self, value, dbtype):
275         if value is None:
276             return 'NULL'
277         if not isinstance(value, str):
278             value = value.encode(dbtype.encoding)
279         for pat, repl in self.escapes:
280             value = value.replace(pat, repl)
281         return "'" + value + "'"
282    
283     def pull(self, value, dbtype):
284         if value is None:
285             return None
286         return value
287
288
289 class Pg_unicode_to_VARCHAR(Pg_LIKE_Mixin, adapters.unicode_to_SQL92VARCHAR):
290    
291     def push(self, value, dbtype):
292         if value is None:
293             return 'NULL'
294         if not isinstance(value, str):
295             value = value.encode(dbtype.encoding)
296         for pat, repl in self.escapes:
297             value = value.replace(pat, repl)
298         return "'" + value + "'"
299    
300     def pull(self, value, dbtype):
301         if value is None:
302             return None
303         if isinstance(value, buffer):
304             value = str(value)
305
306         return unicode(value, dbtype.encoding)
307
308
309 class PgPickler(Pg_LIKE_Mixin, adapters.Pickler):
310    
311     def push(self, value, dbtype):
312         if value is None:
313             return 'NULL'
314         value = pickle.dumps(value, 2)
315        
316         if not isinstance(value, str):
317             value = value.encode(dbtype.encoding)
318         for pat, repl in self.escapes:
319             value = value.replace(pat, repl)
320        
321         # Escape octal sequences
322         value = escape_oct.sub(replace_oct, value)
323         return "'" + value + "'"
324    
325     def pull(self, value, dbtype):
326         if value is None:
327             return None
328         # Unescape octal sequences
329         value = unescape_oct.sub(replace_unoct, value)
330         for pat, repl in self.escapes:
331             value = value.replace(repl, pat)
332         return pickle.loads(value)
333
334
335
336 class PgFLOAT4_Adapter(adapters.float_to_SQL92REAL):
337    
338     def push(self, value, dbtype):
339         if value is None:
340             return 'NULL'
341         # Use quotes to restrict the value to single precision, so that
342         # comparisons work between existing values and supplied constants.
343         # See http://archives.postgresql.org/pgsql-bugs/2004-02/msg00062.php
344         return "'%r'" % value
345    
346     def compare_op(self, op1, op, sqlop, op2):
347         if isinstance(op2.dbtype, FLOAT8):
348             # Downcast to the smaller type
349             return "(%s %s (%s)::FLOAT4)" % (op1.sql, sqlop, op2.sql)
350         elif isinstance(op2.dbtype, (INT2, INT4, INT8, FLOAT4)):
351             return "(%s %s %s)" % (op1.sql, sqlop, op2.sql)
352         raise TypeError("unsupported operand type(s) for %s: "
353                         "%r and %r" % (op, op1.pytype, op2.pytype))
354
355
356 # ---------------------------- BYTEA Adapters ---------------------------- #
357
358
359 class Pg_str_to_BYTEA(Pg_LIKE_Mixin, adapters.str_to_SQL92VARCHAR):
360     """Python str to PostgreSQL bytea adapter.
361     
362     For the most part, Postgres bytea works like Python's str: a sequence
363     of bytes. Certain bytes have to be octal-escaped for consumption by PG.
364     
365     See http://www.postgresql.org/docs/8.1/interactive/datatype-binary.html
366     """
367    
368     def push(self, value, dbtype):
369         if value is None:
370             return 'NULL'
371         def repl(char):
372             o = ord(char)
373             if o <= 31 or o == 39 or o == 92 or o >= 127:
374                 return r"\\%03d" % int(oct(o))
375             return char
376         return "'%s'::bytea" % "".join(map(repl, value))
377    
378     def pull(self, value, dbtype):
379         if value is None:
380             return None
381         # Unescape octal sequences
382         value = unescape_oct.sub(replace_unoct, value)
383         return unicode(value, dbtype.encoding)
384    
385     def escape_like(self, sql):
386         """Prepare a string value for use in a LIKE comparison."""
387         # Notice we strip leading and trailing quote-marks.
388         sql = sql.strip("'\"")
389         for pat, repl in self.like_escapes:
390             sql = sql.replace(pat, repl)
391         # BYTEA requires an additional set of backslashes for the RHS of LIKE
392         sql = sql.replace("\\", "\\\\")
393         return sql
394
395
396 class Pg_unicode_to_BYTEA(Pg_unicode_to_VARCHAR):
397     """Python unicode to PostgreSQL bytea adapter.
398     
399     For the most part, Postgres bytea works like Python's str: a sequence
400     of bytes. Certain bytes have to be octal-escaped for consumption by PG.
401     
402     See http://www.postgresql.org/docs/8.1/interactive/datatype-binary.html
403     """
404    
405     def push(self, value, dbtype):
406         # TODO STRABS: Can we reverse this translation?  Probably, but
407         #              that still doesn't make it suitable for storing
408         #              unicode text.
409         if value is None:
410             return 'NULL'
411         if not isinstance(value, str):
412             value = value.encode(dbtype.encoding)
413         def repl(char):
414             o = ord(char)
415             if o <= 31 or o == 39 or o == 92 or o >= 127:
416                 return r"\\%03d" % int(oct(o))
417             return char
418         return "'%s'::bytea" % "".join(map(repl, value))
419    
420     def escape_like(self, sql):
421         """Prepare a string value for use in a LIKE comparison."""
422         # Notice we strip leading and trailing quote-marks.
423         sql = sql.strip("'\"")
424         for pat, repl in self.like_escapes:
425             sql = sql.replace(pat, repl)
426         # BYTEA requires an additional set of backslashes for the RHS of LIKE
427         sql = sql.replace("\\", "\\\\")
428         return sql
429
430
431 class PgBYTEA_Pickler(PgPickler):
432     """Python object to PostgreSQL bytea adapter.
433     
434     For the most part, Postgres bytea works like Python's str: a sequence
435     of bytes. Certain bytes have to be octal-escaped for consumption by PG.
436     
437     See http://www.postgresql.org/docs/8.1/interactive/datatype-binary.html
438     """
439    
440     def push(self, value, dbtype):
441         if value is None:
442             return 'NULL'
443        
444         value = pickle.dumps(value, 2)
445        
446         def repl(char):
447             o = ord(char)
448             if o <= 31 or o == 39 or o == 92 or o >= 127:
449                 return r"\\%03d" % int(oct(o))
450             return char
451         return "'%s'::bytea" % "".join(map(repl, value))
452    
453     def escape_like(self, sql):
454         """Prepare a string value for use in a LIKE comparison."""
455         # Notice we strip leading and trailing quote-marks.
456         sql = sql.strip("'\"")
457         for pat, repl in self.like_escapes:
458             sql = sql.replace(pat, repl)
459         # BYTEA requires an additional set of backslashes for the RHS of LIKE
460         sql = sql.replace("\\", "\\\\")
461         return sql
462
463
464 # ---------------------------- DatabaseTypes ---------------------------- #
465
466 # See http://www.postgresql.org/docs/8.1/static/datatype.html
467
468 # Not implemented here:
469 #
470 # box         rectangular box in the plane
471 # cidr        IPv4 or IPv6 network address
472 # circle      circle in the plane
473 # line        infinite line in the plane
474 # lseg        line segment in the plane
475 # macaddr     MAC address
476 # path        geometric path in the plane
477 # point       geometric point in the plane
478 # polygon     closed geometric path in the plane
479 # timetz      time of day, including time zone
480
481
482 class BOOLEAN(dbtypes.SQL99BOOLEAN):
483     """A logical Boolean (true/false)."""
484     synonyms = ['BOOL']
485
486
487 class BYTEA(dbtypes.FrozenByteType):
488     """A type for binary data ("byte array")."""
489     default_adapters = {str: Pg_str_to_BYTEA(),
490                         unicode: Pg_unicode_to_BYTEA(),
491                         None: PgBYTEA_Pickler(),
492                         }
493     default_pytype = str
494     encoding = 'utf8'
495
496 class BIT(dbtypes.SQL92VARCHAR):
497     """A fixed-length bit string"""
498     variable = False
499     default_adapters = {str: Pg_str_to_VARCHAR(),
500                         unicode: Pg_unicode_to_VARCHAR(),
501                         None: PgPickler(),
502                         }
503
504 class VARBIT(dbtypes.SQL92VARCHAR):
505     """A variable-length bit string."""
506     synonyms = ['BIT VARYING']
507     variable = True
508     default_adapters = {str: Pg_str_to_VARCHAR(),
509                         unicode: Pg_unicode_to_VARCHAR(),
510                         None: PgPickler(),
511                         }
512
513 class CHAR(dbtypes.SQL92CHAR):
514     """A fixed-length character string."""
515     synonyms = ['CHARACTER', 'BPCHAR']
516     default_adapters = {str: Pg_str_to_VARCHAR(),
517                         unicode: Pg_unicode_to_VARCHAR(),
518                         None: PgPickler(),
519                         }
520
521 class VARCHAR(dbtypes.SQL92VARCHAR):
522     """A variable-length character string."""
523
524     # http://www.postgresql.org/docs/8.0/static/datatype-character.html
525     max_bytes = 2**30
526
527     synonyms = ['CHARACTER VARYING']
528     variable = True
529     default_adapters = {str: Pg_str_to_VARCHAR(),
530                         unicode: Pg_unicode_to_VARCHAR(),
531                         None: PgPickler(),
532                         }
533
534
535 class ComparableInfinity(object):
536    
537     def __cmp__(self, other):
538         if isinstance(other, self.__class__):
539             return False
540         return True
541    
542     def __str__(self):
543         return "Infinity"
544    
545     def __repr__(self):
546         return "%s.%s()" % (self.__module__, self.__class__.__name__)
547
548
549
550 class TEXT(dbtypes.TEXT):
551     """A variable-length character string."""
552     # TEXT has no hard byte limit.
553     _bytes = max_bytes = ComparableInfinity()
554    
555     default_adapters = dbtypes.TEXT.default_adapters.copy()
556     default_adapters.update({str: Pg_str_to_VARCHAR(),
557                              unicode: Pg_unicode_to_VARCHAR(),
558                              None: PgPickler(),
559                              })
560
561 class NAME(dbtypes.TEXT):
562     """63-character type for storing system identifiers."""
563     _bytes = max_bytes = 63
564    
565     default_adapters = dbtypes.TEXT.default_adapters.copy()
566     default_adapters.update({str: Pg_str_to_VARCHAR(),
567                              unicode: Pg_unicode_to_VARCHAR(),
568                              None: PgPickler(),
569                              })
570
571
572 # Float types.
573 # "In addition to ordinary numeric values, the floating-point types
574 # have several special values: Infinity, -Infinity, NaN."
575
576 class FLOAT4(dbtypes.SQL92REAL):
577     """A single precision floating-point number."""
578     synonyms = ['REAL']
579     default_adapters = {float: PgFLOAT4_Adapter()}
580
581 class FLOAT8(dbtypes.SQL92DOUBLE):
582     """A double precision floating-point number."""
583     synonyms = ['DOUBLE PRECISION']
584
585
586 class INT2(dbtypes.SQL92SMALLINT):
587     """A signed two-byte integer."""
588     synonyms = ['SMALLINT']
589
590 class INT4(dbtypes.SQL92INTEGER):
591     """A signed four-byte integer."""
592    
593     # "The data types serial and bigserial are not true types, but merely
594     # a notational convenience for setting up unique identifier columns
595     # (similar to the AUTO_INCREMENT property supported by some other
596     # databases).
597     synonyms = ['INT', 'INTEGER', 'SERIAL', 'SERIAL4']
598
599 class INT8(dbtypes.SQL92INTEGER):
600     synonyms = ['BIGINT', 'BIGSERIAL', 'SERIAL8']
601     _bytes = max_bytes = 8
602     default_adapters = {int: adapters.int_to_SQL92INTEGER(8),
603                         long: adapters.int_to_SQL92INTEGER(8),
604                         }
605
606
607 class TIMESTAMP(dbtypes.SQL92TIMESTAMP):
608     """A date and time. Timezone naive."""
609     default_adapters = {datetime.datetime: datetime_to_PgTIMESTAMPTZ_Adapter()}
610
611 class TIMESTAMPTZ(dbtypes.SQL92TIMESTAMP):
612     """A date and time. Timezone aware."""
613     default_adapters = {datetime.datetime: datetime_tz_to_PgTIMESTAMPTZ_Adapter()}
614     timezone_aware = True
615
616     def ddl(self):
617         return "TIMESTAMP WITH TIME ZONE"
618
619 class DATE(dbtypes.SQL92DATE):
620     """A calendar date (year, month, day)."""
621     default_adapters = {datetime.date: PgDATE_Adapter()}
622
623 class TIME(dbtypes.SQL92TIME):
624     """A time of day."""
625     pass
626
627 class INTERVAL(dbtypes.AdjustablePrecisionType):
628     """A time span."""
629     default_adapters = {datetime.timedelta: PgINTERVAL_Adapter()}
630     default_pytype = datetime.timedelta
631
632
633 class DECIMAL(dbtypes.SQL92DECIMAL):
634     """An exact numeric of selectable precision."""
635    
636     # "In addition to ordinary numeric values, the numeric type allows the
637     # special value NaN, meaning "not-a-number". Any operation on NaN yields
638     # another NaN. When writing this value as a constant in a SQL command,
639     # you must put quotes around it, for example UPDATE table SET x = 'NaN'.
640     # On input, the string NaN is recognized in a case-insensitive manner."
641    
642     synonyms = ['NUMERIC']
643     _precision = max_precision = 1000
644
645
646 class MONEY(dbtypes.FrozenPrecisionType):
647     """A currency amount."""
648     default_pytype = dbtypes.SQL92DECIMAL.default_pytype
649
650
651 class INET(dbtypes.FrozenByteType):
652     """An IPv4 or IPv6 host address, and optionally the subnet."""
653     # "The inet type holds an IPv4 or IPv6 host address, and optionally
654     # the identity of the subnet it is in, all in one field. The subnet
655     # identity is represented by stating how many bits of the host address
656     # represent the network address (the "netmask"). If the netmask is 32
657     # and the address is IPv4, then the value does not indicate a subnet,
658     # only a single host. In IPv6, the address length is 128 bits, so 128
659     # bits specify a unique host address. Note that if you want to accept
660     # networks only, you should use the cidr type rather than inet.
661     #
662     # The input format for this type is address/y where address is an IPv4
663     # or IPv6 address and y is the number of bits in the netmask. If the /y
664     # part is left off, then the netmask is 32 for IPv4 and 128 for IPv6,
665     # so the value represents just a single host. On display, the /y
666     # portion is suppressed if the netmask specifies a single host."
667    
668     variable = False
669     encoding = 'utf8'
670    
671     default_pytype = str
672     default_adapters = {str: adapters.str_to_SQL92VARCHAR(),
673                         unicode: adapters.unicode_to_SQL92VARCHAR(),
674                         None: adapters.Pickler(),
675                         }
676
677
678 class PgTypeSet(dbtypes.DatabaseTypeSet):
679    
680     known_types = {'float': [FLOAT4, FLOAT8],
681                    'varchar': [TEXT, VARCHAR, VARBIT, BYTEA, NAME],
682                    'char': [CHAR, BIT],
683                    'int': [INT2, INT4, INT8],
684                    'bool': [BOOLEAN],
685                    'datetime': [TIMESTAMP, TIMESTAMPTZ],
686                    'date': [DATE],
687                    'time': [TIME],
688                    'timedelta': [INTERVAL],
689                    'numeric': [DECIMAL],
690                    'other': [MONEY, INET],
691                    }
692
693
694
695 class PgDeparser(deparse.SQLDeparser):
696    
697     def builtins_ieq(self, op1, op2):
698         # ILIKE with no wildcards should behave like ieq.
699         return self.get_expr(op1.adapter.like_op(
700             op1, op2, ignore_case=True, start_only=True, end_only=True), bool)
701    
702     def builtins_year(self, x):
703         return self.get_expr("date_part('year', " + x.sql + ")", int)
704    
705     def builtins_month(self, x):
706         return self.get_expr("date_part('month', " + x.sql + ")", int)
707    
708     def builtins_day(self, x):
709         return self.get_expr("date_part('day', " + x.sql + ")", int)
710    
711     def builtins_now(self):
712         neg, h, m = adapters.localtime_offset()
713         sign = ""
714         if neg:
715             sign = "-"
716         offset = "%s:%s" % (h, m)
717         return self.get_expr("(NOW() AT TIME ZONE INTERVAL '%s%s')"
718                              % (sign, offset), datetime.datetime)
719    
720     def builtins_utcnow(self):
721         return self.get_expr("NOW()", datetime.datetime)
722    
723     def builtins_today(self):
724         neg, h, m = adapters.localtime_offset()
725         sign = ""
726         if neg:
727             sign = "-"
728         offset = "%s:%s" % (h, m)
729         return self.get_expr("date_trunc('day', NOW() AT TIME ZONE INTERVAL '%s%s')"
730                              % (sign, offset), datetime.date)
731
732
733 class PgIndexSet(geniusql.IndexSet):
734    
735     def __delitem__(self, key):
736         """Drop the specified index."""
737         # PG doesn't use DROP INDEX .. ON ..
738         self.table.schema.db.execute_ddl('DROP INDEX %s;' % self[key].qname)
739
740
741 class PgTable(geniusql.Table):
742    
743     implicit_pkey_indices = True
744    
745     def __init__(self, name, qname, schema, created=False, description=None):
746         geniusql.Table.__init__(self, name=name, qname=qname, schema=schema,
747                                 created=created, description=description)
748         self.qname = self.schema.qname + "." + self.qname
749    
750     def _grab_new_ids(self, idkeys, conn):
751         newids = {}
752         for idkey in idkeys:
753             col = self[idkey]
754             seq = self.schema.qname + "." + col.sequence_name
755             # Using currval instead of "SELECT last_value FROM %s;"
756             # avoids the need for permissions on the sequence.
757             data, _ = self.schema.db.fetch("SELECT currval('%s');" % seq, conn)
758             newids[idkey] = data[0][0]
759         return newids
760    
761     def drop_primary(self):
762         """Remove any PRIMARY KEY for this Table."""
763         db = self.schema.db
764        
765         # Get the OID of the table
766         data, _ = db.fetch("SELECT oid FROM pg_class WHERE "
767                            "relname = '%s'" % self.name)
768         table_OID = data[0][0]
769        
770         data, _ = db.fetch("SELECT conname, * FROM pg_constraint WHERE conrelid "
771                            "= %s AND contype = 'p'" % table_OID)
772         for row in data:
773             constraint_name = row[0]
774             db.execute('ALTER TABLE %s DROP CONSTRAINT "%s";'
775                        % (self.qname, constraint_name))
776
777
778 class PgSchema(geniusql.Schema):
779    
780     tableclass = PgTable
781     indexsetclass = PgIndexSet
782    
783     discover_pg_tables = False
784    
785     def __init__(self, db, name=None):
786         if name is None:
787             name = 'public'
788         geniusql.Schema.__init__(self, db, name)
789    
790     def _get_tables(self, conn=None):
791         data, _ = self.db.fetch("SELECT oid FROM pg_class WHERE relname = "
792                                 " 'pg_class' and relkind='r'", conn=conn)
793         pgclass_OID = data[0][0]
794        
795         data, _ = self.db.fetch("SELECT oid FROM pg_namespace WHERE "
796                                 "nspname = '%s'" % self.name, conn=conn)
797         nsoid = data[0][0]
798        
799         data, _ = self.db.fetch(
800             "SELECT c.relname, d.description FROM pg_class c LEFT JOIN "
801             "(SELECT description, objoid FROM pg_description WHERE "
802             "classoid = %s) AS d ON c.oid = d.objoid WHERE c.relnamespace = "
803             "%s and c.relkind = 'r';" % (pgclass_OID, nsoid), conn=conn)
804         return [self.tableclass(name, self.db.quote(name), self,
805                                 created=True, description=description)
806                 for name, description in data
807                 if self.discover_pg_tables or not name.startswith("pg_")]
808    
809     def _get_table(self, tablename, conn=None):
810         if (not self.discover_pg_tables) and tablename.startswith("pg_"):
811             raise errors.MappingError(
812                 "Table %r not found. Set schema.discover_pg_tables to True "
813                 "if you want to discover Postgres system tables (pg_*)." %
814                 tablename)
815        
816         data, _ = self.db.fetch(
817             "SELECT oid FROM pg_class WHERE relname = 'pg_class'", conn=conn)
818         pgclass_OID = data[0][0]
819        
820         data, _ = self.db.fetch(
821             "SELECT oid FROM pg_namespace WHERE nspname = '%s'" % self.name,
822             conn=conn)
823         nsoid = data[0][0]
824        
825         data, _ = self.db.fetch(
826             "SELECT c.oid, c.relname, c.relkind FROM pg_class c WHERE "
827             "c.relnamespace = %s AND c.relname = '%s' AND c.relkind in ('r', 'v')" %
828             (nsoid, tablename), conn=conn)
829         for table_OID, name, kind in data:
830             if name == tablename:
831                 if kind == 'r':
832                     t = self.tableclass(name, self.db.quote(name),
833                                         self, created=True)
834                 else:
835                     t = self.viewclass(name, self.db.quote(name),
836                                        self, created=True)
837                
838                 # Get the description of the table, if any
839                 data, _ = self.db.fetch("SELECT description FROM pg_description "
840                                         "WHERE objoid = %s and classoid = %s" %
841                                         (table_OID, pgclass_OID), conn=conn)
842                 for cell, in data:
843                     t.description = cell
844                     break
845                
846                 return t
847         raise errors.MappingError("Table %r not found." % tablename)
848    
849     def _get_columns(self, table, conn=None):
850         data, _ = self.db.fetch(
851             "SELECT oid FROM pg_namespace WHERE nspname = '%s'" % self.name,
852             conn=conn)
853         nsoid = data[0][0]
854        
855         # Get the OID of the table
856         data, _ = self.db.fetch(
857             "SELECT c.oid FROM pg_class c WHERE c.relnamespace = %s AND "
858             "c.relname = '%s' AND c.relkind in ('r', 'v')" %
859             (nsoid, table.name), conn=conn)
860         table_OID = data[0][0]
861        
862         # Get index data so we can set col.key if pg_index.indisprimary
863         data, _ = self.db.fetch(
864             "SELECT indkey FROM pg_index WHERE indrelid = %s AND indisprimary"
865             % table_OID, conn=conn)
866         if data:
867             # indkey is an "array" (we get a space-separated string of ints).
868             # These will equal pg_attribute.attnum, below.
869             indices = map(int, data[0][0].split(" "))
870         else:
871             indices = []
872        
873         # Get column data
874         sql = ("SELECT attname, atttypid, attnum, attlen, atttypmod "
875                "FROM pg_attribute WHERE attisdropped = False AND "
876                "attrelid = %s" % table_OID)
877         data, _ = self.db.fetch(sql, conn=conn)
878         cols = []
879         typeset = self.db.typeset
880         for row in data:
881             name = row[0]
882             if name in ('tableoid', 'cmax', 'xmax', 'cmin', 'xmin',
883                         'oid', 'ctid'):
884                 # This is a column which PostgreSQL defines automatically
885                 continue
886            
887             # Data type
888             dbtype, _ = self.db.fetch("SELECT typname, typlen FROM pg_type "
889                                       "WHERE oid = %s" % row[1], conn=conn)
890             try:
891                 dbtypetype = typeset.canonicalize(dbtype[0][0].upper())
892             except KeyError, x:
893                 x.args += ("%s.%s" % (table.name, name),)
894                 raise
895             dbtype = dbtypetype()
896            
897             c = geniusql.Column(dbtype.default_pytype, dbtype,
898                                 None, key=row[2] in indices,
899                                 name=row[0], qname=self.db.quote(row[0]))
900             c.adapter = dbtype.default_adapter(c.pytype)
901            
902             if dbtypetype in (FLOAT4, FLOAT8):
903                 dbtype.precision = int(row[3])
904             elif dbtypetype in (MONEY, DECIMAL):
905                 dbtype.precision = int((row[4] >> 16) & 65535)
906                 dbtype.scale = int((row[4] & 65535) - 4)
907            
908             if dbtypetype is VARCHAR:
909                 # See http://archives.postgresql.org/pgsql-interfaces/2004-07/msg00021.php
910                 bytes = int(row[4] - 4)
911                 if bytes > 0:
912                     dbtype.bytes = bytes
913                 else:
914                     raise ValueError("Column %r has illegal size %r" % (name, bytes))
915             else:
916                 bytes = int(row[3])
917                 if bytes > 0:
918                     dbtype.bytes = bytes
919            
920             # Default value
921             default, _ = self.db.fetch(
922                 "SELECT adsrc FROM pg_attrdef WHERE adnum = %s AND adrelid = %s"
923                 % (row[2], table_OID), conn=conn)
924             if default:
925                 default = default[0][0]
926                 if default.startswith("nextval("):
927                     # Grab seqname from "nextval('seqname'::[text|regclass])"
928                     c.autoincrement = True
929                     sname = seq_name.search(default).group(1)
930                     if (sname.startswith(self.name + ".") or
931                         sname.startswith(self.qname + ".")):
932                         sname = sname.split(".", 1)[1]
933                     # Don't stick the schema name into c.sequence_name...
934                     c.sequence_name = sname
935                     # ...but do use the schema name to get min_value
936                     sqname = self.qname + "." + sname
937                     c.initial = self.db.fetch("SELECT min_value FROM %s" %
938                                               sqname, conn=conn)[0][0][0]
939                     c.default = None
940                 else:
941                     # adsrc is always a string, so we must cast it using
942                     # our guessed type. Be sure to strip any ::typename
943                     defval = default.split("::", 1)[0]
944                     try:
945                         # String defaults have quotes we need to strip
946                         defval = defval.strip("'")
947                         c.default = c.adapter.pull(defval, c.dbtype)
948                     except ValueError:
949                         # The default is probably a function like 'now()'.
950                         # Keep the whole unmunged string for now.
951                         # TODO: set default to an equivalent lambda?
952                         c.default = default
953             else:
954                 c.default = None
955            
956             cols.append(c)
957         return cols
958    
959     def _get_indices(self, table, conn=None):
960         data, _ = self.db.fetch("SELECT oid FROM pg_namespace WHERE "
961                                 "nspname = '%s'" % self.name, conn=conn)
962         nsoid = data[0][0]
963        
964         # Get the OID of the table
965         data, _ = self.db.fetch("SELECT c.oid FROM pg_class c WHERE "
966                                 "c.relnamespace = %s AND "
967                                 "c.relname = '%s' AND c.relkind = 'r'" %
968                                 (nsoid, table.name), conn=conn)
969         table_OID = data[0][0]
970        
971         indices = []
972         data, _ = self.db.fetch(
973             "SELECT pg_class.relname, indkey, indisprimary, "
974             "indisunique FROM pg_index LEFT JOIN pg_class "
975             "ON pg_index.indexrelid = pg_class.oid WHERE "
976             "pg_index.indrelid = %s" % table_OID, conn=conn)
977         for row in data:
978             iname = row[0]
979             q_iname = self.db.quote(iname)
980             uniq = bool(row[3])
981             # indkey is an "array" (we get a space-separated string of ints).
982             cols = map(int, row[1].split(" "))
983             for col in cols:
984                 d, _ = self.db.fetch("SELECT attname FROM pg_attribute "
985                                      "WHERE attrelid = %s AND attnum = %s"
986                                      % (table_OID, col), conn=conn)
987                 if not d:
988                     # This is probably an index that was added by hand,
989                     # without reference to a single existing column.
990                     indices.append(geniusql.Index(iname, q_iname, table.name,
991                                                   "<unknown>", uniq))
992                 else:
993                     attname = d[0][0]
994                     indices.append(geniusql.Index(iname, q_iname, table.name,
995                                                   attname, uniq))
996        
997         return indices
998    
999     def columnclause(self, column):
1000         """Return a clause for the given column for CREATE or ALTER TABLE.
1001         
1002         This will be of the form "name type [DEFAULT [x | nextval('seq')]]".
1003         
1004         PostgreSQL creates the sequence in a separate statement.
1005         """
1006         if column.autoincrement:
1007             default = "nextval('%s.%s')" % (self.qname, column.sequence_name)
1008         else:
1009             default = column.default or ""
1010             if isinstance(default, str):
1011                 if issubclass(column.pytype, basestring):
1012                     default = column.adapter.push(default, column.dbtype)
1013             else:
1014                 default = column.adapter.push(default, column.dbtype)
1015        
1016         if default:
1017             default = " DEFAULT %s" % default
1018        
1019         return '%s %s%s' % (column.qname, column.dbtype.ddl(), default)
1020    
1021     def sequence_name(self, tablename, columnkey):
1022         "Return the SQL sequence name for the given table name and column key."
1023         # If you want to use a map from your ORM's property names
1024         # to DB sequence names, override this method (that's why
1025         # the tablename must be included in the args).
1026         sname = "%s_%s_seq" % (tablename, columnkey)
1027         maxlen = self.db.sql_name_max_length
1028         if maxlen and len(sname) > maxlen:
1029             # Postgres (8.2 anyway) seems to truncate the table name to fit.
1030             sname = "_%s_seq" % columnkey
1031             sname = tablename[:maxlen - len(sname)] + sname
1032         return self.db.sql_name(sname)
1033    
1034     def index_name(self, table, columnkey):
1035         """Return the SQL index name for the given table and column key."""
1036         col = table[columnkey]
1037         if col.key:
1038             return self.db.sql_name("%s_pkey" % col.name)
1039         else:
1040             return self.db.sql_name("%s_%s_idx" % (table.name, col.name))
1041    
1042     def create_sequence(self, table, column):
1043         """Create a SEQUENCE for the given column."""
1044         if column.sequence_name is not None:
1045             self.db.execute_ddl("CREATE SEQUENCE %s.%s START %s;" %
1046                                 (self.qname, column.sequence_name,
1047                                  column.initial))
1048    
1049     def drop_sequence(self, column):
1050         """Drop a SEQUENCE for the given column."""
1051         if column.sequence_name is not None:
1052             self.db.execute_ddl("DROP SEQUENCE %s.%s;" %
1053                                 (self.qname, column.sequence_name))
1054    
1055     def create(self):
1056         if self.name != "public":
1057             self.db.execute_ddl("CREATE SCHEMA %s" % self.qname)
1058         self.clear()
1059    
1060     def drop(self, restrict=False):
1061         """Drop this schema (and any contained objects) from the database.
1062         
1063         WARNING: This method's default is to drop any objects owned by the
1064         schema using the CASCADE parameter to DROP SCHEMA. This is contrary
1065         to the PostgreSQL default! If you wish to drop with the RESTRICT
1066         parameter instead, set the 'restrict' argument to True.
1067         """
1068         if self.name != "public":
1069             if restrict:
1070                 restrict = 'RESTRICT'
1071             else:
1072                 restrict = 'CASCADE'
1073             self.db.execute_ddl("DROP SCHEMA %s %s;" % (self.qname, restrict))
1074         self.clear()
1075
1076
1077 class PgDatabase(geniusql.Database):
1078    
1079     sql_name_max_length = 63
1080     quote_all = True
1081     poolsize = 10
1082     encoding = 'UTF8'
1083    
1084     deparser = PgDeparser
1085     schemaclass = PgSchema
1086     typeset = PgTypeSet()
1087    
1088     def quote(self, name):
1089         if self.quote_all:
1090             name = '"' + name.replace('"', '""') + '"'
1091         return name
1092    
1093     def sql_name(self, name):
1094         name = geniusql.Database.sql_name(self, name)
1095         if not self.quote_all:
1096             name = name.lower()
1097         return name
1098    
1099     def schema(self, name="public"):
1100         return self.schemaclass(self, name)
1101    
1102     def create(self):
1103         c = self.connections._get_conn(master=True)
1104         encoding = self.encoding
1105         if encoding:
1106             encoding = " WITH ENCODING '%s'" % encoding
1107         self.execute_ddl("CREATE DATABASE %s%s" % (self.qname, encoding), c)
1108         self.connections._del_conn(c)
1109    
1110     def exists(self):
1111         """Return True if this database exists, False otherwise."""
1112         c = self.connections._get_conn(master=True)
1113         data, _ = self.fetch("SELECT datname FROM pg_database "
1114                              "WHERE datname = '%s';" % self.sql_name(self.name),
1115                              conn=c)
1116         self.connections._del_conn(c)
1117         return bool(data)
1118    
1119     def drop(self):
1120         c = self.connections._get_conn(master=True)
1121         self.execute_ddl("DROP DATABASE %s;" % self.qname, c)
1122         self.connections._del_conn(c)
1123    
1124     def _get_schemas(self, conn=None):
1125         """Return a list of schema names."""
1126         data, _ = self.fetch("SELECT nspname FROM pg_namespace;", conn=conn)
1127         return [name for name, in data if name != 'information_schema'
1128                 and not name.startswith('pg_')]
1129
1130
Note: See TracBrowser for help on using the browser.