Contact: fumanchu@aminus.org

Log in as guest/geniusql to create tickets

root/trunk/geniusql/codewalk.py

Revision 215 (checked in by fumanchu, 2 years ago)

Allow named_opcodes to take a string (e.g. a co_code string).

  • Property svn:eol-style set to native
Line 
1 """Bytecode visitors, including rewriters and decompilers.
2
3 This work, including the source code, documentation
4 and related data, is placed into the public domain.
5
6 The orginal author is Robert Brewer.
7
8 THIS SOFTWARE IS PROVIDED AS-IS, WITHOUT WARRANTY
9 OF ANY KIND, NOT EVEN THE IMPLIED WARRANTY OF
10 MERCHANTABILITY. THE AUTHOR OF THIS SOFTWARE
11 ASSUMES _NO_ RESPONSIBILITY FOR ANY CONSEQUENCE
12 RESULTING FROM THE USE, MODIFICATION, OR
13 REDISTRIBUTION OF THIS SOFTWARE.
14
15 """
16
17 from opcode import cmp_op, opname, opmap, HAVE_ARGUMENT
18 _visit_map = [name.replace('+', '_PLUS_') for name in opname]
19
20 import operator
21
22 try:
23     # Builtin in Python 2.4+
24     set
25 except NameError:
26     try:
27         # Module in Python 2.3
28         from sets import Set as set
29     except ImportError:
30         set = None
31
32 import types
33
34 from compiler.consts import *
35 CO_NOFREE = 0x0040
36
37
38 def named_opcodes(bits):
39     """Change initial numeric opcode bits to their named equivalents."""
40     if isinstance(bits, basestring):
41         bits = map(ord, bits)
42     bitnums = []
43     bits = iter(bits)
44     for x in bits:
45         bitnums.append(opname[x])
46         if x >= HAVE_ARGUMENT:
47             try:
48                 bitnums.append(bits.next())
49                 bitnums.append(bits.next())
50             except StopIteration:
51                 break
52     return bitnums
53
54 def numeric_opcodes(bits):
55     """Change named opcode bits to their numeric equivalents."""
56     bitnums = []
57     for x in bits:
58         if isinstance(x, basestring):
59             x = opmap[x]
60         bitnums.append(x)
61     return bitnums
62
63 _deref_bytecode = numeric_opcodes(['LOAD_DEREF', 0, 0, 'RETURN_VALUE'])
64 # CodeType(argcount, nlocals, stacksize, flags, codestring, constants,
65 #          names, varnames, filename, name, firstlineno,
66 #          lnotab[, freevars[, cellvars]])
67 _derefblock = types.CodeType(0, 0, 1, 3, ''.join(map(chr, _deref_bytecode)),
68                        (None,), ('cell',), (), '', '', 2, '', ('cell',))
69 def deref_cell(cell):
70     """Return the value of 'cell' (an object from a func_closure)."""
71     # FunctionType(code, globals[, name[, argdefs[, closure]]])
72     return types.FunctionType(_derefblock, {}, "", (), (cell,))()
73
74 def make_closure(*args):
75     def inner():
76         args
77     return inner.func_closure
78
79
80 binary_operators = {'BINARY_POWER': operator.pow,
81                     'BINARY_MULTIPLY': operator.mul,
82                     'BINARY_DIVIDE': operator.div,
83                     'BINARY_FLOOR_DIVIDE': operator.floordiv,
84                     'BINARY_TRUE_DIVIDE': operator.truediv,
85                     'BINARY_MODULO': operator.mod,
86                     'BINARY_ADD': operator.add,
87                     'BINARY_SUBTRACT': operator.sub,
88                     'BINARY_SUBSCR': operator.getitem,
89                     'BINARY_LSHIFT': operator.lshift,
90                     'BINARY_RSHIFT': operator.rshift,
91                     'BINARY_AND': operator.and_,
92                     'BINARY_XOR': operator.xor,
93                     'BINARY_OR': operator.or_,
94                     }
95 inplace_operators = dict([('INPLACE_' + k.split('_')[1], v)
96                           for k, v in binary_operators.iteritems()
97                           if k not in ('BINARY_SUBSCR',)
98                           ])
99
100 binary_repr = {'BINARY_POWER': '**',
101                'BINARY_MULTIPLY': '*',
102                'BINARY_DIVIDE': '/',
103                'BINARY_FLOOR_DIVIDE': '//',
104                'BINARY_TRUE_DIVIDE': '/',
105                'BINARY_MODULO': '%',
106                'BINARY_ADD': '+',
107                'BINARY_SUBTRACT': '-',
108                'BINARY_LSHIFT': '<<',
109                'BINARY_RSHIFT': '>>',
110                'BINARY_AND': '&',
111                'BINARY_XOR': '^',
112                'BINARY_OR': '|',
113                }
114
115 inplace_repr = dict([('INPLACE_' + k.split('_')[1], v + '=')
116                      for k, v in binary_repr.iteritems()])
117
118 comparisons = {'<': operator.lt,
119                '<=': operator.le,
120                '==': operator.eq,
121                '!=': operator.ne,
122                '>': operator.gt,
123                '>=': operator.gt,
124                'in': operator.contains,
125                'not in': lambda x, y: not x in y,
126                'is': operator.is_,
127                'is not': operator.is_not,
128                }
129
130 # Cache the co_* attributes and types
131 _co_code_attrs = {}
132 for name in dir(deref_cell.func_code):
133     if name.startswith("co_"):
134         _co_code_attrs[name] = type(getattr(deref_cell.func_code, name))
135
136
137 class Visitor(object):
138     """A visitor class for bytecode sequences.
139     
140     obj: a function, code object, string, or list of opcodes.
141     """
142    
143     def __init__(self, obj):
144         self.verbose = False
145        
146         # Distill supplied 'obj' arg to a code block string.
147         if isinstance(obj, types.MethodType):
148             obj = obj.im_func
149         if isinstance(obj, types.FunctionType):
150             self._func = obj
151             obj = obj.func_code
152         if hasattr(types, 'GeneratorType') and isinstance(obj, types.GeneratorType):
153             obj = obj.gi_frame.f_code
154        
155         # Copy code object attributes (if present).
156         selfdict = self.__dict__
157         try:
158             for name, _type in _co_code_attrs.iteritems():
159                 value = getattr(obj, name)
160                 if _type is tuple:
161                     value = list(value)
162                 selfdict[name] = value
163         except AttributeError:
164             pass
165        
166         try:
167             obj = obj.co_code
168         except AttributeError:
169             pass
170        
171         # Map the code block string to a list of opcode numbers.
172         if isinstance(obj, basestring):
173             bytecode = map(ord, obj)
174         elif isinstance(obj, list):
175             bytecode = obj[:]
176         else:
177             raise TypeError("obj arg of incorrect type '%s'" % type(obj))
178        
179         self._bytecode = bytecode
180    
181     def debug(self, *messages):
182         for term in messages:
183             print term,
184    
185     def walk(self):
186         verbose = self.verbose
187        
188         self.cursor = 0
189         b = self._bytecode
190         if verbose:
191             self.debug("\n\nWALKING: ", b)
192         b_len = len(b)      # Speed hack
193         while self.cursor < b_len:
194             if verbose:
195                 self.debug("\n", self.cursor)
196            
197             op = b[self.cursor]
198             self.cursor += 1
199             if op >= HAVE_ARGUMENT:
200                 lo = b[self.cursor]
201                 self.cursor += 1
202                 hi = b[self.cursor]
203                 self.cursor += 1
204                 args = (lo, hi)
205             else:
206                 args = ()
207            
208             if verbose:
209                 self.debug("visit (%s, %s)" % (op, repr(args)))
210             self.visit_instruction(op, *args)
211            
212             instruction = _visit_map[op]
213             handler = getattr(self, 'visit_' + instruction, None)
214             if handler:
215                 if verbose:
216                     self.debug("=> %s%s" % (instruction, repr(args)))
217                 handler(*args)
218                 if verbose:
219                     self.debug("\n    %r" % self.stack)
220    
221     def visit_instruction(self, op, lo=None, hi=None):
222         pass
223
224
225 class JumpCodeAdjuster(Visitor):
226     """JumpCodeAdjuster(obj=[func|co|str|list], start, end, newlength).
227     
228     Adjusts jump codes if their target is affected by bytecode changes.
229     
230     start, end: The range of the original bytecode in question.
231     newlength: Length of the codes which overwrote bytecode[start:end].
232     """
233    
234     def __init__(self, obj, start, end, newlength):
235         Visitor.__init__(self, obj)
236         self.start = start
237         self.end = end
238         self.offset = newlength - (end - start)
239    
240     def bytecode(self):
241         """Walk self and return new bytecode."""
242         self.walk()
243         return self.newcode
244    
245     def walk(self):
246         if self.offset == 0:
247             # Avoid costly walk if no changes will be made.
248             self.newcode = self._bytecode
249         else:
250             self.newcode = []
251             Visitor.walk(self)
252    
253     def visit_instruction(self, op, lo=None, hi=None):
254         append = self.newcode.append
255         append(op)
256         if lo is not None:
257             append(lo)
258         if hi is not None:
259             append(hi)
260    
261     def visit_CONTINUE_LOOP(self, lo, hi):
262         self.visit_JUMP_ABSOLUTE(lo, hi)
263    
264     def visit_JUMP_ABSOLUTE(self, lo, hi):
265         target = lo + (hi << 8)
266         if target > self.start:
267             pos = target + self.offset
268             self.newcode[-2:] = [pos & 0xFF, pos >> 8]
269    
270     def visit_JUMP_FORWARD(self, lo, hi):
271         delta = lo + (hi << 8)
272         target = self.cursor + delta
273         if self.cursor < self.end and target > self.start:
274             pos = (target + self.offset) - self.cursor
275             self.newcode[-2:] = [pos & 0xFF, pos >> 8]
276    
277     def visit_JUMP_IF_FALSE(self, lo, hi):
278         self.visit_JUMP_FORWARD(lo, hi)
279    
280     def visit_JUMP_IF_TRUE(self, lo, hi):
281         self.visit_JUMP_FORWARD(lo, hi)
282
283
284 def safe_tuple(seq):
285     """Force func_code attributes to tuples of strings.
286     
287     Many of the func_code attributes must take tuples, not lists,
288     and *cannot* accept unicode items--they must be cast to strings
289     or the interpreter will crash.
290     """
291     seq = map(str, seq)
292     return tuple(seq)
293
294
295 class Rewriter(Visitor):
296     """Rewriter(obj=function or code object).
297     
298     Produce a new function or code object by rewriting an existing one.
299     
300     Notice that, unlike the base Visitor class, Rewriter does not accept a
301     string or list of opcodes as an initial argument.
302     """
303    
304     def bytecode(self):
305         """Walk self and return new bytecode."""
306         self.walk()
307         return self.newcode
308    
309     def code_object(self):
310         """Walk self and produce a new code object."""
311         self.walk()
312         codestr = ''.join(map(chr, self.newcode))
313         return types.CodeType(self.co_argcount, self.co_nlocals, self.co_stacksize,
314                         # Notice co_consts should *not* be safe_tupled.
315                         self.co_flags, codestr, tuple(self.co_consts),
316                         safe_tuple(self.co_names), safe_tuple(self.co_varnames),
317                         self.co_filename, self.co_name, self.co_firstlineno,
318                         self.co_lnotab, safe_tuple(self.co_freevars),
319                         safe_tuple(self.co_cellvars))
320    
321     def function(self, newname=None):
322         """Walk self and produce a new function."""
323         try:
324             f = self._func
325         except AttributeError:
326             if newname is None:
327                 newname = ''
328             co = self.code_object()
329             return types.FunctionType(co, {}, newname)
330         else:
331             if newname is None:
332                 newname = f.func_name
333             co = self.code_object()
334             return types.FunctionType(co, f.func_globals, newname,
335                                 f.func_defaults, f.func_closure)
336    
337     def const_index(self, value):
338         """The index of value in co_consts, appending it if not found."""
339         for pos, item in enumerate(self.co_consts):
340             try:
341                 if type(value) == type(item) and value == item:
342                     break
343             except TypeError:
344                 pass
345         else:
346             pos = len(self.co_consts)
347             self.co_consts.append(value)
348         return pos
349    
350     def name_index(self, value):
351         """The index of value in co_names, appending it if not found."""
352         valtype = type(value)
353         for pos, item in enumerate(self.co_names):
354             try:
355                 if valtype == type(item) and value == item:
356                     return pos
357             except TypeError:
358                 pass
359        
360         pos = len(self.co_names)
361         self.co_names.append(value)
362         return pos
363    
364     def walk(self):
365         self.newcode = []
366         Visitor.walk(self)
367    
368     def visit_instruction(self, op, lo=None, hi=None):
369         append = self.newcode.append
370         append(op)
371         if lo is not None:
372             append(lo)
373         if hi is not None:
374             append(hi)
375    
376     def put(self, start, end, *bits):
377         """Overwrite self.newcode with new opcodes (numbers or names).
378         
379         If the new codes are of different quantity than the old,
380         modify any jump codes affected.
381         """
382         bitnums = numeric_opcodes(bits)
383        
384         # Adjust jump codes. Notice this comes before bytecode is modified.
385         jca = JumpCodeAdjuster(self.newcode, start, end, len(bitnums))
386         self.newcode = jca.bytecode()
387        
388         # Rewrite bytecode.
389         self.newcode[start:end] = bitnums
390    
391     def tail(self, length, *bits):
392         """Overwrite self.newcode[-length:] with bits."""
393         end = len(self.newcode)
394         self.put(end - length, end, *bits)
395
396
397 class Localizer(Rewriter):
398     """Localizer(func, builtin_only=False, stoplist=[], verbose=False)
399     
400     If a global or builtin is known at compile time, replace it with a constant.
401     
402     This duplicates (and borrows from) Raymond Hettinger's Cookbook recipe
403     at: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/277940
404     """
405     def __init__(self, func, builtin_only=False, stoplist=[], verbose=False):
406         Rewriter.__init__(self, func)
407        
408         import __builtin__
409         self.env = vars(__builtin__).copy()
410         if not builtin_only:
411             self.env.update(func.func_globals)
412        
413         self.stoplist = stoplist
414         self.verbose = verbose
415    
416     def visit_LOAD_GLOBAL(self, lo, hi):
417         name = self.co_names[lo + (hi << 8)]
418         if name in self.env and name not in self.stoplist:
419             value = self.env[name]
420             pos = self.const_index(value)
421             self.tail(3, 'LOAD_CONST', pos & 0xFF, pos >> 8)
422             if self.verbose:
423                 self.debug(name, ' --> ', value)
424
425
426 class TaintableStack(list):
427     def __init__(self, seq=[]):
428         list.__init__(self, seq)
429         self._taintindex = set()
430         self.maxsize = len(seq)
431    
432     def taint(self, index=-1):
433         if index < 0:
434             index += len(self)
435         self._taintindex.add(index)
436    
437     def tainted(self, index=-1):
438         if index < 0:
439             index = len(self) + index
440         return (index in self._taintindex)
441    
442     def pop(self, index=-1):
443         """pop(index) -> Returns a tuple!! of (value, tainted)."""
444         if index < 0:
445             index += len(self)
446         is_tainted = (index in self._taintindex)
447         if is_tainted:
448             self._taintindex.remove(index)
449         return list.pop(self, index), is_tainted
450    
451     def append(self, obj):
452         list.append(self, obj)
453         l = len(self)
454         if l > self.maxsize:
455             self.maxsize = l
456
457
458 class EarlyBinder(Rewriter):
459     """Deep-evaluate a function, replacing free vars with constants.
460     
461     reduce_getattr: If True (the default), getattr(x, y) will be
462         replaced with x.y where possible.
463     
464     bind_late: a list of objects (globals, freevars, or attributes)
465         which should not be early-bound. For example, if you want
466         datetime.date.today() to be bound late, include it in bind_late.
467     
468     Example: k = lambda x: x.Date == datetime.date(2004, 1, 1)
469              r = EarlyBinder(k).function()
470             
471     _____ k _____                _____ r _____
472      0 LOAD_FAST     0 (x)        0 LOAD_FAST   0 (x)
473      3 LOAD_ATTR     1 (Date)     3 LOAD_ATTR   1 (Date)
474      6 LOAD_GLOBAL   2 (datetime) 6 LOAD_CONST  6 (datetime.date(2004, 1, 1))
475      9 LOAD_ATTR     3 (date)
476     12 LOAD_CONST    1 (2004)
477     15 LOAD_CONST    2 (1)
478     18 LOAD_CONST    2 (1)
479     21 CALL_FUNCTION 3
480     24 COMPARE_OP    2 (==)       9 COMPARE_OP  2 (==)
481     27 RETURN_VALUE              12 RETURN_VALUE
482     
483     This also pre-computes binary operations *and all other builtin or free
484     functions* where all operands are constants, globals, or freevars.
485     For example:
486         LOAD_CONST        1 (3)
487         LOAD_CONST        2 (4)
488         BINARY_MULTIPLY
489         
490     is replaced with:
491         LOAD_CONST        5 (12)
492     
493     However, order is important. lambda x: x * 4 * 5 won't see any
494     optimization, because the order of eval is (x * 4) * 5. Rewritten
495     as lambda x: 4 * 5 * x, the "4 * 5" can be replaced with "20".
496     """
497    
498     def __init__(self, func, reduce_getattr=True, bind_late=None):
499         Rewriter.__init__(self, func)
500         self.reduce_getattr = reduce_getattr
501        
502         # self.env will be used to make consts out of globals and builtins.
503         import __builtin__
504         self.env = vars(__builtin__).copy()
505         self.env.update(func.func_globals)
506        
507         # Keep a stack like the interpreter would. This does *not*
508         # get overwritten when self.newcode does--it emulates the
509         # original instructions (although tainted values may be dummies).
510         # When a local var is pushed onto this stack, it "taints" itself
511         # and any operations which depend upon it.
512         # This stack is not passed out of this class in any way.
513         self.stack = TaintableStack()
514        
515         if bind_late is None:
516             bind_late = []
517         self.bind_late = bind_late
518        
519     def code_object(self):
520         """Walk self and produce a new code object."""
521         self.walk()
522         codestr = ''.join(map(chr, self.newcode))
523         # Assert CO_NOFREE, since all free vars should have been made constant.
524         self.co_flags |= CO_NOFREE
525         co = types.CodeType(self.co_argcount, self.co_nlocals, self.stack.maxsize,
526                       self.co_flags, codestr, tuple(self.co_consts),
527                       safe_tuple(self.co_names), safe_tuple(self.co_varnames),
528                       '', self.co_name, 1,
529                       self.co_lnotab, (), ())
530         return co
531    
532     def function(self, newname=None):
533         """Walk self and produce a new function."""
534         try:
535             f = self._func
536         except AttributeError:
537             if newname is None:
538                 newname = ''
539             co = self.code_object()
540             return types.FunctionType(co, {}, newname)
541         else:
542             if newname is None:
543                 newname = f.func_name
544             co = self.code_object()
545             # All cells should be dereferenced, so force func_closure to None.
546             return types.FunctionType(co, f.func_globals, newname, f.func_defaults)
547    
548     def reduce(self, number_of_terms, transform=None, overwrite_length=None):
549         """If no stack args are to be bound late, rewrite previous opcodes.
550         
551         number_of_terms: the number of terms to pop off the stack.
552         
553         transform: a callback, to which we send the popped terms. They are
554             transformed in that function as needed, and returned.
555         
556         overwrite_length: the number of previous opcodes to overwrite. If
557             None, it defaults to (number of terms + 1 for the current
558             instruction) * 3.
559         """
560         if overwrite_length is None:
561             # +1 is for current bytecode. If any overwritten bytecode
562             # is not len 3, pass in a value for overwrite_length.
563             overwrite_length = (number_of_terms + 1) * 3
564        
565         # Pop the requested number of terms off the stack.
566         is_tainted = False
567         terms, taints = [], []
568         for i in xrange(number_of_terms):
569             term, taint = self.stack.pop()
570             taints.append(taint)
571             is_tainted |= taint
572             terms.append(term)
573        
574         # Now that all the stack-popping is done...
575         if is_tainted:
576             # We don't have to handle getattr if no args are
577             # tainted, because CALL_FUNCTION will do it normally.
578             if self.reduce_getattr:
579                 if (len(terms) == 3 and terms[2] == getattr
580                     and taints[1] and not taints[0]):
581                     # Form a new LOAD_ATTR instruction.
582                     pos = self.name_index(terms[0])
583                     # Unlike normal CALL_FUNCTION, we can't assume each arg
584                     # is a constant; therefore, our overwrite_length is
585                     # indeterminate. We'll just cheat and keep track of
586                     # the last LOAD_GLOBAL where we looked up getattr. ;)
587                     start = self.last_getattr
588                     # Grab and reuse opcodes of first (LOAD_FAST) term.
589                     bits = self.newcode[start + 3:-6]
590                     bits += ['LOAD_ATTR', pos & 0xFF, pos >> 8]
591                     bits = tuple(bits)
592                     self.put(start, len(self.newcode), *bits)
593                     self.stack.append(None)
594                     self.stack.taint()
595                     return None
596            
597             # Don't form the new object.
598             # Replace TOS with a dummy and taint it.
599             self.stack.append(None)
600             self.stack.taint()
601             return None
602        
603         # Callback the transform.
604         terms.reverse()
605         if transform:
606             result = transform(terms)
607         else:
608             result = terms
609        
610         # Replace TOS with result.
611         self.stack.append(result)
612        
613         # Overwrite bytecodes with new CONST formed from result.
614         pos = self.const_index(result)
615         self.tail(overwrite_length, 'LOAD_CONST', pos & 0xFF, pos >> 8)
616        
617         return result
618    
619     def visit_BUILD_TUPLE(self, lo, hi):
620         self.reduce(lo + (hi << 8), lambda terms: tuple(terms))
621    
622     def visit_BUILD_LIST(self, lo, hi):
623         self.reduce(lo + (hi << 8))
624    
625     def visit_CALL_FUNCTION(self, lo, hi):
626         def call(terms):
627             func = terms.pop(0)
628             args = tuple(terms[:lo])
629             kwargs = {}
630             for i in range(hi):
631                 key = self.terms.pop(0)
632                 val = self.terms.pop(0)
633                 kwargs[key] = val
634             return func(*args, **kwargs)
635         self.reduce(lo + hi + 1, call)
636    
637     def visit_COMPARE_OP(self, lo, hi):
638         op = cmp_op[lo + (hi << 8)]
639         op = comparisons[op]
640         self.reduce(2, lambda terms: op(*terms))
641    
642     def visit_LOAD_ATTR(self, lo, hi):
643         name = self.co_names[lo + (hi << 8)]
644         result = self.reduce(1, lambda terms: getattr(terms[0], name))
645         if result in self.bind_late or getattr(result, 'bind_late', False):
646             self.stack.taint()
647    
648     def visit_LOAD_CONST(self, lo, hi):
649         self.stack.append(self.co_consts[lo + (hi << 8)])
650    
651     def visit_LOAD_DEREF(self, lo, hi):
652         if hasattr(self, '_func'):
653             # name = self.co_freevars[lo + (hi << 8)]
654             value = self._func.func_closure[lo + (hi << 8)]
655             value = deref_cell(value)
656             pos = self.const_index(value)
657             self.tail(3, 'LOAD_CONST', pos & 0xFF, pos >> 8)
658             self.stack.append(value)
659             if value in self.bind_late or getattr(value, 'bind_late', False):
660                 self.stack.taint()
661    
662     def visit_LOAD_FAST(self, lo, hi):
663         self.stack.append(self.co_varnames[lo + (hi << 8)])
664         # LOAD_FAST references our bound variable, which is always bound late.
665         self.stack.taint()
666    
667     def visit_LOAD_GLOBAL(self, lo, hi):
668         name = self.co_names[lo + (hi << 8)]
669         if name == 'getattr':
670             self.last_getattr = (len(self.newcode) - 3)
671         if name in self.env:
672             value = self.env[name]
673             pos = self.const_index(value)
674             self.tail(3, 'LOAD_CONST', pos & 0xFF, pos >> 8)
675             self.stack.append(value)
676             if value in self.bind_late or getattr(value, 'bind_late', False):
677                 self.stack.taint()
678         else:
679             raise KeyError("'%s' is not present in supplied globals." % name)
680    
681     def visit_SLICE_PLUS_0(self):
682         self.reduce(1, lambda terms: terms[0][:], 4)
683    
684     def visit_SLICE_PLUS_1(self):
685         self.reduce(2, lambda terms: terms[0][terms[1]:], 7)
686    
687     def visit_SLICE_PLUS_2(self):
688         self.reduce(2, lambda terms: terms[0][:terms[1]], 7)
689    
690     def visit_SLICE_PLUS_3(self):
691         self.reduce(3, lambda terms: terms[0][terms[1]:terms[2]], 10)
692    
693     def binary_op(self, op):
694         def operate(terms):
695             return op(*terms)
696         self.reduce(2, operate, 7)
697
698 # Add visit_BINARY, visit_INPLACE methods to EarlyBinder.
699 for k, v in binary_operators.iteritems():
700     setattr(EarlyBinder, "visit_" + k,
701             lambda self, opr=v: self.binary_op(opr))
702 for k, v in inplace_operators.iteritems():
703     setattr(EarlyBinder, "visit_" + k,
704             # Yes, we really do call binary_op for inplace methods.
705             lambda self, opr=v: self.binary_op(opr))
706
707
708 class MapStackObject(dict):
709    
710     def __add__(self, other):
711         if isinstance(other, basestring):
712             return repr(self) + other
713         return dict.__add__(self, other)
714    
715     def __repr__(self):
716         atoms = []
717         for k, v in self.iteritems():
718             atoms.append("%s: %s" % (k, v))
719         return "{%s}" % ", ".join(atoms)
720
721
722 class LambdaDecompiler(Visitor):
723     """LambdaDecompiler(obj=lambda function or func_code).
724     
725     Produce decompiled Python code (as a string) from a supplied lambda."""
726    
727     def __init__(self, func, env=None):
728         Visitor.__init__(self, func)
729         if env is None:
730             self.env = {}
731         else:
732             self.env = env.copy()
733         import __builtin__
734         self.env.update(vars(__builtin__))
735         self.env.update(func.func_globals)
736    
737     def code(self, include_func_header=True):
738         self.walk()
739         product = self.stack[0]
740         if include_func_header:
741             args = list(self.co_varnames)
742             if self.co_flags & CO_VARKEYWORDS:
743                 args[-1] = "**" + args[-1]
744                 if self.co_flags & CO_VARARGS:
745                     args[-2] = "*" + args[-2]
746             elif self.co_flags & CO_VARARGS:
747                 args[-1] = "*" + args[-1]
748             args = ", ".join(args)
749            
750             product = "lambda %s: %s" % (args, product)
751         return product
752    
753     def walk(self):
754         self.stack = []
755         self.targets = {}
756        
757         Visitor.walk(self)
758        
759         if self.verbose:
760             self.debug("stack:", self.stack)
761    
762     def visit_instruction(self, op, lo=None, hi=None):
763         # Get the instruction pointer for the current instruction.
764         ip = self.cursor - 3
765         if hi is None:
766             ip += 1
767             if lo is None:
768                 ip += 1
769        
770         # This is where we do folding of logical AND and OR operators.
771         # The Python code just writes "a AND b", but the VM (bytecode)
772         # acts more like assembly, using conditional JUMP instructions to
773         # implement logical operators. The map stored in self.targets is
774         # of the form:
775         #     {JUMP target: [(self.stack[-1], 'and'), ...]}
776         # where "JUMP target" is the instruction number of the bytecode
777         # which is the target of the JUMP, and each item in the value list
778         # is a tuple of (top of the calling stack, operation).
779         # It's a list because a single bytecode may be the target of
780         # multiple JUMP instructions.
781         # See visit_JUMP_IF_FALSE / TRUE.
782         terms = self.targets.get(ip)
783         if terms:
784             clause = self.stack[-1]
785             while terms:
786                 term, oper = terms.pop()
787                 clause = "(%s) %s (%s)" % (term, oper, clause)
788             # Replace TOS with the new clause, so that further
789             # combinations have access to it.
790             self.stack[-1] = clause
791             if self.verbose:
792                 self.debug("clause:", clause, "\n")
793            
794             if op == 1:
795                 # Py2.4: The current instruction is POP_TOP, which means
796                 # the previous is probably JUMP_*. If so, we don't want to
797                 # pop the value we just placed on the stack and lose it.
798                 # We need to replace the entry that the JUMP_* made in
799                 # self.targets with our new TOS.
800                 target = self.targets[self.last_target_ip]
801                 target[-1] = ((clause, target[-1][1]))
802                 if self.verbose:
803                     self.debug("newtarget:", self.last_target_ip, target)
804    
805     def visit_BUILD_LIST(self, lo, hi):
806         terms = [str(self.stack.pop()) for i in range(lo + (hi << 8))]
807         terms.reverse()
808         self.stack.append("[%s]" % ", ".join(terms))
809    
810     def visit_BUILD_MAP(self, lo, hi):
811         # We're actually going to put a non-string object on the stack here,
812         # with the expectation that the next bytecodes will populate it.
813         self.stack.append(MapStackObject())
814    
815     def visit_BUILD_TUPLE(self, lo, hi):
816         terms = [str(self.stack.pop()) for i in range(lo + (hi << 8))]
817         terms.reverse()
818         self.stack.append("(%s)" % ", ".join(terms))
819    
820     def visit_CALL_FUNCTION(self, lo, hi):
821         kwargs = {}
822         for i in range(hi):
823             val = self.stack.pop()
824             key = self.stack.pop()
825             kwargs[key] = val
826         kwargs = ", ".join(["%s=%s" % (k, v) for k, v in kwargs.iteritems()])
827        
828         args = []
829         for i in xrange(lo):
830             arg = self.stack.pop()
831             args.append(arg)
832         args.reverse()
833         args = ", ".join([str(x) for x in args])
834        
835         if kwargs:
836             args += ", " + kwargs
837        
838         func = self.stack.pop()
839         self.stack.append("%s(%s)" % (func, args))
840    
841     def visit_COMPARE_OP(self, lo, hi):
842         term2, term1 = self.stack.pop(), self.stack.pop()
843         op = cmp_op[lo + (hi << 8)]
844         self.stack.append(term1 + " " + op + " " + term2)
845         if self.verbose:
846             self.debug(op)
847    
848     def visit_DUP_TOP(self):
849         self.stack.append(self.stack[-1])
850    
851     def visit_JUMP_IF_FALSE(self, lo, hi):
852         # Note that self.cursor has already advanced to the next instruction.
853         target = self.cursor + (lo + (hi << 8))
854         bucket = self.targets.setdefault(target, [])
855         bucket.append((self.stack[-1], 'and'))
856         if self.verbose:
857             self.debug("target:", target, bucket)
858         # Store target ip for the special code in visit_instruction
859         self.last_target_ip = target
860    
861     def visit_JUMP_IF_TRUE(self, lo, hi):
862         # Note that self.cursor has already advanced to the next instruction.
863         target = self.cursor + (lo + (hi << 8))
864         bucket = self.targets.setdefault(target, [])
865         bucket.append((self.stack[-1], 'or'))
866         if self.verbose:
867             self.debug("target:", target, bucket)
868         # Store target ip for the special code in visit_instruction
869         self.last_target_ip = target
870    
871     def visit_LOAD_ATTR(self, lo, hi):
872         term = self.co_names[lo + (hi << 8)]
873         self.stack[-1] += ("." + term)
874         if self.verbose:
875             self.debug(term)
876    
877     def visit_LOAD_CONST(self, lo, hi):
878         val = self.co_consts[lo + (hi << 8)]
879         mod = getattr(val, "__module__", None)
880         if isinstance(val, (types.FunctionType, type)):
881             # The const in question is a factory function, like int or date.
882             name = val.__name__
883             if name in self.env:
884                 term = name
885             else:
886                 term = mod + "." + name
887         else:
888             term = repr(val)
889             if mod and not mod.startswith("__"):
890                 if not term.startswith(mod + "."):
891                     term = mod + "." + term
892         self.stack.append(term)
893         if self.verbose:
894             self.debug(term)
895    
896     def visit_LOAD_FAST(self, lo, hi):
897         term = self.co_varnames[lo + (hi << 8)]
898         self.stack.append(term)
899         if self.verbose:
900             self.debug(term)
901    
902     def visit_LOAD_GLOBAL(self, lo, hi):
903         self.stack.append(self.co_names[lo + (hi << 8)])
904    
905     def visit_POP_TOP(self):
906         self.stack.pop()
907    
908     def visit_ROT_TWO(self):
909         v = self.stack.pop()
910         k = self.stack.pop()
911         self.stack.extend([v, k])
912    
913     def visit_ROT_THREE(self):
914         v = self.stack.pop()
915         k = self.stack.pop()
916         x = self.stack.pop()
917         self.stack.extend([v, x, k])
918    
919     def visit_SLICE_PLUS_0(self):
920         arg = self.stack.pop()
921         self.stack.append("%s[:]" % arg)
922    
923     def visit_SLICE_PLUS_1(self):
924         args = tuple(self.stack[-2:])
925         del self.stack[-2:]
926         self.stack.append("%s[%s:]" % args)
927    
928     def visit_SLICE_PLUS_2(self):
929         args = tuple(self.stack[-2:])
930         del self.stack[-2:]
931         self.stack.append("%s[:%s]" % args)
932    
933     def visit_SLICE_PLUS_3(self):
934         args = tuple(self.stack[-3:])
935         del self.stack[-3:]
936         self.stack.append("%s[%s:%s]" % args)
937    
938     def visit_STORE_SUBSCR(self):
939         k = self.stack.pop()
940         x = self.stack.pop()
941         v = self.stack.pop()
942         x[k] = v
943    
944     def visit_UNARY_CONVERT(self):
945         term = self.stack.pop()
946         self.stack.append("`(" + term + ")`")
947    
948     def visit_UNARY_INVERT(self):
949         term = self.stack.pop()
950         self.stack.append("~(" + term + ")")
951    
952     def visit_UNARY_NEGATIVE(self):
953         term = self.stack.pop()
954         self.stack.append("-(" + term + ")")
955    
956     def visit_UNARY_NOT(self):
957         term = self.stack.pop()
958         self.stack.append("not (" + term + ")")
959    
960     def visit_UNARY_POSITIVE(self):
961         term = self.stack.pop()
962         self.stack.append("+(" + term + ")")
963    
964     def binary_op(self, op):
965         op2, op1 = self.stack.pop(), self.stack.pop()
966         self.stack.append(op1 + " " + op + " " + op2)
967    
968     def visit_BINARY_SUBSCR(self):
969         op2, op1 = self.stack.pop(), self.stack.pop()
970         self.stack.append(op1 + "[" + op2 + "]")
971
972 # Add visit_BINARY methods to LambdaDecompiler.
973 for k, v in binary_repr.iteritems():
974     setattr(LambdaDecompiler, "visit_" + k,
975             lambda self, op=v: self.binary_op(op))
976
977
978
979 class BranchTracker(Visitor):
980     """BranchTracker(obj=[func|co|str|list]).
981     
982     Finds all possible instructions previous to the supplied instruction(s).
983     """
984    
985     def branches(self, instr=None):
986         """Walk self and return all possible instructions previous to instr.
987         
988         If instr is None, the last instruction will be used.
989         """
990         if instr is None:
991             instr = len(self._bytecode) - 1
992         self.watch = {instr: []}
993         self.walk()
994         return self.watch[instr]
995    
996     def visit_instruction(self, op, lo=None, hi=None):
997         if self.cursor in self.watch and op != 113:
998             if lo is None and hi is None:
999                 self.watch[self.cursor].append(self.cursor - 1)
1000             else:
1001                 self.watch[self.cursor].append(self.cursor - 3)
1002    
1003     def visit_CONTINUE_LOOP(self, lo, hi):
1004         self.visit_JUMP_ABSOLUTE(lo, hi)
1005    
1006     def visit_JUMP_ABSOLUTE(self, lo, hi):
1007         target = lo + (hi << 8)
1008         if target in self.watch:
1009             self.watch[target].append(self.cursor)
1010    
1011     def visit_JUMP_FORWARD(self, lo, hi):
1012         delta = lo + (hi << 8)
1013         target = self.cursor + delta
1014         if target in self.watch:
1015             self.watch[target].append(self.cursor - 3)
1016    
1017     def visit_JUMP_IF_FALSE(self, lo, hi):
1018         self.visit_JUMP_FORWARD(lo, hi)
1019    
1020     def visit_JUMP_IF_TRUE(self, lo, hi):
1021         self.visit_JUMP_FORWARD(lo, hi)
1022
1023
1024 class KeywordInspector(Rewriter):
1025     """Produce a list of all keyword arguments expected."""
1026    
1027     def __init__(self, obj):
1028         """KeywordInspector(obj). List keyword arguments expected."""
1029         Rewriter.__init__(self, obj)
1030         if not (self.co_flags & CO_VARKEYWORDS):
1031             raise ValueError("'%s' does not possess **kwargs." % obj)
1032         if len(self.co_varnames) <= 1:
1033             raise ValueError("'%s' does not possess more than 1 varname." % obj)
1034         self._kwargs = []
1035         self.flag = None
1036    
1037     def kwargs(self):
1038         """kwargs() -> List of keyword arguments expected."""
1039         self.walk()
1040         return self._kwargs
1041    
1042     def visit_instruction(self, op, lo=None, hi=None):
1043         if op == 124 and (lo + (hi << 8) == len(self.co_varnames) - 1):
1044             self.flag = ''
1045         elif op == 100 and self.flag == '':
1046             self.flag = self.co_consts[lo + (hi << 8)]
1047         elif op == 25 and self.flag:
1048             self._kwargs.append(self.flag)
1049         else:
1050             self.flag = None
1051
1052
1053 del k, v
1054
Note: See TracBrowser for help on using the browser.