Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

Changeset 54

Show
Ignore:
Timestamp:
04/02/07 01:31:20
Author:
fumanchu
Message:

Deeper rabbit hole: New DatabaseType? replaces column.dbtype, imperfect_type and hints; they also hold more type metadata, plus the cast method and default pytype and default adapters. Only SQLServer has been updated to this. Hopefully all this will allow implicit conversions, too. New Adapters have push and pull methods which replace the old coerce in and out, so that Databases have one AdapterSet? now instead of 3 separate adapters. Also split sqlserver and msaccess from ado.py.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/geniusql/adapters.py

    r53 r54  
    1 """Adapters from Python to SQL (and back) for the geniusql package.""" 
    2  
    3 __all__ = [ 
    4     'AdapterFromDB', 'AdapterToSQL', 'TypeAdapter', 
    5     'getCoerceMethod', 'getCoerceName', 'maxfloat_digits', 'maxint_bytes', 
    6     'localtime_offset', 
    7 
     1"""Adapters from Python to SQL (and back) for the geniusql package. 
     2 
     3Adaptation is tricky because semantic adaptation and (server-specific) 
     4syntactic adaptation need to be taken care of for every value in both 
     5directions. For example, when we convert a datetime.date to SQL, we 
     6must both convert the Python value to a string (for example, of the 
     7form '2004-03-01') and apply server-specific formatting (for example, 
     8'#2004-03-01#' for Microsoft Access). 
     9 
     10This is an extremely thorny issue and really requires the user to manually 
     11form and apply custom adapters completely by hand. However, in the vast 
     12majority of cases, a reasonable set of default adapters can be generated 
     13by Geniusql. For example, a Column of pytype "datetime.date" can default 
     14to an SQL92DATE adapter (but an MSAccess_Date adapter if necessary), based 
     15entirely on the Python type. These default adapters should be parameterized 
     16to the hilt so that the user can tweak the default adapters easily. 
     17That is, the user should be able to write things like: 
     18 
     19    col = schema.column() 
     20    col.adapter.encoding = 'ASCII' 
     21 
     22...rather than passing all such settings as args to the column() call, 
     23or forcing the user to select an adapter subclass with the desired 
     24characteristics. 
     25 
     26For these reasons, entries in the DatabaseTypes.default_adapters dict 
     27are adapter classes, not instances. This allows each back end to customize 
     28its AdapterSet.default() method; for example, the SQL Server AdapterSet 
     29will set adapter.escapes to [("'", "''")]. This technique also avoids 
     30any problems that might arise from sharing an Adapter instance among 
     31multiple columns. 
     32""" 
    833 
    934import datetime 
     
    1237except ImportError: 
    1338    import pickle 
     39 
     40try: 
     41    set 
     42except NameError: 
     43    from sets import Set as set 
    1444 
    1545import sys 
     
    3161import warnings 
    3262 
    33  
    34 from geniusql import errors 
    35 from geniusql.typerefs import * 
     63from geniusql import errors, typerefs 
     64 
     65 
     66# ----------------------------- Adapter Sets ----------------------------- # 
     67 
     68 
     69def getCoerceName(pytype): 
     70    """Return the name of the coercion method for a given Python type.""" 
     71    mod = pytype.__module__ 
     72    if mod == "__builtin__": 
     73        xform = "%s" % pytype.__name__ 
     74    else: 
     75        xform = "%s_%s" % (mod, pytype.__name__) 
     76    xform = xform.replace(".", "_") 
     77    return xform 
     78 
     79 
     80class AdapterSet(object): 
     81    """Determine the best database type for a given column + Python type. 
     82     
     83    When Geniusql is asked to create database tables, it must choose an 
     84    appropriate column data type for each UnitProperty based on the 
     85    type (and hints) of that property. This class recommends such 
     86    database types by returning a new instance of DatabaseType. 
     87    """ 
     88     
     89    known_types = None 
     90     
     91    # You should REALLY check into your DB's encoding and override this. 
     92    encoding = 'utf8' 
     93     
     94    # Default escapes for string values. 
     95    escapes = [("'", "''"), ("\\", r"\\")] 
     96     
     97    def __copy__(self): 
     98        newset = self.__class__() 
     99        newset.update(self) 
     100        newset.encoding = self.encoding 
     101        newset.escapes = self.escapes 
     102        newset.known_types = self.known_types.copy() 
     103        return newset 
     104    copy = __copy__ 
     105     
     106    def canonicalize(self, dbtypename): 
     107        """Return the canonical DatabaseType for the given synonym. 
     108         
     109        In order to avoid large amounts of code (in each provider module!) 
     110        that merely looks up synonyms, database types MUST be 
     111        canonicalized for all Column and SQLExpression objects. 
     112        """ 
     113        for typeset in self.known_types.itervalues(): 
     114            for dbtype in typeset: 
     115                if (dbtypename == dbtype.__name__ 
     116                    or dbtypename in dbtype.synonyms): 
     117                    return dbtype 
     118        raise KeyError("No canonical name found for %r." % dbtypename) 
     119     
     120    def default(self, pytype, dbtype=None): 
     121        """Return a default adapter instance for the given pytype, dbtype.""" 
     122        if dbtype is None: 
     123            dbtype = self.database_type(pytype) 
     124         
     125        ptypes = [pytype, None] 
     126        for base in pytype.__bases__: 
     127            ptypes.append(base) 
     128         
     129        for p in ptypes: 
     130            if p in dbtype.default_adapters: 
     131                a = dbtype.default_adapters[p]() 
     132                if hasattr(a, "encoding"): 
     133                    a.encoding = self.encoding 
     134                if hasattr(a, "escapes"): 
     135                    a.escapes = self.escapes 
     136                return a 
     137         
     138        raise TypeError("%s has no default adapter for %s. Looked for: %s" % 
     139                        (dbtype, pytype, "\n".join([repr(x) for x in ptypes]))) 
     140     
     141     
     142    # ------------------------- Database types ------------------------- # 
     143     
     144    def database_type(self, pytype, hints=None): 
     145        """Recommend a DatabaseType for the given Python type. 
     146         
     147        hints: if provided, this should be a dict of property attributes 
     148            which can be used to distinguish between similar database types. 
     149            Canonical keys include 'bytes', 'precision', and 'scale'. 
     150        """ 
     151        xform = "dbtype_for_" + getCoerceName(pytype) 
     152        try: 
     153            xform = getattr(self, xform) 
     154        except AttributeError: 
     155            raise TypeError("%r is not handled by %s. Tried %r" % 
     156                            (pytype, self.__class__, xform)) 
     157        if hints is None: 
     158            hints = {} 
     159        return xform(hints) 
     160     
     161    def dbtype_for_float(self, hints): 
     162        """Return a DatabaseType for floats of the given binary precision.""" 
     163        # Note that 'precision' is binary digits, not decimal. 
     164        precision = int(hints.get('precision', maxfloat_digits)) 
     165        for dbtype in self.known_types['float']: 
     166            if precision <= dbtype.max_precision: 
     167                return dbtype(precision=precision) 
     168        return self.decimal_type(precision=precision) 
     169     
     170    def dbtype_for_str(self, hints): 
     171        # The bytes hint shall not reflect the usual 4-byte base for varchar. 
     172        bytes = int(hints.get('bytes', 255)) 
     173        for dbtype in self.known_types['varchar']: 
     174            if bytes <= dbtype.max_bytes: 
     175                print dbtype, bytes 
     176                return dbtype(bytes=bytes) 
     177        raise ValueError("%r is greater than the maximum bytes %r." 
     178                         % (bytes, dbtype.max_bytes)) 
     179     
     180    def dbtype_for_dict(self, hints): 
     181        return self.dbtype_for_str(hints) 
     182    def dbtype_for_list(self, hints): 
     183        return self.dbtype_for_str(hints) 
     184    def dbtype_for_tuple(self, hints): 
     185        return self.dbtype_for_str(hints) 
     186    def dbtype_for_unicode(self, hints): 
     187        return self.dbtype_for_str(hints) 
     188     
     189    def dbtype_for_geniusql_logic_Expression(self, hints): 
     190        return self.dbtype_for_str(hints) 
     191     
     192    def dbtype_for_bool(self, hints): 
     193        return self.known_types['bool'][0]() 
     194     
     195    def dbtype_for_datetime_datetime(self, hints): 
     196        return self.known_types['datetime'][0]() 
     197     
     198    def dbtype_for_datetime_date(self, hints): 
     199        return self.known_types['date'][0]() 
     200     
     201    def dbtype_for_datetime_time(self, hints): 
     202        return self.known_types['time'][0]() 
     203     
     204    def dbtype_for_datetime_timedelta(self, hints): 
     205        try: 
     206            # If your DB has an INTERVAL datatype, you should provide a 
     207            # native INTERVAL type. You'll also have to update the date 
     208            # arithmetic inside the decompiler and add a timedelta adapter. 
     209            return self.known_types['timedelta'][0]() 
     210        except (KeyError, IndexError): 
     211            # Fallback for DB's which do not have an INTERVAL data type. 
     212            # Use decimal instead of float to avoid rounding errors. 
     213            # Using precision of 12 should allow +/- 31688 years. 
     214            return self.decimal_type(12, 0) 
     215     
     216    def numeric_max_precision(self): 
     217        return max([0] + [t.max_precision for t in self.known_types['numeric']]) 
     218     
     219    def decimal_type(self, precision, scale): 
     220        if scale > precision: 
     221            scale = precision 
     222         
     223        for dbtype in self.known_types['numeric']: 
     224            if precision <= dbtype.max_precision: 
     225                return dbtype(precision=precision, scale=scale) 
     226         
     227        # Use a VARCHAR type (add 1 char for the decimal point and 1 for sign). 
     228        bytes = precision + 1 
     229        if scale: 
     230            bytes += 1 
     231        dbtype = self.dbtype_for_str({'bytes': bytes}) 
     232         
     233        errors.warn("The given precision (%s) is greater than the " 
     234                    "maximum numeric precision (%s). Using %s instead." 
     235                    % (precision, self.numeric_max_precision(), dbtype)) 
     236        return dbtype 
     237     
     238    if typerefs.decimal: 
     239        if hasattr(typerefs.decimal, "Decimal"): 
     240            def dbtype_for_decimal_Decimal(self, hints): 
     241                precision = int(hints.get('precision', 
     242                                          self.numeric_max_precision())) 
     243                # Assume most people use decimal for money; default scale = 2. 
     244                scale = int(hints.get('scale', 2)) 
     245                return self.decimal_type(precision, scale) 
     246        else: 
     247            def dbtype_for_decimal(self, hints): 
     248                precision = int(hints.get('precision', 
     249                                          self.numeric_max_precision())) 
     250                # Assume most people use decimal for money; default scale = 2. 
     251                scale = int(hints.get('scale', 2)) 
     252                return self.decimal_type(precision, scale) 
     253     
     254    if typerefs.fixedpoint: 
     255        def dbtype_for_fixedpoint_FixedPoint(self, hints): 
     256            # Note that fixedpoint has no theoretical precision limit. 
     257            precision = int(hints.get('precision', 
     258                                      self.numeric_max_precision())) 
     259            # Assume most people use fixedpoint for money; default scale = 2. 
     260            scale = int(hints.get('scale', 2)) 
     261            return self.decimal_type(precision, scale) 
     262     
     263    def dbtype_for_long(self, hints): 
     264        if 'bytes' in hints: 
     265            bytes = int(hints['bytes']) 
     266        else: 
     267            bytes = self.numeric_max_precision() / 2 
     268         
     269        for dbtype in self.known_types['int']: 
     270            if bytes <= dbtype.max_bytes: 
     271                return dbtype(bytes=bytes) 
     272        return self.decimal_type(precision=bytes * 2, scale=0) 
     273     
     274    def dbtype_for_int(self, hints): 
     275        bytes = int(hints.get('bytes', maxint_bytes)) 
     276        for dbtype in self.known_types['int']: 
     277            if bytes <= dbtype.max_bytes: 
     278                return dbtype(bytes=bytes) 
     279        return self.decimal_type(precision=bytes * 2, scale=0) 
     280 
     281 
     282# ------------------------------- Adapters ------------------------------- # 
    36283 
    37284 
     
    59306 
    60307 
    61 def getCoerceName(pytype): 
    62     """Return the name of the coercion method for a given Python type.""" 
    63     mod = pytype.__module__ 
    64     if mod == "__builtin__": 
    65         xform = "%s" % pytype.__name__ 
    66     else: 
    67         xform = "%s_%s" % (mod, pytype.__name__) 
    68     xform = xform.replace(".", "_") 
    69     return xform 
    70  
    71 def getCoerceMethod(adapter, totype, fromtype, prefix="coerce_"): 
    72     """Return the coercion method for a given 'from' and 'to' type. 
    73      
    74     Possible coercion methods are searched in the following order: 
    75       1. Exact match:    coerce       <fromtype> to <totype> 
    76       2. Exact totype:   coerce              any to <totype> 
    77       3. Exact fromtype: coerce       <fromtype> to any 
    78       4. totype.bases:   coerce       <fromtype> to <totype.base1> 
    79                          coerce              any to <totype.base1> 
    80                          coerce       <fromtype> to <totype.base2>... 
    81       5. fromtype.bases: coerce <fromtype.base1> to <totype> 
    82                          coerce <fromtype.base1> to any 
    83                          coerce <fromtype.base2> to <totype>... 
    84      
    85     If no matching coercion method is found, a TypeError is raised. 
    86     """ 
    87     if isinstance(fromtype, str): 
    88         frombases = () 
    89     else: 
    90         frombases = fromtype.__bases__ 
    91         fromtype = getCoerceName(fromtype) 
    92      
    93     if isinstance(totype, str): 
    94         tobases = () 
    95     else: 
    96         tobases = totype.__bases__ 
    97         totype = getCoerceName(totype) 
    98      
    99     methods = [] 
    100     if fromtype and totype: 
    101         methods.append(prefix + fromtype + "_to_" + totype) 
    102     if totype: 
    103         methods.append(prefix + "any_to_" + totype) 
    104     if fromtype: 
    105         methods.append(prefix + fromtype + "_to_any") 
    106      
    107     for meth in methods: 
    108         if hasattr(adapter, meth): 
    109             return getattr(adapter, meth) 
    110      
    111     for base in tobases: 
    112         base = getCoerceName(base) 
    113         if fromtype: 
    114             meth = prefix + fromtype + "_to_" + base 
    115             methods.append(meth) 
    116             if hasattr(adapter, meth): 
    117                 return getattr(adapter, meth) 
    118         meth = prefix + "any_to_" + base 
    119         methods.append(meth) 
    120         if hasattr(adapter, meth): 
    121             return getattr(adapter, meth) 
    122      
    123     for base in frombases: 
    124         base = getCoerceName(base) 
    125         if totype: 
    126             meth = prefix + base + "_to_" + totype 
    127             methods.append(meth) 
    128             if hasattr(adapter, meth): 
    129                 return getattr(adapter, meth) 
    130         meth = prefix + base + "_to_any" 
    131         methods.append(meth) 
    132         if hasattr(adapter, meth): 
    133             return getattr(adapter, meth) 
    134      
    135     raise TypeError("%s -> %s is not handled by %s.  Looked for: %s" % 
    136                     (fromtype, totype, adapter.__class__, ", ".join(methods))) 
    137  
    138  
    139 class AdapterToSQL(object): 
    140     """Coerce Expression constants to SQL. 
    141      
    142     This base class is designed to work out-of-the-box with PostgreSQL 8. 
    143     """ 
    144      
    145     # You should REALLY check into your DB's encoding and override this. 
    146     encoding = 'utf8' 
    147      
    148     # Notice these are ordered pairs. Escape \ before introducing new ones. 
    149     # Values in these two lists should be strings encoded with self.encoding. 
    150     escapes = [("'", "''"), ("\\", r"\\")] 
    151     like_escapes = [("%", r"\%"), ("_", r"\_")] 
    152      
    153     # These are not the same as coerce_bool (which is used on one side of  
    154     # a comparison). Instead, these are used when the whole (sub)expression 
    155     # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3". 
    156     bool_true = "TRUE" 
    157     bool_false = "FALSE" 
    158      
    159     def __init__(self): 
    160         self._memoized_methods = {} 
    161      
    162     def escape_like(self, value): 
    163         """Prepare a string value for use in a LIKE comparison.""" 
    164         if not isinstance(value, str): 
    165             value = value.encode(self.encoding) 
    166         # Notice we strip leading and trailing quote-marks. 
    167         value = value.strip("'\"") 
    168         for pat, repl in self.like_escapes: 
    169             value = value.replace(pat, repl) 
    170         return value 
    171      
    172     def coerce(self, value, dbtype="", pytype=None): 
    173         """Return value, coerced from (optional pytype) to dbtype.""" 
    174         if pytype is None: 
    175             pytype = type(value) 
    176         if "(" in dbtype: 
    177             dbtype = dbtype[:dbtype.find("(")] 
    178          
    179         key = (dbtype, pytype) 
    180         try: 
    181             meth = self._memoized_methods[key] 
    182         except KeyError: 
    183             meth = getCoerceMethod(self, dbtype, pytype) 
    184             self._memoized_methods[key] = meth 
    185          
    186         return meth(value) 
    187      
    188     def coerce_NoneType_to_any(self, value): 
    189         return "NULL" 
    190      
    191     def coerce_bool_to_any(self, value): 
     308class Adapter(object): 
     309     
     310    def push(self, value): 
     311        """Coerce the given Python value to SQL.""" 
     312        raise NotImplementedError 
     313     
     314    def pull(self, value): 
     315        """Coerce the given database value to a Python value.""" 
     316        raise NotImplementedError 
     317     
     318    def cast(self, sql, fromtype, totype): 
     319        """Cast the given SQL expression from one database type to another. 
     320         
     321        If the given types are equal (or synonyms), this should generally 
     322        return the original SQL. This may also return the original SQL if 
     323        the database implicitly converts one or both types before comparing 
     324        them. 
     325        """ 
     326        if fromtype != totype: 
     327            raise TypeError("Could not cast %r from %r to %r." 
     328                            % (sql, fromtype, totype)) 
     329        return sql 
     330     
     331    def binary_op(self, op1, op, op2): 
     332        """Return the SQL for the given binary operation.""" 
     333        return "%s %s %s" % (op1.sql, op, op2.sql) 
     334 
     335 
     336class BOOLEAN(Adapter): 
     337     
     338    def push(self, value): 
    192339        if value: 
    193340            return 'TRUE' 
    194341        return 'FALSE' 
    195342     
    196     # The great thing about these 3 date coercers is that you can use 
    197     # them with (VAR)CHAR columns just as well as with DATETIME, etc. 
    198     # and comparisons will still work! 
    199     def coerce_datetime_datetime_to_any(self, value): 
     343    pull = bool 
     344 
     345# SQL92 types: INTEGER (INT), SMALLINT, NUMERIC, DECIMAL, REAL, 
     346#   DOUBLE PRECISION (DOUBLE), BIT, BIT VARYING, DATE, TIME, TIMESTAMP, 
     347#   CHARACTER (CHAR), CHARACTER VARYING (VARCHAR), INTERVAL 
     348 
     349class SQL92BIT(Adapter): 
     350    def push(self, value): 
     351        if value: 
     352            return '1' 
     353        return '0' 
     354     
     355    def pull(self, value): 
     356        return bool(int(value)) 
     357 
     358 
     359# The great thing about these 3 date coercers is that you can use 
     360# them with (VAR)CHAR columns just as well as with DATETIME, etc. 
     361# and comparisons will still work! 
     362class SQL92TIMESTAMP(Adapter): 
     363     
     364    def push(self, value): 
    200365        return ("'%04d-%02d-%02d %02d:%02d:%02d'" % 
    201366                (value.year, value.month, value.day, 
    202367                 value.hour, value.minute, value.second)) 
    203368     
    204     def coerce_datetime_date_to_any(self, value): 
     369    def pull(self, value): 
     370        if isinstance(value, datetime.datetime): 
     371            return value 
     372        chunks = (value[0:4], value[5:7], value[8:10], 
     373                  value[11:13], value[14:16], value[17:19], 
     374                  value[20:26] or 0) 
     375        return datetime.datetime(*map(int, chunks)) 
     376 
     377 
     378class SQL92DATE(Adapter): 
     379     
     380    def push(self, value): 
    205381        return "'%04d-%02d-%02d'" % (value.year, value.month, value.day) 
    206382     
    207     def coerce_datetime_time_to_any(self, value): 
     383    def pull(self, value): 
     384        # These are in order for a reason: datetime is a subclass of date! 
     385        if isinstance(value, datetime.datetime): 
     386            # Psycopg might do this when adding date + timedelta, for instance. 
     387            return value.date() 
     388        elif isinstance(value, datetime.date): 
     389            return value 
     390         
     391        chunks = (value[0:4], value[5:7], value[8:10]) 
     392        return datetime.date(*map(int, chunks)) 
     393 
     394 
     395class SQL92TIME(Adapter): 
     396     
     397    def push(self, value): 
    208398        return "'%02d:%02d:%02d'" % (value.hour, value.minute, value.second) 
    209399     
    210     def coerce_datetime_timedelta_to_any(self, value): 
     400    def pull(self, value): 
     401        if isinstance(value, datetime.time): 
     402            return value 
     403        chunks = (value[0:2], value[3:5], value[6:8]) 
     404        return datetime.time(*map(int, chunks)) 
     405 
     406 
     407class INTERVAL(Adapter): 
     408    """Adapter for storing datetime.timedelta values in whole seconds. 
     409     
     410    SQL-92 defines an INTERVAL type, but few commercial databases 
     411    implement it in a reasonable manner. This adapter stores the 
     412    value (days * 86400) + seconds in a NUMERIC field instead, 
     413    which should work with most databases. Note that a custom 
     414    binary_op method MUST be written for each DB which subclasses 
     415    this adapter; there is no default because each RDBMS implements 
     416    date (and especially date interval) arithmetic in its own way. 
     417     
     418    This adapter uses whole seconds only to avoid problems many 
     419    databases exhibit when comparing two FLOATs for equality in SQL. 
     420    """ 
     421     
     422    def push(self, value): 
    211423        dec_val = (value.days * 86400) + value.seconds 
    212424        return repr(dec_val) 
    213425     
    214     coerce_decimal_to_any = str 
    215     coerce_decimal_Decimal_to_any = str 
    216      
    217     def do_pickle(self, value): 
     426    def pull(self, value): 
     427        days, seconds = divmod(long(value), 86400) 
     428        return datetime.timedelta(int(days), int(seconds)) 
     429 
     430 
     431class SQL92REAL(Adapter): 
     432    # Very important we use repr here so we get all 17 decimal digits. 
     433    push = repr 
     434    pull = float 
     435 
     436class SQL92DOUBLE(Adapter): 
     437    # Very important we use repr here so we get all 17 decimal digits. 
     438    push = repr 
     439    pull = float 
     440 
     441class SQL92SMALLINT(Adapter): 
     442    push = str 
     443    # SQL-92 SMALLINT should be 2 bytes 
     444    if maxint_bytes >= 2: 
     445        pull = int 
     446    else: 
     447        pull = long 
     448 
     449class SQL92INTEGER(Adapter): 
     450    push = str 
     451    # SQL-92 INTEGER should be 4 bytes 
     452    if maxint_bytes >= 4: 
     453        pull = int 
     454    else: 
     455        pull = long 
     456 
     457 
     458class SQL92VARCHAR(Adapter): 
     459     
     460    encoding = 'utf8' 
     461    escapes = [("'", "''"), ("\\", r"\\")] 
     462     
     463    def push(self, value): 
     464        if not isinstance(value, str): 
     465            value = value.encode(self.encoding) 
     466        for pat, repl in self.escapes: 
     467            value = value.replace(pat, repl) 
     468        return "'" + value + "'" 
     469     
     470    def pull(self, value): 
     471        if isinstance(value, unicode): 
     472            return value.encode(self.encoding) 
     473        else: 
     474            return str(value) 
     475 
     476 
     477class UNICODE(SQL92VARCHAR): 
     478     
     479    def pull(self, value): 
     480        if isinstance(value, unicode): 
     481            return value 
     482        if isinstance(value, (basestring, buffer)): 
     483            return unicode(value, self.encoding) 
     484        return unicode(value) 
     485 
     486 
     487class Pickler(SQL92VARCHAR): 
     488     
     489    def push(self, value): 
    218490        # dumps with protocol 0 uses the 'raw-unicode-escape' encoding, 
    219491        # and we take pains not to re-encode it with self.encoding. 
     
    221493        # that introduces null bytes into the SQL, which is a no-no. 
    222494        value = pickle.dumps(value) 
    223         value = self.coerce_str_to_any(value, skip_encoding=True) 
    224         return value 
    225      
    226     coerce_dict_to_any = do_pickle 
    227      
    228     coerce_fixedpoint_FixedPoint_to_any = str 
    229      
    230     # Very important we use repr here so we get all 17 decimal digits. 
    231     coerce_float_to_any = repr 
    232     coerce_int_to_any = str 
    233     coerce_list_to_any = do_pickle 
    234     coerce_geniusql_logic_Expression_to_any = do_pickle 
    235     coerce_long_to_any = str 
    236      
    237     def coerce_str_to_any(self, value, skip_encoding=False): 
    238         if not skip_encoding and not isinstance(value, str): 
    239             value = value.encode(self.encoding) 
    240495        for pat, repl in self.escapes: 
    241496            value = value.replace(pat, repl) 
    242497        return "'" + value + "'" 
    243498     
    244     coerce_tuple_to_any = do_pickle 
    245      
    246     coerce_unicode_to_any = coerce_str_to_any 
    247      
    248     def cast(self, colref, dbtype, pytype): 
    249         """Return the column reference, cast from dbtype to pytype.""" 
    250         if "(" in dbtype: 
    251             dbtype = dbtype[:dbtype.find("(")] 
    252          
    253         meth = getCoerceMethod(self, pytype, dbtype, "cast_") 
    254         return meth(colref) 
    255      
    256     def _to_TEXT(self, value): 
    257         return "'%s'" % str(value) 
    258      
    259     def add_pickled_type(self, pytype): 
    260         name = "coerce_%s_to_any" % getCoerceName(pytype) 
    261         setattr(self, name, self.do_pickle) 
    262  
    263  
    264 for fromtype in ('decimal', 'decimal_Decimal', 'fixedpoint_FixedPoint', 
    265                  'float', 'int', 'long'): 
    266     setattr(AdapterToSQL, 'coerce_%s_to_TEXT' % fromtype, AdapterToSQL._to_TEXT) 
    267     setattr(AdapterToSQL, 'coerce_%s_to_VARCHAR' % fromtype, AdapterToSQL._to_TEXT) 
    268  
    269  
    270 class AdapterFromDB(object): 
    271     """Coerce incoming values from DB types to Python datatypes. 
    272      
    273     This base class is designed to work out-of-the-box with PostgreSQL 8. 
    274     """ 
    275      
    276     # You should REALLY check into your DB's encoding and override this. 
    277     encoding = 'utf8' 
    278      
    279     def __init__(self): 
    280         self._memoized_methods = {} 
    281      
    282     def coerce(self, value, dbtype, pytype): 
    283         """Return value, coerced from dbtype to pytype.""" 
    284         # All columns could conceivably hold NULL => Python None 
    285         if value is None: 
    286             return None 
    287          
    288         if "(" in dbtype: 
    289             dbtype = dbtype[:dbtype.find("(")] 
    290          
    291         key = (pytype, dbtype) 
    292         try: 
    293             meth = self._memoized_methods[key] 
    294         except KeyError: 
    295             meth = getCoerceMethod(self, pytype, dbtype) 
    296             self._memoized_methods[key] = meth 
    297          
    298         return meth(value) 
    299      
    300     def do_pickle(self, value): 
     499    def pull(self, value): 
    301500        # Coerce to str for pickle.loads restriction. 
    302         value = self.coerce_any_to_str(value) 
     501        if isinstance(value, unicode): 
     502            value = value.encode(self.encoding) 
     503        else: 
     504            value = str(value) 
    303505        return pickle.loads(value) 
    304      
    305     coerce_any_to_bool = bool 
    306      
    307     def coerce_any_to_datetime_datetime(self, value): 
    308         chunks = (value[0:4], value[5:7], value[8:10], 
    309                   value[11:13], value[14:16], value[17:19], 
    310                   value[20:26] or 0) 
    311         return datetime.datetime(*map(int, chunks)) 
    312      
    313     def coerce_any_to_datetime_date(self, value): 
    314         chunks = (value[0:4], value[5:7], value[8:10]) 
    315         return datetime.date(*map(int, chunks)) 
    316      
    317     def coerce_any_to_datetime_time(self, value): 
    318         chunks = (value[0:2], value[3:5], value[6:8]) 
    319         return datetime.time(*map(int, chunks)) 
    320      
    321     def coerce_any_to_datetime_timedelta(self, value): 
    322         days, seconds = divmod(long(value), 86400) 
    323         return datetime.timedelta(int(days), int(seconds)) 
    324      
    325     def coerce_any_to_decimal(self, value): 
    326         return decimal(str(value)) 
    327      
    328     def coerce_any_to_decimal_Decimal(self, value): 
    329         return decimal.Decimal(str(value)) 
    330      
    331     coerce_any_to_dict = do_pickle 
    332      
    333     def coerce_any_to_fixedpoint_FixedPoint(self, value): 
    334         if (isinstance(value, basestring) or 
    335             decimal and isinstance(value, decimal.Decimal)): 
    336             # Unicode really screws up fixedpoint; for example: 
    337             # >>> fixedpoint.FixedPoint(u'111111111111111111111111111.1') 
    338             # FixedPoint('111111111111111104952008704.00', 2) 
    339             value = str(value) 
    340              
    341             scale = 0 
    342             atoms = value.rsplit(".", 1) 
    343             if len(atoms) > 1: 
    344                 scale = len(atoms[-1]) 
    345             return fixedpoint.FixedPoint(value, scale) 
    346         else: 
    347             return fixedpoint.FixedPoint(value) 
    348      
    349     coerce_any_to_float = float 
    350     coerce_any_to_int = int 
    351     coerce_any_to_list = do_pickle 
    352     coerce_any_to_geniusql_logic_Expression = do_pickle 
    353     coerce_any_to_long = long 
    354      
    355     def coerce_any_to_str(self, value): 
    356         if isinstance(value, unicode): 
    357             return value.encode(self.encoding) 
    358         else: 
    359             return str(value) 
    360      
    361     coerce_any_to_tuple = do_pickle 
    362      
    363     def coerce_any_to_unicode(self, value): 
    364         if isinstance(value, unicode): 
    365             return value 
    366         else: 
    367             return unicode(value, self.encoding) 
    368      
    369     def add_pickled_type(self, pytype): 
    370         name = "coerce_any_to_%s" % getCoerceName(pytype) 
    371         setattr(self, name, self.do_pickle) 
    372  
    373  
    374 class TypeAdapter(object): 
    375     """Determine the best database type for a given column + Python type. 
    376      
    377     When Geniusql is asked to create database tables, it must choose an 
    378     appropriate column data type for each UnitProperty based on the 
    379     type (and hints) of that property. This class recommends such 
    380     database types, usually by returning the type name as a string 
    381     (so it can be inserted into SQL statements). 
    382      
    383     This base class is designed to work out-of-the-box with PostgreSQL 8. 
    384     """ 
    385      
    386     # Max binary precision for floating-point columns (= 53 for PostgreSQL 8). 
    387     # Python floats are implemented using C doubles; actual precision 
    388     # depends on platform (but is usually 53 binary digits, see maxfloat_digits). 
    389     # PostgreSQL DOUBLE is 53 binary-digit precision. 
    390     float_max_precision = 53 
    391      
    392     # Max decimal precision for NUMERIC columns (= 1000 for PostgreSQL 8). 
    393     numeric_max_precision = 1000 
    394      
    395     # "The actual storage requirement is two bytes for each group of four 
    396     # decimal digits, plus eight bytes overhead." Note we omit the overhead. 
    397     numeric_max_bytes = 500 
    398      
    399     # This type name will be returned when falling back to a character type 
    400     # from a numeric type which cannot support the desired precision. 
    401     # TEXT is not an SQL standard, but it's common. 
    402     numeric_text_type = "TEXT" 
    403      
    404     _reverse_types = { 
    405         "DATE": datetime.date, 
    406         "DATETIME": datetime.datetime, 
    407         "TIMESTAMP": datetime.datetime, 
    408         "TIME": datetime.time, 
    409          
    410         "INT": int, 
    411         "INTEGER": int, 
    412         "SMALLINT": int, 
    413          
    414         "BOOL": bool, 
    415         "BOOLEAN": bool, 
    416          
    417         "BIGINT": long, 
    418         "LONG": long, 
    419          
    420         "DOUBLE": float, 
    421         "DOUBLE PRECISION": float, 
    422         "FLOAT": float, 
    423         "REAL": float, 
    424         "SINGLE": float, 
    425          
    426         "CHAR": str, 
    427         "BLOB": str, 
    428         "TEXT": str, 
    429         "VARCHAR": str, 
    430         } 
    431      
    432     if decimal: 
    433         _reverse_types["DECIMAL"] = decimal.Decimal 
    434         _reverse_types["NUMERIC"] = decimal.Decimal 
    435     elif fixedpoint: 
    436         _reverse_types["DECIMAL"] = fixedpoint.FixedPoint 
    437         _reverse_types["NUMERIC"] = fixedpoint.FixedPoint 
    438      
    439     def __init__(self): 
    440         # Make a copy of the class-level dict 
    441         self._reverse_types = self._reverse_types.copy() 
    442      
    443     def python_type(self, dbtype): 
    444         """Return a Python type which can store values of the given dbtype.""" 
    445         # Strip any size argument (e.g. "VARCHAR(255)"). 
    446         key = dbtype.upper().split("(", 1)[0] 
    447         try: 
    448             return self._reverse_types[key] 
    449         except KeyError: 
    450             raise TypeError("Database type %r could not be converted " 
    451                             "to a Python type." % dbtype) 
    452      
    453     def related(self, pytype1, pytype2): 
    454         """If values of both types are expressed with the same SQL, return True.""" 
    455         if issubclass(pytype1, pytype2) or issubclass(pytype2, pytype1): 
    456             return True 
    457         if issubclass(pytype1, basestring) and issubclass(pytype2, basestring): 
    458             return True 
    459         if ((issubclass(pytype1, int) or issubclass(pytype1, long)) and 
    460             (issubclass(pytype2, int) or issubclass(pytype2, long))): 
    461             return True 
    462         if fixedpoint: 
    463             if decimal: 
    464                 if ((issubclass(pytype1, fixedpoint.FixedPoint) 
    465                      or issubclass(pytype1, decimal.Decimal)) and 
    466                     (issubclass(pytype2, fixedpoint.FixedPoint) 
    467                      or issubclass(pytype2, decimal.Decimal))): 
    468                     return True 
     506 
     507 
     508def make_numeric_text_adapter(pytype): 
     509    """Return a new Adapter class (between the given pytype and TEXT).""" 
     510    class NumericTextAdapter(Adapter): 
     511        pull = pytype 
     512        def push(self, value): 
     513            return "'%s'" % str(value) 
     514    return NumericTextAdapter 
     515 
     516 
     517if typerefs.decimal: 
     518    if hasattr(typerefs.decimal, "Decimal"): 
     519        class DECIMAL(Adapter): 
     520            push = str 
     521            def pull(self, value): 
     522                # pywin32 build 205 began support for returning 
     523                # COM Currency objects as decimal objects. 
     524                # See http://pywin32.cvs.sourceforge.net/pywin32/pywin32/CHANGES.txt?view=markup 
     525                if not isinstance(value, typerefs.decimal.Decimal): 
     526                    return typerefs.decimal.Decimal(str(value)) 
     527                return value 
     528    else: 
     529        class DECIMAL(Adapter): 
     530            push = str 
     531            def pull(self, value): 
     532                if not isinstance(value, typerefs.decimal): 
     533                    return typerefs.decimal(str(value)) 
     534                return value 
     535 
     536if typerefs.fixedpoint: 
     537    class FIXEDPOINT(Adapter): 
     538        push = str 
     539        def pull(self, value): 
     540            if (isinstance(value, basestring) or 
     541                (typerefs.decimal and 
     542                 isinstance(value, typerefs.decimal.Decimal))): 
     543                # Unicode really screws up fixedpoint; for example: 
     544                # >>> fixedpoint.FixedPoint(u'111111111111111111111111111.1') 
     545                # FixedPoint('111111111111111104952008704.00', 2) 
     546                value = str(value) 
     547                 
     548                scale = 0 
     549                atoms = value.rsplit(".", 1) 
     550                if len(atoms) > 1: 
     551                    scale = len(atoms[-1]) 
     552                return typerefs.fixedpoint.FixedPoint(value, scale) 
    469553            else: 
    470                 if (issubclass(pytype1, fixedpoint.FixedPoint) and 
    471                     issubclass(pytype2, fixedpoint.FixedPoint)): 
    472                     return True 
    473         else: 
    474             if decimal: 
    475                 if (issubclass(pytype1, decimal.Decimal) and 
    476                     issubclass(pytype2, decimal.Decimal)): 
    477                     return True 
    478         return False 
    479      
    480     def coerce(self, pytype, hints=None): 
    481         """Return a database type for the given Python type. 
    482          
    483         hints: if provided, this should be a dict of property attributes 
    484             which can be used to distinguish between similar database types. 
    485             Canonical keys include 'bytes', 'precision', and 'scale'. 
    486         """ 
    487         xform = "coerce_" + getCoerceName(pytype) 
    488         try: 
    489             xform = getattr(self, xform) 
    490         except AttributeError: 
    491             raise TypeError("'%s' is not handled by %s." % 
    492                             (pytype, self.__class__)) 
    493         if hints is None: 
    494             hints = {} 
    495         return xform(hints) 
    496      
    497     def float_type(self, precision): 
    498         """Return a datatype which can handle floats of the given binary precision.""" 
    499         if precision <= 24: 
    500             return "REAL" 
    501         else: 
    502             return "DOUBLE PRECISION" 
    503      
    504     def coerce_float(self, hints): 
    505         # Note that 'precision' is binary digits, not decimal. 
    506         precision = int(hints.get('precision', maxfloat_digits)) 
    507         if precision > self.float_max_precision: 
    508             return self.numeric_text_type 
    509         return self.float_type(precision) 
    510      
    511     def coerce_str(self, hints): 
    512         # The bytes hint shall not reflect the usual 4-byte base for varchar. 
    513         bytes = int(hints.get('bytes', 255)) 
    514         if bytes and bytes <= 255: 
    515             return "VARCHAR(%s)" % bytes 
    516         return "TEXT" 
    517      
    518     def coerce_dict(self, hints): 
    519         return self.coerce_str(hints) 
    520     def coerce_list(self, hints): 
    521         return self.coerce_str(hints) 
    522     def coerce_tuple(self, hints): 
    523         return self.coerce_str(hints) 
    524     def coerce_unicode(self, hints): 
    525         return self.coerce_str(hints) 
    526      
    527     def coerce_geniusql_logic_Expression(self, hints): 
    528         return self.coerce_str(hints) 
    529      
    530     def coerce_bool(self, hints): return "BOOLEAN" 
    531      
    532     def coerce_datetime_datetime(self, hints): return "TIMESTAMP" 
    533     def coerce_datetime_date(self, hints): return "DATE" 
    534     def coerce_datetime_time(self, hints): return "TIME" 
    535      
    536     # Use decimal instead of float to avoid rounding errors. 
    537     def coerce_datetime_timedelta(self, hints): 
    538         # This is a fallback for DB's which do not have an INTERVAL data type. 
    539         # If your DB has an INTERVAL datatype, you should override this and 
    540         # use the native INTERVAL type instead. You'll usually also have to 
    541         # update the date arithmetic inside the decompiler. 
    542         return self.int_type(self.numeric_max_bytes) 
    543      
    544     def decimal_type(self, precision, scale): 
    545         if precision > self.numeric_max_precision: 
    546             errors.warn("The given precision (%s) is greater than the " 
    547                         "maximum numeric precision (%s). Using %s instead." 
    548                         % (precision, self.numeric_max_precision, 
    549                            self.numeric_text_type)) 
    550             return self.numeric_text_type 
    551         if scale > precision: 
    552             scale = precision 
    553         return "NUMERIC(%s, %s)" % (precision, scale) 
    554      
    555     def coerce_decimal_Decimal(self, hints): 
    556         precision = int(hints.get('precision', self.numeric_max_precision)) 
    557         # Assume most people use decimal for money; default scale = 2. 
    558         scale = int(hints.get('scale', 2)) 
    559         return self.decimal_type(precision, scale) 
    560      
    561     def coerce_decimal(self, hints): 
    562         # If decimal ever becomes a builtin. Python 2.5? 
    563         return self.coerce_decimal_Decimal(hints) 
    564      
    565     def coerce_fixedpoint_FixedPoint(self, hints): 
    566         # Note that fixedpoint has no theoretical precision limit. 
    567         precision = int(hints.get('precision', self.numeric_max_precision)) 
    568         # Assume most people use fixedpoint for money; default scale = 2. 
    569         scale = int(hints.get('scale', 2)) 
    570         return self.decimal_type(precision, scale) 
    571      
    572     def int_type(self, bytes): 
    573         """Return a datatype which can handle the given number of bytes.""" 
    574         if bytes <= 2: 
    575             return "SMALLINT" 
    576         elif bytes <= 4: 
    577             return "INTEGER" 
    578         elif bytes <= 8: 
    579             # BIGINT is usually 8 bytes 
    580             return "BIGINT" 
    581         else: 
    582             # Anything larger than 8 bytes, use decimal/numeric. 
    583             # For PostgreSQL, "The actual storage requirement is two bytes 
    584             # for each group of four decimal digits, plus eight bytes 
    585             # overhead." Note we omit the overhead in our calculation. 
    586             return "NUMERIC(%s, 0)" % (bytes * 2) 
    587      
    588     def coerce_long(self, hints): 
    589         bytes = int(hints.get('bytes', self.numeric_max_bytes)) 
    590         if bytes > self.numeric_max_bytes: 
    591             return self.numeric_text_type 
    592         return self.int_type(bytes) 
    593      
    594     def coerce_int(self, hints): 
    595         bytes = int(hints.get('bytes', maxint_bytes)) 
    596         if bytes > maxint_bytes: 
    597             return self.coerce_long(hints) 
    598         return self.int_type(bytes) 
    599      
    600     def add_pickled_type(self, pytype): 
    601         name = "coerce_%s" % getCoerceName(pytype) 
    602         setattr(self, name, self.coerce_str) 
    603  
     554                return typerefs.fixedpoint.FixedPoint(value) 
     555 
  • trunk/geniusql/decompile.py

    r53 r54  
    2828        self.dbtype = dbtype 
    2929        self.pytype = pytype 
    30         self.imperfect_type = Fals
     30        self.adapter = Non
    3131         
    3232        self.value = value 
     
    4040     
    4141    def __repr__(self): 
    42         return "%s.%s(%r)" % (self.__module__, self.__class__.__name__, 
    43                               self.sql) 
     42        return ("%s.%s(%r, dbtype=%s)" % 
     43                (self.__module__, self.__class__.__name__, self.sql, 
     44                 self.dbtype.__class__.__name__)) 
    4445 
    4546 
     
    7879                 ) 
    7980     
     81    # SQL comparison operators (matching the order of opcode.cmp_op). 
    8082    sql_cmp_op = ('<', '<=', '=', '!=', '>', '>=', 'in', 'not in') 
    8183     
    82     def __init__(self, tables, expr, adapter, typeadapter): 
     84    # These are not adapter.push(bool) (which are used on one side of  
     85    # a comparison). Instead, these are used when the whole (sub)expression 
     86    # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3". 
     87    bool_true = "TRUE" 
     88    bool_false = "FALSE" 
     89     
     90    def __init__(self, tables, expr, adapterset): 
    8391        self.tables = tables 
    8492        self.expr = expr 
    85         self.adapter = adapter 
    86         self.typeadapter = typeadapter 
     93        self.adapterset = adapterset 
    8794         
    8895        self.groups = [] 
    8996         
    9097        # Cache coerced booleans 
    91         self.true_expr = self.const(True, self.adapter.bool_true) 
    92         self.false_expr = self.const(False, self.adapter.bool_false) 
     98        self.true_expr = self.const(True, self.bool_true) 
     99        self.false_expr = self.const(False, self.bool_false) 
     100        self.T = self.const(True, self.true_expr.adapter.push(True)) 
     101        self.F = self.const(False, self.false_expr.adapter.push(False)) 
    93102        self.none_expr = SQLExpression("NULL", "expr0", None, type(None)) 
    94         self.T = self.const(True, adapter.coerce_bool_to_any(True)) 
    95         self.F = self.const(False, adapter.coerce_bool_to_any(False)) 
    96103         
    97104        codewalk.LambdaDecompiler.__init__(self, expr.func) 
     
    99106    exprcount = 0 
    100107     
    101     def get_expr(self, sql, pytype): 
     108    def get_expr(self, sql, pytype, adapter=None): 
    102109        """Return an SQLExpression for the given sql of the given pytype.""" 
    103         typer = self.typeadapter 
    104         dbtype = typer.coerce(pytype) 
     110        typer = self.adapterset 
     111        dbtype = typer.database_type(pytype) 
    105112         
    106113        self.exprcount += 1 
    107114        name = "expr%s" % self.exprcount 
    108115        e = SQLExpression(sql, name, dbtype, pytype) 
    109          
    110         reverse_type = typer.python_type(dbtype) 
    111         e.imperfect_type = not typer.related(pytype, reverse_type) 
     116        e.adapter = adapter or typer.default(pytype, dbtype) 
    112117         
    113118        return e 
     
    117122        if value is None: 
    118123            return self.none_expr 
    119         pytype = type(value) 
    120         dbtype = self.typeadapter.coerce(pytype) 
     124         
     125        e = self.get_expr(sql, type(value)) 
     126        e.value = value 
    121127        if sql is None: 
    122             sql = self.adapter.coerce(value, dbtype, pytype) 
    123          
    124         e = self.get_expr(sql, pytype) 
    125         e.value = value 
     128            e.sql = e.adapter.push(value) 
    126129        return e 
     130     
     131    def append_expr(self, sql, pytype): 
     132        """Syntactic sugar for self.stack.append(self.get_expr(sql, pytype)).""" 
     133        self.stack.append(self.get_expr(sql, pytype)) 
    127134     
    128135    def code(self): 
     
    228235            atom = SQLExpression('%s.%s' % (alias, col.qname), 
    229236                                 name, col.dbtype, col.pytype) 
    230             atom.imperfect_type = col.imperfect_type 
     237            atom.adapter = col.adapter 
    231238        else: 
    232239            # 'tos.name' will reference an attribute of the tos object. 
     
    312319        elif op1.sql == 'NULL': 
    313320            if op in (2, 8):    # '==', is 
    314                 self.stack.append(self.get_expr(op2.sql + " IS NULL", bool)
     321                self.append_expr(op2.sql + " IS NULL", bool
    315322            elif op in (3, 9):  # '!=', 'is not' 
    316                 self.stack.append(self.get_expr(op2.sql + " IS NOT NULL", bool)
     323                self.append_expr(op2.sql + " IS NOT NULL", bool
    317324            else: 
    318325                raise ValueError("Non-equality Null comparisons not allowed.") 
    319326        elif op2.sql == 'NULL': 
    320327            if op in (2, 8):    # '==', 'is' 
    321                 self.stack.append(self.get_expr(op1.sql + " IS NULL", bool)
     328                self.append_expr(op1.sql + " IS NULL", bool
    322329            elif op in (3, 9):  # '!=', 'is not' 
    323                 self.stack.append(self.get_expr(op1.sql + " IS NOT NULL", bool)
     330                self.append_expr(op1.sql + " IS NOT NULL", bool
    324331            else: 
    325332                raise ValueError("Non-equality Null comparisons not allowed.") 
    326333        else: 
     334            # Try to cast from one to the other. Try in both directions 
     335            # (but try to cast op2 first, since most of *my* expressions 
     336            # put the column first and a constant second ("Field < 3")). 
    327337            try: 
    328                 op1, op2 = self._compare_constants(op1, op2) 
    329             except TypeError: 
    330                 self.stack.append(cannot_represent) 
    331                 self.imperfect = True 
    332                 return 
     338                op2.sql = op2.dbtype.cast(op2.sql, op1.dbtype) 
     339            except (AttributeError, TypeError): 
     340                try: 
     341                    op1.sql = op1.dbtype.cast(op1.sql, op2.dbtype) 
     342                except (AttributeError, TypeError): 
     343                    self.stack.append(cannot_represent) 
     344                    self.imperfect = True 
     345                    return 
     346             
    333347            # Comparison operators for strings are case-sensitive in PG et al. 
    334348            e = op1.sql + " " + self.sql_cmp_op[op] + " " + op2.sql 
    335             self.stack.append(self.get_expr(e, bool)) 
    336      
    337     def _compare_constants(self, op1, op2): 
    338         """Coerce/cast compared types. 
    339          
    340         If a column value is compared to a constant and no coerce or cast 
    341         adapter function is available, a TypeError is raised. 
    342         """ 
    343         if op1.value is None: 
    344             if op2.value is not None: 
    345                 # op2 is a constant 
    346                 if op1.imperfect_type: 
    347                     # Try to cast the column to op2's type 
    348                     op1.sql = self.adapter.cast(op1, op1.dbtype, op2.pytype) 
    349                 else: 
    350                     # Try to coerce op2 to the column's type 
    351                     op2.sql = self.adapter.coerce(op2.value, op1.dbtype) 
    352         else: 
    353             if op2.imperfect_type: 
    354                 # Try to cast the column to op1's type 
    355                 op2.sql = self.adapter.cast(op2, op2.dbtype, op1.pytype) 
    356             else: 
    357                 # Try to coerce op1 to the column's type 
    358                 op1.sql = self.adapter.coerce(op1.value, op2.dbtype) 
    359         return op1, op2 
     349            self.append_expr(e, bool) 
    360350     
    361351    def visit_BINARY_SUBSCR(self): 
     
    378368            self.stack.append(cannot_represent) 
    379369        else: 
    380             self.stack.append(self.get_expr("NOT (" + op.sql + ")", bool)
     370            self.append_expr("NOT (" + op.sql + ")", bool
    381371     
    382372    # --------------------------- Dispatchees --------------------------- # 
    383373     
     374    # Notice these are ordered pairs. Escape \ before introducing new ones. 
     375    # Values in these two lists should be strings encoded with self.encoding. 
     376    like_escapes = [("%", r"\%"), ("_", r"\_")] 
     377     
     378    def escape_like(self, value): 
     379        """Prepare a string value for use in a LIKE comparison.""" 
     380        if not isinstance(value, str): 
     381            value = value.encode(self.encoding) 
     382        # Notice we strip leading and trailing quote-marks. 
     383        value = value.strip("'\"") 
     384        for pat, repl in self.like_escapes: 
     385            value = value.replace(pat, repl) 
     386        return value 
     387     
    384388    def attr_startswith(self, tos, arg): 
    385         return self.get_expr(tos.sql + " LIKE '" + self.adapter.escape_like(arg.sql) + "%'", bool) 
     389        return self.get_expr(tos.sql + " LIKE '" + self.escape_like(arg.sql) + "%'", bool) 
    386390     
    387391    def attr_endswith(self, tos, arg): 
    388         return self.get_expr(tos.sql + " LIKE '%" + self.adapter.escape_like(arg.sql) + "'", bool) 
     392        return self.get_expr(tos.sql + " LIKE '%" + self.escape_like(arg.sql) + "'", bool) 
    389393     
    390394    def containedby(self, op1, op2): 
    391395        if op1.value is not None: 
    392396            # Looking for text in a field. Use Like (reverse terms). 
    393             like = self.adapter.escape_like(op1.sql) 
     397            like = self.escape_like(op1.sql) 
    394398            return self.get_expr(op2.sql + " LIKE '%" + like + "%'", bool) 
    395399        else: 
    396400            # Looking for field in (a, b, c) 
    397             atoms = [self.adapter.coerce(x) for x in op2.value] 
     401            atoms = [self.adapterset.default(type(x), op1.dbtype).push(x) 
     402                     for x in op2.value] 
    398403            if atoms: 
    399404                return self.get_expr(op1.sql + " IN (" + ", ".join(atoms) + ")", bool) 
     
    406411            # Looking for text in a field. Use Like (reverse terms). 
    407412            return self.get_expr("LOWER(" + op2.sql + ") LIKE '%" + 
    408                                  self.adapter.escape_like(op1.sql).lower() 
     413                                 self.escape_like(op1.sql).lower() 
    409414                                 + "%'", bool) 
    410415        else: 
    411416            # Looking for field in (a, b, c). 
    412417            # Force all args to lowercase for case-insensitive comparison. 
    413             atoms = [self.adapter.coerce(x).lower() for x in op2.value] 
     418            atoms = [self.adapterset.default(type(x), op1.dbtype).push(x).lower() 
     419                     for x in op2.value] 
    414420            return self.get_expr("LOWER(%s) IN (%s)" % 
    415421                                 (op1.sql, ", ".join(atoms)), bool) 
     
    420426    def builtins_istartswith(self, x, y): 
    421427        return self.get_expr("LOWER(" + x.sql + ") LIKE '" + 
    422                              self.adapter.escape_like(y.sql) + "%'", bool) 
     428                             self.escape_like(y.sql) + "%'", bool) 
    423429     
    424430    def builtins_iendswith(self, x, y): 
    425431        return self.get_expr("LOWER(" + x.sql + ") LIKE '%" + 
    426                              self.adapter.escape_like(y.sql) + "'", bool) 
     432                             self.escape_like(y.sql) + "'", bool) 
    427433     
    428434    def builtins_ieq(self, x, y): 
     
    481487            return 
    482488         
     489        try: 
     490            newsql = op1.adapter.binary_op(op1, op, op2) 
     491        except TypeError: 
     492            self.stack.append(cannot_represent) 
     493            return 
     494         
     495        newpytype = self.result_type[(op1.pytype, op, op2.pytype)] 
     496         
    483497        # re-use op1 
    484         op1.pytype = self.result_type[(op1.pytype, op, op2.pytype)] 
    485         op1.sql = "%s %s %s" % (op1.sql, op, op2.sql) 
     498        op1.sql = newsql 
     499        if newpytype != op1.pytype: 
     500            op1.pytype = newpytype 
     501            op1.dbtype = self.adapterset.database_type(newpytype) 
     502            op1.adapter = self.adapterset.default(newpytype, op1.dbtype) 
    486503        if not op1.name.startswith("expr_"): 
    487504            op1.name = "expr_%s" % op1.name 
  • trunk/geniusql/objects.py

    r53 r54  
    116116    name: the SQL name for this table (unquoted). 
    117117    qname: the SQL name for this table (quoted). 
     118     
    118119    pytype: the Python type (the actual type object, not its name). 
    119     dbtype: the database type name (as used in a CREATE TABLE statement). 
     120    dbtype: a DatabaseType instance. 
     121    adapter: the object whose push and pull methods will convert Python 
     122        values to and from SQL for values in this Column. 
     123     
    120124    default: default Python value for this column for new rows. 
    121     hints: a dict of implementation hints, such as precision, scale, or bytes. 
    122125    key: True if this column is part of the table's primary key. 
    123126     
    124     imperfect_type: if True, signals that we are deliberately using a 
    125         database type other than the default (usually in order to handle 
    126         irregular values, such as huge numbers). When comparing imperfect 
    127         column values with constant values in SQL, the database must be 
    128         able to cast the column value to the constant's type. If that 
    129         cannot be done for the given types, then the query will be marked 
    130         imperfect. 
    131127    autoincrement: if True, uses the database's built-in sequencing. 
    132128    sequence_name: for databases that use separate statements to create and 
     
    135131    """ 
    136132     
    137     def __init__(self, pytype, dbtype, default=None, hints=None, key=False, 
     133    def __init__(self, pytype, dbtype, default=None, key=False, 
    138134                 name=None, qname=None): 
    139135        self.pytype = pytype 
    140136        self.dbtype = dbtype 
     137        self.adapter = None 
     138         
    141139        self.name = name 
    142140        self.qname = qname 
    143141        self.default = default 
    144         if hints is None: 
    145             hints = {} 
    146         else: 
    147             hints = hints.copy() 
    148         self.hints = hints 
    149142        self.key = key 
    150143         
     
    153146        self.sequence_name = None 
    154147        self.initial = 1 
    155          
    156         self.imperfect_type = False 
    157148     
    158149    def __repr__(self): 
    159         return ("%s.%s(%r, %r, default=%r, hints=%r, key=%r, name=%r, qname=%r)" % 
     150        return ("%s.%s(%r, %r, default=%r, key=%r, name=%r, qname=%r)" % 
    160151                (self.__module__, self.__class__.__name__, 
    161                  self.pytype, self.dbtype, 
    162                  self.default, self.hints, self.key, 
     152                 self.pytype, self.dbtype, self.default, self.key, 
    163153                 self.name, self.qname) 
    164154                ) 
    165155     
    166156    def __copy__(self): 
    167         newcol = self.__class__(self.pytype, self.dbtype, 
    168                                 self.default, self.hints, self.key, 
    169                                 self.name, self.qname) 
     157        newcol = self.__class__(self.pytype, self.dbtype, self.default, 
     158                                self.key, self.name, self.qname) 
    170159        newcol.autoincrement = self.autoincrement 
    171160        newcol.initial = self.initial 
    172         newcol.imperfect_type = self.imperfect_type 
     161        newcol.adapter = self.adapter 
    173162        return newcol 
    174163    copy = __copy__ 
     
    341330        tpair = [(self.qname, self)] 
    342331        decom = self.schema.db.decompiler(tpair, logic.filter(**inputs), 
    343                                           self.schema.db.adaptertosql, 
    344                                           self.schema.db.typeadapter) 
     332                                          self.schema.db.adapterset) 
     333##        decom.verbose = True 
    345334        code = decom.code() 
    346335        if decom.imperfect: 
     
    363352    def insert(self, **inputs): 
    364353        """Insert a row, then return inputs including any new identifiers.""" 
    365         coerce_out = self.schema.db.adaptertosql.coerce 
    366         coerce_in = self.schema.db.adapterfromdb.coerce 
    367          
    368354        fields = [] 
    369355        idkeys = [] 
     
    375361                continue 
    376362            if key in inputs: 
    377                 val = coerce_out(inputs[key], col.dbtype
     363                val = col.adapter.push(inputs[key]
    378364                fields.append(col.qname) 
    379365                values.append(val) 
     
    390376            for k, v in self._grab_new_ids(idkeys, conn).iteritems(): 
    391377                col = self[k] 
    392                 base[k] = coerce_in(v, col.dbtype, col.pytype
     378                base[k] = col.adapter.pull(v
    393379        return base 
    394380     
     
    400386        """Update a row using the given inputs.""" 
    401387        parms = [] 
    402         coerce = self.schema.db.adaptertosql.coerce 
    403388        for key, val in inputs.iteritems(): 
    404389            col = self[key] 
     
    407392                pass 
    408393            else: 
    409                 val = coerce(val, col.dbtype
     394                val = col.adapter.push(val
    410395                parms.append('%s = %s' % (col.qname, val)) 
    411396         
     
    418403        """Update all rows (with 'data' dict) matching the given inputs.""" 
    419404        parms = [] 
    420         coerce = self.schema.db.adaptertosql.coerce 
    421405        for key, val in data.iteritems(): 
    422406            col = self[key] 
    423             val = coerce(val, col.dbtype
     407            val = col.adapter.push(val
    424408            parms.append('%s = %s' % (col.qname, val)) 
    425409         
     
    635619        return self.db.sql_name(columnkey) 
    636620     
    637     def column(self, pytype=unicode, dbtype=None, default=None, hints=None, 
    638                key=False, autoincrement=False): 
     621    def column(self, pytype=unicode, dbtype=None, default=None, 
     622               key=False, autoincrement=False, hints=None): 
    639623        """Return a Column object from the given arguments.""" 
    640         col = Column(pytype, dbtype, default, hints, key) 
     624        col = Column(pytype, dbtype, default, key) 
    641625        col.autoincrement = autoincrement 
    642626         
    643         typer = self.db.typeadapter 
    644         if dbtype is None: 
    645             col.dbtype = typer.coerce(pytype, col.hints) 
    646         pytype2 = typer.python_type(col.dbtype) 
    647         col.imperfect_type = not typer.related(pytype, pytype2) 
     627        typer = self.db.adapterset 
     628        if col.dbtype is None: 
     629            col.dbtype = typer.database_type(pytype, hints or {}) 
     630        col.adapter = typer.default(pytype, col.dbtype) 
    648631         
    649632        return col 
     
    681664        Most subclasses will override this to add autoincrement support. 
    682665        """ 
    683         dbtype = column.dbtype 
     666        ddltype = column.dbtype.ddl() 
    684667         
    685668        default = column.default or "" 
    686669        if default: 
    687             default = self.db.adaptertosql.coerce(default, dbtype
     670            default = column.adapter.push(default
    688671            default = " DEFAULT %s" % default 
    689672         
    690         return "%s %s%s" % (column.qname, dbtype, default) 
     673        return "%s %s%s" % (column.qname, ddltype, default) 
    691674     
    692675    def __setitem__(self, key, table): 
     
    770753    __metaclass__ = geniusql._AttributeDocstrings 
    771754     
    772     adaptertosql = adapters.AdapterToSQL() 
    773     adapterfromdb = adapters.AdapterFromDB() 
    774     typeadapter = adapters.TypeAdapter() 
     755    adapterset = adapters.AdapterSet() 
    775756    decompiler = decompile.SQLDecompiler 
    776757    joinwrapper = select.TableWrapper 
     
    860841            conn = self.connections.get() 
    861842        if isinstance(query, unicode): 
    862             query = query.encode(self.adaptertosql.encoding) 
     843            query = query.encode(self.adapterset.encoding) 
    863844        self.log(query) 
    864845        return conn.query(query) 
     
    954935        self.selector = selector 
    955936        self.data = data 
    956         self.coerce = self.selector.result.schema.db.adapterfromdb.coerce 
    957937        self.cursor = 0 
    958938     
     
    972952            val = row[i] 
    973953            col = self.selector.result[colkey] 
    974             val = self.coerce(val, col.dbtype, col.pytype) 
     954            if val is not None: 
     955                # Any column value can be None. Don't coerce it. 
     956                val = col.adapter.pull(val) 
    975957            coerced_row.append(val) 
    976958        return coerced_row 
  • trunk/geniusql/providers/__init__.py

    r53 r54  
    6565    "sqlite": "geniusql.providers.sqlite.SQLiteDatabase", 
    6666     
    67     "sqlserver": "geniusql.providers.ado.SQLServerDatabase", 
    68     "mssql": "geniusql.providers.ado.SQLServerDatabase", 
     67    "sqlserver": "geniusql.providers.sqlserver.SQLServerDatabase", 
     68    "mssql": "geniusql.providers.sqlserver.SQLServerDatabase", 
    6969    }) 
  • trunk/geniusql/providers/ado.py

    r53 r54  
    2727 
    2828import time 
    29  
    30 try: 
    31     import cPickle as pickle 
    32 except ImportError: 
    33     import pickle 
    3429 
    3530import threading 
     
    8277        132: 'USERDEFINED', 
    8378        133: 'DBDATE', 134: 'DBTIME', 
    84         135: 'DBTIMESTAMP',   # DATETIME, SMALLDATETIME   DATETIME (ODBC 97) 
     79        135: 'DBTIMESTAMP',   # DATETIME, SMALLDATETIME DATETIME (ODBC 97) 
    8580        200: 'VARCHAR',       # VARCHAR                 TEXT (Access 97) 
    8681        201: 'LONGVARCHAR',   # TEXT                    MEMO (Access 97) 
     
    10297 
    10398 
    104 class AdapterFromADO(adapters.AdapterFromDB): 
    105     """Coerce incoming values from ADO to Python datatypes.""" 
    106      
    107     encoding = 'ISO-8859-1' 
    108     epoch = datetime.datetime(1899, 12, 30) 
    109      
    110     def timedelta_from_com(self, com_date): 
    111         """Return a valid datetime.timedelta from a COM date/time object.""" 
    112         com_date = float(com_date) 
    113          
    114         # MS Access represents dates and times as floats. If the value is 
    115         # before the epoch (12/30/1899), the seconds will be SUBTRACTED 
    116         # from the float. For example, -2.01 is in the morning and -2.99 
    117         # is in the evening of the same day. Therefore, when we split off 
    118         # our seconds we must use the abs value of the fractional portion. 
    119         neg = (com_date < 0) 
    120         com_date = abs(com_date) 
    121          
    122         days = int(com_date) 
    123         # Must do both int() and round() or we'll be up to 1 second off. 
    124         secs = int(round(86400 * (com_date - days))) 
    125          
    126         result = datetime.timedelta(days, secs) 
    127         if neg: 
    128             return -result 
     99def timedelta_from_com(com_date): 
     100    """Return a valid datetime.timedelta from a COM date/time object.""" 
     101    com_date = float(com_date) 
     102     
     103    # MS Access represents dates and times as floats. If the value is 
     104    # before the epoch (12/30/1899), the seconds will be SUBTRACTED 
     105    # from the float. For example, -2.01 is in the morning and -2.99 
     106    # is in the evening of the same day. Therefore, when we split off 
     107    # our seconds we must use the abs value of the fractional portion. 
     108    neg = (com_date < 0) 
     109    com_date = abs(com_date) 
     110     
     111    days = int(com_date) 
     112    # Must do both int() and round() or we'll be up to 1 second off. 
     113    secs = int(round(86400 * (com_date - days))) 
     114     
     115    result = datetime.timedelta(days, secs) 
     116    if neg: 
     117        return -result 
     118    else: 
     119        return result 
     120 
     121 
     122class COM_timedelta(adapters.INTERVAL): 
     123     
     124    def pull(self, value): 
     125        if isinstance(value, unicode): 
     126            # The value is a stringified NUMERIC of seconds. 
     127            days, secs = divmod(long(value), 86400) 
     128            return datetime.timedelta(int(days), int(secs)) 
     129        return timedelta_from_com(value) 
     130     
     131    def TIMEDELTAADD(op1, op, op2): 
     132        return "(%s %s %s)" % (op1.sql, op, op2.sql) 
     133    TIMEDELTAADD = staticmethod(TIMEDELTAADD) 
     134     
     135    def DATEADD(dt, td): 
     136        """Return the SQL to add a timedelta to a date.""" 
     137        # Days, seconds seems like a good way to avoid overflow. 
     138        return ("DATEADD(dd, FLOOR(%s / 86400), " 
     139                "DATEADD(ss, (%s %% 86400), %s))" 
     140                % (td, td, dt)) 
     141    DATEADD = staticmethod(DATEADD) 
     142     
     143    def DATETIMEADD(dt, td): 
     144        """Return the SQL to add a timedelta to a datetime.""" 
     145        return "(%s + (%s / 86400.0))" % (dt, td) 
     146    DATETIMEADD = staticmethod(DATETIMEADD) 
     147     
     148    def binary_op(self, op1, op, op2): 
     149        if op2.pytype is datetime.timedelta: 
     150            return self.TIMEDELTAADD(op1, op, op2) 
    129151        else: 
    130             return result 
    131      
    132     def coerce_any_to_datetime_timedelta(self, value): 
    133         # Assume pywintypes.TimeType 
    134         return self.timedelta_from_com(value) 
    135      
    136     def coerce_any_to_datetime_time(self, value): 
    137         t = self.timedelta_from_com(value) 
     152            if op == "+": 
     153                if op2.pytype is datetime.date: 
     154                    return self.DATEADD(op2.sql, op1.sql) 
     155                elif op2.pytype is datetime.datetime: 
     156                    return self.DATETIMEADD(op2.sql, op1.sql) 
     157         
     158        raise TypeError("unsupported operand type(s) for %s: " 
     159                        "%r and %r" % (op, op1.pytype, op2.pytype)) 
     160 
     161 
     162class COM_time(adapters.SQL92TIME): 
     163    def pull(self, value): 
     164        t = timedelta_from_com(value) 
    138165        if t.days: 
    139166            raise ValueError("Time values greater than 23:59:59 not allowed.") 
     
    141168        m, s = divmod(m, 60) 
    142169        return datetime.time(int(h), int(m), int(s)) 
    143      
    144     def datetime_from_com(self, com_date): 
     170 
     171 
     172class COM_datetime(adapters.SQL92TIMESTAMP): 
     173     
     174    epoch = datetime.datetime(1899, 12, 30) 
     175     
     176    def pull(self, value): 
    145177        """Return a valid datetime.datetime from a COM date/time object.""" 
    146         com_date = float(com_date) 
     178        # Illegal Date/Time values will crash the app when using 
     179        # value.Format(). Therefore, grab the float value and figure 
     180        # the date ourselves. Use 1-second resolution only. 
     181        com_date = float(value) 
    147182         
    148183        # MS Access represents dates and times as floats. If the value is 
     
    160195        return self.epoch + datetime.timedelta(days, secs) 
    161196     
    162     def coerce_any_to_datetime_datetime(self, value): 
    163         if isinstance(value, basestring): 
    164             if value: 
    165                 try: 
    166                     return datetime.datetime(int(value[0:4]), int(value[4:6]), 
    167                                              int(value[6:8])) 
    168                 except Exception: 
    169                     raise ValueError("'%s' %s" % (value, type(value))) 
    170             else: 
    171                 return None 
    172         else: 
    173             # Illegal Date/Time values will crash the app when using 
    174             # value.Format(). Therefore, grab the float value and figure 
    175             # the date ourselves. Use 1-second resolution only. 
    176             return self.datetime_from_com(value) 
    177      
    178     def coerce_any_to_datetime_date(self, value): 
    179         if isinstance(value, basestring): 
    180             if value: 
    181                 try: 
    182                     return datetime.date(int(value[0:4]), int(value[4:6]), 
    183                                          int(value[6:8])) 
    184                 except Exception: 
    185                     raise ValueError("'%s' %s" % (value, type(value))) 
    186             else: 
    187                 return None 
    188         else: 
    189             value = float(value) 
    190             days = int(value) 
    191             return self.epoch.date() + datetime.timedelta(days) 
    192      
    193     def coerce_any_to_decimal_Decimal(self, value): 
    194         # pywin32 build 205 began support for returning 
    195         # COM Currency objects as decimal objects. 
    196         # See http://pywin32.cvs.sourceforge.net/pywin32/pywin32/CHANGES.txt?view=markup 
    197         if not isinstance(value, typerefs.decimal.Decimal): 
    198             value = str(value) 
    199             value = typerefs.decimal.Decimal(value) 
    200         return value 
    201      
    202     def coerce_CURRENCY_to_float(self, value): 
     197    def DATETIMEADD(dt, td): 
     198        """Return the SQL to add a timedelta to a datetime.""" 
     199        return "(%s + (%s / 86400.0))" % (dt, td) 
     200    DATETIMEADD = staticmethod(DATETIMEADD) 
     201     
     202    def DATETIMEDIFF(d1, d2): 
     203        """Return the SQL to subtract one datetime from another.""" 
     204        return "CAST(CAST(%s - %s AS FLOAT) * 86400 AS NUMERIC)" % (d1, d2) 
     205    DATETIMEDIFF = staticmethod(DATETIMEDIFF) 
     206     
     207    def DATETIMESUB(dt, td): 
     208        """Return the SQL to subtract a timedelta from a datetime.""" 
     209        return "(%s - (%s / 86400.0))" % (dt, td) 
     210    DATETIMESUB = staticmethod(DATETIMESUB) 
     211     
     212    def binary_op(self, op1, op, op2): 
     213        if op2.pytype is datetime.datetime: 
     214            if op == "-": 
     215                return self.DATETIMEDIFF(op1.sql, op2.sql) 
     216        elif op2.pytype is datetime.timedelta: 
     217            if op == "+": 
     218                return self.DATETIMEADD(op1.sql, op2.sql) 
     219            elif op == "-": 
     220                return self.DATETIMESUB(op1.sql, op2.sql) 
     221         
     222        raise TypeError("unsupported operand type(s) for %s: " 
     223                        "%r and %r" % (op, op1.pytype, op2.pytype)) 
     224 
     225 
     226class COM_date(adapters.SQL92DATE): 
     227     
     228    epoch = datetime.datetime(1899, 12, 30) 
     229     
     230    def pull(self, value): 
     231        value = float(value) 
     232        days = int(value) 
     233        return self.epoch.date() + datetime.timedelta(days) 
     234     
     235    def DATEDIFF(d1, d2): 
     236        """Return the SQL to subtract one date from another.""" 
     237        # Amazing what a difference a little ".0" can make. 
     238        return "CAST(DATEDIFF(dd, %s, %s) * 86400.0 AS NUMERIC)" % (d2, d1) 
     239    DATEDIFF = staticmethod(DATEDIFF) 
     240     
     241    def DATEADD(dt, td): 
     242        """Return the SQL to add a timedelta to a date.""" 
     243        # Days, seconds seems like a good way to avoid overflow. 
     244        return ("DATEADD(dd, FLOOR(%s / 86400), " 
     245                "DATEADD(ss, (%s %% 86400), %s))" 
     246                % (td, td, dt)) 
     247    DATEADD = staticmethod(DATEADD) 
     248     
     249    def DATESUB(dt, td): 
     250        """Return the SQL to subtract a timedelta from a date.""" 
     251        return "(%s - FLOOR(%s / 86400.0))" % (dt, td) 
     252    DATESUB = staticmethod(DATESUB) 
     253     
     254    def binary_op(self, op1, op, op2): 
     255        if op2.pytype is datetime.date: 
     256            if op == "-": 
     257                return self.DATEDIFF(op1.sql, op2.sql) 
     258        elif op2.pytype is datetime.timedelta: 
     259            if op == "+": 
     260                return self.DATEADD(op1.sql, op2.sql) 
     261            elif op == "-": 
     262                return self.DATESUB(op1.sql, op2.sql) 
     263         
     264        raise TypeError("unsupported operand type(s) for %s: " 
     265                        "%r and %r" % (op, op1.pytype, op2.pytype)) 
     266 
     267 
     268class CURRENCY(adapters.SQL92DOUBLE): 
     269    def pull(self, value): 
    203270        if isinstance(value, tuple): 
    204271            # See http://groups.google.com/group/comp.lang.python/ 
     
    207274            return ((value[1] & 0xFFFFFFFFL) | (value[0] << 32)) / 1e4 
    208275        return float(value) 
    209      
    210     def coerce_CURRENCY_to_decimal_Decimal(self, value): 
     276 
     277class decimalCURRENCY(adapters.DECIMAL): 
     278    def pull(self, value): 
    211279        # pywin32 build 205 began support for returning 
    212280        # COM Currency objects as decimal objects. 
     
    219287            return typerefs.decimal.Decimal(value) / 10000 
    220288        return value 
    221      
    222     def coerce_CURRENCY_to_fixedpoint_FixedPoint(self, value): 
     289 
     290class fpCURRENCY(adapters.FIXEDPOINT): 
     291    def pull(self, value): 
    223292        if isinstance(value, typerefs.decimal.Decimal): 
    224293            value = str(value) 
     
    234303            value = (value[1] & 0xFFFFFFFFL) | (value[0] << 32) 
    235304            return typerefs.fixedpoint.FixedPoint(value, 4) / 1e4 
    236      
    237     def coerce_any_to_unicode(self, value): 
    238         if isinstance(value, unicode): 
    239             # For some reason, value is already a unicode object. 
    240             return value 
    241          
    242         if isinstance(value, (basestring, buffer)): 
    243             try: 
    244                 return unicode(value, self.encoding) 
    245             except UnicodeError, exc: 
    246                 exc.args += (type(value),) 
    247         return unicode(value) 
     305 
     306 
     307class ADOAdapterSet(adapters.AdapterSet): 
     308     
     309    encoding = 'ISO-8859-1' 
     310    escapes = [("'", "''")] 
    248311 
    249312 
     
    270333        elif op1.sql == 'NULL': 
    271334            if op in (2, 8):    # '==', is 
    272                 self.stack.append(self.get_expr(op2.sql + " IS NULL", bool)
     335                self.append_expr(op2.sql + " IS NULL", bool
    273336            elif op in (3, 9):  # '!=', 'is not' 
    274                 self.stack.append(self.get_expr(op2.sql + " IS NOT NULL", bool)
     337                self.append_expr(op2.sql + " IS NOT NULL", bool
    275338            else: 
    276339                raise ValueError("Non-equality Null comparisons not allowed.") 
    277340        elif op2.sql == 'NULL': 
    278341            if op in (2, 8):    # '==', 'is' 
    279                 self.stack.append(self.get_expr(op1.sql + " IS NULL", bool)
     342                self.append_expr(op1.sql + " IS NULL", bool
    280343            elif op in (3, 9):  # '!=', 'is not' 
    281                 self.stack.append(self.get_expr(op1.sql + " IS NOT NULL", bool)
     344                self.append_expr(op1.sql + " IS NOT NULL", bool
    282345            else: 
    283346                raise ValueError("Non-equality Null comparisons not allowed.") 
    284347        else: 
     348            # Try to cast from one to the other. Try in both directions 
     349            # (but try to cast op2 first, since most of *my* expressions 
     350            # put the column first and a constant second ("Field < 3")). 
    285351            try: 
    286                 op1, op2 = self._compare_constants(op1, op2
     352                op2.sql = op2.dbtype.cast(op2.sql, op1.dbtype
    287353            except TypeError: 
    288                 self.stack.append(decompile.cannot_represent) 
    289                 self.imperfect = True 
    290                 return 
    291              
    292             if (isinstance(op2, decompile.SQLExpression) 
    293                 and issubclass(op2.pytype, basestring)): 
     354                try: 
     355                    op1.sql = op1.dbtype.cast(op1.sql, op2.dbtype) 
     356                except TypeError: 
     357                    self.stack.append(decompile.cannot_represent) 
     358                    self.imperfect = True 
     359                    return 
     360             
     361            if issubclass(op1.pytype, basestring) and issubclass(op2.pytype, basestring): 
    294362                atom = self._compare_strings(op1, op, op2) 
    295363                if atom is not None: 
     
    298366             
    299367            e = op1.sql + " " + self.sql_cmp_op[op] + " " + op2.sql 
    300             self.stack.append(self.get_expr(e, bool)
     368            self.append_expr(e, bool
    301369     
    302370    def _compare_strings(self, op1, op, op2): 
     
    308376        self.imperfect = True 
    309377     
    310     def binary_op(self, op): 
    311         op2, op1 = self.stack.pop(), self.stack.pop() 
    312         if op1 is decompile.cannot_represent or op2 is decompile.cannot_represent: 
    313             self.stack.append(decompile.cannot_represent) 
    314             return 
    315          
    316         t1, t2 = op1.pytype, op2.pytype 
    317          
    318         newsql = None 
    319         if t1 is datetime.date: 
    320             if t2 is datetime.date: 
    321                 if op == "-": 
    322                     newsql = self.DATEDIFF(op1.sql, op2.sql) 
    323             elif t2 is datetime.timedelta: 
    324                 if op == "+": 
    325                     newsql = self.DATEADD(op1.sql, op2.sql) 
    326                 elif op == "-": 
    327                     newsql = self.DATESUB(op1.sql, op2.sql) 
    328         elif t1 is datetime.datetime: 
    329             if t2 is datetime.datetime: 
    330                 if op == "-": 
    331                     newsql = self.DATETIMEDIFF(op1.sql, op2.sql) 
    332             elif t2 is datetime.timedelta: 
    333                 if op == "+": 
    334                     newsql = self.DATETIMEADD(op1.sql, op2.sql) 
    335                 elif op == "-": 
    336                     newsql = self.DATETIMESUB(op1.sql, op2.sql) 
    337         elif t1 is datetime.timedelta: 
    338             if t2 is datetime.timedelta: 
    339                 newsql = self.TIMEDELTAADD(op1, op, op2) 
    340             else: 
    341                 if op == "+": 
    342                     if t2 is datetime.date: 
    343                         newsql = self.DATEADD(op2.sql, op1.sql) 
    344                     elif t2 is datetime.datetime: 
    345                         newsql = self.DATETIMEADD(op2.sql, op1.sql) 
    346         else: 
    347             newsql = "(%s %s %s)" % (op1.sql, op, op2.sql) 
    348          
    349         if newsql is None: 
    350             raise TypeError("unsupported operand type(s) for %s: " 
    351                             "%r and %r" % (op, t1, t2)) 
    352          
    353         # re-use op1 
    354         op1.pytype = self.result_type[(t1, op, t2)] 
    355         op1.sql = newsql 
    356         if not op1.name.startswith("expr_"): 
    357             op1.name = "expr_%s" % op1.name 
    358         self.stack.append(op1) 
    359      
    360378    # --------------------------- Dispatchees --------------------------- # 
    361379     
    362380    def attr_startswith(self, tos, arg): 
    363381        self.imperfect = True 
    364         return self.get_expr(tos.sql + " LIKE '" + self.adapter.escape_like(arg.sql) + "%'", bool) 
     382        return self.get_expr(tos.sql + " LIKE '" + self.escape_like(arg.sql) + "%'", bool) 
    365383     
    366384    def attr_endswith(self, tos, arg): 
    367385        self.imperfect = True 
    368         return self.get_expr(tos.sql + " LIKE '%" + self.adapter.escape_like(arg.sql) + "'", bool) 
     386        return self.get_expr(tos.sql + " LIKE '%" + self.escape_like(arg.sql) + "'", bool) 
    369387     
    370388    def containedby(self, op1, op2): 
     
    378396            # Looking for text in a field. Use Like (reverse terms). 
    379397            return self.get_expr(op2.sql + " LIKE '%" + 
    380                                  self.adapter.escape_like(op1.sql) 
     398                                 self.escape_like(op1.sql) 
    381399                                 + "%'", bool) 
    382400        else: 
    383401            # Looking for field in (a, b, c) 
    384             atoms = [self.adapter.coerce(x) for x in op2.value] 
     402            atoms = [self.adapterset.default(type(x)).push(x) for x in op2.value] 
    385403            if atoms: 
    386404                return self.get_expr("%s IN (%s)" % 
     
    393411    def builtins_istartswith(self, x, y): 
    394412        # Like is already case-insensitive in ADO; so don't use LOWER(). 
    395         return self.get_expr(x.sql + " LIKE '" + self.adapter.escape_like(y.sql) + "%'", bool) 
     413        return self.get_expr(x.sql + " LIKE '" + self.escape_like(y.sql) + "%'", bool) 
    396414     
    397415    def builtins_iendswith(self, x, y): 
    398416        # Like is already case-insensitive in ADO; so don't use LOWER(). 
    399         return self.get_expr(x.sql + " LIKE '%" + self.adapter.escape_like(y.sql) + "'", bool) 
     417        return self.get_expr(x.sql + " LIKE '%" + self.escape_like(y.sql) + "'", bool) 
    400418     
    401419    def builtins_ieq(self, x, y): 
     
    569587         
    570588        cols = [] 
    571         get_pytype = self.db.typeadapter.python_type 
     589        typer = self.db.adapterset 
    572590        for row in data: 
    573591            # I tried passing criteria to OpenSchema, but passing None is 
     
    576594                continue 
    577595             
    578             dbtype = dbtypes[row[11]] 
    579             pytype = get_pytype(dbtype) 
     596            dbtypename = dbtypes[row[11]] 
     597            dbtype = typer.canonicalize(dbtypename)() 
     598            pytype = dbtype.default_pytype 
     599            if pytype is None: 
     600                raise TypeError("%r has no default pytype." % dbtype) 
    580601             
    581602            default = row[8] 
     
    589610             
    590611            name = str(row[3]) 
    591             c = geniusql.Column(pytype, dbtype, 
    592                                 default, hints={}, key=(name in pknames), 
     612            c = geniusql.Column(pytype, dbtype, default, 
     613                                key=(name in pknames), 
    593614                                name=name, qname=self.db.quote(name)) 
    594615             
     
    600621                c.autoincrement = True 
    601622             
    602             if dbtype in ("SMALLINT", "INTEGER", "TINYINT", 
    603                           "UNSIGNEDTINYINT", "UNSIGNEDSMALLINT", 
    604                           "UNSIGNEDINT", "BIGINT", "UNSIGNEDBIGINT"): 
    605                 c.hints['bytes'] = row[15] 
    606             elif dbtype in ("SINGLE", "DOUBLE"): 
    607                 c.hints['precision'] = row[15] 
    608                 c.hints['scale'] = row[16] 
    609             elif dbtype == "CURRENCY": 
    610                 # CURRENCY allows 15 places to the left of the decimal point, 
    611                 # and 4 places to the right. 
    612                 c.hints['precision'] = 19 
    613                 c.hints['scale'] = 4 
    614             elif dbtype in ("DECIMAL", "NUMERIC"): 
    615                 c.hints['precision'] = row[15] 
    616                 c.hints['scale'] = row[16] 
    617                 c.dbtype = "%s(%s, %s)" % (dbtype, row[15], row[16]) 
    618             elif dbtype in ("BSTR", "VARIANT", "BINARY", "CHAR", 
    619                             "VARCHAR", "VARBINARY", "WCHAR", "VARWCHAR"): 
     623            if dbtype in typer.known_types['int']: 
     624                dbtype.bytes = row[15] 
     625            elif dbtype in typer.known_types['float']: 
     626                dbtype.precision = row[15] 
     627                dbtype.scale = row[16] 
     628##            elif dbtype == "CURRENCY": 
     629##                # CURRENCY allows 15 places to the left of the decimal point, 
     630##                # and 4 places to the right. 
     631##                c.hints['precision'] = 19 
     632##                c.hints['scale'] = 4 
     633            elif dbtype in typer.known_types['numeric']: 
     634                dbtype.precision = row[15] 
     635                dbtype.scale = row[16] 
     636            elif (dbtype in typer.known_types['char'] or 
     637                  dbtype in typer.known_types['varchar']): 
    620638                if row[13]: 
    621639                    # row[13] will be a float 
    622                     c.hints['bytes'] = b = int(row[13]) 
    623                 else: 
    624                     # I'm kinda guessing on this. If we use "MEMO" in an 
    625                     # MSAccess CREATE statement, it comes back as "WCHAR", 
    626                     # and seems to support over 65536 bytes. 
    627                     c.hints['bytes'] = b = (2 ** 31) - 1 
    628                 c.dbtype = "%s(%s)" % (c.dbtype, b) 
    629             elif dbtype in ("LONGVARCHAR", "LONGVARBINARY", "LONGVARWCHAR")
    630                 if row[13]: 
    631                     # row[13] will be a float 
    632                     c.hints['bytes'] = b = int(row[13]
    633                     c.dbtype = "%s(%s)" % (c.dbtype, b) 
    634                 else: 
    635                     c.hints['bytes'] = 65535 
    636              
     640                    dbtype.bytes = b = int(row[13]) 
     641##                else: 
     642##                    # I'm kinda guessing on this. If we use "MEMO" in an 
     643##                    # MSAccess CREATE statement, it comes back as "WCHAR", 
     644##                    # and seems to support over 65536 bytes. 
     645##                    dbtype.bytes = b = (2 ** 31) - 1 
     646##            elif dbtype in ("LONGVARCHAR", "LONGVARBINARY", "LONGVARWCHAR"): 
     647##                if row[13]
     648##                    # row[13] will be a float 
     649##                    c.hints['bytes'] = b = int(row[13]) 
     650##                    c.dbtype = "%s(%s)" % (c.dbtype, b
     651##                else: 
     652##                    c.hints['bytes'] = 65535 
     653             
     654            c.adapter = typer.default(pytype, dbtype) 
    637655            cols.append(c) 
    638656        return cols 
     
    673691 
    674692 
    675 class ADOTypeAdapter(adapters.TypeAdapter): 
    676      
    677     _reverse_types = adapters.TypeAdapter._reverse_types.copy() 
    678     _reverse_types.update({ 
    679         "DBDATE": datetime.date, 
    680         "DBTIME": datetime.time, 
    681         "DBTIMESTAMP": datetime.datetime, 
    682          
    683         "UNSIGNEDTINYINT": int, 
    684         "UNSIGNEDSMALLINT": int, 
    685         "UNSIGNEDINT": int, 
    686         "BIT": bool, 
    687          
    688         "UNSIGNEDBIGINT": long, 
    689          
    690         "BSTR": str, 
    691         "VARIANT": str, 
    692         "BINARY": str, 
    693         "LONGVARCHAR": str, 
    694         "VARBINARY": str, 
    695         "LONGVARBINARY": str, 
    696          
    697         "WCHAR": unicode, 
    698         "VARWCHAR": unicode, 
    699         "LONGVARWCHAR": unicode, 
    700         }) 
    701      
    702     if typerefs.decimal: 
    703         _reverse_types["CURRENCY"] = typerefs.decimal.Decimal 
    704     elif typerefs.fixedpoint: 
    705         _reverse_types["CURRENCY"] = typerefs.fixedpoint.FixedPoint 
    706  
    707  
    708693class ADODatabase(geniusql.Database): 
    709694     
    710695    decompiler = ADOSQLDecompiler 
    711     adapterfromdb = AdapterFromADO() 
    712     typeadapter = ADOTypeAdapter() 
     696    adapterset = ADOAdapterSet() 
    713697     
    714698    #                               Naming                                # 
     
    722706            conn = self.connections.get() 
    723707        if isinstance(query, unicode): 
    724             query = query.encode(self.adaptertosql.encoding) 
     708            query = query.encode(self.adapterset.encoding) 
    725709         
    726710        self.log(query) 
     
    815799 
    816800 
    817  
    818 ########################################################################### 
    819 ##                                                                       ## 
    820 ##                             SQL Server                                ## 
    821 ##                                                                       ## 
    822 ########################################################################### 
    823  
    824  
    825 # "Sure, there are two 4-byte integers stored. But they are 
    826 # packed together into a BINARY(8). The first 4-byte being 
    827 # the elapsed number days since SQL Server's base date of 
    828 # 1900-01-01. The Second 4-bytes Store the Time of Day 
    829 # Represented as the Number of Milliseconds After Midnight." 
    830 # http://www.sql-server-performance.com/fk_datetime.asp 
    831  
    832 # Note also that SQL Server allows DATETIME in the range: 
    833 # "1753-01-01 00:00:00.0" to "9999-12-31 23:59:59.997". 
    834  
    835  
    836 class ADOSQLDecompiler_SQLServer(ADOSQLDecompiler): 
    837      
    838     def _compare_strings(self, op1, op, op2): 
    839         # ADO comparison operators for strings are case-insensitive. 
    840         if op < 6: 
    841             # ('<', '<=', '==', '!=', '>', '>=') 
    842             # Some operations on strings can be emulated with the 
    843             # Convert function. 
    844             return self.get_expr("Convert(binary, %s) %s Convert(binary, %s)" 
    845                                  % (op1.sql, self.sql_cmp_op[op], op2.sql), 
    846                                  bool) 
    847         else: 
    848             return ADOSQLDecompiler._compare_strings(self, op1, op, op2) 
    849      
    850     def DATEADD(dt, td): 
    851         """Return the SQL to add a timedelta to a date.""" 
    852         # Days, seconds seems like a good way to avoid overflow. 
    853         return ("DATEADD(dd, FLOOR(%s / 86400), " 
    854                 "DATEADD(ss, (%s %% 86400), %s))" 
    855                 % (td, td, dt)) 
    856     DATEADD = staticmethod(DATEADD) 
    857      
    858     def DATESUB(dt, td): 
    859         """Return the SQL to subtract a timedelta from a date.""" 
    860         return "(%s - FLOOR(%s / 86400.0))" % (dt, td) 
    861     DATESUB = staticmethod(DATESUB) 
    862      
    863     def DATEDIFF(d1, d2): 
    864         """Return the SQL to subtract one date from another.""" 
    865         # Amazing what a difference a little ".0" can make. 
    866         return "CAST(DATEDIFF(dd, %s, %s) * 86400.0 AS NUMERIC)" % (d2, d1) 
    867     DATEDIFF = staticmethod(DATEDIFF) 
    868      
    869     def DATETIMEADD(dt, td): 
    870         """Return the SQL to add a timedelta to a datetime.""" 
    871         return "(%s + (%s / 86400.0))" % (dt, td) 
    872     DATETIMEADD = staticmethod(DATETIMEADD) 
    873      
    874     def DATETIMEDIFF(d1, d2): 
    875         """Return the SQL to subtract one datetime from another.""" 
    876         return "CAST(CAST(%s - %s AS FLOAT) * 86400 AS NUMERIC)" % (d1, d2) 
    877     DATETIMEDIFF = staticmethod(DATETIMEDIFF) 
    878      
    879     def DATETIMESUB(dt, td): 
    880         """Return the SQL to subtract a timedelta from a datetime.""" 
    881         return "(%s - (%s / 86400.0))" % (dt, td) 
    882     DATETIMESUB = staticmethod(DATETIMESUB) 
    883      
    884     def TIMEDELTAADD(op1, op, op2): 
    885         return "(%s %s %s)" % (op1.sql, op, op2.sql) 
    886     TIMEDELTAADD = staticmethod(TIMEDELTAADD) 
    887      
    888     def builtins_now(self): 
    889         return self.get_expr("GETDATE()", datetime.datetime) 
    890      
    891     def builtins_today(self): 
    892         return self.get_expr("DATEADD(dd, DATEDIFF(dd, 0, getdate()), 0)", 
    893                              datetime.date) 
    894      
    895     def builtins_year(self, x): 
    896         return self.get_expr("DATEPART(year, " + x.sql + ")", int) 
    897      
    898     def builtins_month(self, x): 
    899         return self.get_expr("DATEPART(month, " + x.sql + ")", int) 
    900      
    901     def builtins_day(self, x): 
    902         return self.get_expr("DATEPART(day, " + x.sql + ")", int) 
    903      
    904     def builtins_utcnow(self): 
    905         return self.get_expr("GETUTCDATE()", datetime.datetime) 
    906  
    907  
    908 class AdapterToADOSQL_SQLServer(adapters.AdapterToSQL): 
    909      
    910     encoding = 'ISO-8859-1' 
    911      
    912     escapes = [("'", "''")] 
    913     like_escapes = [("[", "[[]"), ("%", "[%]"), ("_", "[_]"), 
    914                     ("?", "[?]"), ("#", "[#]")] 
    915      
    916     # These are not the same as coerce_bool_to_any (which is used on one side of  
    917     # a comparison). Instead, these are used when the whole (sub)expression 
    918     # is True or False, e.g. "WHERE TRUE", or "WHERE TRUE and 'a'.'b' = 3". 
    919     bool_true = "(1=1)" 
    920     bool_false = "(1=0)" 
    921      
    922     def coerce_bool_to_any(self, value): 
    923         if value: 
    924             return '1' 
    925         return '0' 
    926      
    927     def cast_VARCHAR_to_int(self, colref): 
    928         return ("(CASE WHEN ISNUMERIC(%s)=1 THEN CAST(%s AS int) END)" 
    929                 % (colref, colref)) 
    930  
    931  
    932 class AdapterFromADOSQL_SQLServer(AdapterFromADO): 
    933      
    934     def coerce_any_to_datetime_time(self, value): 
    935         # Floats returned from SQL Server will be 2 days off 
    936         # because its epoch is 2 days later than MS Access. 
    937         return AdapterFromADO.coerce_any_to_datetime_time(self, float(value) - 2) 
    938      
    939     def coerce_any_to_datetime_timedelta(self, value): 
    940         # We're using the fallback type for timedelta (secs * 86400). 
    941         days, secs = divmod(long(value), 86400) 
    942         return datetime.timedelta(int(days), int(secs)) 
    943  
    944  
    945 class TypeAdapter_SQLServer(ADOTypeAdapter): 
    946      
    947     # Hm. Docs say 38, but I can't seem to get more than 12 working. 
    948     # They must mean 38 binary digits; math.log(2 ** 38, 10) = 11.4+ 
    949     numeric_max_precision = 12 
    950     numeric_max_bytes = 6 
    951      
    952     def coerce_bool(self, hints): 
    953         return "BIT" 
    954      
    955     def coerce_datetime_datetime(self, hints): 
    956         return "DATETIME" 
    957      
    958     def coerce_datetime_date(self, hints): 
    959         return "DATETIME" 
    960      
    961     def coerce_datetime_time(self, hints): 
    962         return "DATETIME" 
    963      
    964     def int_type(self, bytes): 
    965         """Return a datatype which can handle the given number of bytes.""" 
    966         if bytes <= 2: 
    967             return "SMALLINT" 
    968         elif bytes <= 4: 
    969             return "INTEGER" 
    970         elif bytes <= 8: 
    971             # BIGINT is usually 8 bytes 
    972             return "BIGINT" 
    973         else: 
    974             # Anything larger than 8 bytes, use decimal/numeric. 
    975             return "NUMERIC(%s, 0)" % (bytes * 2) 
    976      
    977     def coerce_str(self, hints): 
    978         # The bytes hint does not reflect the usual 4-byte base for varchar. 
    979         bytes = int(hints.get('bytes', 255)) 
    980          
    981         if bytes == 0 or bytes > 8000: 
    982             # Okay, what the @#$%& is wrong with Redmond??!?! We can't even 
    983             # compare TEXT or NTEXT fields??!? Fine. We'll deny such, and 
    984             # warn the deployer with less swearing and exclamation points. 
    985             errors.warn("You have defined a string property without " 
    986                         "limiting its length. Microsoft SQL Server does " 
    987                         "not allow comparisons on string fields larger " 
    988                         "than 8000 characters. Some of your data may be " 
    989                         "truncated.") 
    990             bytes = 8000 
    991          
    992         # 8000 *bytes* is the absolute upper limit, based on T_SQL docs for 
    993         # varchar/varbinary. If there are further fields defined for the 
    994         # class, or the codepage uses a double-byte character set, we still 
    995         # might exceed the max size (8060) for a record. We could calc the 
    996         # total requested record size, and adjust accordingly. Meh. 
    997         return "VARCHAR(%s)" % bytes 
    998  
    999  
    1000 class SQLServerTable(ADOTable): 
    1001      
    1002     def _rename(self, oldcol, newcol): 
    1003         self.schema.db.execute_ddl("EXEC sp_rename '%s.%s', '%s', 'COLUMN'" % 
    1004                                    (self.name, oldcol.name, newcol.name)) 
    1005      
    1006     def _grab_new_ids(self, idkeys, conn): 
    1007         """Insert a row using the table's SERIAL field.""" 
    1008         # For some reason, using SCOPE_IDENTITY or IDENTITY failed (returned 
    1009         # None) when retrieving ID's just after a 99-thread-test ran. Moving 
    1010         # the multithreading test fixed it. IDENT_CURRENT worked regardless. 
    1011         data, _ = self.schema.db.fetch("SELECT IDENT_CURRENT('%s');" 
    1012                                        % self.qname, conn) 
    1013         return {idkeys[0]: data[0][0]} 
    1014  
    1015  
    1016 class SQLServerConnectionManager(ADOConnectionManager): 
    1017      
    1018     default_isolation = "READ COMMITTED" 
    1019  
    1020  
    1021 class SQLServerSchema(ADOSchema): 
    1022      
    1023     tableclass = SQLServerTable 
    1024      
    1025     def create_database(self): 
    1026         conn = self.db.connections._get_conn(master=True) 
    1027         self.db.execute_ddl("CREATE DATABASE %s;" % self.qname, conn) 
    1028         conn.Close() 
    1029         self.clear() 
    1030      
    1031     def drop_database(self): 
    1032         conn = self.db.connections._get_conn(master=True) 
    1033         self.db.execute_ddl("DROP DATABASE %s;" % self.qname, conn) 
    1034         conn.Close() 
    1035         self.clear() 
    1036      
    1037     def columnclause(self, column): 
    1038         """Return a clause for the given column for CREATE or ALTER TABLE. 
    1039          
    1040         This will be of the form: 
    1041             name type [DEFAULT x|IDENTITY(initial, 1) NOT NULL] 
    1042         """ 
    1043         dbtype = column.dbtype 
    1044          
    1045         clause = "" 
    1046         if column.autoincrement: 
    1047             if dbtype not in ("BOOLEAN", "SMALLINT", "INTEGER", "BIGINT"): 
    1048                 raise ValueError("SQL Server does not allow IDENTITY " 
    1049                                  "columns of type %r" % dbtype) 
    1050             clause = " IDENTITY(%s, 1) NOT NULL" % column.initial 
    1051         else: 
    1052             # SQL Server does not allow a column to have 
    1053             # both an IDENTITY clause and a DEFAULT clause. 
    1054             default = column.default or "" 
    1055             if default: 
    1056                 clause = self.db.adaptertosql.coerce(default, dbtype) 
    1057                 clause = " DEFAULT %s" % clause 
    1058          
    1059         return '%s %s%s' % (column.qname, dbtype, clause) 
    1060  
    1061  
    1062 class SQLServerDatabase(ADODatabase): 
    1063      
    1064     decompiler = ADOSQLDecompiler_SQLServer 
    1065     adaptertosql = AdapterToADOSQL_SQLServer() 
    1066     adapterfromdb = AdapterFromADOSQL_SQLServer() 
    1067     typeadapter = TypeAdapter_SQLServer() 
    1068     connectionmanager = SQLServerConnectionManager 
    1069     schemaclass = SQLServerSchema 
    1070      
    1071     def __init__(self, **kwargs): 
    1072         ADODatabase.__init__(self, **kwargs) 
    1073         if "2005" in self.version(): 
    1074             self.connections.isolation_levels.append("SNAPSHOT") 
    1075      
    1076     def version(self): 
    1077         conn = self.connections._get_conn(master=True) 
    1078         adov = conn.Version 
    1079         data, coldefs = self.fetch("SELECT @@VERSION;", conn) 
    1080         sqlv, = data[0] 
    1081         conn.Close() 
    1082         del conn 
    1083         return "ADO Version: %s\n%s" % (adov, sqlv) 
    1084      
    1085     def is_timeout_error(self, exc): 
    1086         """If the given exception instance is a lock timeout, return True. 
    1087          
    1088         This should return True for errors which arise from transaction 
    1089         locking timeouts; for example, if the database prevents 'dirty 
    1090         reads' by raising an error. 
    1091         """ 
    1092         # com_error: (-2147352567, 'Exception occurred.', 
    1093         #   (0, 'Microsoft OLE DB Provider for SQL Server', 
    1094         #    'Timeout expired', None, 0, -2147217871), None, 
    1095         #    "UPDATE [testVet] SET [City] = 'Tehachapi' ... ;") 
    1096         if not isinstance(exc, pywintypes.com_error): 
    1097             return False 
    1098         return exc.args[2][5] == -2147217871 
    1099  
    1100  
    1101  
    1102 ########################################################################### 
    1103 ##                                                                       ## 
    1104 ##                             MS Access                                 ## 
    1105 ##                                                                       ## 
    1106 ########################################################################### 
    1107  
    1108  
    1109 class ADOSQLDecompiler_MSAccess(ADOSQLDecompiler): 
    1110     sql_cmp_op = ('<', '<=', '=', '<>', '>', '>=', 'in', 'not in') 
    1111      
    1112     epoch = datetime.datetime(1899, 12, 30) 
    1113      
    1114     def _compare_strings(self, op1, op, op2): 
    1115         # ADO comparison operators for strings are case-insensitive. 
    1116         if op < 6: 
    1117             # ('<', '<=', '==', '!=', '>', '>=') 
    1118             # Some operations on strings can be emulated with the 
    1119             # StrComp function. Oddly enough, "StrComp(x, y) op 0" 
    1120             # is the same as "x op y" in most cases. 
    1121             return self.get_expr("StrComp(%s, %s) %s 0" % 
    1122                                  (op1.sql, op2.sql, self.sql_cmp_op[op]), 
    1123                                  bool) 
    1124         else: 
    1125             return ADOSQLDecompiler._compare_strings(self, op1, op, op2) 
    1126      
    1127     def builtins_now(self): 
    1128         return self.get_expr("Now()", datetime.datetime) 
    1129      
    1130     def builtins_today(self): 
    1131         return self.get_expr("DateValue(Now())", datetime.date) 
    1132      
    1133     def builtins_year(self, x): 
    1134         return self.get_expr("Year(" + x.sql + ")", int) 
    1135      
    1136     def builtins_month(self, x): 
    1137         return self.get_expr("Month(" + x.sql + ")", int) 
    1138      
    1139     def builtins_day(self, x): 
    1140         return self.get_expr("Day(" + x.sql + ")", int) 
    1141      
    1142     def DATEADD(dt, td): 
    1143         """Return the SQL to add a timedelta to a date.""" 
    1144         # Important to use Fix (instead of CLng, for example) 
    1145         # for negative numbers. 
    1146         return "DateAdd('d', Fix(%s), %s)" % (td, dt) 
    1147     DATEADD = staticmethod(DATEADD) 
    1148      
    1149     def DATEDIFF(d1, d2): 
    1150         """Return the SQL to subtract one date from another.""" 
    1151         # Important to use Fix (instead of CLng, for example) 
    1152         # for negative numbers. 
    1153         return "CDate(Fix(%s) - Fix(%s))" % (d1, d2) 
    1154     DATEDIFF = staticmethod(DATEDIFF) 
    1155     DATESUB = DATEDIFF 
    1156      
    1157     def DATETIMEADD(dt, td): 
    1158         """Return the SQL to add a timedelta to a datetime.""" 
    1159         return "CDate(%s + %s)" % (dt, td) 
    1160     DATETIMEADD = staticmethod(DATETIMEADD) 
    1161      
    1162     def DATETIMEDIFF(d1, d2): 
    1163         """Return the SQL to subtract one (datetime or date expr) from another.""" 
    1164         return "CDate(%s - %s)" % (d1, d2) 
    1165     DATETIMEDIFF = staticmethod(DATETIMEDIFF) 
    1166     DATETIMESUB = DATETIMEDIFF 
    1167      
    1168     def TIMEDELTAADD(op1, op, op2): 
    1169         return "CDate(%s %s %s)" % (op1.sql, op, op2.sql) 
    1170     TIMEDELTAADD = staticmethod(TIMEDELTAADD) 
    1171  
    1172  
    1173 class TypeAdapter_MSAccess(ADOTypeAdapter): 
    1174     # http://msdn2.microsoft.com/en-us/library/ms714540.aspx 
    1175     # http://office.microsoft.com/en-us/access/HP010322481033.aspx 
    1176      
    1177     # Hm. Docs say 28/38, but I can't seem to get more than 12 working. 
    1178     numeric_max_precision = 12 
    1179     numeric_max_bytes = 6 
    1180      
    1181     _reverse_types = ADOTypeAdapter._reverse_types.copy() 
    1182     _reverse_types.update({ 
    1183         "LONG": int, 
    1184         "MEMO": str, 
    1185         }) 
    1186      
    1187     def coerce_bool(self, hints): return "BIT" 
    1188      
    1189     def coerce_datetime_datetime(self, hints): return "DATETIME" 
    1190     def coerce_datetime_date(self, hints): return "DATETIME" 
    1191     def coerce_datetime_time(self, hints): return "DATETIME" 
    1192     def coerce_datetime_timedelta(self, hints): return "DATETIME" 
    1193      
    1194     def int_type(self, bytes): 
    1195         if bytes <= 2: 
    1196             return "INTEGER" 
    1197         elif bytes <= 4: 
    1198             return "LONG" 
    1199         else: 
    1200             # Anything larger than 4 bytes, use decimal/numeric. 
    1201             return "DECIMAL" 
    1202      
    1203     def coerce_str(self, hints): 
    1204         # The bytes hint shall not reflect the usual 4-byte base for varchar. 
    1205         bytes = int(hints.get('bytes', 255)) 
    1206          
    1207         # 255 chars is the upper limit for TEXT / VARCHAR in MS Access. 
    1208         if bytes == 0 or bytes > 255: 
    1209             # MEMO is 1 GB max when set programatically (only 64K when set 
    1210             # in Access UI). But then, 1 GB is the limit for the whole DB. 
    1211             # Note that OpenSchema will return a DATA_TYPE of "WCHAR". 
    1212             return "MEMO" 
    1213          
    1214         return "VARCHAR(%s)" % bytes 
    1215  
    1216  
    1217 class AdapterToADOSQL_MSAccess(adapters.AdapterToSQL): 
    1218     """Coerce Expression constants to ADO SQL.""" 
    1219      
    1220     encoding = 'ISO-8859-1' 
    1221      
    1222     escapes = [("'", "''")] 
    1223     like_escapes = [("[", "[[]"), ("%", "[%]"), ("_", "[_]"), 
    1224                     ("?", "[?]"), ("#", "[#]")] 
    1225      
    1226     def coerce_datetime_datetime_to_any(self, value): 
    1227         return ('#%s/%s/%s %02d:%02d:%02d#' % 
    1228                 (value.month, value.day, value.year, 
    1229                  value.hour, value.minute, value.second)) 
    1230      
    1231     def coerce_datetime_date_to_any(self, value): 
    1232         return '#%s/%s/%s#' % (value.month, value.day, value.year) 
    1233      
    1234     def coerce_datetime_time_to_any(self, value): 
    1235         return '#%02d:%02d:%02d#' % (value.hour, value.minute, value.second) 
    1236      
    1237     def coerce_datetime_timedelta_to_any(self, value): 
    1238         # This took a lot of work to get right, because timedelta 
    1239         # seconds are positive even if the days are negative. 
    1240         # So is the fractional portion of a negative Access Date! 
    1241         # Very important we use repr here so we get all 17 decimal 
    1242         # digits in the float. 
    1243         return ("CDate(#12/30/1899# + (%r) + %r)" % 
    1244                 (value.days, (value.seconds / 86400.0))) 
    1245  
    1246  
    1247 class MSAccessTable(ADOTable): 
    1248      
    1249     def delete(self, **inputs): 
    1250         """Delete all rows matching the given identifier inputs.""" 
    1251         # MS Access needs an asterisk to delete 
    1252         self.schema.db.execute('DELETE * FROM %s WHERE %s;' % 
    1253                                (self.qname, self.id_clause(**inputs))) 
    1254      
    1255     def delete_all(self, **inputs): 
    1256         """Delete all rows matching the given inputs.""" 
    1257         # MS Access needs an asterisk to delete 
    1258         self.schema.db.execute('DELETE * FROM %s WHERE %s;' % 
    1259                                (self.qname, self.whereclause(**inputs))) 
    1260      
    1261     def _grab_new_ids(self, idkeys, conn): 
    1262         data, _ = self.schema.db.fetch("SELECT @@IDENTITY;", conn) 
    1263         return {idkeys[0]: data[0][0]} 
    1264  
    1265  
    1266 class MSAccessConnectionManager(ADOConnectionManager): 
    1267      
    1268     poolsize = 0 
    1269     default_isolation = "READ UNCOMMITTED" 
    1270     isolation_levels = ["READ UNCOMMITTED",] 
    1271      
    1272     def _set_factory(self): 
    1273         # MS Access can't use a pool, because there doesn't seem 
    1274         # to be a commit timeout. See http://support.microsoft.com/kb/200300 
    1275         # for additional synchronization issues. 
    1276         self._factory = conns.SingleConnection(self._get_conn, self._del_conn) 
    1277      
    1278     def isolate(self, conn, isolation=None): 
    1279         """Set the isolation level of the given connection. 
    1280          
    1281         If 'isolation' is None, our default_isolation will be used for new 
    1282         connections. Valid values for the 'isolation' argument may be native 
    1283         values for your particular database. However, it is recommended you 
    1284         pass items from the global 'levels' list instead; these will be 
    1285         automatically replaced with native values. 
    1286          
    1287         For many databases, this must be executed after START TRANSACTION. 
    1288         """ 
    1289         if isolation is None: 
    1290             isolation = self.default_isolation 
    1291          
    1292         if isinstance(isolation, _isolation.IsolationLevel): 
    1293             # Map the given IsolationLevel object to a native value. 
    1294             isolation = isolation.name 
    1295             if isolation not in self.isolation_levels: 
    1296                 raise ValueError("IsolationLevel %r not allowed by %s." 
    1297                                  % (isolation, self.__class__.__name__)) 
    1298          
    1299         # No action to take, since you can't actually set iso level. 
    1300         pass 
    1301  
    1302  
    1303 class MSAccessSchema(ADOSchema): 
    1304      
    1305     tableclass = MSAccessTable 
    1306      
    1307     def _get_columns(self, tablename, conn=None): 
    1308         cols = ADOSchema._get_columns(self, tablename, conn) 
    1309         if conn is None: 
    1310             conn = self.db.connections._factory() 
    1311          
    1312         try: 
    1313             # Horrible hack to get autoincrement property 
    1314             query = "SELECT * FROM %s WHERE FALSE;" % self.db.quote(tablename) 
    1315             bareconn = conn 
    1316             if hasattr(conn, 'conn'): 
    1317                 # 'conn' is a ConnectionWrapper object, which .Open 
    1318                 # won't accept. Pass the unwrapped connection instead. 
    1319                 bareconn = conn.conn 
    1320              
    1321             # Call conn.Open(query) directly, skipping win32com overhead. 
    1322             res, rows_affected = conn._oleobj_.InvokeTypes(6, 0, 1, (9, 0), 
    1323                                             ((8, 1), (16396, 18), (3, 49)), 
    1324                                             # *args = 
    1325                                             query, pythoncom.Missing, -1) 
    1326         except pywintypes.com_error, x: 
    1327             try: 
    1328                 res.InvokeTypes(*Recordset_Close) 
    1329             except: 
    1330                 pass 
    1331             res = None 
    1332             x.args += (query, ) 
    1333             conn = None 
    1334              
    1335             try: 
    1336                 if "no read permission" in x.args[2][2]: 
    1337                     conn = None 
    1338                     return [] 
    1339             except IndexError: 
    1340                 pass 
    1341              
    1342             # "raise x" here or we could get the traceback of the inner try. 
    1343             raise x 
    1344          
    1345         resFields = res.InvokeTypes(*Recordset_Fields) 
    1346         for c in cols: 
    1347             f = resFields.InvokeTypes(0, 0, 2, (9, 0), ((12, 1),), c.name) 
    1348             fprops = f.InvokeTypes(*Field_Properties) 
    1349             fprop = fprops.InvokeTypes(0, 0, 2, (9, 0), ((12, 1), ), "ISAUTOINCREMENT") 
    1350             c.autoincrement = fprop.InvokeTypes(*Property_Value) 
    1351          
    1352         try: 
    1353             res.InvokeTypes(*Recordset_Close) 
    1354         except: 
    1355             pass 
    1356         conn = None 
    1357          
    1358         return cols 
    1359      
    1360     def columnclause(self, column): 
    1361         """Return a clause for the given column for CREATE or ALTER TABLE. 
    1362          
    1363         This will be of the form: 
    1364             name type [DEFAULT x|AUTOINCREMENT(initial, 1)] 
    1365         """ 
    1366         dbtype = column.dbtype 
    1367          
    1368         if column.autoincrement: 
    1369             # MS Access does not allow a column to have 
    1370             # both an AUTOINCREMENT clause and a DEFAULT clause. 
    1371             # It also needs no type in this case. 
    1372             dbtype = "AUTOINCREMENT(%s, 1)" % column.initial 
    1373         else: 
    1374             default = column.default or "" 
    1375             if default: 
    1376                 defspec = self.db.adaptertosql.coerce(default, dbtype) 
    1377                 if isinstance(default, (int, long)): 
    1378                     # Crazy quote hack to get a numeric default to work. 
    1379                     defspec = "'%s'" % defspec 
    1380                 dbtype = "%s DEFAULT %s" % (dbtype, defspec) 
    1381          
    1382         return '%s %s' % (column.qname, dbtype) 
    1383      
    1384     def create_database(self): 
    1385         # By not providing an Engine Type, it defaults to 5 = Access 2000. 
    1386         cat = win32com.client.Dispatch(r'ADOX.Catalog') 
    1387         cat.Create(self.db.connections.Connect) 
    1388         cat.ActiveConnection.Close() 
    1389         self.clear() 
    1390      
    1391     def drop_database(self): 
    1392         # Must shut down our only connection to avoid 
    1393         # "Permission denied" error on os.remove call below. 
    1394         self.db.connections.shutdown() 
    1395          
    1396         import os 
    1397         # This should accept relative or absolute paths 
    1398         if os.path.exists(self.name): 
    1399             os.remove(self.name) 
    1400          
    1401         self.clear() 
    1402  
    1403  
    1404 class MSAccessDatabase(ADODatabase): 
    1405      
    1406     decompiler = ADOSQLDecompiler_MSAccess 
    1407     adaptertosql = AdapterToADOSQL_MSAccess() 
    1408     typeadapter = TypeAdapter_MSAccess() 
    1409     connectionmanager = MSAccessConnectionManager 
    1410     schemaclass = MSAccessSchema 
    1411      
    1412     def version(self): 
    1413         conn = win32com.client.Dispatch(r'ADODB.Connection') 
    1414         v = conn.Version 
    1415         del conn 
    1416         return "ADO Version: %s" % v 
    1417  
    1418  
    1419801def gen_py(): 
    1420802    """Auto generate .py support for ADO 2.7+""" 
  • trunk/geniusql/providers/firebird.py

    r53 r54  
    2121    # Values in these two lists should be strings encoded with self.encoding. 
    2222    escapes = [("'", "''")] 
    23     like_escapes = [("\\", r"\\"), ("%", r"\%"), ("_", r"\_")] 
    24      
    25     # Firebird doesn't have true or false keywords. 
    26     bool_true = "1=1" 
    27     bool_false = "1=0" 
    2823     
    2924    def coerce_bool_to_any(self, value): 
     
    182177class FirebirdSQLDecompiler(decompile.SQLDecompiler): 
    183178     
     179    # Firebird doesn't have true or false keywords. 
     180    bool_true = "1=1" 
     181    bool_false = "1=0" 
     182     
     183    like_escapes = [("\\", r"\\"), ("%", r"\%"), ("_", r"\_")] 
     184     
    184185    # --------------------------- Dispatchees --------------------------- # 
    185186     
     
    189190    def attr_endswith(self, tos, arg): 
    190191        return self.get_expr(tos.sql + " LIKE '%" + 
    191                              self.adapter.escape_like(arg.sql) + 
     192                             self.escape_like(arg.sql) + 
    192193                             "' ESCAPE '\\'", bool) 
    193194     
     
    195196        if op1.value is not None: 
    196197            # Looking for text in a field. Use Like (reverse terms). 
    197             like = self.adapter.escape_like(op1.sql) 
     198            like = self.escape_like(op1.sql) 
    198199            return self.get_expr(op2.sql + " LIKE '%" + like + 
    199200                                 "%' ESCAPE '\\'", bool) 
     
    222223    def builtins_istartswith(self, x, y): 
    223224        return self.get_expr("UPPER(" + x.sql + ") LIKE '" + 
    224                              self.adapter.escape_like(y.sql) + 
     225                             self.escape_like(y.sql) + 
    225226                             "%' ESCAPE '\\'", bool) 
    226227     
    227228    def builtins_iendswith(self, x, y): 
    228229        return self.get_expr("UPPER(" + x.sql + ") LIKE '%" + 
    229                              self.adapter.escape_like(y.sql) + 
     230                             self.escape_like(y.sql) + 
    230231                             "' ESCAPE '\\'", bool) 
    231232     
  • trunk/geniusql/providers/mysql.py

    r53 r54  
    2323     
    2424    escapes = [("'", "''"), ("\\", r"\\")] 
    25     like_escapes = [("%", r"\%"), ("_", r"\_")] 
    26      
    27     # TRUE and FALSE only work with 4.1 or better. 
    28     bool_true = "1" 
    29     bool_false = "0" 
    3025     
    3126    def coerce_str_to_any(self, value, skip_encoding=False): 
     
    5146 
    5247class MySQLDecompiler(decompile.SQLDecompiler): 
     48     
     49    # TRUE and FALSE only work with 4.1 or better. 
     50    bool_true = "1" 
     51    bool_false = "0" 
     52     
     53    like_escapes = [("%", r"\%"), ("_", r"\_")] 
    5354     
    5455    def DAY_SECOND(td): 
     
    117118 
    118119class MySQLDecompiler411(MySQLDecompiler): 
     120     
     121    # TRUE and FALSE only work with 4.1 or better. 
     122    bool_true = "TRUE" 
     123    bool_false = "FALSE" 
     124     
    119125    # Before MySQL 4.1.1, BINARY comparisons could use UPPER() 
    120126    # or LOWER() to perform case-insensitive comparisons. Newer 
     
    127133            return self.get_expr("CONVERT(" + op2.sql + 
    128134                                 " USING utf8) LIKE '%" + 
    129                                  self.adapter.escape_like(op1.sql) 
     135                                 self.escape_like(op1.sql) 
    130136                                 + "%'", bool) 
    131137        else: 
     
    138144    def builtins_istartswith(self, x, y): 
    139145        return self.get_expr("CONVERT(" + x.sql + " USING utf8) LIKE '" + 
    140                              self.adapter.escape_like(y.sql) + "%'", bool) 
     146                             self.escape_like(y.sql) + "%'", bool) 
    141147     
    142148    def builtins_iendswith(self, x, y): 
    143149        return self.get_expr("CONVERT(" + x.sql + " USING utf8) LIKE '%" + 
    144                              self.adapter.escape_like(y.sql) + "'", bool) 
     150                             self.escape_like(y.sql) + "'", bool) 
    145151     
    146152    def builtins_ieq(self, x, y): 
  • trunk/geniusql/providers/postgres.py

    r53 r54  
    1616 
    1717 
    18 class AdapterToPgSQL(adapters.AdapterToSQL): 
    19      
    20     like_escapes = [("%", r"\\%"), ("_", r"\\_")] 
    21      
    22     # Do these need to know if "SHOW DateStyle;" != "ISO, MDY" ? 
    23     def coerce_datetime_datetime_to_any(self, value): 
    24         return ("'%04d-%02d-%02d %02d:%02d:%02d.%06d'" % 
    25                 (value.year, value.month, value.day, 
    26                  value.hour, value.minute, value.second, 
    27                  value.microsecond)) 
    28      
    29     def coerce_datetime_date_to_any(self, value): 
    30         return "'%04d-%02d-%02d'" % (value.year, value.month, value.day) 
    31      
    32     def coerce_datetime_time_to_any(self, value): 
    33         return ("'%02d:%02d:%02d.%06d'" % 
    34                 (value.hour, value.minute, value.second, value.microsecond)) 
    35      
    36     def coerce_datetime_timedelta_to_interval(self, value): 
     18class PgTIMESTAMP(adapters.SQL92TIMESTAMP): 
     19    def binary_op(self, op1, op, op2): 
     20        sql1 = op1.sql 
     21        if op1.value is not None and op2.pytype is datetime.date: 
     22            # Postgres assumes a "date" is actually midnight, so we 
     23            # need to drop any h:m:s from our interval. 
     24            sql1 = "interval '%s days'" % op1.value.days 
     25        return "(%s %s %s)" % (sql1, op, op2.sql) 
     26 
     27 
     28class PgDATE(adapters.SQL92DATE): 
     29     
     30    def binary_op(self, op1, op, op2): 
     31        sql1, sql2 = op1.sql, op2.sql 
     32        if op2.pytype is datetime.timedelta and op2.value is not None: 
     33            # Postgres assumes a "date" is actually midnight, so we 
     34            # need to drop any h:m:s from our interval. 
     35            sql2 = "interval '%s days'" % op2.value.days 
     36        elif op2.pytype is datetime.date: 
     37            # Cast to timestamp to achieve an INTERVAL result 
     38            sql1 = "%s::TIMESTAMP" % sql1 
     39            sql2 = "%s::TIMESTAMP" % sql2 
     40        return "(%s %s %s)" % (sql1, op, sql2) 
     41 
     42 
     43class INTERVAL(adapters.Adapter): 
     44     
     45    def push(self, value): 
    3746        # Ignore microseconds for now 
    3847        h, m = divmod(value.seconds, 3600) 
     
    4049        return "interval '%s %s:%s:%s'" % (value.days, h, m, s) 
    4150     
    42     def coerce_any_to_bytea(self, value): 
    43         # See http://www.postgresql.org/docs/8.1/interactive/datatype-binary.html 
    44         value = pickle.dumps(value, 2) 
    45         def repl(char): 
    46             o = ord(char) 
    47             if o <= 31 or o == 39 or o == 92 or o >= 127: 
    48                 return r"\\%03d" % int(oct(o)) 
    49             return char 
    50         return "'%s'::bytea" % "".join(map(repl, value)) 
    51      
    52     def do_pickle(self, value): 
    53         value = pickle.dumps(value, 2) 
    54         value = self.coerce_str_to_any(value, skip_encoding=False) 
    55         return value 
    56     coerce_dict_to_any = do_pickle 
    57     coerce_list_to_any = do_pickle 
    58     coerce_tuple_to_any = do_pickle 
    59      
    60     def coerce_str_to_any(self, value, skip_encoding=False): 
    61         if not skip_encoding and not isinstance(value, str): 
    62             value = value.encode(self.encoding) 
    63         for pat, repl in self.escapes: 
    64             value = value.replace(pat, repl) 
    65          
    66         # Escape octal sequences 
    67         value = escape_oct.sub(replace_oct, value) 
    68         return "'" + value + "'" 
    69      
    70     def coerce_float_to_REAL(self, value): 
    71         # Use quotes to restrict the value to single precision, so that 
    72         # comparisons work between existing values and supplied constants. 
    73         # See http://archives.postgresql.org/pgsql-bugs/2004-02/msg00062.php 
    74         return "'%r'" % value 
    75     coerce_float_to_FLOAT4 = coerce_float_to_REAL 
    76  
    77  
    78 class AdapterFromPg(adapters.AdapterFromDB): 
    79      
    80     def coerce_any_to_str(self, value): 
    81         # Unescape octal sequences 
    82         value = unescape_oct.sub(replace_unoct, value) 
    83         if isinstance(value, unicode): 
    84             return value.encode(self.encoding) 
    85         else: 
    86             return str(value) 
    87      
    88     def coerce_any_to_datetime_timedelta(self, value): 
     51    def pull(self, value): 
     52        if isinstance(value, datetime.timedelta): 
     53            return value 
     54         
    8955        # When an interval is returned, it will be of typename 
    9056        # "interval" or "TIMESTAMP". 
     
    12187 
    12288 
     89class BYTEA(adapters.Pickler): 
     90     
     91    def push(self, value): 
     92        # See http://www.postgresql.org/docs/8.1/interactive/datatype-binary.html 
     93        value = pickle.dumps(value, 2) 
     94        def repl(char): 
     95            o = ord(char) 
     96            if o <= 31 or o == 39 or o == 92 or o >= 127: 
     97                return r"\\%03d" % int(oct(o)) 
     98            return char 
     99        return "'%s'::bytea" % "".join(map(repl, value)) 
     100 
     101 
     102class PgVARCHAR(adapters.SQL92VARCHAR): 
     103     
     104    def push(self, value): 
     105        if not isinstance(value, str): 
     106            value = value.encode(self.encoding) 
     107        for pat, repl in self.escapes: 
     108            value = value.replace(pat, repl) 
     109         
     110        # Escape octal sequences 
     111        value = escape_oct.sub(replace_oct, value) 
     112        return "'" + value + "'" 
     113     
     114    def pull(self, value): 
     115        # Unescape octal sequences 
     116        value = unescape_oct.sub(replace_unoct, value) 
     117        if isinstance(value, unicode): 
     118            return value.encode(self.encoding) 
     119        else: 
     120            return str(value) 
     121 
     122 
     123class PgUNICODE(adapters.UNICODE): 
     124     
     125    def push(self, value): 
     126        if not isinstance(value, str): 
     127            value = value.encode(self.encoding) 
     128        for pat, repl in self.escapes: 
     129            value = value.replace(pat, repl) 
     130         
     131        # Escape octal sequences 
     132        value = escape_oct.sub(replace_oct, value) 
     133        return "'" + value + "'" 
     134     
     135    def pull(self, value): 
     136        # Unescape octal sequences 
     137        value = unescape_oct.sub(replace_unoct, value) 
     138        if isinstance(value, unicode): 
     139            return value 
     140        else: 
     141            return unicode(value, self.encoding) 
     142 
     143 
     144class Pickler(adapters.Pickler): 
     145     
     146    def push(self, value): 
     147        value = pickle.dumps(value, 2) 
     148         
     149        if not isinstance(value, str): 
     150            value = value.encode(self.encoding) 
     151        for pat, repl in self.escapes: 
     152            value = value.replace(pat, repl) 
     153         
     154        # Escape octal sequences 
     155        value = escape_oct.sub(replace_oct, value) 
     156        return "'" + value + "'" 
     157     
     158    def pull(self, value): 
     159        # Unescape octal sequences 
     160        value = unescape_oct.sub(replace_unoct, value) 
     161         
     162        # Coerce to str for pickle.loads restriction. 
     163        if isinstance(value, unicode): 
     164            value = value.encode(self.encoding) 
     165        else: 
     166            value = str(value) 
     167        return pickle.loads(value) 
     168 
     169 
     170 
     171class REAL(adapters.SQL92REAL): 
     172     
     173    def push(self, value): 
     174        # Use quotes to restrict the value to single precision, so that 
     175        # comparisons work between existing values and supplied constants. 
     176        # See http://archives.postgresql.org/pgsql-bugs/2004-02/msg00062.php 
     177        return "'%r'" % value 
     178 
     179 
     180class PgAdapterSet(adapters.AdapterSet): 
     181     
     182    def pytype_for_TIMESTAMPTZ(self, hints): 
     183        return datetime.datetime 
     184    def pytype_for_TIMETZ(self, hints): 
     185        return datetime.time 
     186    def pytype_for_INT2(self, hints): 
     187        return int 
     188    def pytype_for_INT4(self, hints): 
     189        return int 
     190    def pytype_for_INT8(self, hints): 
     191        return long 
     192    def pytype_for_FLOAT4(self, hints): 
     193        return float 
     194    def pytype_for_FLOAT8(self, hints): 
     195        return float 
     196    def pytype_for_MONEY(self, hints): 
     197        return float 
     198    def pytype_for_BYTEA(self, hints): 
     199        return str 
     200    def pytype_for_BPCHAR(self, hints): 
     201        return str 
     202    def pytype_for_INTERVAL(self, hints): 
     203        return datetime.timedelta 
     204     
     205    # Postgres has a wonderful interval type we can use. 
     206    def dbtype_for_datetime_timedelta(self, hints): 
     207        return "INTERVAL" 
     208 
     209 
     210PgAdapterSet.defaults.update({(str, "BYTEA"): BYTEA, 
     211                              (float, "REAL"): REAL, 
     212                              (float, "FLOAT4"): REAL, 
     213                              (str, "any"): PgVARCHAR, 
     214                              (unicode, "any"): PgUNICODE, 
     215                              (datetime.datetime, "any"): PgTIMESTAMP, 
     216                              (datetime.date, "any"): PgDATE, 
     217                              (datetime.timedelta, "INTERVAL"): INTERVAL, 
     218                              }) 
     219for k, v in PgAdapterSet.defaults.items(): 
     220    if issubclass(v, adapters.Pickler): 
     221        PgAdapterSet.defaults[k] = Pickler 
     222del k, v 
     223 
     224 
    123225class PgDecompiler(decompile.SQLDecompiler): 
     226     
     227    like_escapes = [("%", r"\\%"), ("_", r"\\_")] 
    124228     
    125229    def builtins_icontainedby(self, op1, op2): 
     
    127231            # Looking for text in a field. Use ILike (reverse terms). 
    128232            return self.get_expr(op2.sql + " ILIKE '%" + 
    129                                  self.adapter.escape_like(op1.sql) + "%'", 
     233                                 self.escape_like(op1.sql) + "%'", 
    130234                                 bool) 
    131235        else: 
    132236            # Looking for field in (a, b, c). 
    133237            # Force all args to lowercase for case-insensitive comparison. 
    134             atoms = [self.adapter.coerce(x).lower() for x in op2.value] 
     238            atoms = [self.adapterset.default(type(x), op1.dbtype).push(x).lower() 
     239                     for x in op2.value] 
    135240            return self.get_expr("LOWER(%s) IN (%s)" % 
    136241                                 (op1.sql, ", ".join(atoms)), bool) 
     
    138243    def builtins_istartswith(self, x, y): 
    139244        return self.get_expr(x.sql + " ILIKE '" + 
    140                              self.adapter.escape_like(y.sql) + "%'", bool) 
     245                             self.escape_like(y.sql) + "%'", bool) 
    141246     
    142247    def builtins_iendswith(self, x, y): 
    143248        return self.get_expr(x.sql + " ILIKE '%" + 
    144                              self.adapter.escape_like(y.sql) + "'", bool) 
     249                             self.escape_like(y.sql) + "'", bool) 
    145250     
    146251    def builtins_ieq(self, x, y): 
    147252        # ILIKE with no wildcards should behave like ieq. 
    148253        return self.get_expr(x.sql + " ILIKE '" + 
    149                              self.adapter.escape_like(y.sql) + "'", bool) 
     254                             self.escape_like(y.sql) + "'", bool) 
    150255     
    151256    def builtins_year(self, x): 
     
    169274    def builtins_utcnow(self): 
    170275        return self.get_expr("NOW()", datetime.datetime) 
    171      
    172     def binary_op(self, op): 
    173         op2, op1 = self.stack.pop(), self.stack.pop() 
    174         if op1 is decompile.cannot_represent or op2 is decompile.cannot_represent: 
    175             self.stack.append(decompile.cannot_represent) 
    176             return 
    177          
    178         t1, t2 = op1.pytype, op2.pytype 
    179          
    180         sql1, sql2 = op1.sql, op2.sql 
    181         if t1 is datetime.timedelta and op1.value is not None: 
    182             if t2 is datetime.date: 
    183                 # Postgres assumes a "date" is actually midnight, so we 
    184                 # need to drop any h:m:s from our interval. 
    185                 sql1 = "interval '%s days'" % op1.value.days 
    186         if t2 is datetime.timedelta and op2.value is not None: 
    187             if t1 is datetime.date: 
    188                 # Postgres assumes a "date" is actually midnight, so we 
    189                 # need to drop any h:m:s from our interval. 
    190                 sql2 = "interval '%s days'" % op2.value.days 
    191         if t1 is datetime.date and t2 is datetime.date: 
    192             # Cast to timestamp to achieve an INTERVAL result 
    193             newsql = "(%s::TIMESTAMP %s %s::TIMESTAMP)" % (sql1, op, sql2) 
    194         else: 
    195             newsql = "(%s %s %s)" % (sql1, op, sql2) 
    196          
    197         # re-use op1 
    198         op1.pytype = self.result_type[(t1, op, t2)] 
    199         op1.sql = newsql 
    200         if not op1.name.startswith("expr_"): 
    201             op1.name = "expr_%s" % op1.name 
    202         self.stack.append(op1) 
    203276 
    204277 
     
    284357        data, _ = self.db.fetch(sql, conn=conn) 
    285358        cols = [] 
    286         pytype = self.db.typeadapter.python_type 
     359        adapterset = self.db.adapterset 
    287360        for row in data: 
    288361            name = row[0] 
     
    299372            else: 
    300373                dbtype = None 
    301             c = geniusql.Column(pytype(dbtype), dbtype, 
     374            c = geniusql.Column(adapterset.python_type(dbtype), dbtype, 
    302375                                None, {}, row[2] in indices, 
    303376                                row[0], self.db.quote(row[0])) 
     377            c.adapter = adapterset.default(c.pytype, dbtype) 
    304378             
    305379            if dbtype in ('FLOAT4', 'FLOAT8'): 
     
    326400                    # our guessed type. Be sure to strip any ::typename 
    327401                    default = default.split("::", 1)[0] 
    328                     c.default = pytype(dbtype)(default) 
     402                    c.default = c.adapter.pull(default) 
    329403            else: 
    330404                c.default = None 
     
    382456            default = column.default or "" 
    383457            if not isinstance(default, str): 
    384                 default = self.db.adaptertosql.coerce(default, column.dbtype
     458                default = column.adapter.push(default
    385459         
    386460        if default: 
     
    422496 
    423497 
    424 class PgTypeAdapter(adapters.TypeAdapter): 
    425      
    426     _reverse_types = adapters.TypeAdapter._reverse_types.copy() 
    427     _reverse_types.update({ 
    428         "TIMESTAMPTZ": datetime.datetime, 
    429         "TIMETZ": datetime.time, 
    430         "INT2": int, 
    431         "INT4": int, 
    432         "INT8": long, 
    433         "FLOAT4": float, 
    434         "FLOAT8": float, 
    435         "MONEY": float, 
    436         "BYTEA": str, 
    437         "BPCHAR": str, 
    438         "INTERVAL": datetime.timedelta, 
    439         }) 
    440      
    441     # Postgres has a wonderful interval type we can use. 
    442     def coerce_datetime_timedelta(self, hints): 
    443         return "interval" 
    444  
    445  
    446498class PgDatabase(geniusql.Database): 
    447499     
     
    451503    encoding = 'SQL_ASCII' 
    452504     
    453     adaptertosql = AdapterToPgSQL() 
    454     adapterfromdb = AdapterFromPg() 
    455     typeadapter = PgTypeAdapter() 
    456      
    457505    decompiler = PgDecompiler 
    458506    schemaclass = PgSchema 
     507    adapterset = PgAdapterSet() 
    459508     
    460509    def quote(self, name): 
  • trunk/geniusql/providers/psycopg.py

    r53 r54  
    77    from psycopg2 import _psycopg 
    88 
    9 import datetime 
    10  
    11 import geniusql 
    129from geniusql import conns, errors 
    1310from geniusql.providers import postgres 
    14  
    15  
    16 class AdapterFromPsycoPg(postgres.AdapterFromPg): 
    17      
    18     def coerce_any_to_datetime_datetime(self, value): 
    19         return value 
    20      
    21     def coerce_any_to_datetime_date(self, value): 
    22         if isinstance(value, datetime.datetime): 
    23             # Psycopg might do this when adding date + timedelta, for instance. 
    24             return value.date() 
    25         return value 
    26     coerce_DATE_to_datetime_date = coerce_any_to_datetime_date 
    27      
    28     def coerce_any_to_datetime_time(self, value): 
    29         return value 
    30      
    31     def coerce_any_to_datetime_timedelta(self, value): 
    32         return value 
    3311 
    3412 
     
    8260class PsycoPgDatabase(postgres.PgDatabase): 
    8361     
    84     adapterfromdb = AdapterFromPsycoPg() 
    8562    connectionmanager = PsycoPgConnectionManager 
    8663    schemaclass = PsycoPgSchema 
     
    9875            conn = self.connections.get() 
    9976        if isinstance(query, unicode): 
    100             query = query.encode(self.adaptertosql.encoding) 
     77            query = query.encode(self.adapterset.encoding) 
    10178        self.log(query) 
    10279        cursor = conn.cursor() 
     
    11188            conn = self.connections.get() 
    11289        if isinstance(query, unicode): 
    113             query = query.encode(self.adaptertosql.encoding) 
     90            query = query.encode(self.adapterset.encoding) 
    11491        self.log(query) 
    11592         
  • trunk/geniusql/providers/pypgsql.py

    r53 r54  
    22from pyPgSQL import libpq 
    33 
    4 import geniusql 
    5 from geniusql import conns, errors, typerefs 
     4from geniusql import conns, errors 
    65from geniusql.providers import postgres 
    76 
  • trunk/geniusql/providers/sqlite.py

    r53 r54  
    4444    # See http://www.sqlite.org/lang_expr.html 
    4545    escapes = [("'", "''")] 
    46     like_escapes = [("%", "\%"), ("_", "\_")] 
    47      
    48     bool_true = "1" 
    49     bool_false = "0" 
    5046     
    5147    def coerce_bool_to_any(self, value): 
     
    6763class SQLiteDecompiler(decompile.SQLDecompiler): 
    6864     
     65    bool_true = "1" 
     66    bool_false = "0" 
     67     
     68    like_escapes = [("%", "\%"), ("_", "\_")] 
     69     
    6970    def attr_startswith(self, tos, arg): 
    7071        if _escape_support: 
    7172            return self.get_expr(tos.sql + " LIKE '" + 
    72                                  self.adapter.escape_like(arg.sql) + 
     73                                 self.escape_like(arg.sql) + 
    7374                                 r"%' ESCAPE '\'", bool) 
    7475        else: 
     
    8283        if _escape_support: 
    8384            return self.get_expr(tos.sql + " LIKE '%" + 
    84                                  self.adapter.escape_like(arg.sql) + 
     85                                 self.escape_like(arg.sql) + 
    8586                                 r"' ESCAPE '\'", bool) 
    8687        else: 
     
    9697            if _escape_support: 
    9798                return self.get_expr(op2.sql + " LIKE '%" + 
    98                                      self.adapter.escape_like(op1.sql) + 
     99                                     self.escape_like(op1.sql) + 
    99100                                     r"%' ESCAPE '\'", bool) 
    100101            else: 
     
    115116            if _escape_support: 
    116117                return self.get_expr("LOWER(" + op2.sql + ") LIKE '%" + 
    117                                      self.adapter.escape_like(op1.sql).lower() + 
     118                                     self.escape_like(op1.sql).lower() + 
    118119                                     r"%' ESCAPE '\'", bool) 
    119120            else: 
     
    134135        if _escape_support: 
    135136            return self.get_expr("LOWER(" + x.sql + ") LIKE '" + 
    136                                  self.adapter.escape_like(y.sql) 
     137                                 self.escape_like(y.sql) 
    137138                                 + r"%' ESCAPE '\'", bool) 
    138139        else: 
     
    146147        if _escape_support: 
    147148            return self.get_expr("LOWER(" + x.sql + ") LIKE '%" + 
    148                                  self.adapter.escape_like(y.sql) 
     149                                 self.escape_like(y.sql) 
    149150                                 + r"%' ESCAPE '\'", bool) 
    150151        else: 
  • trunk/geniusql/select.py

    r53 r54  
    247247        """Return an SQL WHERE clause, and an 'imperfect' flag.""" 
    248248        tpairs = [(t.alias or t.qname, t.table) for t in self.tables] 
    249         decom = self.db.decompiler(tpairs, self.restriction, 
    250                                    self.db.adaptertosql, 
    251                                    self.db.typeadapter) 
     249        decom = self.db.decompiler(tpairs, self.restriction, self.db.adapterset) 
    252250        return decom.code(), decom.imperfect 
    253251     
     
    344342        tpairs = [(t.alias or t.qname, t.table) for t in self.tables] 
    345343        decom = self.db.decompiler 
    346         decom = decom(tpairs, self.attributes, self.db.adaptertosql, 
    347                       self.db.typeadapter) 
     344        decom = decom(tpairs, self.attributes, self.db.adapterset) 
     345##        decom.verbose = True 
    348346         
    349347        from geniusql import objects, decompile 
     
    368366            col = objects.Column(atom.pytype, atom.dbtype, name=atom.name) 
    369367            col.qname = self.db.quote(col.name) 
    370             col.imperfect_type = atom.imperfect_type 
     368            col.adapter = atom.adapter 
    371369            self.result[atom.name] = col 
    372370 
  • trunk/geniusql/test/test_sqlserver.py

    r53 r54  
    88                      "The SQL Server test will not be run.") 
    99else: 
    10     from geniusql.providers import ado 
     10    from geniusql.providers import sqlserver 
    1111    try: 
    12         ado.gen_py() 
     12        sqlserver.ado.gen_py() 
    1313    except ImportError: 
    1414        def run(): 
  • trunk/geniusql/test/zoo_fixture.py

    r53 r54  
    130130         
    131131        t = schema.table('NothingToDoWithZoos') 
    132         t['ALong'] = schema.column(long, hints={'precision': 1}) 
     132        t['ALong'] = schema.column(long, hints={'bytes': 1}) 
    133133        t['AFloat'] = schema.column(float, hints={'precision': 1}) 
    134134        if typerefs.decimal: 
     
    340340        self.assertEqual(matches(lambda x: x.Species.startswith('L')), 2) 
    341341        self.assertEqual(matches(lambda x: x.Species.endswith('pede')), 2) 
     342         
    342343        self.assertEqual(matches(lambda x: x.LastEscape != None), 1) 
    343344        self.assertEqual(matches(lambda x: x.LastEscape is not None), 1) 
     
    10991100            if db: 
    11001101                import math 
    1101                 maxprec = db.typeadapter.numeric_max_precision 
     1102                maxprec = db.adapterset.numeric_max_precision 
    11021103                if maxprec == 0: 
    11031104                    # SQLite, for example, must always use TEXT. 
     
    12951296            # Each thread opens a new SQLite :memory: database, 
    12961297            # so the concept of "isolation" is pretty meaningless. 
    1297             if name != ':memory:': 
    1298                 tools.TestRunner.run(loader(IsolationTests)) 
     1298##            if name != ':memory:': 
     1299##                tools.TestRunner.run(loader(IsolationTests)) 
    12991300             
    13001301##            tools.TestRunner.run(loader(DiscoveryTests))