Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

root/trunk/geniusql/codewalk.py

Revision 135 (checked in by fumanchu, 3 years ago)

Merged AST branch back into trunk.

  • Property svn:eol-style set to native
Line 
1 """Bytecode visitors, including rewriters and decompilers.
2
3 This work, including the source code, documentation
4 and related data, is placed into the public domain.
5
6 The orginal author is Robert Brewer.
7
8 THIS SOFTWARE IS PROVIDED AS-IS, WITHOUT WARRANTY
9 OF ANY KIND, NOT EVEN THE IMPLIED WARRANTY OF
10 MERCHANTABILITY. THE AUTHOR OF THIS SOFTWARE
11 ASSUMES _NO_ RESPONSIBILITY FOR ANY CONSEQUENCE
12 RESULTING FROM THE USE, MODIFICATION, OR
13 REDISTRIBUTION OF THIS SOFTWARE.
14
15 """
16
17 from opcode import cmp_op, opname, opmap, HAVE_ARGUMENT
18 _visit_map = [name.replace('+', '_PLUS_') for name in opname]
19
20 import operator
21
22 try:
23     # Builtin in Python 2.4+
24     set
25 except NameError:
26     try:
27         # Module in Python 2.3
28         from sets import Set as set
29     except ImportError:
30         set = None
31
32 import types
33
34 from compiler.consts import *
35 CO_NOFREE = 0x0040
36
37
38 def named_opcodes(bits):
39     """Change initial numeric opcode bits to their named equivalents."""
40     bitnums = []
41     bits = iter(bits)
42     for x in bits:
43         bitnums.append(opname[x])
44         if x >= HAVE_ARGUMENT:
45             try:
46                 bitnums.append(bits.next())
47                 bitnums.append(bits.next())
48             except StopIteration:
49                 break
50     return bitnums
51
52 def numeric_opcodes(bits):
53     """Change named opcode bits to their numeric equivalents."""
54     bitnums = []
55     for x in bits:
56         if isinstance(x, basestring):
57             x = opmap[x]
58         bitnums.append(x)
59     return bitnums
60
61 _deref_bytecode = numeric_opcodes(['LOAD_DEREF', 0, 0, 'RETURN_VALUE'])
62 # CodeType(argcount, nlocals, stacksize, flags, codestring, constants,
63 #          names, varnames, filename, name, firstlineno,
64 #          lnotab[, freevars[, cellvars]])
65 _derefblock = types.CodeType(0, 0, 1, 3, ''.join(map(chr, _deref_bytecode)),
66                        (None,), ('cell',), (), '', '', 2, '', ('cell',))
67 def deref_cell(cell):
68     """Return the value of 'cell' (an object from a func_closure)."""
69     # FunctionType(code, globals[, name[, argdefs[, closure]]])
70     return types.FunctionType(_derefblock, {}, "", (), (cell,))()
71
72 def make_closure(*args):
73     def inner():
74         args
75     return inner.func_closure
76
77
78 binary_operators = {'BINARY_POWER': operator.pow,
79                     'BINARY_MULTIPLY': operator.mul,
80                     'BINARY_DIVIDE': operator.div,
81                     'BINARY_FLOOR_DIVIDE': operator.floordiv,
82                     'BINARY_TRUE_DIVIDE': operator.truediv,
83                     'BINARY_MODULO': operator.mod,
84                     'BINARY_ADD': operator.add,
85                     'BINARY_SUBTRACT': operator.sub,
86                     'BINARY_SUBSCR': operator.getitem,
87                     'BINARY_LSHIFT': operator.lshift,
88                     'BINARY_RSHIFT': operator.rshift,
89                     'BINARY_AND': operator.and_,
90                     'BINARY_XOR': operator.xor,
91                     'BINARY_OR': operator.or_,
92                     }
93 inplace_operators = dict([('INPLACE_' + k.split('_')[1], v)
94                           for k, v in binary_operators.iteritems()
95                           if k not in ('BINARY_SUBSCR',)
96                           ])
97
98 binary_repr = {'BINARY_POWER': '**',
99                'BINARY_MULTIPLY': '*',
100                'BINARY_DIVIDE': '/',
101                'BINARY_FLOOR_DIVIDE': '//',
102                'BINARY_TRUE_DIVIDE': '/',
103                'BINARY_MODULO': '%',
104                'BINARY_ADD': '+',
105                'BINARY_SUBTRACT': '-',
106                'BINARY_LSHIFT': '<<',
107                'BINARY_RSHIFT': '>>',
108                'BINARY_AND': '&',
109                'BINARY_XOR': '^',
110                'BINARY_OR': '|',
111                }
112
113 inplace_repr = dict([('INPLACE_' + k.split('_')[1], v + '=')
114                      for k, v in binary_repr.iteritems()])
115
116 comparisons = {'<': operator.lt,
117                '<=': operator.le,
118                '==': operator.eq,
119                '!=': operator.ne,
120                '>': operator.gt,
121                '>=': operator.gt,
122                'in': operator.contains,
123                'not in': lambda x, y: not x in y,
124                'is': operator.is_,
125                'is not': operator.is_not,
126                }
127
128 # Cache the co_* attributes and types
129 _co_code_attrs = {}
130 for name in dir(deref_cell.func_code):
131     if name.startswith("co_"):
132         _co_code_attrs[name] = type(getattr(deref_cell.func_code, name))
133
134
135 class Visitor(object):
136     """A visitor class for bytecode sequences.
137     
138     obj: a function, code object, string, or list of opcodes.
139     """
140    
141     def __init__(self, obj):
142         self.verbose = False
143        
144         # Distill supplied 'obj' arg to a code block string.
145         if isinstance(obj, types.MethodType):
146             obj = obj.im_func
147         if isinstance(obj, types.FunctionType):
148             self._func = obj
149             obj = obj.func_code
150        
151         # Copy code object attributes (if present).
152         selfdict = self.__dict__
153         try:
154             for name, _type in _co_code_attrs.iteritems():
155                 value = getattr(obj, name)
156                 if _type is tuple:
157                     value = list(value)
158                 selfdict[name] = value
159         except AttributeError:
160             pass
161        
162         try:
163             obj = obj.co_code
164         except AttributeError:
165             pass
166        
167         # Map the code block string to a list of opcode numbers.
168         if isinstance(obj, basestring):
169             bytecode = map(ord, obj)
170         elif isinstance(obj, list):
171             bytecode = obj[:]
172         else:
173             raise TypeError("obj arg of incorrect type '%s'" % type(obj))
174        
175         self._bytecode = bytecode
176    
177     def debug(self, *messages):
178         for term in messages:
179             print term,
180    
181     def walk(self):
182         verbose = self.verbose
183        
184         self.cursor = 0
185         b = self._bytecode
186         if verbose:
187             self.debug("\n\nWALKING: ", b)
188         b_len = len(b)      # Speed hack
189         while self.cursor < b_len:
190             if verbose:
191                 self.debug("\n", self.cursor)
192            
193             op = b[self.cursor]
194             self.cursor += 1
195             if op >= HAVE_ARGUMENT:
196                 lo = b[self.cursor]
197                 self.cursor += 1
198                 hi = b[self.cursor]
199                 self.cursor += 1
200                 args = (lo, hi)
201             else:
202                 args = ()
203            
204             if verbose:
205                 self.debug("visit (%s, %s)" % (op, repr(args)))
206             self.visit_instruction(op, *args)
207            
208             instruction = _visit_map[op]
209             handler = getattr(self, 'visit_' + instruction, None)
210             if handler:
211                 if verbose:
212                     self.debug("=> %s%s" % (instruction, repr(args)))
213                 handler(*args)
214                 if verbose:
215                     self.debug("\n    %r" % self.stack)
216    
217     def visit_instruction(self, op, lo=None, hi=None):
218         pass
219
220
221 class JumpCodeAdjuster(Visitor):
222     """JumpCodeAdjuster(obj=[func|co|str|list], start, end, newlength).
223     
224     Adjusts jump codes if their target is affected by bytecode changes.
225     
226     start, end: The range of the original bytecode in question.
227     newlength: Length of the codes which overwrote bytecode[start:end].
228     """
229    
230     def __init__(self, obj, start, end, newlength):
231         Visitor.__init__(self, obj)
232         self.start = start
233         self.end = end
234         self.offset = newlength - (end - start)
235    
236     def bytecode(self):
237         """Walk self and return new bytecode."""
238         self.walk()
239         return self.newcode
240    
241     def walk(self):
242         if self.offset == 0:
243             # Avoid costly walk if no changes will be made.
244             self.newcode = self._bytecode
245         else:
246             self.newcode = []
247             Visitor.walk(self)
248    
249     def visit_instruction(self, op, lo=None, hi=None):
250         append = self.newcode.append
251         append(op)
252         if lo is not None:
253             append(lo)
254         if hi is not None:
255             append(hi)
256    
257     def visit_CONTINUE_LOOP(self, lo, hi):
258         self.visit_JUMP_ABSOLUTE(lo, hi)
259    
260     def visit_JUMP_ABSOLUTE(self, lo, hi):
261         target = lo + (hi << 8)
262         if target > self.start:
263             pos = target + self.offset
264             self.newcode[-2:] = [pos & 0xFF, pos >> 8]
265    
266     def visit_JUMP_FORWARD(self, lo, hi):
267         delta = lo + (hi << 8)
268         target = self.cursor + delta
269         if self.cursor < self.end and target > self.start:
270             pos = (target + self.offset) - self.cursor
271             self.newcode[-2:] = [pos & 0xFF, pos >> 8]
272    
273     def visit_JUMP_IF_FALSE(self, lo, hi):
274         self.visit_JUMP_FORWARD(lo, hi)
275    
276     def visit_JUMP_IF_TRUE(self, lo, hi):
277         self.visit_JUMP_FORWARD(lo, hi)
278
279
280 def safe_tuple(seq):
281     """Force func_code attributes to tuples of strings.
282     
283     Many of the func_code attributes must take tuples, not lists,
284     and *cannot* accept unicode items--they must be cast to strings
285     or the interpreter will crash.
286     """
287     seq = map(str, seq)
288     return tuple(seq)
289
290
291 class Rewriter(Visitor):
292     """Rewriter(obj=function or code object).
293     
294     Produce a new function or code object by rewriting an existing one.
295     
296     Notice that, unlike the base Visitor class, Rewriter does not accept a
297     string or list of opcodes as an initial argument.
298     """
299    
300     def bytecode(self):
301         """Walk self and return new bytecode."""
302         self.walk()
303         return self.newcode
304    
305     def code_object(self):
306         """Walk self and produce a new code object."""
307         self.walk()
308         codestr = ''.join(map(chr, self.newcode))
309         return types.CodeType(self.co_argcount, self.co_nlocals, self.co_stacksize,
310                         # Notice co_consts should *not* be safe_tupled.
311                         self.co_flags, codestr, tuple(self.co_consts),
312                         safe_tuple(self.co_names), safe_tuple(self.co_varnames),
313                         self.co_filename, self.co_name, self.co_firstlineno,
314                         self.co_lnotab, safe_tuple(self.co_freevars),
315                         safe_tuple(self.co_cellvars))
316    
317     def function(self, newname=None):
318         """Walk self and produce a new function."""
319         try:
320             f = self._func
321         except AttributeError:
322             if newname is None:
323                 newname = ''
324             co = self.code_object()
325             return types.FunctionType(co, {}, newname)
326         else:
327             if newname is None:
328                 newname = f.func_name
329             co = self.code_object()
330             return types.FunctionType(co, f.func_globals, newname,
331                                 f.func_defaults, f.func_closure)
332    
333     def const_index(self, value):
334         """The index of value in co_consts, appending it if not found."""
335         for pos, item in enumerate(self.co_consts):
336             try:
337                 if type(value) == type(item) and value == item:
338                     break
339             except TypeError:
340                 pass
341         else:
342             pos = len(self.co_consts)
343             self.co_consts.append(value)
344         return pos
345    
346     def name_index(self, value):
347         """The index of value in co_names, appending it if not found."""
348         valtype = type(value)
349         for pos, item in enumerate(self.co_names):
350             try:
351                 if valtype == type(item) and value == item:
352                     return pos
353             except TypeError:
354                 pass
355        
356         pos = len(self.co_names)
357         self.co_names.append(value)
358         return pos
359    
360     def walk(self):
361         self.newcode = []
362         Visitor.walk(self)
363    
364     def visit_instruction(self, op, lo=None, hi=None):
365         append = self.newcode.append
366         append(op)
367         if lo is not None:
368             append(lo)
369         if hi is not None:
370             append(hi)
371    
372     def put(self, start, end, *bits):
373         """Overwrite self.newcode with new opcodes (numbers or names).
374         
375         If the new codes are of different quantity than the old,
376         modify any jump codes affected.
377         """
378         bitnums = numeric_opcodes(bits)
379        
380         # Adjust jump codes. Notice this comes before bytecode is modified.
381         jca = JumpCodeAdjuster(self.newcode, start, end, len(bitnums))
382         self.newcode = jca.bytecode()
383        
384         # Rewrite bytecode.
385         self.newcode[start:end] = bitnums
386    
387     def tail(self, length, *bits):
388         """Overwrite self.newcode[-length:] with bits."""
389         end = len(self.newcode)
390         self.put(end - length, end, *bits)
391
392
393 class Localizer(Rewriter):
394     """Localizer(func, builtin_only=False, stoplist=[], verbose=False)
395     
396     If a global or builtin is known at compile time, replace it with a constant.
397     
398     This duplicates (and borrows from) Raymond Hettinger's Cookbook recipe
399     at: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/277940
400     """
401     def __init__(self, func, builtin_only=False, stoplist=[], verbose=False):
402         Rewriter.__init__(self, func)
403        
404         import __builtin__
405         self.env = vars(__builtin__).copy()
406         if not builtin_only:
407             self.env.update(func.func_globals)
408        
409         self.stoplist = stoplist
410         self.verbose = verbose
411    
412     def visit_LOAD_GLOBAL(self, lo, hi):
413         name = self.co_names[lo + (hi << 8)]
414         if name in self.env and name not in self.stoplist:
415             value = self.env[name]
416             pos = self.const_index(value)
417             self.tail(3, 'LOAD_CONST', pos & 0xFF, pos >> 8)
418             if self.verbose:
419                 self.debug(name, ' --> ', value)
420
421
422 class TaintableStack(list):
423     def __init__(self, seq=[]):
424         list.__init__(self, seq)
425         self._taintindex = set()
426         self.maxsize = len(seq)
427    
428     def taint(self, index=-1):
429         if index < 0:
430             index += len(self)
431         self._taintindex.add(index)
432    
433     def tainted(self, index=-1):
434         if index < 0:
435             index = len(self) + index
436         return (index in self._taintindex)
437    
438     def pop(self, index=-1):
439         """pop(index) -> Returns a tuple!! of (value, tainted)."""
440         if index < 0:
441             index += len(self)
442         is_tainted = (index in self._taintindex)
443         if is_tainted:
444             self._taintindex.remove(index)
445         return list.pop(self, index), is_tainted
446    
447     def append(self, obj):
448         list.append(self, obj)
449         l = len(self)
450         if l > self.maxsize:
451             self.maxsize = l
452
453
454 class EarlyBinder(Rewriter):
455     """Deep-evaluate a function, replacing free vars with constants.
456     
457     reduce_getattr: If True (the default), getattr(x, y) will be
458         replaced with x.y where possible.
459     
460     bind_late: a list of objects (globals, freevars, or attributes)
461         which should not be early-bound. For example, if you want
462         datetime.date.today() to be bound late, include it in bind_late.
463     
464     Example: k = lambda x: x.Date == datetime.date(2004, 1, 1)
465              r = EarlyBinder(k).function()
466             
467     _____ k _____                _____ r _____
468      0 LOAD_FAST     0 (x)        0 LOAD_FAST   0 (x)
469      3 LOAD_ATTR     1 (Date)     3 LOAD_ATTR   1 (Date)
470      6 LOAD_GLOBAL   2 (datetime) 6 LOAD_CONST  6 (datetime.date(2004, 1, 1))
471      9 LOAD_ATTR     3 (date)
472     12 LOAD_CONST    1 (2004)
473     15 LOAD_CONST    2 (1)
474     18 LOAD_CONST    2 (1)
475     21 CALL_FUNCTION 3
476     24 COMPARE_OP    2 (==)       9 COMPARE_OP  2 (==)
477     27 RETURN_VALUE              12 RETURN_VALUE
478     
479     This also pre-computes binary operations *and all other builtin or free
480     functions* where all operands are constants, globals, or freevars.
481     For example:
482         LOAD_CONST        1 (3)
483         LOAD_CONST        2 (4)
484         BINARY_MULTIPLY
485         
486     is replaced with:
487         LOAD_CONST        5 (12)
488     
489     However, order is important. lambda x: x * 4 * 5 won't see any
490     optimization, because the order of eval is (x * 4) * 5. Rewritten
491     as lambda x: 4 * 5 * x, the "4 * 5" can be replaced with "20".
492     """
493    
494     def __init__(self, func, reduce_getattr=True, bind_late=None):
495         Rewriter.__init__(self, func)
496         self.reduce_getattr = reduce_getattr
497        
498         # self.env will be used to make consts out of globals and builtins.
499         import __builtin__
500         self.env = vars(__builtin__).copy()
501         self.env.update(func.func_globals)
502        
503         # Keep a stack like the interpreter would. This does *not*
504         # get overwritten when self.newcode does--it emulates the
505         # original instructions (although tainted values may be dummies).
506         # When a local var is pushed onto this stack, it "taints" itself
507         # and any operations which depend upon it.
508         # This stack is not passed out of this class in any way.
509         self.stack = TaintableStack()
510        
511         if bind_late is None:
512             bind_late = []
513         self.bind_late = bind_late
514        
515     def code_object(self):
516         """Walk self and produce a new code object."""
517         self.walk()
518         codestr = ''.join(map(chr, self.newcode))
519         # Assert CO_NOFREE, since all free vars should have been made constant.
520         self.co_flags |= CO_NOFREE
521         co = types.CodeType(self.co_argcount, self.co_nlocals, self.stack.maxsize,
522                       self.co_flags, codestr, tuple(self.co_consts),
523                       safe_tuple(self.co_names), safe_tuple(self.co_varnames),
524                       '', self.co_name, 1,
525                       self.co_lnotab, (), ())
526         return co
527    
528     def function(self, newname=None):
529         """Walk self and produce a new function."""
530         try:
531             f = self._func
532         except AttributeError:
533             if newname is None:
534                 newname = ''
535             co = self.code_object()
536             return types.FunctionType(co, {}, newname)
537         else:
538             if newname is None:
539                 newname = f.func_name
540             co = self.code_object()
541             # All cells should be dereferenced, so force func_closure to None.
542             return types.FunctionType(co, f.func_globals, newname, f.func_defaults)
543    
544     def reduce(self, number_of_terms, transform=None, overwrite_length=None):
545         """If no stack args are to be bound late, rewrite previous opcodes.
546         
547         number_of_terms: the number of terms to pop off the stack.
548         
549         transform: a callback, to which we send the popped terms. They are
550             transformed in that function as needed, and returned.
551         
552         overwrite_length: the number of previous opcodes to overwrite. If
553             None, it defaults to (number of terms + 1 for the current
554             instruction) * 3.
555         """
556         if overwrite_length is None:
557             # +1 is for current bytecode. If any overwritten bytecode
558             # is not len 3, pass in a value for overwrite_length.
559             overwrite_length = (number_of_terms + 1) * 3
560        
561         # Pop the requested number of terms off the stack.
562         is_tainted = False
563         terms, taints = [], []
564         for i in xrange(number_of_terms):
565             term, taint = self.stack.pop()
566             taints.append(taint)
567             is_tainted |= taint
568             terms.append(term)
569        
570         # Now that all the stack-popping is done...
571         if is_tainted:
572             # We don't have to handle getattr if no args are
573             # tainted, because CALL_FUNCTION will do it normally.
574             if self.reduce_getattr:
575                 if (len(terms) == 3 and terms[2] == getattr
576                     and taints[1] and not taints[0]):
577                     # Form a new LOAD_ATTR instruction.
578                     pos = self.name_index(terms[0])
579                     # Unlike normal CALL_FUNCTION, we can't assume each arg
580                     # is a constant; therefore, our overwrite_length is
581                     # indeterminate. We'll just cheat and keep track of
582                     # the last LOAD_GLOBAL where we looked up getattr. ;)
583                     start = self.last_getattr
584                     # Grab and reuse opcodes of first (LOAD_FAST) term.
585                     bits = self.newcode[start + 3:-6]
586                     bits += ['LOAD_ATTR', pos & 0xFF, pos >> 8]
587                     bits = tuple(bits)
588                     self.put(start, len(self.newcode), *bits)
589                     self.stack.append(None)
590                     self.stack.taint()
591                     return None
592            
593             # Don't form the new object.
594             # Replace TOS with a dummy and taint it.
595             self.stack.append(None)
596             self.stack.taint()
597             return None
598        
599         # Callback the transform.
600         terms.reverse()
601         if transform:
602             result = transform(terms)
603         else:
604             result = terms
605        
606         # Replace TOS with result.
607         self.stack.append(result)
608        
609         # Overwrite bytecodes with new CONST formed from result.
610         pos = self.const_index(result)
611         self.tail(overwrite_length, 'LOAD_CONST', pos & 0xFF, pos >> 8)
612        
613         return result
614    
615     def visit_BUILD_TUPLE(self, lo, hi):
616         self.reduce(lo + (hi << 8), lambda terms: tuple(terms))
617    
618     def visit_BUILD_LIST(self, lo, hi):
619         self.reduce(lo + (hi << 8))
620    
621     def visit_CALL_FUNCTION(self, lo, hi):
622         def call(terms):
623             func = terms.pop(0)
624             args = tuple(terms[:lo])
625             kwargs = {}
626             for i in range(hi):
627                 key = self.terms.pop(0)
628                 val = self.terms.pop(0)
629                 kwargs[key] = val
630             return func(*args, **kwargs)
631         self.reduce(lo + hi + 1, call)
632    
633     def visit_COMPARE_OP(self, lo, hi):
634         op = cmp_op[lo + (hi << 8)]
635         op = comparisons[op]
636         self.reduce(2, lambda terms: op(*terms))
637    
638     def visit_LOAD_ATTR(self, lo, hi):
639         name = self.co_names[lo + (hi << 8)]
640         result = self.reduce(1, lambda terms: getattr(terms[0], name))
641         if result in self.bind_late or getattr(result, 'bind_late', False):
642             self.stack.taint()
643    
644     def visit_LOAD_CONST(self, lo, hi):
645         self.stack.append(self.co_consts[lo + (hi << 8)])
646    
647     def visit_LOAD_DEREF(self, lo, hi):
648         if hasattr(self, '_func'):
649             # name = self.co_freevars[lo + (hi << 8)]
650             value = self._func.func_closure[lo + (hi << 8)]
651             value = deref_cell(value)
652             pos = self.const_index(value)
653             self.tail(3, 'LOAD_CONST', pos & 0xFF, pos >> 8)
654             self.stack.append(value)
655             if value in self.bind_late or getattr(value, 'bind_late', False):
656                 self.stack.taint()
657    
658     def visit_LOAD_FAST(self, lo, hi):
659         self.stack.append(self.co_varnames[lo + (hi << 8)])
660         # LOAD_FAST references our bound variable, which is always bound late.
661         self.stack.taint()
662    
663     def visit_LOAD_GLOBAL(self, lo, hi):
664         name = self.co_names[lo + (hi << 8)]
665         if name == 'getattr':
666             self.last_getattr = (len(self.newcode) - 3)
667         if name in self.env:
668             value = self.env[name]
669             pos = self.const_index(value)
670             self.tail(3, 'LOAD_CONST', pos & 0xFF, pos >> 8)
671             self.stack.append(value)
672             if value in self.bind_late or getattr(value, 'bind_late', False):
673                 self.stack.taint()
674         else:
675             raise KeyError("'%s' is not present in supplied globals." % name)
676    
677     def visit_SLICE_PLUS_0(self):
678         self.reduce(1, lambda terms: terms[0][:], 4)
679    
680     def visit_SLICE_PLUS_1(self):
681         self.reduce(2, lambda terms: terms[0][terms[1]:], 7)
682    
683     def visit_SLICE_PLUS_2(self):
684         self.reduce(2, lambda terms: terms[0][:terms[1]], 7)
685    
686     def visit_SLICE_PLUS_3(self):
687         self.reduce(3, lambda terms: terms[0][terms[1]:terms[2]], 10)
688    
689     def binary_op(self, op):
690         def operate(terms):
691             return op(*terms)
692         self.reduce(2, operate, 7)
693
694 # Add visit_BINARY, visit_INPLACE methods to EarlyBinder.
695 for k, v in binary_operators.iteritems():
696     setattr(EarlyBinder, "visit_" + k,
697             lambda self, opr=v: self.binary_op(opr))
698 for k, v in inplace_operators.iteritems():
699     setattr(EarlyBinder, "visit_" + k,
700             # Yes, we really do call binary_op for inplace methods.
701             lambda self, opr=v: self.binary_op(opr))
702
703
704 class MapStackObject(dict):
705    
706     def __add__(self, other):
707         if isinstance(other, basestring):
708             return repr(self) + other
709         return dict.__add__(self, other)
710    
711     def __repr__(self):
712         atoms = []
713         for k, v in self.iteritems():
714             atoms.append("%s: %s" % (k, v))
715         return "{%s}" % ", ".join(atoms)
716
717
718 class LambdaDecompiler(Visitor):
719     """LambdaDecompiler(obj=lambda function or func_code).
720     
721     Produce decompiled Python code (as a string) from a supplied lambda."""
722    
723     def __init__(self, func, env=None):
724         Visitor.__init__(self, func)
725         if env is None:
726             self.env = {}
727         else:
728             self.env = env.copy()
729         import __builtin__
730         self.env.update(vars(__builtin__))
731         self.env.update(func.func_globals)
732    
733     def code(self, include_func_header=True):
734         self.walk()
735         product = self.stack[0]
736         if include_func_header:
737             args = list(self.co_varnames)
738             if self.co_flags & CO_VARKEYWORDS:
739                 args[-1] = "**" + args[-1]
740                 if self.co_flags & CO_VARARGS:
741                     args[-2] = "*" + args[-2]
742             elif self.co_flags & CO_VARARGS:
743                 args[-1] = "*" + args[-1]
744             args = ", ".join(args)
745            
746             product = "lambda %s: %s" % (args, product)
747         return product
748    
749     def walk(self):
750         self.stack = []
751         self.targets = {}
752        
753         Visitor.walk(self)
754        
755         if self.verbose:
756             self.debug("stack:", self.stack)
757    
758     def visit_instruction(self, op, lo=None, hi=None):
759         # Get the instruction pointer for the current instruction.
760         ip = self.cursor - 3
761         if hi is None:
762             ip += 1
763             if lo is None:
764                 ip += 1
765        
766         # This is where we do folding of logical AND and OR operators.
767         # The Python code just writes "a AND b", but the VM (bytecode)
768         # acts more like assembly, using conditional JUMP instructions to
769         # implement logical operators. The map stored in self.targets is
770         # of the form:
771         #     {JUMP target: [(self.stack[-1], 'and'), ...]}
772         # where "JUMP target" is the instruction number of the bytecode
773         # which is the target of the JUMP, and each item in the value list
774         # is a tuple of (top of the calling stack, operation).
775         # It's a list because a single bytecode may be the target of
776         # multiple JUMP instructions.
777         # See visit_JUMP_IF_FALSE / TRUE.
778         terms = self.targets.get(ip)
779         if terms:
780             clause = self.stack[-1]
781             while terms:
782                 term, oper = terms.pop()
783                 clause = "(%s) %s (%s)" % (term, oper, clause)
784             # Replace TOS with the new clause, so that further
785             # combinations have access to it.
786             self.stack[-1] = clause
787             if self.verbose:
788                 self.debug("clause:", clause, "\n")
789            
790             if op == 1:
791                 # Py2.4: The current instruction is POP_TOP, which means
792                 # the previous is probably JUMP_*. If so, we don't want to
793                 # pop the value we just placed on the stack and lose it.
794                 # We need to replace the entry that the JUMP_* made in
795                 # self.targets with our new TOS.
796                 target = self.targets[self.last_target_ip]
797                 target[-1] = ((clause, target[-1][1]))
798                 if self.verbose:
799                     self.debug("newtarget:", self.last_target_ip, target)
800    
801     def visit_BUILD_LIST(self, lo, hi):
802         terms = [str(self.stack.pop()) for i in range(lo + (hi << 8))]
803         terms.reverse()
804         self.stack.append("[%s]" % ", ".join(terms))
805    
806     def visit_BUILD_MAP(self, lo, hi):
807         # We're actually going to put a non-string object on the stack here,
808         # with the expectation that the next bytecodes will populate it.
809         self.stack.append(MapStackObject())
810    
811     def visit_BUILD_TUPLE(self, lo, hi):
812         terms = [str(self.stack.pop()) for i in range(lo + (hi << 8))]
813         terms.reverse()
814         self.stack.append("(%s)" % ", ".join(terms))
815    
816     def visit_CALL_FUNCTION(self, lo, hi):
817         kwargs = {}
818         for i in range(hi):
819             val = self.stack.pop()
820             key = self.stack.pop()
821             kwargs[key] = val
822         kwargs = ", ".join(["%s=%s" % (k, v) for k, v in kwargs.iteritems()])
823        
824         args = []
825         for i in xrange(lo):
826             arg = self.stack.pop()
827             args.append(arg)
828         args.reverse()
829         args = ", ".join([str(x) for x in args])
830        
831         if kwargs:
832             args += ", " + kwargs
833        
834         func = self.stack.pop()
835         self.stack.append("%s(%s)" % (func, args))
836    
837     def visit_COMPARE_OP(self, lo, hi):
838         term2, term1 = self.stack.pop(), self.stack.pop()
839         op = cmp_op[lo + (hi << 8)]
840         self.stack.append(term1 + " " + op + " " + term2)
841         if self.verbose:
842             self.debug(op)
843    
844     def visit_DUP_TOP(self):
845         self.stack.append(self.stack[-1])
846    
847     def visit_JUMP_IF_FALSE(self, lo, hi):
848         # Note that self.cursor has already advanced to the next instruction.
849         target = self.cursor + (lo + (hi << 8))
850         bucket = self.targets.setdefault(target, [])
851         bucket.append((self.stack[-1], 'and'))
852         if self.verbose:
853             self.debug("target:", target, bucket)
854         # Store target ip for the special code in visit_instruction
855         self.last_target_ip = target
856    
857     def visit_JUMP_IF_TRUE(self, lo, hi):
858         # Note that self.cursor has already advanced to the next instruction.
859         target = self.cursor + (lo + (hi << 8))
860         bucket = self.targets.setdefault(target, [])
861         bucket.append((self.stack[-1], 'or'))
862         if self.verbose:
863             self.debug("target:", target, bucket)
864         # Store target ip for the special code in visit_instruction
865         self.last_target_ip = target
866    
867     def visit_LOAD_ATTR(self, lo, hi):
868         term = self.co_names[lo + (hi << 8)]
869         self.stack[-1] += ("." + term)
870         if self.verbose:
871             self.debug(term)
872    
873     def visit_LOAD_CONST(self, lo, hi):
874         val = self.co_consts[lo + (hi << 8)]
875         mod = getattr(val, "__module__", None)
876         if isinstance(val, (types.FunctionType, type)):
877             # The const in question is a factory function, like int or date.
878             name = val.__name__
879             if name in self.env:
880                 term = name
881             else:
882                 term = mod + "." + name
883         else:
884             term = repr(val)
885             if mod and not mod.startswith("__"):
886                 if not term.startswith(mod + "."):
887                     term = mod + "." + term
888         self.stack.append(term)
889         if self.verbose:
890             self.debug(term)
891    
892     def visit_LOAD_FAST(self, lo, hi):
893         term = self.co_varnames[lo + (hi << 8)]
894         self.stack.append(term)
895         if self.verbose:
896             self.debug(term)
897    
898     def visit_LOAD_GLOBAL(self, lo, hi):
899         self.stack.append(self.co_names[lo + (hi << 8)])
900    
901     def visit_POP_TOP(self):
902         self.stack.pop()
903    
904     def visit_ROT_TWO(self):
905         v = self.stack.pop()
906         k = self.stack.pop()
907         self.stack.extend([v, k])
908    
909     def visit_ROT_THREE(self):
910         v = self.stack.pop()
911         k = self.stack.pop()
912         x = self.stack.pop()
913         self.stack.extend([v, x, k])
914    
915     def visit_SLICE_PLUS_0(self):
916         arg = self.stack.pop()
917         self.stack.append("%s[:]" % arg)
918    
919     def visit_SLICE_PLUS_1(self):
920         args = tuple(self.stack[-2:])
921         del self.stack[-2:]
922         self.stack.append("%s[%s:]" % args)
923    
924     def visit_SLICE_PLUS_2(self):
925         args = tuple(self.stack[-2:])
926         del self.stack[-2:]
927         self.stack.append("%s[:%s]" % args)
928    
929     def visit_SLICE_PLUS_3(self):
930         args = tuple(self.stack[-3:])
931         del self.stack[-3:]
932         self.stack.append("%s[%s:%s]" % args)
933    
934     def visit_STORE_SUBSCR(self):
935         k = self.stack.pop()
936         x = self.stack.pop()
937         v = self.stack.pop()
938         x[k] = v
939    
940     def visit_UNARY_CONVERT(self):
941         term = self.stack.pop()
942         self.stack.append("`(" + term + ")`")
943    
944     def visit_UNARY_INVERT(self):
945         term = self.stack.pop()
946         self.stack.append("~(" + term + ")")
947    
948     def visit_UNARY_NEGATIVE(self):
949         term = self.stack.pop()
950         self.stack.append("-(" + term + ")")
951    
952     def visit_UNARY_NOT(self):
953         term = self.stack.pop()
954         self.stack.append("not (" + term + ")")
955    
956     def visit_UNARY_POSITIVE(self):
957         term = self.stack.pop()
958         self.stack.append("+(" + term + ")")
959    
960     def binary_op(self, op):
961         op2, op1 = self.stack.pop(), self.stack.pop()
962         self.stack.append(op1 + " " + op + " " + op2)
963    
964     def visit_BINARY_SUBSCR(self):
965         op2, op1 = self.stack.pop(), self.stack.pop()
966         self.stack.append(op1 + "[" + op2 + "]")
967
968 # Add visit_BINARY methods to LambdaDecompiler.
969 for k, v in binary_repr.iteritems():
970     setattr(LambdaDecompiler, "visit_" + k,
971             lambda self, op=v: self.binary_op(op))
972
973
974 class GenexpDecompiler(Visitor):
975     """GenexpDecompiler(obj=generator expression).
976     
977     Produce decompiled Python code (a string) from a generator expression."""
978    
979     def __init__(self, obj, env=None):
980         if isinstance(obj, types.GeneratorType):
981             frame = obj.gi_frame
982             fcode = frame.f_code
983         elif isinstance(obj, types.FrameType):
984             frame = obj
985             fcode = frame.f_code
986        
987         Visitor.__init__(self, fcode)
988        
989         if env is None:
990             self.env = {}
991         else:
992             self.env = env.copy()
993         import __builtin__
994         self.env.update(vars(__builtin__))
995         self.env.update(frame.f_globals)
996         self.source = frame.f_locals['[outmost-iterable]']
997    
998     def code(self):
999         self.stage = 0
1000         self.ifexpr = ""
1001         self.attrs = ""
1002        
1003         self.walk()
1004        
1005         names = list(self.co_varnames)[1:]
1006         names.reverse()
1007         names = ', '.join(names)
1008        
1009         if isinstance(self.source, type(iter([]))):
1010             self.source = "[%s]" % ", ".join([repr(x) for x in self.source])
1011         return ("(%s for %s in %s if %s)" %
1012                 (self.attrs, names, self.source, self.ifexpr))
1013    
1014     def walk(self):
1015         self.stack = []
1016         self.newcode = []
1017         self.targets = {}
1018        
1019         Visitor.walk(self)
1020        
1021         if self.verbose:
1022             self.debug("stack:", self.stack)
1023    
1024     def visit_instruction(self, op, lo=None, hi=None):
1025         # Get the instruction pointer for the current instruction.
1026         ip = self.cursor - 3
1027         if hi is None:
1028             ip += 1
1029             if lo is None:
1030                 ip += 1
1031        
1032         # This is where we do folding of logical AND and OR operators.
1033         # The Python code just writes "a AND b", but the VM (bytecode)
1034         # acts more like assembly, using conditional JUMP instructions to
1035         # implement logical operators. The map stored in self.targets is
1036         # of the form:
1037         #     {JUMP target: [(self.stack[-1], 'and'), ...]}
1038         # where "JUMP target" is the instruction number of the bytecode
1039         # which is the target of the JUMP, and each item in the value list
1040         # is a tuple of (top of the calling stack, operation).
1041         # It's a list because a single bytecode may be the target of
1042         # multiple JUMP instructions.
1043         # See visit_JUMP_IF_FALSE / TRUE.
1044         terms = self.targets.get(ip)
1045         if terms:
1046             if self.stage == 3:
1047                 # 'terms' is storing the complete 'if' portion of the genexp.
1048                 clause, unnecessary_oper = terms.pop()
1049                 while terms:
1050                     term, oper = terms.pop()
1051                     clause = "(%s) %s (%s)" % (term, oper, clause)
1052                 self.ifexpr = clause
1053             else:
1054                 clause = self.stack[-1]
1055                 while terms:
1056                     term, oper = terms.pop()
1057                     clause = "(%s) %s (%s)" % (term, oper, clause)
1058                
1059                 # Replace TOS with the new clause, so that further
1060                 # combinations have access to it.
1061                 self.stack[-1] = clause
1062                 if self.verbose:
1063                     self.debug("clause:", clause, "\n")
1064                
1065                 if op == 1:
1066                     # Py2.4: The current instruction is POP_TOP, which means
1067                     # the previous is probably JUMP_*. If so, we don't want to
1068                     # pop the value we just placed on the stack and lose it.
1069                     # We need to replace the entry that the JUMP_* made in
1070                     # self.targets with our new TOS.
1071                     target = self.targets[self.last_target_ip]
1072                     target[-1] = ((clause, target[-1][1]))
1073                     if self.verbose:
1074                         self.debug("newtarget:", self.last_target_ip, target)
1075    
1076     def visit_BUILD_LIST(self, lo, hi):
1077         terms = [str(self.stack.pop()) for i in range(lo + (hi << 8))]
1078         terms.reverse()
1079         self.stack.append("[%s]" % ", ".join(terms))
1080    
1081     def visit_BUILD_MAP(self, lo, hi):
1082         # We're actually going to put a non-string object on the stack here,
1083         # with the expectation that the next bytecodes will populate it.
1084         self.stack.append(MapStackObject())
1085    
1086     def visit_BUILD_TUPLE(self, lo, hi):
1087         terms = [str(self.stack.pop()) for i in range(lo + (hi << 8))]
1088         terms.reverse()
1089         self.stack.append("(%s)" % ", ".join(terms))
1090    
1091     def visit_CALL_FUNCTION(self, lo, hi):
1092         kwargs = {}
1093         for i in range(hi):
1094             val = self.stack.pop()
1095             key = self.stack.pop()
1096             kwargs[key] = val
1097         kwargs = ", ".join(["%s=%s" % (k, v) for k, v in kwargs.iteritems()])
1098        
1099         args = []
1100         for i in xrange(lo):
1101             arg = self.stack.pop()
1102             args.append(arg)
1103         args.reverse()
1104         args = ", ".join([str(x) for x in args])
1105        
1106         if kwargs:
1107             args += ", " + kwargs
1108        
1109         func = self.stack.pop()
1110         self.stack.append("%s(%s)" % (func, args))
1111    
1112     def visit_COMPARE_OP(self, lo, hi):
1113         term2, term1 = self.stack.pop(), self.stack.pop()
1114         op = cmp_op[lo + (hi << 8)]
1115         self.stack.append(term1 + " " + op + " " + term2)
1116         if self.verbose:
1117             self.debug(op)
1118    
1119     def visit_DUP_TOP(self):
1120         self.stack.append(self.stack[-1])
1121    
1122     def visit_JUMP_IF_FALSE(self, lo, hi):
1123         # Note that self.cursor has already advanced to the next instruction.
1124         target = self.cursor + (lo + (hi << 8))
1125         bucket = self.targets.setdefault(target, [])
1126         bucket.append((self.stack[-1], 'and'))
1127         if self.verbose:
1128             self.debug("target:", target, bucket)
1129         # Store target ip for the special code in visit_instruction
1130         self.last_target_ip = target
1131    
1132     def visit_JUMP_IF_TRUE(self, lo, hi):
1133         # Note that self.cursor has already advanced to the next instruction.
1134         target = self.cursor + (lo + (hi << 8))
1135         bucket = self.targets.setdefault(target, [])
1136         bucket.append((self.stack[-1], 'or'))
1137         if self.verbose:
1138             self.debug("target:", target, bucket)
1139         # Store target ip for the special code in visit_instruction
1140         self.last_target_ip = target
1141    
1142     def visit_LOAD_ATTR(self, lo, hi):
1143         term = self.co_names[lo + (hi << 8)]
1144         self.stack[-1] += ("." + term)
1145         if self.verbose:
1146             self.debug(term)
1147    
1148     def visit_LOAD_CONST(self, lo, hi):
1149         val = self.co_consts[lo + (hi << 8)]
1150         mod = getattr(val, "__module__", None)
1151         if isinstance(val, (types.FunctionType, type)):
1152             # The const in question is a factory function, like int or date.
1153             name = val.__name__
1154             if name in self.env:
1155                 term = name
1156             else:
1157                 term = mod + "." + name
1158         else:
1159             term = repr(val)
1160             if mod and not mod.startswith("__"):
1161                 if not term.startswith(mod + "."):
1162                     term = mod + "." + term
1163         self.stack.append(term)
1164         if self.verbose:
1165             self.debug(term)
1166    
1167     def visit_LOAD_FAST(self, lo, hi):
1168         if self.stage == 1:
1169             return
1170        
1171         term = self.co_varnames[lo + (hi << 8)]
1172         self.stack.append(term)
1173         if self.verbose:
1174             self.debug(term)
1175    
1176     def visit_LOAD_GLOBAL(self, lo, hi):
1177         self.stack.append(self.co_names[lo + (hi << 8)])
1178    
1179     def visit_POP_TOP(self):
1180         if self.stage < 3:
1181             self.stack.pop()
1182    
1183     def visit_ROT_TWO(self):
1184         v = self.stack.pop()
1185         k = self.stack.pop()
1186         self.stack.extend([v, k])
1187    
1188     def visit_ROT_THREE(self):
1189         v = self.stack.pop()
1190         k = self.stack.pop()
1191         x = self.stack.pop()
1192         self.stack.extend([v, x, k])
1193    
1194     def visit_SLICE_PLUS_0(self):
1195         arg = self.stack.pop()
1196         self.stack.append("%s[:]" % arg)
1197    
1198     def visit_SLICE_PLUS_1(self):
1199         args = tuple(self.stack[-2:])
1200         del self.stack[-2:]
1201         self.stack.append("%s[%s:]" % args)
1202    
1203     def visit_SLICE_PLUS_2(self):
1204         args = tuple(self.stack[-2:])
1205         del self.stack[-2:]
1206         self.stack.append("%s[:%s]" % args)
1207    
1208     def visit_SLICE_PLUS_3(self):
1209         args = tuple(self.stack[-3:])
1210         del self.stack[-3:]
1211         self.stack.append("%s[%s:%s]" % args)
1212    
1213     def visit_STORE_SUBSCR(self):
1214         k = self.stack.pop()
1215         x = self.stack.pop()
1216         v = self.stack.pop()
1217         x[k] = v
1218    
1219     def visit_UNARY_CONVERT(self):
1220         term = self.stack.pop()
1221         self.stack.append("`(" + term + ")`")
1222    
1223     def visit_UNARY_INVERT(self):
1224         term = self.stack.pop()
1225         self.stack.append("~(" + term + ")")
1226    
1227     def visit_UNARY_NEGATIVE(self):
1228         term = self.stack.pop()
1229         self.stack.append("-(" + term + ")")
1230    
1231     def visit_UNARY_NOT(self):
1232         term = self.stack.pop()
1233         self.stack.append("not (" + term + ")")
1234    
1235     def visit_UNARY_POSITIVE(self):
1236         term = self.stack.pop()
1237         self.stack.append("+(" + term + ")")
1238    
1239     def binary_op(self, op):
1240         op2, op1 = self.stack.pop(), self.stack.pop()
1241         self.stack.append(op1 + " " + op + " " + op2)
1242    
1243     def visit_BINARY_SUBSCR(self):
1244         op2, op1 = self.stack.pop(), self.stack.pop()
1245         self.stack.append(op1 + "[" + op2 + "]")
1246    
1247     def visit_SETUP_LOOP(self, lo, hi):
1248         self.stage = 1
1249    
1250     def visit_FOR_ITER(self, lo, hi):
1251         self.stage = 2
1252         self.for_loop_address = self.cursor - 3
1253    
1254     def visit_UNPACK_SEQUENCE(self, lo, hi):
1255         # Skip all the STORE_FAST opcodes that follow.
1256         numvars = lo + (hi << 8)
1257         self.cursor += (3 * numvars)
1258    
1259     def visit_YIELD_VALUE(self):
1260         self.stage = 3
1261    
1262     def visit_POP_BLOCK(self):
1263         self.attrs = self.stack.pop()
1264         self.stage = 4
1265
1266
1267 # Add visit_BINARY methods to GenexpDecompiler.
1268 for k, v in binary_repr.iteritems():
1269     setattr(GenexpDecompiler, "visit_" + k,
1270             lambda self, op=v: self.binary_op(op))
1271
1272
1273
1274 class BranchTracker(Visitor):
1275     """BranchTracker(obj=[func|co|str|list]).
1276     
1277     Finds all possible instructions previous to the supplied instruction(s).
1278     """
1279    
1280     def branches(self, instr=None):
1281         """Walk self and return all possible instructions previous to instr.
1282         
1283         If instr is None, the last instruction will be used.
1284         """
1285         if instr is None:
1286             instr = len(self._bytecode) - 1
1287         self.watch = {instr: []}
1288         self.walk()
1289         return self.watch[instr]
1290    
1291     def visit_instruction(self, op, lo=None, hi=None):
1292         if self.cursor in self.watch and op != 113:
1293             if lo is None and hi is None:
1294                 self.watch[self.cursor].append(self.cursor - 1)
1295             else:
1296                 self.watch[self.cursor].append(self.cursor - 3)
1297    
1298     def visit_CONTINUE_LOOP(self, lo, hi):
1299         self.visit_JUMP_ABSOLUTE(lo, hi)
1300    
1301     def visit_JUMP_ABSOLUTE(self, lo, hi):
1302         target = lo + (hi << 8)
1303         if target in self.watch:
1304             self.watch[target].append(self.cursor)
1305    
1306     def visit_JUMP_FORWARD(self, lo, hi):
1307         delta = lo + (hi << 8)
1308         target = self.cursor + delta
1309         if target in self.watch:
1310             self.watch[target].append(self.cursor - 3)
1311    
1312     def visit_JUMP_IF_FALSE(self, lo, hi):
1313         self.visit_JUMP_FORWARD(lo, hi)
1314    
1315     def visit_JUMP_IF_TRUE(self, lo, hi):
1316         self.visit_JUMP_FORWARD(lo, hi)
1317
1318
1319 class KeywordInspector(Rewriter):
1320     """Produce a list of all keyword arguments expected."""
1321    
1322     def __init__(self, obj):
1323         """KeywordInspector(obj). List keyword arguments expected."""
1324         Rewriter.__init__(self, obj)
1325         if not (self.co_flags & CO_VARKEYWORDS):
1326             raise ValueError("'%s' does not possess **kwargs." % obj)
1327         if len(self.co_varnames) <= 1:
1328             raise ValueError("'%s' does not possess more than 1 varname." % obj)
1329         self._kwargs = []
1330         self.flag = None
1331    
1332     def kwargs(self):
1333         """kwargs() -> List of keyword arguments expected."""
1334         self.walk()
1335         return self._kwargs
1336    
1337     def visit_instruction(self, op, lo=None, hi=None):
1338         if op == 124 and (lo + (hi << 8) == len(self.co_varnames) - 1):
1339             self.flag = ''
1340         elif op == 100 and self.flag == '':
1341             self.flag = self.co_consts[lo + (hi << 8)]
1342         elif op == 25 and self.flag:
1343             self._kwargs.append(self.flag)
1344         else:
1345             self.flag = None
1346
1347
1348 del k, v
1349
Note: See TracBrowser for help on using the browser.