Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

root/trunk/geniusql/codewalk.py

Revision 215 (checked in by fumanchu, 1 year ago)

Allow named_opcodes to take a string (e.g. a co_code string).

  • Property svn:eol-style set to native
Line 
1 """Bytecode visitors, including rewriters and decompilers.
2
3 This work, including the source code, documentation
4 and related data, is placed into the public domain.
5
6 The orginal author is Robert Brewer.
7
8 THIS SOFTWARE IS PROVIDED AS-IS, WITHOUT WARRANTY
9 OF ANY KIND, NOT EVEN THE IMPLIED WARRANTY OF
10 MERCHANTABILITY. THE AUTHOR OF THIS SOFTWARE
11 ASSUMES _NO_ RESPONSIBILITY FOR ANY CONSEQUENCE
12 RESULTING FROM THE USE, MODIFICATION, OR
13 REDISTRIBUTION OF THIS SOFTWARE.
14
15 """
16
17 from opcode import cmp_op, opname, opmap, HAVE_ARGUMENT
18 _visit_map = [name.replace('+', '_PLUS_') for name in opname]
19
20 import operator
21
22 try:
23     # Builtin in Python 2.4+
24     set
25 except NameError:
26     try:
27         # Module in Python 2.3
28         from sets import Set as set
29     except ImportError:
30         set = None
31
32 import types
33
34 from compiler.consts import *
35 CO_NOFREE = 0x0040
36
37
38 def named_opcodes(bits):
39     """Change initial numeric opcode bits to their named equivalents."""
40     if isinstance(bits, basestring):
41         bits = map(ord, bits)
42     bitnums = []
43     bits = iter(bits)
44     for x in bits:
45         bitnums.append(opname[x])
46         if x >= HAVE_ARGUMENT:
47             try:
48                 bitnums.append(bits.next())
49                 bitnums.append(bits.next())
50             except StopIteration:
51                 break
52     return bitnums
53
54 def numeric_opcodes(bits):
55     """Change named opcode bits to their numeric equivalents."""
56     bitnums = []
57     for x in bits:
58         if isinstance(x, basestring):
59             x = opmap[x]
60         bitnums.append(x)
61     return bitnums
62
63 _deref_bytecode = numeric_opcodes(['LOAD_DEREF', 0, 0, 'RETURN_VALUE'])
64 # CodeType(argcount, nlocals, stacksize, flags, codestring, constants,
65 #          names, varnames, filename, name, firstlineno,
66 #          lnotab[, freevars[, cellvars]])
67 _derefblock = types.CodeType(0, 0, 1, 3, ''.join(map(chr, _deref_bytecode)),
68                        (None,), ('cell',), (), '', '', 2, '', ('cell',))
69 def deref_cell(cell):
70     """Return the value of 'cell' (an object from a func_closure)."""
71     # FunctionType(code, globals[, name[, argdefs[, closure]]])
72     return types.FunctionType(_derefblock, {}, "", (), (cell,))()
73
74 def make_closure(*args):
75     def inner():
76         args
77     return inner.func_closure
78
79
80 binary_operators = {'BINARY_POWER': operator.pow,
81                     'BINARY_MULTIPLY': operator.mul,
82                     'BINARY_DIVIDE': operator.div,
83                     'BINARY_FLOOR_DIVIDE': operator.floordiv,
84                     'BINARY_TRUE_DIVIDE': operator.truediv,
85                     'BINARY_MODULO': operator.mod,
86                     'BINARY_ADD': operator.add,
87                     'BINARY_SUBTRACT': operator.sub,
88                     'BINARY_SUBSCR': operator.getitem,
89                     'BINARY_LSHIFT': operator.lshift,
90                     'BINARY_RSHIFT': operator.rshift,
91                     'BINARY_AND': operator.and_,
92                     'BINARY_XOR': operator.xor,
93                     'BINARY_OR': operator.or_,
94                     }
95 inplace_operators = dict([('INPLACE_' + k.split('_')[1], v)
96                           for k, v in binary_operators.iteritems()
97                           if k not in ('BINARY_SUBSCR',)
98                           ])
99
100 binary_repr = {'BINARY_POWER': '**',
101                'BINARY_MULTIPLY': '*',
102                'BINARY_DIVIDE': '/',
103                'BINARY_FLOOR_DIVIDE': '//',
104                'BINARY_TRUE_DIVIDE': '/',
105                'BINARY_MODULO': '%',
106                'BINARY_ADD': '+',
107                'BINARY_SUBTRACT': '-',
108                'BINARY_LSHIFT': '<<',
109                'BINARY_RSHIFT': '>>',
110                'BINARY_AND': '&',
111                'BINARY_XOR': '^',
112                'BINARY_OR': '|',
113                }
114
115 inplace_repr = dict([('INPLACE_' + k.split('_')[1], v + '=')
116                      for k, v in binary_repr.iteritems()])
117
118 comparisons = {'<': operator.lt,
119                '<=': operator.le,
120                '==': operator.eq,
121                '!=': operator.ne,
122                '>': operator.gt,
123                '>=': operator.gt,
124                'in': operator.contains,
125                'not in': lambda x, y: not x in y,
126                'is': operator.is_,
127                'is not': operator.is_not,
128                }
129
130 # Cache the co_* attributes and types
131 _co_code_attrs = {}
132 for name in dir(deref_cell.func_code):
133     if name.startswith("co_"):
134         _co_code_attrs[name] = type(getattr(deref_cell.func_code, name))
135
136
137 class Visitor(object):
138     """A visitor class for bytecode sequences.
139     
140     obj: a function, code object, string, or list of opcodes.
141     """
142    
143     def __init__(self, obj):
144         self.verbose = False
145        
146         # Distill supplied 'obj' arg to a code block string.
147         if isinstance(obj, types.MethodType):
148             obj = obj.im_func
149         if isinstance(obj, types.FunctionType):
150             self._func = obj
151             obj = obj.func_code
152         if hasattr(types, 'GeneratorType') and isinstance(obj, types.GeneratorType):
153             obj = obj.gi_frame.f_code
154        
155         # Copy code object attributes (if present).
156         selfdict = self.__dict__
157         try:
158             for name, _type in _co_code_attrs.iteritems():
159                 value = getattr(obj, name)
160                 if _type is tuple:
161                     value = list(value)
162                 selfdict[name] = value
163         except AttributeError:
164             pass
165        
166         try:
167             obj = obj.co_code
168         except AttributeError:
169             pass
170        
171         # Map the code block string to a list of opcode numbers.
172         if isinstance(obj, basestring):
173             bytecode = map(ord, obj)
174         elif isinstance(obj, list):
175             bytecode = obj[:]
176         else:
177             raise TypeError("obj arg of incorrect type '%s'" % type(obj))
178        
179         self._bytecode = bytecode
180    
181     def debug(self, *messages):
182         for term in messages:
183             print term,
184    
185     def walk(self):
186         verbose = self.verbose
187        
188         self.cursor = 0
189         b = self._bytecode
190         if verbose:
191             self.debug("\n\nWALKING: ", b)
192         b_len = len(b)      # Speed hack
193         while self.cursor < b_len:
194             if verbose:
195                 self.debug("\n", self.cursor)
196            
197             op = b[self.cursor]
198             self.cursor += 1
199             if op >= HAVE_ARGUMENT:
200                 lo = b[self.cursor]
201                 self.cursor += 1
202                 hi = b[self.cursor]
203                 self.cursor += 1
204                 args = (lo, hi)
205             else:
206                 args = ()
207            
208             if verbose:
209                 self.debug("visit (%s, %s)" % (op, repr(args)))
210             self.visit_instruction(op, *args)
211            
212             instruction = _visit_map[op]
213             handler = getattr(self, 'visit_' + instruction, None)
214             if handler:
215                 if verbose:
216                     self.debug("=> %s%s" % (instruction, repr(args)))
217                 handler(*args)
218                 if verbose:
219                     self.debug("\n    %r" % self.stack)
220    
221     def visit_instruction(self, op, lo=None, hi=None):
222         pass
223
224
225 class JumpCodeAdjuster(Visitor):
226     """JumpCodeAdjuster(obj=[func|co|str|list], start, end, newlength).
227     
228     Adjusts jump codes if their target is affected by bytecode changes.
229     
230     start, end: The range of the original bytecode in question.
231     newlength: Length of the codes which overwrote bytecode[start:end].
232     """
233    
234     def __init__(self, obj, start, end, newlength):
235         Visitor.__init__(self, obj)
236         self.start = start
237         self.end = end
238         self.offset = newlength - (end - start)
239    
240     def bytecode(self):
241         """Walk self and return new bytecode."""
242         self.walk()
243         return self.newcode
244    
245     def walk(self):
246         if self.offset == 0:
247             # Avoid costly walk if no changes will be made.
248             self.newcode = self._bytecode
249         else:
250             self.newcode = []
251             Visitor.walk(self)
252    
253     def visit_instruction(self, op, lo=None, hi=None):
254         append = self.newcode.append
255         append(op)
256         if lo is not None:
257             append(lo)
258         if hi is not None:
259             append(hi)
260    
261     def visit_CONTINUE_LOOP(self, lo, hi):
262         self.visit_JUMP_ABSOLUTE(lo, hi)
263    
264     def visit_JUMP_ABSOLUTE(self, lo, hi):
265         target = lo + (hi << 8)
266         if target > self.start:
267             pos = target + self.offset
268             self.newcode[-2:] = [pos & 0xFF, pos >> 8]
269    
270     def visit_JUMP_FORWARD(self, lo, hi):
271         delta = lo + (hi << 8)
272         target = self.cursor + delta
273         if self.cursor < self.end and target > self.start:
274             pos = (target + self.offset) - self.cursor
275             self.newcode[-2:] = [pos & 0xFF, pos >> 8]
276    
277     def visit_JUMP_IF_FALSE(self, lo, hi):
278         self.visit_JUMP_FORWARD(lo, hi)
279    
280     def visit_JUMP_IF_TRUE(self, lo, hi):
281         self.visit_JUMP_FORWARD(lo, hi)
282
283
284 def safe_tuple(seq):
285     """Force func_code attributes to tuples of strings.
286     
287     Many of the func_code attributes must take tuples, not lists,
288     and *cannot* accept unicode items--they must be cast to strings
289     or the interpreter will crash.
290     """
291     seq = map(str, seq)
292     return tuple(seq)
293
294
295 class Rewriter(Visitor):
296     """Rewriter(obj=function or code object).
297     
298     Produce a new function or code object by rewriting an existing one.
299     
300     Notice that, unlike the base Visitor class, Rewriter does not accept a
301     string or list of opcodes as an initial argument.
302     """
303    
304     def bytecode(self):
305         """Walk self and return new bytecode."""
306         self.walk()
307         return self.newcode
308    
309     def code_object(self):
310         """Walk self and produce a new code object."""
311         self.walk()
312         codestr = ''.join(map(chr, self.newcode))
313         return types.CodeType(self.co_argcount, self.co_nlocals, self.co_stacksize,
314                         # Notice co_consts should *not* be safe_tupled.
315                         self.co_flags, codestr, tuple(self.co_consts),
316                         safe_tuple(self.co_names), safe_tuple(self.co_varnames),
317                         self.co_filename, self.co_name, self.co_firstlineno,
318                         self.co_lnotab, safe_tuple(self.co_freevars),
319                         safe_tuple(self.co_cellvars))
320    
321     def function(self, newname=None):
322         """Walk self and produce a new function."""
323         try:
324             f = self._func
325         except AttributeError:
326             if newname is None:
327                 newname = ''
328             co = self.code_object()
329             return types.FunctionType(co, {}, newname)
330         else:
331             if newname is None:
332                 newname = f.func_name
333             co = self.code_object()
334             return types.FunctionType(co, f.func_globals, newname,
335                                 f.func_defaults, f.func_closure)
336    
337     def const_index(self, value):
338         """The index of value in co_consts, appending it if not found."""
339         for pos, item in enumerate(self.co_consts):
340             try:
341                 if type(value) == type(item) and value == item:
342                     break
343             except TypeError:
344                 pass
345         else:
346             pos = len(self.co_consts)
347             self.co_consts.append(value)
348         return pos
349    
350     def name_index(self, value):
351         """The index of value in co_names, appending it if not found."""
352         valtype = type(value)
353         for pos, item in enumerate(self.co_names):
354             try:
355                 if valtype == type(item) and value == item:
356                     return pos
357             except TypeError:
358                 pass
359        
360         pos = len(self.co_names)
361         self.co_names.append(value)
362         return pos
363    
364     def walk(self):
365         self.newcode = []
366         Visitor.walk(self)
367    
368     def visit_instruction(self, op, lo=None, hi=None):
369         append = self.newcode.append
370         append(op)
371         if lo is not None:
372             append(lo)
373         if hi is not None:
374             append(hi)
375    
376     def put(self, start, end, *bits):
377         """Overwrite self.newcode with new opcodes (numbers or names).
378         
379         If the new codes are of different quantity than the old,
380         modify any jump codes affected.
381         """
382         bitnums = numeric_opcodes(bits)
383        
384         # Adjust jump codes. Notice this comes before bytecode is modified.
385         jca = JumpCodeAdjuster(self.newcode, start, end, len(bitnums))
386         self.newcode = jca.bytecode()
387        
388         # Rewrite bytecode.
389         self.newcode[start:end] = bitnums
390    
391     def tail(self, length, *bits):
392         """Overwrite self.newcode[-length:] with bits."""
393         end = len(self.newcode)
394         self.put(end - length, end, *bits)
395
396
397 class Localizer(Rewriter):
398     """Localizer(func, builtin_only=False, stoplist=[], verbose=False)
399     
400     If a global or builtin is known at compile time, replace it with a constant.
401     
402     This duplicates (and borrows from) Raymond Hettinger's Cookbook recipe
403     at: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/277940
404     """
405     def __init__(self, func, builtin_only=False, stoplist=[], verbose=False):
406         Rewriter.__init__(self, func)
407        
408         import __builtin__
409         self.env = vars(__builtin__).copy()
410         if not builtin_only:
411             self.env.update(func.func_globals)
412        
413         self.stoplist = stoplist
414         self.verbose = verbose
415    
416     def visit_LOAD_GLOBAL(self, lo, hi):
417         name = self.co_names[lo + (hi << 8)]
418         if name in self.env and name not in self.stoplist:
419             value = self.env[name]
420             pos = self.const_index(value)
421             self.tail(3, 'LOAD_CONST', pos & 0xFF, pos >> 8)
422             if self.verbose:
423                 self.debug(name, ' --> ', value)
424
425
426 class TaintableStack(list):
427     def __init__(self, seq=[]):
428         list.__init__(self, seq)
429         self._taintindex = set()
430         self.maxsize = len(seq)
431    
432     def taint(self, index=-1):
433         if index < 0:
434             index += len(self)
435         self._taintindex.add(index)
436    
437     def tainted(self, index=-1):
438         if index < 0:
439             index = len(self) + index
440         return (index in self._taintindex)
441    
442     def pop(self, index=-1):
443         """pop(index) -> Returns a tuple!! of (value, tainted)."""
444         if index < 0:
445             index += len(self)
446         is_tainted = (index in self._taintindex)
447         if is_tainted:
448             self._taintindex.remove(index)
449         return list.pop(self, index), is_tainted
450    
451     def append(self, obj):
452         list.append(self, obj)
453         l = len(self)
454         if l > self.maxsize:
455             self.maxsize = l
456
457
458 class EarlyBinder(Rewriter):
459     """Deep-evaluate a function, replacing free vars with constants.
460     
461     reduce_getattr: If True (the default), getattr(x, y) will be
462         replaced with x.y where possible.
463     
464     bind_late: a list of objects (globals, freevars, or attributes)
465         which should not be early-bound. For example, if you want
466         datetime.date.today() to be bound late, include it in bind_late.
467     
468     Example: k = lambda x: x.Date == datetime.date(2004, 1, 1)
469              r = EarlyBinder(k).function()
470             
471     _____ k _____                _____ r _____
472      0 LOAD_FAST     0 (x)        0 LOAD_FAST   0 (x)
473      3 LOAD_ATTR     1 (Date)     3 LOAD_ATTR   1 (Date)
474      6 LOAD_GLOBAL   2 (datetime) 6 LOAD_CONST  6 (datetime.date(2004, 1, 1))
475      9 LOAD_ATTR     3 (date)
476     12 LOAD_CONST    1 (2004)
477     15 LOAD_CONST    2 (1)
478     18 LOAD_CONST    2 (1)
479     21 CALL_FUNCTION 3
480     24 COMPARE_OP    2 (==)       9 COMPARE_OP  2 (==)
481     27 RETURN_VALUE              12 RETURN_VALUE
482     
483     This also pre-computes binary operations *and all other builtin or free
484     functions* where all operands are constants, globals, or freevars.
485     For example:
486         LOAD_CONST        1 (3)
487         LOAD_CONST        2 (4)
488         BINARY_MULTIPLY
489