Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

Changeset 44

Show
Ignore:
Timestamp:
12/31/04 03:43:46
Author:
fumanchu
Message:

1. Changed all storage tests to use common zoo_fixture.py.
2. Added SQLite Storage Manager.
3. Fixed icontains, icontainedby.
4. Added storage.Version class for comparing version strings.
5. Moved decompiler code into common db.SQLDecompiler.
6. SQLDecompiler.code() sets imperfect, but doesn't return it anymore.
7. Fixed bug in storeado expanded columns (no repr of None).
8. All db SM's now have a _get_conn method in prep for a common db.StorageManagerDB class.
9. All db SM's now have create_database and drop_database.
10. Added separate decom for mysql 4.1.1. Autodetects version.
11. Bugfix: SM.shutdown() methods weren't closing current thread's connection (if threaded).
12. Started cleaning up storeodbc.
13. Bugfix in storeshelve if Unit.ID was non-string.
14. Changed Zoo.Animal.Options to .PreviousZoos? (a list).

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/__init__.py

    r43 r44  
    908908def icontains(a, b): 
    909909    """Case-insensitive test b in a. Note the operand order.""" 
    910     if a is None or b is None: 
    911         return False 
    912     return b.lower() in a.lower() 
     910    return icontainedby(b, a) 
    913911 
    914912def icontainedby(a, b): 
    915913    """Case-insensitive test a in b. Note the operand order.""" 
    916     if a is None or b is None: 
    917         return False 
    918     return a.lower() in b.lower() 
     914    if isinstance(b, basestring): 
     915        # Looking for text in a string. 
     916        if a is None or b is None: 
     917            return False 
     918        return a.lower() in b.lower() 
     919    else: 
     920        # Looking for field in (a, b, c). 
     921        # Force all args to lowercase for case-insensitive comparison. 
     922        if a is None or not b: 
     923            return False 
     924        return a.lower() in [x.lower() for x in b] 
    919925 
    920926def istartswith(a, b): 
  • trunk/containers.py

    r43 r44  
    369369     
    370370    def add(self, **facets): 
     371        """add(**facets) -> add a row to the Prism.""" 
    371372        for name, row in self.facets.iteritems(): 
    372373            row.append(facets.get(name)) 
    373374     
    374375    def row_number(self, **facet): 
     376        """row_number(**facet) -> row number, given one facet's key+value.""" 
    375377        k, v = facet.popitem() 
    376378        f = self.facets[k] 
     
    381383     
    382384    def row(self, **facet): 
     385        """row(**facet) -> all facets of a row given one facet's key+value.""" 
    383386        number = self.row_number(**facet) 
    384387        return dict([(k, v[number]) for k, v in self.facets.iteritems()]) 
  • trunk/storage/__init__.py

    r43 r44  
    11"""Storage Managers for Dejavu.""" 
    22 
     3import re 
    34import datetime 
    45import threading 
     
    310311        return xform(value) 
    311312 
     313 
     314class Version(object): 
     315     
     316    def __init__(self, atoms): 
     317        if isinstance(atoms, basestring): 
     318            self.atoms = re.split(r'\W', atoms) 
     319        else: 
     320            self.atoms = [str(x) for x in atoms] 
     321     
     322    def __cmp__(self, other): 
     323        index = 0 
     324        while index < len(self.atoms) and index < len(other.atoms): 
     325            mine, theirs = self.atoms[index], other.atoms[index] 
     326            if mine.isdigit() and theirs.isdigit(): 
     327                mine, theirs = int(mine), int(theirs) 
     328            if mine < theirs: 
     329                return -1 
     330            if mine > theirs: 
     331                return 1 
     332            index += 1 
     333        if index < len(other.atoms): 
     334            return -1 
     335        if index < len(self.atoms): 
     336            return 1 
     337        return 0 
     338 
  • trunk/storage/db.py

    r43 r44  
    1515""" 
    1616 
     17from types import FunctionType 
     18from dejavu import codewalk 
     19 
    1720try: 
    1821    import cPickle as pickle 
     
    98101        # precision, but... meh. That's why they're called "hints". 
    99102        return u"NUMERIC" 
     103    coerce_decimal_Decimal = coerce_decimal 
    100104     
    101105    def coerce_fixedpoint_FixedPoint(self, cls, key): 
     
    292296 
    293297 
     298# -------------------------- SQL DECOMPILATION -------------------------- # 
     299 
    294300class ConstWrapper(str): 
    295     """Wraps a constant for use in a decompiler stack. 
     301    """Wraps a constant for use in SQLDecompiler's stack. 
    296302     
    297303    When we hit LOAD_CONST while decompiling, we occasionally need to keep 
     
    314320 
    315321 
     322# Stack sentinels 
     323table_arg = object() 
     324kw_arg = object() 
     325# cannot_represent exists so that a portion of an Expression can be 
     326# labeled imperfect. For example, the function dejavu.iscurrentweek 
     327# rarely has an SQL equivalent. All Units (which match the rest of  
     328# the Expression) will be recalled; they can then be compared in 
     329# expr.evaluate(unit). 
     330cannot_represent = object() 
     331 
     332 
     333class SQLDecompiler(codewalk.LambdaDecompiler): 
     334    """SQLDecompiler(tablename, expr, adapter=AdapterToSQL). 
     335     
     336    Produce SQL from a supplied Expression object, with a lambda of the form: 
     337        lambda x, **kw: ... 
     338     
     339    Attributes of x (or whatever the name of the first argument is) will be 
     340    mapped to table columns. Keyword arguments should be bound using 
     341    Expression.bind_args before calling this decompiler. 
     342    """ 
     343     
     344    # Some constants are function or class objects, 
     345    # which should not be coerced. 
     346    no_coerce = (FunctionType, 
     347                 type, 
     348                 type(len),       # <type 'builtin_function_or_method'> 
     349                 ) 
     350     
     351    sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in') 
     352     
     353    def __init__(self, tablename, expr, adapter=AdapterToSQL): 
     354        self.tablename = tablename 
     355        self.expr = expr 
     356        self.adapter = adapter() 
     357        obj = expr.func 
     358        codewalk.LambdaDecompiler.__init__(self, obj) 
     359     
     360    def code(self): 
     361        self.imperfect = False 
     362        self.walk() 
     363        # After walk(), self.stack should be reduced to a single string, 
     364        # which is the SQL representation of our Expression. 
     365        result = self.stack[0] 
     366        if result is cannot_represent: 
     367            # The entire expression could not be evaluated. 
     368            result = self.adapter.coerce(True) 
     369        return result 
     370     
     371    def visit_target(self, terms): 
     372        """A target is an AND or OR test.""" 
     373        comp = self.stack.pop() 
     374        while terms: 
     375            term, operation = terms.pop() 
     376            if term is not cannot_represent: 
     377                if comp is cannot_represent: 
     378                    comp = term 
     379                else: 
     380                    comp = "(%s) %s (%s)" % (term, operation.upper(), comp) 
     381        self.stack.append(comp) 
     382     
     383    def visit_LOAD_DEREF(self, lo, hi): 
     384        raise ValueError("Illegal reference found in %s." % self.expr) 
     385     
     386    def visit_LOAD_GLOBAL(self, lo, hi): 
     387        raise ValueError("Illegal global found in %s." % self.expr) 
     388     
     389    def visit_LOAD_FAST(self, lo, hi): 
     390        if lo + (hi << 8) < self.co_argcount: 
     391            self.stack.append(table_arg) 
     392        else: 
     393            self.stack.append(kw_arg) 
     394     
     395    def visit_LOAD_ATTR(self, lo, hi): 
     396        name = self.co_names[lo + (hi << 8)] 
     397        tos = self.stack.pop() 
     398        if tos is table_arg: 
     399            # Call another function to make subclassing easier. 
     400            self.stack.append(self.column_name(name)) 
     401        else: 
     402            # tos.name will reference an attribute of the tos object. 
     403            # Stick the tos and name in a tuple for later processing. 
     404            self.stack.append((tos, name)) 
     405     
     406    def visit_LOAD_CONST(self, lo, hi): 
     407        val = self.co_consts[lo + (hi << 8)] 
     408        if not isinstance(val, self.no_coerce): 
     409            val = ConstWrapper(val, self.adapter.coerce(val)) 
     410        self.stack.append(val) 
     411     
     412    def visit_BUILD_TUPLE(self, lo, hi): 
     413        terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)]) 
     414        self.stack.append("(" + terms + ")") 
     415     
     416    visit_BUILD_LIST = visit_BUILD_TUPLE 
     417     
     418    def visit_CALL_FUNCTION(self, lo, hi): 
     419        kwargs = {} 
     420        for i in range(hi): 
     421            val = self.stack.pop() 
     422            key = self.stack.pop() 
     423            kwargs[key] = val 
     424        kwargs = [k + "=" + v for k, v in kwargs.iteritems()] 
     425         
     426        args = [] 
     427        for i in range(lo): 
     428            arg = self.stack.pop() 
     429            args.append(arg) 
     430        args.reverse() 
     431         
     432        if kwargs: 
     433            args += kwargs 
     434         
     435        func = self.stack.pop() 
     436         
     437        # Handle function objects. 
     438        if isinstance(func, tuple): 
     439            tos, name = func 
     440            dispatch = getattr(self, "attr_" + name, None) 
     441            if dispatch: 
     442                self.stack.append(dispatch(tos, *args)) 
     443                return 
     444        else: 
     445            funcname = func.__module__ + "_" + func.__name__ 
     446            funcname = funcname.replace(".", "_") 
     447            if funcname.startswith("_"): 
     448                funcname = "func" + funcname 
     449            dispatch = getattr(self, funcname, None) 
     450            if dispatch: 
     451                self.stack.append(dispatch(*args)) 
     452                return 
     453         
     454        if self.stack: 
     455            self.stack[-1] = cannot_represent 
     456        else: 
     457            self.stack = [cannot_represent] 
     458        self.imperfect = True 
     459     
     460    def visit_COMPARE_OP(self, lo, hi): 
     461        op2, op1 = self.stack.pop(), self.stack.pop() 
     462        if op1 is cannot_represent or op2 is cannot_represent: 
     463            self.stack.append(cannot_represent) 
     464        else: 
     465            op = lo + (hi << 8) 
     466            if op in (6, 7):     # in, not in 
     467                value = self.containedby(op1, op2) 
     468                if op == 7: 
     469                    value = "NOT " + value 
     470                self.stack.append(value) 
     471            elif op1 == 'NULL': 
     472                if op == 2: 
     473                    self.stack.append(op2 + " IS NULL") 
     474                elif op == 3: 
     475                    self.stack.append(op2 + " IS NOT NULL") 
     476                else: 
     477                    raise ValueError("Non-equality Null comparisons not allowed.") 
     478            elif op2 == 'NULL': 
     479                if op == 2: 
     480                    self.stack.append(op1 + " IS NULL") 
     481                elif op == 3: 
     482                    self.stack.append(op1 + " IS NOT NULL") 
     483                else: 
     484                    raise ValueError("Non-equality Null comparisons not allowed.") 
     485            else: 
     486                # Comparison operators for strings are case-sensitive in PG et al. 
     487                self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2) 
     488     
     489    def binary_op(self, op): 
     490        op2, op1 = self.stack.pop(), self.stack.pop() 
     491        self.stack.append(op1 + " " + op + " " + op2) 
     492     
     493    def visit_BINARY_SUBSCR(self): 
     494        # The only BINARY_SUBSCR used in Expressions should be kwargs[key]. 
     495        name = self.stack.pop() 
     496        tos = self.stack.pop() 
     497        if tos is not kw_arg: 
     498            raise ValueError("Subscript %s of %s object not allowed." 
     499                             % (name, tos)) 
     500        # name, since formed in LOAD_CONST, may have extraneous quotes. 
     501        name = name.strip("'\"") 
     502        value = self.expr.kwargs[name] 
     503        value = self.adapter.coerce(value) 
     504        self.stack.append(value) 
     505     
     506    def visit_UNARY_NOT(self): 
     507        op = self.stack.pop() 
     508        if op is cannot_represent: 
     509            self.stack.append(cannot_represent) 
     510        else: 
     511            self.stack.append("NOT (" + op + ")") 
     512     
     513     
     514    def column_name(self, name): 
     515        # This is valid SQL for PostgreSQL only and should be overridden. 
     516        return '"%s"."%s"' % (self.tablename, name) 
     517     
     518    # --------------------------- Dispatchees --------------------------- # 
     519     
     520    def attr_startswith(self, tos, arg): 
     521        return tos + " LIKE '" + arg.strip("'\"") + "%'" 
     522     
     523    def attr_endswith(self, tos, arg): 
     524        return tos + " LIKE '%" + arg.strip("'\"") + "'" 
     525     
     526    def containedby(self, op1, op2): 
     527        if isinstance(op1, ConstWrapper): 
     528            # Looking for text in a field. Use Like (reverse terms). 
     529            return op2 + " LIKE '%" + op1.strip("'\"") + "%'" 
     530        else: 
     531            # Looking for field in (a, b, c) 
     532            atoms = [self.adapter.coerce(x) for x in op2.basevalue] 
     533            return op1 + " IN (" + ", ".join(atoms) + ")" 
     534     
     535    def dejavu_icontainedby(self, op1, op2): 
     536        if isinstance(op1, ConstWrapper): 
     537            # Looking for text in a field. Use Like (reverse terms). 
     538            return "LOWER(" + op2 + ") LIKE '%" + op1.strip("'\"").lower() + "%'" 
     539        else: 
     540            # Looking for field in (a, b, c). 
     541            # Force all args to lowercase for case-insensitive comparison. 
     542            atoms = [self.adapter.coerce(x).lower() for x in op2.basevalue] 
     543            return "LOWER(%s) IN (%s)" % (op1, ", ".join(atoms)) 
     544     
     545    def dejavu_icontains(self, x, y): 
     546        return self.dejavu_icontainedby(y, x) 
     547     
     548    def dejavu_istartswith(self, x, y): 
     549        y = y.strip("'\"") 
     550        return "LOWER(" + x + ") LIKE '" + y + "%'" 
     551     
     552    def dejavu_iendswith(self, x, y): 
     553        y = y.strip("'\"") 
     554        return "LOWER(" + x + ") LIKE '%" + y + "'" 
     555     
     556    def dejavu_ieq(self, x, y): 
     557        return "LOWER(" + x + ") = LOWER(" + y + ")" 
     558     
     559    def dejavu_now(self): 
     560        return "NOW()" 
     561     
     562    def dejavu_today(self): 
     563        return "CURRENT_DATE" 
     564     
     565    def dejavu_year(self, x): 
     566        return "YEAR(" + x + ")" 
     567     
     568    def func__builtin___len(self, x): 
     569        return "LENGTH(" + x + ")" 
     570 
  • trunk/storage/storeado.py

    r43 r44  
    2020 
    2121import dejavu 
    22 from dejavu import storage, codewalk, logic 
     22from dejavu import storage, logic 
    2323from dejavu.storage import db 
    2424import recur 
     
    191191 
    192192 
    193 def icontainedby(op1, op2, notin=False): 
    194     if isinstance(op1, db.ConstWrapper): 
    195         # Looking for text in a field. Use Like (reverse terms). 
    196         # LIKE is case-insensitive in MS SQL Server. 
    197         value = op2 + " Like '%" + op1[1:-1] + "%'" 
    198     else: 
    199         # Looking for field in (a, b, c) 
    200         value = op1 + " in (" + ", ".join([AdapterToADOSQL().coerce(x) 
    201                                            for x in op2.basevalue]) + ")" 
    202     if notin: 
    203         value = "not " + value 
    204     return value 
    205  
    206  
    207 # Stack sentinels 
    208 cannot_represent = object() 
    209 table_arg = object() 
    210 kw_arg = object() 
    211  
    212 class ADOSQLDecompiler(codewalk.LambdaDecompiler): 
    213     """ADOSQLDecompiler(store, unitClass, expr, adapter=AdapterToADOSQL()). 
    214      
    215     Produce SQL from a supplied Expression object, with a lambda of the form: 
    216         lambda x, **kw: ... 
    217      
    218     Attributes of x (or whatever the name of the first argument is) will be 
    219     mapped to table columns. Keyword arguments should be bound to the 
    220     Expression before calling this decompiler. 
    221     """ 
    222      
    223     sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in') 
    224     functions = {dejavu.icontains: lambda x, y: x + " Like '%" + y[1:-1] + "%'", 
    225                  dejavu.icontainedby: icontainedby, 
    226                  dejavu.istartswith: lambda x, y: x + " Like '" + y[1:-1] + "%'", 
    227                  dejavu.iendswith: lambda x, y: x + " Like '%" + y[1:-1] + "'", 
    228                  dejavu.ieq: lambda x, y: x + " = " + y, 
    229                  dejavu.now: lambda: "getdate()", 
    230                  dejavu.today: lambda: "DATEADD(dd, DATEDIFF(dd,0,getdate()), 0)", 
    231                  dejavu.year: lambda x: "YEAR(" + x + ")", 
    232                  len: lambda x: "Len(" + x + ")", 
    233                  } 
    234      
    235     def __init__(self, store, unitClass, expr, adapter=AdapterToADOSQL()): 
    236         self.tablename = store.prefix + unitClass.__name__ 
    237         self.expr = expr 
    238         self.adapter = adapter 
    239         obj = expr.func 
    240         codewalk.LambdaDecompiler.__init__(self, obj) 
    241      
    242     def code(self): 
    243         self.imperfect = False 
    244         self.walk() 
    245         result = self.stack[0] 
    246         if result is cannot_represent: 
    247             result = 'TRUE' 
    248         return result, self.imperfect 
    249      
    250     def visit_target(self, terms): 
    251         """A target is an AND or OR test.""" 
    252         comp = self.stack.pop() 
    253         while terms: 
    254             term, operation = terms.pop() 
    255             # All this checking of cannot_represent is done so that a 
    256             # function can be labeled imperfect. For example, the function 
    257             # dejavu.iscurrentweek has no ADO SQL equivalent. All Units 
    258             # (which match the rest of the Expression) will be recalled. 
    259             # They can then be compared in expr.evaluate(unit). 
    260             if term is not cannot_represent: 
    261                 if comp is cannot_represent: 
    262                     comp = term 
    263                 else: 
    264                     comp = "(%s) %s (%s)" % (term, operation, comp) 
    265         self.stack.append(comp) 
    266      
    267     def visit_LOAD_DEREF(self, lo, hi): 
    268         raise ValueError("Illegal reference found in %s." % self.expr) 
    269      
    270     def visit_LOAD_GLOBAL(self, lo, hi): 
    271         raise ValueError("Illegal global found in %s." % self.expr) 
    272      
    273     def visit_LOAD_FAST(self, lo, hi): 
    274         if lo + (hi << 8) < self.co_argcount: 
    275             self.stack.append(table_arg) 
    276         else: 
    277             self.stack.append(kw_arg) 
    278      
    279     def visit_LOAD_ATTR(self, lo, hi): 
    280         name = self.co_names[lo + (hi << 8)] 
    281         tos = self.stack.pop() 
    282         if tos is table_arg: 
    283             self.stack.append("[%s].[%s]" % (self.tablename, name)) 
    284         else: 
    285             self.stack.append((tos, name)) 
    286      
    287     def visit_LOAD_CONST(self, lo, hi): 
    288         val = self.co_consts[lo + (hi << 8)] 
    289         # Some constants are function or class objects, 
    290         # which should not be coerced. 
    291         no_coerce = (FunctionType, type, 
    292                      type(len),       # <type 'builtin_function_or_method'> 
    293                      ) 
    294         if not isinstance(val, no_coerce): 
    295             val = db.ConstWrapper(val, self.adapter.coerce(val)) 
    296         self.stack.append(val) 
    297      
    298     def visit_BUILD_TUPLE(self, lo, hi): 
    299         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)]) 
    300         self.stack.append("(" + terms + ")") 
    301      
    302     visit_BUILD_LIST = visit_BUILD_TUPLE 
    303      
    304     def visit_CALL_FUNCTION(self, lo, hi): 
    305         kwargs = {} 
    306         for i in range(hi): 
    307             val = self.stack.pop() 
    308             key = self.stack.pop() 
    309             kwargs[key] = val 
    310         kwargs = [k + "=" + v for k, v in kwargs.iteritems()] 
    311          
    312         args = [] 
    313         for i in range(lo): 
    314             arg = self.stack.pop() 
    315             args.append(arg) 
    316         args.reverse() 
    317          
    318         if kwargs: 
    319             args += kwargs 
    320          
    321         func = self.stack.pop() 
    322          
    323         # Handle function objects. 
    324         if func in self.functions: 
    325             self.stack.append(self.functions[func](*args)) 
    326         else: 
    327             if isinstance(func, tuple): 
    328                 tos, func = func 
    329                 if func == "startswith": 
    330                     self.stack.append(tos + " Like '" + args[0][1:-1] + "%'") 
    331                     self.imperfect = True 
    332                     return 
    333                 elif func == "endswith": 
    334                     self.stack.append(tos + " Like '%" + args[0][1:-1] + "'") 
    335                     self.imperfect = True 
    336                     return 
    337              
    338             if self.stack: 
    339                 self.stack[-1] = cannot_represent 
    340             else: 
    341                 self.stack = [cannot_represent] 
    342             self.imperfect = True 
     193class ADOSQLDecompiler(db.SQLDecompiler): 
    343194     
    344195    def visit_COMPARE_OP(self, lo, hi): 
     
    346197        op = lo + (hi << 8) 
    347198        if op in (6, 7):     # in, not in 
    348             if isinstance(op1, db.ConstWrapper): 
    349                 # Looking for text in a field. Use Like (reverse terms). 
    350                 # LIKE is case-insensitive in MS SQL Server (and there 
    351                 # doesn't seem to be a way around it). Mark imperfect, 
    352                 # but use the imperfect, insensitive filter at least. 
    353                 value = op2 + " Like '%" + op1[1:-1] + "%'" 
    354             else: 
    355                 # Looking for field in (a, b, c) 
    356                 value = op1 + " in (" + ", ".join([AdapterToADOSQL().coerce(x) 
    357                                                    for x in op2.basevalue]) + ")" 
     199            # Looking for text in a field. Use Like (reverse terms). 
     200            # LIKE is case-insensitive in MS SQL Server (and there 
     201            # doesn't seem to be a way around it). Use icontainedby 
     202            # and just mark imperfect. 
     203            value = self.dejavu_icontainedby(op1, op2) 
    358204            if op == 7: 
    359                 value = "not " + value 
     205                value = "NOT " + value 
    360206            self.stack.append(value) 
    361207            self.imperfect = True 
     
    385231            self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2) 
    386232     
    387     def binary_op(self, op): 
    388         op2, op1 = self.stack.pop(), self.stack.pop() 
    389         self.stack.append(op1 + " " + op + " " + op2) 
    390      
    391     def visit_BINARY_SUBSCR(self): 
    392         # The only BINARY_SUBSCR used in Expressions should be kwargs[key]. 
    393         # TODO: provide string slicing? 
    394         name = self.stack.pop() 
    395         tos = self.stack.pop() 
    396         if tos is not kw_arg: 
    397             raise ValueError(tos, name) 
    398         # name, since formed in LOAD_CONST, has extraneous single-quotes. 
    399         name = name[1:-1] 
    400         value = self.expr.kwargs[name] 
    401         value = self.adapter.coerce(value) 
    402         self.stack.append(value) 
    403      
    404     def visit_UNARY_NOT(self): 
    405         op = self.stack.pop() 
    406         if op is cannot_represent: 
    407             self.stack.append(cannot_represent) 
    408         else: 
    409             self.stack.append("not (" + op + ")") 
     233    def column_name(self, name): 
     234        # This is valid SQL for PostgreSQL only and should be overridden. 
     235        return '[%s].[%s]' % (self.tablename, name) 
     236     
     237    # --------------------------- Dispatchees --------------------------- # 
     238     
     239    def attr_startswith(self, tos, arg): 
     240        self.imperfect = True 
     241        return tos + " LIKE '" + arg.strip("'\"") + "%'" 
     242     
     243    def attr_endswith(self, tos, arg): 
     244        self.imperfect = True 
     245        return tos + " LIKE '%" + arg.strip("'\"") + "'" 
     246     
     247    def containedby(self, op1, op2): 
     248        self.imperfect = True 
     249        if isinstance(op1, ConstWrapper): 
     250            # Looking for text in a field. Use Like (reverse terms). 
     251            return op2 + " LIKE '%" + op1.strip("'\"") + "%'" 
     252        else: 
     253            # Looking for field in (a, b, c) 
     254            atoms = [self.adapter.coerce(x) for x in op2.basevalue] 
     255            return op1 + " IN (" + ", ".join(atoms) + ")" 
     256     
     257    def dejavu_icontainedby(self, op1, op2): 
     258        if isinstance(op1, db.ConstWrapper): 
     259            # Looking for text in a field. Use Like (reverse terms). 
     260            # LIKE is already case-insensitive in MS SQL Server; 
     261            # so don't use LOWER(). 
     262            value = op2 + " LIKE '%" + op1.strip("'\"") + "%'" 
     263        else: 
     264            # Looking for field in (a, b, c) 
     265            atoms = [self.adapter.coerce(x) for x in op2.basevalue] 
     266            value = op1 + " IN (" + ", ".join(atoms) + ")" 
     267        return value 
     268     
     269    def dejavu_istartswith(self, x, y): 
     270        # Like is already case-insensitive in ADO; so don't use LOWER(). 
     271        y = y.strip("'\"") 
     272        return x + " LIKE '" + y + "%'" 
     273     
     274    def dejavu_iendswith(self, x, y): 
     275        # Like is already case-insensitive in ADO; so don't use LOWER(). 
     276        y = y.strip("'\"") 
     277        return x + " LIKE '%" + y + "'" 
     278     
     279    def dejavu_ieq(self, x, y): 
     280        # = is already case-insensitive in ADO. 
     281        return x + " = " + y 
     282     
     283    def dejavu_now(self): 
     284        return "getdate()" 
     285     
     286    def dejavu_today(self): 
     287        return "DATEADD(dd, DATEDIFF(dd,0,getdate()), 0)" 
     288     
     289    def func__builtin___len(self, x): 
     290        return "Len(" + x + ")" 
     291 
    410292 
    411293 
     
    451333                            col = self.colIndices['ID'] 
    452334                            ID = self.data[col][row] 
    453                             rs = s.recordset(u"SELECT EXPVAL FROM %s" % 
    454                                              s.identifier(s.prefix, "_", 
    455                                                           clsname, "_", 
    456                                                           ID, "_", key)) 
     335                            i = s.identifier(s.prefix, "_", clsname, 
     336                                             "_", ID, "_", key) 
     337                            rs = s.recordset(u"SELECT EXPVAL FROM %s" % i) 
    457338                        except pywintypes.com_error, x: 
    458                             # This usually occurs because the parent Unit 
     339                            # This usually occurs because 1) the parent Unit 
    459340                            # was reserved but no table yet made for these 
    460                             # expanded values. This is OK. TODO: trap this 
    461                             # more specifically by examining the errmsg. 
    462                             values = [] 
     341                            # expanded values, or 2) no table was made 
     342                            # because no values were present. This is OK. 
     343                            # TODO: trap this more specifically by examining 
     344                            # the errmsg. 
     345                            values = None 
    463346                        else: 
    464347                            if rs.BOF and rs.EOF: 
    465348                                values = [] 
    466349                            else: 
    467                                 values = [pickle.loads(str(x)) for x in rs.GetRows()[0]] 
     350                                values = [pickle.loads(str(x)) 
     351                                          for x in rs.GetRows()[0]] 
    468352                            rs.Close() 
    469                         expectedType = unit.__class__.property_type(key) 
    470                         values = expectedType(values) 
     353                            expectedType = unit.__class__.property_type(key) 
     354                            values = expectedType(values) 
    471355                        # Set the attribute directly to avoid __set__ overhead. 
    472356                        unit._properties[key] = values 
     
    561445        self.connstring = allOptions[u'Connect'] 
    562446        self.CreateIfMissing = allOptions.get(u'Create If Missing', '') 
    563         self.DBName = allOptions.get(u'DBName', None) 
    564         if allOptions.get(u'Threaded', ''): 
    565             self.threaded = True 
    566             self._connection = None 
    567         else: 
    568             try: 
    569                 self._connection = win32com.client.Dispatch(r'ADODB.Connection') 
    570                 self._connection.Open(self.connstring) 
    571             except pywintypes.com_error, x: 
    572                 if x.args[2][5] == -2147467259 and self.CreateIfMissing: 
    573                     self.create_database() 
    574                     self._connection.Open(self.connstring) 
    575                 else: 
    576                     raise 
     447         
     448        atoms = dict([pair.lower().split("=", 1) 
     449                      for pair in self.connstring.split(";") 
     450                      if pair]) 
     451        self.DBName = atoms.get(u'data source', None) 
     452         
     453        self.threaded = bool(allOptions.get(u'Threaded', '')) 
     454        self._connection = None 
    577455         
    578456        self.prefix = allOptions.get(u'Prefix', u"djv") 
     
    598476            self._connection.Close() 
    599477     
     478    def _get_conn(self): 
     479        conn = win32com.client.Dispatch(r'ADODB.Connection') 
     480        try: 
     481            conn.Open(self.connstring) 
     482        except pywintypes.com_error, x: 
     483            if x.args[2][5] == -2147467259 and self.CreateIfMissing: 
     484                self.create_database() 
     485                conn.Open(self.connstring) 
     486            else: 
     487                raise 
     488        return conn 
     489     
    600490    def connection(self): 
    601491        if self.threaded: 
    602492            t = threading.currentThread() 
    603             if not hasattr(t, 'SMADOconn'): 
    604                 t.SMADOconn = win32com.client.Dispatch(r'ADODB.Connection') 
    605             if t.SMADOconn.State == adStateClosed: 
    606                 try: 
    607                     t.SMADOconn.Open(self.connstring) 
    608                 except pywintypes.com_error, x: 
    609                     if x.args[2][5] == -2147467259 and self.CreateIfMissing: 
    610                         self.create_database() 
    611                         t.SMADOconn.Open(self.connstring) 
    612                     else: 
    613                         raise 
    614             return t.SMADOconn 
    615         else: 
     493            if not hasattr(t, 'dejavu_storage_connection'): 
     494                t.dejavu_storage_connection = self._get_conn() 
     495            return t.dejavu_storage_connection 
     496        else: 
     497            if self._connection is None: 
     498                self._connection = self._get_conn() 
    616499            return self._connection 
    617500     
     
    704587     
    705588    def where(self, cls, expr): 
    706         return self.decompiler(self, cls, expr).code() 
     589        decom = self.decompiler(self.prefix + cls.__name__, expr, AdapterToADOSQL) 
     590        return decom.code(), decom.imperfect 
    707591     
    708592    def execute(self, aQuery, conn=None): 
     
    796680            pass 
    797681         
    798         # Ugly, ugly hack to get NTEXT or MEMO as appropriate. The point 
    799         # is, we want a large text field so we can pickle each item. 
    800         ftype = self.createAdapter.coerce_list(unitcls, key) 
    801         self.execute(u"CREATE TABLE %s (EXPVAL %s);" % (table, ftype), conn) 
    802          
    803         ins = u"INSERT INTO " + table + " (EXPVAL) VALUES ('%s');" 
    804682        val = getattr(unit, key) 
    805         if val: 
     683        if val is None: 
     684            # Don't create a new table at all. This will signal 
     685            # our iterator to set the attribute to None on load. 
     686            pass 
     687        else: 
     688            # Ugly, ugly hack to get NTEXT or MEMO as appropriate. The point 
     689            # is, we want a large text field so we can pickle each item. 
     690            ftype = self.createAdapter.coerce_list(unitcls, key) 
     691            self.execute(u"CREATE TABLE %s (EXPVAL %s);" % (table, ftype), conn) 
     692             
     693            ins = u"INSERT INTO " + table + " (EXPVAL) VALUES ('%s');" 
    806694            for v in val: 
    807695                # Create a row for the unit. 
     
    906794        # This method hasn't been tested yet for SQL server. 
    907795        adoconn = win32com.client.Dispatch(r'ADODB.Connection') 
    908         adoconn.Open(self.connstring) 
     796        atoms = dict([pair.upper().split("=", 1) 
     797                      for pair in self.connstring.split(";") 
     798                      if pair]) 
     799        adoconn.Open("Provider=SQLOLEDB;Data Source=(local);User ID=%s;Password=%s;" 
     800                     % (atoms.get('USER ID') or atoms.get('UID'), 
     801                        atoms.get('PASSWORD') or atoms.get('PWD')) 
     802                     ) 
    909803        adoconn.Execute("CREATE DATABASE %s" % self.DBName) 
    910804 
     
    919813class ADOSQLDecompiler_MSAccess(ADOSQLDecompiler): 
    920814    sql_cmp_op = ('<', '<=', '=', '<>', '>', '>=', 'in', 'not in') 
    921     functions = {dejavu.icontains: lambda x, y: x + " Like '%" + y[1:-1] + "%'", 
    922                  dejavu.icontainedby: icontainedby, 
    923                  dejavu.istartswith: lambda x, y: x + " Like '" + y[1:-1] + "%'", 
    924                  dejavu.iendswith: lambda x, y: x + " Like '%" + y[1:-1] + "'", 
    925                  dejavu.ieq: lambda x, y: x + " = " + y, 
    926                  dejavu.now: lambda: "Now()", 
    927                  dejavu.today: lambda: "DateValue(Now())", 
    928                  dejavu.year: lambda x: "Year(" + x + ")", 
    929                  len: lambda x: "Len(" + x + ")", 
    930                  } 
     815     
     816    def dejavu_now(self): 
     817        return "Now()" 
     818     
     819    def dejavu_today(self): 
     820        return "DateValue(Now())" 
     821     
     822    def dejavu_year(self, x): 
     823        return "Year(" + x + ")" 
    931824 
    932825 
     
    952845     
    953846    def create_database(self): 
     847        # By not providing an Engine Type, it defaults to 5 = Access 2000. 
    954848        cat = win32com.client.Dispatch(r'ADOX.Catalog') 
    955849        cat.Create(self.connstring) 
     850     
     851    def drop_database(self): 
     852        atoms = dict([pair.upper().split("=", 1) 
     853                      for pair in self.connstring.split(";") 
     854                      if pair]) 
     855        dbname = atoms.get("DATA SOURCE") or atoms.get("DATA SOURCE NAME") 
     856        import os 
     857        # This should accept relative or absolute paths 
     858        os.remove(dbname) 
    956859 
    957860 
  • trunk/storage/storemysql.py

    r43 r44  
    1515import threading 
    1616import datetime 
    17 try: 
    18     import cPickle as pickle 
    19 except ImportError: 
    20     import pickle 
    21 from types import FunctionType 
    22  
    23 try: 
    24     import fixedpoint 
    25 except ImportError: 
    26     pass 
    2717 
    2818import dejavu 
    29 from dejavu import storage, codewalk, logic 
     19from dejavu import storage, logic 
    3020from dejavu.storage import db 
    31 import recur 
    3221 
    3322 
     
    3625 
    3726 
    38 def containedby(op1, op2, notin=False): 
    39     if isinstance(op1, db.ConstWrapper): 
    40         # Looking for text in a field. Use Like (reverse terms). 
    41         value = op2 + " LIKE '%" + op1[1:-1] + "%'" 
    42     else: 
    43         # Looking for field in (a, b, c) 
    44         atoms = [AdapterToMySQL.coerce(x) for x in op2.basevalue] 
    45         value = op1 + " IN (" + ", ".join(atoms) + ")" 
    46     if notin: 
    47         value = "NOT " + value 
    48     return value 
    49  
    50 def icontainedby(op1, op2, notin=False): 
    51     if isinstance(op1, db.ConstWrapper): 
    52         # Looking for text in a field. Use Like (reverse terms). 
    53         value = "LOWER(" + op2 + ") LIKE '%" + op1[1:-1].lower() + "%'" 
    54     else: 
    55         # Looking for field in (a, b, c) 
    56         atoms = [AdapterToMySQL.coerce(x).lower() for x in op2.basevalue] 
    57         value = "LOWER(" + op1 + ") IN (" + ", ".join(atoms) + ")" 
    58     if notin: 
    59         value = "NOT " + value 
    60     return value 
    61  
    62  
    63 # Stack sentinels 
    64 cannot_represent = object() 
    65 table_arg = object() 
    66 kw_arg = object() 
    67  
    68 class MySQLDecompiler(codewalk.LambdaDecompiler): 
    69     """MySQLDecompiler(store, unitClass, expr, adapter=AdapterToMySQL). 
    70      
    71     Produce SQL from a supplied Expression object, with a lambda of the form: 
    72         lambda x, **kw: ... 
    73      
    74     Attributes of x (or whatever the name of the first argument is) will be 
    75     mapped to table columns. Keyword arguments should be bound to the 
    76     Expression before calling this decompiler. 
    77     """ 
    78      
    79     sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in') 
    80     functions = {dejavu.icontains: lambda x, y: "LOWER(" + x + ") LIKE '%" + y[1:-1] + "%'", 
    81                  dejavu.icontainedby: icontainedby, 
    82                  dejavu.istartswith: lambda x, y: "LOWER(" + x + ") LIKE '" + y[1:-1] + "%'", 
    83                  dejavu.iendswith: lambda x, y: "LOWER(" + x + ") LIKE '%" + y[1:-1] + "'", 
    84                  # This is a test of ILIKE with no wildcards, 
    85                  # to see if it behaves like ieq. 
    86                  dejavu.ieq: lambda x, y: "LOWER(" + x + ") = LOWER(" + y + ")", 
    87                  dejavu.now: lambda: "now()", 
    88                  dejavu.today: lambda: "CURDATE", 
    89                  dejavu.year: lambda x: "YEAR(" + x + ")", 
    90                  len: lambda x: "LENGTH(" + x + ")", 
    91                  } 
    92      
    93     def __init__(self, store, unitClass, expr, adapter=AdapterToMySQL): 
    94         self.tablename = store.identifier(store.prefix, unitClass.__name__) 
    95         self.expr = expr 
    96         self.adapter = adapter 
    97         obj = expr.func 
    98         codewalk.LambdaDecompiler.__init__(self, obj) 
    99      
    100     def code(self): 
    101         self.imperfect = False 
    102         self.walk() 
    103         result = self.stack[0] 
    104         if result is cannot_represent: 
    105             result = 'TRUE' 
    106         return result, self.imperfect 
    107      
    108     def visit_target(self, terms): 
    109         """A target is an AND or OR test.""" 
    110         comp = self.stack.pop() 
    111         while terms: 
    112             term, operation = terms.pop() 
    113             # All this checking of cannot_represent is done so that a 
    114             # function can be labeled imperfect. For example, the function 
    115             # dejavu.iscurrentweek has no MySQL equivalent. All Units 
    116             # (which match the rest of the Expression) will be recalled. 
    117             # They can then be compared in expr.evaluate(unit). 
    118             if term is not cannot_represent: 
    119                 if comp is cannot_represent: 
    120                     comp = term 
    121                 else: 
    122                     comp = "(%s) %s (%s)" % (term, operation.upper(), comp) 
    123         self.stack.append(comp) 
    124      
    125     def visit_LOAD_DEREF(self, lo, hi): 
    126         raise ValueError("Illegal reference found in %s." % self.expr) 
    127      
    128     def visit_LOAD_GLOBAL(self, lo, hi): 
    129         raise ValueError("Illegal global found in %s." % self.expr) 
    130      
    131     def visit_LOAD_FAST(self, lo, hi): 
    132         if lo + (hi << 8) < self.co_argcount: 
    133             self.stack.append(table_arg) 
    134         else: 
    135             self.stack.append(kw_arg) 
    136      
    137     def visit_LOAD_ATTR(self, lo, hi): 
    138         name = self.co_names[lo + (hi << 8)] 
    139         tos = self.stack.pop() 
    140         if tos is table_arg: 
    141             self.stack.append('%s.`%s`' % (self.tablename, name.lower())) 
    142         else: 
    143             self.stack.append((tos, name)) 
    144      
    145     def visit_LOAD_CONST(self, lo, hi): 
    146         val = self.co_consts[lo + (hi << 8)] 
    147         # Some constants are function or class objects, 
    148         # which should not be coerced. 
    149         no_coerce = (FunctionType, type, 
    150                      type(len),       # <type 'builtin_function_or_method'> 
    151                      ) 
    152         if not isinstance(val, no_coerce): 
    153             val = db.ConstWrapper(val, self.adapter.coerce(val)) 
    154         self.stack.append(val) 
    155      
    156     def visit_BUILD_TUPLE(self, lo, hi): 
    157         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)]) 
    158         self.stack.append("(" + terms + ")") 
    159      
    160     visit_BUILD_LIST = visit_BUILD_TUPLE 
    161      
    162     def visit_CALL_FUNCTION(self, lo, hi): 
    163         kwargs = {} 
    164         for i in range(hi): 
    165             val = self.stack.pop() 
    166             key = self.stack.pop() 
    167             kwargs[key] = val 
    168         kwargs = [k + "=" + v for k, v in kwargs.iteritems()] 
    169          
    170         args = [] 
    171         for i in range(lo): 
    172             arg = self.stack.pop() 
    173             args.append(arg) 
    174         args.reverse() 
    175          
    176         if kwargs: 
    177             args += kwargs 
    178          
    179         func = self.stack.pop() 
    180          
    181         # Handle function objects. 
    182         if func in self.functions: 
    183             self.stack.append(self.functions[func](*args)) 
    184         else: 
    185             if isinstance(func, tuple): 
    186                 tos, func = func 
    187                 if func == "startswith": 
    188                     self.stack.append(tos + " LIKE '" + args[0][1:-1] + "%'") 
    189                     return 
    190                 elif func == "endswith": 
    191                     self.stack.append(tos + " LIKE '%" + args[0][1:-1] + "'") 
    192                     return 
    193                 return 
    194              
    195             if self.stack: 
    196                 self.stack[-1] = cannot_represent 
    197             else: 
    198                 self.stack = [cannot_represent] 
    199             self.imperfect = True 
    200      
    201     def visit_COMPARE_OP(self, lo, hi): 
    202         op2, op1 = self.stack.pop(), self.stack.pop() 
    203         op = lo + (hi << 8) 
    204         if op in (6, 7):     # in, not in 
    205             self.stack.append(containedby(op1, op2, op == 7)) 
    206             self.imperfect = True 
    207         elif op1 == 'NULL': 
    208             if op == 2: 
    209                 self.stack.append(op2 + " IS NULL") 
    210             elif op == 3: 
    211                 self.stack.append(op2 + " IS NOT NULL") 
    212             else: 
    213                 raise ValueError("Non-equality Null comparisons not allowed.") 
    214         elif op2 == 'NULL': 
    215             if op == 2: 
    216                 self.stack.append(op1 + " IS NULL") 
    217             elif op == 3: 
    218                 self.stack.append(op1 + " IS NOT NULL") 
    219             else: 
    220                 raise ValueError("Non-equality Null comparisons not allowed.") 
    221         else: 
    222             # Comparison operators for strings are case-sensitive in MySQL. 
    223             self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2) 
    224      
    225     def binary_op(self, op): 
    226         op2, op1 = self.stack.pop(), self.stack.pop() 
    227         self.stack.append(op1 + " " + op + " " + op2) 
    228      
    229     def visit_BINARY_SUBSCR(self): 
    230         # The only BINARY_SUBSCR used in Expressions should be kwargs[key]. 
    231         name = self.stack.pop() 
    232         tos = self.stack.pop() 
    233         if tos is not kw_arg: 
    234             raise ValueError(tos, name) 
    235         # name, since formed in LOAD_CONST, has extraneous single-quotes. 
    236         name = name[1:-1] 
    237         value = self.expr.kwargs[name] 
    238         value = self.adapter.coerce(value) 
    239         self.stack.append(value) 
    240      
    241     def visit_UNARY_NOT(self): 
    242         op = self.stack.pop() 
    243         if op is cannot_represent: 
    244             # Usually as a result of has(farClassName). 
    245             self.stack.append(cannot_represent) 
    246         else: 
    247             self.stack.append("NOT (" + op + ")") 
     27class MySQLDecompiler(db.SQLDecompiler): 
     28     
     29    def column_name(self, name): 
     30        # MySQL forces lowercase column names. 
     31        return '`%s`.`%s`' % (self.tablename, name.lower()) 
     32     
     33    # --------------------------- Dispatchees --------------------------- # 
     34     
     35    def dejavu_today(self): 
     36        return "CURDATE()" 
     37 
     38 
     39class MySQLDecompiler411(MySQLDecompiler): 
     40    # Before MySQL 4.1.1, BINARY comparisons could use UPPER() 
     41    # or LOWER() to perform case-insensitive comparisons. Newer 
     42    # versions must use CONVERT() to obtain a case-sensitive 
     43    # encoding, like utf8. 
     44     
     45    def dejavu_icontainedby(self, op1, op2): 
     46        if isinstance(op1, db.ConstWrapper): 
     47            # Looking for text in a field. Use Like (reverse terms). 
     48            return "CONVERT("+ op2 + " USING utf8) LIKE '%" + op1.strip("'\"") + "%'" 
     49        else: 
     50            # Looking for field in (a, b, c). 
     51            atoms = [self.adapter.coerce(x) for x in op2.basevalue] 
     52            return "CONVERT(%s USING utf8) IN (%s)" % (op1, ", ".join(atoms)) 
     53     
     54    def dejavu_istartswith(self, x, y): 
     55        y = y.strip("'\"") 
     56        return "CONVERT(" + x + " USING utf8) LIKE '" + y + "%'" 
     57     
     58    def dejavu_iendswith(self, x, y): 
     59        y = y.strip("'\"") 
     60        return "CONVERT(" + x + " USING utf8) LIKE '%" + y + "'" 
     61     
     62    def dejavu_ieq(self, x, y): 
     63        return "CONVERT(" + x + " USING utf8) = " + y 
    24864 
    24965 
     
    303119        return u"LONGBLOB" 
    304120     
    305     def coerce_datetime_datetime(self, cls, key): return u"DATETIME" 
     121    def coerce_datetime_datetime(self, cls, key): 
     122        return u"DATETIME" 
    306123 
    307124 
     
    309126    """StoreManager to save and retrieve Units via _mysql .""" 
    310127     
    311     decompiler = MySQLDecompiler 
    312128    createAdapter = FieldTypeAdapterMySQL() 
    313129    threaded = False 
     
    325141         
    326142        self.CreateIfMissing = allOptions.get(u'Create If Missing', '') 
    327         if allOptions.get(u'Threaded', '1'): 
    328             self.threaded = True 
    329             self._connection = None 
    330         else: 
    331             try: 
    332                 self._connection = _mysql.connect(**self.connargs) 
    333             except Exception, x: 
    334                 if False and self.CreateIfMissing: 
    335                     self.create_database() 
    336                     self._connection = _mysql.connect(**self.connargs) 
    337                 else: 
    338                     raise 
     143        self.threaded = bool(allOptions.get(u'Threaded', '1')) 
     144        self._connection = None 
    339145         
    340146        self.prefix = allOptions.get(u'Prefix', u"djv") 
    341147        self.reserve_lock = threading.Lock() 
     148         
     149        self.decompiler = MySQLDecompiler 
     150        # Try to get the version string from MySQL, to see if we need 
     151        # a different decompiler. 
     152        try: 
     153            res = self.execute("SELECT VERSION();") 
     154            if res.num_rows(): 
     155                version = storage.Version(res.fetch_row(1, 0)[0][0]) 
     156                if version > storage.Version("4.1.1"): 
     157                    self.decompiler = MySQLDecompiler411 
     158        except: 
     159            pass 
    342160     
    343161    def identifier(self, *atoms): 
     
    354172     
    355173    def shutdown(self): 
    356         if self._connection is not None: 
    357             self._connection.close() 
     174        if self.threaded: 
     175            t = threading.currentThread() 
     176            conn = getattr(t, "dejavu_storage_connection", None) 
     177            if conn is not None: 
     178                conn.close() 
     179        else: 
     180            if self._connection is not None: 
     181                self._connection.close() 
     182     
     183    def _get_conn(self): 
     184        try: 
     185            conn = _mysql.connect(**self.connargs) 
     186        except _mysql.OperationalError, x: 
     187            if self.CreateIfMissing: 
     188                self.create_database() 
     189                conn = _mysql.connect(**self.connargs) 
     190            else: 
     191                raise 
     192        return conn 
    358193     
    359194    def connection(self): 
    360195        if self.threaded: 
    361196            t = threading.currentThread() 
    362             if not hasattr(t, 'SMMySQLconn'): 
    363                 try: 
    364                     t.SMMySQLconn = _mysql.connect(**self.connargs) 
    365                 except _mysql.OperationalError, x: 
    366                     if False and self.CreateIfMissing: 
    367                         self.create_database() 
    368                         t.SMMySQLconn = _mysql.connect(**self.connargs) 
    369                     else: 
    370                         raise 
    371             return t.SMMySQLconn 
    372         else: 
     197            if not hasattr(t, 'dejavu_storage_connection'): 
     198                t.dejavu_storage_connection = self._get_conn() 
     199            return t.dejavu_storage_connection 
     200        else: 
     201            if self._connection is None: 
     202                self._connection = self._get_conn() 
    373203            return self._connection 
    374204     
     
    377207        tmplconn['db'] = '' 
    378208        conn = _mysql.connect(**tmplconn) 
    379         self.execute('CREATE DATABASE %s' % 
     209        self.execute('CREATE DATABASE %s;' % 
    380210                     self.identifier(self.connargs['db']), conn) 
    381211        conn.close() 
     212     
     213    def drop_database(self): 
     214        self.execute("DROP DATABASE %s;" % 
     215                     self.identifier(self.connargs['db'])) 
    382216     
    383217    def select(self, unitClass, expr, distinct_fields=None): 
     
    398232     
    399233    def where(self, cls, expr): 
    400         return self.decompiler(self, cls, expr).code() 
     234        tablename = self.prefix + cls.__name__ 
     235        decom = self.decompiler(tablename.lower(), expr) 
     236        return decom.code(), decom.imperfect 
    401237     
    402238    def execute(self, query, conn=None): 
  • trunk/storage/storeodbc.py

    r43 r44  
    33""" 
    44 
    5 import fixedpoint 
     5import dbi, odbc 
     6 
     7import warnings 
     8import threading 
    69import datetime 
    7 import pickle 
    8 import dbi, odbc 
     10 
     11try: 
     12    import fixedpoint 
     13except ImportError: 
     14    pass 
    915 
    1016import dejavu 
    11 from dejavu import storage, codewalk 
    12  
    13  
    14 class AdapterFromODBC(storage.Adapter): 
    15     """Transform incoming values from ODBC to Dejavu datatypes.""" 
    16      
    17     def __init__(self, unit): 
    18         self.unit = unit 
    19      
    20     def consume(self, key, value): 
    21         expectedType = self.unit.__class__.property_type(key) 
    22         setattr(self.unit, key, self.coerce(value, expectedType)) 
    23      
    24     def to_uni(self, value): 
    25         if value is None: 
    26             return None 
    27         return unicode(value) 
    28      
    29     def pickle(self, value): 
    30         return pickle.loads(value, 2) 
    31      
    32     def coerce_datetime_datetime(self, value): 
    33         # Illegal Date/Time values will crash the 
    34         # app when using value.Format(). Therefore, 
    35         # grab the value and figure the date ourselves. 
    36         # Use 1-second resolution only. 
    37         if value is None: 
    38             return None 
    39         else: 
    40             aDate, aTime = divmod(float(value), 1) 
    41             aDate = datetime.date.fromordinal(int(aDate) + zeroHour) 
    42             hour, min = divmod(86400 * aTime, 3600) 
    43             min, sec = divmod(min, 60) 
    44             aTime = datetime.time(int(hour), int(min), int(sec)) 
    45             return datetime.datetime.combine(aDate, aTime) 
    46      
    47     def coerce_datetime_date(self, value): 
    48         # See coerce_datetime 
    49         if value is None: 
    50             return None 
    51         else: 
    52             aDate, aTime = divmod(float(value), 1) 
    53             return datetime.date.fromordinal(int(aDate) + zeroHour) 
    54      
    55     def coerce_datetime_time(self, value): 
    56         # See coerce_datetime 
    57         if value is None: 
    58             return None 
    59         else: 
    60             aDate, aTime = divmod(float(value), 1) 
    61             hour, min = divmod(86400 * aTime, 3600) 
    62             min, sec = divmod(min, 60) 
    63             return datetime.time(int(hour), int(min), int(sec)) 
    64      
    65     coerce_dict = pickle 
    66      
    67     def coerce_fixedpoint_FixedPoint(self, value): 
    68         if value is None: 
    69             return None 
    70         return fixedpoint.FixedPoint(value) 
    71      
    72     def coerce_float(self, value): 
    73         if value is None: 
    74             return None 
    75         return float(value) 
    76      
    77     def coerce_int(self, value): 
    78         if value is None: 
    79             return None 
    80         return int(value) 
    81     coerce_bool = coerce_int 
    82      
    83     coerce_str = to_uni 
    84     coerce_unicode = to_uni 
    85  
    86  
    87 class AdapterToODBCSQL(storage.Adapter): 
     17from dejavu import storage, logic 
     18from dejavu.storage import db 
     19 
     20 
     21AdapterToPgSQL = db.AdapterToSQL() 
     22AdapterFromPg = db.AdapterFromDB 
     23 
     24 
     25class AdapterToODBCSQL(db.AdapterToSQL): 
    8826    """Transform Expression values according to their type for ODBC SQL.""" 
    89      
    90     def to_str(self, value): 
    91         return str(value) 
    92      
    93     def coerce_NoneType(self, value): 
    94         return "Null" 
    95      
    96     def coerce_bool(self, value): 
    97         if value: 
    98             return 'True' 
    99         return 'False' 
    10027     
    10128    def coerce_datetime_datetime(self, value): 
     
    10734    def coerce_datetime_time(self, value): 
    10835        return u"{t '%s'}" % value.strftime('%H:%M:%S') 
    109      
    110     coerce_int = to_str 
    111     coerce_float = to_str 
    112     coerce_long = to_str 
    113      
    114     def coerce_str(self, value): 
    115         return "'" + value.replace(u"'", u"''") + "'" 
    116      
    117     def coerce_tuple(self, value): 
    118         return "(" + ", ".join([self.coerce(x) for x in value]) + ")" 
    119      
    120     coerce_unicode = coerce_str 
    121  
    122  
    123 def _icontainedby(op1, op2, notin=False): 
    124     if op2.startswith("[") and op2.endswith("]"): 
    125         # Looking for text in a field. Use Like (reverse terms). 
    126         value = op2 + " Like '%" + op1[1:-1] + "%'" 
    127     else: 
    128         # Looking for field in (a, b, c) 
    129         value = op1 + " in " + op2 
    130     if notin: 
    131         value = "not " + value 
    132     return value 
    133  
    134  
    135 class ODBCSQLDecompiler(codewalk.LambdaDecompiler): 
    136     """ODBCSQLDecompiler(expr=logic.Expression). 
    137      
    138     Produce ODBC SQL from a supplied lambda of the form: 
    139         lambda x, **kw: ... 
    140      
    141     Attributes of x (or whatever the name of the first argument is) will be 
    142     mapped to table columns. Keyword arguments should be bound to the 
    143     Expression before calling this decompiler. 
    144     """ 
    145      
    146     sql_cmp_op = ('<', '<=', '=', '<>', '>', '>=', 'in', 'not in') 
    147     functions = {dejavu.icontains: lambda x, y: x + " Like '%" + y[1:-1] + "%'", 
    148                  dejavu.icontainedby: _icontainedby, 
    149                  dejavu.istartswith: lambda x, y: x + " Like '" + y[1:-1] + "%'", 
    150                  dejavu.iendswith: lambda x, y: x + " Like '%" + y[1:-1] + "'", 
    151                  dejavu.ieq: lambda x, y: x + " = " + y, 
    152                  } 
    153      
    154     def __init__(self, expr): 
    155         self.expr = expr 
    156         obj = expr.func 
    157         codewalk.LambdaDecompiler.__init__(self, obj) 
    158      
    159     def code(self): 
    160         self.imperfect = False 
    161         self.walk() 
    162         return self.stack[0], self.imperfect 
    163      
    164     def visit_LOAD_GLOBAL(self, lo, hi): 
    165         pass 
    166      
    167     def visit_LOAD_FAST(self, lo, hi): 
    168         pass 
    169      
    170     def visit_LOAD_ATTR(self, lo, hi): 
    171         self.stack.append("[" + self.co_names[lo + (hi << 8)] + "]") 
    172      
    173     def visit_LOAD_CONST(self, lo, hi): 
    174         value = self.co_consts[lo + (hi << 8)] 
    175 ##        # Handle logic functions 
    176 ##        try: 
    177 ##            is_logic_func = (value.__module__ == 'logic') 
    178 ##        except AttributeError: 
    179 ##            is_logic_func = False 
    180 ##        if not is_logic_func: 
    181         value = AdapterToODBCSQL().coerce(value) 
    182         self.stack.append(value) 
    183      
    184     def visit_BUILD_TUPLE(self, lo, hi): 
    185         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)]) 
    186         self.stack.append("(" + terms + ")") 
    187      
    188     def visit_BUILD_LIST(self, lo, hi): 
    189         self.visit_BUILD_TUPLE(lo, hi) 
    190      
    191     def visit_CALL_FUNCTION(self, lo, hi): 
    192         kwargs = {} 
    193         for i in range(hi): 
    194             val = self.stack.pop() 
    195             key = self.stack.pop() 
    196             kwargs[key] = val 
    197         kwargs = [k + "=" + v for k, v in kwargs.iteritems()] 
    198          
    199         args = [] 
    200         for i in range(lo): 
    201             arg = self.stack.pop() 
    202             args.append(arg) 
    203         args.reverse() 
    204          
    205         if kwargs: 
    206             args += kwargs 
    207          
    208         func = self.stack.pop() 
    209          
    210         # Handle logic functions 
    211         if func in self.functions: 
    212             self.stack.append(self.functions[func](*args)) 
    213         else: 
    214             args = ", ".join(args) 
    215             if func == "[startswith]": 
    216                 self.stack[-1] = self.stack[-1] + " Like '" + args[1:-1] + "%'" 
    217                 self.imperfect = True 
    218             elif func == "[endswith]": 
    219                 self.stack[-1] = self.stack[-1] + " Like '%" + args[1:-1] + "'" 
    220                 self.imperfect = True 
    221             else: 
    222                 self.stack.append(func + "(" + args + ")") 
    223      
    224     def visit_COMPARE_OP(self, lo, hi): 
    225         op2, op1 = self.stack.pop(), self.stack.pop() 
    226         op = self.sql_cmp_op[lo + (hi << 8)] 
    227         if op == 'in': 
    228             self.stack.append(_icontainedby(op1, op2)) 
    229             self.imperfect = True 
    230         elif op == 'not in': 
    231             self.stack.append(_icontainedby(op1, op2, True)) 
    232             self.imperfect = True 
    233         else: 
    234             if op2.startswith("'") and op2.endswith("'"): 
    235                 # All ODBC comparison operators for strings are case-insensitive 
    236                 # by default. Rather than determine column-by-column which 
    237                 # might be case-sensitive, just flag them all as imperfect. 
    238                 self.imperfect = True 
    239             self.stack.append(op1 + " " + op + " " + op2) 
    240      
    241     def binary_op(self, op): 
    242         op2, op1 = self.stack.pop(), self.stack.pop() 
    243         self.stack.append(op1 + " " + op + " " + op2) 
    244      
    245     def visit_BINARY_SUBSCR(self): 
    246         name = self.stack.pop() 
    247         # name, since formed in LOAD_CONST, has extraneous single-quotes. 
    248         value = self.expr.kwargs[name[1:-1]] 
    249         value = AdapterToODBCSQL().coerce(value) 
    250         self.stack.append(value) 
    251  
    252  
    253 def safe_name(content): 
    254     return content.replace(u"_", u"") 
     36 
     37 
     38dbi_datetype = type(dbi.dbiDate(0)) 
     39 
     40class AdapterFromODBC(db.AdapterFromDB): 
     41    """Transform incoming values from ODBC to Dejavu datatypes.""" 
     42     
     43    def coerce_datetime_datetime(self, value, coltype): 
     44        if isinstance(value, dbi_datetype): 
     45            return datetime.datetime.utcfromtimestamp(int(value)) 
     46     
     47    def coerce_datetime_date(self, value, coltype): 
     48        if isinstance(value, dbi_datetype): 
     49            return datetime.datetime.utcfromtimestamp(int(value)).date() 
     50     
     51    def coerce_datetime_time(self, value, coltype): 
    25552 
    25653 
     
    25855    """Iterator for populating Units from storage.""" 
    25956     
    260     recordset = None 
    261     unitClass = None 
    262     server = None 
    263     fieldNames = None 
    264      
    265     def __init__(self, store, unitClass, expr, server): 
     57    def __init__(self, store, unitClass, expr): 
    26658        self.store  = store 
    26759        self.unitClass = unitClass 
    26860        self.expr = expr 
    269         self.server = server 
    270         self.colIndices = {} 
    271         self.fieldTypes = [] 
    272          
    27361        self.sql, self.imperfect = store.select(unitClass, expr) 
    27462     
    275     def populate_unit(self, unit, row): 
    276         coercer = AdapterFromODBC(unit) 
    277         for eachKey in unit.__class__.properties(): 
    278             coercer.consume(eachKey, row[self.colIndices[eachKey.lower()]]) 
    279         unit.concrete = True 
    280         unit.cleanse() 
    281         return True 
    282      
     63    def units(self): 
     64        s = self.store 
     65         
     66        res = s.execute(self.sql) 
     67        if res.num_rows(): 
     68            columns = {} 
     69            for index in xrange(res.nfields): 
     70                columns[res.fname(index)] = (index, res.ftype(index)) 
     71             
     72            for row in xrange(res.ntuples): 
     73                unit = self.unitClass() 
     74                coercer = AdapterFromPg(unit) 
     75                for key in unit.__class__.properties(): 
     76                    index, ftype = columns[key] 
     77                    value = res.getvalue(row, index) 
     78                    try: 
     79                        coercer.consume(key, value, ftype) 
     80                    except Exception, x: 
     81                        x.args += (key, ftype, value) 
     82                        raise x 
     83                # If our SQL is imperfect, don't yield it to the 
     84                # caller unless it passes evaluate(). 
     85                if (not self.imperfect) or self.expr.evaluate(unit): 
     86                    yield unit 
     87        res.clear() 
     88 
    28389    def load_data(self): 
    28490        anRS = self.store.recordset(self.sql) 
     
    305111                    unit = self.unitClass(server.namespace) 
    306112                    self.populate_unit(unit, row) 
    307                     cache.store(unit) 
    308                 else: 
    309                     unit = unit[0] 
    310                 # If our SQL is imperfect, it's OK to ask our server 
    311                 # to accept() our new Unit, but don't yield it to the 
    312                 # caller unless it passes evaluate(). 
    313                 if (not self.imperfect) or self.expr.evaluate(unit): 
    314                     yield unit 
    315  
    316  
    317 class CollectionIteratorODBC(StoreIteratorODBC): 
    318     """Iterator for populating Unit Collections from storage.""" 
    319      
    320     storageManager = None 
    321      
    322     def load_collection(self, unit): 
    323         # Grab the data dictionary (list of Unit ID's) 
    324         rsource = (u"SELECT ID FROM %s__%s" %  
    325                   (self.storageManager.prefix, safe_name(unit.ID))) 
    326         dataRS = self.storageManager.recordset(rsource) 
    327         while 1: 
    328             data = dataRS.fetchone() 
    329             if data is None: 
    330                 break 
    331             fieldNames = [x[0] for x in dataRS.description] 
    332             datadict = dict(zip(fieldNames, data)) 
    333             unit[unicode(datadict(u'id'))] = None 
    334         dataRS.close() 
    335      
    336     def units(self): 
    337         self.load_data() 
    338         if len(self.data) > 0: 
    339             server = self.server 
    340             cache = server.cache(self.unitClass) 
    341             for row in self.data: 
    342                 # Notice odbc field names are lower case. 
    343                 ID = unicode(row[self.colIndices[u'id']]) 
    344                 # Search the cache to see if we've already attached this unit. 
    345                 # Use has_key() instead of 'is' or '==' because the Unit may 
    346                 # have changed its _properties since the last load. 
    347                 unit = cache['ID'].get(ID, None) 
    348                 if unit is None: 
    349                     unit = self.unitClass(server.namespace) 
    350                     self.populate_unit(unit, row) 
    351                     self.load_collection(unit) 
    352113                    cache.store(unit) 
    353114                else: 
     
    485246                       bool: lambda x, y: u"BIT", 
    486247                       } 
     248 
     249 
     250class StorageManagerPgSQL(storage.StorageManager): 
     251    """StoreManager to save and retrieve Units via pyPgSQL 1.35.""" 
     252     
     253    decompiler = PgSQLDecompiler 
     254    createAdapter = db.FieldTypeAdapter() 
     255     
     256    def __init__(self, name, arena, allOptions={}): 
     257        storage.StorageManager.__init__(self, name, arena, allOptions) 
     258         
     259        # connstring = (host=h port=p dbname=d user=u password=p options=o tty=t) 
     260        self.connstring = allOptions[u'Connect'] 
     261        atoms = self.connstring.split(" ") 
     262        for atom in atoms: 
     263            k, v = atom.split("=", 1) 
     264            setattr(self, k, v) 
     265        self.CreateIfMissing = allOptions.get(u'Create If Missing', '') 
     266        self.threaded = bool(allOptions.get(u'Threaded', '1')) 
     267        self._connection = None 
     268         
     269        self.prefix = allOptions.get(u'Prefix', u"djv") 
     270        self.reserve_lock = threading.Lock() 
     271     
     272    def identifier(self, *atoms): 
     273        ident = '"' + ''.join(map(str, atoms)).replace('"', '""') + '"' 
     274        if len(ident) > 63: 
     275            warnings.warn("Identifier is longer than 63 characters. Most " 
     276                          "installations of Postgres are limited to 63. " 
     277                          "See NAMEDATALEN.") 
     278        return ident 
     279     
     280    def shutdown(self): 
     281        if self.threaded: 
     282            t = threading.currentThread() 
     283            conn = getattr(t, "dejavu_storage_connection", None) 
     284            if conn is not None: 
     285                conn.finish() 
     286        else: 
     287            if self._connection is not None: 
     288                self._connection.finish() 
     289     
     290    def _get_conn(self): 
     291        try: 
     292            conn = libpq.PQconnectdb(self.connstring) 
     293        except Exception, x: 
     294            if self.CreateIfMissing: 
     295                self.create_database() 
     296                conn = libpq.PQconnectdb(self.connstring) 
     297            else: 
     298                raise 
     299        return conn 
     300     
     301    def connection(self): 
     302        if self.threaded: 
     303            t = threading.currentThread() 
     304            if not hasattr(t, 'dejavu_storage_connection'): 
     305                t.dejavu_storage_connection = self._get_conn() 
     306            return t.dejavu_storage_connection 
     307        else: 
     308            if self._connection is None: 
     309                self._connection = self._get_conn() 
     310            return self._connection 
     311     
     312    def _template_conn(self): 
     313        atoms = self.connstring.split(" ") 
     314        tmplconn = "" 
     315        for atom in atoms: 
     316            k, v = atom.split("=", 1) 
     317            if k == 'dbname': v = 'template1' 
     318            tmplconn += "%s=%s " % (k, v) 
     319        return libpq.PQconnectdb(tmplconn) 
     320     
     321    def create_database(self): 
     322        self.execute('CREATE DATABASE %s' % self.identifier(self.dbname), 
     323                     self._template_conn()) 
     324     
     325    def drop_database(self): 
     326        self.execute("DROP DATABASE %s;" % self.identifier(self.dbname), 
     327                     self._template_conn()) 
     328     
     329    def select(self, unitClass, expr, distinct_fields=None): 
     330        tablename = self.identifier(self.prefix, unitClass.__name__) 
     331        if distinct_fields: 
     332            distinct_fields = [self.identifier(x) for x in distinct_fields] 
     333            sql = (u'SELECT DISTINCT %s FROM %s' % 
     334                   (u', '.join(distinct_fields), tablename)) 
     335        else: 
     336            sql = u'SELECT * FROM %s' % tablename 
     337        w, i = self.where(unitClass, expr) 
     338        if len(w) > 0: 
     339            w = u" WHERE " + w 
     340        else: 
     341            w = u"" 
     342        sql += w 
     343        return sql, i 
     344     
     345    def where(self, cls, expr): 
     346        decom = self.decompiler(self.prefix + cls.__name__, expr) 
     347        return decom.code(), decom.imperfect 
     348     
     349    def execute(self, query, conn=None): 
     350        if conn is None: 
     351            conn = self.connection() 
     352        try: 
     353            return conn.query(query) 
     354        except Exception, x: 
     355            x.args += (query,) 
     356            raise x 
     357     
     358    def recall(self, cls, expr=None, pairs=None): 
     359        if expr is None: 
     360            expr = logic.Expression(lambda x: True) 
     361        return StoreIteratorPgSQL(self, cls, expr).units() 
     362     
     363    def reserve(self, unit): 
     364        """reserve(unit). -> Reserve a persistent slot for unit. 
     365         
     366        Notice in particular that we do not use the auto-number or 
     367        sequence generation capabilities within some databases, etc. 
     368        The ID should be supplied by UnitSequencers via reserve(). 
     369        """ 
     370        clsname = unit.__class__.__name__ 
     371        tblname = self.identifier(self.prefix, clsname) 
     372        id = self.identifier("ID") 
     373        self.reserve_lock.acquire() 
     374        try: 
     375            if unit.ID is None: 
     376                data = [] 
     377                res = self.execute(u'SELECT %s FROM %s;' % (id, tblname)) 
     378                if res.resultType != libpq.EMPTY_QUERY: 
     379                    data = [res.getvalue(row, 0) for row in xrange(res.ntuples)] 
     380                unit.ID = unit.sequencer.next(data) 
     381             
     382            self.execute('INSERT INTO %s (%s) VALUES (%s)' % 
     383                         (tblname, id, AdapterToPgSQL.coerce(unit.ID))) 
     384        finally: 
     385            self.reserve_lock.release() 
     386     
     387    def save(self, unit, forceSave=False): 
     388        """save(unit, forceSave=False) -> Update storage from unit's data.""" 
     389        if unit.dirty() or forceSave: 
     390            cls = unit.__class__ 
     391            clsname = cls.__name__ 
     392            tablename = self.identifier(self.prefix, clsname) 
     393             
     394            parms = [] 
     395            for key in cls.properties(): 
     396                if key != "ID": 
     397                    val = AdapterToPgSQL.coerce(getattr(unit, key)) 
     398                    parms.append('%s = %s' % (self.identifier(key), val)) 
     399            sql = ('UPDATE %s SET %s WHERE %s = %s' % 
     400                   (tablename, u", ".join(parms), 
     401                    self.identifier("ID"), 
     402                    AdapterToPgSQL.coerce(unit.ID, cls.property_type("ID")))) 
     403            self.execute(sql) 
     404            unit.cleanse() 
     405     
     406    def destroy(self, unit): 
     407        """Delete the unit.""" 
     408        # Use a DELETE command instead of a cursor for better performance. 
     409        deleteStatement = (u'DELETE * FROM %s WHERE %s = %s' % 
     410                           (self.identifier(self.prefix, unit.__class__.__name__), 
     411                            self.identifier("ID"), 
     412                            AdapterToPgSQL.coerce(unit.ID))) 
     413        self.execute(deleteStatement) 
    487414     
    488415    def create_storage(self, unitClass): 
     416        tblname = self.identifier(self.prefix, unitClass.__name__) 
     417         
     418        coerce = self.createAdapter.coerce 
    489419        fields = [] 
    490         for eachKey in unitClass.properties(): 
    491             eachType = unitClass.property_type(eachKey) 
    492             aType = self.createCoercions[eachType](unitClass, eachKey) 
    493             fields.append(u"[%s] %s" % (eachKey, aType)) 
    494         indices = [x + " ASC" for x in unitClass.indices()] 
    495          
    496         tablename = self.prefix + safe_name(unitClass.__name__) 
    497         createStatement = u"CREATE TABLE [%s] (%s)" % (tablename, ", ".join(fields)) 
     420        for key in unitClass.properties(): 
     421            fields.append(u'%s %s' % (self.identifier(key), 
     422                                      coerce(unitClass, key))) 
    498423        try: 
    499             self.execute(createStatement) 
    500         except Exception, x: 
    501             x.args += (createStatement, ) 
    502             raise x 
    503          
    504         for index in indices: 
    505             indexStatement = (u"CREATE INDEX [%si%s%s] ON [%s%s] (%s)" 
    506                               % (self.prefix, safe_name(unitClass.__name__), safe_name(index), 
    507                                  self.prefix, safe_name(unitClass.__name__), index)) 
    508             try: 
    509                 self.execute(indexStatement) 
    510             except Exception, x: 
    511                 x.args += (indexStatement, ) 
    512                 raise x 
    513          
    514         return True 
    515  
     424            self.execute(u'CREATE TABLE %s (%s)' % (tblname, ", ".join(fields))) 
     425        except libpq.OperationalError, x: 
     426            if not x.args[0].endswith(' already exists\n'): 
     427                raise 
     428        else: 
     429            for index in unitClass.indices(): 
     430                indexname = self.identifier(self.prefix, "i", 
     431                                            unitClass.__name__, index) 
     432                self.execute(u'CREATE INDEX %s ON %s (%s)' 
     433                             % (indexname, tblname, self.identifier(index))) 
     434     
     435    def distinct(self, cls, fields, expr=None): 
     436        """Return distinct values for specified fields.""" 
     437        if expr is None: 
     438            expr = logic.Expression(lambda x: True) 
     439         
     440        # ^%$#@! There's no way to handle imperfect queries without 
     441        # creating all involved Units, which defeats the purpose of 
     442        # distinct, which was a speed issue more than anything. Grr. 
     443        sql, imperfect = self.select(cls, expr, fields) 
     444        # Ignore for now. 
     445##        if imperfect: 
     446##            raise ValueError(u"The following query cannot be reliably " 
     447##                             u"returned from a Postgres data source.", 
     448##                             u"distinct()", cls, fields, expr) 
     449         
     450        res = self.execute(sql) 
     451        if res.resultType == libpq.EMPTY_QUERY: 
     452            return [] 
     453         
     454        coerce = AdapterFromPg().coerce 
     455        data = [] 
     456        for row in xrange(res.ntuples): 
     457            coerced_row = [] 
     458            for i in xrange(len(fields)): 
     459                expectedType = cls.property_type(field[i]) 
     460                actualType = res.ftype(i) 
     461                val = coerce(res.getvalue(row, i), actualType, expectedType) 
     462                coerced_row.append(val) 
     463            data.append(coerced_row) 
     464        return zip(*data) 
     465 
  • trunk/storage/storepypgsql.py

    r43 r44  
    55import threading 
    66import datetime 
    7 try: 
    8     import cPickle as pickle 
    9 except ImportError: 
    10     import pickle 
    11 from types import FunctionType 
    12  
    13 try: 
    14     import fixedpoint 
    15 except ImportError: 
    16     pass 
    177 
    188import dejavu 
    19 from dejavu import storage, codewalk, logic 
     9from dejavu import storage, logic 
    2010from dejavu.storage import db 
    21 import recur 
    2211 
    2312 
    2413AdapterToPgSQL = db.AdapterToSQL() 
    2514AdapterFromPg = db.AdapterFromDB 
    26  
    27  
    28 def containedby(op1, op2, notin=False): 
    29     if isinstance(op1, db.ConstWrapper): 
    30         # Looking for text in a field. Use Like (reverse terms). 
    31         value = op2 + " LIKE '%" + op1[1:-1] + "%'" 
    32     else: 
    33         # Looking for field in (a, b, c) 
    34         atoms = [AdapterToPgSQL.coerce(x) for x in op2.basevalue] 
    35         value = op1 + " IN (" + ", ".join(atoms) + ")" 
    36     if notin: 
    37         value = "NOT " + value 
    38     return value 
    39  
    40 def icontainedby(op1, op2, notin=False): 
    41     if isinstance(op1, db.ConstWrapper): 
    42         # Looking for text in a field. Use Like (reverse terms). 
    43         value = op2 + " ILIKE '%" + op1[1:-1] + "%'" 
    44     else: 
    45         # Looking for field in (a, b, c) 
    46         atoms = [AdapterToPgSQL.coerce(x).lower() for x in op2.basevalue] 
    47         value = "LOWER(" + op1 + ") IN (" + ", ".join(atoms) + ")" 
    48     if notin: 
    49         value = "NOT " + value 
    50     return value 
    51  
    52  
    53 # Stack sentinels 
    54 cannot_represent = object() 
    55 table_arg = object() 
    56 kw_arg = object() 
    57  
    58 class PgSQLDecompiler(codewalk.LambdaDecompiler): 
    59     """PgSQLDecompiler(store, unitClass, expr, adapter=AdapterToPgSQL). 
    60      
    61     Produce SQL from a supplied Expression object, with a lambda of the form: 
    62         lambda x, **kw: ... 
    63      
    64     Attributes of x (or whatever the name of the first argument is) will be 
    65     mapped to table columns. Keyword arguments should be bound to the 
    66     Expression before calling this decompiler. 
    67     """ 
    68      
    69     sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in') 
    70     functions = {dejavu.icontains: lambda x, y: x + " ILIKE '%" + y[1:-1] + "%'", 
    71                  dejavu.icontainedby: icontainedby, 
    72                  dejavu.istartswith: lambda x, y: x + " ILIKE '" + y[1:-1] + "%'", 
    73                  dejavu.iendswith: lambda x, y: x + " ILIKE '%" + y[1:-1] + "'", 
    74                  # This is a test of ILIKE with no wildcards, 
    75                  # to see if it behaves like ieq. 
    76                  dejavu.ieq: lambda x, y: x + " ILIKE " + y, 
    77                  dejavu.now: lambda: "now()", 
    78                  dejavu.today: lambda: "CURRENT_DATE", 
    79                  dejavu.year: lambda x: "date_part('year', " + x + ")", 
    80                  len: lambda x: "length(" + x + ")", 
    81                  } 
    82      
    83     def __init__(self, store, unitClass, expr, adapter=AdapterToPgSQL): 
    84         self.tablename = store.prefix + unitClass.__name__ 
    85         self.expr = expr 
    86         self.adapter = adapter 
    87         obj = expr.func 
    88         codewalk.LambdaDecompiler.__init__(self, obj) 
    89      
    90     def code(self): 
    91         self.imperfect = False 
    92         self.walk() 
    93         result = self.stack[0] 
    94         if result is cannot_represent: 
    95             result = 'TRUE' 
    96         return result, self.imperfect 
    97      
    98     def visit_target(self, terms): 
    99         """A target is an AND or OR test.""" 
    100         comp = self.stack.pop() 
    101         while terms: 
    102             term, operation = terms.pop() 
    103             # All this checking of cannot_represent is done so that a 
    104             # function can be labeled imperfect. For example, the function 
    105             # dejavu.iscurrentweek has no PG SQL equivalent. All Units 
    106             # (which match the rest of the Expression) will be recalled. 
    107             # They can then be compared in expr.evaluate(unit). 
    108             if term is not cannot_represent: 
    109                 if comp is cannot_represent: 
    110                     comp = term 
    111                 else: 
    112                     comp = "(%s) %s (%s)" % (term, operation.upper(), comp) 
    113         self.stack.append(comp) 
    114      
    115     def visit_LOAD_DEREF(self, lo, hi): 
    116         raise ValueError("Illegal reference found in %s." % self.expr) 
    117      
    118     def visit_LOAD_GLOBAL(self, lo, hi): 
    119         raise ValueError("Illegal global found in %s." % self.expr) 
    120      
    121     def visit_LOAD_FAST(self, lo, hi): 
    122         if lo + (hi << 8) < self.co_argcount: 
    123             self.stack.append(table_arg) 
    124         else: 
    125             self.stack.append(kw_arg) 
    126      
    127     def visit_LOAD_ATTR(self, lo, hi): 
    128         name = self.co_names[lo + (hi << 8)] 
    129         tos = self.stack.pop() 
    130         if tos is table_arg: 
    131             self.stack.append('"%s"."%s"' % (self.tablename, name)) 
    132         else: 
    133             self.stack.append((tos, name)) 
    134      
    135     def visit_LOAD_CONST(self, lo, hi): 
    136         val = self.co_consts[lo + (hi << 8)] 
    137         # Some constants are function or class objects, 
    138         # which should not be coerced. 
    139         no_coerce = (FunctionType, type, 
    140                      type(len),       # <type 'builtin_function_or_method'> 
    141                      ) 
    142         if not isinstance(val, no_coerce): 
    143             val = db.ConstWrapper(val, self.adapter.coerce(val)) 
    144         self.stack.append(val) 
    145      
    146     def visit_BUILD_TUPLE(self, lo, hi): 
    147         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)]) 
    148         self.stack.append("(" + terms + ")") 
    149      
    150     visit_BUILD_LIST = visit_BUILD_TUPLE 
    151      
    152     def visit_CALL_FUNCTION(self, lo, hi): 
    153         kwargs = {} 
    154         for i in range(hi): 
    155             val = self.stack.pop() 
    156             key = self.stack.pop() 
    157             kwargs[key] = val 
    158         kwargs = [k + "=" + v for k, v in kwargs.iteritems()] 
    159          
    160         args = [] 
    161         for i in range(lo): 
    162             arg = self.stack.pop() 
    163             args.append(arg) 
    164         args.reverse() 
    165          
    166         if kwargs: 
    167             args += kwargs 
    168          
    169         func = self.stack.pop() 
    170          
    171         # Handle function objects. 
    172         if func in self.functions: 
    173             self.stack.append(self.functions[func](*args)) 
    174         else: 
    175             if isinstance(func, tuple): 
    176                 tos, func = func 
    177                 if func == "startswith": 
    178                     self.stack.append(tos + " LIKE '" + args[0][1:-1] + "%'") 
    179                     return 
    180                 elif func == "endswith": 
    181                     self.stack.append(tos + " LIKE '%" + args[0][1:-1] + "'") 
    182                     return 
    183                 return 
    184              
    185             if self.stack: 
    186                 self.stack[-1] = cannot_represent 
    187             else: 
    188                 self.stack = [cannot_represent] 
    189             self.imperfect = True 
    190      
    191     def visit_COMPARE_OP(self, lo, hi): 
    192         op2, op1 = self.stack.pop(), self.stack.pop() 
    193         op = lo + (hi << 8) 
    194         if op in (6, 7):     # in, not in 
    195             self.stack.append(containedby(op1, op2, op == 7)) 
    196             self.imperfect = True 
    197         elif op1 == 'NULL': 
    198             if op == 2: 
    199                 self.stack.append(op2 + " IS NULL") 
    200             elif op == 3: 
    201                 self.stack.append(op2 + " IS NOT NULL") 
    202             else: 
    203                 raise ValueError("Non-equality Null comparisons not allowed.") 
    204         elif op2 == 'NULL': 
    205             if op == 2: 
    206                 self.stack.append(op1 + " IS NULL") 
    207             elif op == 3: 
    208                 self.stack.append(op1 + " IS NOT NULL") 
    209             else: 
    210                 raise ValueError("Non-equality Null comparisons not allowed.") 
    211         else: 
    212             # Comparison operators for strings are case-sensitive in PG. 
    213             self.stack.append(op1 + " " + self.sql_cmp_op[op] + " " + op2) 
    214      
    215     def binary_op(self, op): 
    216         op2, op1 = self.stack.pop(), self.stack.pop() 
    217         self.stack.append(op1 + " " + op + " " + op2) 
    218      
    219     def visit_BINARY_SUBSCR(self): 
    220         # The only BINARY_SUBSCR used in Expressions should be kwargs[key]. 
    221         name = self.stack.pop() 
    222         tos = self.stack.pop() 
    223         if tos is not kw_arg: 
    224             raise ValueError(tos, name) 
    225         # name, since formed in LOAD_CONST, has extraneous single-quotes. 
    226         name = name[1:-1] 
    227         value = self.expr.kwargs[name] 
    228         value = self.adapter.coerce(value) 
    229         self.stack.append(value) 
    230      
    231     def visit_UNARY_NOT(self): 
    232         op = self.stack.pop() 
    233         if op is cannot_represent: 
    234             # Usually as a result of has(farClassName). 
    235             self.stack.append(cannot_represent) 
    236         else: 
    237             self.stack.append("NOT (" + op + ")") 
    23815 
    23916 
     
    27451 
    27552 
     53class PgSQLDecompiler(db.SQLDecompiler): 
     54     
     55    def dejavu_icontainedby(self, op1, op2): 
     56        if isinstance(op1, db.ConstWrapper): 
     57            # Looking for text in a field. Use ILike (reverse terms). 
     58            return op2 + " ILIKE '%" + op1.strip("'\"") + "%'" 
     59        else: 
     60            # Looking for field in (a, b, c). 
     61            # Force all args to lowercase for case-insensitive comparison. 
     62            atoms = [self.adapter.coerce(x).lower() for x in op2.basevalue] 
     63            return "LOWER(%s) IN (%s)" % (op1, ", ".join(atoms)) 
     64     
     65    def dejavu_istartswith(self, x, y): 
     66        return x + " ILIKE '" + y.strip("'\"") + "%'" 
     67     
     68    def dejavu_iendswith(self, x, y): 
     69        return x + " ILIKE '%" + y.strip("'\"") + "'" 
     70     
     71    def dejavu_ieq(self, x, y): 
     72        # ILIKE with no wildcards should behave like ieq. 
     73        return x + " ILIKE " + y 
     74     
     75    def dejavu_year(self, x): 
     76        return "date_part('year', " + x + ")" 
     77 
     78 
     79 
    27680class StorageManagerPgSQL(storage.StorageManager): 
    27781    """StoreManager to save and retrieve Units via pyPgSQL 1.35.""" 
     
    27983    decompiler = PgSQLDecompiler 
    28084    createAdapter = db.FieldTypeAdapter() 
    281     threaded = False 
    28285     
    28386    def __init__(self, name, arena, allOptions={}): 
     
    29194            setattr(self, k, v) 
    29295        self.CreateIfMissing = allOptions.get(u'Create If Missing', '') 
    293         if allOptions.get(u'Threaded', '1'): 
    294             self.threaded = True 
    295             self._connection = None 
    296         else: 
    297             try: 
    298                 self._connection = libpq.PQconnectdb(self.connstring) 
    299             except Exception, x: 
    300                 if False and self.CreateIfMissing: 
    301                     self.create_database() 
    302                     self._connection = libpq.PQconnectdb(self.connstring) 
    303                 else: 
    304                     raise 
     96        self.threaded = bool(allOptions.get(u'Threaded', '1')) 
     97        self._connection = None 
    30598         
    30699        self.prefix = allOptions.get(u'Prefix', u"djv") 
     
    316109     
    317110    def shutdown(self): 
    318         if self._connection is not None: 
    319             self._connection.finish() 
     111        if self.threaded: 
     112            t = threading.currentThread() 
     113            conn = getattr(t, "dejavu_storage_connection", None) 
     114            if conn is not None: 
     115                conn.finish() 
     116        else: 
     117            if self._connection is not None: 
     118                self._connection.finish() 
     119     
     120    def _get_conn(self): 
     121        try: 
     122            conn = libpq.PQconnectdb(self.connstring) 
     123        except Exception, x: 
     124            if self.CreateIfMissing: 
     125                self.create_database() 
     126                conn = libpq.PQconnectdb(self.connstring) 
     127            else: 
     128                raise 
     129        return conn 
    320130     
    321131    def connection(self): 
    322132        if self.threaded: 
    323133            t = threading.currentThread() 
    324             if not hasattr(t, 'SMPgSQLconn'): 
    325                 try: 
    326                     t.SMPgSQLconn = libpq.PQconnectdb(self.connstring) 
    327                 except Exception, x: 
    328                     if False and self.CreateIfMissing: 
    329                         self.create_database() 
    330                         t.SMPgSQLconn = libpq.PQconnectdb(self.connstring) 
    331                     else: 
    332                         raise 
    333             return t.SMPgSQLconn 
    334         else: 
     134            if not hasattr(t, 'dejavu_storage_connection'): 
     135                t.dejavu_storage_connection = self._get_conn() 
     136            return t.dejavu_storage_connection 
     137        else: 
     138            if self._connection is None: 
     139                self._connection = self._get_conn() 
    335140            return self._connection 
    336141     
    337     def create_database(self): 
     142    def _template_conn(self): 
    338143        atoms = self.connstring.split(" ") 
    339144        tmplconn = "" 
     
    342147            if k == 'dbname': v = 'template1' 
    343148            tmplconn += "%s=%s " % (k, v) 
    344         conn = libpq.PQconnectdb(tmplconn) 
    345         self.execute('CREATE DATABASE %s' % self.identifier(self.dbname), conn) 
     149        return libpq.PQconnectdb(tmplconn) 
     150     
     151    def create_database(self): 
     152        self.execute('CREATE DATABASE %s' % self.identifier(self.dbname), 
     153                     self._template_conn()) 
     154     
     155    def drop_database(self): 
     156        self.execute("DROP DATABASE %s;" % self.identifier(self.dbname), 
     157                     self._template_conn()) 
    346158     
    347159    def select(self, unitClass, expr, distinct_fields=None): 
     
    362174     
    363175    def where(self, cls, expr): 
    364         return self.decompiler(self, cls, expr).code() 
     176        decom = self.decompiler(self.prefix + cls.__name__, expr) 
     177        return decom.code(), decom.imperfect 
    365178     
    366179    def execute(self, query, conn=None): 
  • trunk/storage/storeshelve.py

    r43 r44  
    1 import os.path 
     1import os 
    22import shelve 
    33import threading 
     
    1010    """StoreManager to save and retrieve Units via stdlib shelve.""" 
    1111     
    12     shelvepath = None 
    13      
    14     # A dictionary whose keys are classnames and whose 
    15     # values are objects returned by shelve.open(). 
    16     shelves = {} 
    17      
    1812    def __init__(self, name, arena, allOptions={}): 
    1913        storage.StorageManager.__init__(self, name, arena, allOptions) 
    2014        self.shelvepath = allOptions['Path'] 
     15         
     16        # A dictionary whose keys are classnames and whose 
     17        # values are objects returned by shelve.open(). 
    2118        self.shelves = {} 
     19         
    2220        self.locks = {} 
    2321     
     
    6058        try: 
    6159            if unit.ID is None: 
    62                 unit.ID = unit.sequencer.next(data.keys()) 
     60                ids = [x['ID'] for x in data.itervalues()] 
     61                unit.ID = unit.sequencer.next(ids) 
    6362            data[str(unit.ID)] = unit._properties 
    6463        finally: 
     
    8483            lock.release() 
    8584     
     85    def create_database(self): 
     86        pass 
     87     
     88    def drop_database(self): 
     89        for clsname, shelf in self.shelves.iteritems(): 
     90            shelf.close() 
     91            tbl = os.path.join(self.shelvepath, clsname) 
     92            os.remove(tbl) 
     93     
    8694    def create_storage(self, unitClass): 
    8795        pass 
  • trunk/storage/test_storeado.py

    r43 r44  
    1 import unittest 
    2 import pywintypes 
    3 import datetime 
    4 import storeado 
    5 import dejavu 
    6 from dejavu import logic, Unit, UnitProperty, zoo 
     1import os 
     2from dejavu.storage import zoo_fixture 
    73 
     4if __name__ == '__main__': 
     5    # Microsoft Access 
     6    opts = {u'Connect': "PROVIDER=MICROSOFT.JET.OLEDB.4.0;DATA SOURCE=zoo.mdb;", 
     7            u'Expanded Columns': "Animal.PreviousZoos", 
     8            u'Create If Missing': True, 
     9            } 
    810 
    9 conn = "PROVIDER=MICROSOFT.JET.OLEDB.4.0;DATA SOURCE=zoo.mdb;" 
    10 smOptions = {u'Connect': conn, 
    11              u'Expanded Columns': "Animal.Options", 
    12              u'Create If Missing': True, 
    13              } 
    14 testSM = storeado.StorageManagerADO_MSAccess("test", zoo.arena, smOptions) 
    15 zoo.arena.add_store('testSM', testSM) 
    16  
    17 for cls in (zoo.Animal, zoo.Zoo, zoo.Exhibit): 
    1811    try: 
    19         zoo.arena.create_storage(cls) 
    20     except: 
    21         pass 
    22  
    23  
    24 class StorageManagerTests(unittest.TestCase): 
    25      
    26     def test_select(self): 
    27         def sel(f, sql, imp): 
    28             e = logic.Expression(f) 
    29             self.assertEqual(testSM.select(zoo.Animal, e), (sql, imp)) 
    30          
    31         sel(lambda x: x.Group == 3, 
    32             u"SELECT * FROM [djvAnimal] WHERE [djvAnimal].[Group] = 3", False) 
    33         sel(lambda x: x.Group.startswith('ex-'), 
    34             u"SELECT * FROM [djvAnimal] WHERE [djvAnimal].[Group] Like 'ex-%'", True) 
    35          
    36         # Test select all 
    37         sel(lambda x: True, u"SELECT * FROM [djvAnimal] WHERE TRUE", False) 
    38          
    39         # Test now(), today(), year() 
    40         sel(lambda x: x.FirstDate > dejavu.today(), 
    41             u"SELECT * FROM [djvAnimal] WHERE [djvAnimal].[FirstDate] > DateValue(Now())", False) 
    42         sel(lambda x: x.Event == dejavu.now(), 
    43             u"SELECT * FROM [djvAnimal] WHERE [djvAnimal].[Event] = Now()", False) 
    44         sel(lambda x: dejavu.year(x.Event) == 2004, 
    45             u"SELECT * FROM [djvAnimal] WHERE Year([djvAnimal].[Event]) = 2004", False) 
    46      
    47     def test_multiselect(self): 
    48         f = logic.Expression(lambda x: x.Legs == 4) 
    49         sql = testSM.multiselect(zoo.Animal, f, [(zoo.Zoo, None)])[0] 
    50         expected = ("SELECT [djvAnimal].[Legs], [djvAnimal].[Name], " 
    51                     "[djvAnimal].[ZooID], [djvAnimal].[ID], " 
    52                     "[djvAnimal].[LastEscape], [djvAnimal].[Options], " 
    53                     "[djvZoo].[Founded], [djvZoo].[LastEscape], " 
    54                     "[djvZoo].[ID], [djvZoo].[Opens], [djvZoo].[Name] " 
    55                     "FROM [djvAnimal] LEFT JOIN [djvZoo] ON " 
    56                     "[djvAnimal].[ZooID] = [djvZoo].[ID] WHERE " 
    57                     "[djvAnimal].[Legs] = 4") 
    58         self.assertEqual(sql, expected) 
    59      
    60     def test_create_storage(self): 
    61         try: 
    62             testSM.execute("DROP TABLE djvAnimal") 
    63         except pywintypes.com_error: 
    64             pass 
    65         testSM.create_storage(zoo.Animal) 
    66      
    67     def test_expanded_columns(self): 
    68         # Notice this also tests that: a Unit which is only 
    69         # dirtied via __init__ is still saved. 
    70         o = zoo.Exhibit(ID=1, Animals=[1, 2, 3]) 
    71         box = zoo.arena.new_sandbox() 
    72         box.memorize(o) 
    73         box.flush_all() 
    74          
    75         o = box.unit(zoo.Exhibit, ID=1) 
    76         self.assertNotEqual(o, None) 
    77         self.assertEqual(o.Animals, [1, 2, 3]) 
    78      
    79     def test_unit_roundtrip(self): 
    80         """Assert that a Unit can be loaded and saved.""" 
    81         box = zoo.arena.new_sandbox() 
    82          
    83         e = logic.Expression(lambda x: x.Name == 'Cat') 
    84         for unit in box.recall(zoo.Animal, e): 
    85             unit.forget() 
    86          
    87         cat = zoo.Animal(Name='Cat', Legs=1) 
    88         self.assertEqual(cat.Name, 'Cat') 
    89         legs = cat.Legs + 1 
    90         if legs > 10: 
    91             legs = 4 
    92         cat.Legs = legs 
    93         box.memorize(cat) 
    94         box.flush_all() 
    95          
    96         # Now, do the whole thing again to see if our save worked. 
    97         box = zoo.arena.new_sandbox() 
    98         u = [x for x in box.recall(zoo.Animal, e)] 
    99         self.assertEqual(len(u), 1) 
    100         cat = u[0] 
    101         self.assertEqual(cat.Name, 'Cat') 
    102         self.assertEqual(cat.Legs, legs) 
    103          
    104         # Now, do the whole thing again just for kicks. 
    105         box = zoo.arena.new_sandbox() 
    106         u = [x for x in box.recall(zoo.Animal, e)] 
    107         self.assertEqual(len(u), 1) 
    108         cat = u[0] 
    109         self.assertEqual(cat.Name, 'Cat') 
    110         self.assertEqual(cat.Legs, legs) 
    111  
    112  
    113 class ExpressionTests(unittest.TestCase): 
    114      
    115     def test_Adapter(self): 
    116         adapter = storeado.AdapterToADOSQL() 
    117         pairs = [(3, '3'), 
    118                  (3.1, '3.1'), 
    119                  (u'down the Stra\u00DFe', u"'down the Stra\u00DFe'"), 
    120                  ('a salted peanut', "'a salted peanut'"), 
    121                  (datetime.datetime(2001, 11, 15, 14, 15, 16), 
    122                   '#11/15/2001 14:15:16#'), 
    123                  (datetime.date(2001, 11, 15), '#11/15/2001#'), 
    124                  (datetime.time(6, 30), '#06:30:00#'), 
    125                  (True, 'TRUE'), 
    126                  (None, 'NULL'), 
    127                  ] 
    128         for initial, final in pairs: 
    129             self.assertEqual(adapter.coerce(initial), final) 
    130      
    131     def test_Decompiler(self): 
    132         def trial(lam, sql, imperfect): 
    133             e = logic.Expression(lam) 
    134             decom = storeado.ADOSQLDecompiler(testSM, zoo.Animal, e) 
    135             self.assertEqual(decom.code(), (sql, imperfect)) 
    136         trial(lambda x: x.Date == 3, "[djvAnimal].[Date] = 3", False) 
    137         trial(lambda x, **kw: (x.a == 3) and ((x.b > 1) or (x.b < -10)), 
    138               u'([djvAnimal].[a] = 3) and (([djvAnimal].[b] > 1) or ' 
    139               u'([djvAnimal].[b] < -10))', False) 
    140         trial(lambda x, **kw: (x.Group == 3) and not (x.Name.startswith("_")) 
    141               and not (x.Name.endswith('test')), 
    142               u"([djvAnimal].[Group] = 3) " 
    143               u"and ((not ([djvAnimal].[Name] Like '[_]%')) " 
    144               u"and (not ([djvAnimal].[Name] Like '%test')))", True) 
    145         trial(lambda x: (x.Group == '3') and (x.Date > datetime.date(2004, 2, 14)), 
    146               u"([djvAnimal].[Group] = '3') and ([djvAnimal].[Date] > #2/14/2004#)", True) 
    147          
    148         # None values 
    149         trial(lambda x: x.Date != None and None != x.Date, 
    150               u"([djvAnimal].[Date] IS NOT NULL) and ([djvAnimal].[Date] " 
    151               u"IS NOT NULL)", False) 
    152          
    153         # In operator 
    154         trial(lambda x: 'tool' in x.Name or 'tool' in x.Content, 
    155               u"([djvAnimal].[Name] Like '%tool%') " 
    156               u"or ([djvAnimal].[Content] Like '%tool%')", True) 
    157         trial(lambda x: x.Name in ('Johann', 'Gambolputty', 'de von Ausfern'), 
    158               u"[djvAnimal].[Name] in ('Johann', 'Gambolputty', 'de von Ausfern')", True) 
    159         # Try In with cell references 
    160         pet, pet2 = zoo.Animal(), zoo.Animal() 
    161         pet.Name, pet2.Name = 'Pony', 'Iguana' 
    162         trial(lambda x: x.Name in (pet.Name, pet2.Name), 
    163               u"[djvAnimal].[Name] in ('Pony', 'Iguana')", True) 
    164          
    165         # logic and other functions 
    166         trial(lambda x: dejavu.ieq(x.Name, 'Johann'), u"[djvAnimal].[Name] = 'Johann'", False) 
    167         trial(lambda x: dejavu.icontains(x.Name, 'tool'), u"[djvAnimal].[Name] Like '%tool%'", False) 
    168         trial(lambda x: dejavu.icontainedby(x.Name, ('Johann', 'Gambolputty', 'de von Ausfern')), 
    169               u"[djvAnimal].[Name] in ('Johann', 'Gambolputty', 'de von Ausfern')", False) 
    170         reqZip = '92104' 
    171         trial(lambda x: len(x.ZipStart) == len(reqZip), u"Len([djvAnimal].[ZipStart]) = 5", False) 
    172          
    173         # This broke on 5/10/04, because "== None" wasn't succeeding as "= NULL". 
    174         trial(lambda x: x.DateTo == None, u"[djvAnimal].[DateTo] IS NULL", False) 
    175          
    176         # Another one that broke sometime in 2004. Rev 32 seems to have fixed it. 
    177         trial(lambda x: 'C' in x.Plan, "[djvAnimal].[Plan] Like '%C%'", True) 
    178          
    179         # Multiple arguments (? Why should this be supported?) 
    180         trial(lambda x, y, z: x.Date == 3 and y.Qty > 4 and z.Qty < 20, 
    181               "([djvAnimal].[Date] = 3) and (([djvAnimal].[Qty] > 4) and ([djvAnimal].[Qty] < 20))", False) 
    182          
    183         # Pickled types 
    184         e = logic.Expression(lambda x: x.Animals == [1, 2, '3']) 
    185         decom = storeado.ADOSQLDecompiler(testSM, zoo.Exhibit, e) 
    186         self.assertEqual(decom.code(), 
    187                          ("[djvExhibit].[Animals] = '(lp1\nI1\naI2\naS''3''\na.'", False)) 
    188  
    189  
    190 class AdapterTests(unittest.TestCase): 
    191      
    192     def test_dates(self): 
    193         box = zoo.arena.new_sandbox() 
    194          
    195         WAP = zoo.Zoo() 
    196         WAP.Name = 'Wild Animal Park' 
    197         WAP.Founded = d = datetime.date(2000, 1, 1) 
    198         # 59 should give rounding errors with divmod, which 
    199         # AdapterFromADO needs to correct. 
    200         WAP.Opens = t = datetime.time(8, 15, 59) 
    201         WAP.LastEscape = dt = datetime.datetime(2004, 7, 29, 5, 6, 7) 
    202         box.memorize(WAP) 
    203          
    204         box.flush_all() 
    205          
    206         WAP = box.unit(zoo.Zoo, Name='Wild Animal Park') 
    207         self.assertNotEqual(WAP, None) 
    208         self.assertEqual(WAP.Founded, d) 
    209         self.assertEqual(WAP.Opens, t) 
    210         self.assertEqual(WAP.LastEscape, dt) 
    211  
    212  
    213 if __name__ == "__main__": 
    214     unittest.main() 
    215  
     12        testSM = zoo_fixture.setup_SM("dejavu.storage.storeado.StorageManagerADO_MSAccess", opts) 
     13        zoo_fixture.run_tests() 
     14    finally: 
     15        zoo_fixture.zoo.arena.shutdown() 
     16        testSM.drop_database() 
  • trunk/storage/test_storemysql.py

    r43 r44  
    1 import traceback 
    2 import unittest 
    3 import datetime 
    4 import storemysql 
    5 import dejavu 
    6 from dejavu import logic, Unit, UnitProperty, zoo 
     1from dejavu.storage import zoo_fixture 
    72 
    8  
    9 class StorageManagerTests(unittest.TestCase): 
    10      
    11     def test_select(self): 
    12         def sel(f, sql, imp): 
    13             e = logic.Expression(f) 
    14             self.assertEqual(testSM.select(zoo.Animal, e), (sql, imp)) 
    15          
    16         sel(lambda x: x.Group == 3, 
    17             """SELECT * FROM `djvanimal` WHERE `djvanimal`.`group` = 3""", False) 
    18         sel(lambda x: x.Group.startswith('ex-'), 
    19             """SELECT * FROM `djvanimal` WHERE `djvanimal`.`group` LIKE 'ex-%'""", False) 
    20          
    21         # Test select all 
    22         sel(lambda x: True, """SELECT * FROM `djvanimal` WHERE TRUE""", False) 
    23          
    24         # Test now(), today(), year() 
    25         sel(lambda x: x.FirstDate > dejavu.today(), 
    26             """SELECT * FROM `djvanimal` WHERE `djvanimal`.`firstdate` > CURDATE""", False) 
    27         sel(lambda x: x.Event == dejavu.now(), 
    28             """SELECT * FROM `djvanimal` WHERE `djvanimal`.`event` = now()""", False) 
    29         sel(lambda x: dejavu.year(x.Event) == 2004, 
    30             """SELECT * FROM `djvanimal` WHERE YEAR(`djvanimal`.`event`) = 2004""", False) 
    31      
    32     def test_create_storage(self): 
    33         testSM.execute('DROP TABLE `djvanimal`;') 
    34         testSM.create_storage(zoo.Animal) 
    35      
    36     def test_expanded_columns(self): 
    37         # Notice this also tests that: a Unit which is only 
    38         # dirtied via __init__ is still saved. 
    39         o = zoo.Exhibit(ID=1, Animals=[1, 2, 3]) 
    40         self.assertEqual(o.dirty(), True) 
    41         box = zoo.arena.new_sandbox() 
    42         box.memorize(o) 
    43         self.assertEqual(o.ID, 1) 
    44         box.flush_all() 
    45          
    46         o = box.unit(zoo.Exhibit, ID=1) 
    47         self.assertNotEqual(o, None) 
    48         self.assertEqual(o.Animals, [1, 2, 3]) 
    49      
    50     def test_unit_roundtrip(self): 
    51         """Assert that a Unit can be loaded and saved.""" 
    52         box = zoo.arena.new_sandbox() 
    53          
    54         e = logic.Expression(lambda x: x.Name == 'Cat') 
    55         for unit in box.recall(zoo.Animal, e): 
    56             unit.forget() 
    57          
    58         cat = zoo.Animal(Name='Cat', Legs=1) 
    59         self.assertEqual(cat.Name, 'Cat') 
    60         legs = cat.Legs + 1 
    61         if legs > 10: 
    62             legs = 4 
    63         cat.Legs = legs 
    64         box.memorize(cat) 
    65         box.flush_all() 
    66          
    67         # Now, do the whole thing again to see if our save worked. 
    68         box = zoo.arena.new_sandbox() 
    69         u = [x for x in box.recall(zoo.Animal, e)] 
    70         self.assertEqual(len(u), 1) 
    71         cat = u[0] 
    72         self.assertEqual(cat.Name, 'Cat') 
    73         self.assertEqual(cat.Legs, legs) 
    74          
    75         # Now, do the whole thing again just for kicks. 
    76         box = zoo.arena.new_sandbox() 
    77         u = [x for x in box.recall(zoo.Animal, e)] 
    78         self.assertEqual(len(u), 1) 
    79         cat = u[0] 
    80         self.assertEqual(cat.Name, 'Cat') 
    81         self.assertEqual(cat.Legs, legs) 
    82  
    83  
    84 class ExpressionTests(unittest.TestCase): 
    85      
    86     def test_AdapterToMySQL(self): 
    87         adapter = storemysql.AdapterToMySQL 
    88         pairs = [(3, '3'), 
    89                  (3.1, '3.1'), 
    90                  (u'down the Stra\u00DFe', u"'down the Stra\xdfe'"), 
    91                  ('a salted peanut', "'a salted peanut'"), 
    92                  (datetime.datetime(2001, 11, 15, 14, 15, 16), 
    93                   "'2001-11-15 14:15:16'"), 
    94                  (datetime.date(2001, 11, 15), "'2001-11-15'"), 
    95                  (datetime.time(6, 30), "'06:30:00'"), 
    96                  (True, 'TRUE'), 
    97                  (None, 'NULL'), 
    98                  ] 
    99         for initial, final in pairs: 
    100             self.assertEqual(adapter.coerce(initial), final) 
    101      
    102     def test_Decompiler(self): 
    103         def trial(lam, sql, imperfect): 
    104             e = logic.Expression(lam) 
    105             decom = storemysql.MySQLDecompiler(testSM, zoo.Animal, e) 
    106             self.assertEqual(decom.code(), (sql, imperfect)) 
    107         trial(lambda x: x.Date == 3, '`djvanimal`.`date` = 3', False) 
    108         trial(lambda x, **kw: (x.a == 3) and ((x.b > 1) or (x.b < -10)), 
    109               u'(`djvanimal`.`a` = 3) AND ((`djvanimal`.`b` > 1) OR (`djvanimal`.`b` < -10))', False) 
    110         trial(lambda x, **kw: (x.Group == 3) and not (x.Name.startswith("_")) 
    111               and not (x.Name.endswith('test')), 
    112               """(`djvanimal`.`group` = 3) AND ((NOT (`djvanimal`.`name` LIKE '\_%')) """ 
    113               """AND (NOT (`djvanimal`.`name` LIKE '%test')))""", False) 
    114         trial(lambda x: (x.Group == '3') and (x.Date > datetime.date(2004, 2, 14)), 
    115               """(`djvanimal`.`group` = '3') AND (`djvanimal`.`date` > '2004-02-14')""", False) 
    116          
    117         # None values 
    118         trial(lambda x: x.Date != None and None != x.Date, 
    119               """(`djvanimal`.`date` IS NOT NULL) AND (`djvanimal`.`date` IS NOT NULL)""", False) 
    120          
    121         # In operator 
    122         trial(lambda x: 'tool' in x.Name or 'tool' in x.Content, 
    123               """(`djvanimal`.`name` LIKE '%tool%') OR (`djvanimal`.`content` LIKE '%tool%')""", True) 
    124         trial(lambda x: x.Name in ('Johann', 'Gambolputty', 'de von Ausfern'), 
    125               """`djvanimal`.`name` IN ('Johann', 'Gambolputty', 'de von Ausfern')""", True) 
    126         # Try In with cell references 
    127         pet, pet2 = zoo.Animal(), zoo.Animal() 
    128         pet.Name, pet2.Name = 'Pony', 'Iguana' 
    129         trial(lambda x: x.Name in (pet.Name, pet2.Name), 
    130               """`djvanimal`.`name` IN ('Pony', 'Iguana')""", True) 
    131          
    132         # logic and other functions 
    133         trial(lambda x: dejavu.ieq(x.Name, 'Johann'), """LOWER(`djvanimal`.`name`) = LOWER('Johann')""", False) 
    134         trial(lambda x: dejavu.icontains(x.Name, 'tool'), """LOWER(`djvanimal`.`name`) LIKE '%tool%'""", False) 
    135         trial(lambda x: dejavu.icontainedby(x.Name, ('Johann', 'Gambolputty', 'de von Ausfern')), 
    136               """LOWER(`djvanimal`.`name`) IN ('johann', 'gambolputty', 'de von ausfern')""", False) 
    137         reqZip = '92104' 
    138         trial(lambda x: len(x.ZipStart) == len(reqZip), """LENGTH(`djvanimal`.`zipstart`) = 5""", False) 
    139          
    140         # This broke on 5/10/04, because "== None" wasn't succeeding as "= Null". 
    141         trial(lambda x: x.DateTo == None, """`djvanimal`.`dateto` IS NULL""", False) 
    142          
    143         # Another one that broke sometime in 2004. Rev 32 seems to have fixed it. 
    144         trial(lambda x: 'C' in x.Plan, """`djvanimal`.`plan` LIKE '%C%'""", True) 
    145          
    146         # Multiple arguments (? Why should this be supported?) 
    147         trial(lambda x, y, z: x.Date == 3 and y.Qty > 4 and z.Qty < 20, 
    148               """(`djvanimal`.`date` = 3) AND ((`djvanimal`.`qty` > 4) AND (`djvanimal`.`qty` < 20))""", False) 
    149  
    150  
    151 class AdapterTests(unittest.TestCase): 
    152      
    153     def test_dates(self): 
    154         box = zoo.arena.new_sandbox() 
    155          
    156         WAP = zoo.Zoo() 
    157         WAP.Name = 'Wild Animal Park' 
    158         WAP.Founded = d = datetime.date(2000, 1, 1) 
    159         # 59 should give rounding errors with divmod, 
    160         # which the Adapter needs to correct. 
    161         WAP.Opens = t = datetime.time(8, 15, 59) 
    162         WAP.LastEscape = dt = datetime.datetime(2004, 7, 29, 5, 6, 7) 
    163         box.memorize(WAP) 
    164          
    165         box.flush_all() 
    166          
    167         WAP = box.unit(zoo.Zoo, Name='Wild Animal Park') 
    168         self.assertNotEqual(WAP, None) 
    169         self.assertEqual(WAP.Founded, d) 
    170         self.assertEqual(WAP.Opens, t) 
    171         self.assertEqual(WAP.LastEscape, dt) 
    172  
    173  
    174 if __name__ == "__main__": 
     3if __name__ == '__main__': 
    1754    dbname = raw_input("Database name [dejavu_test]:") or "dejavu_test" 
    1765    pword = raw_input("Password for the root user:") 
    177     smOptions = {"host": "localhost", 
    178                  "db": dbname, 
    179                  "user": "root", 
    180                  "passwd": pword, 
    181                  u'Create If Missing': True, 
    182                  } 
    183     testSM = storemysql.StorageManagerMySQL("test", zoo.arena, smOptions) 
    184     zoo.arena.add_store('testSM', testSM) 
     6    opts = {"host": "localhost", 
     7            "db": dbname, 
     8            "user": "root", 
     9            "passwd": pword, 
     10            u'Create If Missing': True, 
     11            } 
    18512     
    186     # Create the database and our tables if necessary. 
    187     for cls in (zoo.Animal, zoo.Zoo, zoo.Exhibit): 
     13    try: 
     14        testSM = zoo_fixture.setup_SM("dejavu.storage.storemysql.StorageManagerMySQL", opts) 
     15        zoo_fixture.run_tests() 
     16    finally: 
    18817        try: 
    189             zoo.arena.create_storage(cls) 
    190         except Exception, x: 
    191             if x.args[1] == ("Unknown database '%s'" % dbname): 
    192                 testSM.create_database() 
    193                 zoo.arena.create_storage(cls) 
    194             else: 
    195                 traceback.print_exc() 
    196      
    197     unittest.main() 
    198  
     18            testSM.drop_database() 
     19        except NameError: 
     20            pass 
     21        zoo_fixture.zoo.arena.shutdown() 
  • trunk/storage/test_storeodbc.py

    r43 r44  
    1 import unittest 
    2 import fixedpoint 
    3 import storeodbc 
    4 import datetime 
    5 import dejavu 
    6 import dbi 
    7 from dejavu import servers, logic 
     1from dejavu.storage import zoo_fixture 
    82 
    9  
    10 # Once again, we find that the first param must be repeated 
    11 # in the connection string. Not sure why. 
    12 allOptions = {u'Connect': (u"Provider=MSDASQL;Driver={Microsoft Access " 
    13                            u"Driver (*.mdb)};DBQ=test.mdb;" 
    14                            u"Provider=MSDASQL;")} 
    15 testSM = storeodbc.StorageManagerODBC(allOptions) 
    16  
    17 ns = dejavu.Namespace() 
    18 ns.stores['testSM'] = testSM 
    19  
    20 allOptions = {u'Litmus': '.*', 
    21               u'StorageManager': 'testSM', 
    22               } 
    23 testServer = servers.UnitServer(ns, allOptions) 
    24  
    25  
    26 class Things(dejavu.Unit): pass 
    27 Things.set_properties({"Name": unicode, 
    28                        "Size": int, 
    29                        "Date": datetime.date, 
    30                        }) 
    31  
    32 class Animals(dejavu.Unit): pass 
    33 Animals.set_properties({"Name": unicode, 
    34                         "Legs": int, 
    35                         }) 
    36  
    37 class StorageManagerTests(unittest.TestCase): 
     3if __name__ == '__main__': 
     4    # Once again, we find that the first param must be repeated 
     5    # in the connection string. Not sure why. 
     6    opts = {u'Connect': ("Provider=MSDASQL;" 
     7                         "Driver={Microsoft Access Driver (*.mdb)};" 
     8                         "DBQ=zoo.mdb;Provider=MSDASQL;"), 
     9            } 
    3810     
    39     def test_select(self): 
    40         e = logic.Expression(lambda x: x.Group == 3) 
    41         self.assertEqual(testSM.select(Things, e), 
    42                          (u"SELECT * FROM [djvThings] WHERE [Group] = 3", False)) 
    43         e = logic.Expression(lambda x: x.Group.startswith('ex-')) 
    44         self.assertEqual(testSM.select(Things, e), 
    45                          (u"SELECT * FROM [djvThings] WHERE [Group] Like 'ex-%'", True)) 
     11    # Create the database. 
     12    import win32com.client 
     13    cat = win32com.client.Dispatch(r'ADOX.Catalog') 
     14    cat.Create(opts['Connect']) 
    4615     
    47     def test_create_storage(self): 
     16    try: 
     17        testSM = zoo_fixture.setup_SM("dejavu.storage.storeodbc.StorageManagerODBC", opts) 
     18        zoo_fixture.run_tests() 
     19    finally: 
     20        zoo_fixture.zoo.arena.shutdown() 
    4821        try: 
    49             testSM.execute("DROP TABLE djvThings"
    50         except dbi.progError: 
     22            testSM.drop_database(
     23        except NameError: 
    5124            pass 
    52         testSM.create_storage(Things) 
    53      
    54     def test_recordset(self): 
    55         rs = testSM.recordset("SELECT * FROM [djvAnimals]") 
    56         data = rs.fetchall() 
    57         # Assert num columns == 4 
    58         self.assertEqual(len(data), 4) 
    59         # Assert num rows == 3 
    60         self.assertEqual(len(data[0]), 3) 
    61      
    62     def test_max_id(self): 
    63         self.assertEqual(testSM.max_id(Animals), 4) 
    64         self.test_create_storage() 
    65         self.assertEqual(testSM.max_id(Things), 0) 
    66      
    67     def test_unit_roundtrip(self): 
    68         """Assert that a Unit can be loaded and saved.""" 
    69         e = logic.Expression(lambda x: x.Name == 'Cat') 
    70         it = testSM.loader(testServer, Animals, e) 
    71         self.assertEqual(it.sql, "SELECT * FROM [djvAnimals] WHERE [Name] = 'Cat'") 
    72         u = [x for x in it.units()] 
    73         self.assertEqual(len(u), 1) 
    74         cat = u[0] 
    75         self.assertEqual(cat.Name, 'Cat') 
    76         legs = cat.Legs + 1 
    77         if legs > 10: 
    78             legs = 4 
    79         cat.Legs = legs 
    80         testSM.save(cat, True) 
    81          
    82         # Now, do the whole thing again to see if our save worked. 
    83         testServer.forget(cat, False) 
    84         it = testSM.loader(testServer, Animals, e) 
    85         u = [x for x in it.units()] 
    86         self.assertEqual(len(u), 1) 
    87         cat = u[0] 
    88         self.assertEqual(cat.Name, 'Cat') 
    89         self.assertEqual(cat.Legs, legs) 
    90  
    91  
    92 class ExpressionTests(unittest.TestCase): 
    93      
    94     def test_Adapter(self): 
    95         adapter = storeodbc.AdapterToODBCSQL() 
    96         pairs = [(3, '3'), 
    97                  (3.1, '3.1'), 
    98                  (fixedpoint.FixedPoint(5.2, 3), '5.200'), 
    99                  (u'down the Stra\u00DFe', u"'down the Stra\u00DFe'"), 
    100                  ('a salted peanut', "'a salted peanut'"), 
    101                  (datetime.datetime(2001, 11, 15, 14, 15, 16), 
    102                   "{ts '2001-11-15 14:15:16'}"), 
    103                  (datetime.date(2001, 11, 15), "{d '2001-11-15'}"), 
    104                  (datetime.time(6, 30), "{t '06:30:00'}"), 
    105                  (True, 'True'), 
    106                  (None, 'Null'), 
    107                  ] 
    108         for initial, final in pairs: 
    109             self.assertEqual(adapter.coerce(initial), final) 
    110      
    111     def test_Decompiler(self): 
    112         def trial(lam, sql, imperfect): 
    113             e = logic.Expression(lam) 
    114             decom = storeodbc.ODBCSQLDecompiler(e) 
    115             self.assertEqual(decom.code(), (sql, imperfect)) 
    116         trial(lambda x: x.Date == 3, "[Date] = 3", False) 
    117         trial(lambda x, **kw: (x.a == 3) and ((x.b > 1) or (x.b < -10)), 
    118               u'([a] = 3) and (([b] > 1) or ([b] < -10))', False) 
    119         trial(lambda x, **kw: (x.Group == 3) and not (x.Name.startswith("_")) 
    120               and not (x.Name.endswith('test')), 
    121               u"([Group] = 3) and ((not ([Name] Like '_%')) " 
    122               u"and (not ([Name] Like '%test')))", True) 
    123         trial(lambda x: (x.Group == '3') and (x.Date > datetime.date(2004, 2, 14)), 
    124               u"([Group] = '3') and ([Date] > {d '2004-02-14'})", True) 
    125          
    126         # In operator 
    127         trial(lambda x: 'tool' in x.Name or 'tool' in x.Content, 
    128               u"([Name] Like '%tool%') or ([Content] Like '%tool%')", True) 
    129         trial(lambda x: x.Name in ('Johann', 'Gambolputty', 'de von Ausfern'), 
    130               u"[Name] in ('Johann', 'Gambolputty', 'de von Ausfern')", True) 
    131         # Try In with cell references 
    132         pet, pet2 = Animals(), Animals() 
    133         pet.Name, pet2.Name = 'Pony', 'Iguana' 
    134         trial(lambda x: x.Name in (pet.Name, pet2.Name), 
    135               u"[Name] in ('Pony', 'Iguana')", True) 
    136          
    137         # logic functions 
    138         trial(lambda x: dejavu.ieq(x.Name, 'Johann'), u"[Name] = 'Johann'", False) 
    139         trial(lambda x: dejavu.icontains(x.Name, 'tool'), u"[Name] Like '%tool%'", False) 
    140         trial(lambda x: dejavu.icontainedby(x.Name, ('Johann', 'Gambolputty', 'de von Ausfern')), 
    141               u"[Name] in ('Johann', 'Gambolputty', 'de von Ausfern')", False) 
    142  
    143  
    144 if __name__ == "__main__": 
    145     unittest.main() 
    146  
  • trunk/storage/test_storepypgsql.py

    r43 r44  
    1 import traceback 
    2 import unittest 
    3 import datetime 
    4 import storepypgsql 
    5 import dejavu 
    6 from dejavu import logic, Unit, UnitProperty, zoo 
     1from dejavu.storage import zoo_fixture 
    72 
    8  
    9 class StorageManagerTests(unittest.TestCase): 
    10      
    11     def test_select(self): 
    12         def sel(f, sql, imp): 
    13             e = logic.Expression(f) 
    14             self.assertEqual(testSM.select(zoo.Animal, e), (sql, imp)) 
    15          
    16         sel(lambda x: x.Group == 3, 
    17             """SELECT * FROM "djvAnimal" WHERE "djvAnimal"."Group" = 3""", False) 
    18         sel(lambda x: x.Group.startswith('ex-'), 
    19             """SELECT * FROM "djvAnimal" WHERE "djvAnimal"."Group" LIKE 'ex-%'""", False) 
    20          
    21         # Test select all 
    22         sel(lambda x: True, """SELECT * FROM "djvAnimal" WHERE TRUE""", False) 
    23          
    24         # Test now(), today(), year() 
    25         sel(lambda x: x.FirstDate > dejavu.today(), 
    26             """SELECT * FROM "djvAnimal" WHERE "djvAnimal"."FirstDate" > CURRENT_DATE""", False) 
    27         sel(lambda x: x.Event == dejavu.now(), 
    28             """SELECT * FROM "djvAnimal" WHERE "djvAnimal"."Event" = now()""", False) 
    29         sel(lambda x: dejavu.year(x.Event) == 2004, 
    30             """SELECT * FROM "djvAnimal" WHERE date_part('year', "djvAnimal"."Event") = 2004""", False) 
    31      
    32     def test_create_storage(self): 
    33         testSM.execute('DROP TABLE "djvAnimal"') 
    34         testSM.create_storage(zoo.Animal) 
    35      
    36     def test_expanded_columns(self): 
    37         # Notice this also tests that: a Unit which is only 
    38         # dirtied via __init__ is still saved. 
    39         o = zoo.Exhibit(ID=1, Animals=[1, 2, 3]) 
    40         self.assertEqual(o.dirty(), True) 
    41         box = zoo.arena.new_sandbox() 
    42         box.memorize(o) 
    43         self.assertEqual(o.ID, 1) 
    44         box.flush_all() 
    45          
    46         o = box.unit(zoo.Exhibit, ID=1) 
    47         self.assertNotEqual(o, None) 
    48         self.assertEqual(o.Animals, [1, 2, 3]) 
    49      
    50     def test_unit_roundtrip(self): 
    51         """Assert that a Unit can be loaded and saved.""" 
    52         box = zoo.arena.new_sandbox() 
    53          
    54         e = logic.Expression(lambda x: x.Name == 'Cat') 
    55         for unit in box.recall(zoo.Animal, e): 
    56             unit.forget() 
    57          
    58         cat = zoo.Animal(Name='Cat', Legs=1) 
    59         self.assertEqual(cat.Name, 'Cat') 
    60         legs = cat.Legs + 1 
    61         if legs > 10: 
    62             legs = 4 
    63         cat.Legs = legs 
    64         box.memorize(cat) 
    65         box.flush_all() 
    66          
    67         # Now, do the whole thing again to see if our save worked. 
    68         box = zoo.arena.new_sandbox() 
    69         u = [x for x in box.recall(zoo.Animal, e)] 
    70         self.assertEqual(len(u), 1) 
    71         cat = u[0] 
    72         self.assertEqual(cat.Name, 'Cat') 
    73         self.assertEqual(cat.Legs, legs) 
    74          
    75         # Now, do the whole thing again just for kicks. 
    76         box = zoo.arena.new_sandbox() 
    77         u = [x for x in box.recall(zoo.Animal, e)] 
    78         self.assertEqual(len(u), 1) 
    79         cat = u[0] 
    80         self.assertEqual(cat.Name, 'Cat') 
    81         self.assertEqual(cat.Legs, legs) 
    82  
    83  
    84 class ExpressionTests(unittest.TestCase): 
    85      
    86     def test_AdapterToPgSQL(self): 
    87         adapter = storepypgsql.AdapterToPgSQL 
    88         pairs = [(3, '3'), 
    89                  (3.1, '3.1'), 
    90                  (u'down the Stra\u00DFe', u"'down the Stra\xdfe'"), 
    91                  ('a salted peanut', "'a salted peanut'"), 
    92                  (datetime.datetime(2001, 11, 15, 14, 15, 16), 
    93                   "'2001-11-15 14:15:16'"), 
    94                  (datetime.date(2001, 11, 15), "'2001-11-15'"), 
    95                  (datetime.time(6, 30), "'06:30:00'"), 
    96                  (True, 'TRUE'), 
    97                  (None, 'NULL'), 
    98                  ] 
    99         for initial, final in pairs: 
    100             self.assertEqual(adapter.coerce(initial), final) 
    101      
    102     def test_Decompiler(self): 
    103         def trial(lam, sql, imperfect): 
    104             e = logic.Expression(lam) 
    105             decom = storepypgsql.PgSQLDecompiler(testSM, zoo.Animal, e) 
    106             self.assertEqual(decom.code(), (sql, imperfect)) 
    107         trial(lambda x: x.Date == 3, '"djvAnimal"."Date" = 3', False) 
    108         trial(lambda x, **kw: (x.a == 3) and ((x.b > 1) or (x.b < -10)), 
    109               u'("djvAnimal"."a" = 3) AND (("djvAnimal"."b" > 1) OR ("djvAnimal"."b" < -10))', False) 
    110         trial(lambda x, **kw: (x.Group == 3) and not (x.Name.startswith("_")) 
    111               and not (x.Name.endswith('test')), 
    112               """("djvAnimal"."Group" = 3) AND ((NOT ("djvAnimal"."Name" LIKE '\_%')) """ 
    113               """AND (NOT ("djvAnimal"."Name" LIKE '%test')))""", False) 
    114         trial(lambda x: (x.Group == '3') and (x.Date > datetime.date(2004, 2, 14)), 
    115               """("djvAnimal"."Group" = '3') AND ("djvAnimal"."Date" > '2004-02-14')""", False) 
    116          
    117         # None values 
    118         trial(lambda x: x.Date != None and None != x.Date, 
    119               """("djvAnimal"."Date" IS NOT NULL) AND ("djvAnimal"."Date" IS NOT NULL)""", False) 
    120          
    121         # In operator 
    122         trial(lambda x: 'tool' in x.Name or 'tool' in x.Content, 
    123               """("djvAnimal"."Name" LIKE '%tool%') OR ("djvAnimal"."Content" LIKE '%tool%')""", True) 
    124         trial(lambda x: x.Name in ('Johann', 'Gambolputty', 'de von Ausfern'), 
    125               """"djvAnimal"."Name" IN ('Johann', 'Gambolputty', 'de von Ausfern')""", True) 
    126         # Try In with cell references 
    127         pet, pet2 = zoo.Animal(), zoo.Animal() 
    128         pet.Name, pet2.Name = 'Pony', 'Iguana' 
    129         trial(lambda x: x.Name in (pet.Name, pet2.Name), 
    130               """"djvAnimal"."Name" IN ('Pony', 'Iguana')""", True) 
    131          
    132         # logic and other functions 
    133         trial(lambda x: dejavu.ieq(x.Name, 'Johann'), """"djvAnimal"."Name" ILIKE 'Johann'""", False) 
    134         trial(lambda x: dejavu.icontains(x.Name, 'tool'), """"djvAnimal"."Name" ILIKE '%tool%'""", False) 
    135         trial(lambda x: dejavu.icontainedby(x.Name, ('Johann', 'Gambolputty', 'de von Ausfern')), 
    136               """LOWER("djvAnimal"."Name") IN ('johann', 'gambolputty', 'de von ausfern')""", False) 
    137         reqZip = '92104' 
    138         trial(lambda x: len(x.ZipStart) == len(reqZip), """length("djvAnimal"."ZipStart") = 5""", False) 
    139          
    140         # This broke on 5/10/04, because "== None" wasn't succeeding as "= Null". 
    141         trial(lambda x: x.DateTo == None, """"djvAnimal"."DateTo" IS NULL""", False) 
    142          
    143         # Another one that broke sometime in 2004. Rev 32 seems to have fixed it. 
    144         trial(lambda x: 'C' in x.Plan, """"djvAnimal"."Plan" LIKE '%C%'""", True) 
    145          
    146         # Multiple arguments (? Why should this be supported?) 
    147         trial(lambda x, y, z: x.Date == 3 and y.Qty > 4 and z.Qty < 20, 
    148               """("djvAnimal"."Date" = 3) AND (("djvAnimal"."Qty" > 4) AND ("djvAnimal"."Qty" < 20))""", False) 
    149  
    150  
    151 class AdapterTests(unittest.TestCase): 
    152      
    153     def test_dates(self): 
    154         box = zoo.arena.new_sandbox() 
    155          
    156         WAP = zoo.Zoo() 
    157         WAP.Name = 'Wild Animal Park' 
    158         WAP.Founded = d = datetime.date(2000, 1, 1) 
    159         # 59 should give rounding errors with divmod, 
    160         # which the Adapter needs to correct. 
    161         WAP.Opens = t = datetime.time(8, 15, 59) 
    162         WAP.LastEscape = dt = datetime.datetime(2004, 7, 29, 5, 6, 7) 
    163         box.memorize(WAP) 
    164          
    165         box.flush_all() 
    166          
    167         WAP = box.unit(zoo.Zoo, Name='Wild Animal Park') 
    168         self.assertNotEqual(WAP, None) 
    169         self.assertEqual(WAP.Founded, d) 
    170         self.assertEqual(WAP.Opens, t) 
    171         self.assertEqual(WAP.LastEscape, dt) 
    172  
    173  
    174 if __name__ == "__main__": 
     3if __name__ == '__main__': 
    1754    dbname = raw_input("Database name [dejavu_test]:") or "dejavu_test" 
    1765    pword = raw_input("Password for the postgres user:") 
    177     conn = "host=localhost dbname=%s user=postgres password=%s" % (dbname, pword) 
    178     smOptions = {u'Connect': conn, 
    179                  u'Create If Missing': True, 
    180                  } 
    181     testSM = storepypgsql.StorageManagerPgSQL("test", zoo.arena, smOptions) 
    182     zoo.arena.add_store('testSM', testSM) 
     6    opts = {u'Connect': ("host=localhost dbname=%s user=postgres password=%s" 
     7                         % (dbname, pword)), 
     8            u'Create If Missing': True, 
     9            } 
    18310     
    184     # Create the database and our tables if necessary. 
    185     for cls in (zoo.Animal, zoo.Zoo, zoo.Exhibit): 
     11    try: 
     12        testSM = zoo_fixture.setup_SM("dejavu.storage.storepypgsql.StorageManagerPgSQL", opts) 
     13        zoo_fixture.run_tests() 
     14    finally: 
     15        zoo_fixture.zoo.arena.shutdown() 
    18616        try: 
    187             zoo.arena.create_storage(cls) 
    188         except Exception, x: 
    189             if x.args[0] == ('FATAL:  database "%s" does not exist\n' % dbname): 
    190                 testSM.create_database() 
    191                 zoo.arena.create_storage(cls) 
    192             else: 
    193                 traceback.print_exc() 
    194      
    195     unittest.main() 
     17            testSM.drop_database() 
     18        except NameError: 
     19            pass 
    19620 
  • trunk/storage/test_storeshelve.py

    r43 r44  
    1 import unittest 
    2 import storeshelve 
     1"""Test the shelve Storage Manager for dejavu. 
    32 
    4 from dejavu import logic, zoo 
     3Notice that, since StorageManagerShelve doesn't decompile any Expressions, 
     4this will also test all native dejavu logic functions and any other aspects 
     5of Expression.evaluate(unit). 
     6""" 
    57 
    6 path = r"C:\Python23\Lib\site-packages\dejavu\storage" 
    7 testSM = storeshelve.StorageManagerShelve("test", zoo.arena, {u'Path': path}) 
    8 zoo.arena.add_store('testSM', testSM) 
     8import os 
     9from dejavu.storage import zoo_fixture 
    910 
    10  
    11 class StorageManagerTests(unittest.TestCase): 
     11if __name__ == '__main__': 
     12    opts = {u'Path': os.getcwd()} 
    1213     
    13     def test_unit_roundtrip(self): 
    14         """Assert that a Unit can be loaded and saved.""" 
    15         self.assertEqual(zoo.arena.defaultStore, testSM) 
    16         self.assertEqual(zoo.arena.defaultStore.shelvepath, path) 
    17         self.assertEqual(zoo.arena.storage(zoo.Animal), testSM) 
    18          
    19         box = zoo.arena.new_sandbox() 
    20          
    21         e = logic.Expression(lambda x: x.Name == 'Cat') 
    22         for unit in box.recall(zoo.Animal, e): 
    23             unit.forget() 
    24          
    25         cat = zoo.Animal(Name='Cat', Legs=1) 
    26         self.assertEqual(cat.Name, 'Cat') 
    27         legs = cat.Legs + 1 
    28         if legs > 10: 
    29             legs = 4 
    30         cat.Legs = legs 
    31         box.memorize(cat) 
    32         box.flush_all() 
    33          
    34         # Now, do the whole thing again to see if our save worked. 
    35         box = zoo.arena.new_sandbox() 
    36         u = [x for x in box.recall(zoo.Animal, e)] 
    37         self.assertEqual(len(u), 1) 
    38         cat = u[0] 
    39         self.assertEqual(cat.Name, 'Cat') 
    40         self.assertEqual(cat.Legs, legs) 
    41          
    42         # Now, do the whole thing again just for kicks. 
    43         box = zoo.arena.new_sandbox() 
    44         u = [x for x in box.recall(zoo.Animal, e)] 
    45         self.assertEqual(len(u), 1) 
    46         cat = u[0] 
    47         self.assertEqual(cat.Name, 'Cat') 
    48         self.assertEqual(cat.Legs, legs) 
    49  
    50  
    51 if __name__ == "__main__": 
    52     unittest.main() 
    53  
     14    try: 
     15        testSM = zoo_fixture.setup_SM("dejavu.storage.storeshelve.StorageManagerShelve", opts) 
     16        zoo_fixture.run_tests() 
     17    finally: 
     18        zoo_fixture.zoo.arena.shutdown() 
     19        try: 
     20            testSM.drop_database() 
     21        except NameError: 
     22            pass 
  • trunk/zoo.py

    r43 r44  
    2121    ZooID = UnitProperty(int, index=True) 
    2222    Legs = UnitProperty(int) 
    23     Options = UnitProperty(dict) 
     23    PreviousZoos = UnitProperty(list) 
    2424    LastEscape = EscapeProperty(datetime.datetime) 
    2525