| 1 |
import datetime |
|---|
| 2 |
try: |
|---|
| 3 |
import cPickle as pickle |
|---|
| 4 |
except ImportError: |
|---|
| 5 |
import pickle |
|---|
| 6 |
import re |
|---|
| 7 |
seq_name = re.compile(r"nextval\('([^:]+)'.*\)") |
|---|
| 8 |
escape_oct = re.compile(r"[\000-\037\177-\377]") |
|---|
| 9 |
replace_oct = lambda m: r"\\%03o" % ord(m.group(0)) |
|---|
| 10 |
unescape_oct = re.compile(r"\\\\(\d\d\d)") |
|---|
| 11 |
replace_unoct = lambda m: chr(int(m.group(1), 8)) |
|---|
| 12 |
import threading |
|---|
| 13 |
|
|---|
| 14 |
import geniusql |
|---|
| 15 |
from geniusql import adapters, dbtypes, deparse, errors |
|---|
| 16 |
|
|---|
| 17 |
|
|---|
| 18 |
|
|---|
| 19 |
|
|---|
| 20 |
|
|---|
| 21 |
class PgDATE_Adapter(adapters.date_to_SQL92DATE): |
|---|
| 22 |
|
|---|
| 23 |
def binary_op(self, op1, op, sqlop, op2): |
|---|
| 24 |
if op2.pytype is datetime.timedelta: |
|---|
| 25 |
|
|---|
| 26 |
|
|---|
| 27 |
return "(%s %s date_trunc('day', %s))" % (op1.sql, sqlop, op2.sql) |
|---|
| 28 |
elif op2.pytype is datetime.date: |
|---|
| 29 |
|
|---|
| 30 |
return "(%s::TIMESTAMP %s %s::TIMESTAMP)" % (op1.sql, sqlop, op2.sql) |
|---|
| 31 |
raise TypeError("unsupported operand type(s) for %s: " |
|---|
| 32 |
"%r and %r" % (op, op1.pytype, op2.pytype)) |
|---|
| 33 |
|
|---|
| 34 |
|
|---|
| 35 |
class datetime_to_PgTIMESTAMPTZ_Adapter(adapters.Adapter): |
|---|
| 36 |
"""Adapter for timezone-naive datetime objects. |
|---|
| 37 |
|
|---|
| 38 |
Postgres stores all timestamps as UTC internally, adjusting inbound |
|---|
| 39 |
values based on their timezone component, and offsetting outbound |
|---|
| 40 |
values relative to the Postgres 'timezone' config entry. |
|---|
| 41 |
|
|---|
| 42 |
This adapter assumes you always want to push and pull datetime objects |
|---|
| 43 |
that have no timezone. Therefore, it doesn't supply a timezone atom |
|---|
| 44 |
when sending datetime data to Postgres. When retrieving values, any |
|---|
| 45 |
offset is silently ignored. That is, in both case, we assume you |
|---|
| 46 |
correctly set the connection's timezone attribute. |
|---|
| 47 |
""" |
|---|
| 48 |
|
|---|
| 49 |
def push(self, value, dbtype): |
|---|
| 50 |
if value is None: |
|---|
| 51 |
return 'NULL' |
|---|
| 52 |
return ("'%04d-%02d-%02d %02d:%02d:%02d'" % |
|---|
| 53 |
(value.year, value.month, value.day, |
|---|
| 54 |
value.hour, value.minute, value.second)) |
|---|
| 55 |
|
|---|
| 56 |
def pull(self, value, dbtype): |
|---|
| 57 |
if value is None: |
|---|
| 58 |
return None |
|---|
| 59 |
if isinstance(value, datetime.datetime): |
|---|
| 60 |
return value |
|---|
| 61 |
chunks = (value[0:4], value[5:7], value[8:10], |
|---|
| 62 |
value[11:13], value[14:16], value[17:19]) |
|---|
| 63 |
args = map(int, chunks) |
|---|
| 64 |
|
|---|
| 65 |
ms, tz = None, None |
|---|
| 66 |
mstz = value[19:] |
|---|
| 67 |
if mstz: |
|---|
| 68 |
signpos = mstz.find("+") |
|---|
| 69 |
if signpos == -1: |
|---|
| 70 |
signpos = mstz.find("-") |
|---|
| 71 |
|
|---|
| 72 |
if signpos != -1: |
|---|
| 73 |
|
|---|
| 74 |
ms = mstz[:signpos] |
|---|
| 75 |
else: |
|---|
| 76 |
ms = mstz |
|---|
| 77 |
|
|---|
| 78 |
if ms: |
|---|
| 79 |
ms = int(ms.strip(".")) |
|---|
| 80 |
|
|---|
| 81 |
args.append(ms or 0) |
|---|
| 82 |
|
|---|
| 83 |
return datetime.datetime(*args) |
|---|
| 84 |
|
|---|
| 85 |
|
|---|
| 86 |
class datetime_tz_to_PgTIMESTAMPTZ_Adapter(adapters.Adapter): |
|---|
| 87 |
"""Adapter for timezone-aware datetime objects. |
|---|
| 88 |
|
|---|
| 89 |
Postgres stores all timestamps as UTC internally, adjusting inbound |
|---|
| 90 |
values based on their timezone component, and offsetting outbound |
|---|
| 91 |
values relative to the Postgres 'timezone' config entry. |
|---|
| 92 |
|
|---|
| 93 |
This adapter assumes you always want to push and pull datetime objects |
|---|
| 94 |
that have a valid tzinfo. Therefore, it always tries to supply a |
|---|
| 95 |
timezone atom when sending datetime data to Postgres. If you push |
|---|
| 96 |
a datetime with a tzinfo of None, "+00" is used for the timezone. |
|---|
| 97 |
When retrieving values, any offset in the database value is used to |
|---|
| 98 |
form a valid tzinfo object for the value (see dbtypes.FixedTimeZone). |
|---|
| 99 |
In both directions, we assume you correctly set the connection's |
|---|
| 100 |
timezone attribute. |
|---|
| 101 |
""" |
|---|
| 102 |
|
|---|
| 103 |
def push(self, value, dbtype): |
|---|
| 104 |
if value is None: |
|---|
| 105 |
return 'NULL' |
|---|
| 106 |
|
|---|
| 107 |
if value.tzinfo is None: |
|---|
| 108 |
h, m = 0, 0 |
|---|
| 109 |
else: |
|---|
| 110 |
offset = value.tzinfo.utcoffset(value) |
|---|
| 111 |
minutes = (offset.days * 1440) + (offset.seconds / 60) |
|---|
| 112 |
h = minutes / 60 |
|---|
| 113 |
m = abs(minutes) % 60 |
|---|
| 114 |
|
|---|
| 115 |
if h < 0: |
|---|
| 116 |
h = abs(h) |
|---|
| 117 |
sign = "-" |
|---|
| 118 |
else: |
|---|
| 119 |
sign = "+" |
|---|
| 120 |
|
|---|
| 121 |
return ("TIMESTAMP WITH TIME ZONE " |
|---|
| 122 |
"'%04d-%02d-%02d %02d:%02d:%02d.%06d%s%02d:%02d'" % |
|---|
| 123 |
(value.year, value.month, value.day, |
|---|
| 124 |
value.hour, value.minute, value.second, value.microsecond, |
|---|
| 125 |
sign, h, m)) |
|---|
| 126 |
|
|---|
| 127 |
def pull(self, value, dbtype): |
|---|
| 128 |
if value is None: |
|---|
| 129 |
return None |
|---|
| 130 |
if isinstance(value, datetime.datetime): |
|---|
| 131 |
return value |
|---|
| 132 |
chunks = (value[0:4], value[5:7], value[8:10], |
|---|
| 133 |
value[11:13], value[14:16], value[17:19]) |
|---|
| 134 |
args = map(int, chunks) |
|---|
| 135 |
|
|---|
| 136 |
ms, tz = None, None |
|---|
| 137 |
mstz = value[19:] |
|---|
| 138 |
if mstz: |
|---|
| 139 |
signpos = mstz.find("+") |
|---|
| 140 |
if signpos == -1: |
|---|
| 141 |
signpos = mstz.find("-") |
|---|
| 142 |
|
|---|
| 143 |
if signpos != -1: |
|---|
| 144 |
|
|---|
| 145 |
ms, tz = mstz[:signpos], mstz[signpos:] |
|---|
| 146 |
else: |
|---|
| 147 |
ms, tz = mstz, "" |
|---|
| 148 |
|
|---|
| 149 |
if ms: |
|---|
| 150 |
ms = int(ms.strip(".")) |
|---|
| 151 |
|
|---|
| 152 |
if tz: |
|---|
| 153 |
if ":" in tz: |
|---|
| 154 |
h, m = map(int, h.split(":", 1)) |
|---|
| 155 |
else: |
|---|
| 156 |
h, m = int(tz), 0 |
|---|
| 157 |
tz = dbtypes.FixedTimeZone((h * 60) + m) |
|---|
| 158 |
|
|---|
| 159 |
args.append(ms or 0) |
|---|
| 160 |
args.append(tz or None) |
|---|
| 161 |
return datetime.datetime(*args) |
|---|
| 162 |
|
|---|
| 163 |
|
|---|
| 164 |
class PgINTERVAL_Adapter(adapters.Adapter): |
|---|
| 165 |
|
|---|
| 166 |
def push(self, value, dbtype): |
|---|
| 167 |
if value is None: |
|---|
| 168 |
return 'NULL' |
|---|
| 169 |
|
|---|
| 170 |
h, m = divmod(value.seconds, 3600) |
|---|
| 171 |
m, s = divmod(m, 60) |
|---|
| 172 |
return "interval '%s %s:%s:%s'" % (value.days, h, m, s) |
|---|
| 173 |
|
|---|
| 174 |
def pull(self, value, dbtype): |
|---|
| 175 |
if value is None: |
|---|
| 176 |
return None |
|---|
| 177 |
if isinstance(value, datetime.timedelta): |
|---|
| 178 |
return value |
|---|
| 179 |
|
|---|
| 180 |
|
|---|
| 181 |
|
|---|
| 182 |
|
|---|
| 183 |
|
|---|
| 184 |
|
|---|
| 185 |
|
|---|
| 186 |
|
|---|
| 187 |
|
|---|
| 188 |
|
|---|
| 189 |
|
|---|
| 190 |
|
|---|
| 191 |
days = 0 |
|---|
| 192 |
atoms = re.split(r"( ?days? ?)", value) |
|---|
| 193 |
hms = atoms.pop() |
|---|
| 194 |
if atoms: |
|---|
| 195 |
|
|---|
| 196 |
days = int(atoms[0]) |
|---|
| 197 |
if not hms: |
|---|
| 198 |
return datetime.timedelta(days) |
|---|
| 199 |
|
|---|
| 200 |
h, m, s = hms.split(":", 2) |
|---|
| 201 |
if h.startswith("-"): |
|---|
| 202 |
neg = True |
|---|
| 203 |
h = abs(int(h)) |
|---|
| 204 |
else: |
|---|
| 205 |
neg = False |
|---|
| 206 |
h = int(h) |
|---|
| 207 |
s = (h * 3600) + (int(m) * 60) + float(s) |
|---|
| 208 |
if neg: |
|---|
| 209 |
s = -s |
|---|
| 210 |
|
|---|
| 211 |
return datetime.timedelta(days, s) |
|---|
| 212 |
|
|---|
| 213 |
def binary_op(self, op1, op, sqlop, op2): |
|---|
| 214 |
if op2.pytype is datetime.date: |
|---|
| 215 |
|
|---|
| 216 |
|
|---|
| 217 |
return "(date_trunc('day', %s) %s %s)" % (op1.sql, sqlop, op2.sql) |
|---|
| 218 |
elif op2.pytype in (datetime.datetime, datetime.timedelta): |
|---|
| 219 |
return "(%s %s %s)" % (op1.sql, sqlop, op2.sql) |
|---|
| 220 |
raise TypeError("unsupported operand type(s) for %s: " |
|---|
| 221 |
"%r and %r" % (op, op1.pytype, op2.pytype)) |
|---|
| 222 |
|
|---|
| 223 |
|
|---|
| 224 |
class Pg_LIKE_Mixin(object): |
|---|
| 225 |
|
|---|
| 226 |
like_escapes = [("%", r"\%"), ("_", r"\_")] |
|---|
| 227 |
|
|---|
| 228 |
def escape_like(self, sql): |
|---|
| 229 |
"""Prepare a string value for use in a LIKE comparison.""" |
|---|
| 230 |
|
|---|
| 231 |
sql = sql.strip("'\"") |
|---|
| 232 |
for pat, repl in self.like_escapes: |
|---|
| 233 |
sql = sql.replace(pat, repl) |
|---|
| 234 |
return sql |
|---|
| 235 |
|
|---|
| 236 |
def like_op(self, op1, op2, ignore_case=False, |
|---|
| 237 |
start_only=False, end_only=False): |
|---|
| 238 |
"""Return the SQL for 'op1 LIKE op2' (or raise TypeError). |
|---|
| 239 |
|
|---|
| 240 |
op1 and op2 will be SQLExpression objects. |
|---|
| 241 |
|
|---|
| 242 |
If 'ignore_case' is False (the default), then the LIKE comparison |
|---|
| 243 |
will be performed in a case-sensitive manner; otherwise (if |
|---|
| 244 |
ignore_case is True), the LIKE comparison will be performed in |
|---|
| 245 |
a case-INsensitive manner. |
|---|
| 246 |
|
|---|
| 247 |
If 'start_only' is True, then op2 will be matched only at the start |
|---|
| 248 |
of op1. If False (the default), then op2 will be matched anywhere. |
|---|
| 249 |
|
|---|
| 250 |
If 'end_only' is True, then op2 will be matched only at the end |
|---|
| 251 |
of op1. If False (the default), then op2 will be matched anywhere. |
|---|
| 252 |
|
|---|
| 253 |
If both 'start_only' and 'end_only' are True, then op2 will only |
|---|
| 254 |
match op1 if they are identical. |
|---|
| 255 |
""" |
|---|
| 256 |
likeexpr = self.escape_like(op2.sql) |
|---|
| 257 |
if start_only: |
|---|
| 258 |
start = '' |
|---|
| 259 |
else: |
|---|
| 260 |
start = '%' |
|---|
| 261 |
if end_only: |
|---|
| 262 |
end = '' |
|---|
| 263 |
else: |
|---|
| 264 |
end = '%' |
|---|
| 265 |
if ignore_case: |
|---|
| 266 |
sql = op1.sql + " ILIKE '" + start + likeexpr + end + "'" |
|---|
| 267 |
else: |
|---|
| 268 |
sql = op1.sql + " LIKE '" + start + likeexpr + end + "'" |
|---|
| 269 |
return sql |
|---|
| 270 |
|
|---|
| 271 |
|
|---|
| 272 |
class Pg_str_to_VARCHAR(Pg_LIKE_Mixin, adapters.str_to_SQL92VARCHAR): |
|---|
| 273 |
|
|---|
| 274 |
def push(self, value, dbtype): |
|---|
| 275 |
if value is None: |
|---|
| 276 |
return 'NULL' |
|---|
| 277 |
if not isinstance(value, str): |
|---|
| 278 |
value = value.encode(dbtype.encoding) |
|---|
| 279 |
for pat, repl in self.escapes: |
|---|
| 280 |
value = value.replace(pat, repl) |
|---|
| 281 |
return "'" + value + "'" |
|---|
| 282 |
|
|---|
| 283 |
def pull(self, value, dbtype): |
|---|
| 284 |
if value is None: |
|---|
| 285 |
return None |
|---|
| 286 |
return value |
|---|
| 287 |
|
|---|
| 288 |
|
|---|
| 289 |
class Pg_unicode_to_VARCHAR(Pg_LIKE_Mixin, adapters.unicode_to_SQL92VARCHAR): |
|---|
| 290 |
|
|---|
| 291 |
def push(self, value, dbtype): |
|---|
| 292 |
if value is None: |
|---|
| 293 |
return 'NULL' |
|---|
| 294 |
if not isinstance(value, str): |
|---|
| 295 |
value = value.encode(dbtype.encoding) |
|---|
| 296 |
for pat, repl in self.escapes: |
|---|
| 297 |
value = value.replace(pat, repl) |
|---|
| 298 |
return "'" + value + "'" |
|---|
| 299 |
|
|---|
| 300 |
def pull(self, value, dbtype): |
|---|
| 301 |
if value is None: |
|---|
| 302 |
return None |
|---|
| 303 |
if isinstance(value, buffer): |
|---|
| 304 |
value = str(value) |
|---|
| 305 |
|
|---|
| 306 |
return unicode(value, dbtype.encoding) |
|---|
| 307 |
|
|---|
| 308 |
|
|---|
| 309 |
class PgPickler(Pg_LIKE_Mixin, adapters.Pickler): |
|---|
| 310 |
|
|---|
| 311 |
def push(self, value, dbtype): |
|---|
| 312 |
if value is None: |
|---|
| 313 |
return 'NULL' |
|---|
| 314 |
value = pickle.dumps(value, 2) |
|---|
| 315 |
|
|---|
| 316 |
if not isinstance(value, str): |
|---|
| 317 |
value = value.encode(dbtype.encoding) |
|---|
| 318 |
for pat, repl in self.escapes: |
|---|
| 319 |
value = value.replace(pat, repl) |
|---|
| 320 |
|
|---|
| 321 |
|
|---|
| 322 |
value = escape_oct.sub(replace_oct, value) |
|---|
| 323 |
return "'" + value + "'" |
|---|
| 324 |
|
|---|
| 325 |
def pull(self, value, dbtype): |
|---|
| 326 |
if value is None: |
|---|
| 327 |
return None |
|---|
| 328 |
|
|---|
| 329 |
value = unescape_oct.sub(replace_unoct, value) |
|---|
| 330 |
for pat, repl in self.escapes: |
|---|
| 331 |
value = value.replace(repl, pat) |
|---|
| 332 |
return pickle.loads(value) |
|---|
| 333 |
|
|---|
| 334 |
|
|---|
| 335 |
|
|---|
| 336 |
class PgFLOAT4_Adapter(adapters.float_to_SQL92REAL): |
|---|
| 337 |
|
|---|
| 338 |
def push(self, value, dbtype): |
|---|
| 339 |
if value is None: |
|---|
| 340 |
return 'NULL' |
|---|
| 341 |
|
|---|
| 342 |
|
|---|
| 343 |
|
|---|
| 344 |
return "'%r'" % value |
|---|
| 345 |
|
|---|
| 346 |
def compare_op(self, op1, op, sqlop, op2): |
|---|
| 347 |
if isinstance(op2.dbtype, FLOAT8): |
|---|
| 348 |
|
|---|
| 349 |
return "(%s %s (%s)::FLOAT4)" % (op1.sql, sqlop, op2.sql) |
|---|
| 350 |
elif isinstance(op2.dbtype, (INT2, INT4, INT8, FLOAT4)): |
|---|
| 351 |
return "(%s %s %s)" % (op1.sql, sqlop, op2.sql) |
|---|
| 352 |
raise TypeError("unsupported operand type(s) for %s: " |
|---|
| 353 |
"%r and %r" % (op, op1.pytype, op2.pytype)) |
|---|
| 354 |
|
|---|
| 355 |
|
|---|
| 356 |
|
|---|
| 357 |
|
|---|
| 358 |
|
|---|
| 359 |
class Pg_str_to_BYTEA(Pg_LIKE_Mixin, adapters.str_to_SQL92VARCHAR): |
|---|
| 360 |
"""Python str to PostgreSQL bytea adapter. |
|---|
| 361 |
|
|---|
| 362 |
For the most part, Postgres bytea works like Python's str: a sequence |
|---|
| 363 |
of bytes. Certain bytes have to be octal-escaped for consumption by PG. |
|---|
| 364 |
|
|---|
| 365 |
See http://www.postgresql.org/docs/8.1/interactive/datatype-binary.html |
|---|
| 366 |
""" |
|---|
| 367 |
|
|---|
| 368 |
def push(self, value, dbtype): |
|---|
| 369 |
if value is None: |
|---|
| 370 |
return 'NULL' |
|---|
| 371 |
def repl(char): |
|---|
| 372 |
o = ord(char) |
|---|
| 373 |
if o <= 31 or o == 39 or o == 92 or o >= 127: |
|---|
| 374 |
return r"\\%03d" % int(oct(o)) |
|---|
| 375 |
return char |
|---|
| 376 |
return "'%s'::bytea" % "".join(map(repl, value)) |
|---|
| 377 |
|
|---|
| 378 |
def pull(self, value, dbtype): |
|---|
| 379 |
if value is None: |
|---|
| 380 |
return None |
|---|
| 381 |
|
|---|
| 382 |
value = unescape_oct.sub(replace_unoct, value) |
|---|
| 383 |
return unicode(value, dbtype.encoding) |
|---|
| 384 |
|
|---|
| 385 |
def escape_like(self, sql): |
|---|
| 386 |
"""Prepare a string value for use in a LIKE comparison.""" |
|---|
| 387 |
|
|---|
| 388 |
sql = sql.strip("'\"") |
|---|
| 389 |
for pat, repl in self.like_escapes: |
|---|
| 390 |
sql = sql.replace(pat, repl) |
|---|
| 391 |
|
|---|
| 392 |
sql = sql.replace("\\", "\\\\") |
|---|
| 393 |
return sql |
|---|
| 394 |
|
|---|
| 395 |
|
|---|
| 396 |
class Pg_unicode_to_BYTEA(Pg_unicode_to_VARCHAR): |
|---|
| 397 |
"""Python unicode to PostgreSQL bytea adapter. |
|---|
| 398 |
|
|---|
| 399 |
For the most part, Postgres bytea works like Python's str: a sequence |
|---|
| 400 |
of bytes. Certain bytes have to be octal-escaped for consumption by PG. |
|---|
| 401 |
|
|---|
| 402 |
See http://www.postgresql.org/docs/8.1/interactive/datatype-binary.html |
|---|
| 403 |
""" |
|---|
| 404 |
|
|---|
| 405 |
def push(self, value, dbtype): |
|---|
| 406 |
|
|---|
| 407 |
|
|---|
| 408 |
|
|---|
| 409 |
if value is None: |
|---|
| 410 |
return 'NULL' |
|---|
| 411 |
if not isinstance(value, str): |
|---|
| 412 |
value = value.encode(dbtype.encoding) |
|---|
| 413 |
def repl(char): |
|---|
| 414 |
o = ord(char) |
|---|
| 415 |
if o <= 31 or o == 39 or o == 92 or o >= 127: |
|---|
| 416 |
return r"\\%03d" % int(oct(o)) |
|---|
| 417 |
return char |
|---|
| 418 |
return "'%s'::bytea" % "".join(map(repl, value)) |
|---|
| 419 |
|
|---|
| 420 |
def escape_like(self, sql): |
|---|
| 421 |
"""Prepare a string value for use in a LIKE comparison.""" |
|---|
| 422 |
|
|---|
| 423 |
sql = sql.strip("'\"") |
|---|
| 424 |
for pat, repl in self.like_escapes: |
|---|
| 425 |
sql = sql.replace(pat, repl) |
|---|
| 426 |
|
|---|
| 427 |
sql = sql.replace("\\", "\\\\") |
|---|
| 428 |
return sql |
|---|
| 429 |
|
|---|
| 430 |
|
|---|
| 431 |
class PgBYTEA_Pickler(PgPickler): |
|---|
| 432 |
"""Python object to PostgreSQL bytea adapter. |
|---|
| 433 |
|
|---|
| 434 |
For the most part, Postgres bytea works like Python's str: a sequence |
|---|
| 435 |
of bytes. Certain bytes have to be octal-escaped for consumption by PG. |
|---|
| 436 |
|
|---|
| 437 |
See http://www.postgresql.org/docs/8.1/interactive/datatype-binary.html |
|---|
| 438 |
""" |
|---|
| 439 |
|
|---|
| 440 |
def push(self, value, dbtype): |
|---|
| 441 |
if value is None: |
|---|
| 442 |
return 'NULL' |
|---|
| 443 |
|
|---|
| 444 |
value = pickle.dumps(value, 2) |
|---|
| 445 |
|
|---|
| 446 |
def repl(char): |
|---|
| 447 |
o = ord(char) |
|---|
| 448 |
if o <= 31 or o == 39 or o == 92 or o >= 127: |
|---|
| 449 |
return r"\\%03d" % int(oct(o)) |
|---|
| 450 |
return char |
|---|
| 451 |
return "'%s'::bytea" % "".join(map(repl, value)) |
|---|
| 452 |
|
|---|
| 453 |
def escape_like(self, sql): |
|---|
| 454 |
"""Prepare a string value for use in a LIKE comparison.""" |
|---|
| 455 |
|
|---|
| 456 |
sql = sql.strip("'\"") |
|---|
| 457 |
for pat, repl in self.like_escapes: |
|---|
| 458 |
sql = sql.replace(pat, repl) |
|---|
| 459 |
|
|---|
| 460 |
sql = sql.replace("\\", "\\\\") |
|---|
| 461 |
return sql |
|---|
| 462 |
|
|---|
| 463 |
|
|---|
| 464 |
|
|---|
| 465 |
|
|---|
| 466 |
|
|---|
| 467 |
|
|---|
| 468 |
|
|---|
| 469 |
|
|---|
| 470 |
|
|---|
| 471 |
|
|---|
| 472 |
|
|---|
| 473 |
|
|---|
| 474 |
|
|---|
| 475 |
|
|---|
| 476 |
|
|---|
| 477 |
|
|---|
| 478 |
|
|---|
| 479 |
|
|---|
| 480 |
|
|---|
| 481 |
|
|---|
| 482 |
class BOOLEAN(dbtypes.SQL99BOOLEAN): |
|---|
| 483 |
"""A logical Boolean (true/false).""" |
|---|
| 484 |
synonyms = ['BOOL'] |
|---|
| 485 |
|
|---|
| 486 |
|
|---|
| 487 |
class BYTEA(dbtypes.FrozenByteType): |
|---|
| 488 |
"""A type for binary data ("byte array").""" |
|---|
| 489 |
default_adapters = {str: Pg_str_to_BYTEA(), |
|---|
| 490 |
unicode: Pg_unicode_to_BYTEA(), |
|---|
| 491 |
None: PgBYTEA_Pickler(), |
|---|
| 492 |
} |
|---|
| 493 |
default_pytype = str |
|---|
| 494 |
encoding = 'utf8' |
|---|
| 495 |
|
|---|
| 496 |
class BIT(dbtypes.SQL92VARCHAR): |
|---|
| 497 |
"""A fixed-length bit string""" |
|---|
| 498 |
variable = False |
|---|
| 499 |
default_adapters = {str: Pg_str_to_VARCHAR(), |
|---|
| 500 |
unicode: Pg_unicode_to_VARCHAR(), |
|---|
| 501 |
None: PgPickler(), |
|---|
| 502 |
} |
|---|
| 503 |
|
|---|
| 504 |
class VARBIT(dbtypes.SQL92VARCHAR): |
|---|
| 505 |
"""A variable-length bit string.""" |
|---|
| 506 |
synonyms = ['BIT VARYING'] |
|---|
| 507 |
variable = True |
|---|
| 508 |
default_adapters = {str: Pg_str_to_VARCHAR(), |
|---|
| 509 |
unicode: Pg_unicode_to_VARCHAR(), |
|---|
| 510 |
None: PgPickler(), |
|---|
| 511 |
} |
|---|
| 512 |
|
|---|
| 513 |
class CHAR(dbtypes.SQL92CHAR): |
|---|
| 514 |
"""A fixed-length character string.""" |
|---|
| 515 |
synonyms = ['CHARACTER', 'BPCHAR'] |
|---|
| 516 |
default_adapters = {str: Pg_str_to_VARCHAR(), |
|---|
| 517 |
unicode: Pg_unicode_to_VARCHAR(), |
|---|
| 518 |
None: PgPickler(), |
|---|
| 519 |
} |
|---|
| 520 |
|
|---|
| 521 |
class VARCHAR(dbtypes.SQL92VARCHAR): |
|---|
| 522 |
"""A variable-length character string.""" |
|---|
| 523 |
|
|---|
| 524 |
|
|---|
| 525 |
max_bytes = 2**30 |
|---|
| 526 |
|
|---|
| 527 |
synonyms = ['CHARACTER VARYING'] |
|---|
| 528 |
variable = True |
|---|
| 529 |
default_adapters = {str: Pg_str_to_VARCHAR(), |
|---|
| 530 |
unicode: Pg_unicode_to_VARCHAR(), |
|---|
| 531 |
None: PgPickler(), |
|---|
| 532 |
} |
|---|
| 533 |
|
|---|
| 534 |
|
|---|
| 535 |
class ComparableInfinity(object): |
|---|
| 536 |
|
|---|
| 537 |
def __cmp__(self, other): |
|---|
| 538 |
if isinstance(other, self.__class__): |
|---|
| 539 |
return False |
|---|
| 540 |
return True |
|---|
| 541 |
|
|---|
| 542 |
def __str__(self): |
|---|
| 543 |
return "Infinity" |
|---|
| 544 |
|
|---|
| 545 |
def __repr__(self): |
|---|
| 546 |
return "%s.%s()" % (self.__module__, self.__class__.__name__) |
|---|
| 547 |
|
|---|
| 548 |
|
|---|
| 549 |
|
|---|
| 550 |
class TEXT(dbtypes.TEXT): |
|---|
| 551 |
"""A variable-length character string.""" |
|---|
| 552 |
|
|---|
| 553 |
_bytes = max_bytes = ComparableInfinity() |
|---|
| 554 |
|
|---|
| 555 |
default_adapters = dbtypes.TEXT.default_adapters.copy() |
|---|
| 556 |
default_adapters.update({str: Pg_str_to_VARCHAR(), |
|---|
| 557 |
unicode: Pg_unicode_to_VARCHAR(), |
|---|
| 558 |
None: PgPickler(), |
|---|
| 559 |
}) |
|---|
| 560 |
|
|---|
| 561 |
class NAME(dbtypes.TEXT): |
|---|
| 562 |
"""63-character type for storing system identifiers.""" |
|---|
| 563 |
_bytes = max_bytes = 63 |
|---|
| 564 |
|
|---|
| 565 |
default_adapters = dbtypes.TEXT.default_adapters.copy() |
|---|
| 566 |
default_adapters.update({str: Pg_str_to_VARCHAR(), |
|---|
| 567 |
unicode: Pg_unicode_to_VARCHAR(), |
|---|
| 568 |
None: PgPickler(), |
|---|
| 569 |
}) |
|---|
| 570 |
|
|---|
| 571 |
|
|---|
| 572 |
|
|---|
| 573 |
|
|---|
| 574 |
|
|---|
| 575 |
|
|---|
| 576 |
class FLOAT4(dbtypes.SQL92REAL): |
|---|
| 577 |
"""A single precision floating-point number.""" |
|---|
| 578 |
synonyms = ['REAL'] |
|---|
| 579 |
default_adapters = {float: PgFLOAT4_Adapter()} |
|---|
| 580 |
|
|---|
| 581 |
class FLOAT8(dbtypes.SQL92DOUBLE): |
|---|
| 582 |
"""A double precision floating-point number.""" |
|---|
| 583 |
synonyms = ['DOUBLE PRECISION'] |
|---|
| 584 |
|
|---|
| 585 |
|
|---|
| 586 |
class INT2(dbtypes.SQL92SMALLINT): |
|---|
| 587 |
"""A signed two-byte integer.""" |
|---|
| 588 |
synonyms = ['SMALLINT'] |
|---|
| 589 |
|
|---|
| 590 |
class INT4(dbtypes.SQL92INTEGER): |
|---|
| 591 |
"""A signed four-byte integer.""" |
|---|
| 592 |
|
|---|
| 593 |
|
|---|
| 594 |
|
|---|
| 595 |
|
|---|
| 596 |
|
|---|
| 597 |
synonyms = ['INT', 'INTEGER', 'SERIAL', 'SERIAL4'] |
|---|
| 598 |
|
|---|
| 599 |
class INT8(dbtypes.SQL92INTEGER): |
|---|
| 600 |
synonyms = ['BIGINT', 'BIGSERIAL', 'SERIAL8'] |
|---|
| 601 |
_bytes = max_bytes = 8 |
|---|
| 602 |
default_adapters = {int: adapters.int_to_SQL92INTEGER(8), |
|---|
| 603 |
long: adapters.int_to_SQL92INTEGER(8), |
|---|
| 604 |
} |
|---|
| 605 |
|
|---|
| 606 |
|
|---|
| 607 |
class TIMESTAMP(dbtypes.SQL92TIMESTAMP): |
|---|
| 608 |
"""A date and time. Timezone naive.""" |
|---|
| 609 |
default_adapters = {datetime.datetime: datetime_to_PgTIMESTAMPTZ_Adapter()} |
|---|
| 610 |
|
|---|
| 611 |
class TIMESTAMPTZ(dbtypes.SQL92TIMESTAMP): |
|---|
| 612 |
"""A date and time. Timezone aware.""" |
|---|
| 613 |
default_adapters = {datetime.datetime: datetime_tz_to_PgTIMESTAMPTZ_Adapter()} |
|---|
| 614 |
timezone_aware = True |
|---|
| 615 |
|
|---|
| 616 |
def ddl(self): |
|---|
| 617 |
return "TIMESTAMP WITH TIME ZONE" |
|---|
| 618 |
|
|---|
| 619 |
class DATE(dbtypes.SQL92DATE): |
|---|
| 620 |
"""A calendar date (year, month, day).""" |
|---|
| 621 |
default_adapters = {datetime.date: PgDATE_Adapter()} |
|---|
| 622 |
|
|---|
| 623 |
class TIME(dbtypes.SQL92TIME): |
|---|
| 624 |
"""A time of day.""" |
|---|
| 625 |
pass |
|---|
| 626 |
|
|---|
| 627 |
class INTERVAL(dbtypes.AdjustablePrecisionType): |
|---|
| 628 |
"""A time span.""" |
|---|
| 629 |
default_adapters = {datetime.timedelta: PgINTERVAL_Adapter()} |
|---|
| 630 |
default_pytype = datetime.timedelta |
|---|
| 631 |
|
|---|
| 632 |
|
|---|
| 633 |
class DECIMAL(dbtypes.SQL92DECIMAL): |
|---|
| 634 |
"""An exact numeric of selectable precision.""" |
|---|
| 635 |
|
|---|
| 636 |
|
|---|
| 637 |
|
|---|
| 638 |
|
|---|
| 639 |
|
|---|
| 640 |
|
|---|
| 641 |
|
|---|
| 642 |
synonyms = ['NUMERIC'] |
|---|
| 643 |
_precision = max_precision = 1000 |
|---|
| 644 |
|
|---|
| 645 |
|
|---|
| 646 |
class MONEY(dbtypes.FrozenPrecisionType): |
|---|
| 647 |
"""A currency amount.""" |
|---|
| 648 |
default_pytype = dbtypes.SQL92DECIMAL.default_pytype |
|---|
| 649 |
|
|---|
| 650 |
|
|---|
| 651 |
class INET(dbtypes.FrozenByteType): |
|---|
| 652 |
"""An IPv4 or IPv6 host address, and optionally the subnet.""" |
|---|
| 653 |
|
|---|
| 654 |
|
|---|
| 655 |
|
|---|
| 656 |
|
|---|
| 657 |
|
|---|
| 658 |
|
|---|
| 659 |
|
|---|
| 660 |
|
|---|
| 661 |
|
|---|
| 662 |
|
|---|
| 663 |
|
|---|
| 664 |
|
|---|
| 665 |
|
|---|
| 666 |
|
|---|
| 667 |
|
|---|
| 668 |
variable = False |
|---|
| 669 |
encoding = 'utf8' |
|---|
| 670 |
|
|---|
| 671 |
default_pytype = str |
|---|
| 672 |
default_adapters = {str: adapters.str_to_SQL92VARCHAR(), |
|---|
| 673 |
unicode: adapters.unicode_to_SQL92VARCHAR(), |
|---|
| 674 |
None: adapters.Pickler(), |
|---|
| 675 |
} |
|---|
| 676 |
|
|---|
| 677 |
|
|---|
| 678 |
class PgTypeSet(dbtypes.DatabaseTypeSet): |
|---|
| 679 |
|
|---|
| 680 |
known_types = {'float': [FLOAT4, FLOAT8], |
|---|
| 681 |
'varchar': [TEXT, VARCHAR, VARBIT, BYTEA, NAME], |
|---|
| 682 |
'char': [CHAR, BIT], |
|---|
| 683 |
'int': [INT2, INT4, INT8], |
|---|
| 684 |
'bool': [BOOLEAN], |
|---|
| 685 |
'datetime': [TIMESTAMP, TIMESTAMPTZ], |
|---|
| 686 |
'date': [DATE], |
|---|
| 687 |
'time': [TIME], |
|---|
| 688 |
'timedelta': [INTERVAL], |
|---|
| 689 |
'numeric': [DECIMAL], |
|---|
| 690 |
'other': [MONEY, INET], |
|---|
| 691 |
} |
|---|
| 692 |
|
|---|
| 693 |
|
|---|
| 694 |
|
|---|
| 695 |
class PgDeparser(deparse.SQLDeparser): |
|---|
| 696 |
|
|---|
| 697 |
def builtins_ieq(self, op1, op2): |
|---|
| 698 |
|
|---|
| 699 |
return self.get_expr(op1.adapter.like_op( |
|---|
| 700 |
op1, op2, ignore_case=True, start_only=True, end_only=True), bool) |
|---|
| 701 |
|
|---|
| 702 |
def builtins_year(self, x): |
|---|
| 703 |
return self.get_expr("date_part('year', " + x.sql + ")", int) |
|---|
| 704 |
|
|---|
| 705 |
def builtins_month(self, x): |
|---|
| 706 |
return self.get_expr("date_part('month', " + x.sql + ")", int) |
|---|
| 707 |
|
|---|
| 708 |
def builtins_day(self, x): |
|---|
| 709 |
return self.get_expr("date_part('day', " + x.sql + ")", int) |
|---|
| 710 |
|
|---|
| 711 |
def builtins_now(self): |
|---|
| 712 |
neg, h, m = adapters.localtime_offset() |
|---|
| 713 |
sign = "" |
|---|
| 714 |
if neg: |
|---|
| 715 |
sign = "-" |
|---|
| 716 |
offset = "%s:%s" % (h, m) |
|---|
| 717 |
return self.get_expr("(NOW() AT TIME ZONE INTERVAL '%s%s')" |
|---|
| 718 |
% (sign, offset), datetime.datetime) |
|---|
| 719 |
|
|---|
| 720 |
def builtins_utcnow(self): |
|---|
| 721 |
return self.get_expr("NOW()", datetime.datetime) |
|---|
| 722 |
|
|---|
| 723 |
def builtins_today(self): |
|---|
| 724 |
neg, h, m = adapters.localtime_offset() |
|---|
| 725 |
sign = "" |
|---|
| 726 |
if neg: |
|---|
| 727 |
sign = "-" |
|---|
| 728 |
offset = "%s:%s" % (h, m) |
|---|
| 729 |
return self.get_expr("date_trunc('day', NOW() AT TIME ZONE INTERVAL '%s%s')" |
|---|
| 730 |
% (sign, offset), datetime.date) |
|---|
| 731 |
|
|---|
| 732 |
|
|---|
| 733 |
class PgIndexSet(geniusql.IndexSet): |
|---|
| 734 |
|
|---|
| 735 |
def __delitem__(self, key): |
|---|
| 736 |
"""Drop the specified index.""" |
|---|
| 737 |
|
|---|
| 738 |
self.table.schema.db.execute_ddl('DROP INDEX %s;' % self[key].qname) |
|---|
| 739 |
|
|---|
| 740 |
|
|---|
| 741 |
class PgTable(geniusql.Table): |
|---|
| 742 |
|
|---|
| 743 |
implicit_pkey_indices = True |
|---|
| 744 |
|
|---|
| 745 |
def __init__(self, name, qname, schema, created=False, description=None): |
|---|
| 746 |
geniusql.Table.__init__(self, name=name, qname=qname, schema=schema, |
|---|
| 747 |
created=created, description=description) |
|---|
| 748 |
self.qname = self.schema.qname + "." + self.qname |
|---|
| 749 |
|
|---|
| 750 |
def _grab_new_ids(self, idkeys, conn): |
|---|
| 751 |
newids = {} |
|---|
| 752 |
for idkey in idkeys: |
|---|
| 753 |
col = self[idkey] |
|---|
| 754 |
seq = self.schema.qname + "." + col.sequence_name |
|---|
| 755 |
|
|---|
| 756 |
|
|---|
| 757 |
data, _ = self.schema.db.fetch("SELECT currval('%s');" % seq, conn) |
|---|
| 758 |
newids[idkey] = data[0][0] |
|---|
| 759 |
return newids |
|---|
| 760 |
|
|---|
| 761 |
def drop_primary(self): |
|---|
| 762 |
"""Remove any PRIMARY KEY for this Table.""" |
|---|
| 763 |
db = self.schema.db |
|---|
| 764 |
|
|---|
| 765 |
|
|---|
| 766 |
data, _ = db.fetch("SELECT oid FROM pg_class WHERE " |
|---|
| 767 |
"relname = '%s'" % self.name) |
|---|
| 768 |
table_OID = data[0][0] |
|---|
| 769 |
|
|---|
| 770 |
data, _ = db.fetch("SELECT conname, * FROM pg_constraint WHERE conrelid " |
|---|
| 771 |
"= %s AND contype = 'p'" % table_OID) |
|---|
| 772 |
for row in data: |
|---|
| 773 |
constraint_name = row[0] |
|---|
| 774 |
db.execute('ALTER TABLE %s DROP CONSTRAINT "%s";' |
|---|
| 775 |
% (self.qname, constraint_name)) |
|---|
| 776 |
|
|---|
| 777 |
|
|---|
| 778 |
class PgSchema(geniusql.Schema): |
|---|
| 779 |
|
|---|
| 780 |
tableclass = PgTable |
|---|
| 781 |
indexsetclass = PgIndexSet |
|---|
| 782 |
|
|---|
| 783 |
discover_pg_tables = False |
|---|
| 784 |
|
|---|
| 785 |
def __init__(self, db, name=None): |
|---|
| 786 |
if name is None: |
|---|
| 787 |
name = 'public' |
|---|
| 788 |
geniusql.Schema.__init__(self, db, name) |
|---|
| 789 |
|
|---|
| 790 |
def _get_tables(self, conn=None): |
|---|
| 791 |
data, _ = self.db.fetch("SELECT oid FROM pg_class WHERE relname = " |
|---|
| 792 |
" 'pg_class' and relkind='r'", conn=conn) |
|---|
| 793 |
pgclass_OID = data[0][0] |
|---|
| 794 |
|
|---|
| 795 |
data, _ = self.db.fetch("SELECT oid FROM pg_namespace WHERE " |
|---|
| 796 |
"nspname = '%s'" % self.name, conn=conn) |
|---|
| 797 |
nsoid = data[0][0] |
|---|
| 798 |
|
|---|
| 799 |
data, _ = self.db.fetch( |
|---|
| 800 |
"SELECT c.relname, d.description FROM pg_class c LEFT JOIN " |
|---|
| 801 |
"(SELECT description, objoid FROM pg_description WHERE " |
|---|
| 802 |
"classoid = %s) AS d ON c.oid = d.objoid WHERE c.relnamespace = " |
|---|
| 803 |
"%s and c.relkind = 'r';" % (pgclass_OID, nsoid), conn=conn) |
|---|
| 804 |
return [self.tableclass(name, self.db.quote(name), self, |
|---|
| 805 |
created=True, description=description) |
|---|
| 806 |
for name, description in data |
|---|
| 807 |
if self.discover_pg_tables or not name.startswith("pg_")] |
|---|
| 808 |
|
|---|
| 809 |
def _get_table(self, tablename, conn=None): |
|---|
| 810 |
if (not self.discover_pg_tables) and tablename.startswith("pg_"): |
|---|
| 811 |
raise errors.MappingError( |
|---|
| 812 |
"Table %r not found. Set schema.discover_pg_tables to True " |
|---|
| 813 |
"if you want to discover Postgres system tables (pg_*)." % |
|---|
| 814 |
tablename) |
|---|
| 815 |
|
|---|
| 816 |
data, _ = self.db.fetch( |
|---|
| 817 |
"SELECT oid FROM pg_class WHERE relname = 'pg_class'", conn=conn) |
|---|
| 818 |
pgclass_OID = data[0][0] |
|---|
| 819 |
|
|---|
| 820 |
data, _ = self.db.fetch( |
|---|
| 821 |
"SELECT oid FROM pg_namespace WHERE nspname = '%s'" % self.name, |
|---|
| 822 |
conn=conn) |
|---|
| 823 |
nsoid = data[0][0] |
|---|
| 824 |
|
|---|
| 825 |
data, _ = self.db.fetch( |
|---|
| 826 |
"SELECT c.oid, c.relname, c.relkind FROM pg_class c WHERE " |
|---|
| 827 |
"c.relnamespace = %s AND c.relname = '%s' AND c.relkind in ('r', 'v')" % |
|---|
| 828 |
(nsoid, tablename), conn=conn) |
|---|
| 829 |
for table_OID, name, kind in data: |
|---|
| 830 |
if name == tablename: |
|---|
| 831 |
if kind == 'r': |
|---|
| 832 |
t = self.tableclass(name, self.db.quote(name), |
|---|
| 833 |
self, created=True) |
|---|
| 834 |
else: |
|---|
| 835 |
t = self.viewclass(name, self.db.quote(name), |
|---|
| 836 |
self, created=True) |
|---|
| 837 |
|
|---|
| 838 |
|
|---|
| 839 |
data, _ = self.db.fetch("SELECT description FROM pg_description " |
|---|
| 840 |
"WHERE objoid = %s and classoid = %s" % |
|---|
| 841 |
(table_OID, pgclass_OID), conn=conn) |
|---|
| 842 |
for cell, in data: |
|---|
| 843 |
t.description = cell |
|---|
| 844 |
break |
|---|
| 845 |
|
|---|
| 846 |
return t |
|---|
| 847 |
raise errors.MappingError("Table %r not found." % tablename) |
|---|
| 848 |
|
|---|
| 849 |
def _get_columns(self, table, conn=None): |
|---|
| 850 |
data, _ = self.db.fetch( |
|---|
| 851 |
"SELECT oid FROM pg_namespace WHERE nspname = '%s'" % self.name, |
|---|
| 852 |
conn=conn) |
|---|
| 853 |
nsoid = data[0][0] |
|---|
| 854 |
|
|---|
| 855 |
|
|---|
| 856 |
data, _ = self.db.fetch( |
|---|
| 857 |
"SELECT c.oid FROM pg_class c WHERE c.relnamespace = %s AND " |
|---|
| 858 |
"c.relname = '%s' AND c.relkind in ('r', 'v')" % |
|---|
| 859 |
(nsoid, table.name), conn=conn) |
|---|
| 860 |
table_OID = data[0][0] |
|---|
| 861 |
|
|---|
| 862 |
|
|---|
| 863 |
data, _ = self.db.fetch( |
|---|
| 864 |
"SELECT indkey FROM pg_index WHERE indrelid = %s AND indisprimary" |
|---|
| 865 |
% table_OID, conn=conn) |
|---|
| 866 |
if data: |
|---|
| 867 |
|
|---|
| 868 |
|
|---|
| 869 |
indices = map(int, data[0][0].split(" ")) |
|---|
| 870 |
else: |
|---|
| 871 |
indices = [] |
|---|
| 872 |
|
|---|
| 873 |
|
|---|
| 874 |
sql = ("SELECT attname, atttypid, attnum, attlen, atttypmod " |
|---|
| 875 |
"FROM pg_attribute WHERE attisdropped = False AND " |
|---|
| 876 |
"attrelid = %s" % table_OID) |
|---|
| 877 |
data, _ = self.db.fetch(sql, conn=conn) |
|---|
| 878 |
cols = [] |
|---|
| 879 |
typeset = self.db.typeset |
|---|
| 880 |
for row in data: |
|---|
| 881 |
name = row[0] |
|---|
| 882 |
if name in ('tableoid', 'cmax', 'xmax', 'cmin', 'xmin', |
|---|
| 883 |
'oid', 'ctid'): |
|---|
| 884 |
|
|---|
| 885 |
continue |
|---|
| 886 |
|
|---|
| 887 |
|
|---|
| 888 |
dbtype, _ = self.db.fetch("SELECT typname, typlen FROM pg_type " |
|---|
| 889 |
"WHERE oid = %s" % row[1], conn=conn) |
|---|
| 890 |
try: |
|---|
| 891 |
dbtypetype = typeset.canonicalize(dbtype[0][0].upper()) |
|---|
| 892 |
except KeyError, x: |
|---|
| 893 |
x.args += ("%s.%s" % (table.name, name),) |
|---|
| 894 |
raise |
|---|
| 895 |
dbtype = dbtypetype() |
|---|
| 896 |
|
|---|
| 897 |
c = geniusql.Column(dbtype.default_pytype, dbtype, |
|---|
| 898 |
None, key=row[2] in indices, |
|---|
| 899 |
name=row[0], qname=self.db.quote(row[0])) |
|---|
| 900 |
c.adapter = dbtype.default_adapter(c.pytype) |
|---|
| 901 |
|
|---|
| 902 |
if dbtypetype in (FLOAT4, FLOAT8): |
|---|
| 903 |
dbtype.precision = int(row[3]) |
|---|
| 904 |
elif dbtypetype in (MONEY, DECIMAL): |
|---|
| 905 |
dbtype.precision = int((row[4] >> 16) & 65535) |
|---|
| 906 |
dbtype.scale = int((row[4] & 65535) - 4) |
|---|
| 907 |
|
|---|
| 908 |
if dbtypetype is VARCHAR: |
|---|
| 909 |
|
|---|
| 910 |
bytes = int(row[4] - 4) |
|---|
| 911 |
if bytes > 0: |
|---|
| 912 |
dbtype.bytes = bytes |
|---|
| 913 |
else: |
|---|
| 914 |
raise ValueError("Column %r has illegal size %r" % (name, bytes)) |
|---|
| 915 |
else: |
|---|
| 916 |
bytes = int(row[3]) |
|---|
| 917 |
if bytes > 0: |
|---|
| 918 |
dbtype.bytes = bytes |
|---|
| 919 |
|
|---|
| 920 |
|
|---|
| 921 |
default, _ = self.db.fetch( |
|---|
| 922 |
"SELECT adsrc FROM pg_attrdef WHERE adnum = %s AND adrelid = %s" |
|---|
| 923 |
% (row[2], table_OID), conn=conn) |
|---|
| 924 |
if default: |
|---|
| 925 |
default = default[0][0] |
|---|
| 926 |
if default.startswith("nextval("): |
|---|
| 927 |
|
|---|
| 928 |
c.autoincrement = True |
|---|
| 929 |
sname = seq_name.search(default).group(1) |
|---|
| 930 |
if (sname.startswith(self.name + ".") or |
|---|
| 931 |
sname.startswith(self.qname + ".")): |
|---|
| 932 |
sname = sname.split(".", 1)[1] |
|---|
| 933 |
|
|---|
| 934 |
c.sequence_name = sname |
|---|
| 935 |
|
|---|
| 936 |
sqname = self.qname + "." + sname |
|---|
| 937 |
c.initial = self.db.fetch("SELECT min_value FROM %s" % |
|---|
| 938 |
sqname, conn=conn)[0][0][0] |
|---|
| 939 |
c.default = None |
|---|
| 940 |
else: |
|---|
| 941 |
|
|---|
| 942 |
|
|---|
| 943 |
defval = default.split("::", 1)[0] |
|---|
| 944 |
try: |
|---|
| 945 |
|
|---|
| 946 |
defval = defval.strip("'") |
|---|
| 947 |
c.default = c.adapter.pull(defval, c.dbtype) |
|---|
| 948 |
except ValueError: |
|---|
| 949 |
|
|---|
| 950 |
|
|---|
| 951 |
|
|---|
| 952 |
c.default = default |
|---|
| 953 |
else: |
|---|
| 954 |
c.default = None |
|---|
| 955 |
|
|---|
| 956 |
cols.append(c) |
|---|
| 957 |
return cols |
|---|
| 958 |
|
|---|
| 959 |
def _get_indices(self, table, conn=None): |
|---|
| 960 |
data, _ = self.db.fetch("SELECT oid FROM pg_namespace WHERE " |
|---|
| 961 |
"nspname = '%s'" % self.name, conn=conn) |
|---|
| 962 |
nsoid = data[0][0] |
|---|
| 963 |
|
|---|
| 964 |
|
|---|
| 965 |
data, _ = self.db.fetch("SELECT c.oid FROM pg_class c WHERE " |
|---|
| 966 |
"c.relnamespace = %s AND " |
|---|
| 967 |
"c.relname = '%s' AND c.relkind = 'r'" % |
|---|
| 968 |
(nsoid, table.name), conn=conn) |
|---|
| 969 |
table_OID = data[0][0] |
|---|
| 970 |
|
|---|
| 971 |
indices = [] |
|---|
| 972 |
data, _ = self.db.fetch( |
|---|
| 973 |
"SELECT pg_class.relname, indkey, indisprimary, " |
|---|
| 974 |
"indisunique FROM pg_index LEFT JOIN pg_class " |
|---|
| 975 |
"ON pg_index.indexrelid = pg_class.oid WHERE " |
|---|
| 976 |
"pg_index.indrelid = %s" % table_OID, conn=conn) |
|---|
| 977 |
for row in data: |
|---|
| 978 |
iname = row[0] |
|---|
| 979 |
q_iname = self.db.quote(iname) |
|---|
| 980 |
uniq = bool(row[3]) |
|---|
| 981 |
|
|---|
| 982 |
cols = map(int, row[1].split(" ")) |
|---|
| 983 |
for col in cols: |
|---|
| 984 |
d, _ = self.db.fetch("SELECT attname FROM pg_attribute " |
|---|
| 985 |
"WHERE attrelid = %s AND attnum = %s" |
|---|
| 986 |
% (table_OID, col), conn=conn) |
|---|
| 987 |
if not d: |
|---|
| 988 |
|
|---|
| 989 |
|
|---|
| 990 |
indices.append(geniusql.Index(iname, q_iname, table.name, |
|---|
| 991 |
"<unknown>", uniq)) |
|---|
| 992 |
else: |
|---|
| 993 |
attname = d[0][0] |
|---|
| 994 |
indices.append(geniusql.Index(iname, q_iname, table.name, |
|---|
| 995 |
attname, uniq)) |
|---|
| 996 |
|
|---|
| 997 |
return indices |
|---|
| 998 |
|
|---|
| 999 |
def columnclause(self, column): |
|---|
| 1000 |
"""Return a clause for the given column for CREATE or ALTER TABLE. |
|---|
| 1001 |
|
|---|
| 1002 |
This will be of the form "name type [DEFAULT [x | nextval('seq')]]". |
|---|
| 1003 |
|
|---|
| 1004 |
PostgreSQL creates the sequence in a separate statement. |
|---|
| 1005 |
""" |
|---|
| 1006 |
if column.autoincrement: |
|---|
| 1007 |
default = "nextval('%s.%s')" % (self.qname, column.sequence_name) |
|---|
| 1008 |
else: |
|---|
| 1009 |
default = column.default or "" |
|---|
| 1010 |
if isinstance(default, str): |
|---|
| 1011 |
if issubclass(column.pytype, basestring): |
|---|
| 1012 |
default = column.adapter.push(default, column.dbtype) |
|---|
| 1013 |
else: |
|---|
| 1014 |
default = column.adapter.push(default, column.dbtype) |
|---|
| 1015 |
|
|---|
| 1016 |
if default: |
|---|
| 1017 |
default = " DEFAULT %s" % default |
|---|
| 1018 |
|
|---|
| 1019 |
return '%s %s%s' % (column.qname, column.dbtype.ddl(), default) |
|---|
| 1020 |
|
|---|
| 1021 |
def sequence_name(self, tablename, columnkey): |
|---|
| 1022 |
"Return the SQL sequence name for the given table name and column key." |
|---|
| 1023 |
|
|---|
| 1024 |
|
|---|
| 1025 |
|
|---|
| 1026 |
sname = "%s_%s_seq" % (tablename, columnkey) |
|---|
| 1027 |
maxlen = self.db.sql_name_max_length |
|---|
| 1028 |
if maxlen and len(sname) > maxlen: |
|---|
| 1029 |
|
|---|
| 1030 |
sname = "_%s_seq" % columnkey |
|---|
| 1031 |
sname = tablename[:maxlen - len(sname)] + sname |
|---|
| 1032 |
return self.db.sql_name(sname) |
|---|
| 1033 |
|
|---|
| 1034 |
def index_name(self, table, columnkey): |
|---|
| 1035 |
"""Return the SQL index name for the given table and column key.""" |
|---|
| 1036 |
col = table[columnkey] |
|---|
| 1037 |
if col.key: |
|---|
| 1038 |
return self.db.sql_name("%s_pkey" % col.name) |
|---|
| 1039 |
else: |
|---|
| 1040 |
return self.db.sql_name("%s_%s_idx" % (table.name, col.name)) |
|---|
| 1041 |
|
|---|
| 1042 |
def create_sequence(self, table, column): |
|---|
| 1043 |
"""Create a SEQUENCE for the given column.""" |
|---|
| 1044 |
if column.sequence_name is not None: |
|---|
| 1045 |
self.db.execute_ddl("CREATE SEQUENCE %s.%s START %s;" % |
|---|
| 1046 |
(self.qname, column.sequence_name, |
|---|
| 1047 |
column.initial)) |
|---|
| 1048 |
|
|---|
| 1049 |
def drop_sequence(self, column): |
|---|
| 1050 |
"""Drop a SEQUENCE for the given column.""" |
|---|
| 1051 |
if column.sequence_name is not None: |
|---|
| 1052 |
self.db.execute_ddl("DROP SEQUENCE %s.%s;" % |
|---|
| 1053 |
(self.qname, column.sequence_name)) |
|---|
| 1054 |
|
|---|
| 1055 |
def create(self): |
|---|
| 1056 |
if self.name != "public": |
|---|
| 1057 |
self.db.execute_ddl("CREATE SCHEMA %s" % self.qname) |
|---|
| 1058 |
self.clear() |
|---|
| 1059 |
|
|---|
| 1060 |
def drop(self, restrict=False): |
|---|
| 1061 |
"""Drop this schema (and any contained objects) from the database. |
|---|
| 1062 |
|
|---|
| 1063 |
WARNING: This method's default is to drop any objects owned by the |
|---|
| 1064 |
schema using the CASCADE parameter to DROP SCHEMA. This is contrary |
|---|
| 1065 |
to the PostgreSQL default! If you wish to drop with the RESTRICT |
|---|
| 1066 |
parameter instead, set the 'restrict' argument to True. |
|---|
| 1067 |
""" |
|---|
| 1068 |
if self.name != "public": |
|---|
| 1069 |
if restrict: |
|---|
| 1070 |
restrict = 'RESTRICT' |
|---|
| 1071 |
else: |
|---|
| 1072 |
restrict = 'CASCADE' |
|---|
| 1073 |
self.db.execute_ddl("DROP SCHEMA %s %s;" % (self.qname, restrict)) |
|---|
| 1074 |
self.clear() |
|---|
| 1075 |
|
|---|
| 1076 |
|
|---|
| 1077 |
class PgDatabase(geniusql.Database): |
|---|
| 1078 |
|
|---|
| 1079 |
sql_name_max_length = 63 |
|---|
| 1080 |
quote_all = True |
|---|
| 1081 |
poolsize = 10 |
|---|
| 1082 |
encoding = 'UTF8' |
|---|
| 1083 |
|
|---|
| 1084 |
deparser = PgDeparser |
|---|
| 1085 |
schemaclass = PgSchema |
|---|
| 1086 |
typeset = PgTypeSet() |
|---|
| 1087 |
|
|---|
| 1088 |
def quote(self, name): |
|---|
| 1089 |
if self.quote_all: |
|---|
| 1090 |
name = '"' + name.replace('"', '""') + '"' |
|---|
| 1091 |
return name |
|---|
| 1092 |
|
|---|
| 1093 |
def sql_name(self, name): |
|---|
| 1094 |
name = geniusql.Database.sql_name(self, name) |
|---|
| 1095 |
if not self.quote_all: |
|---|
| 1096 |
name = name.lower() |
|---|
| 1097 |
return name |
|---|
| 1098 |
|
|---|
| 1099 |
def schema(self, name="public"): |
|---|
| 1100 |
return self.schemaclass(self, name) |
|---|
| 1101 |
|
|---|
| 1102 |
def create(self): |
|---|
| 1103 |
c = self.connections._get_conn(master=True) |
|---|
| 1104 |
encoding = self.encoding |
|---|
| 1105 |
if encoding: |
|---|
| 1106 |
encoding = " WITH ENCODING '%s'" % encoding |
|---|
| 1107 |
self.execute_ddl("CREATE DATABASE %s%s" % (self.qname, encoding), c) |
|---|
| 1108 |
self.connections._del_conn(c) |
|---|
| 1109 |
|
|---|
| 1110 |
def exists(self): |
|---|
| 1111 |
"""Return True if this database exists, False otherwise.""" |
|---|
| 1112 |
c = self.connections._get_conn(master=True) |
|---|
| 1113 |
data, _ = self.fetch("SELECT datname FROM pg_database " |
|---|
| 1114 |
"WHERE datname = '%s';" % self.sql_name(self.name), |
|---|
| 1115 |
conn=c) |
|---|
| 1116 |
self.connections._del_conn(c) |
|---|
| 1117 |
return bool(data) |
|---|
| 1118 |
|
|---|
| 1119 |
def drop(self): |
|---|
| 1120 |
c = self.connections._get_conn(master=True) |
|---|
| 1121 |
self.execute_ddl("DROP DATABASE %s;" % self.qname, c) |
|---|
| 1122 |
self.connections._del_conn(c) |
|---|
| 1123 |
|
|---|
| 1124 |
def _get_schemas(self, conn=None): |
|---|
| 1125 |
"""Return a list of schema names.""" |
|---|
| 1126 |
data, _ = self.fetch("SELECT nspname FROM pg_namespace;", conn=conn) |
|---|
| 1127 |
return [name for name, in data if name != 'information_schema' |
|---|
| 1128 |
and not name.startswith('pg_')] |
|---|
| 1129 |
|
|---|
| 1130 |
|
|---|