Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/trunk/storage/storeodbc.py

Revision 43 (checked in by fumanchu, 8 years ago)

1. Changed UnitProperty?.hints['Size'] to 'bytes'. SM's should now assume infinite bytes unless told otherwise.
2. Abstracted adapters into db.py.
3. Adapter coerce methods now take a coltype arg.
4. Changed safe_name functions to SM.identifier methods.
5. Bugfix: COMPARE_OP now uses op indices not values.
6. Moved len() from CALL_FUNCTION to decompiler.functions.
7. Added db.ConstWrapper? to help with LOAD_CONST corner cases.
8. Added MYSQL SM and test suite.

Line 
1 """This module is seriously broken; it hasn't been updated to match
2 framework redesign. But we didn't want to destroy what we learned so far.
3 """
4
5 import fixedpoint
6 import datetime
7 import pickle
8 import dbi, odbc
9
10 import dejavu
11 from dejavu import storage, codewalk
12
13
14 class AdapterFromODBC(storage.Adapter):
15     """Transform incoming values from ODBC to Dejavu datatypes."""
16    
17     def __init__(self, unit):
18         self.unit = unit
19    
20     def consume(self, key, value):
21         expectedType = self.unit.__class__.property_type(key)
22         setattr(self.unit, key, self.coerce(value, expectedType))
23    
24     def to_uni(self, value):
25         if value is None:
26             return None
27         return unicode(value)
28    
29     def pickle(self, value):
30         return pickle.loads(value, 2)
31    
32     def coerce_datetime_datetime(self, value):
33         # Illegal Date/Time values will crash the
34         # app when using value.Format(). Therefore,
35         # grab the value and figure the date ourselves.
36         # Use 1-second resolution only.
37         if value is None:
38             return None
39         else:
40             aDate, aTime = divmod(float(value), 1)
41             aDate = datetime.date.fromordinal(int(aDate) + zeroHour)
42             hour, min = divmod(86400 * aTime, 3600)
43             min, sec = divmod(min, 60)
44             aTime = datetime.time(int(hour), int(min), int(sec))
45             return datetime.datetime.combine(aDate, aTime)
46    
47     def coerce_datetime_date(self, value):
48         # See coerce_datetime
49         if value is None:
50             return None
51         else:
52             aDate, aTime = divmod(float(value), 1)
53             return datetime.date.fromordinal(int(aDate) + zeroHour)
54    
55     def coerce_datetime_time(self, value):
56         # See coerce_datetime
57         if value is None:
58             return None
59         else:
60             aDate, aTime = divmod(float(value), 1)
61             hour, min = divmod(86400 * aTime, 3600)
62             min, sec = divmod(min, 60)
63             return datetime.time(int(hour), int(min), int(sec))
64    
65     coerce_dict = pickle
66    
67     def coerce_fixedpoint_FixedPoint(self, value):
68         if value is None:
69             return None
70         return fixedpoint.FixedPoint(value)
71    
72     def coerce_float(self, value):
73         if value is None:
74             return None
75         return float(value)
76    
77     def coerce_int(self, value):
78         if value is None:
79             return None
80         return int(value)
81     coerce_bool = coerce_int
82    
83     coerce_str = to_uni
84     coerce_unicode = to_uni
85
86
87 class AdapterToODBCSQL(storage.Adapter):
88     """Transform Expression values according to their type for ODBC SQL."""
89    
90     def to_str(self, value):
91         return str(value)
92    
93     def coerce_NoneType(self, value):
94         return "Null"
95    
96     def coerce_bool(self, value):
97         if value:
98             return 'True'
99         return 'False'
100    
101     def coerce_datetime_datetime(self, value):
102         return u"{ts '%s'}" % value.strftime('%Y-%m-%d %H:%M:%S')
103    
104     def coerce_datetime_date(self, value):
105         return u"{d '%s'}" % value.strftime('%Y-%m-%d')
106    
107     def coerce_datetime_time(self, value):
108         return u"{t '%s'}" % value.strftime('%H:%M:%S')
109    
110     coerce_int = to_str
111     coerce_float = to_str
112     coerce_long = to_str
113    
114     def coerce_str(self, value):
115         return "'" + value.replace(u"'", u"''") + "'"
116    
117     def coerce_tuple(self, value):
118         return "(" + ", ".join([self.coerce(x) for x in value]) + ")"
119    
120     coerce_unicode = coerce_str
121
122
123 def _icontainedby(op1, op2, notin=False):
124     if op2.startswith("[") and op2.endswith("]"):
125         # Looking for text in a field. Use Like (reverse terms).
126         value = op2 + " Like '%" + op1[1:-1] + "%'"
127     else:
128         # Looking for field in (a, b, c)
129         value = op1 + " in " + op2
130     if notin:
131         value = "not " + value
132     return value
133
134
135 class ODBCSQLDecompiler(codewalk.LambdaDecompiler):
136     """ODBCSQLDecompiler(expr=logic.Expression).
137     
138     Produce ODBC SQL from a supplied lambda of the form:
139         lambda x, **kw: ...
140     
141     Attributes of x (or whatever the name of the first argument is) will be
142     mapped to table columns. Keyword arguments should be bound to the
143     Expression before calling this decompiler.
144     """
145    
146     sql_cmp_op = ('<', '<=', '=', '<>', '>', '>=', 'in', 'not in')
147     functions = {dejavu.icontains: lambda x, y: x + " Like '%" + y[1:-1] + "%'",
148                  dejavu.icontainedby: _icontainedby,
149                  dejavu.istartswith: lambda x, y: x + " Like '" + y[1:-1] + "%'",
150                  dejavu.iendswith: lambda x, y: x + " Like '%" + y[1:-1] + "'",
151                  dejavu.ieq: lambda x, y: x + " = " + y,
152                  }
153    
154     def __init__(self, expr):
155         self.expr = expr
156         obj = expr.func
157         codewalk.LambdaDecompiler.__init__(self, obj)
158    
159     def code(self):
160         self.imperfect = False
161         self.walk()
162         return self.stack[0], self.imperfect
163    
164     def visit_LOAD_GLOBAL(self, lo, hi):
165         pass
166    
167     def visit_LOAD_FAST(self, lo, hi):
168         pass
169    
170     def visit_LOAD_ATTR(self, lo, hi):
171         self.stack.append("[" + self.co_names[lo + (hi << 8)] + "]")
172    
173     def visit_LOAD_CONST(self, lo, hi):
174         value = self.co_consts[lo + (hi << 8)]
175 ##        # Handle logic functions
176 ##        try:
177 ##            is_logic_func = (value.__module__ == 'logic')
178 ##        except AttributeError:
179 ##            is_logic_func = False
180 ##        if not is_logic_func:
181         value = AdapterToODBCSQL().coerce(value)
182         self.stack.append(value)
183    
184     def visit_BUILD_TUPLE(self, lo, hi):
185         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)])
186         self.stack.append("(" + terms + ")")
187    
188     def visit_BUILD_LIST(self, lo, hi):
189         self.visit_BUILD_TUPLE(lo, hi)
190    
191     def visit_CALL_FUNCTION(self, lo, hi):
192         kwargs = {}
193         for i in range(hi):
194             val = self.stack.pop()
195             key = self.stack.pop()
196             kwargs[key] = val
197         kwargs = [k + "=" + v for k, v in kwargs.iteritems()]
198        
199         args = []
200         for i in range(lo):
201             arg = self.stack.pop()
202             args.append(arg)
203         args.reverse()
204        
205         if kwargs:
206             args += kwargs
207        
208         func = self.stack.pop()
209        
210         # Handle logic functions
211         if func in self.functions:
212             self.stack.append(self.functions[func](*args))
213         else:
214             args = ", ".join(args)
215             if func == "[startswith]":
216                 self.stack[-1] = self.stack[-1] + " Like '" + args[1:-1] + "%'"
217                 self.imperfect = True
218             elif func == "[endswith]":
219                 self.stack[-1] = self.stack[-1] + " Like '%" + args[1:-1] + "'"
220                 self.imperfect = True
221             else:
222                 self.stack.append(func + "(" + args + ")")
223    
224     def visit_COMPARE_OP(self, lo, hi):
225         op2, op1 = self.stack.pop(), self.stack.pop()
226         op = self.sql_cmp_op[lo + (hi << 8)]
227         if op == 'in':
228             self.stack.append(_icontainedby(op1, op2))
229             self.imperfect = True
230         elif op == 'not in':
231             self.stack.append(_icontainedby(op1, op2, True))
232             self.imperfect = True
233         else:
234             if op2.startswith("'") and op2.endswith("'"):
235                 # All ODBC comparison operators for strings are case-insensitive
236                 # by default. Rather than determine column-by-column which
237                 # might be case-sensitive, just flag them all as imperfect.
238                 self.imperfect = True
239             self.stack.append(op1 + " " + op + " " + op2)
240    
241     def binary_op(self, op):
242         op2, op1 = self.stack.pop(), self.stack.pop()
243         self.stack.append(op1 + " " + op + " " + op2)
244    
245     def visit_BINARY_SUBSCR(self):
246         name = self.stack.pop()
247         # name, since formed in LOAD_CONST, has extraneous single-quotes.
248         value = self.expr.kwargs[name[1:-1]]
249         value = AdapterToODBCSQL().coerce(value)
250         self.stack.append(value)
251
252
253 def safe_name(content):
254     return content.replace(u"_", u"")
255
256
257 class StoreIteratorODBC(object):
258     """Iterator for populating Units from storage."""
259    
260     recordset = None
261     unitClass = None
262     server = None
263     fieldNames = None
264    
265     def __init__(self, store, unitClass, expr, server):
266         self.store  = store
267         self.unitClass = unitClass
268         self.expr = expr
269         self.server = server
270         self.colIndices = {}
271         self.fieldTypes = []
272        
273         self.sql, self.imperfect = store.select(unitClass, expr)
274    
275     def populate_unit(self, unit, row):
276         coercer = AdapterFromODBC(unit)
277         for eachKey in unit.__class__.properties():
278             coercer.consume(eachKey, row[self.colIndices[eachKey.lower()]])
279         unit.concrete = True
280         unit.cleanse()
281         return True
282    
283     def load_data(self):
284         anRS = self.store.recordset(self.sql)
285         self.fieldNames = [x[0] for x in anRS.description]
286        
287         for col, x in enumerate(anRS.description):
288             self.colIndices[x[0]] = col
289        
290         self.data = anRS.fetchall()
291    
292     def units(self):
293         self.load_data()
294         if len(self.data) > 0:
295             server = self.server
296             cache = server.cache(self.unitClass)
297             for row in self.data:
298                 # Notice odbc field names are lower case.
299                 ID = unicode(row[self.colIndices[u'id']])
300                 # Search the cache to see if we've already attached this unit.
301                 # Use has_key() instead of 'is' or '==' because the Unit may
302                 # have changed its _properties since the last load.
303                 unit = cache['ID'].get(ID, None)
304                 if unit is None:
305                     unit = self.unitClass(server.namespace)
306                     self.populate_unit(unit, row)
307                     cache.store(unit)
308                 else:
309                     unit = unit[0]
310                 # If our SQL is imperfect, it's OK to ask our server
311                 # to accept() our new Unit, but don't yield it to the
312                 # caller unless it passes evaluate().
313                 if (not self.imperfect) or self.expr.evaluate(unit):
314                     yield unit
315
316
317 class CollectionIteratorODBC(StoreIteratorODBC):
318     """Iterator for populating Unit Collections from storage."""
319    
320     storageManager = None
321    
322     def load_collection(self, unit):
323         # Grab the data dictionary (list of Unit ID's)
324         rsource = (u"SELECT ID FROM %s__%s" %
325                   (self.storageManager.prefix, safe_name(unit.ID)))
326         dataRS = self.storageManager.recordset(rsource)
327         while 1:
328             data = dataRS.fetchone()
329             if data is None:
330                 break
331             fieldNames = [x[0] for x in dataRS.description]
332             datadict = dict(zip(fieldNames, data))
333             unit[unicode(datadict(u'id'))] = None
334         dataRS.close()
335    
336     def units(self):
337         self.load_data()
338         if len(self.data) > 0:
339             server = self.server
340             cache = server.cache(self.unitClass)
341             for row in self.data:
342                 # Notice odbc field names are lower case.
343                 ID = unicode(row[self.colIndices[u'id']])
344                 # Search the cache to see if we've already attached this unit.
345                 # Use has_key() instead of 'is' or '==' because the Unit may
346                 # have changed its _properties since the last load.
347                 unit = cache['ID'].get(ID, None)
348                 if unit is None:
349                     unit = self.unitClass(server.namespace)
350                     self.populate_unit(unit, row)
351                     self.load_collection(unit)
352                     cache.store(unit)
353                 else:
354                     unit = unit[0]
355                 # If our SQL is imperfect, it's OK to ask our server
356                 # to accept() our new Unit, but don't yield it to the
357                 # caller unless it passes evaluate().
358                 if (not self.imperfect) or self.expr.evaluate(unit):
359                     yield unit
360
361
362 savecoercer = AdapterToODBCSQL()
363 class StorageManagerODBC(storage.StorageManager):
364     """StoreManager to save and retrieve Dejavu Units via ODBC."""
365    
366     connection = None
367     prefix = 'djv'
368    
369     def __init__(self, allOptions):
370         self.connection = None
371         try:
372             self.connect(allOptions['Connect'])
373         except KeyError:
374             pass
375        
376         self.prefix = allOptions.get(u'Prefix', u"djv")
377    
378     def __del__(self):
379         if self.connection is not None:
380             self.connection.close()
381    
382     def connect(self, connectString):
383         self.connection = odbc.odbc(connectString)
384    
385     def recordset(self, aQuery):
386         anRS = self.connection.cursor()
387 ##        try:
388         anRS.execute(aQuery)
389 ##        except:
390 ##            raise storage.StorageError(aQuery)
391         return anRS
392    
393     def select(self, unitClass, expr):
394         sql = u"SELECT * FROM [%s]" % (self.prefix + safe_name(unitClass.__name__))
395         w, i = self.where(expr)
396         sql += w
397         return sql, i
398    
399     def where(self, expr):
400         atoms, i = ODBCSQLDecompiler(expr).code()
401         if len(atoms) > 0:
402             return (u" WHERE " + atoms, i)
403         else:
404             return (u"", i)
405    
406     def execute(self, aQuery):
407         cur = self.connection.cursor()
408         cur.execute(aQuery)
409    
410     def loader(self, server, unitClass, expr):
411         if unitClass.__name__ == u'UnitCollection':
412             aLoader = CollectionIteratorODBC
413             aLoader.storageManager = self
414         else:
415             aLoader = StoreIteratorODBC
416         return aLoader(self, unitClass, expr, server)
417    
418     def save(self, unit, forceSave=False):
419         """Update the recordset from the Unit's data.
420         
421         Notice in particular that we do not use the auto-number or
422         sequence generation capabilities within some databases, etc.
423         The ID should be already supplied by the UnitServer(s).
424         """
425         if unit.dirty() or forceSave:
426             # Use an UPDATE command.
427             SETAtoms = [u"%s = %s" % (eachKey, savecoercer.coerce(getattr(unit, eachKey)))
428                         for eachKey in unit.__class__.properties()]
429             tablename = self.prefix + safe_name(unit.__class__.__name__)
430             if len(SETAtoms) > 0:
431                 data = self.recordset("SELECT * FROM %s WHERE ID = '%s';"
432                                       % (tablename, unit.ID)).fetchone()
433                 updateStatement = (u"UPDATE %s SET %s WHERE ID = '%s';"
434                                    % (tablename, u", ".join(SETAtoms), unit.ID))
435                 if data:
436                     self.execute(updateStatement)
437                 else:
438                     # Create a row for the unit.
439                     # Use an INSERT (not a cursor) for better performance.
440                     insertStatement = (u"INSERT INTO %s (ID) VALUES ('%s');"
441                                        % (tablename, unit.ID))
442                     self.execute(insertStatement)
443                     self.execute(updateStatement)
444             else:
445                 # These Units have no data other than IDs.
446                 # Create a row for the unit.
447                 # Use an INSERT (not a cursor) for better performance.
448                 insertStatement = (u"INSERT INTO %s%s (ID) VALUES ('%s');"
449                                    % (tablename, unit.ID))
450                 self.execute(insertStatement)
451             unit.cleanse()
452         return True
453    
454     def max_id(self, unitClass):
455         top1 = u"SELECT TOP 1 ID FROM [%s%s] ORDER BY Val(ID) DESC;"
456         recordsource = top1 % (self.prefix, safe_name(unitClass.__name__))
457         anRS = self.recordset(recordsource)
458         data = anRS.fetchone()
459         if data:
460             fieldNames = [x[0] for x in anRS.description]
461             # ODBC field names are lower case.
462             id_index = fieldNames.index('id')
463             val = long(data[id_index])
464             return val
465         else:
466             return 0
467    
468     def _create_str_storage(unitClass, key):
469         """This basic string handler does not know anything about the size
470         limitations of the particular database. You should use one of the
471         subclasses for your particular database if you need storage for
472         strings over 255 characters."""
473         prop = getattr(unitClass, key)
474         size = prop.hints.get(u'bytes', '255')
475         return u"VARCHAR(%s)" % size
476    
477     createCoercions = {datetime.datetime: lambda x, y: u"TIMESTAMP",
478                        datetime.date: lambda x, y: u"DATE",
479                        datetime.time: lambda x, y: u"TIME",
480                        str: _create_str_storage,
481                        unicode: _create_str_storage,
482                        dict: _create_str_storage,
483                        fixedpoint.FixedPoint: lambda x, y: u"FLOAT",
484                        int: lambda x, y: u"INTEGER",
485                        bool: lambda x, y: u"BIT",
486                        }
487    
488     def create_storage(self, unitClass):
489         fields = []
490         for eachKey in unitClass.properties():
491             eachType = unitClass.property_type(eachKey)
492             aType = self.createCoercions[eachType](unitClass, eachKey)
493             fields.append(u"[%s] %s" % (eachKey, aType))
494         indices = [x + " ASC" for x in unitClass.indices()]
495        
496         tablename = self.prefix + safe_name(unitClass.__name__)
497         createStatement = u"CREATE TABLE [%s] (%s)" % (tablename, ", ".join(fields))
498         try:
499             self.execute(createStatement)
500         except Exception, x:
501             x.args += (createStatement, )
502             raise x
503        
504         for index in indices:
505             indexStatement = (u"CREATE INDEX [%si%s%s] ON [%s%s] (%s)"
506                               % (self.prefix, safe_name(unitClass.__name__), safe_name(index),
507                                  self.prefix, safe_name(unitClass.__name__), index))
508             try:
509                 self.execute(indexStatement)
510             except Exception, x:
511                 x.args += (indexStatement, )
512                 raise x
513        
514         return True
515
Note: See TracBrowser for help on using the browser.