Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

Changeset 143

Show
Ignore:
Timestamp:
08/16/07 08:46:12
Author:
fumanchu
Message:

New Table.description attribute, plus lots of postgres reflection fixes.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/geniusql/adapters.py

    r142 r143  
    176176        if value is None: 
    177177            return None 
     178        if value in ('false', 'False'): 
     179            return False 
    178180        return bool(value) 
    179181 
  • trunk/geniusql/objects.py

    r142 r143  
    184184    """ 
    185185     
    186     def __new__(cls, name, qname, schema, created=False): 
     186    def __new__(cls, name, qname, schema, created=False, description=None): 
    187187        return dict.__new__(cls) 
    188188     
    189     def __init__(self, name, qname, schema, created=False): 
     189    def __init__(self, name, qname, schema, created=False, description=None): 
    190190        dict.__init__(self) 
    191191         
     
    194194        self.schema = schema 
    195195        self.created = created 
     196        self.description = description 
    196197         
    197198        self.indices = schema.indexsetclass(self) 
  • trunk/geniusql/providers/postgres.py

    r142 r143  
    3131        raise TypeError("unsupported operand type(s) for %s: " 
    3232                        "%r and %r" % (op, op1.pytype, op2.pytype)) 
     33 
     34 
     35class PgTIMESTAMPTZ_Adapter(adapters.Adapter): 
     36     
     37    def push(self, value, dbtype): 
     38        if value is None: 
     39            return 'NULL' 
     40        return ("'%04d-%02d-%02d %02d:%02d:%02d'" % 
     41                (value.year, value.month, value.day, 
     42                 value.hour, value.minute, value.second)) 
     43     
     44    def pull(self, value, dbtype): 
     45        if value is None: 
     46            return None 
     47        if isinstance(value, datetime.datetime): 
     48            return value 
     49        chunks = (value[0:4], value[5:7], value[8:10], 
     50                  value[11:13], value[14:16], value[17:19], 
     51                  value[20:26] or 0) 
     52        return datetime.datetime(*map(int, chunks)) 
    3353 
    3454 
     
    223243# cidr        IPv4 or IPv6 network address 
    224244# circle      circle in the plane 
    225 # inet        IPv4 or IPv6 host address 
    226245# line        infinite line in the plane 
    227246# lseg        line segment in the plane 
     
    231250# polygon     closed geometric path in the plane 
    232251# timetz      time of day, including time zone 
    233 # timestamptz date and time, including time zone 
    234252 
    235253 
     
    263281class CHAR(dbtypes.SQL92CHAR): 
    264282    """A fixed-length character string.""" 
    265     synonyms = ['CHARACTER'
     283    synonyms = ['CHARACTER', 'BPCHAR'
    266284    default_adapters = {str: Pg_str_to_VARCHAR(), 
    267285                        unicode: Pg_unicode_to_VARCHAR(), 
     
    285303            return False 
    286304        return True 
     305     
     306    def __str__(self): 
     307        return "Infinity" 
     308     
     309    def __repr__(self): 
     310        return "%s.%s()" % (self.__module__, self.__class__.__name__) 
     311 
    287312 
    288313 
     
    339364    pass 
    340365 
     366class TIMESTAMPTZ(dbtypes.SQL92TIMESTAMP): 
     367    """A date and time.""" 
     368    default_adapters = {datetime.datetime: PgTIMESTAMPTZ_Adapter()} 
     369 
    341370class DATE(dbtypes.SQL92DATE): 
    342371    """A calendar date (year, month, day).""" 
     
    371400 
    372401 
     402class INET(dbtypes.FrozenByteType): 
     403    """An IPv4 or IPv6 host address, and optionally the subnet.""" 
     404    # "The inet type holds an IPv4 or IPv6 host address, and optionally 
     405    # the identity of the subnet it is in, all in one field. The subnet 
     406    # identity is represented by stating how many bits of the host address 
     407    # represent the network address (the "netmask"). If the netmask is 32 
     408    # and the address is IPv4, then the value does not indicate a subnet, 
     409    # only a single host. In IPv6, the address length is 128 bits, so 128 
     410    # bits specify a unique host address. Note that if you want to accept 
     411    # networks only, you should use the cidr type rather than inet. 
     412    # 
     413    # The input format for this type is address/y where address is an IPv4 
     414    # or IPv6 address and y is the number of bits in the netmask. If the /y 
     415    # part is left off, then the netmask is 32 for IPv4 and 128 for IPv6, 
     416    # so the value represents just a single host. On display, the /y 
     417    # portion is suppressed if the netmask specifies a single host." 
     418     
     419    variable = False 
     420    encoding = 'utf8' 
     421     
     422    default_pytype = str 
     423    default_adapters = {str: adapters.str_to_SQL92VARCHAR(), 
     424                        unicode: adapters.unicode_to_SQL92VARCHAR(), 
     425                        None: adapters.Pickler(), 
     426                        } 
     427 
    373428 
    374429class PgTypeSet(dbtypes.DatabaseTypeSet): 
     
    379434                   'int': [INT2, INT4, INT8], 
    380435                   'bool': [BOOLEAN], 
    381                    'datetime': [TIMESTAMP], 
     436                   'datetime': [TIMESTAMP, TIMESTAMPTZ], 
    382437                   'date': [DATE], 
    383438                   'time': [TIME], 
    384439                   'timedelta': [INTERVAL], 
    385440                   'numeric': [DECIMAL], 
    386                    'other': [MONEY], 
     441                   'other': [MONEY, INET], 
    387442                   } 
    388443 
     
    495550     
    496551    def _get_tables(self, conn=None): 
     552        data, _ = self.db.fetch("SELECT oid FROM pg_class WHERE " 
     553                                "relname = 'pg_class'", conn=conn) 
     554        pgclass_OID = data[0][0] 
     555         
    497556        data, _ = self.db.fetch( 
    498             "SELECT tablename FROM pg_tables WHERE schemaname" 
    499             " not in ('information_schema', 'pg_catalog')", 
     557            "SELECT pg_tables.tablename, descr.description FROM " 
     558            "(pg_tables LEFT JOIN pg_class ON pg_tables.tablename = " 
     559            "pg_class.relname) LEFT JOIN (SELECT description, objoid " 
     560            "FROM pg_description WHERE classoid = %s) AS descr " 
     561            "ON pg_class.oid = descr.objoid WHERE pg_tables.schemaname " 
     562            "not in ('information_schema', 'pg_catalog')" % pgclass_OID, 
    500563            conn=conn) 
    501564        return [self.tableclass(row[0], self.db.quote(row[0]), 
    502                                 self, created=True
     565                                self, created=True, description=row[1]
    503566                for row in data] 
    504567     
     
    508571        for name, in data: 
    509572            if name == tablename: 
    510                 return self.tableclass(name, self.db.quote(name), 
    511                                        self, created=True) 
     573                t = self.tableclass(name, self.db.quote(name), 
     574                                    self, created=True) 
     575                 
     576                # Get the description of the table, if any 
     577                data, _ = self.db.fetch("SELECT oid FROM pg_class WHERE " 
     578                                        "relname = '%s'" % tablename, conn=conn) 
     579                table_OID = data[0][0] 
     580                data, _ = self.db.fetch("SELECT oid FROM pg_class WHERE " 
     581                                        "relname = 'pg_class'", conn=conn) 
     582                pgclass_OID = data[0][0] 
     583                data, _ = self.db.fetch("SELECT description FROM pg_shdescription " 
     584                                        "WHERE objoid = %s and classoid = %s" % 
     585                                        (table_OID, pgclass_OID), conn=conn) 
     586                for cell, in data: 
     587                    t.description = cell 
     588                    break 
     589                 
     590                return t 
    512591        raise errors.MappingError(tablename) 
    513592     
     
    545624            dbtype, _ = self.db.fetch("SELECT typname, typlen FROM pg_type " 
    546625                                      "WHERE oid = %s" % row[1], conn=conn) 
    547             dbtypetype = typeset.canonicalize(dbtype[0][0].upper()) 
     626            try: 
     627                dbtypetype = typeset.canonicalize(dbtype[0][0].upper()) 
     628            except KeyError, x: 
     629                x.args += ("%s.%s" % (tablename, name),) 
     630                raise 
    548631            dbtype = dbtypetype() 
    549632             
     
    554637             
    555638            if dbtypetype in (FLOAT4, FLOAT8): 
    556                 dbtype.precision = row[3] 
     639                dbtype.precision = int(row[3]) 
    557640            elif dbtypetype in (MONEY, DECIMAL): 
    558                 dbtype.precision = (row[4] >> 16) & 65535 
    559                 dbtype.scale = (row[4] & 65535) - 4 
     641                dbtype.precision = int((row[4] >> 16) & 65535) 
     642                dbtype.scale = int((row[4] & 65535) - 4) 
    560643             
    561644            if dbtypetype is VARCHAR: 
    562645                # See http://archives.postgresql.org/pgsql-interfaces/2004-07/msg00021.php 
    563                 dbtype.bytes = row[4] - 4 
     646                bytes = int(row[4] - 4) 
     647                if bytes > 0: 
     648                    dbtype.bytes = bytes 
     649                else: 
     650                    raise ValueError("Column %r has illegal size %r" % (name, bytes)) 
    564651            else: 
    565                 bytes = row[3] 
     652                bytes = int(row[3]) 
    566653                if bytes > 0: 
    567654                    dbtype.bytes = bytes 
     
    574661                default = default[0][0] 
    575662                if default.startswith("nextval("): 
    576                     # Grab seqname from "nextval(seqname::[text|regclass])" 
     663                    # Grab seqname from "nextval('seqname'::[text|regclass])" 
    577664                    c.autoincrement = True 
    578665                    c.sequence_name = seq_name.search(default).group(1) 
    579666                    c.initial = self.db.fetch("SELECT min_value FROM %s" % 
    580                                               c.sequence_name, conn=conn)[0][0] 
     667                                              c.sequence_name, conn=conn)[0][0][0] 
    581668                    c.default = None 
    582669                else: 
    583670                    # adsrc is always a string, so we must cast it using 
    584671                    # our guessed type. Be sure to strip any ::typename 
    585                     default = default.split("::", 1)[0] 
    586                     c.default = c.adapter.pull(default, c.dbtype) 
     672                    defval = default.split("::", 1)[0] 
     673                    try: 
     674                        # String defaults have quotes we need to strip 
     675                        defval = defval.strip("'") 
     676                        c.default = c.adapter.pull(defval, c.dbtype) 
     677                    except ValueError: 
     678                        # The default is probably a function like 'now()'. 
     679                        # Keep the whole unmunged string for now. 
     680                        # TODO: set default to an equivalent lambda? 
     681                        c.default = default 
    587682            else: 
    588683                c.default = None 
     
    606701            "pg_index.indrelid = %s" % table_OID, conn=conn) 
    607702        for row in data: 
     703            iname = row[0] 
     704            q_iname = self.db.quote(iname) 
     705            uniq = bool(row[3]) 
    608706            # indkey is an "array" (we get a space-separated string of ints). 
    609707            cols = map(int, row[1].split(" ")) 
     
    612710                                     "WHERE attrelid = %s AND attnum = %s" 
    613711                                     % (table_OID, col), conn=conn) 
    614                 i = geniusql.Index(row[0], self.db.quote(row[0]), tablename, 
    615                                    d[0][0], bool(row[3])) 
    616                 indices.append(i) 
     712                if not d: 
     713                    # This is probably an index that was added by hand, 
     714                    # without reference to a single existing column. 
     715                    indices.append(geniusql.Index(iname, q_iname, tablename, 
     716                                                  "<unknown>", uniq)) 
     717                else: 
     718                    attname = d[0][0] 
     719                    indices.append(geniusql.Index(iname, q_iname, tablename, 
     720                                                  attname, uniq)) 
    617721         
    618722        return indices