Contact: fumanchu@aminus.org

Log in as guest/dejavu to create tickets

root/tags/1.4.0/codewalk.py

Revision 141 (checked in by fumanchu, 3 years ago)

Gratuitous premature optimization.

  • Property svn:eol-style set to native
Line 
1 """Bytecode visitors, including rewriters and decompilers.
2
3 This work, including the source code, documentation
4 and related data, is placed into the public domain.
5
6 The orginal author is Robert Brewer, Amor Ministries.
7
8 THIS SOFTWARE IS PROVIDED AS-IS, WITHOUT WARRANTY
9 OF ANY KIND, NOT EVEN THE IMPLIED WARRANTY OF
10 MERCHANTABILITY. THE AUTHOR OF THIS SOFTWARE
11 ASSUMES _NO_ RESPONSIBILITY FOR ANY CONSEQUENCE
12 RESULTING FROM THE USE, MODIFICATION, OR
13 REDISTRIBUTION OF THIS SOFTWARE.
14
15 """
16
17 from opcode import cmp_op, opname, opmap, HAVE_ARGUMENT
18 import operator
19 import sets
20 from types import CodeType, FunctionType
21
22
23 def named_opcodes(bits):
24     """Change initial numeric opcode bits to their named equivalents."""
25     bitnums = []
26     bits = iter(bits)
27     for x in bits:
28         bitnums.append(opname[x])
29         if x >= HAVE_ARGUMENT:
30             try:
31                 bitnums.append(bits.next())
32                 bitnums.append(bits.next())
33             except StopIteration:
34                 break
35     return bitnums
36
37 def numeric_opcodes(bits):
38     """Change named opcode bits to their numeric equivalents."""
39     bitnums = []
40     for x in bits:
41         if isinstance(x, basestring):
42             x = opmap[x]
43         bitnums.append(x)
44     return bitnums
45
46 _deref_bytecode = numeric_opcodes(['LOAD_DEREF', 0, 0, 'RETURN_VALUE'])
47 # CodeType(argcount, nlocals, stacksize, flags, codestring, constants,
48 #          names, varnames, filename, name, firstlineno,
49 #          lnotab[, freevars[, cellvars]])
50 _derefblock = CodeType(0, 0, 1, 3, ''.join(map(chr, _deref_bytecode)),
51                        (None,), ('cell',), (), '', '', 2, '', ('cell',))
52 def deref_cell(cell):
53     # FunctionType(code, globals[, name[, argdefs[, closure]]])
54     return FunctionType(_derefblock, {}, "", (), (cell,))()
55
56 def make_closure(*args):
57     def inner():
58         args
59     return inner.func_closure
60
61
62 binary_operators = {'BINARY_POWER': operator.pow,
63                     'BINARY_MULTIPLY': operator.mul,
64                     'BINARY_DIVIDE': operator.div,
65                     'BINARY_FLOOR_DIVIDE': operator.floordiv,
66                     'BINARY_TRUE_DIVIDE': operator.truediv,
67                     'BINARY_MODULO': operator.mod,
68                     'BINARY_ADD': operator.add,
69                     'BINARY_SUBTRACT': operator.sub,
70                     'BINARY_SUBSCR': operator.getitem,
71                     'BINARY_LSHIFT': operator.lshift,
72                     'BINARY_RSHIFT': operator.rshift,
73                     'BINARY_AND': operator.and_,
74                     'BINARY_XOR': operator.xor,
75                     'BINARY_OR': operator.or_,
76                     }
77 inplace_operators = dict([('INPLACE_' + k.split('_')[1], v)
78                           for k, v in binary_operators.iteritems()
79                           if k not in ('BINARY_SUBSCR',)
80                           ])
81
82 binary_repr = {'BINARY_POWER': '**',
83                'BINARY_MULTIPLY': '*',
84                'BINARY_DIVIDE': '/',
85                'BINARY_FLOOR_DIVIDE': '//',
86                'BINARY_TRUE_DIVIDE': '/',
87                'BINARY_MODULO': '%',
88                'BINARY_ADD': '+',
89                'BINARY_SUBTRACT': '-',
90                'BINARY_LSHIFT': '<<',
91                'BINARY_RSHIFT': '>>',
92                'BINARY_AND': '&',
93                'BINARY_XOR': '^',
94                'BINARY_OR': '|',
95                }
96
97 inplace_repr = dict([('INPLACE_' + k.split('_')[1], v + '=')
98                      for k, v in binary_repr.iteritems()])
99
100 comparisons = {'<': operator.lt,
101                '<=': operator.le,
102                '==': operator.eq,
103                '!=': operator.ne,
104                '>': operator.gt,
105                '>=': operator.gt,
106                'in': operator.contains,
107                'not in': lambda x, y: not x in y,
108                'is': operator.is_,
109                'is not': operator.is_not,
110                }
111
112 _co_code_attrs = ['co_argcount', 'co_cellvars', 'co_code', 'co_consts',
113                   'co_filename', 'co_firstlineno', 'co_flags', 'co_freevars',
114                   'co_lnotab', 'co_name', 'co_names', 'co_nlocals',
115                   'co_stacksize', 'co_varnames']
116
117
118 class Visitor(object):
119     """Visitor(obj=function, code object, string, or list of opcodes)."""
120    
121     def __init__(self, obj):
122         self.verbose = False
123        
124         # Distill supplied 'obj' arg to a code block string.
125         if isinstance(obj, FunctionType):
126             obj = obj.func_code
127         try:
128             obj = obj.co_code
129         except AttributeError:
130             pass
131        
132         # Map the code block string to a list of opcode numbers.
133         if isinstance(obj, basestring):
134             obj = map(ord, obj)
135         elif isinstance(obj, list):
136             obj = obj[:]
137        
138         self._bytecode = obj
139    
140     def debug(self, *messages):
141         if self.verbose:
142             for term in messages:
143                 print term,
144         else:
145             pass
146    
147     def walk(self):
148         self.cursor = 0
149         b = self._bytecode
150         self.debug("\nWALKING: ", b)
151         while self.cursor < len(b):
152             self.debug("\n", self.cursor)
153            
154             op = b[self.cursor]
155             self.cursor += 1
156             if op >= HAVE_ARGUMENT:
157                 lo = b[self.cursor]
158                 self.cursor += 1
159                 hi = b[self.cursor]
160                 self.cursor += 1
161                 args = (lo, hi)
162             else:
163                 args = ()
164            
165             self.debug("visit (%s, %s)" % (op, repr(args)))
166             self.visit_instruction(op, *args)
167            
168             name = 'visit_' + opname[op]
169             name = name.replace('+', '_PLUS_')
170             handler = getattr(self, name, None)
171             if handler:
172                 self.debug("=> %s%s" % (name[6:], repr(args)))
173                 handler(*args)
174    
175     def visit_instruction(self, op, lo=None, hi=None):
176         pass
177
178
179 class JumpCodeAdjuster(Visitor):
180     """JumpCodeAdjuster(obj=[func|co|str|list], start, end, newlength).
181     
182     Adjusts jump codes if their target is affected by bytecode changes.
183     
184     start, end: The range of the original bytecode in question.
185     newlength: Length of the codes which overwrote bytecode[start:end].
186     """
187    
188     def __init__(self, obj, start, end, newlength):
189         Visitor.__init__(self, obj)
190         self.start = start
191         self.end = end
192         self.offset = newlength - (end - start)
193    
194     def bytecode(self):
195         """Walk self and return new bytecode."""
196         self.walk()
197         return self.newcode
198    
199     def walk(self):
200         if self.offset == 0:
201             # Avoid costly walk if no changes will be made.
202             self.newcode = self._bytecode
203         else:
204             self.newcode = []
205             Visitor.walk(self)
206    
207     def visit_instruction(self, op, lo=None, hi=None):
208         append = self.newcode.append
209         append(op)
210         if lo is not None:
211             append(lo)
212         if hi is not None:
213             append(hi)
214    
215     def visit_CONTINUE_LOOP(self, lo, hi):
216         self.visit_JUMP_ABSOLUTE(lo, hi)
217    
218     def visit_JUMP_ABSOLUTE(self, lo, hi):
219         target = lo + (hi << 8)
220         if target > self.start:
221             pos = target + self.offset
222             self.newcode[-2:] = [pos & 0xFF, pos >> 8]
223    
224     def visit_JUMP_FORWARD(self, lo, hi):
225         delta = lo + (hi << 8)
226         target = self.cursor + delta
227         if self.cursor < self.end and target > self.start:
228             pos = (target + self.offset) - self.cursor
229             self.newcode[-2:] = [pos & 0xFF, pos >> 8]
230    
231     def visit_JUMP_IF_FALSE(self, lo, hi):
232         self.visit_JUMP_FORWARD(lo, hi)
233    
234     def visit_JUMP_IF_TRUE(self, lo, hi):
235         self.visit_JUMP_FORWARD(lo, hi)
236
237
238 def safe_tuple(seq):
239     """Force func_code attributes to tuples of strings.
240     
241     Many of the func_code attributes must take tuples, not lists,
242     and *cannot* accept unicode items--they must be coerced to strings
243     or the interpreter will crash.
244     """
245     seq = map(str, seq)
246     return tuple(seq)
247
248
249 class Rewriter(Visitor):
250     """Rewriter(obj=function or code object).
251     
252     Produce a new function or code object by rewriting an existing one.
253     
254     Notice that, unlike the base Visitor class, Rewriter does not accept a
255     string or list of opcodes as an initial argument.
256     """
257    
258     def __init__(self, obj):
259         self.verbose = False
260        
261         if isinstance(obj, FunctionType):
262             self._func = obj
263             obj = obj.func_code
264        
265         # Copy code object attributes.
266         for name in _co_code_attrs:
267             value = getattr(obj, name)
268             if isinstance(value, tuple):
269                 value = list(value)
270             setattr(self, name, value)
271        
272         try:
273             obj = obj.co_code
274         except AttributeError:
275             pass
276        
277         # Map the code block to a list of opcode numbers.
278         if isinstance(obj, basestring):
279             bytecode = map(ord, obj)
280         elif isinstance(obj, list):
281             bytecode = obj[:]
282         else:
283             raise TypeError("obj arg of incorrect type '%s'" % type(obj))
284        
285         self._bytecode = bytecode
286    
287     def bytecode(self):
288         """Walk self and return new bytecode."""
289         self.walk()
290         return self.newcode
291    
292     def code_object(self):
293         """Walk self and produce a new code object."""
294         self.walk()
295         codestr = ''.join(map(chr, self.newcode))
296         return CodeType(self.co_argcount, self.co_nlocals, self.co_stacksize,
297                         # Notice co_consts should *not* be safe_tupled.
298                         self.co_flags, codestr, tuple(self.co_consts),
299                         safe_tuple(self.co_names), safe_tuple(self.co_varnames),
300                         self.co_filename, self.co_name, self.co_firstlineno,
301                         self.co_lnotab, safe_tuple(self.co_freevars),
302                         safe_tuple(self.co_cellvars))
303    
304     def function(self, newname=None):
305         """Walk self and produce a new function."""
306         try:
307             f = self._func
308         except AttributeError:
309             if newname is None:
310                 newname = ''
311             co = self.code_object()
312             return FunctionType(co, {}, newname)
313         else:
314             if newname is None:
315                 newname = f.func_name
316             co = self.code_object()
317             return FunctionType(co, f.func_globals, newname,
318                                 f.func_defaults, f.func_closure)
319    
320     def const_index(self, value):
321         """The index of value in co_consts, appending it if not found."""
322         for pos, item in enumerate(self.co_consts):
323             try:
324                 if type(value) == type(item) and value == item:
325                     break
326             except TypeError:
327                 pass
328         else:
329             pos = len(self.co_consts)
330             self.co_consts.append(value)
331         return pos
332    
333     def name_index(self, value):
334         """The index of value in co_names, appending it if not found."""
335         valtype = type(value)
336         for pos, item in enumerate(self.co_names):
337             try:
338                 if valtype == type(item) and value == item:
339                     return pos
340             except TypeError:
341                 pass
342        
343         pos = len(self.co_names)
344         self.co_names.append(value)
345         return pos
346    
347     def walk(self):
348         self.newcode = []
349         Visitor.walk(self)
350    
351     def visit_instruction(self, op, lo=None, hi=None):
352         append = self.newcode.append
353         append(op)
354         if lo is not None:
355             append(lo)
356         if hi is not None:
357             append(hi)
358    
359     def put(self, start, end, *bits):
360         """Overwrite self.newcode with new opcodes (numbers or names).
361         
362         If the new codes are of different quantity than the old,
363         modify any jump codes affected.
364         """
365         bitnums = numeric_opcodes(bits)
366        
367         # Adjust jump codes. Notice this comes before bytecode is modified.
368         jca = JumpCodeAdjuster(self.newcode, start, end, len(bitnums))
369         self.newcode = jca.bytecode()
370        
371         # Rewrite bytecode.
372         self.newcode[start:end] = bitnums
373    
374     def tail(self, length, *bits):
375         """Overwrite self.newcode[-length:] with bits."""
376         end = len(self.newcode)
377         self.put(end - length, end, *bits)
378
379
380 class Localizer(Rewriter):
381     """Localizer(func, builtin_only=False, stoplist=[], verbose=False)
382     
383     If a global or builtin is known at compile time, replace it with a constant.
384     
385     This duplicates (and borrows from) Raymond Hettinger's Cookbook recipe
386     at: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/277940
387     """
388     def __init__(self, func, builtin_only=False, stoplist=[], verbose=False):
389         Rewriter.__init__(self, func)
390        
391         import __builtin__
392         self.env = vars(__builtin__).copy()
393         if not builtin_only:
394             self.env.update(func.func_globals)
395        
396         self.stoplist = stoplist
397         self.verbose = verbose
398    
399     def visit_LOAD_GLOBAL(self, lo, hi):
400         name = self.co_names[lo + (hi << 8)]
401         if name in self.env and name not in self.stoplist:
402             value = self.env[name]
403             pos = self.const_index(value)
404             self.tail(3, 'LOAD_CONST', pos & 0xFF, pos >> 8)
405             self.debug(name, ' --> ', value)
406
407
408 class TaintableStack(list):
409     def __init__(self, seq=[]):
410         list.__init__(self, seq)
411         self._taintindex = sets.Set()
412         self.maxsize = len(seq)
413    
414     def taint(self, index=-1):
415         if index < 0:
416             index += len(self)
417         self._taintindex.add(index)
418    
419     def tainted(self, index=-1):
420         if index < 0:
421             index = len(self) + index
422         return (index in self._taintindex)
423    
424     def pop(self, index=-1):
425         """pop(index) -> Returns a tuple!! of (value, tainted)."""
426         if index < 0:
427             index += len(self)
428         is_tainted = (index in self._taintindex)
429         if is_tainted:
430             self._taintindex.remove(index)
431         return list.pop(self, index), is_tainted
432    
433     def append(self, obj):
434         list.append(self, obj)
435         l = len(self)
436         if l > self.maxsize:
437             self.maxsize = l
438
439
440 class EarlyBinder(Rewriter):
441     """EarlyBinder(func, reduce_getattr=True, bind_late=[])
442     
443     Deep-evaluate function, replacing free vars with constants.
444     
445     reduce_getattr: If True (the default), getattr(x, y) will be
446         replaced with x.y where possible.
447     
448     bind_late: a list of names (globals, freevars, or attributes)
449         which should not be early-bound. For example, if you want
450         datetime.date.today() to be bound late, include 'today'
451         in the bind_late.
452     
453     Example: k = lambda x: x.Date == datetime.date(2004, 1, 1)
454              r = EarlyBinder(k).function()
455             
456     _____ k _____                _____ r _____
457      0 LOAD_FAST     0 (x)        0 LOAD_FAST   0 (x)
458      3 LOAD_ATTR     1 (Date)     3 LOAD_ATTR   1 (Date)
459      6 LOAD_GLOBAL   2 (datetime) 6 LOAD_CONST  6 (datetime.date(2004, 1, 1))
460      9 LOAD_ATTR     3 (date)
461     12 LOAD_CONST    1 (2004)
462     15 LOAD_CONST    2 (1)
463     18 LOAD_CONST    2 (1)
464     21 CALL_FUNCTION 3
465     24 COMPARE_OP    2 (==)       9 COMPARE_OP  2 (==)
466     27 RETURN_VALUE              12 RETURN_VALUE
467     
468     This also pre-computes binary operations *and all other builtin or free
469     functions* where all operands are constants, globals, or freevars.
470     For example:
471         LOAD_CONST        1 (3)
472         LOAD_CONST        2 (4)
473         BINARY_MULTIPLY
474         
475     is replaced with:
476         LOAD_CONST        5 (12)
477     
478     However, order is important. lambda x: x * 4 * 5 won't see any
479     optimization, because the order of eval is (x * 4) * 5. Rewritten
480     as lambda x: 4 * 5 * x, the "4 * 5" can be replaced with "20".
481     """
482    
483     def __init__(self, func, reduce_getattr=True, bind_late=[]):
484         Rewriter.__init__(self, func)
485         self.reduce_getattr = reduce_getattr
486        
487         import __builtin__
488         self.env = vars(__builtin__).copy()
489         self.env.update(func.func_globals)
490        
491         # Keep a stack like the interpreter would. This does *not*