Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storepypgsql.py

Revision 193 (checked in by fumanchu, 7 years ago)

Fix for #51 (remove expanded columns). If anyone objects, this can be reinstated with very little work.

  • Property svn:eol-style set to native
Line 
1 # Use libpq directly to avoid all of the DB-API overhead.
2 from pyPgSQL import libpq
3
4 import datetime
5
6 try:
7     # Builtin in Python 2.5?
8     decimal
9 except NameError:
10     try:
11         # Module in Python 2.3, 2.4
12         import decimal
13     except ImportError:
14         pass
15
16 import dejavu
17 from dejavu.storage import db
18
19
20 class AdapterToPgSQL(db.AdapterToSQL):
21    
22     like_escapes = [("%", r"\\%"), ("_", r"\\_")]
23
24
25 class FieldTypeAdapterPgSQL(db.FieldTypeAdapter):
26    
27     def coerce_int(self, cls, key):
28         prop = getattr(cls, key)
29         if isinstance(cls.sequencer, dejavu.UnitSequencerInteger):
30             if key in cls.identifiers:
31                 return ("INTEGER DEFAULT nextval('%s_%s_seq') NOT NULL"
32                         % (cls.__name__, key))
33         bytes = int(prop.hints.get('bytes', '4'))
34         if bytes == 1:
35             return "BOOLEAN"
36         elif bytes == 2:
37             return "SMALLINT"
38         else:
39             return "INTEGER"
40
41
42 class PgSQLDecompiler(db.SQLDecompiler):
43    
44     def dejavu_icontainedby(self, op1, op2):
45         if isinstance(op1, db.ConstWrapper):
46             # Looking for text in a field. Use ILike (reverse terms).
47             return op2 + " ILIKE '%" + self.adapter.escape_like(op1) + "%'"
48         else:
49             # Looking for field in (a, b, c).
50             # Force all args to lowercase for case-insensitive comparison.
51             atoms = [self.adapter.coerce(x).lower() for x in op2.basevalue]
52             return "LOWER(%s) IN (%s)" % (op1, ", ".join(atoms))
53    
54     def dejavu_istartswith(self, x, y):
55         return x + " ILIKE '" + self.adapter.escape_like(y) + "%'"
56    
57     def dejavu_iendswith(self, x, y):
58         return x + " ILIKE '%" + self.adapter.escape_like(y) + "'"
59    
60     def dejavu_ieq(self, x, y):
61         # ILIKE with no wildcards should behave like ieq.
62         return x + " ILIKE '" + self.adapter.escape_like(y) + "'"
63    
64     def dejavu_year(self, x):
65         return "date_part('year', " + x + ")"
66
67
68
69 class StorageManagerPgSQL(db.StorageManagerDB):
70     """StoreManager to save and retrieve Units via pyPgSQL 1.35."""
71    
72     sql_name_max_length = 63
73     close_connection_method = 'finish'
74     decompiler = PgSQLDecompiler
75     toAdapter = AdapterToPgSQL()
76     typeAdapter = FieldTypeAdapterPgSQL()
77    
78     def __init__(self, name, arena, allOptions={}):
79         db.StorageManagerDB.__init__(self, name, arena, allOptions)
80        
81         # connstring = (host=h port=p dbname=d user=u password=p options=o tty=t)
82         self.connstring = allOptions['Connect']
83         atoms = self.connstring.split(" ")
84         for atom in atoms:
85             k, v = atom.split("=", 1)
86             setattr(self, k, v)
87    
88     def sql_name(self, name, quoted=True):
89         name = db.StorageManagerDB.sql_name(self, name, quoted)
90         if quoted:
91             name = '"' + name.replace('"', '""') + '"'
92         return name
93    
94     def _get_conn(self):
95         try:
96             return libpq.PQconnectdb(self.connstring)
97         except libpq.DatabaseError, x:
98             if x.args[0].startswith('could not connect'):
99                 raise db.OutOfConnectionsError
100             raise
101    
102     def _template_conn(self):
103         atoms = self.connstring.split(" ")
104         tmplconn = ""
105         for atom in atoms:
106             k, v = atom.split("=", 1)
107             if k == 'dbname': v = 'template1'
108             tmplconn += "%s=%s " % (k, v)
109         return libpq.PQconnectdb(tmplconn)
110    
111     def version(self):
112         c = self._template_conn()
113         v = c.version
114         c.finish()
115         return v
116    
117     def fetch(self, query, conn=None):
118         """fetch(query, conn=None) -> rowdata, columns."""
119         res = self.execute(query, conn)
120        
121         columns = []
122         if res.resultType != libpq.EMPTY_QUERY:
123             for index in xrange(res.nfields):
124                 columns.append((res.fname(index), res.ftype(index)))
125        
126         data = [[res.getvalue(row, col) for col in xrange(res.nfields)]
127                 for row in xrange(res.ntuples)]
128         res.clear()
129        
130         return data, columns
131    
132     def _seq_UnitSequencerInteger(self, unit):
133         """Reserve a unit using the table's SERIAL field."""
134         cls = unit.__class__
135         clsname = cls.__name__
136         tablename = self.table_name(clsname)
137        
138         fields = []
139         values = []
140         for key in cls.properties:
141             typename = self.typeAdapter.coerce(cls, key)
142             if 'nextval' in typename:
143                 # Skip this field, since we're using a sequencer
144                 continue
145             val = self.toAdapter.coerce(getattr(unit, key))
146             fields.append(self.column_name(clsname, key))
147             values.append(val)
148        
149         fields = ", ".join(fields)
150         values = ", ".join(values)
151         self.execute('INSERT INTO %s (%s) VALUES (%s);' %
152                      (str(tablename), fields, values))
153        
154         # Grab the new ID. This is threadsafe because db.reserve has a mutex.
155         data, col_defs = self.fetch("SELECT last_value FROM %s_%s_seq;"
156                                     % (clsname, cls.identifiers[0]))
157         setattr(unit, cls.identifiers[0], data[0][0])
158    
159     #                               Schemas                               #
160    
161     def create_database(self):
162         c = self._template_conn()
163         self.execute('CREATE DATABASE %s' % self.sql_name(self.dbname), c)
164         c.finish()
165    
166     def drop_database(self):
167         c = self._template_conn()
168         self.execute("DROP DATABASE %s;" % self.sql_name(self.dbname), c)
169         c.finish()
170    
171     def has_storage(self, cls):
172         # For some odd reason, libpq errors if you try to filter by tablename.
173         sql = "SELECT tablename FROM pg_tables"
174         data, cols = self.fetch(sql)
175         return [self.table_name(cls.__name__, quoted=False)] in data
176    
177     def create_storage(self, cls):
178         """Create storage for the given class."""
179         clsname = cls.__name__
180         tablename = self.table_name(clsname)
181         typename = self.typeAdapter.coerce
182        
183         fields = []
184         for key in cls.properties:
185             dbtype = typename(cls, key)
186             if 'nextval' in dbtype:
187                 self.execute("CREATE SEQUENCE %s_%s_seq START %s;"
188                              % (clsname, key, cls.sequencer.initial))
189             fields.append('%s %s' % (self.column_name(clsname, key), dbtype))
190         self.execute('CREATE TABLE %s (%s);' % (tablename, ", ".join(fields)))
191        
192         for index in cls.indices():
193             i = self.table_name("i" + clsname + index)
194             self.execute('CREATE INDEX %s ON %s (%s);' %
195                          (i, tablename, self.column_name(clsname, index)))
196    
197     def drop_index(self, cls, name):
198         clsname = cls.__name__
199         for i in self.get_indices(clsname):
200             if i.colname == name:
201                 self.execute('DROP INDEX %s;' % self.sql_name(i.name))
202    
203     def get_tables(self, conn=None):
204         data, _ = self.fetch("SELECT tablename FROM pg_tables WHERE "
205                              "schemaname not in ('information_schema', 'pg_catalog')",
206                              conn=conn)
207         return [db.Table(row[0]) for row in data]
208    
209     def get_columns(self, tablename=None, conn=None):
210         data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'"
211                              % tablename, conn=conn)
212         table_OID = data[0][0]
213         sql = ("SELECT attname, atttypid, attnum, attlen "
214                "FROM pg_attribute WHERE attrelid = %s" % table_OID)
215         data, _ = self.fetch(sql, conn=conn)
216         cols = []
217         for row in data:
218             name = row[0]
219             if name in ('tableoid', 'cmax', 'xmax', 'cmin', 'xmin',
220                         'oid', 'ctid'):
221                 # This is a column which PostgreSQL defines automatically
222                 continue
223            
224             # Data type
225             dbtype, _ = self.fetch("SELECT typname, typlen FROM pg_type "
226                                     "WHERE oid = %s" % row[1])
227             if dbtype:
228                 dbtype = dbtype[0][0]
229                 if dbtype in ('int2', 'int4'):
230                     dbtype = int
231                 elif dbtype == 'bool':
232                     dbtype = bool
233                 elif dbtype == 'int8':
234                     dbtype = long
235                 elif dbtype in ('float4', 'float8', 'money'):
236                     dbtype = float
237                 elif dbtype == 'numeric':
238                     dbtype = decimal.Decimal
239                 elif dbtype == 'date':
240                     dbtype = datetime.date
241                 elif dbtype in ('timestamp', 'timestamptz'):
242                     dbtype = datetime.datetime
243                 elif dbtype in ('time', 'timetz'):
244                     dbtype = datetime.time
245                 elif dbtype in ('char', 'varchar', 'bpchar', 'text'):
246                     dbtype = str
247             else:
248                 dbtype = None
249            
250             # Default value
251             default, _ = self.fetch("SELECT adsrc FROM pg_attrdef "
252                                     "WHERE adnum = %s AND adrelid = %s"
253                                     % (row[2], table_OID))
254             if default:
255                 default = default[0][0]
256                 if default.startswith("nextval("):
257                     default = None
258             else:
259                 default = None
260            
261             c = db.Column(row[0], dbtype, default)
262            
263             bytes = row[3]
264             if bytes > 0:
265                 c.hints['bytes'] = bytes
266            
267             cols.append(c)
268         return cols
269    
270     def get_indices(self, tablename, conn=None):
271         # Get the OID of the parent table.
272         data, _ = self.fetch("SELECT oid FROM pg_class WHERE relname = '%s'"
273                              % tablename, conn=conn)
274         if not data:
275             return []
276        
277         table_OID = data[0][0]
278         indices = []
279         data, _ = self.fetch("SELECT pg_class.relname, indkey, indisprimary, "
280                              "indisunique FROM pg_index LEFT JOIN pg_class "
281                              "ON pg_index.indexrelid = pg_class.oid WHERE "
282                              "pg_index.indrelid = %s" % table_OID, conn=conn)
283         for row in data:
284             cols = map(int, row[1].split(" "))
285             for col in cols:
286                 d, _ = self.fetch("SELECT attname FROM pg_attribute "
287                                   "WHERE attrelid = %s AND attnum = %s"
288                                   % (table_OID, col), conn=conn)
289                 indices.append(db.Index(row[0], tablename, d[0][0],
290                                         bool(row[2]), bool(row[3])))
291        
292         return indices
293
Note: See TracBrowser for help on using the browser.