Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

I think I've seen this ORM somewhere before...

root/tags/1.4.0/codewalk.py

Revision 141 (checked in by fumanchu, 6 years ago)

Gratuitous premature optimization.

  • Property svn:eol-style set to native
Line 
1 """Bytecode visitors, including rewriters and decompilers.
2
3 This work, including the source code, documentation
4 and related data, is placed into the public domain.
5
6 The orginal author is Robert Brewer, Amor Ministries.
7
8 THIS SOFTWARE IS PROVIDED AS-IS, WITHOUT WARRANTY
9 OF ANY KIND, NOT EVEN THE IMPLIED WARRANTY OF
10 MERCHANTABILITY. THE AUTHOR OF THIS SOFTWARE
11 ASSUMES _NO_ RESPONSIBILITY FOR ANY CONSEQUENCE
12 RESULTING FROM THE USE, MODIFICATION, OR
13 REDISTRIBUTION OF THIS SOFTWARE.
14
15 """
16
17 from opcode import cmp_op, opname, opmap, HAVE_ARGUMENT
18 import operator
19 import sets
20 from types import CodeType, FunctionType
21
22
23 def named_opcodes(bits):
24     """Change initial numeric opcode bits to their named equivalents."""
25     bitnums = []
26     bits = iter(bits)
27     for x in bits:
28         bitnums.append(opname[x])
29         if x >= HAVE_ARGUMENT:
30             try:
31                 bitnums.append(bits.next())
32                 bitnums.append(bits.next())
33             except StopIteration:
34                 break
35     return bitnums
36
37 def numeric_opcodes(bits):
38     """Change named opcode bits to their numeric equivalents."""
39     bitnums = []
40     for x in bits:
41         if isinstance(x, basestring):
42             x = opmap[x]
43         bitnums.append(x)
44     return bitnums
45
46 _deref_bytecode = numeric_opcodes(['LOAD_DEREF', 0, 0, 'RETURN_VALUE'])
47 # CodeType(argcount, nlocals, stacksize, flags, codestring, constants,
48 #          names, varnames, filename, name, firstlineno,
49 #          lnotab[, freevars[, cellvars]])
50 _derefblock = CodeType(0, 0, 1, 3, ''.join(map(chr, _deref_bytecode)),
51                        (None,), ('cell',), (), '', '', 2, '', ('cell',))
52 def deref_cell(cell):
53     # FunctionType(code, globals[, name[, argdefs[, closure]]])
54     return FunctionType(_derefblock, {}, "", (), (cell,))()
55
56 def make_closure(*args):
57     def inner():
58         args
59     return inner.func_closure
60
61
62 binary_operators = {'BINARY_POWER': operator.pow,
63                     'BINARY_MULTIPLY': operator.mul,
64                     'BINARY_DIVIDE': operator.div,
65                     'BINARY_FLOOR_DIVIDE': operator.floordiv,
66                     'BINARY_TRUE_DIVIDE': operator.truediv,
67                     'BINARY_MODULO': operator.mod,
68                     'BINARY_ADD': operator.add,
69                     'BINARY_SUBTRACT': operator.sub,
70                     'BINARY_SUBSCR': operator.getitem,
71                     'BINARY_LSHIFT': operator.lshift,
72                     'BINARY_RSHIFT': operator.rshift,
73                     'BINARY_AND': operator.and_,
74                     'BINARY_XOR': operator.xor,
75                     'BINARY_OR': operator.or_,
76                     }
77 inplace_operators = dict([('INPLACE_' + k.split('_')[1], v)
78                           for k, v in binary_operators.iteritems()
79                           if k not in ('BINARY_SUBSCR',)
80                           ])
81
82 binary_repr = {'BINARY_POWER': '**',
83                'BINARY_MULTIPLY': '*',
84                'BINARY_DIVIDE': '/',
85                'BINARY_FLOOR_DIVIDE': '//',
86                'BINARY_TRUE_DIVIDE': '/',
87                'BINARY_MODULO': '%',
88                'BINARY_ADD': '+',
89                'BINARY_SUBTRACT': '-',
90                'BINARY_LSHIFT': '<<',
91                'BINARY_RSHIFT': '>>',
92                'BINARY_AND': '&',
93                'BINARY_XOR': '^',
94                'BINARY_OR': '|',
95                }
96
97 inplace_repr = dict([('INPLACE_' + k.split('_')[1], v + '=')
98                      for k, v in binary_repr.iteritems()])
99
100 comparisons = {'<': operator.lt,
101                '<=': operator.le,
102                '==': operator.eq,
103                '!=': operator.ne,
104                '>': operator.gt,
105                '>=': operator.gt,
106                'in': operator.contains,
107                'not in': lambda x, y: not x in y,
108                'is': operator.is_,
109                'is not': operator.is_not,
110                }
111
112 _co_code_attrs = ['co_argcount', 'co_cellvars', 'co_code', 'co_consts',
113                   'co_filename', 'co_firstlineno', 'co_flags', 'co_freevars',
114                   'co_lnotab', 'co_name', 'co_names', 'co_nlocals',
115                   'co_stacksize', 'co_varnames']
116
117
118 class Visitor(object):
119     """Visitor(obj=function, code object, string, or list of opcodes)."""
120    
121     def __init__(self, obj):
122         self.verbose = False
123        
124         # Distill supplied 'obj' arg to a code block string.
125         if isinstance(obj, FunctionType):
126             obj = obj.func_code
127         try:
128             obj = obj.co_code
129         except AttributeError:
130             pass
131        
132         # Map the code block string to a list of opcode numbers.
133         if isinstance(obj, basestring):
134             obj = map(ord, obj)
135         elif isinstance(obj, list):
136             obj = obj[:]
137        
138         self._bytecode = obj
139    
140     def debug(self, *messages):
141         if self.verbose:
142             for term in messages:
143                 print term,
144         else:
145             pass
146    
147     def walk(self):
148         self.cursor = 0
149         b = self._bytecode
150         self.debug("\nWALKING: ", b)
151         while self.cursor < len(b):
152             self.debug("\n", self.cursor)
153            
154             op = b[self.cursor]
155             self.cursor += 1
156             if op >= HAVE_ARGUMENT:
157                 lo = b[self.cursor]
158                 self.cursor += 1
159                 hi = b[self.cursor]
160                 self.cursor += 1
161                 args = (lo, hi)
162             else:
163                 args = ()
164            
165             self.debug("visit (%s, %s)" % (op, repr(args)))
166             self.visit_instruction(op, *args)
167            
168             name = 'visit_' + opname[op]
169             name = name.replace('+', '_PLUS_')
170             handler = getattr(self, name, None)
171             if handler:
172                 self.debug("=> %s%s" % (name[6:], repr(args)))
173                 handler(*args)
174    
175     def visit_instruction(self, op, lo=None, hi=None):
176         pass
177
178
179 class JumpCodeAdjuster(Visitor):
180     """JumpCodeAdjuster(obj=[func|co|str|list], start, end, newlength).
181     
182     Adjusts jump codes if their target is affected by bytecode changes.
183     
184     start, end: The range of the original bytecode in question.
185     newlength: Length of the codes which overwrote bytecode[start:end].
186     """
187    
188     def __init__(self, obj, start, end, newlength):
189         Visitor.__init__(self, obj)
190         self.start = start
191         self.end = end
192         self.offset = newlength - (end - start)
193    
194     def bytecode(self):
195         """Walk self and return new bytecode."""
196         self.walk()
197         return self.newcode
198    
199     def walk(self):
200         if self.offset == 0:
201             # Avoid costly walk if no changes will be made.
202             self.newcode = self._bytecode
203         else:
204             self.newcode = []
205             Visitor.walk(self)
206    
207     def visit_instruction(self, op, lo=None, hi=None):
208         append = self.newcode.append
209         append(op)
210         if lo is not None:
211             append(lo)
212         if hi is not None:
213             append(hi)
214    
215     def visit_CONTINUE_LOOP(self, lo, hi):
216         self.visit_JUMP_ABSOLUTE(lo, hi)
217    
218     def visit_JUMP_ABSOLUTE(self, lo, hi):
219         target = lo + (hi << 8)
220         if target > self.start:
221             pos = target + self.offset
222             self.newcode[-2:] = [pos & 0xFF, pos >> 8]
223    
224     def visit_JUMP_FORWARD(self, lo, hi):
225         delta = lo + (hi << 8)
226         target = self.cursor + delta
227         if self.cursor < self.end and target > self.start:
228             pos = (target + self.offset) - self.cursor
229             self.newcode[-2:] = [pos & 0xFF, pos >> 8]
230    
231     def visit_JUMP_IF_FALSE(self, lo, hi):
232         self.visit_JUMP_FORWARD(lo, hi)
233    
234     def visit_JUMP_IF_TRUE(self, lo, hi):
235         self.visit_JUMP_FORWARD(lo, hi)
236
237
238 def safe_tuple(seq):
239     """Force func_code attributes to tuples of strings.
240     
241     Many of the func_code attributes must take tuples, not lists,
242     and *cannot* accept unicode items--they must be coerced to strings
243     or the interpreter will crash.
244     """
245     seq = map(str, seq)
246     return tuple(seq)
247
248
249 class Rewriter(Visitor):
250     """Rewriter(obj=function or code object).
251     
252     Produce a new function or code object by rewriting an existing one.
253     
254     Notice that, unlike the base Visitor class, Rewriter does not accept a
255     string or list of opcodes as an initial argument.
256     """
257    
258     def __init__(self, obj):
259         self.verbose = False
260        
261         if isinstance(obj, FunctionType):
262             self._func = obj
263             obj = obj.func_code
264        
265         # Copy code object attributes.
266         for name in _co_code_attrs:
267             value = getattr(obj, name)
268             if isinstance(value, tuple):
269                 value = list(value)
270             setattr(self, name, value)
271        
272         try:
273             obj = obj.co_code
274         except AttributeError:
275             pass
276        
277         # Map the code block to a list of opcode numbers.
278         if isinstance(obj, basestring):
279             bytecode = map(ord, obj)
280         elif isinstance(obj, list):
281             bytecode = obj[:]
282         else:
283             raise TypeError("obj arg of incorrect type '%s'" % type(obj))
284        
285         self._bytecode = bytecode
286    
287     def bytecode(self):
288         """Walk self and return new bytecode."""
289         self.walk()
290         return self.newcode
291    
292     def code_object(self):
293         """Walk self and produce a new code object."""
294         self.walk()
295         codestr = ''.join(map(chr, self.newcode))
296         return CodeType(self.co_argcount, self.co_nlocals, self.co_stacksize,
297                         # Notice co_consts should *not* be safe_tupled.
298                         self.co_flags, codestr, tuple(self.co_consts),
299                         safe_tuple(self.co_names), safe_tuple(self.co_varnames),
300                         self.co_filename, self.co_name, self.co_firstlineno,
301                         self.co_lnotab, safe_tuple(self.co_freevars),
302                         safe_tuple(self.co_cellvars))
303    
304     def function(self, newname=None):
305         """Walk self and produce a new function."""
306         try:
307             f = self._func
308         except AttributeError:
309             if newname is None:
310                 newname = ''
311             co = self.code_object()
312             return FunctionType(co, {}, newname)
313         else:
314             if newname is None:
315                 newname = f.func_name
316             co = self.code_object()
317             return FunctionType(co, f.func_globals, newname,
318                                 f.func_defaults, f.func_closure)
319    
320     def const_index(self, value):
321         """The index of value in co_consts, appending it if not found."""
322         for pos, item in enumerate(self.co_consts):
323             try:
324                 if type(value) == type(item) and value == item:
325                     break
326             except TypeError:
327                 pass
328         else:
329             pos = len(self.co_consts)
330             self.co_consts.append(value)
331         return pos
332    
333     def name_index(self, value):
334         """The index of value in co_names, appending it if not found."""
335         valtype = type(value)
336         for pos, item in enumerate(self.co_names):
337             try:
338                 if valtype == type(item) and value == item:
339                     return pos
340             except TypeError:
341                 pass
342        
343         pos = len(self.co_names)
344         self.co_names.append(value)
345         return pos
346    
347     def walk(self):
348         self.newcode = []
349         Visitor.walk(self)
350    
351     def visit_instruction(self, op, lo=None, hi=None):
352         append = self.newcode.append
353         append(op)
354         if lo is not None:
355             append(lo)
356         if hi is not None:
357             append(hi)
358    
359     def put(self, start, end, *bits):
360         """Overwrite self.newcode with new opcodes (numbers or names).
361         
362         If the new codes are of different quantity than the old,
363         modify any jump codes affected.
364         """
365         bitnums = numeric_opcodes(bits)
366        
367         # Adjust jump codes. Notice this comes before bytecode is modified.
368         jca = JumpCodeAdjuster(self.newcode, start, end, len(bitnums))
369         self.newcode = jca.bytecode()
370        
371         # Rewrite bytecode.
372         self.newcode[start:end] = bitnums
373    
374     def tail(self, length, *bits):
375         """Overwrite self.newcode[-length:] with bits."""
376         end = len(self.newcode)
377         self.put(end - length, end, *bits)
378
379
380 class Localizer(Rewriter):
381     """Localizer(func, builtin_only=False, stoplist=[], verbose=False)
382     
383     If a global or builtin is known at compile time, replace it with a constant.
384     
385     This duplicates (and borrows from) Raymond Hettinger's Cookbook recipe
386     at: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/277940
387     """
388     def __init__(self, func, builtin_only=False, stoplist=[], verbose=False):
389         Rewriter.__init__(self, func)
390        
391         import __builtin__
392         self.env = vars(__builtin__).copy()
393         if not builtin_only:
394             self.env.update(func.func_globals)
395        
396         self.stoplist = stoplist
397         self.verbose = verbose
398    
399     def visit_LOAD_GLOBAL(self, lo, hi):
400         name = self.co_names[lo + (hi << 8)]
401         if name in self.env and name not in self.stoplist:
402             value = self.env[name]
403             pos = self.const_index(value)
404             self.tail(3, 'LOAD_CONST', pos & 0xFF, pos >> 8)
405             self.debug(name, ' --> ', value)
406
407
408 class TaintableStack(list):
409     def __init__(self, seq=[]):
410         list.__init__(self, seq)
411         self._taintindex = sets.Set()
412         self.maxsize = len(seq)
413    
414     def taint(self, index=-1):
415         if index < 0:
416             index += len(self)
417         self._taintindex.add(index)
418    
419     def tainted(self, index=-1):
420         if index < 0:
421             index = len(self) + index
422         return (index in self._taintindex)
423    
424     def pop(self, index=-1):
425         """pop(index) -> Returns a tuple!! of (value, tainted)."""
426         if index < 0:
427             index += len(self)
428         is_tainted = (index in self._taintindex)
429         if is_tainted:
430             self._taintindex.remove(index)
431         return list.pop(self, index), is_tainted
432    
433     def append(self, obj):
434         list.append(self, obj)
435         l = len(self)
436         if l > self.maxsize:
437             self.maxsize = l
438
439
440 class EarlyBinder(Rewriter):
441     """EarlyBinder(func, reduce_getattr=True, bind_late=[])
442     
443     Deep-evaluate function, replacing free vars with constants.
444     
445     reduce_getattr: If True (the default), getattr(x, y) will be
446         replaced with x.y where possible.
447     
448     bind_late: a list of names (globals, freevars, or attributes)
449         which should not be early-bound. For example, if you want
450         datetime.date.today() to be bound late, include 'today'
451         in the bind_late.
452     
453     Example: k = lambda x: x.Date == datetime.date(2004, 1, 1)
454              r = EarlyBinder(k).function()
455             
456     _____ k _____                _____ r _____
457      0 LOAD_FAST     0 (x)        0 LOAD_FAST   0 (x)
458      3 LOAD_ATTR     1 (Date)     3 LOAD_ATTR   1 (Date)
459      6 LOAD_GLOBAL   2 (datetime) 6 LOAD_CONST  6 (datetime.date(2004, 1, 1))
460      9 LOAD_ATTR     3 (date)
461     12 LOAD_CONST    1 (2004)
462     15 LOAD_CONST    2 (1)
463     18 LOAD_CONST    2 (1)
464     21 CALL_FUNCTION 3
465     24 COMPARE_OP    2 (==)       9 COMPARE_OP  2 (==)
466     27 RETURN_VALUE              12 RETURN_VALUE
467     
468     This also pre-computes binary operations *and all other builtin or free
469     functions* where all operands are constants, globals, or freevars.
470     For example:
471         LOAD_CONST        1 (3)
472         LOAD_CONST        2 (4)
473         BINARY_MULTIPLY
474         
475     is replaced with:
476         LOAD_CONST        5 (12)
477     
478     However, order is important. lambda x: x * 4 * 5 won't see any
479     optimization, because the order of eval is (x * 4) * 5. Rewritten
480     as lambda x: 4 * 5 * x, the "4 * 5" can be replaced with "20".
481     """
482    
483     def __init__(self, func, reduce_getattr=True, bind_late=[]):
484         Rewriter.__init__(self, func)
485         self.reduce_getattr = reduce_getattr
486        
487         import __builtin__
488         self.env = vars(__builtin__).copy()
489         self.env.update(func.func_globals)
490        
491         # Keep a stack like the interpreter would. This does *not*
492         # get overwritten when self.newcode does--it emulates the
493         # original instructions (although tainted values may be dummies).
494         # When a local var is pushed onto this stack, it "taints" itself
495         # and any operations which depend upon it.
496         # This stack is not passed out of this class in any way.
497         self.stack = TaintableStack()
498         self.bind_late = bind_late
499    
500     def code_object(self):
501         """Walk self and produce a new code object."""
502         self.walk()
503         codestr = ''.join(map(chr, self.newcode))
504         # Assert CO_NOFREE, since all free vars should have been made constant.
505         self.co_flags |= 0x0040
506         co = CodeType(self.co_argcount, self.co_nlocals, self.stack.maxsize,
507                       self.co_flags, codestr, tuple(self.co_consts),
508                       safe_tuple(self.co_names), safe_tuple(self.co_varnames),
509                       '', self.co_name, 1,
510                       self.co_lnotab, (), ())
511         return co
512    
513     def function(self, newname=None):
514         """Walk self and produce a new function."""
515         try:
516             f = self._func
517         except AttributeError:
518             if newname is None:
519                 newname = ''
520             co = self.code_object()
521             return FunctionType(co, {}, newname)
522         else:
523             if newname is None:
524                 newname = f.func_name
525             co = self.code_object()
526             # All cells should be dereferenced, so force func_closure to None.
527             return FunctionType(co, f.func_globals, newname, f.func_defaults)
528    
529     def reduce(self, number_of_terms, transform=None, overwrite_length=None):
530         """If no stack args are to be bound late, rewrite previous opcodes.
531         
532         number_of_terms: the number of terms to pop off the stack.
533         
534         transform: a callback, to which we send the popped terms. They are
535             transformed in that function as needed, and returned.
536         
537         overwrite_length: the number of previous opcodes to overwrite. If
538             None, it defaults to (number of terms + 1 for the current
539             instruction) * 3.
540         """
541         if overwrite_length is None:
542             # +1 is for current bytecode. If any overwritten bytecode
543             # is not len 3, pass in a value for overwrite_length.
544             overwrite_length = (number_of_terms + 1) * 3
545        
546         # Pop the requested number of terms off the stack.
547         is_tainted = False
548         terms, taints = [], []
549         for i in xrange(number_of_terms):
550             term, taint = self.stack.pop()
551             taints.append(taint)
552             is_tainted |= taint
553             terms.append(term)
554        
555         # Now that all the stack-popping is done...
556         if is_tainted:
557             # We don't have to handle getattr if no args are
558             # tainted, because CALL_FUNCTION will do it normally.
559             if self.reduce_getattr:
560                 if (len(terms) == 3 and terms[2] == getattr
561                     and taints[1] and not taints[0]):
562                     # Form a new LOAD_ATTR instruction.
563                     pos = self.name_index(terms[0])
564                     # Unlike normal CALL_FUNCTION, we can't assume each arg
565                     # is a constant; therefore, our overwrite_length is
566                     # indeterminate. We'll just cheat and keep track of
567                     # the last LOAD_GLOBAL where we looked up getattr. ;)
568                     start = self.last_getattr
569                     # Grab and reuse opcodes of first (LOAD_FAST) term.
570                     bits = self.newcode[start + 3:-6]
571                     bits += ['LOAD_ATTR', pos & 0xFF, pos >> 8]
572                     bits = tuple(bits)
573                     self.put(start, len(self.newcode), *bits)
574                     self.stack.append(None)
575                     self.stack.taint()
576                     return None
577            
578             # Don't form the new object.
579             # Replace TOS with a dummy and taint it.
580             self.stack.append(None)
581             self.stack.taint()
582             return None
583        
584         # Callback the transform.
585         terms.reverse()
586         if transform:
587             result = transform(terms)
588         else:
589             result = terms
590        
591         # Replace TOS with result.
592         self.stack.append(result)
593        
594         # Overwrite bytecodes with new CONST formed from result.
595         pos = self.const_index(result)
596         self.tail(overwrite_length, 'LOAD_CONST', pos & 0xFF, pos >> 8)
597        
598         return result
599    
600     def visit_BUILD_TUPLE(self, lo, hi):
601         self.reduce(lo + (hi << 8), lambda terms: tuple(terms))
602    
603     def visit_BUILD_LIST(self, lo, hi):
604         self.reduce(lo + (hi << 8))
605    
606     def visit_CALL_FUNCTION(self, lo, hi):
607         def call(terms):
608             func = terms.pop(0)
609             args = tuple(terms[:lo])
610             kwargs = {}
611             for i in range(hi):
612                 key = self.terms.pop(0)
613                 val = self.terms.pop(0)
614                 kwargs[key] = val
615             return func(*args, **kwargs)
616         self.reduce(lo + hi + 1, call)
617    
618     def visit_COMPARE_OP(self, lo, hi):
619         op = cmp_op[lo + (hi << 8)]
620         op = comparisons[op]
621         self.reduce(2, lambda terms: op(*terms))
622    
623     def visit_LOAD_ATTR(self, lo, hi):
624         name = self.co_names[lo + (hi << 8)]
625         result = self.reduce(1, lambda terms: getattr(terms[0], name))
626         if result in self.bind_late or getattr(result, 'bind_late', False):
627             self.stack.taint()
628    
629     def visit_LOAD_CONST(self, lo, hi):
630         self.stack.append(self.co_consts[lo + (hi << 8)])
631    
632     def visit_LOAD_DEREF(self, lo, hi):
633         if hasattr(self, '_func'):
634             name = self.co_freevars[lo + (hi << 8)]
635             value = self._func.func_closure[lo + (hi << 8)]
636             value = deref_cell(value)
637             pos = self.const_index(value)
638             self.tail(3, 'LOAD_CONST', pos & 0xFF, pos >> 8)
639             self.stack.append(value)
640             if value in self.bind_late or getattr(value, 'bind_late', False):
641                 self.stack.taint()
642    
643     def visit_LOAD_FAST(self, lo, hi):
644         self.stack.append(self.co_varnames[lo + (hi << 8)])
645         # LOAD_FAST references our bound variable, which is always bound late.
646         self.stack.taint()
647    
648     def visit_LOAD_GLOBAL(self, lo, hi):
649         name = self.co_names[lo + (hi << 8)]
650         if name == 'getattr':
651             self.last_getattr = (len(self.newcode) - 3)
652         if name in self.env:
653             value = self.env[name]
654             pos = self.const_index(value)
655             self.tail(3, 'LOAD_CONST', pos & 0xFF, pos >> 8)
656             self.stack.append(value)
657             if value in self.bind_late or getattr(value, 'bind_late', False):
658                 self.stack.taint()
659         else:
660             raise KeyError("'%s' is not present in supplied globals." % name)
661    
662     def visit_SLICE_PLUS_0(self):
663         self.reduce(1, lambda terms: terms[0][:], 4)
664    
665     def visit_SLICE_PLUS_1(self):
666         self.reduce(2, lambda terms: terms[0][terms[1]:], 7)
667    
668     def visit_SLICE_PLUS_2(self):
669         self.reduce(2, lambda terms: terms[0][:terms[1]], 7)
670    
671     def visit_SLICE_PLUS_3(self):
672         self.reduce(3, lambda terms: terms[0][terms[1]:terms[2]], 10)
673    
674     def binary_op(self, op):
675         def operate(terms):
676             return op(*terms)
677         self.reduce(2, operate, 7)
678
679 # Add visit_BINARY, visit_INPLACE methods to EarlyBinder.
680 for k, v in binary_operators.iteritems():
681     setattr(EarlyBinder, "visit_" + k,
682             lambda self, opr=v: self.binary_op(opr))
683 for k, v in inplace_operators.iteritems():
684     setattr(EarlyBinder, "visit_" + k,
685             # Yes, we really do call binary_op for inplace methods.
686             lambda self, opr=v: self.binary_op(opr))
687
688
689 class LambdaDecompiler(Visitor):
690     """LambdaDecompiler(obj=lambda function or func_code).
691     
692     Produce decompiled Python code (as a string) from a supplied lambda."""
693    
694     def __init__(self, obj):
695         self.verbose = False
696        
697         # Distill supplied 'obj' arg to a code block string.
698         if isinstance(obj, FunctionType):
699             obj = obj.func_code
700        
701         # Copy code object attributes.
702         for name in _co_code_attrs:
703             value = getattr(obj, name)
704             if isinstance(value, tuple):
705                 value = list(value)
706             setattr(self, name, value)
707        
708         try:
709             obj = obj.co_code
710         except AttributeError:
711             pass
712        
713         # Map the code block to a list of opcode numbers.
714         if isinstance(obj, basestring):
715             obj = map(ord, obj)
716         elif isinstance(obj, list):
717             obj = obj[:]
718        
719         self._bytecode = obj
720    
721     def code(self, include_func_header=True):
722         self.walk()
723         product = self.stack[0]
724         if include_func_header:
725             args = list(self.co_varnames)
726             if self.co_flags & 0x08:
727                 args[-1] = "**" + args[-1]
728                 if self.co_flags & 0x04:
729                     args[-2] = "*" + args[-2]
730             elif self.co_flags & 0x04:
731                 args[-1] = "*" + args[-1]
732             args = ", ".join(args)
733            
734             product = "lambda %s: %s" % (args, product)
735         return product
736    
737     def walk(self):
738         self.stack = []
739         self.newcode = []
740         self.targets = {}
741        
742         Visitor.walk(self)
743        
744         self.debug("stack:", self.stack)
745    
746     def visit_instruction(self, op, lo=None, hi=None):
747         # Get the instruction pointer for the current instruction.
748         ip = self.cursor - 3
749         if hi is None:
750             ip += 1
751             if lo is None:
752                 ip += 1
753        
754         terms = self.targets.get(ip)
755         if terms:
756             clause = self.stack[-1]
757             while terms:
758                 term, oper = terms.pop()
759                 clause = "(%s) %s (%s)" % (term, oper, clause)
760             # Replace TOS with the new clause, so that further
761             # combinations have access to it.
762             self.stack[-1] = clause
763             self.debug("clause:", clause, "\n")
764            
765             if op == 1:
766                 # Py2.4: The current instruction is POP_TOP, which means
767                 # the previous is probably JUMP_*. If so, we don't want to
768                 # pop the value we just placed on the stack and lose it.
769                 # We need to replace the entry that the JUMP_* made in
770                 # self.targets with our new TOS.
771                 target = self.targets[self.last_target_ip]
772                 target[-1] = ((clause, target[-1][1]))
773                 self.debug("newtarget:", self.last_target_ip, target)
774    
775     def visit_BUILD_TUPLE(self, lo, hi):
776         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)])
777         self.stack.append("(%s)" % terms)
778    
779     def visit_BUILD_LIST(self, lo, hi):
780         terms = ", ".join([self.stack.pop() for i in range(lo + hi << 8)])
781         self.stack.append("[%s]" % terms)
782    
783     def visit_CALL_FUNCTION(self, lo, hi):
784         kwargs = {}
785         for i in range(hi):
786             val = self.stack.pop()
787             key = self.stack.pop()
788             kwargs[key] = val
789         kwargs = ", ".join([k + "=" + v for k, v in kwargs.iteritems()])
790        
791         args = []
792         for i in xrange(lo):
793             arg = self.stack.pop()
794             args.append(arg)
795         args.reverse()
796         args = ", ".join(args)
797        
798         if kwargs:
799             args += ", " + kwargs
800        
801         func = self.stack.pop()
802         self.stack.append("%s(%s)" % (func, args))
803    
804     def visit_COMPARE_OP(self, lo, hi):
805         term2, term1 = self.stack.pop(), self.stack.pop()
806         op = cmp_op[lo + (hi << 8)]
807         self.stack.append(term1 + " " + op + " " + term2)
808         self.debug(op)
809    
810     def visit_JUMP_IF_FALSE(self, lo, hi):
811         # Note that self.cursor has already advanced to the next instruction.
812         target = self.cursor + (lo + (hi << 8))
813         bucket = self.targets.setdefault(target, [])
814         bucket.append((self.stack[-1], 'and'))
815         self.debug("target:", target, bucket)
816         # Store target ip for the special code in visit_instruction
817         self.last_target_ip = target
818    
819     def visit_JUMP_IF_TRUE(self, lo, hi):
820         # Note that self.cursor has already advanced to the next instruction.
821         target = self.cursor + (lo + (hi << 8))
822         bucket = self.targets.setdefault(target, [])
823         bucket.append((self.stack[-1], 'or'))
824         self.debug("target:", target, bucket)
825         # Store target ip for the special code in visit_instruction
826         self.last_target_ip = target
827    
828     def visit_LOAD_ATTR(self, lo, hi):
829         term = self.co_names[lo + (hi << 8)]
830         self.stack[-1] += ("." + term)
831         self.debug(term)
832    
833     def visit_LOAD_CONST(self, lo, hi):
834         val = self.co_consts[lo + (hi << 8)]
835         if isinstance(val, (FunctionType, type)):
836             # The const in question is a factory function.
837             self.stack.append(val.__module__ + "." + val.__name__)
838         else:
839             term = repr(val)
840             if hasattr(val, "__module__"):
841                 if not term.startswith(val.__module__ + "."):
842                     term = val.__module__ + "." + term
843             self.stack.append(term)
844             self.debug(term)
845    
846     def visit_LOAD_FAST(self, lo, hi):
847         term = self.co_varnames[lo + (hi << 8)]
848         self.stack.append(term)
849         self.debug(term)
850    
851     def visit_LOAD_GLOBAL(self, lo, hi):
852         self.stack.append(self.co_names[lo + (hi << 8)])
853    
854     def visit_POP_TOP(self):
855         self.stack.pop()
856    
857     def visit_SLICE_PLUS_0(self):
858         arg = self.stack.pop()
859         self.stack.append("%s[:]" % arg)
860    
861     def visit_SLICE_PLUS_1(self):
862         args = tuple(self.stack[-2:])
863         del self.stack[-2:]
864         self.stack.append("%s[%s:]" % args)
865    
866     def visit_SLICE_PLUS_2(self):
867         args = tuple(self.stack[-2:])
868         del self.stack[-2:]
869         self.stack.append("%s[:%s]" % args)
870    
871     def visit_SLICE_PLUS_3(self):
872         args = tuple(self.stack[-3:])
873         del self.stack[-3:]
874         self.stack.append("%s[%s:%s]" % args)
875    
876     def visit_UNARY_CONVERT(self):
877         term = self.stack.pop()
878         self.stack.append("`(" + term + ")`")
879    
880     def visit_UNARY_INVERT(self):
881         term = self.stack.pop()
882         self.stack.append("~(" + term + ")")
883    
884     def visit_UNARY_NEGATIVE(self):
885         term = self.stack.pop()
886         self.stack.append("-(" + term + ")")
887    
888     def visit_UNARY_NOT(self):
889         term = self.stack.pop()
890         self.stack.append("not (" + term + ")")
891    
892     def visit_UNARY_POSITIVE(self):
893         term = self.stack.pop()
894         self.stack.append("+(" + term + ")")
895    
896     def binary_op(self, op):
897         op2, op1 = self.stack.pop(), self.stack.pop()
898         self.stack.append(op1 + " " + op + " " + op2)
899    
900     def visit_BINARY_SUBSCR(self):
901         op2, op1 = self.stack.pop(), self.stack.pop()
902         self.stack.append(op1 + "[" + op2 + "]")
903
904 # Add visit_BINARY methods to LambdaDecompiler.
905 for k, v in binary_repr.iteritems():
906     setattr(LambdaDecompiler, "visit_" + k,
907             lambda self, op=v: self.binary_op(op))
908
909
910 class BranchTracker(Visitor):
911     """BranchTracker(obj=[func|co|str|list]).
912     
913     Finds all possible instructions previous to the supplied instruction(s).
914     """
915    
916     def branches(self, instr=None):
917         """Walk self and return all possible instructions previous to instr.
918         
919         If instr is None, the last instruction will be used.
920         """
921         if instr is None:
922             instr = len(self._bytecode) - 1
923         self.watch = {instr: []}
924         self.walk()
925         return self.watch[instr]
926    
927     def visit_instruction(self, op, lo=None, hi=None):
928         if self.cursor in self.watch and op != 113:
929             if lo is None and hi is None:
930                 self.watch[self.cursor].append(self.cursor - 1)
931             else:
932                 self.watch[self.cursor].append(self.cursor - 3)
933    
934     def visit_CONTINUE_LOOP(self, lo, hi):
935         self.visit_JUMP_ABSOLUTE(lo, hi)
936    
937     def visit_JUMP_ABSOLUTE(self, lo, hi):
938         target = lo + (hi << 8)
939         if target in self.watch:
940             self.watch[target].append(self.cursor)
941    
942     def visit_JUMP_FORWARD(self, lo, hi):
943         delta = lo + (hi << 8)
944         target = self.cursor + delta
945         if target in self.watch:
946             self.watch[target].append(self.cursor - 3)
947    
948     def visit_JUMP_IF_FALSE(self, lo, hi):
949         self.visit_JUMP_FORWARD(lo, hi)
950    
951     def visit_JUMP_IF_TRUE(self, lo, hi):
952         self.visit_JUMP_FORWARD(lo, hi)
953
954
955 class KeywordInspector(Rewriter):
956     """Produce a list of all keyword arguments expected."""
957    
958     def __init__(self, obj):
959         """KeywordInspector(obj). List keyword arguments expected."""
960         Rewriter.__init__(self, obj)
961         if not (self.co_flags & 0x0008):
962             raise ValueError("'%s' does not possess **kwargs." % obj)
963         if len(self.co_varnames) <= 1:
964             raise ValueError("'%s' does not possess more than 1 varname." % obj)
965         self._kwargs = []
966         self.flag = None
967    
968     def kwargs(self):
969         """kwargs() -> List of keyword arguments expected."""
970         self.walk()
971         return self._kwargs
972    
973     def visit_instruction(self, op, lo=None, hi=None):
974         if op == 124 and (lo + (hi << 8) == len(self.co_varnames) - 1):
975             self.flag = ''
976         elif op == 100 and self.flag == '':
977             self.flag = self.co_consts[lo + (hi << 8)]
978         elif op == 25 and self.flag:
979             self._kwargs.append(self.flag)
980         else:
981             self.flag = None
982
983 del k, v
984
Note: See TracBrowser for help on using the browser.