import md5
import memcache
import re
import sys

try:
    set
except NameError:
    from sets import Set as set

import dejavu
from dejavu import errors, logflags, storage
from geniusql import logic


def bytecode_regex(bits):
    """Make a regular expression out of the given mixed bytecode bits.
    
    If any bit is an integer, it will be replaced with re.escape(chr(bit)).
    Any bits which are already strings will be added as-is.
    """
    s = []
    for bit in bits:
        if not isinstance(bit, basestring):
            bit = re.escape(chr(bit))
        s.append(bit)
    return "".join(s)

simple_compare = bytecode_regex([124, 0, 0, 105, ".", ".", 100, ".", ".", 106, 2, 0])
simple_and = bytecode_regex([111, ".", ".", 1])
indexable_regex = re.compile("^(%s(%s)?)+S$" % (simple_compare, simple_and))


class MemcachedStorageManager(storage.StorageManager):
    """A Storage Manager which keeps all data in memcached.
    
    memcached is a high-performance, distributed memory object caching
    system, generic in nature, but intended for use in speeding up
    dynamic web applications by alleviating database load.
    
    See http://www.danga.com/memcached/
    and ftp://ftp.tummy.com/pub/python-memcached/
    
    IMPORTANT: data stuck into memcached is not guaranteed to be stable.
    It may disappear at any time, according to an internal LRU algorithm.
    In particular, you should be aware that the LRU algorithm is itself
    partitioned by object size (into "slabs"), so that a newer object
    may be removed before an older one if they are of significantly
    different sizes.
    
    Options:
        memcached.servers: a list of strings of the form 'IP-address:port'.
            These will be passed directly into the memcache.Client instance.
        
        memcached.global_index: if True (the default), this store will
            maintain a index over the identifiers of all stored objects in
            memcached itself. This is the 'safe' choice, and necessary if
            your only store is memcached. However, if you run this store as
            an ObjectCache.cache, you should turn this off, allowing
            ObjectCache.nextstore to maintain the primary indexes--this
            allows the cache to run orders of magnitude faster.
        
        memcached.index_time: the timeout, in seconds, for any cached indexes.
            Default is 300 seconds.
    """
    
    def __init__(self, allOptions={}):
        storage.StorageManager.__init__(self, allOptions)
        
        self.name = allOptions['name']
        self.global_index = allOptions.pop("memcached.global_index", True)
        self.index_time = allOptions.pop("memcached.index_time", 5 * 60)
        self.primary_keys = {}
        self.indexsets = {}
        self.index_stride = 50
        
        cache_opts = dict([(k[10:], v) for k, v in allOptions.iteritems()
                           if k.startswith("memcached.")])
        self.client = memcache.Client(**cache_opts)
    
    def hash(self, object):
        """Return a consistent hash for object (for use in a memcached key)."""
        # TODO: can we add overflow support for collisions?
        return md5.new(repr(object)).hexdigest()
    
    def _unit_key(self, unit):
        """Return (ident, memcached key) for the given unit."""
        cls = unit.__class__
        ident = tuple([getattr(unit, name) for name in self.primary_keys[cls]])
        key = "%s:%s:%s" % (self.name, cls.__name__, self.hash(ident))
        return key
    
    def unit(self, cls, **kwargs):
        """A single Unit which matches the given kwargs, else None.
        
        The first Unit matching the kwargs is returned; if no Units match,
        None is returned.
        """
        keyset = set(kwargs.keys())
        
        # Try to retrieve a matching unit using its primary_keys.
        # This will skip grabbing any indices (a HUGE optimization).
        pk = self.primary_keys[cls]
        if keyset >= set(pk):
            return self._unit_by_primary_key(cls, pk, kwargs)
        
        # Try to retrieve a matching unit using an index.
        # If self.global_index is True, the last one should
        # be an index with propnames == []. See self.register.
        indexset = self.indexsets[cls]
        for index in indexset:
            if keyset >= set(index):
                unit = indexset.unit(index, kwargs)
                if unit is not None:
                    if self.logflags & logflags.RECALL:
                        self.log(logflags.RECALL.message(cls, ('HIT', kwargs)))
                    return unit
        
        # Return None since we have no more access paths.
        if self.logflags & logflags.RECALL:
            self.log(logflags.RECALL.message(cls, ('DEFER', kwargs)))
        return None
    
    def xrecall(self, classes, expr=None, order=None, limit=None, offset=None):
        """Yield units of the given cls which match the given expr."""
        if isinstance(classes, dejavu.UnitJoin):
            for units in self._xmultirecall(classes, expr, order=order,
                                            limit=limit, offset=offset):
                yield units
            return
        
        cls = classes
        indexset = self.indexsets[cls]
        
        if not isinstance(expr, logic.Expression):
            expr = logic.Expression(expr)
        if self.logflags & logflags.RECALL:
            self.log(logflags.RECALL.message(cls, expr))
        
        if limit == 0:
            return
        
        if offset and not order:
            raise TypeError("Order argument expected when offset is provided.")
        
        filters = self.extract_filters(expr)
        
        # Try to retrieve a single matching unit using its primary_keys.
        # This will skip grabbing any indices (a HUGE optimization).
        pk = self.primary_keys[cls]
        if set(filters.keys()) >= set(pk):
            yield self._unit_by_primary_key(cls, pk, filters)
            return
        
        # Try to retrieve matching units using an index.
        # If self.global_index is True, the last one should
        # be an index with propnames == []. See self.register.
        for index in indexset:
            if set(filters.keys()) >= set(index):
                data = indexset.xrecall(index, filters)
                data = self._xrecall_inner(data, expr)
                for unit in self._paginate(data, order, limit, offset, single=True):
                    yield unit
                return
    
    def _xrecall_inner(self, units, expr=None):
        """Private helper for self.xrecall."""
        for unit in units:
            if expr is None or expr(unit):
                # Must yield a sequence for use in _paginate.
                yield (unit,)
    
    def save(self, unit, forceSave=False):
        """Store the unit."""
        if self.logflags & logflags.SAVE:
            self.log(logflags.SAVE.message(unit, forceSave))
        
        if forceSave or unit.dirty():
            # Cleanse first because pickle state
            # includes _initial_property_hash.
            unit.cleanse()
            self.client.set(self._unit_key(unit), unit)
            self.indexsets[unit.__class__].add(unit)
    
    def destroy(self, unit):
        """Delete the unit."""
        if self.logflags & logflags.DESTROY:
            self.log(logflags.DESTROY.message(unit))
        
        self.client.delete(self._unit_key(unit))
        self.indexsets[unit.__class__].discard(unit)
    
    def reserve(self, unit):
        """Reserve storage space for the Unit."""
        if unit.identifiers:
            cls = unit.__class__
            indexset = self.indexsets[cls]
            
            if not unit.sequencer.valid_id(unit.identity()):
                if () in indexset:
                    # Try to generate an identifier by looking
                    # up all units in the global index.
                    index = indexset.get({}) or []
                    ids = [u.identity()
                           for u in indexset.scan(index).itervalues()]
                    unit.sequencer.assign(unit, ids)
                else:
                    raise NotImplementedError(
                        "Unindexed memcache cannot generate identifiers.")
            
            unit.cleanse()
            
            # Add the unit to the cache.
            try:
                self.client.add(self._unit_key(unit), unit)
            except IOError, exc:
                if exc.args[0] == 'NOT_STORED':
                    pass
                raise
            
            # Add the unit to all indices.
            indexset.add(unit)
        else:
            # This class has no identifiers, so skip reserve and wait for save.
            pass
        
        # Usually we log ASAP, but here we log after
        # the unit has had a chance to get an auto ID.
        if self.logflags & logflags.RESERVE:
            self.log(logflags.RESERVE.message(unit))
    
    def shutdown(self, conflicts='error'):
        """Shut down all connections to internal storage.
        
        conflicts: see errors.conflict.
        """
        self.client.disconnect_all()
    
    def create_database(self, conflicts='error'):
        """Create internal structures for the entire database.
        
        conflicts: see errors.conflict.
        """
        pass
    
    def drop_database(self, conflicts='error'):
        """Destroy internal structures for the entire database.
        
        conflicts: see errors.conflict.
        """
        for cls in self.classes:
            self.flush(cls)
    
    def create_storage(self, cls, conflicts='error'):
        """Create internal structures for the given class.
        
        conflicts: see errors.conflict.
        """
        if self.logflags & logflags.DDL:
            self.log(logflags.DDL.message("create storage %s" % cls))
        
        indexset = self.indexsets[cls]
        if () in indexset:
            try:
                self.client.add(indexset.key({}), [])
            except IOError, exc:
                if exc.args[0] == 'NOT STORED':
                    errors.conflict(conflicts, "Class %r already has storage."
                                    % cls)
                else:
                    raise
    
    def has_storage(self, cls):
        """If storage structures exist for the given class, return True."""
        return True
    
    def drop_storage(self, cls, conflicts='error'):
        """Destroy internal structures for the given class.
        
        conflicts: see errors.conflict.
        """
        if self.logflags & logflags.DDL:
            self.log(logflags.DDL.message("drop storage %s" % cls))
        self.flush(cls)
    
    def add_property(self, cls, name, conflicts='error'):
        """Add internal structures for the given property.
        
        conflicts: see errors.conflict.
        """
        clsname = cls.__name__
        if self.logflags & logflags.DDL:
            self.log(logflags.DDL.message("add property %s %s" %
                                          (clsname, name)))
        
        indexset = self.indexsets[cls]
        if () in indexset:
            # TODO: recalculate if primary_keys changed
            ci = self.client.get(indexset.key({})) or []
            for id in ci:
                key = "%s:%s:%s" % (self.name, clsname, self.hash(id))
                unit = self.client.get(key)
                if unit is not None:
                    unit._properties[name] = None
                    unit.cleanse()
                    self.client.set(key, unit)
    
    def has_property(self, cls, name):
        """If storage structures exist for the given property, return True."""
        indexset = self.indexsets[cls]
        if () in indexset:
            clsname = cls.__name__
            ci = self.client.get(indexset.key({}))
            
            if not ci:
                # We don't have any items, so there's nothing to
                # declare as 'unprepared'.
                return True
            
            for id in ci:
                key = "%s:%s:%s" % (self.name, clsname, self.hash(id))
                unit = self.client.get(key)
                if unit is not None:
                    return name in unit._properties
        
        return True
    
    def drop_property(self, cls, name, conflicts='error'):
        """Destroy internal structures for the given property.
        
        conflicts: see errors.conflict.
        """
        clsname = cls.__name__
        if self.logflags & logflags.DDL:
            self.log(logflags.DDL.message("drop property %s %s" %
                                          (clsname, name)))
        
        indexset = self.indexsets[cls]
        if () in indexset:
            ci = self.client.get(indexset.key({})) or []
            for id in ci:
                key = "%s:%s:%s" % (self.name, clsname, self.hash(id))
                unit = self.client.get(key)
                if unit is not None:
                    del unit._properties[name]
                    unit.cleanse()
                    self.client.set(key, unit)
    
    def rename_property(self, cls, oldname, newname, conflicts='error'):
        """Rename internal structures for the given property.
        
        conflicts: see errors.conflict.
        """
        clsname = cls.__name__
        if self.logflags & logflags.DDL:
            self.log(logflags.DDL.message("rename property %s from %s to %s"
                                          % (cls, oldname, newname)))
        
        indexset = self.indexsets[cls]
        if () in indexset:
            ci = self.client.get(indexset.key({})) or []
            for id in ci:
                key = "%s:%s:%s" % (self.name, clsname, self.hash(id))
                unit = self.client.get(key)
                if unit is not None:
                    unit._properties[newname] = unit._properties[oldname]
                    del unit._properties[oldname]
                    unit.cleanse()
                    self.client.set(key, unit)
    
    
    #                   Extra methods for use as a cache                   #
    
    def cachelen(self, cls):
        indexset = self.indexsets[cls]
        if () in indexset:
            return len(self.client.get(indexset.key({})))
        else:
            return 0
    
    def cached_units(self, cls):
        units = []
        indexset = self.indexsets[cls]
        if () in indexset:
            for key in self.client.get(indexset.key({})):
                unit = self.client.get(key)
                if unit is not None:
                    unit.cleanse()
                    units.append(unit)
        return units
    
    def flush(self, cls):
        """Dump all objects of the given class."""
        clsname = cls.__name__
        
        indexset = self.indexsets[cls]
        if () in indexset:
            gi_key = indexset.key({})
            # Delete all units in the global index.
            for id in self.client.get(gi_key) or []:
                key = "%s:%s:%s" % (self.name, clsname, self.hash(id))
                self.client.delete(key)
            
            # Delete the global index.
            self.client.delete(gi_key)
        # TODO:
        # else:
        #     self.increment_generation(cls)
    
    def register(self, cls):
        """Assert that Units of class 'cls' will be handled."""
        # Set a default primary key for the class. Consumers are free to
        # change this if another unique property is looked up more often.
        self.primary_keys[cls] = tuple(cls.identifiers or cls.properties)
        
        # Add indices based on the .index attribute of each UnitProperty.
        self.indexsets[cls] = i = IndexSet(self, cls)
        for propname in cls.properties:
            prop = getattr(cls, propname)
            if prop.index:
                # There's usually no need for an index on the primary key;
                # we can just fetch each one directly by cache key.
                # Callers are free to add one in explicitly if needed,
                # for example, if wanting to retrive all units without
                # any filtering criteria.
                if propname not in cls.identifiers:
                    i.add_index(propname)
        
        # Add an index with no propnames. This is a special
        # sentinel value for the global index that keeps us DRY.
        if self.global_index:
            i.add_index()
        
        storage.StorageManager.register(self, cls)
    
    def _unit_by_primary_key(self, cls, keys, filters):
        """Return a unit (or None) by primary keys which matches the filters dict.
        
        The filters argument must contain an entry for each key in the
        given list of keys, although it may and often should contain
        additional entries.
        """
        ident = tuple([filters[k] for k in keys])
        key = "%s:%s:%s" % (self.name, cls.__name__, self.hash(ident))
        unit = self.client.get(key)
        if unit is not None:
            matching = True
            if set(filters.keys()) > set(keys):
                # We retrieved the Unit using a subset of the filters.
                # Filter in full now.
                for k, v in filters.iteritems():
                    if getattr(unit, k) != v:
                        matching = False
                        break
            
            if matching:
                if self.logflags & logflags.IO:
                    self.log(logflags.IO.message('PK HIT (%s) %s' % (key, filters)))
                unit.cleanse()
                return unit
        
        if self.logflags & logflags.IO:
            self.log(logflags.IO.message('PK MISS (%s) %s' % (key, filters)))
        return None
    
    def extract_filters(self, expr):
        """Return a dict of (key == value) pairs from the given expr.
        
        If the given Expression contains operators other than ==, or if a
        set of filters cannot be obtained for some other reason, returns {}.
        In theory, we should be able to ignore other operators but the
        simple regex we use isn't that smart; we'd have to do a full parse
        of the expr and then functionally decompose it.
        
        This function is only designed to work on Expressions for a single
        class (i.e. - no joins).
        """
        if expr.is_constant(True):
            return {}
        
        fc = expr.func.func_code
        if indexable_regex.match(fc.co_code):
            if sys.version_info >= (2, 5):
                # Python 2.5 stopped including args in co_names.
                compkeys = fc.co_names
            else:
                # The first co_names will be the positional args for the class.
                compkeys = fc.co_names[fc.co_argcount:]
            
            # "If a code object represents a function, the first item
            # in co_consts is the documentation string of the function,
            # or None if undefined."
            compvals = fc.co_consts[1:]
            
            return dict(zip(compkeys, compvals))
        
        return {}
    
    def scan(self, mainstore, cls, expr):
        """Return a list of units from a cached index (or None).
        
        The class and expression will be used to find a cached index;
        if not found, the mainstore will be used to create one, and it
        will be cached.
        
        Once an index has been obtained, it will be iterated over against
        the cache. Each unit in the index which is not available in the
        cache will be pulled from mainstore.
        
        If no index intersects with the given expression, None is returned.
        """
        filters = self.extract_filters(expr)
        indexset = self.indexsets[cls]
        keyattrs = self.primary_keys[cls]
        
        # Find the best index for the given filters.
        for index in indexset:
            if set(filters.keys()) >= set(index):
                break
        else:
            # Signal the caller that no index scan was performed.
            return None
        
        criteria = [(k, filters[k]) for k in index]
        ids = indexset.get(criteria)
        if ids is None:
            # Not in the cache. Grab the list of id-tuples from nextstore.
            # Note well: we're NOT grabbing view(.., filters), because
            # if filters > criteria, that would cache a subset of
            # the index leaf node.
            ids = mainstore.view((cls, keyattrs, dict(criteria)))
            # Then cache the list result for next time. Note that index
            # contents are unordered.
            indexset.put(criteria, ids, time=self.index_time)
        
        # Query the cache for multiple units (by id).
        units = indexset.scan(ids)
        
        # Now query the nextstore for any units that the cache missed...
        misses = [k for k in ids if k not in units]
        if self.index_stride:
            # ...in chunks of length: self.index_stride.
            for step in xrange(0, len(misses), self.index_stride):
                # TODO: allow for multiple identifiers
                misstep = zip(*misses[step:step + self.index_stride])[0]
                f = lambda x: getattr(x, keyattrs[0]) in misstep
                for unit in mainstore.recall(cls, f):
                    units[tuple([getattr(unit, a) for a in keyattrs])] = unit
                    try:
                        self.save(unit, forceSave=True)
                    except KeyError:
                        # The cache refused to save the unit (possibly full).
                        pass
        elif misses:
            # ...or all in one chunk if desired.
            # TODO: allow for multiple identifiers
            misstep = zip(*misses)[0]
            f = lambda x: getattr(x, keyattrs[0]) in misstep
            for unit in mainstore.recall(cls, f):
                units[tuple([getattr(unit, a) for a in keyattrs])] = unit
                try:
                    self.save(unit, forceSave=True)
                except KeyError:
                    # The cache refused to save the unit (possibly full).
                    pass
        
        indexset.filter(index, filters, units)
        
        return units.values()


class IndexSet(object):
    """A set of indices for a single class.
    
    Each index covers a tuple of unit attributes.
    
    Each leaf node of each index is stored in memcached under its own key;
    each value is a list of tuple([unit.k for k in primary_keys[cls]]).
    For example, given an index over ("age", ), each distinct recall
    operation will produce its own index node:
        
        recall(Person, {age: 31}) -> ns:Person:index(age=31) = [(1132, 663)]
        recall(Person, {age: 25}) -> ns:Person:index(age=25) = [(12, 34, 22)]
        recall(Person, {age: 64}) -> ns:Person:index(age=64) = [(7, 17, 27)]
    """
    
    def __init__(self, store, cls):
        self.store = store
        self.cls = cls
        self._key_template = '%s:%s:index(%%s)' % (store.name, cls.__name__)
        self._indices = []
    
    def add_index(self, *attributes):
        """Add an index over the given attributes."""
        # Sort them from most-specific (most properties) to least.
        if attributes not in self._indices:
            self._indices.append(attributes)
            self._indices.sort(lambda x, y: cmp(len(y), len(x)))
    
    def __iter__(self):
        return iter(self._indices)
    
    def key(self, criteria):
        """Return the cache key for the index node for the given criteria.
        
        The given criteria must be an iterable of (key, value) tuples,
        and must only contain keys for an existing index.
        
        If criteria is an empty list, the 'global index' key is returned.
        """
        criteria = ["%s=%s" % (k, str(v).replace(" ", "+"))
                    for k, v in criteria]
        return self._key_template % ",".join(criteria)
    
    def get(self, criteria):
        """Return a cached list of unit ids which match the given criteria.
        
        The given criteria must be an iterable of (key, value) tuples,
        and must only contain keys for an existing index.
        
        If criteria is an empty list, the 'global index' key is returned.
        
        The ids returned will be a list of tuples of the form:
            tuple([getattr(unit, name) for name in primary_keys[cls]])
        """
        cache_key = self.key(criteria)
        ids = self.store.client.get(cache_key)
        if self.store.logflags & logflags.IO:
            if ids is None:
                idlen = None
            else:
                idlen = len(ids)
            self.store.log(logflags.IO.message("INDEX GET (%s) len %r" %
                                               (cache_key, idlen)))
        return ids
    
    def put(self, criteria, ids, time=0):
        """Cache a list of unit identifiers which match the given criteria.
        
        The given criteria must be an iterable of (key, value) tuples,
        and must only contain keys for an existing index.
        
        If criteria is an empty list, the 'global index' key is returned.
        
        The ids provided MUST be a list of tuples of the form:
            tuple([getattr(unit, name) for name in primary_keys[cls]])
        """
        cache_key = self.key(criteria)
        if self.store.logflags & logflags.IO:
            self.store.log(logflags.IO.message("INDEX PUT (%s) len %r: %r" %
                                               (cache_key, len(ids), ids)))
        self.store.client.set(cache_key, ids, time=time)
    
    def scan(self, ids):
        """Return a dict of multiple units from the given list of ids.
        
        The ids provided MUST be a list of tuples of the form:
            tuple([getattr(unit, name) for name in primary_keys[cls]])
        
        The returned dict will not contain entries for any units which
        have expired from the cache. Callers may use this information
        to request missed units from another store.
        """
        clsname = self.cls.__name__
        if ids:
            keys = ["%s:%s:%s" % (self.store.name, clsname, self.store.hash(id))
                    for id in ids]
            data = self.store.client.get_multi(keys)
            
            # Transform the dict back to id keys instead of cache keys.
            units = {}
            for i, k in zip(ids, keys):
                unit = data.get(k, None)
                if unit is not None:
                    unit.cleanse()
                    units[i] = unit
        else:
            units = {}
        
        if self.store.logflags & logflags.IO:
            self.store.log(logflags.IO.message("INDEX SCAN %s (%r hits of %r)" %
                                               (clsname, len(units), len(ids))))
        return units
    
    def unit(self, index, filters):
        """Return a unit from the index which matches the filters dict (or None).
        
        The filters argument must contain an entry for each key in the given
        index, although it may and often should contain additional entries.
        """
        if set(filters.keys()) > set(index):
            for unit in self.xrecall(index, filters):
                return unit
        else:
            clsname = self.cls.__name__
            # If the filters and index keys are equal, it should be faster
            # to perform single gets against memcached, rather than the
            # get_multi calls that self.xrecall performs.
            criteria = [(k, filters[k]) for k in index]
            ids = self.get(criteria)
            if ids:
                for id in ids:
                    cache_key = "%s:%s:%s" % (self.store.name, clsname, self.hash(id))
                    unit = self.client.get(cache_key)
                    if unit is None:
                        if self.store.logflags & logflags.IO:
                            self.store.log(logflags.IO.message(
                                'INDEX MISS (%s) %s' % (cache_key, filters)))
                    else:
                        if self.store.logflags & logflags.IO:
                            self.store.log(logflags.IO.message(
                                'INDEX HIT (%s) %s' % (cache_key, filters)))
                        unit.cleanse()
                        return unit
            else:
                if self.store.logflags & logflags.IO:
                    self.store.log(logflags.IO.message(
                        'INDEX EMPTY (%s) %s' % (clsname, filters)))
        return None
    
    def xrecall(self, index, filters):
        """Yield units from the given index which match the filters dict.
        
        The filters argument must contain an entry for each key in the given
        index, although it may and often should contain additional entries.
        """
        criteria = [(k, filters[k]) for k in index]
        ids = self.get(criteria)
        if ids:
            units = self.scan(ids)
            self.filter(index, filters, units)
            for unit in units.itervalues():
                unit.cleanse()
                yield unit
    
    def filter(self, index, filters, units):
        """Remove any units which don't match filters, and update the index.
        
        The filters argument must contain an entry for each key in the given
        index, although it may and often should contain additional entries.
        
        The 'units' arg must be a dict of (id, unit) pairs, and must be
        the complete set of units from an index node; that is, the result
        of an indexset.get() call for the same index.
        """
        ids = units.keys()
        removals = False
        for id, unit in units.items():
            for key, value in filters.iteritems():
                if getattr(unit, key) != value:
                    del units[id]
                    # Remove any idents from the index node that no longer
                    # satisfy the index criteria. This is how we update
                    # index nodes--eager adds but late discards.
                    if key in index:
                        removals = True
                        ids.remove(id)
        if removals:
            criteria = [(k, filters[k]) for k in index]
            indexset.put(criteria, ids, time=self.store.index_time)
    
    def add(self, unit):
        """Add the given unit to all indices."""
        ident = tuple([getattr(unit, name)
                       for name in self.store.primary_keys[self.cls]])
        for index in self._indices:
            criteria = [(k, getattr(unit, k)) for k in index]
            indexnode = self.get(criteria) or []
            if ident not in indexnode:
                indexnode.append(ident)
                self.put(criteria, indexnode)
    
    def discard(self, unit):
        """Discard the given unit from all indices."""
        ident = tuple([getattr(unit, name)
                       for name in self.store.primary_keys[self.cls]])
        for index in self._indices:
            criteria = [(k, getattr(unit, k)) for k in index]
            indexnode = self.get(criteria) or []
            if ident in indexnode:
                indexnode.remove(ident)
                self.put(criteria, indexnode)

