Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

Changeset 119

Show
Ignore:
Timestamp:
12/07/05 02:13:22
Author:
fumanchu
Message:

Fix for #27 (OUTER JOIN capability).

  1. New UnitJoin? class, which is automatically returned from Unit class & Unit class (for INNER), and from << and >> (LEFT and RIGHT joins).
  2. All warnings changed from UserWarning? to new dejavu.StorageWarning?.
  3. New warning for MSAccess if MEMO is used for join.
Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/arenas.py

    r118 r119  
    301301                       LOGRECALL) 
    302302         
    303         store = self.arena.storage(classes[0]) 
    304         for c in classes[1:]: 
    305             if self.arena.storage(c) is not store: 
     303        stores = [self.arena.storage(cls) for cls in classes] 
     304        firststore = stores[0] 
     305        for s in stores: 
     306            if s is not firststore: 
    306307                raise ValueError(u"multirecall() does not support multiple" 
    307308                                 u" classes in disparate stores.") 
     
    312313        # in read-only scripts, it should be OK for now. But if you mutate 
    313314        # Units and then call multirecall, expect inconsistent results. 
    314         for unitset in store.multirecall(classes, expr): 
     315        for unitset in firststore.multirecall(classes, expr): 
    315316            confirmed = True 
    316317            for index in xrange(len(unitset)): 
  • trunk/errors.py

    r118 r119  
    11 
    2 __all__ = ['DejavuError', 'AssociationError', 'UnrecallableError'] 
     2__all__ = ['DejavuError', 'AssociationError', 'UnrecallableError', 
     3           'StorageWarning'] 
    34 
    45class DejavuError(Exception): 
     
    1920    pass 
    2021 
     22class StorageWarning(UserWarning): 
     23    """Warning about functionality which is not supported by all SM's.""" 
     24    pass 
  • trunk/storage/db.py

    r118 r119  
    153153                          "using %s. Values may be stored incorrectly." 
    154154                          % (precision, self.numeric_max_precision, 
    155                              cls.__name__, key, self.__class__.__name__)) 
     155                             cls.__name__, key, self.__class__.__name__), 
     156                          dejavu.StorageWarning) 
    156157            precision = self.numeric_max_precision 
    157158        # Assume most people use decimal for money; default scale = 2. 
     
    185186                          "using %s. Values may be stored incorrectly." 
    186187                          % (bytes, self.numeric_max_precision, 
    187                              cls.__name__, key, self.__class__.__name__)) 
     188                             cls.__name__, key, self.__class__.__name__), 
     189                          dejavu.StorageWarning) 
    188190            bytes = self.numeric_max_precision 
    189191        return "NUMERIC(%s, 0)" % bytes 
     
    761763        idlen = self.identifier_length 
    762764        if idlen and len(ident) > idlen: 
    763             warnings.warn("Identifier is longer than %s characters." % idlen) 
     765            warnings.warn(("Identifier is longer than %s characters." 
     766                           % idlen), dejavu.StorageWarning) 
    764767            ident = ident[:idlen] 
    765768        return '"' + ident + '"' 
     
    11071110                          "It may take an absurd amount of time to run, " 
    11081111                          "since each unit must be fully-formed. %s" 
    1109                           % (cls.__name__, self.__class__.__name__, expr)) 
     1112                          % (cls.__name__, self.__class__.__name__, expr), 
     1113                          dejavu.StorageWarning) 
    11101114            for unit in self.recall(cls, expr): 
    11111115                # Use tuples for hashability 
     
    11361140                          "It may take an absurd amount of time to run, " 
    11371141                          "since each unit must be fully-formed. %s" 
    1138                           % (cls.__name__, self.__class__.__name__, expr)) 
     1142                          % (cls.__name__, self.__class__.__name__, expr), 
     1143                          dejavu.StorageWarning) 
    11391144            vals = {} 
    11401145            for unit in self.recall(cls, expr): 
     
    11541159                     for row in data] 
    11551160     
    1156     def multiselect(self, classes, expr): 
     1161    def join(self, unitjoin): 
    11571162        t = self.tablename 
    11581163        i = self.identifier 
    11591164         
    1160         tablenames = [t(cls) for cls in classes] 
     1165        cls1, cls2 = unitjoin.class1, unitjoin.class2 
     1166        if isinstance(cls1, dejavu.UnitJoin): 
     1167            name1 = self.join(cls1) 
     1168            classlist1 = iter(cls1) 
     1169        else: 
     1170            # cls1 is a Unit class. 
     1171            name1 = t(cls1) 
     1172            classlist1 = [cls1] 
     1173         
     1174        if isinstance(cls2, dejavu.UnitJoin): 
     1175            name2 = self.join(cls2) 
     1176            classlist2 = iter(cls2) 
     1177        else: 
     1178            # cls2 is a Unit class. 
     1179            name2 = t(cls2) 
     1180            classlist2 = [cls2] 
     1181         
     1182        # Find an association between the two halves. 
     1183        ua = None 
     1184        for clsA in classlist1: 
     1185            for clsB in classlist2: 
     1186                ua = clsA._associations.get(clsB.__name__, None) 
     1187                if ua: break 
     1188                ua = clsB._associations.get(clsA.__name__, None) 
     1189                if ua: break 
     1190            if ua: break 
     1191        if ua is None: 
     1192            msg = ("No association found between %s and %s." % (cls1, cls2)) 
     1193            raise dejavu.AssociationError(msg) 
     1194         
     1195        if unitjoin.leftbiased is None: 
     1196            j = "INNER" 
     1197        elif unitjoin.leftbiased is True: 
     1198            j = "LEFT" 
     1199        else: 
     1200            j = "RIGHT" 
     1201        return ("(%s %s JOIN %s ON %s.%s = %s.%s)" % 
     1202                (name1, j, name2, 
     1203                 t(ua.nearClass), i(ua.nearKey), 
     1204                 t(ua.farClass), i(ua.farKey))) 
     1205     
     1206    def multiselect(self, classes, expr): 
     1207        tablenames = [self.tablename(cls) for cls in classes] 
    11611208        if expr is None: 
    11621209            expr = logic.Expression(lambda *args: True) 
    11631210        w, imp = self.where(tablenames, expr) 
    11641211         
    1165         # Because various databases may mangle column names, we explicitly 
    1166         # order the requested columns (instead of using *). 
     1212        joins = self.join(classes) 
     1213         
     1214        # Determine output columns. 
     1215        # Because various databases may mangle column names, 
     1216        # we explicitly order them (instead of using *). 
     1217        # Note that, if a class is repeated in the classes tree, 
     1218        # it will be repeated in the output. 
    11671219        columns = [] 
    1168         joins = [] 
    1169         basecls = firstcls = classes[0] 
    11701220        for cls in classes: 
    11711221            # Place the identifier properties first 
    1172             # in case others depend upon it
     1222            # in case others depend upon them
    11731223            idnames = [prop.key for prop in cls.identifiers] 
     1224             
    11741225            keys = idnames + [k for k in cls.properties() if k not in idnames] 
    11751226            columns.extend([(cls, k) for k in keys]) 
    1176              
    1177             if cls is not firstcls: 
    1178                 spath = self.arena.associations.shortest_path(basecls, cls) 
    1179                 # cls1 should be firstcls in every case. 
    1180                 cls1 = spath.pop(0) 
    1181                 for cls2 in spath: 
    1182                     ua = cls1._associations[cls2.__name__] 
    1183                     joins.append("(%s.%s = %s.%s)" % (t(cls1), i(ua.nearKey), 
    1184                                                       t(cls2), i(ua.farKey))) 
    1185                     tablenames.append(t(cls2)) 
    1186                     cls1 = cls2 
    1187                 basecls = cls 
    1188          
    1189         # Remove any duplicate entries in the join clauses. 
    1190         # Note that we assume join clauses are perfect. 
    1191         joins = dict.fromkeys(joins).keys() 
    1192         tablenames = dict.fromkeys(tablenames).keys() 
    1193          
    1194         w = u' AND '.join([w] + joins) 
    1195          
    1196         colnames = ["%s.%s" % (t(cls), i(key)) for cls, key in columns] 
     1227         
     1228        colnames = ["%s.%s" % (self.tablename(cls), self.identifier(key)) 
     1229                    for cls, key in columns] 
    11971230        statement = ("SELECT %s FROM %s WHERE %s" % 
    1198                      (u', '.join(colnames), u', '.join(tablenames), w)) 
     1231                     (u', '.join(colnames), joins, w)) 
    11991232        return statement, imp, columns 
    12001233     
  • trunk/storage/storeado.py

    r118 r119  
    456456                          "not allow comparisons on string fields larger " 
    457457                          "than 8000 characters. Some of your data may be " 
    458                           "truncated."
     458                          "truncated.", dejavu.StorageWarning
    459459            bytes = 8000 
    460460        # 8000 *bytes* is the absolute upper limit, based on T_SQL docs for 
     
    531531                          "using %s. Values may be stored incorrectly." 
    532532                          % (precision, self.numeric_max_precision, 
    533                              cls.__name__, key, self.__class__.__name__)) 
     533                             cls.__name__, key, self.__class__.__name__), 
     534                          dejavu.StorageWarning) 
    534535            precision = self.numeric_max_precision 
    535536        if scale > 4: 
     
    537538                          "using %s. Values may be stored incorrectly." 
    538539                          % (scale, cls.__name__, key, 
    539                              self.__class__.__name__)) 
     540                             self.__class__.__name__), 
     541                          dejavu.StorageWarning) 
    540542         
    541543        # MS Access doesn't let us control precision and scale directly. 
     
    599601            # MEMO is 1 GB max when set programatically (only 64K when set 
    600602            # in Access UI). But then, 1 GB is the limit for the whole DB. 
     603            for assoc in cls._associations.itervalues(): 
     604                if assoc.nearKey == key: 
     605                    warnings.warn("Memo fields cannot be used as join keys. " 
     606                                  "You should set %s.%s(hints={'bytes': 255})" 
     607                                  % (cls.__name__, key), 
     608                                  dejavu.StorageWarning) 
    601609            return u"MEMO" 
    602610 
  • trunk/storage/storemysql.py

    r118 r119  
    158158        idlen = self.identifier_length 
    159159        if idlen and len(ident) > idlen: 
    160             warnings.warn("Identifier is longer than %s characters." % idlen) 
     160            warnings.warn(("Identifier is longer than %s characters." 
     161                           % idlen), dejavu.StorageWarning) 
    161162            ident = ident[:idlen] 
    162163        return '`' + ident + '`' 
  • trunk/storage/storeodbc.py

    r118 r119  
    1919        warnings.warn("The precision of %s.%s cannot be determined for " 
    2020                      "ODBC stores. Values may be stored incorrectly." 
    21                       % (cls.__name__, key)
     21                      % (cls.__name__, key), dejavu.StorageWarning
    2222        return u"NUMERIC" 
    2323     
     
    2525        warnings.warn("The precision of %s.%s cannot be determined for " 
    2626                      "ODBC stores. Values may be stored incorrectly." 
    27                       % (cls.__name__, key)
     27                      % (cls.__name__, key), dejavu.StorageWarning
    2828        return u"NUMERIC" 
    2929     
     
    3939        warnings.warn("The precision of %s.%s cannot be determined for " 
    4040                      "ODBC stores. Values may be stored incorrectly." 
    41                       % (cls.__name__, key)
     41                      % (cls.__name__, key), dejavu.StorageWarning
    4242        return u"NUMERIC" 
    4343 
  • trunk/storage/storeshelve.py

    r118 r119  
    168168    def multirecall(self, classes, expr): 
    169169        """multirecall(classes, expr) -> Full inner join units.""" 
    170          
    171         firstcls = classes[0] 
    172         tables = {} 
    173         joins = dict([(cls, None) for cls in classes]) 
     170        if expr is None: 
     171            expr = logic.Expression(lambda *args: True) 
     172         
     173        firstcls = list(classes)[0] 
    174174        # TODO: deconstruct expr into a set of subexpr's, one for 
    175175        # each class in classes. 
    176176        filters = dict([(cls, None) for cls in classes]) 
    177177         
    178         def combine(nearValue, farKey, *classes): 
    179             classes = list(classes) 
    180             thiscls = classes.pop(0) 
    181              
    182             # Use cached table if present 
    183             cached = (thiscls in tables) 
    184             if cached: 
    185                 table = tables[thiscls] 
     178        def combine(unitjoin): 
     179            cls1, cls2 = unitjoin.class1, unitjoin.class2 
     180             
     181            if isinstance(cls1, dejavu.UnitJoin): 
     182                table1 = combine(cls1) 
     183                classlist1 = iter(cls1) 
    186184            else: 
    187                 table = self.recall(thiscls, filters[thiscls]) 
    188                 tables[thiscls] = newcache = [] 
    189              
    190             nextClass = None 
    191             if classes: 
    192                 nextClass = classes[0] 
    193                 ua = thiscls._associations[nextClass.__name__] 
    194                 nextNearKey, nextFarKey = ua.nearKey, ua.farKey 
    195              
    196             for unit in table: 
    197                 # Note that the caching happens only if the optimization 
    198                 # filters succeed; however, it doesn't depend on whether 
    199                 # the join test succeeds or fails. 
    200                 if not cached: 
    201                     newcache.append(unit) 
    202                  
    203                 # Test against join constraint 
    204                 if farKey and getattr(unit, farKey) != nearValue: 
    205                     continue 
    206                  
    207                 if nextClass: 
    208                     newNearVal = getattr(unit, nextNearKey) 
    209                     for subunits in combine(newNearVal, nextFarKey, *classes): 
    210                         yield [unit,] + subunits 
    211                 else: 
    212                     yield [unit,] 
    213          
    214         for unitrow in combine(None, None, *classes): 
     185                table1 = [[x] for x in self.recall(cls1, filters[cls1])] 
     186                classlist1 = [cls1] 
     187             
     188            if isinstance(cls2, dejavu.UnitJoin): 
     189                table2 = combine(cls2) 
     190                classlist2 = iter(cls2) 
     191            else: 
     192                table2 = [[x] for x in self.recall(cls2, filters[cls2])] 
     193                classlist2 = [cls2] 
     194             
     195            # Find an association between the two halves. 
     196            ua = None 
     197            for indexA, clsA in enumerate(classlist1): 
     198                for indexB, clsB in enumerate(classlist2): 
     199                    ua = clsA._associations.get(clsB.__name__, None) 
     200                    if ua: 
     201                        nearKey, farKey = ua.nearKey, ua.farKey 
     202                        break 
     203                    ua = clsB._associations.get(clsA.__name__, None) 
     204                    if ua: 
     205                        nearKey, farKey = ua.farKey, ua.nearKey 
     206                        break 
     207                if ua: break 
     208            if ua is None: 
     209                msg = ("No association found between %s and %s." % (cls1, cls2)) 
     210                raise dejavu.AssociationError(msg) 
     211             
     212            unitrows = [] 
     213            if unitjoin.leftbiased is None: 
     214                # INNER JOIN 
     215                for row1 in table1: 
     216                    nearVal = getattr(row1[indexA], nearKey) 
     217                    for row2 in table2: 
     218                        # Test against join constraint 
     219                        farVal = getattr(row2[indexB], farKey) 
     220                        if nearVal == farVal: 
     221                            unitrows.append(row1 + row2) 
     222            elif unitjoin.leftbiased is True: 
     223                # LEFT JOIN 
     224                for row1 in table1: 
     225                    nearVal = getattr(row1[indexA], nearKey) 
     226                    found = False 
     227                    for row2 in table2: 
     228                        # Test against join constraint 
     229                        farVal = getattr(row2[indexB], farKey) 
     230                        if nearVal == farVal: 
     231                            unitrows.append(row1 + row2) 
     232                            found = True 
     233                    if not found: 
     234                        unitrows.append(row1 + [unit.__class__() for unit in row2]) 
     235            else: 
     236                # RIGHT JOIN 
     237                for row2 in table2: 
     238                    farVal = getattr(row2[indexB], farKey) 
     239                    found = False 
     240                    for row1 in table1: 
     241                        # Test against join constraint 
     242                        nearVal = getattr(row1[indexA], nearKey) 
     243                        if nearVal == farVal: 
     244                            unitrows.append(row1 + row2) 
     245                            found = True 
     246                    if not found: 
     247                        unitrows.append([unit.__class__() for unit in row1] + row2) 
     248            return unitrows 
     249         
     250        for unitrow in combine(classes): 
    215251            if expr(*unitrow): 
    216252                yield unitrow 
  • trunk/storage/storesqlite.py

    r118 r119  
    1515                       "wildcard literals." % _version) 
    1616    import warnings 
    17     warnings.warn(_escape_warning
     17    warnings.warn(_escape_warning, dejavu.StorageWarning
    1818 
    1919 
     
    204204            self.execute(u'CREATE INDEX %s ON %s (%s);' % 
    205205                         (i, tablename, self.identifier(index))) 
    206  
     206     
     207    def join(self, unitjoin): 
     208        t = self.tablename 
     209        i = self.identifier 
     210         
     211        on_clauses = [] 
     212         
     213        cls1, cls2 = unitjoin.class1, unitjoin.class2 
     214        if isinstance(cls1, dejavu.UnitJoin): 
     215            name1, w = self.join(cls1) 
     216            on_clauses.extend(w) 
     217            classlist1 = iter(cls1) 
     218        else: 
     219            # cls1 is a Unit class. 
     220            name1 = t(cls1) 
     221            classlist1 = [cls1] 
     222         
     223        if isinstance(cls2, dejavu.UnitJoin): 
     224            name2, w = self.join(cls2) 
     225            on_clauses.extend(w) 
     226            classlist2 = iter(cls2) 
     227        else: 
     228            # cls2 is a Unit class. 
     229            name2 = t(cls2) 
     230            classlist2 = [cls2] 
     231         
     232        # Find an association between the two halves. 
     233        ua = None 
     234        for cls1 in classlist1: 
     235            for cls2 in classlist2: 
     236                ua = cls1._associations.get(cls2.__name__, None) 
     237                if ua: break 
     238                ua = cls2._associations.get(cls1.__name__, None) 
     239                if ua: break 
     240            if ua: break 
     241        if ua is None: 
     242            msg = ("No association found between %s and %s." % (cls1, cls2)) 
     243            raise dejavu.AssociationError(msg) 
     244         
     245        if unitjoin.leftbiased is None: 
     246            j = "%s INNER JOIN %s" % (name1, name2) 
     247        elif unitjoin.leftbiased is True: 
     248            j = "%s LEFT JOIN %s" % (name1, name2) 
     249        else: 
     250            # My version (3.0.8) of SQLite says: 
     251            # "RIGHT and FULL OUTER JOINs are not currently supported". 
     252            # TODO: find out if any versions do support it. 
     253            j = "%s LEFT JOIN %s" % (name2, name1) 
     254        w = ("%s.%s = %s.%s" % (t(ua.nearClass), i(ua.nearKey), 
     255                                t(ua.farClass), i(ua.farKey))) 
     256        on_clauses.append(w) 
     257        return j, on_clauses 
     258     
     259    def multiselect(self, classes, expr): 
     260        tablenames = [self.tablename(cls) for cls in classes] 
     261        if expr is None: 
     262            expr = logic.Expression(lambda *args: True) 
     263        w, imp = self.where(tablenames, expr) 
     264         
     265        # SQLite doesn't do nested JOINs, but instead applies them 
     266        # in order. Therefore, we need a single ON-clause at the 
     267        # end of the list of tables. For example: 
     268        # "From a LEFT JOIN b LEFT JOIN c ON a.ID = b.ID AND b.Name = c.Name 
     269        joins, on_clauses = self.join(classes) 
     270        joins += " ON " + " AND ".join(on_clauses) 
     271         
     272        # Determine output columns. 
     273        # Because various databases may mangle column names, 
     274        # we explicitly order them (instead of using *). 
     275        # Note that, if a class is repeated in the classes tree, 
     276        # it will be repeated in the output. 
     277        columns = [] 
     278        for cls in classes: 
     279            # Place the identifier properties first 
     280            # in case others depend upon them. 
     281            idnames = [prop.key for prop in cls.identifiers] 
     282             
     283            keys = idnames + [k for k in cls.properties() if k not in idnames] 
     284            columns.extend([(cls, k) for k in keys]) 
     285         
     286        colnames = ["%s.%s" % (self.tablename(cls), self.identifier(key)) 
     287                    for cls, key in columns] 
     288        statement = ("SELECT %s FROM %s WHERE %s" % 
     289                     (u', '.join(colnames), joins, w)) 
     290        return statement, imp, columns 
     291 
  • trunk/test/test_dejavu.py

    r118 r119  
    11import datetime 
    22import unittest 
     3import warnings 
    34 
    45import dejavu 
    56from dejavu import storage 
    6 from dejavu.test import zoo_fixture 
     7from dejavu.test.zoo_fixture import * 
    78 
    8 zoo_fixture.arena.add_store("default", "dejavu.storage.CachingProxy") 
     9 
     10arena.add_store("default", "dejavu.storage.CachingProxy") 
    911 
    1012 
     
    1315    def setUp(self): 
    1416        # CleanUP The Database! 
    15         box = zoo_fixture.arena.new_sandbox() 
    16         for animal in box.recall(zoo_fixture.Animal): 
     17        box = arena.new_sandbox() 
     18        for animal in box.recall(Animal): 
    1719            animal.forget() 
    18         for zoo_thing in box.recall(zoo_fixture.Zoo): 
     20        for zoo_thing in box.recall(Zoo): 
    1921            zoo_thing.forget() 
    2022     
     
    2224        # Instance creation and population 
    2325        f = datetime.date(1916, 10, 2) 
    24         z = zoo_fixture.Zoo(Name='San Diego Zoo', Founded=f) 
     26        z = Zoo(Name='San Diego Zoo', Founded=f) 
    2527        self.assertEqual(z.dirty(), True) 
    26         self.assertEqual(zoo_fixture.Zoo.ID.type, int) 
     28        self.assertEqual(Zoo.ID.type, int) 
    2729        self.assertEqual(z.ID, None) 
    2830        self.assertEqual(z.Name, 'San Diego Zoo') 
     
    3133        self.assertEqual(z.__class__.ID.index, True) 
    3234         
    33         a = zoo_fixture.Animal(Name='Giraffe', Legs=4) 
     35        a = Animal(Name='Giraffe', Legs=4) 
    3436        self.assertEqual(a.dirty(), True) 
    3537        self.assertEqual(a.ID, None) 
     
    3941         
    4042        # Sandboxing 
    41         s = zoo_fixture.arena.new_sandbox() 
     43        s = arena.new_sandbox() 
    4244        s.memorize(z) 
    4345        self.assertEqual(z.ID, 1) 
     
    5961         
    6062        # Create some animals in a sandbox 
    61         box = zoo_fixture.arena.new_sandbox() 
    62         box.memorize(zoo_fixture.Animal(Name='Wombat', Legs=4)) 
    63         box.memorize(zoo_fixture.Animal(Name='Lizard', Legs=4)) 
     63        box = arena.new_sandbox() 
     64        box.memorize(Animal(Name='Wombat', Legs=4)) 
     65        box.memorize(Animal(Name='Lizard', Legs=4)) 
    6466         
    6567        animals = [] 
     
    6769         
    6870        # Start a new sandbox with no cache 
    69         box = zoo_fixture.arena.new_sandbox() 
     71        box = arena.new_sandbox() 
    7072         
    7173        # get animals alternating from two different xrecalls 
    7274        animals_is_stopped = False 
    7375        animals2_is_stopped = False 
    74         animals_iter = box.xrecall(zoo_fixture.Animal) 
    75         animals2_iter = box.xrecall(zoo_fixture.Animal) 
     76        animals_iter = box.xrecall(Animal) 
     77        animals2_iter = box.xrecall(Animal) 
    7678        while not (animals_is_stopped and animals2_is_stopped): 
    7779            try: 
     
    99101         
    100102        # Create an animal in a sandbox, but retain a reference to it 
    101         box = zoo_fixture.arena.new_sandbox() 
    102         bat = zoo_fixture.Animal(Name='Bat', Legs=4) 
     103        box = arena.new_sandbox() 
     104        bat = Animal(Name='Bat', Legs=4) 
    103105        box.memorize(bat) 
    104106         
     
    107109         
    108110        # Retrieve the Unit from the same sandbox again. 
    109         self.assert_(box.unit(zoo_fixture.Animal) is bat) 
     111        self.assert_(box.unit(Animal) is bat) 
    110112         
    111113        # Retrieve the Unit from a new sandbox. 
    112114        # Units should be different, and their 
    113115        # UnitProperties should be different. 
    114         bat3 = zoo_fixture.arena.new_sandbox().unit(zoo_fixture.Animal) 
     116        bat3 = arena.new_sandbox().unit(Animal) 
    115117        self.assert_(bat3 is not bat) 
    116118        self.assertEqual(bat3.Legs, 4) 
     119     
     120    def test_UnitJoin(self): 
     121        box = arena.new_sandbox() 
     122        tree = Animal & Zoo 
     123        self.assertEqual(str(tree), "(Animal & Zoo)") 
     124        tree = Animal << Zoo 
     125        self.assertEqual(str(tree), "(Animal << Zoo)") 
     126        tree = Animal >> Zoo 
     127        self.assertEqual(str(tree), "(Animal >> Zoo)") 
     128         
     129        trees = [] 
     130        def make_tree(): 
     131            trees.append( (Animal & Zoo) >> Exhibit ) 
     132         
     133        warnings.filterwarnings("error", category=dejavu.StorageWarning) 
     134        try: 
     135            self.assertRaises(dejavu.StorageWarning, make_tree) 
     136        finally: 
     137            warnings.filters.pop(0) 
     138         
     139        # Since we raised the warning, our first make_tree failed. 
     140        warnings.filterwarnings("ignore", category=dejavu.StorageWarning) 
     141        try: 
     142            make_tree() 
     143        finally: 
     144            warnings.filters.pop(0) 
     145         
     146        self.assertEqual(str(trees[0]), "((Animal & Zoo) >> Exhibit)") 
     147        tree = trees[0] & (Visit << Vet) & Exhibit 
     148        self.assertEqual(str(tree), "((((Animal & Zoo) >> Exhibit) & " 
     149                                    "(Visit << Vet)) & Exhibit)") 
     150        self.assertEqual(list(tree), [Animal, Zoo, Exhibit, Visit, Vet, Exhibit]) 
    117151 
    118152 
  • trunk/test/zoo_fixture.py

    r118 r119  
    33import datetime 
    44import os 
     5try: 
     6    import pythoncom 
     7except ImportError: 
     8    pythoncom = None 
     9 
     10try: 
     11    set 
     12except NameError: 
     13    from sets import Set as set 
     14 
    515import threading 
    616import unittest 
     17import warnings 
    718 
    819try: 
     
    431442        f = logic.Expression(lambda z, a: z.Name == 'San Diego Zoo') 
    432443        zooed_animals = [(z, a) for z, a in 
    433                          box.multirecall([Zoo, Animal], f)] 
     444                         box.multirecall(Zoo & Animal, f)] 
    434445        self.assertEqual(len(zooed_animals), 2) 
    435446         
     
    448459        leo = logic.Expression(lambda z, a: a.Species == 'Leopard') 
    449460        zooed_animals = [(z, a) for z, a in 
    450                          box.multirecall([Zoo, Animal], sdexpr + leo)] 
     461                         box.multirecall(Zoo & Animal, sdexpr + leo)] 
    451462        self.assertEqual(len(zooed_animals), 0) 
     463         
     464        # Now try the same expr with INNER, LEFT, and RIGHT JOINs. 
     465        zooed_animals = list(box.multirecall(Zoo & Animal)) 
     466        self.assertEqual(len(zooed_animals), 6) 
     467        self.assertEqual(set([(z.Name, a.Species) for z, a in zooed_animals]), 
     468                         set([("Wild Animal Park", "Leopard"), 
     469                              ("Wild Animal Park", "Lion"), 
     470                              ("San Diego Zoo", "Tiger"), 
     471                              ("San Diego Zoo", "Millipede"), 
     472                              ("Sea_World", "Emperor Penguin"), 
     473                              ("Sea_World", "Adelie Penguin")])) 
     474         
     475        zooed_animals = list(box.multirecall(Zoo >> Animal)) 
     476        self.assertEqual(len(zooed_animals), 12) 
     477        self.assertEqual(set([(z.Name, a.Species) for z, a in zooed_animals]), 
     478                         set([("Wild Animal Park", "Leopard"), 
     479                              ("Wild Animal Park", "Lion"), 
     480                              ("San Diego Zoo", "Tiger"), 
     481                              ("San Diego Zoo", "Millipede"), 
     482                              ("Sea_World", "Emperor Penguin"), 
     483                              ("Sea_World", "Adelie Penguin"), 
     484                              (None, "Slug"), 
     485                              (None, "Bear"), 
     486                              (None, "Ostrich"), 
     487                              (None, "Centipede"), 
     488                              (None, "Ape"), 
     489                              (None, "Ape"), 
     490                              ])) 
     491         
     492        zooed_animals = list(box.multirecall(Zoo << Animal)) 
     493        self.assertEqual(len(zooed_animals), 7) 
     494        self.assertEqual(set([(z.Name, a.Species) for z, a in zooed_animals]), 
     495                         set([("Wild Animal Park", "Leopard"), 
     496                              ("Wild Animal Park", "Lion"), 
     497                              ("San Diego Zoo", "Tiger"), 
     498                              ("San Diego Zoo", "Millipede"), 
     499                              ("Sea_World", "Emperor Penguin"), 
     500                              ("Sea_World", "Adelie Penguin"), 
     501                              (u'Montr\xe9al Biod\xf4me', None), 
     502                              ])) 
    452503         
    453504        # Try a multiple-arg expression 
    454505        f = logic.Expression(lambda a, z: a.Legs >= 4 and z.Admission < 10) 
    455         animal_zoos = [(a, z) for a, z in box.multirecall([Animal, Zoo], f)] 
     506        animal_zoos = [(a, z) for a, z in box.multirecall(Animal & Zoo, f)] 
    456507        self.assertEqual(len(animal_zoos), 4) 
    457508        names = [a.Species for a, z in animal_zoos] 
     
    460511         
    461512        # Let's try three joined classes just for the sadistic fun of it. 
     513        tree = (Animal >> Zoo) >> Vet 
    462514        f = logic.Expression(lambda a, z, v: z.Name == 'Sea_World') 
    463         azv = [(a, z, v) for a, z, v in 
    464                box.multirecall([Animal, Zoo, Vet], f)] 
     515        azv = list(box.multirecall(tree, f)) 
    465516        self.assertEqual(len(azv), 2) 
     517         
     518        # MSAccess can't handle an INNER JOIN nested in an OUTER JOIN. 
     519        # Test that this fails for MSAccess, but works for other SM's. 
     520        trees = [] 
     521        def make_tree(): 
     522            trees.append( (Animal & Zoo) >> Vet ) 
     523        warnings.filterwarnings("ignore", category=dejavu.StorageWarning) 
     524        try: 
     525            make_tree() 
     526        finally: 
     527            warnings.filters.pop(0) 
     528         
     529        azv = [] 
     530        def set_azv(): 
     531            f = logic.Expression(lambda a, z, v: z.Name == 'Sea_World') 
     532            azv.append([(a, z, v) for a, z, v in 
     533                        box.multirecall(trees[0], f)]) 
     534         
     535        smname = arena.stores['testSM'].__class__.__name__ 
     536        if smname in ("StorageManagerADO_MSAccess",): 
     537            self.assertRaises(pythoncom.com_error, set_azv) 
     538        else: 
     539            set_azv() 
     540            self.assertEqual(len(azv[0]), 2) 
    466541     
    467542    def test_Editing(self): 
     
    589664        self.assertEqual(len(snaps[1]), 2) 
    590665        self.assertEqual(eng.last_snapshot(), snaps[1]) 
    591      
     666##     
    592667##    def test_Transactions(self): 
    593668##        box = arena.new_sandbox() 
  • trunk/units.py

    r118 r119  
    1 import sha 
    2 import types 
    3  
    41try: 
    52    import cPickle as pickle 
     
    74    import pickle 
    85 
    9 import logic as logic 
    10 import errors as errors 
    11  
    12  
    13 __all__ = ['UnitAssociation', 'ToMany', 'ToOne', 
     6import sha 
     7import types 
     8import warnings 
     9 
     10import logic 
     11import errors 
     12 
     13 
     14__all__ = ['UnitAssociation', 'ToMany', 'ToOne', 'UnitJoin', 
    1415           'Unit', 'UnitProperty', 'MetaUnit', 
    1516           'UnitSequencerInteger', 'UnitSequencerNull', 
     
    109110            f += expr 
    110111        return unit.sandbox.recall(self.farClass, f) 
     112 
     113 
     114class UnitJoin(object): 
     115     
     116    def __init__(self, class1, class2, leftbiased=None): 
     117        self.class1 = class1 
     118        self.class2 = class2 
     119        self.leftbiased = leftbiased 
     120         
     121        # From http://msdn.microsoft.com/library/en-us/ 
     122        #           dnacc2k/html/acintsql.asp#acintsql_joins 
     123        # "OUTER JOINs can be nested inside INNER JOINs in a multi-table 
     124        # join, but INNER JOINs cannot be nested inside OUTER JOINs." 
     125        if leftbiased is not None: 
     126            if ((isinstance(class1, UnitJoin) and class1.leftbiased is None) 
     127                or (isinstance(class1, UnitJoin) and class1.leftbiased is None)): 
     128                warnings.warn("Some StorageManagers cannot nest an INNER " 
     129                              "JOIN within an OUTER JOIN. Consider rewriting " 
     130                              "your join tree.", errors.StorageWarning) 
     131     
     132    def __str__(self): 
     133        if self.leftbiased is None: 
     134            op = "&" 
     135        elif self.leftbiased is True: 
     136            op = "<<" 
     137        else: 
     138            op = ">>" 
     139        if isinstance(self.class1, UnitJoin): 
     140            name1 = str(self.class1) 
     141        else: 
     142            name1 = self.class1.__name__ 
     143        if isinstance(self.class2, UnitJoin): 
     144            name2 = str(self.class2) 
     145        else: 
     146            name2 = self.class2.__name__ 
     147        return "(%s %s %s)" % (name1, op, name2) 
     148    __repr__ = __str__ 
     149     
     150    def __iter__(self): 
     151        def genclasses(): 
     152            if isinstance(self.class1, MetaUnit): 
     153                yield self.class1 
     154            else: 
     155                for cls in iter(self.class1): 
     156                    yield cls 
     157            if isinstance(self.class2, MetaUnit): 
     158                yield self.class2 
     159            else: 
     160                for cls in iter(self.class2): 
     161                    yield cls 
     162        return genclasses() 
     163     
     164    def __lshift__(self, other): 
     165        if isinstance(other, (MetaUnit, UnitJoin)): 
     166            return UnitJoin(self, other, leftbiased=True) 
     167        else: 
     168            raise TypeError("Joined classes must be UnitJoin or Unit subclasses.") 
     169    __rrshift__ = __lshift__ 
     170     
     171    def __rshift__(self, other): 
     172        if isinstance(other, (MetaUnit, UnitJoin)): 
     173            return UnitJoin(self, other, leftbiased=False) 
     174        else: 
     175            raise TypeError("Joined classes must be UnitJoin or Unit subclasses.") 
     176    __rlshift__ = __rshift__ 
     177     
     178    def __add__(self, other): 
     179        if isinstance(other, (MetaUnit, UnitJoin)): 
     180            return UnitJoin(self, other) 
     181        else: 
     182            raise TypeError("Joined classes must be UnitJoin or Unit subclasses.") 
     183    __and__ = __add__ 
     184     
     185    def __radd__(self, other): 
     186        if isinstance(other, (MetaUnit, UnitJoin)): 
     187            return UnitJoin(other, self) 
     188        else: 
     189            raise TypeError("Joined classes must be UnitJoin or Unit subclasses.") 
     190    __rand__ = __radd__ 
    111191 
    112192 
     
    310390        cls._properties = props 
    311391        cls._associations = assocs 
     392     
     393    def __lshift__(self, other): 
     394        if isinstance(other, (MetaUnit, UnitJoin)): 
     395            return UnitJoin(self, other, leftbiased=True) 
     396        else: 
     397            raise TypeError("Joined classes must be UnitJoin or Unit subclasses.") 
     398    __rrshift__ = __lshift__ 
     399     
     400    def __rshift__(self, other): 
     401        if isinstance(other, (MetaUnit, UnitJoin)): 
     402            return UnitJoin(self, other, leftbiased=False) 
     403        else: 
     404            raise TypeError("Joined classes must be UnitJoin or Unit subclasses.") 
     405    __rlshift__ = __rshift__ 
     406     
     407    def __add__(self, other): 
     408        if isinstance(other, (MetaUnit, UnitJoin)): 
     409            return UnitJoin(self, other) 
     410        else: 
     411            raise TypeError("Joined classes must be UnitJoin or Unit subclasses.") 
     412    __and__ = __add__ 
     413     
     414    def __radd__(self, other): 
     415        if isinstance(other, (MetaUnit, UnitJoin)): 
     416            return UnitJoin(other, self) 
     417        else: 
     418            raise TypeError("Joined classes must be UnitJoin or Unit subclasses.") 
     419    __rand__ = __radd__ 
    312420 
    313421 
     
    519627                # If far key is already set, it will simply be overwritten. 
    520628                setattr(unit, ua.farKey, nearval) 
    521