#! /usr/local/bin/python import sys import compiler from compiler.ast import * import copy debug = False def any(ls): for x in ls: if x: return True return False def prepend_stmts(ss, s): if isinstance(s, Stmt): return Stmt(ss + s.nodes) else: return Stmt(ss + [s]) def append_stmts(s1, s2): if isinstance(s1, Stmt): if isinstance(s2,Stmt): return Stmt(s1.nodes + s2.nodes) else: return Stmt(s1.nodes + s2) elif isinstance(s2,Stmt): return Stmt([s1] + s2.nodes) else: return Stmt([s1] + [s2]) # lhs : string, rhs : expr def make_assign(lhs, rhs): return Assign(nodes=[AssName(name=lhs, flags='OP_ASSIGN')], expr=rhs) def make_assign_t(lhs, rhs, t): ret = Assign(nodes=[AssName(name=lhs, flags='OP_ASSIGN')], expr=rhs) ret.type = t return ret ############################################################################### # Simplify: # comparison and logic operators # lambdas and function definitions become lambdas with statement body # New Classes for the Intermediate Representation that # provide a uniform representation for binary and unary operations class PrimitiveOp(Node): def __init__(self, name, nodes, lineno=None): self.name = name self.nodes = nodes self.lineno = lineno def getChildren(self): return self.nodes def getChildNodes(self): return self.nodes def __repr__(self): return "PrimitiveOp(%s, %s)" % (self.name, repr(self.nodes)) class Let(Node): def __init__(self, var, rhs, body, lineno=None): self.var = var self.rhs = rhs self.body = body self.lineno = lineno def getChildren(self): return self.rhs, self.body def getChildNodes(self): return self.rhs, self.body def __repr__(self): return "Let(%s, %s, %s)" % (repr(self.var), repr(self.rhs), repr(self.body)) # the following counter is for generating unique names counter = 0 def generate_name(x): global counter name = x + str(counter) counter = counter + 1 return name binary_op_classes = [Add, Sub, Mul, Div, Mod, Power] unary_op_classes = [UnaryAdd, UnarySub, Not] class_to_fun = {Add: 'add', Sub: 'sub', Mul: 'mul', Div: 'floordiv', \ Mod: 'mod', Power: 'power', \ UnaryAdd: 'unary_add', UnarySub: 'unary_sub', Not: 'logic_not' } compare_to_fun = {'<': 'less', '>': 'greater', '<=': 'less_equal', '>=': 'greater_equal', \ '==': 'equal', '!=': 'not_equal', 'is': 'identical' } def is_simple(n): return isinstance(n, (Const, Name)) def make_subscript(expr, index): return PrimitiveOp('deref', [PrimitiveOp('subscript', [expr, index])]) def make_get_attr(expr, attrname): return PrimitiveOp('get_attr', [expr, attrname]) def make_set_attr(lhs_expr, attrname, rhs_expr): return PrimitiveOp('set_attr', [lhs_expr, attrname, rhs_expr]) def simplify_ops(n): if debug: print 'simplify ops ', n if isinstance(n, Class): code = simplify_ops(n.code) return Class(n.name, n.bases, n.doc, code, n.lineno) elif isinstance(n, Module): node = simplify_ops(n.node) return Module(n.doc, node) elif isinstance(n, Stmt): nodes = [simplify_ops(s) for s in n.nodes] return Stmt(nodes) elif isinstance(n, Printnl): return Printnl([simplify_ops(e) for e in n.nodes], n.dest) elif isinstance(n, Discard): if isinstance(n.expr, Const) and n.expr.value == None: return Pass() else: return Discard(simplify_ops(n.expr)) elif isinstance(n, If): tests = [(simplify_ops(cond), simplify_ops(body)) for (cond,body) in n.tests] else_ = simplify_ops(n.else_) return If(tests, else_) elif n == None: # to handle when an If's else_ clause is not there return None elif isinstance(n, While): test = simplify_ops(n.test) body = simplify_ops(n.body) if n.else_ == None: return While(test, body, None) else: else_ = simplify_ops(n.else_) tmp = generate_name('__tmp') return Stmt([make_sssign(tmp, test), If(Name(tmp), While(Name(tmp), body, None), else_)]) elif isinstance(n, Pass): return n elif isinstance(n, Assign): rhs = simplify_ops(n.expr) lhs = n.nodes[0] if isinstance(lhs, AssName): return Assign([lhs], rhs) elif isinstance(lhs, AssAttr): lhs_expr = simplify_ops(lhs.expr) return Discard(make_set_attr(lhs_expr, Const(lhs.attrname), rhs)) elif isinstance(lhs, Subscript): return Assign([make_subscript(simplify_ops(lhs.expr), simplify_ops(lhs.subs[0]))], rhs) else: raise Exception('Error in simplify_ops: unrecognized lhs of assignment ' + repr(lhs)) elif isinstance(n, Function): code = simplify_ops(n.code) return make_assign(n.name, Lambda(n.argnames, n.defaults, n.flags, code)) elif isinstance(n, Return): return Return(simplify_ops(n.value)) elif isinstance(n, Const): return n elif isinstance(n, Name): return n elif n.__class__ in binary_op_classes: name = class_to_fun[n.__class__] left = simplify_ops(n.left) l_name = generate_name('__left') right = simplify_ops(n.right) return Let(l_name, left, PrimitiveOp(name, [Name(l_name), right])) elif n.__class__ in unary_op_classes: return PrimitiveOp(class_to_fun[n.__class__], [simplify_ops(n.expr)]) elif isinstance(n, And): nodes = [simplify_ops(e) for e in n.nodes] r = nodes[len(nodes)-1] for child in reversed(nodes[0:len(nodes)-1]): var = Name(generate_name('__tmp')) r = Let(var.name, child, IfExp(var, r, var)) return r elif isinstance(n, Or): nodes = [simplify_ops(e) for e in n.nodes] r = nodes[len(nodes)-1] for child in reversed(nodes[0:len(nodes)-1]): var = Name(generate_name('__tmp')) r = Let(var.name, child, IfExp(var, var, r)) return r elif isinstance(n, IfExp): return IfExp(simplify_ops(n.test), simplify_ops(n.then), simplify_ops(n.else_)) elif isinstance(n, Compare): def gen_compare(lhs, ops): if len(ops) == 1: op, rhs = ops[0] return PrimitiveOp(compare_to_fun[op], [lhs, simplify_ops(rhs)]) elif len(ops) > 1: op, rhs = ops[0] rhs_var = Name(generate_name('__tmp')) return Let(rhs_var.name, simplify_ops(rhs), PrimitiveOp('logic_and', [PrimitiveOp(compare_to_fun[op], [lhs, rhs_var]), gen_compare(rhs_var, ops[1:])])) else: print "error in gen_compare: zero length ops" lhs_var = Name(generate_name('__tmp')) return Let(lhs_var.name, simplify_ops(n.expr), gen_compare(lhs_var, n.ops)) elif isinstance(n, CallFunc): if isinstance(n.node, Name) and n.node.name == 'input': return PrimitiveOp('input', []) else: def gen_call(f, args, new_args): if len(args) == 0: return CallFunc(f, new_args) elif is_simple(args[0]): return gen_call(f, args[1:], new_args + [args[0]]) else: a = generate_name('__arg') return Let(a, args[0], gen_call(f, args[1:], new_args + [Name(a)])) node = simplify_ops(n.node) args = [simplify_ops(e) for e in n.args] if is_simple(node): return gen_call(node, args, []) else: f = generate_name('__f') return Let(f, node, gen_call(Name(f), args, [])) elif isinstance(n, List): ls_name = generate_name('__list') def gen_list(nodes, i): if len(nodes) == 0: return Name(ls_name) else: return Let('_', PrimitiveOp('set_subscript',\ [Name(ls_name), Const(i), simplify_ops(nodes[0])]), gen_list(nodes[1:], i + 1)) return Let(ls_name, PrimitiveOp('make_list', [Const(len(n.nodes))]),\ gen_list(n.nodes, 0)) elif isinstance(n, Dict): d_name = generate_name('__dict') def gen_dict(items): if len(items) == 0: return Name(d_name) else: return Let('_', PrimitiveOp('set_subscript', \ [Name(d_name), simplify_ops(items[0][0]), simplify_ops(items[0][1])]), gen_dict(items[1:])) return Let(d_name, PrimitiveOp('make_dict', []), gen_dict(n.items)) elif isinstance(n, Subscript): return make_subscript(simplify_ops(n.expr), simplify_ops(n.subs[0])) elif isinstance(n, Lambda): code = Return(simplify_ops(n.code)) return Lambda(n.argnames, n.defaults, n.flags, code) elif isinstance(n, Getattr): expr = simplify_ops(n.expr) return make_get_attr(expr, Const(n.attrname)) else: raise Exception('Error in simplify_ops: unrecognized AST node ' + repr(n)) ############################################################################### # Some helper functions def union(a,b): return a | b # This is good for figuring out the local variables of a function def assigned_vars(n): if isinstance(n, Class): return set([n.name]) elif isinstance(n, Stmt): return reduce(union, [assigned_vars(s) for s in n.nodes], set([])) elif isinstance(n, Printnl): return set([]) elif isinstance(n, Pass): return set([]) elif isinstance(n, If): return reduce(union,[assigned_vars(b) for (c,b) in n.tests], set([])) \ | assigned_vars(n.else_) \ | (reduce(union, [assigned_vars(s) for s in n.phis], set([])) \ if hasattr(n, 'phis') else set([])) elif n == None: return set([]) elif isinstance(n, Assign): return reduce(union, [assigned_vars(n) for n in n.nodes], set([])) elif isinstance(n, AssName): return set([n.name]) elif isinstance(n, While): return assigned_vars(n.body) \ | (reduce(union, [assigned_vars(s) for s in n.phis], set([])) \ if hasattr(n, 'phis') else set([])) elif isinstance(n, Discard): return set([]) else: return set([]) ############################################################################### # Class lowering def replace_attributes(n, attr, cls, outer): if isinstance(n, Class): return n elif isinstance(n, Stmt): return Stmt([replace_attributes(s, attr, cls, outer) for s in n.nodes]) elif isinstance(n, Printnl): return Printnl([replace_attributes(e, attr, cls, outer) for e in n.nodes], n.dest) elif isinstance(n, Pass): return n elif isinstance(n, Discard): return Discard(replace_attributes(n.expr, attr, cls, outer)) elif isinstance(n, Return): return Return(replace_attributes(n.value, attr, cls, outer)) elif isinstance(n, If): new_tests = [(replace_attributes(cond, attr, cls, outer), \ replace_attributes(body, attr, cls, outer)) \ for (cond,body) in n.tests] new_else = replace_attributes(n.else_, attr, cls, outer) return If(tests=new_tests, else_=new_else) elif isinstance(n, While): new_test = replace_attributes(n.test, attr, cls, outer) new_body = replace_attributes(n.body, attr, cls, outer) return While(test=new_test, body=new_body, else_=None) elif isinstance(n, Assign): rhs = replace_attributes(n.expr, attr, cls, outer) lhs = n.nodes[0] new_nodes = [replace_attributes(n, attr, cls, outer) for n in n.nodes] if isinstance(lhs, AssName): if lhs.name in attr: return Discard(make_set_attr(Name(cls), Const(lhs.name), rhs)) else: return Assign(expr=new_rhs, nodes=new_nodes) else: return Assign(expr=new_rhs, nodes=new_nodes) elif isinstance(n, AssName): return n elif n == None: return None elif isinstance(n, Const): return n elif isinstance(n, Name): if n.name in attr: if n.name in outer: return IfExp(PrimitiveOp('has_attr', [Name(cls), Const(n.name)]), make_get_attr(Name(cls), Const(n.name)), Name(n.name)) else: return make_get_attr(Name(cls), Const(n.name)) else: return n elif isinstance(n, PrimitiveOp): nodes = [replace_attributes(e, attr, cls, outer) for e in n.nodes] return PrimitiveOp(n.name, nodes) elif isinstance(n, CallFunc): node = replace_attributes(n.node, attr, cls, outer) args = [replace_attributes(e, attr, cls, outer) for e in n.args] return CallFunc(node, args) elif isinstance(n, IfExp): new_test = replace_attributes(n.test, attr, cls, outer) new_else = replace_attributes(n.else_, attr, cls, outer) new_then = replace_attributes(n.then, attr, cls, outer) return IfExp(test=new_test, else_=new_else, then=new_then) elif isinstance(n, Let): rhs = replace_attributes(n.rhs, attr, cls, outer) body = replace_attributes(n.body, attr, cls, outer) return Let(n.var, rhs, body) elif isinstance(n, Lambda): return n else: raise Exception('Error in rename attributes: unrecognized AST node ' + repr(ast)) def lower_classes(n, outer): if isinstance(n, Class): attributes = assigned_vars(n.code) newcode = replace_attributes(n.code, attributes, n.name, outer) cls = make_assign(n.name, PrimitiveOp('make_class', [simplify_ops(List(n.bases))])) return Stmt([cls,newcode]) elif isinstance(n, Module): new_outer = assigned_vars(n.node) | outer return Module(doc=n.doc, node=lower_classes(n.node, new_outer)) elif isinstance(n, Stmt): return Stmt([lower_classes(s, outer) for s in n.nodes]) elif isinstance(n, Printnl): return Printnl([lower_classes(e, outer) for e in n.nodes], n.dest) elif isinstance(n, Pass): return n elif isinstance(n, Discard): return Discard(lower_classes(n.expr, outer)) elif isinstance(n, Return): return Return(lower_classes(n.value, outer)) elif isinstance(n, If): new_tests = [(lower_classes(cond, outer), \ lower_classes(body, outer)) \ for (cond,body) in n.tests] new_else = lower_classes(n.else_, outer) ret = If(tests=new_tests, else_=new_else) return ret elif isinstance(n, While): new_test = lower_classes(n.test, outer) new_body = lower_classes(n.body, outer) ret = While(test=new_test, body=new_body, else_=None) return ret elif isinstance(n, Assign): new_rhs = lower_classes(n.expr, outer) new_nodes = [lower_classes(n, outer) for n in n.nodes] return Assign(expr=new_rhs, nodes=new_nodes) elif isinstance(n, AssName): return n elif isinstance(n, Subscript): return n elif n == None: return None elif isinstance(n, Const): return n elif isinstance(n, Name): return n elif isinstance(n, PrimitiveOp): nodes = [lower_classes(e, outer) for e in n.nodes] return PrimitiveOp(n.name, nodes) elif isinstance(n, CallFunc): e0 = lower_classes(n.node, outer) e1n = [lower_classes(e, outer) for e in n.args] f = generate_name('__fun') o = generate_name('__obj') if len(e1n) == 0: handle_no_attr = Name(o) else: handle_no_attr = PrimitiveOp('error', [Const("No __init__ method")]) handle_class = Let(o, PrimitiveOp('make_object', [Name(f)]), IfExp(PrimitiveOp('has_attr', \ [Name(f), Const("__init__")]), Let(generate_name('_'), CallFunc(PrimitiveOp('get_closure', [PrimitiveOp('get_attr', \ [Name(f), \ Const("__init__")])]), [Name(o)] + e1n), Name(o)), handle_no_attr)) if 0 < len(e1n): handle_unbound = \ Let(o, e1n[0],\ IfExp(PrimitiveOp('inherits',\ [PrimitiveOp('get_class', [Name(o)]),\ PrimitiveOp('get_class', [Name(f)])]),\ CallFunc(PrimitiveOp('get_closure', [Name(f)]), [Name(o)] + e1n[1:]),\ PrimitiveOp('error', [Const("unbound method error")]))) else: handle_unbound = PrimitiveOp('error', [Const('unbound call requires at least one argument')]) handle_bound = CallFunc(PrimitiveOp('get_closure', [Name(f)]), [PrimitiveOp('get_object', [Name(f)])] + e1n) return Let(f, e0, IfExp(PrimitiveOp('is_class', [Name(f)]), handle_class, IfExp(PrimitiveOp('is_unbound_method', [Name(f)]), handle_unbound, IfExp(PrimitiveOp('is_bound_method', [Name(f)]), handle_bound, CallFunc(Name(f), e1n))))) elif isinstance(n, IfExp): new_test = lower_classes(n.test, outer) new_else = lower_classes(n.else_, outer) new_then = lower_classes(n.then, outer) return IfExp(test=new_test, else_=new_else, then=new_then) elif isinstance(n, Let): rhs = lower_classes(n.rhs, outer) body = lower_classes(n.body, outer) return Let(n.var, rhs, body) elif isinstance(n, Lambda): new_outer = assigned_vars(n.code) | set(n.argnames) | outer return Lambda(n.argnames, n.defaults, n.flags, lower_classes(n.code, new_outer)) else: raise Exception('Error in class lowering: unrecognized AST node ' + repr(n)) ############################################################################### # Heapification def free_vars(n): if isinstance(n, Stmt): return reduce(union, map(free_vars, n.nodes), set([])) elif isinstance(n, Printnl): return reduce(union, map(free_vars, n.nodes), set([])) elif isinstance(n, Pass): return set([]) elif isinstance(n, If): def branch((c,b)): return free_vars(c) | free_vars(b) return reduce(union, map(branch, n.tests), set([])) \ | (free_vars(n.else_)) elif isinstance(n, Assign): return reduce(union, map(free_vars, n.nodes)) | free_vars(n.expr) elif isinstance(n, While): return free_vars(n.body) | free_vars(n.else_) elif isinstance(n, Break): return set([]) elif isinstance(n, Continue): return set([]) elif isinstance(n, Discard): return free_vars(n.expr) elif isinstance(n, Return): return free_vars(n.value) elif n == None: return set([]) elif isinstance(n, Const): return set([]) elif isinstance(n, Name): return set([n.name]) elif isinstance(n, AssName): return set([n.name]) elif isinstance(n, PrimitiveOp): return reduce(union, map(free_vars, n.nodes), set([])) elif isinstance(n, IfExp): return free_vars(n.test) | free_vars(n.else_) | free_vars(n.then) elif isinstance(n, List): return reduce(union, map(free_vars, n.nodes), set([])) elif isinstance(n, CallFunc): return free_vars(n.node) \ | reduce(union, map(free_vars, n.args), set([])) elif isinstance(n, Lambda): return (free_vars(n.code) - set(n.argnames)) \ - set(assigned_vars(n.code)) elif isinstance(n, Let): return free_vars(n.rhs) | (free_vars(n.body) - set([n.var])) else: print n raise Exception('Error in free_vars: unrecognized AST node') def compute_heapify_vars(n): if isinstance(n, Stmt): return reduce(union, map(compute_heapify_vars, n.nodes), set([])) elif isinstance(n, Printnl): return reduce(union, map(compute_heapify_vars, n.nodes), set([])) elif isinstance(n, Pass): return set([]) elif isinstance(n, If): def branch((c,b)): return compute_heapify_vars(c) | compute_heapify_vars(b) return reduce(union, map(branch, n.tests), set([])) \ | (compute_heapify_vars(n.else_)) elif isinstance(n, Assign): return reduce(union, map(compute_heapify_vars, n.nodes)) \ | compute_heapify_vars(n.expr) elif isinstance(n, While): return compute_heapify_vars(n.body) | compute_heapify_vars(n.else_) elif isinstance(n, Break): return set([]) elif isinstance(n, Continue): return set([]) elif isinstance(n, Discard): return compute_heapify_vars(n.expr) elif isinstance(n, Return): return compute_heapify_vars(n.value) elif n == None: return set([]) elif isinstance(n, Const): return set([]) elif isinstance(n, Name): return set([]) elif isinstance(n, AssName): return set([]) elif isinstance(n, PrimitiveOp): return reduce(union, map(compute_heapify_vars, n.nodes), set([])) elif isinstance(n, IfExp): return compute_heapify_vars(n.test) | compute_heapify_vars(n.else_) \ | compute_heapify_vars(n.then) elif isinstance(n, List): return reduce(union, map(compute_heapify_vars, n.nodes), set([])) elif isinstance(n, CallFunc): return compute_heapify_vars(n.node) \ | reduce(union, map(compute_heapify_vars, n.args), set([])) elif isinstance(n, Lambda): return free_vars(n) elif isinstance(n, Let): return (compute_heapify_vars(n.rhs) \ | compute_heapify_vars(n.body)) - set([n.var]) else: print n raise Exception('Error in compute_heapify_vars: unrecognized AST node') def heapify(ast, vars_to_heapify): if isinstance(ast, Module): local_vars = assigned_vars(ast.node) to_heapify = compute_heapify_vars(ast.node) locals_to_heapify = local_vars & to_heapify node = heapify(ast.node, locals_to_heapify) hl_assigns = [make_assign(p, simplify_ops(List([Const(4444)]))) \ for p in locals_to_heapify] newnode = prepend_stmts(hl_assigns, node) return Module(doc=ast.doc, node=newnode) elif isinstance(ast, Stmt): return Stmt([heapify(s, vars_to_heapify) for s in ast.nodes]) elif isinstance(ast, Printnl): return Printnl([heapify(e, vars_to_heapify) for e in ast.nodes], \ ast.dest) elif isinstance(ast, Pass): return ast elif isinstance(ast, Discard): return Discard(heapify(ast.expr, vars_to_heapify)) elif isinstance(ast, Return): return Return(heapify(ast.value, vars_to_heapify)) elif isinstance(ast, If): new_tests = [(heapify(cond, vars_to_heapify), \ heapify(body, vars_to_heapify)) \ for (cond,body) in ast.tests] new_else = heapify(ast.else_, vars_to_heapify) ret = If(tests=new_tests, else_=new_else) return ret elif isinstance(ast, While): new_test = heapify(ast.test, vars_to_heapify) new_body = heapify(ast.body, vars_to_heapify) ret = While(test=new_test, body=new_body, else_=None) return ret elif isinstance(ast, Assign): new_rhs = heapify(ast.expr, vars_to_heapify) new_nodes = [heapify(n, vars_to_heapify) for n in ast.nodes] return Assign(expr=new_rhs, nodes=new_nodes) elif isinstance(ast, AssName): if ast.name in vars_to_heapify: return make_subscript(Name(ast.name), Const(0)) else: return AssName(name=ast.name, flags=ast.flags) elif ast == None: return None elif isinstance(ast, Const): return ast elif isinstance(ast, Name): if ast.name == 'True' or ast.name == 'False': return ast elif ast.name in vars_to_heapify: return make_subscript(ast, Const(0)) else: return ast elif isinstance(ast, PrimitiveOp): nodes = [heapify(e, vars_to_heapify) for e in ast.nodes] return PrimitiveOp(ast.name, nodes) elif isinstance(ast, CallFunc): node = heapify(ast.node, vars_to_heapify) args = [heapify(e, vars_to_heapify) for e in ast.args] return CallFunc(node, args) elif isinstance(ast, IfExp): new_test = heapify(ast.test, vars_to_heapify) new_else = heapify(ast.else_, vars_to_heapify) new_then = heapify(ast.then, vars_to_heapify) return IfExp(test=new_test, else_=new_else, then=new_then) elif isinstance(ast, Let): rhs = heapify(ast.rhs, vars_to_heapify) body = heapify(ast.body, vars_to_heapify) return Let(ast.var, rhs, body) elif isinstance(ast, Lambda): params = set(ast.argnames) local_vars = assigned_vars(ast.code) - params to_heapify = compute_heapify_vars(ast.code) locals_to_heapify = local_vars & to_heapify params_to_heapify = params & to_heapify new_vth = (vars_to_heapify - local_vars - params) \ | locals_to_heapify | params_to_heapify code = heapify(ast.code, new_vth) hp_assigns = [make_assign(p, simplify_ops(List([Name(p)]))) for p in params_to_heapify] hl_assigns = [make_assign(p, simplify_ops(List([Const(4444)]))) \ for p in locals_to_heapify] newcode = prepend_stmts(hp_assigns + hl_assigns, code) return Lambda(ast.argnames, ast.defaults, ast.flags, newcode) else: raise Exception('Error in heapify: unrecognized AST node ' + repr(ast)) ############################################################################### # Closure Conversion class Program: def __init__(self, decls): self.decls = decls def __repr__(self): return 'Program(%s)' % repr(self.decls) def convert_closures(ast): if debug: print 'convert closures ', ast if isinstance(ast, Module): (node, funs) = convert_closures(ast.node) return Program(funs + [Module(doc=ast.doc, node=node)]) elif isinstance(ast, Stmt): nodes_funs = [convert_closures(n) for n in ast.nodes] nodes = [n for (n,fs) in nodes_funs] funs = reduce(lambda a,b: a + b, [fs for (n,fs) in nodes_funs], []) return (Stmt(nodes), funs) elif isinstance(ast, Printnl): nodes_funs = [convert_closures(n) for n in ast.nodes] nodes = [n for (n,fs) in nodes_funs] funs = reduce(lambda a,b: a + b, [fs for (n,fs) in nodes_funs], []) return (Printnl(nodes, ast.dest), funs) elif isinstance(ast, Pass): return (ast, []) elif isinstance(ast, Discard): (expr,funs) = convert_closures(ast.expr) return (Discard(expr), funs) elif isinstance(ast, Return): (expr,funs) = convert_closures(ast.value) return (Return(expr), funs) elif isinstance(ast, If): new_tests_funs = [(convert_closures(cond), \ convert_closures(body)) \ for (cond,body) in ast.tests] new_tests = [ (cond,body) \ for ((cond,cond_funs), (body, body_funs)) in new_tests_funs] test_funs = reduce(lambda a,b: a + b, [ cond_funs + body_funs \ for ((cond,cond_funs), (body, body_funs)) in new_tests_funs], []) (new_else, els_funs) = convert_closures(ast.else_) ret = If(tests=new_tests, else_=new_else) return (ret, test_funs + els_funs) elif isinstance(ast, While): (new_test, test_funs) = convert_closures(ast.test) (new_body, body_funs) = convert_closures(ast.body) ret = While(test=new_test, body=new_body, else_=None) return (ret, test_funs + body_funs) elif isinstance(ast, Assign): (new_lhs, lhs_funs) = convert_closures(ast.expr) new_nodes_funs = [convert_closures(n) for n in ast.nodes] new_nodes = [n for (n,fs) in new_nodes_funs] rhs_funs = reduce(lambda a,b: a + b, [fs for (n,fs) in new_nodes_funs], []) return (Assign(expr=new_lhs, nodes=new_nodes), lhs_funs + rhs_funs) elif isinstance(ast, AssName): return (AssName(name=ast.name, flags=ast.flags), []) elif ast == None: return (None, []) elif isinstance(ast, Const): return (ast, []) elif isinstance(ast, Name): return (ast, []) elif isinstance(ast, PrimitiveOp): nodes_funs = [convert_closures(e) for e in ast.nodes] nodes = [n for (n,fs) in nodes_funs] funs = reduce(lambda a,b: a + b, [fs for (n,fs) in nodes_funs], []) return (PrimitiveOp(ast.name, nodes), funs) elif isinstance(ast, CallFunc): (node,node_funs) = convert_closures(ast.node) args_funs = [convert_closures(e) for e in ast.args] args = [a for (a,fs) in args_funs] more_funs = reduce(lambda a,b: a + b, \ [fs for (a,fs) in args_funs], []) f = PrimitiveOp('get_func', [node]) fvs = PrimitiveOp('get_free', [node]) return (CallFunc(f, [fvs] + args), node_funs + more_funs) elif isinstance(ast, IfExp): (new_test, test_funs) = convert_closures(ast.test) (new_else, else_funs) = convert_closures(ast.else_) (new_then, then_funs) = convert_closures(ast.then) return (IfExp(test=new_test, else_=new_else, then=new_then), test_funs + else_funs + then_funs) elif isinstance(ast, Let): (rhs, rhs_funs) = convert_closures(ast.rhs) (body, body_funs) = convert_closures(ast.body) return (Let(ast.var, rhs, body), rhs_funs + body_funs) elif isinstance(ast, Lambda): frees = free_vars(ast) (code, code_funs) = convert_closures(ast.code) fv_assigns = [make_assign(f, make_subscript(Name('fvs'), Const(i)))\ for (f,i) in zip(frees,range(0, len(frees)))] newcode = prepend_stmts(fv_assigns, code) globalname = generate_name('__lambda') globalfun = Function(name=globalname, \ argnames=['fvs'] + list(ast.argnames), \ code=newcode, decorators='', defaults='', flags=ast.flags, doc='') fvs = simplify_ops(List([Name(f) for f in frees])) ret = PrimitiveOp('make_closure', [Name(globalname), fvs]) funs = code_funs + [globalfun] return (ret, funs) else: raise Exception('Error in convert_closures: unrecognized AST node ' + repr(ast)) ############################################################################### # Convert to static single-assignment form (SSA) def get_high(highest_version, x): if x in highest_version: v = highest_version[x] + 1 highest_version[x] = v return v else: highest_version[x] = 0 return 0 def get_current(current_version, x): if x in current_version: return current_version[x] else: return 0 def convert_to_ssa(ast, current_version={}, highest_version={}): if debug: print 'convert to ssa: ', ast if isinstance(ast, Program): return Program([convert_to_ssa(d, current_version) for d in ast.decls]) elif isinstance(ast, Module): return Module(doc=ast.doc, node=convert_to_ssa(ast.node, current_version)) elif isinstance(ast, Function): newscope = assigned_vars(ast.code) | set(ast.argnames) body_cv = copy.deepcopy(current_version) body_hv = copy.deepcopy(highest_version) local_vars = assigned_vars(ast.code) - set(ast.argnames) for x in ast.argnames: body_cv[x] = 0 body_hv[x] = 0 for x in local_vars: if x in body_cv: del body_cv[x] if x in body_hv: del body_hv[x] new_argnames = [x + '_' + str(get_current(body_cv, x)) \ for x in ast.argnames] new_body = convert_to_ssa(ast.code, body_cv, body_hv) name = ast.name + '_' + str(get_current(current_version, ast.name)) return Function(name=name, argnames=new_argnames, code=new_body, decorators=ast.decorators, defaults=ast.defaults, flags=ast.flags, doc=ast.doc) elif isinstance(ast, Stmt): return Stmt([convert_to_ssa(s, current_version, highest_version) for s in ast.nodes]) elif isinstance(ast, Printnl): return Printnl([convert_to_ssa(e, current_version, highest_version) \ for e in ast.nodes], ast.dest) elif isinstance(ast, Pass): return ast elif isinstance(ast, Discard): return Discard(convert_to_ssa(ast.expr, current_version, highest_version)) elif isinstance(ast, Return): return Return(convert_to_ssa(ast.value, current_version, highest_version)) elif isinstance(ast, If): new_tests = [] for (cond,body) in ast.tests: new_cond = convert_to_ssa(cond, current_version, highest_version) body_cv = copy.deepcopy(current_version) new_body = convert_to_ssa(body, body_cv, highest_version) new_tests.append((new_cond, new_body, body_cv)) else_cv = copy.deepcopy(current_version) new_else = convert_to_ssa(ast.else_, else_cv, highest_version) assigned = reduce(union, [assigned_vars(b) for (c,b) in ast.tests], \ set([])) \ | assigned_vars(ast.else_) phis = [] for x in assigned: current_version[x] = get_high(highest_version, x) phi_rhs = [Name(x + '_' + str(get_current(cv, x))) \ for (_,_,cv) in new_tests] phi_rhs.append(Name(x + '_' + str(get_current(else_cv, x)))) phi = make_assign(x + '_' + str(get_current(current_version, x)),\ PrimitiveOp('phi', phi_rhs)) phis.append(phi) ret = If(tests=[(c,b) for (c,b,_) in new_tests], else_=new_else) ret.phis = phis return ret elif isinstance(ast, While): pre_cv = copy.deepcopy(current_version) pre = Stmt(nodes=[]) if debug: print 'convert to ssa While ', ast.test assigned = assigned_vars(ast.body) if debug: print 'assigned = ', assigned for x in assigned: current_version[x] = get_high(highest_version, x) body_cv = copy.deepcopy(current_version) new_body = convert_to_ssa(ast.body, body_cv, highest_version) new_test = convert_to_ssa(ast.test, current_version, highest_version) phis = [] for x in assigned: body_var = Name(x + '_' + str(get_current(body_cv, x))) pre_var = Name(x + '_' + str(get_current(pre_cv, x))) phi = make_assign(x + '_' + str(get_current(current_version, x)),\ PrimitiveOp('phi', [pre_var,body_var])) phis.append(phi) ret = While(test=new_test, body=new_body, else_=None) ret.phis = phis return ret elif isinstance(ast, Assign): new_rhs = convert_to_ssa(ast.expr, current_version, highest_version) new_nodes = [] for n in ast.nodes: if isinstance(n, AssName): x = n.name x_v = get_high(highest_version, x) current_version[x] = x_v new_nodes.append(AssName(name=x + '_' + str(x_v), flags=n.flags)) else: new_nodes.append(convert_to_ssa(n, current_version, highest_version)) return Assign(expr=new_rhs, nodes=new_nodes) elif ast == None: return None elif isinstance(ast, Const): return ast elif isinstance(ast, Name): if ast.name == 'True' or ast.name == 'False': return ast else: return Name(ast.name + '_' + str(get_current(current_version, ast.name))) elif isinstance(ast, PrimitiveOp): nodes = [convert_to_ssa(e, current_version, highest_version) for e in ast.nodes] return PrimitiveOp(ast.name, nodes) elif isinstance(ast, CallFunc): node = convert_to_ssa(ast.node, current_version, highest_version) args = [convert_to_ssa(e, current_version, highest_version) for e in ast.args] return CallFunc(node, args) elif isinstance(ast, IfExp): new_test = convert_to_ssa(ast.test, current_version, highest_version) new_else = convert_to_ssa(ast.else_, current_version, highest_version) new_then = convert_to_ssa(ast.then, current_version, highest_version) return IfExp(test=new_test, else_=new_else, then=new_then) elif isinstance(ast, Let): rhs = convert_to_ssa(ast.rhs, current_version, highest_version) v = get_high(highest_version, ast.var) current_version[ast.var] = v body = convert_to_ssa(ast.body, current_version, highest_version) return Let(ast.var + '_' + str(v), rhs, body) else: raise Exception('Error in convert_to_ssa: unrecognized AST node ' + repr(ast)) ############################################################################### # Insert variable declarations class VarDecl(Node): def __init__(self, name, type, lineno=None): self.name = name self.type = type self.lineno = lineno def getChildren(self): return self.name, self.type def getChildNodes(self): return () def __repr__(self): return "VarDecl(%s, %s)" % (self.name, self.type) def insert_var_decls(n): if isinstance(n, Module): decls = [VarDecl(x,'undefined') for x in assigned_vars(n.node)] return Module(n.doc, prepend_stmts(decls, n.node)) elif isinstance(n, Program): return Program([insert_var_decls(d) for d in n.decls]) elif isinstance(n, Function): locals = assigned_vars(n.code) - set(n.argnames) decls = [VarDecl(x,'undefined') for x in locals] return Function(name=n.name, argnames=n.argnames, code=prepend_stmts(decls, n.code), doc=n.doc, decorators=n.decorators, defaults=n.defaults, flags=n.flags) else: raise Exception('Error in insert_var_decls: unhandled AST ' + repr(n)) ############################################################################### # Type analysis, this pass annotates the IR in-place with types in the 'type' # attribute def join(t1,t2): if t1 == 'pyobj': return 'pyobj' elif t2 == 'pyobj': return 'pyobj' elif t1 == 'undefined': return t2 elif t2 == 'undefined': return t1 elif t1 == t2: return t1 else: return 'pyobj' arith_returns = { 'int' : 'int', 'float' : 'float', 'pyobj' : 'pyobj', 'undefined' : 'undefined' } arith_op = { ('int', 'int') : 'int' , ('int', 'float') : 'float' , ('int', 'bool') : 'int' , ('int', 'pyobj') : 'pyobj' , ('int', 'undefined') : 'undefined' , ('float', 'int') : 'float' , ('float', 'float') : 'float' , ('float', 'bool') : 'float' , ('float', 'pyobj') : 'pyobj' , ('float', 'undefined') : 'undefined' , ('bool', 'int') : 'int' , ('bool', 'float') : 'float' , ('bool', 'bool') : 'int', ('bool', 'pyobj') : 'pyobj', ('bool', 'undefined') : 'undefined', ('pyobj', 'int') : 'pyobj' , ('pyobj', 'float') : 'pyobj' , ('pyobj', 'bool') : 'pyobj' , ('pyobj', 'pyobj') : 'pyobj', ('pyobj', 'undefined') : 'undefined', ('undefined', 'int') : 'undefined' , ('undefined', 'float') : 'undefined' , ('undefined', 'bool') : 'undefined' , ('undefined', 'pyobj') : 'undefined' , ('undefined', 'undefined') : 'undefined', ('pyobj',) : 'pyobj', ('undefined',) : 'undefined', ('int',) : 'int', ('float',) : 'float', ('bool',) : 'int' } power_returns = { 'pyobj' : 'pyobj', 'int' : 'pyobj', 'float' : 'float', 'bool' : 'pyobj', 'undefined' : 'undefined' } # power is funny because whether it returns a float depends on the *values* of the input. # for example 3**-1 returns a float. power_op = { ('int', 'int') : 'int' , ('int', 'float') : 'float' , ('int', 'bool') : 'int' , ('int', 'pyobj') : 'pyobj' , ('int', 'undefined') : 'undefined' , ('float', 'int') : 'float' , ('float', 'float') : 'float' , ('float', 'bool') : 'float' , ('float', 'pyobj') : 'pyobj' , ('float', 'undefined') : 'undefined' , ('bool', 'int') : 'int' , ('bool', 'float') : 'float' , ('bool', 'bool') : 'bool' , ('bool', 'pyobj') : 'pyobj', ('bool', 'undefined') : 'undefined', ('pyobj', 'int') : 'pyobj' , ('pyobj', 'float') : 'pyobj' , ('pyobj', 'bool') : 'pyobj' , ('pyobj', 'pyobj') : 'pyobj', ('pyobj', 'undefined') : 'undefined', ('undefined', 'int') : 'undefined' , ('undefined', 'float') : 'undefined' , ('undefined', 'bool') : 'undefined' , ('undefined', 'pyobj') : 'undefined' , ('undefined', 'undefined') : 'undefined' } bool_returns = { 'int' : 'bool', 'float' : 'bool', 'pyobj' : 'bool', 'bool' : 'bool', 'undefined' : 'bool' } bool_op = { ('int', 'int') : 'int' , ('int', 'float') : 'float' , ('int', 'bool') : 'int' , ('int', 'pyobj') : 'pyobj' , ('int', 'undefined') : 'undefined' , ('float', 'int') : 'float' , ('float', 'float') : 'float' , ('float', 'bool') : 'float' , ('float', 'pyobj') : 'pyobj' , ('float', 'undefined') : 'undefined' , ('bool', 'int') : 'int' , ('bool', 'float') : 'float' , ('bool', 'bool') : 'bool', ('bool', 'pyobj') : 'pyobj', ('bool', 'undefined') : 'undefined', ('pyobj', 'int') : 'pyobj' , ('pyobj', 'float') : 'pyobj' , ('pyobj', 'bool') : 'pyobj' , ('pyobj', 'undefined') : 'undefined', ('pyobj', 'pyobj') : 'pyobj', ('undefined', 'int') : 'undefined' , ('undefined', 'float') : 'undefined' , ('undefined', 'bool') : 'undefined' , ('undefined', 'pyobj') : 'undefined', ('undefined', 'undefined') : 'undefined', ('undefined',) : 'undefined', ('pyobj',) : 'pyobj', ('int',) : 'int', ('float',) : 'float', ('bool',) : 'bool' } arithop = (lambda ts: arith_op[ts]) boolop = (lambda ts: bool_op[ts]) def idop(ts): if ts[0] == ts[1]: return ts[0] else: return 'pyobj' find_op_tag = { 'add' : arithop, 'sub' : arithop, 'mul' : arithop, 'floordiv' : arithop, 'floor' : (lambda ts: ''), 'float_to_int' : (lambda ts: ''), 'mod' : arithop, 'power' : (lambda ts: power_op[ts]), 'unary_add' : arithop, 'unary_sub' : arithop, 'logic_not' : boolop, 'logic_and' : boolop, 'logic_or' : boolop, 'equal' : boolop, 'not_equal' : boolop, 'less' : boolop, 'less_equal' : boolop, 'greater' : boolop, 'greater_equal' : boolop, 'identical' : idop, 'input' : (lambda ts: ''), 'deref' : (lambda ts: ''), 'subscript' : (lambda ts: 'pyobj'), 'set_subscript' : (lambda ts: 'pyobj'), 'make_list' : (lambda ts: ''), 'make_dict' : (lambda ts: ''), 'assign' : (lambda ts: reduce(join, ts, 'undefined')), 'phi' : (lambda ts: reduce(join, ts, 'undefined')), 'make_closure' : (lambda ts: 'pyobj'), 'get_func' : (lambda ts: 'pyobj'), 'get_closure' : (lambda ts: 'pyobj'), 'get_free' : (lambda ts: 'pyobj'), 'set_free' : (lambda ts: 'pyobj'), 'make_class' : (lambda ts: 'pyobj'), 'make_object' : (lambda ts: 'pyobj'), 'get_class' : (lambda ts: 'pyobj'), 'get_object' : (lambda ts: 'pyobj'), 'is_class' : (lambda ts: 'pyobj'), 'is_unbound_method' : (lambda ts: 'pyobj'), 'is_bound_method' : (lambda ts: 'pyobj'), 'has_attr' : (lambda ts: 'pyobj'), 'get_attr' : (lambda ts: 'pyobj'), 'set_attr' : (lambda ts: 'pyobj'), 'inherits' : (lambda ts: 'pyobj'), 'error' : (lambda ts: 'pyobj') } arith_ret = (lambda ts: arith_returns[arith_op[ts]]) bool_ret = (lambda ts: bool_returns[bool_op[ts]]) op_returns = { 'add' : arith_ret, 'sub' : arith_ret, 'mul' : arith_ret, 'floordiv' : arith_ret, 'floor' : (lambda ts: 'float'), 'float_to_int' : (lambda ts: 'int'), 'mod' : arith_ret, 'power' : (lambda ts: power_returns[power_op[ts]]), 'unary_add' : arith_ret, 'unary_sub' : arith_ret, 'logic_not' : bool_ret, 'logic_and' : bool_ret, 'logic_or' : bool_ret, 'equal' : bool_ret, 'not_equal' : bool_ret, 'less' : bool_ret, 'less_equal' : bool_ret, 'greater' : bool_ret, 'greater_equal' : bool_ret, 'identical' : bool_ret, 'input' : (lambda ts: 'pyobj'), 'deref' : (lambda ts: 'pyobj'), 'subscript' : (lambda ts: 'pyobj*'), 'set_subscript' : (lambda ts: 'pyobj'), 'make_list' : (lambda ts: 'pyobj'), 'make_dict' : (lambda ts: 'pyobj'), 'assign' : (lambda ts: reduce(join, ts, 'undefined')), 'phi' : (lambda ts: reduce(join, ts, 'undefined')), 'make_closure' : (lambda ts: 'pyobj'), 'get_func' : (lambda ts: 'void*'), 'get_closure' : (lambda ts: 'pyobj'), 'get_free' : (lambda ts: 'pyobj'), 'set_free' : (lambda ts: 'pyobj'), 'make_class' : (lambda ts: 'pyobj'), 'make_object' : (lambda ts: 'pyobj'), 'get_class' : (lambda ts: 'pyobj'), 'get_object' : (lambda ts: 'pyobj'), 'is_class' : (lambda ts: 'bool'), 'is_unbound_method' : (lambda ts: 'bool'), 'is_bound_method' : (lambda ts: 'bool'), 'has_attr' : (lambda ts: 'bool'), 'get_attr' : (lambda ts: 'pyobj'), 'set_attr' : (lambda ts: 'pyobj'), 'inherits' : (lambda ts: 'bool'), 'error' : (lambda ts: 'pyobj') } op_params = { 'add_int' : ('int','int'), 'add_float' : ('float','float'), 'add_bool' : ('bool','bool'), 'add_pyobj' : ('pyobj','pyobj'), 'sub_int' : ('int','int'), 'sub_float' : ('float','float'), 'sub_bool' : ('bool','bool'), 'sub_pyobj' : ('pyobj','pyobj'), 'mul_int' : ('int','int'), 'mul_float' : ('float','float'), 'mul_bool' : ('bool','bool'), 'mul_pyobj' : ('pyobj','pyobj'), 'floor' : ('float','float'), 'float_to_int' : ('float',), 'floordiv_int' : ('int','int'), 'floordiv_bool' : ('bool','bool'), 'floordiv_float' : ('float','float'), 'floordiv_pyobj' : ('pyobj','pyobj'), 'mod_int' : ('int','int'), 'mod_bool' : ('int','int'), 'mod_float' : ('float','float'), 'mod_pyobj' : ('pyobj','pyboj'), 'power_int' : ('int','int'), 'power_bool' : ('bool','bool'), 'power_float' : ('float','float'), 'power_pyobj' : ('pyobj','pyboj'), 'unary_add_int' : ('int',), 'unary_add_bool' : ('bool',), 'unary_add_float' : ('float',), 'unary_add_pyobj' : ('pyobj',), 'unary_sub_int' : ('int',), 'unary_sub_bool' : ('bool',), 'unary_sub_float' : ('float',), 'unary_sub_pyobj' : ('pyobj',), 'logic_not_int' : ('int',), 'logic_not_bool' : ('bool',), 'logic_not_float' : ('float',), 'logic_not_pyobj' : ('pyobj',), 'logic_and_int' : ('int','int'), 'logic_and_bool' : ('bool','bool'), 'logic_and_pyobj' : ('pyobj','pyobj'), 'logic_or_int' : ('int','int'), 'logic_or_bool' : ('bool','bool'), 'logic_or_pyobj' : ('pyobj','pyobj'), 'equal_int' : ('int','int'), 'equal_float' : ('float','float'), 'equal_bool' : ('bool','bool'), 'equal_pyobj' : ('pyobj','pyobj'), 'not_equal_int' : ('int','int'), 'not_equal_float' : ('float','float'), 'not_equal_bool' : ('bool','bool'), 'not_equal_pyobj' : ('pyobj','pyobj'), 'less_int' : ('int','int'), 'less_float' : ('float','float'), 'less_bool' : ('bool','bool'), 'less_pyobj' : ('pyobj','pyobj'), 'less_equal_int' : ('int','int'), 'less_equal_float' : ('float','float'), 'less_equal_bool' : ('bool','bool'), 'less_equal_pyobj' : ('pyobj','pyobj'), 'greater_int' : ('int','int'), 'greater_float' : ('float','float'), 'greater_bool' : ('bool','bool'), 'greater_pyobj' : ('pyobj','pyobj'), 'greater_equal_int' : ('int','int'), 'greater_equal_float' : ('float','float'), 'greater_equal_bool' : ('bool','bool'), 'greater_equal_pyobj' : ('pyobj','pyobj'), 'identical_int' : ('int','int'), 'identical_float' : ('float','float'), 'identical_bool' : ('bool','bool'), 'identical_pyobj' : ('pyobj','pyobj') } type_changed = False class LinkedEnv: def __init__(self, parent=False): self.local = {} self.parent = parent def __repr__(self): return "LinkedEnv(%s, %s)" % (repr(self.local), repr(self.parent)) def create_frame(env): return LinkedEnv(env) def get_var_type(env, x): if x in env.local: return env.local[x] elif env.parent: return get_var_type(env.parent, x) else: return 'undefined' def is_defined(env, x): if x in env.local: return True elif env.parent: return is_defined(env.parent, x) else: return False # Don't need to do a join in update_var_type because the program # is in SSA form. There is only one assignment to each variable. def update_var_type(env, x, t): if x in env.local: if t == env.local[x]: return False else: env.local[x] = t return True else: env.local[x] = t return True def predict_type(n, env): global type_changed if debug: print 'predict_type ', n print env if isinstance(n, Program): type_changed = True i = 0 while i < 10: i = i + 1 type_changed = False for d in n.decls: predict_type(d, env) elif isinstance(n, Module): predict_type(n.node, env) elif isinstance(n, Function): new_env = create_frame(env) for x in set(n.argnames): update_var_type(new_env, x, 'pyobj') i = 0 while i < 10: i = i + 1 predict_type(n.code, new_env) elif isinstance(n, Stmt): for s in n.nodes: predict_type(s, env) elif isinstance(n, Printnl): for e in n.nodes: predict_type(e, env) elif isinstance(n, Discard): predict_type(n.expr, env) elif isinstance(n, Return): predict_type(n.value, env) elif isinstance(n, If): for (cond,body) in n.tests: predict_type(cond,env) predict_type(body,env) predict_type(n.else_,env) for s in n.phis: predict_type(s,env) elif n == None: pass elif isinstance(n, While): predict_type(n.test,env) predict_type(n.body,env) for s in n.phis: predict_type(s,env) elif isinstance(n, Pass): pass elif isinstance(n, Assign): predict_type(n.expr, env) for a in n.nodes: if isinstance(a, AssName): type_changed += update_var_type(env, a.name, n.expr.type) a.type = get_var_type(env, a.name) else: predict_type(a, env) elif isinstance(n, VarDecl): n.type = get_var_type(env, n.name) elif isinstance(n, Const): if isinstance(n.value, float): n.type = 'float' elif isinstance(n.value, int): n.type = 'int' elif isinstance(n.value, str): n.type = 'pyobj' else: raise Exception('Error in predict_type: unhandled constant ' + repr(n)) elif isinstance(n, Name): if n.name == 'True' or n.name == 'False': n.type = 'bool' else: n.type = get_var_type(env, n.name) if debug: print 'Name: ', n.name, n.type elif isinstance(n, PrimitiveOp): for e in n.nodes: predict_type(e, env) n.type = op_returns[n.name](tuple([e.type for e in n.nodes])) elif isinstance(n, CallFunc): predict_type(n.node, env) for e in n.args: predict_type(e, env) n.type = 'pyobj' elif isinstance(n, IfExp): predict_type(n.test, env) predict_type(n.then, env) predict_type(n.else_, env) if n.then.type == n.else_.type: n.type = n.then.type else: n.type = 'pyobj' elif isinstance(n, Let): predict_type(n.rhs, env) if debug: print 'let type: ', n.var, n.rhs.type body_env = create_frame(env) update_var_type(body_env, n.var, n.rhs.type) predict_type(n.body, body_env) n.type = n.body.type else: raise Exception('Error in predict_type: unrecognized AST node ' + repr(n)) ############################################################################### # Type specialization # select specialized primitive operations # insert calls to is_true and make_* where appropriate def convert_to_pyobj(e): if e.type != 'pyobj' and e.type != 'undefined': new_e = PrimitiveOp(e.type + '_to_pyobj', [e]) new_e.type = 'pyobj' return new_e else: return e def convert_to(e, t): if e.type != t: new_e = PrimitiveOp(e.type + '_to_' + t, [e]) new_e.type = t return new_e else: return e def test_is_true(e): if e.type == 'pyobj': ret = PrimitiveOp('pyobj_to_bool', [e]) ret.type = 'bool' return ret else: return e def type_specialize(n): if debug: print 'type specialize ' + repr(n) if isinstance(n, Program): return Program([type_specialize(d) for d in n.decls]) elif isinstance(n, Module): return Module(n.doc, type_specialize(n.node)) elif isinstance(n, Function): return Function(name=n.name, argnames=n.argnames, code=type_specialize(n.code), decorators=n.decorators, doc=n.doc, defaults=n.defaults, flags=n.flags) elif isinstance(n, Stmt): return Stmt([type_specialize(s) for s in n.nodes]) elif isinstance(n, Printnl): # would be nice to specialize print, but not a high priority return Printnl([convert_to_pyobj(type_specialize(e)) for e in n.nodes], n.dest) elif isinstance(n, Discard): return Discard(type_specialize(n.expr)) elif isinstance(n, Return): return Return(convert_to_pyobj(type_specialize(n.value))) elif isinstance(n, If): tests = [(test_is_true(type_specialize(cond)), type_specialize(body)) \ for (cond,body) in n.tests] else_ = type_specialize(n.else_) phis = [type_specialize(s) for s in n.phis] ret = If(tests,else_) ret.phis = phis return ret elif n == None: return None elif isinstance(n, While): test = test_is_true(type_specialize(n.test)) body = type_specialize(n.body) phis = [type_specialize(s) for s in n.phis] ret = While(test, body, None) ret.phis = phis return ret elif isinstance(n, Pass): return n elif isinstance(n, Assign): expr = type_specialize(n.expr) nodes = [type_specialize(a) for a in n.nodes] if any([a.type == 'pyobj' for a in nodes]): expr = convert_to_pyobj(expr) return Assign(nodes, expr) elif isinstance(n, AssName): return n elif isinstance(n, VarDecl): return n elif isinstance(n, Const): return n elif isinstance(n, Name): return n elif isinstance(n, PrimitiveOp): nodes = [type_specialize(e) for e in n.nodes] tag = find_op_tag[n.name](tuple([e.type for e in n.nodes])) name = n.name if tag == '' else n.name + '_' + tag if tag == 'pyobj': nodes = [convert_to_pyobj(e) for e in nodes] else: if name in op_params: nodes = [convert_to(e,t) for (e,t) in zip(nodes, list(op_params[name]))] else: nodes = nodes r = PrimitiveOp(name, nodes) r.type = n.type return r elif isinstance(n, CallFunc): args = [type_specialize(e) for e in n.args] args = [convert_to_pyobj(e) for e in args] node = type_specialize(n.node) r = CallFunc(node, args) r.type = 'pyobj' return r elif isinstance(n, IfExp): test = type_specialize(n.test) then = type_specialize(n.then) else_ = type_specialize(n.else_) test = test_is_true(test) if any([e.type == 'pyobj' for e in [n,then,else_]]): then = convert_to_pyobj(then) else_ = convert_to_pyobj(else_) r = IfExp(test, then, else_) r.type = n.type return r elif isinstance(n, Let): rhs = type_specialize(n.rhs) body = type_specialize(n.body) r = Let(n.var, rhs, body) r.type = n.type return r else: raise Exception('Error in type_specialize: unrecognized AST node ' + repr(n)) ############################################################################### # Remove SSA def split_phis(phis): branch_dict = {} for phi in phis: lhs = phi.nodes[0].name i = 0 for rhs in phi.expr.nodes: if i in branch_dict: branch_dict[i].append(make_assign(lhs, rhs)) else: branch_dict[i] = [make_assign(lhs, rhs)] i = i + 1 return branch_dict def remove_ssa(n): if isinstance(n, Program): return Program([remove_ssa(d) for d in n.decls]) elif isinstance(n, Module): return Module(n.doc, remove_ssa(n.node)) elif isinstance(n, Function): return Function(name=n.name, argnames=n.argnames, code=remove_ssa(n.code), doc=n.doc, decorators=n.decorators, defaults=n.defaults, flags=n.flags) elif isinstance(n, Stmt): return Stmt([remove_ssa(s) for s in n.nodes]) elif isinstance(n, Printnl): return n elif isinstance(n, Discard): return n elif isinstance(n, Return): return n elif isinstance(n, If): tests = [(cond, remove_ssa(body)) for (cond,body) in n.tests] else_ = remove_ssa(n.else_) phis = [remove_ssa(s) for s in n.phis] branch_dict = split_phis(phis) if debug: print 'remove ssa If ' print 'branch dict: ', branch_dict b = 0 new_tests = [] for (cond,body) in tests: if 0 < len(branch_dict): new_body = append_stmts(body, Stmt(branch_dict[b])) else: new_body = body new_tests.append((cond,new_body)) b = b + 1 if 0 < len(branch_dict): new_else = append_stmts(else_, Stmt(branch_dict[b])) else: new_else = else_ ret = If(new_tests, new_else) return ret elif n == None: return None elif isinstance(n, While): test = n.test body = remove_ssa(n.body) phis = [remove_ssa(s) for s in n.phis] branch_dict = split_phis(phis) if debug: print 'remove ssa While ', phis, branch_dict if 0 < len(branch_dict): ret = Stmt(branch_dict[0] + [While(test, append_stmts(body, Stmt(branch_dict[1])), None)]) else: ret = While(test, body, None) return ret elif isinstance(n, Pass): return n elif isinstance(n, Assign): return n elif isinstance(n, VarDecl): return n else: raise Exception('Error in remove_ssa: unrecognized AST node ' + repr(n)) ############################################################################### # Remove Complex Opera* def remove_complex(n, need_simple=True): if debug: print 'remove complex: ', n if isinstance(n, Program): return Program([remove_complex(d) for d in n.decls]) elif isinstance(n, Module): (ss, ds) = remove_complex(n.node) return Module(n.doc, Stmt(ds + ss)) elif isinstance(n, Function): (ss, ds) = remove_complex(n.code) return Function(name=n.name, argnames=n.argnames, code=Stmt(ds + ss), doc=n.doc, decorators=n.decorators, defaults=n.defaults, flags=n.flags) elif isinstance(n, Stmt): if debug: print 'RC stmt nodes: ', n.nodes sss_dss = [remove_complex(s) for s in n.nodes] ss = reduce(lambda a,b: a + b, [s for (s,d) in sss_dss], []) ds = reduce(lambda a,b: a + b, [d for (s,d) in sss_dss], []) if debug: print 'RC stmt ss: ', ss return (ss, ds) elif isinstance(n, Printnl): es_sss_dss = [remove_complex(e) for e in n.nodes] nodes = [e for (e,ss,ds) in es_sss_dss] ss = reduce(lambda a,b: a + b, [s for (e,s,d) in es_sss_dss], []) ds = reduce(lambda a,b: a + b, [d for (e,s,d) in es_sss_dss], []) return (ss + [Printnl(nodes, n.dest)], ds) elif isinstance(n, Discard): (e, ss, ds) = remove_complex(n.expr, False) return (ss + [Discard(e)], ds) elif isinstance(n, Return): (e, ss, ds) = remove_complex(n.value, True) return (ss + [Return(e)], ds) elif isinstance(n, If): tests_ss_ds = [(remove_complex(cond), remove_complex(body)) \ for (cond,body) in n.tests] new_tests = [ (c, Stmt(ss2)) \ for ((c,ss1,ds1), (ss2,ds2)) in tests_ss_ds] ss = reduce(lambda a,b: a + b, [s for ((c,s,d), _) in tests_ss_ds], []) ds0 = reduce(lambda a,b: a + b, [ d1 + d2 for ((c,s1,d1), (s2,d2)) in tests_ss_ds], []) (sse, dse) = remove_complex(n.else_) return (ss + [If(new_tests, Stmt(sse))], ds0 + dse) elif n == None: return ([], []) elif isinstance(n, While): (test, sst, dst) = remove_complex(n.test) (ssb, dsb) = remove_complex(n.body) return (sst + [While(test, Stmt(ssb + sst), None)], dst + dsb) elif isinstance(n, Pass): return ([n], []) elif isinstance(n, Assign): (e,ss,ds) = remove_complex(n.expr, False) return (ss + [Assign(n.nodes, e)], ds) elif isinstance(n, VarDecl): return ([], [n]) ## expressions elif isinstance(n, Const): return (n, [], []) elif isinstance(n, Name): return (n, [], []) elif isinstance(n, PrimitiveOp): es_sss_dss = [remove_complex(e) for e in n.nodes] nodes = [e for (e,ss,ds) in es_sss_dss] ss = reduce(lambda a,b: a + b, [s for (e,s,d) in es_sss_dss], []) ds = reduce(lambda a,b: a + b, [d for (e,s,d) in es_sss_dss], []) if need_simple: tmp = generate_name('__tmp_rc') assign = make_assign_t(tmp, PrimitiveOp(n.name, nodes), n.type) result = Name(tmp) result.type = n.type return (result, ss + [assign], ds + [VarDecl(tmp, n.type)]) else: return (PrimitiveOp(n.name, nodes), ss, ds) elif isinstance(n, CallFunc): (node,ss1,ds1) = remove_complex(n.node) es_sss_dss = [remove_complex(e) for e in n.args] args = [e for (e,ss,ds) in es_sss_dss] ss2 = reduce(lambda a,b: a + b, [s for (e,s,d) in es_sss_dss], []) ds2 = reduce(lambda a,b: a + b, [d for (e,s,d) in es_sss_dss], []) if need_simple: tmp = generate_name('__tmp_rc') assign = make_assign_t(tmp, CallFunc(node, args), n.type) result = Name(tmp) result.type = n.type return (result, ss1 + ss2 + [assign], ds1 + ds2 + [VarDecl(tmp, n.type)]) else: return (CallFunc(node, args), ss1 + ss2, ds1 + ds2) elif isinstance(n, IfExp): (test, ss_test, ds_test) = remove_complex(n.test) (then, ss_then, ds_then) = remove_complex(n.then) (else_, ss_else, ds_else) = remove_complex(n.else_) tmp = generate_name('__tmp') if_stmt = If([(test, Stmt(ss_then + [make_assign_t(tmp, then, n.then.type)]))], Stmt(ss_else + [make_assign_t(tmp, else_, n.else_.type)])) result = Name(tmp) result.type = n.type return (result, ss_test + [if_stmt], ds_test + ds_then + ds_else + [VarDecl(tmp, then.type)]) elif isinstance(n, Let): (rhs, ss_rhs, ds_rhs) = remove_complex(n.rhs, False) (body, ss_body, ds_body) = remove_complex(n.body) ss = ss_rhs + [make_assign_t(n.var, rhs, n.rhs.type)] + ss_body return (body, ss, ds_rhs + ds_body + [VarDecl(n.var, n.rhs.type)]) else: raise Exception('Error in remove_complex: unrecognized AST node ' + repr(n)) ############################################################################### # Generate C output python_type_to_c = { 'int' : 'int', 'bool' : 'char', 'float' : 'double', 'pyobj' : 'pyobj', 'pyobj*' : 'pyobj*', 'void*' : 'void*', 'undefined' : 'undefined' } def generate_c(n): if debug: print 'generate_c', n if isinstance(n, Program): return ''' #include #include #include #include #include "hashtable.h" #include "hashtable_itr.h" #define SIGN_OF(a) (((a) < 0) ? -1 : 1) #define pyobj_to_bool(v) (!is_false(v)) #define logic_and(A, B) bool_to_pyobj(pyobj_to_bool(A) && pyobj_to_bool(B)) #define logic_or(A, B) bool_to_pyobj(pyobj_to_bool(A) || pyobj_to_bool(B)) #define logic_and_int(A, B) (A && B) #define logic_and_bool(A, B) (A && B) #define logic_or_int(A, B) (A || B) #define logic_or_bool(A, B) (A || B) #define add_int(a, b) (a + b) #define add_bool(a, b) (a + b) #define add_float(a, b) (a + b) #define sub_int(a, b) (a - b) #define sub_bool(a, b) (a - b) #define sub_float(a, b) (a - b) #define mul_int(a, b) (a * b) #define mul_bool(a, b) (a * b) #define mul_float(a, b) (a * b) #define floordiv_int(a, b) ((int)floor((double)a/ (double)b)) #define floordiv_bool(a, b) ((int)floor((double)a/ (double)b)) #define floordiv_float(a, b) (a/b) #define unary_add_int(a) (+a) #define unary_add_bool(a) (+a) #define unary_add_float(a) (+a) #define unary_sub_int(a) (-a) #define unary_sub_bool(a) (-a) #define unary_sub_float(a) (-a) #define mod_bool mod_int #define logic_not_int(a) (!a) #define logic_not_bool(a) (!a) #define logic_not_float(a) (!a) #define less_int(a,b) (a < b) #define less_bool(a,b) (a < b) #define less_float(a,b) (a < b) #define greater_int(a,b) (a > b) #define greater_bool(a,b) (a > b) #define greater_float(a,b) (a > b) #define less_equal_int(a,b) (a <= b) #define less_equal_bool(a,b) (a <= b) #define less_equal_float(a,b) (a <= b) #define greater_equal_int(a,b) (a >= b) #define greater_equal_bool(a,b) (a >= b) #define greater_equal_float(a,b) (a >= b) #define equal_int(a,b) (a == b) #define equal_bool(a,b) (a == b) #define equal_float(a,b) (a == b) #define identical_int(a,b) (a == b) #define identical_bool(a,b) (a == b) #define identical_float(a,b) (a == b) #define not_equal_int(a,b) (a != b) #define not_equal_bool(a,b) (a != b) #define not_equal_float(a,b) (a != b) #define int_to_float(a) ((double)a) #define float_to_int(a) ((int)a) #define bool_to_float(a) ((double)a) #define bool_to_int(a) ((int)a) enum type_tag { INT, FLOAT, BOOL, LIST, DICT, CLOSURE, CLASS, OBJECT, UBMETHOD, BMETHOD }; struct pyobj_struct; struct array_struct { struct pyobj_struct* data; unsigned int len; }; typedef struct array_struct array; struct closure_struct { void* function_ptr; struct pyobj_struct* free_vars; int num_params; }; typedef struct closure_struct closure; struct class_struct { struct hashtable *attrs; int nparents; struct class_struct *parents; }; typedef struct class_struct class; struct object_struct { struct hashtable *attrs; class cl; }; typedef struct object_struct object; struct unbound_method_struct { closure fun; class cl; }; typedef struct unbound_method_struct unbound_method; struct bound_method_struct { closure fun; object receiver; }; typedef struct bound_method_struct bound_method; struct pyobj_struct { enum type_tag tag; union { int i; /* int */ double f; /* float */ char b; /* bool */ array l; /* list */ struct hashtable* d; /* dictionary */ closure c; /* functions */ class cl; /* class objects */ object obj; /* instances */ unbound_method ubm; bound_method bm; } u; }; typedef struct pyobj_struct pyobj; char less_pyobj(pyobj a, pyobj b); char less_equal_pyobj(pyobj a, pyobj b); char greater_pyobj(pyobj a, pyobj b); char greater_equal_pyobj(pyobj a, pyobj b); char equal_pyobj(pyobj a, pyobj b); char not_equal_pyobj(pyobj a, pyobj b); char identical_pyobj(pyobj lhs, pyobj rhs); pyobj add_pyobj(pyobj lhs, pyobj rhs); pyobj make_list(int len); void print_pyobj(pyobj v); char printed_0; char printed_0_neg; void print_float(double in) { char outstr[128]; snprintf(outstr, 128, "%%.12g", in); char *p = outstr; if(in == 0.0) { if(printed_0 == 0) { printed_0 = 1; printed_0_neg = *p == '-'; /*see if we incremented for negative*/ } else { printf(printed_0_neg ? "-0.0" : "0.0"); return; } } if(*p == '-') p++; while(*p && isdigit(*p)) p++; printf( ( (*p) ? "%%s" : "%%s.0" ), outstr); } void print_int(int i) { printf("%%d", i); } void print_bool(char b) { if (b == 0) printf("False"); else printf("True"); } static pyobj *current_list; void print_list(pyobj pyobj_list) { if(current_list && current_list == pyobj_list.u.l.data) { printf("[...]"); return; } int will_reset = 0; if(!current_list) { current_list = pyobj_list.u.l.data; will_reset = 1; } array l = pyobj_list.u.l; printf("["); int i; for(i = 0; i < l.len; i++) { if (l.data[i].tag == LIST && l.data[i].u.l.data == l.data) printf("[...]"); else print_pyobj(l.data[i]); if(i != l.len - 1) printf(", "); } printf("]"); if(will_reset) current_list = NULL; } char is_in_list(pyobj a, pyobj b) { char ret = 0; int i; for(i = 0; i< a.u.l.len; i++) { if(identical_pyobj(a.u.l.data[i],b)) return 1; } return ret; } static char inside; static pyobj printing_list; void print_dict(pyobj dict) { char inside_reset = 0; if(!inside) { inside = 1; inside_reset = 1; printing_list = make_list(0); } if(is_in_list(printing_list,dict)) { printf("{...}"); return; } printf("{"); int i = 0; int max = hashtable_count(dict.u.d); struct hashtable_itr *itr = hashtable_iterator(dict.u.d); if (max) { do { pyobj k = *(pyobj *)hashtable_iterator_key(itr); pyobj v = *(pyobj *)hashtable_iterator_value(itr); print_pyobj(k); printf(": "); if(is_in_list(printing_list,v) || equal_pyobj(v,dict)) { printf("{...}"); } else { /* tally this dictionary in our list of printing dicts */ pyobj a = make_list(1); a.u.l.data[0] = dict; printing_list = add_pyobj(printing_list, a); print_pyobj(v); } if(i != max - 1) printf(", "); i++; } while (hashtable_iterator_advance(itr)); } printf("}"); if(inside_reset) { inside = 0; printing_list = make_list(0); } } void print_pyobj_rec(pyobj v) { switch (v.tag) { case INT: print_int(v.u.i); break; case FLOAT: { print_float(v.u.f); break; } case BOOL: print_bool(v.u.b); break; case LIST: { print_list(v); break; } case DICT: { print_dict(v); break; } default: printf("error, unhandled case in print_pyobj_rec\\n"); *(int*)0 = 42; } } void print_pyobj(pyobj v) { print_pyobj_rec(v); } pyobj int_to_pyobj(int x) { pyobj v; v.tag = INT; v.u.i = x; return v; } pyobj float_to_pyobj(double x) { pyobj v; v.tag = FLOAT; v.u.f = x; return v; } pyobj bool_to_pyobj(char x) { pyobj v; v.tag = BOOL; v.u.b = x; return v; } pyobj make_list(int len) { pyobj v; v.tag = LIST; v.u.l.data = (pyobj*)GC_malloc(sizeof(pyobj) * len); v.u.l.len = len; return v; } pyobj* list_nth(pyobj list, pyobj n) { switch (list.tag) { case LIST: { switch (n.tag) { case INT: { if (n.u.i < list.u.l.len) return &(list.u.l.data[n.u.i]); else { printf("ERROR: list_nth index larger than list"); exit(1); } } case BOOL: { if (n.u.b < list.u.l.len) return &(list.u.l.data[n.u.b]); else { printf("ERROR: list_nth index larger than list"); exit(1); } } default: printf("ERROR: list_nth expected integer index"); exit(1); } } default: printf("ERROR: list_nth applied to non-list"); exit(1); } } pyobj list_add(pyobj x, pyobj y) { array a = x.u.l; array b = y.u.l; pyobj v; int i; v.tag = LIST; v.u.l.data = (pyobj*)GC_malloc(sizeof(pyobj) * (a.len + b.len)); v.u.l.len = a.len + b.len; for (i = 0; i != a.len; ++i) v.u.l.data[i] = a.data[i]; for (i = 0; i != b.len; ++i) v.u.l.data[a.len + i] = b.data[i]; return v; } pyobj list_sub(pyobj x, pyobj y) { printf("error, unsupported operand types"); *(int*)0 = 42; } pyobj list_mult(pyobj x, int n) { int i; pyobj r = make_list(0); for (i = 0; i != n; ++i) r = list_add(x, r); return r; } pyobj list_mul(pyobj x, pyobj y) { switch (x.tag) { case INT: switch (y.tag) { case LIST: return list_mult(y, x.u.i); default: printf("error, unsupported operand types"); *(int*)0 = 42; } case BOOL: switch (y.tag) { case LIST: return list_mult(y, x.u.b); default: printf("error, unsupported operand types"); *(int*)0 = 42; } case LIST: switch (y.tag) { case INT: return list_mult(x, y.u.i); case BOOL: return list_mult(x, y.u.b); default: printf("error, unsupported operand types"); *(int*)0 = 42; } default: printf("error, unsupported operand types"); *(int*)0 = 42; } } pyobj list_divide(pyobj x, pyobj y) { printf("error, unsupported operand types"); *(int*)0 = 42; } /* This hash function was chosen more or less at random -Jeremy */ static int hash32shift(int key) { key = ~key + (key << 15); /* key = (key << 15) - key - 1; */ key = key ^ (key >> 12); key = key + (key << 2); key = key ^ (key >> 4); key = key * 2057; /* key = (key + (key << 3)) + (key << 11); */ key = key ^ (key >> 16); return key; } static unsigned int hash_any(void* o) { pyobj* obj = (pyobj*)o; switch (obj->tag) { case INT: return hash32shift(obj->u.i); case FLOAT: return hash32shift(obj->u.f); case BOOL: return hash32shift(obj->u.b); case LIST: { int i; unsigned long h = 0; for (i = 0; i != obj->u.l.len; ++i) h = 5*h + hash_any(&obj->u.l.data[i]); return h; } case DICT: { struct hashtable_itr* i; unsigned long h = 0; if (hashtable_count(obj->u.d) == 0) return h; i = hashtable_iterator(obj->u.d); do { h = 5*h + hash_any(hashtable_iterator_value(i)); } while (hashtable_iterator_advance(i)); return h; } default: printf("unrecognized tag in hash_any\\n"); *(int*)0 = 42; } } int is_false(pyobj v) { switch (v.tag) { case INT: return v.u.i == 0; case FLOAT: return v.u.f == 0; case BOOL: return v.u.b == 0; case LIST: return v.u.l.len == 0; case DICT: return v.u.d == 0; default: printf("error, unhandled case in is_false\\n"); *(int*)0 = 42; } } static int equal_any(void* a, void* b) { return equal_pyobj(*(pyobj*)a, *(pyobj*)b); } pyobj make_dict() { pyobj v; v.tag = DICT; v.u.d = create_hashtable(4, hash_any, equal_any); return v; } pyobj* dict_subscript(pyobj dict, pyobj key) { switch (dict.tag) { case DICT: { void* p = hashtable_search(dict.u.d, &key); if (p) return (pyobj*)p; else { pyobj* k = (pyobj*)GC_malloc(sizeof(pyobj)); *k = key; pyobj* v = (pyobj*)GC_malloc(sizeof(pyobj)); v->tag = INT; v->u.i = 444; hashtable_insert(dict.u.d, k, v); return v; } } default: printf("error in dict_get, not a dictionary\\n"); *(int*)0 = 42; } } pyobj* subscript_pyobj(pyobj c, pyobj key) { switch (c.tag) { case LIST: return list_nth(c, key); case DICT: return dict_subscript(c, key); default: printf("error in subscript, not a list or dictionary\\n"); *(int*)0 = 42; } } pyobj set_subscript_pyobj(pyobj c, pyobj key, pyobj val) { switch (c.tag) { case LIST: return *list_nth(c, key) = val; case DICT: return *dict_subscript(c, key) = val; default: printf("error in set subscript, not a list or dictionary\\n"); *(int*)0 = 42; } } #define gen_unary_op(NAME, OP) \\ pyobj NAME##_pyobj(pyobj a) { \\ switch (a.tag) { \\ case INT: \\ return int_to_pyobj(OP a.u.i); \\ case FLOAT: \\ return float_to_pyobj(OP a.u.f); \\ case BOOL: \\ return int_to_pyobj(OP a.u.b); \\ default: \\ printf("error, unhandled case in unary operator\\n"); \\ *(int*)0 = 42; \\ } \\ } gen_unary_op(unary_add, +) gen_unary_op(unary_sub, -) #define gen_binary_op(NAME, OP) \\ pyobj NAME##_pyobj(pyobj a, pyobj b) { \\ switch (a.tag) { \\ case INT: \\ switch (b.tag) { \\ case INT: \\ return int_to_pyobj(a.u.i OP b.u.i); \\ case FLOAT: \\ return float_to_pyobj(a.u.i OP b.u.f); \\ case BOOL: \\ return int_to_pyobj(a.u.i OP b.u.b); \\ case LIST: \\ return list_##NAME(a, b); \\ default: \\ printf("error, unhandled case in operator\\n"); \\ *(int*)0 = 42; \\ } \\ case FLOAT: \\ switch (b.tag) { \\ case INT: \\ return float_to_pyobj(a.u.f OP b.u.i); \\ case FLOAT: \\ return float_to_pyobj(a.u.f OP b.u.f); \\ case BOOL: \\ return float_to_pyobj(a.u.f OP b.u.b); \\ default: \\ printf("error, unhandled case in operator\\n"); \\ *(int*)0 = 42; \\ } \\ case BOOL: \\ switch (b.tag) { \\ case INT: \\ return int_to_pyobj(a.u.b OP b.u.i); \\ case FLOAT: \\ return float_to_pyobj(a.u.b OP b.u.f); \\ case BOOL: \\ return int_to_pyobj(a.u.b OP b.u.b); \\ case LIST: \\ return list_##NAME(a, b); \\ default: \\ printf("error, unhandled case in operator\\n"); \\ *(int*)0 = 42; \\ } \\ case LIST: \\ switch (b.tag) { \\ case LIST: \\ return list_##NAME(a, b); \\ case INT: \\ return list_##NAME(a, b); \\ case BOOL: \\ return list_##NAME(a, b); \\ default: \\ printf("error, unhandled case in operator\\n"); \\ *(int*)0 = 42; \\ } \\ default: \\ printf("error, unhandled case in operator\\n"); \\ *(int*)0 = 42; \\ } \\ } gen_binary_op(add, +) gen_binary_op(sub, -) gen_binary_op(mul, *) pyobj floordiv_pyobj(pyobj a, pyobj b) { pyobj ret; switch(a.tag) { case INT: switch(b.tag) { case INT: ret.tag = INT; ret.u.i = (int)floor((double)(a.u.i) / (double)b.u.i); break; case FLOAT: ret.tag = FLOAT; ret.u.f = ((double)a.u.i / b.u.f); break; case BOOL: ret.tag = INT; ret.u.i = (int)floor((double)a.u.i / (double)b.u.b); break; default: break; } break; case FLOAT: switch(b.tag) { case INT: ret.tag = FLOAT; ret.u.f = a.u.f / b.u.i; break; case FLOAT: ret.tag = FLOAT; ret.u.f = a.u.f / b.u.f; break; case BOOL: ret.tag = FLOAT; ret.u.f = a.u.f / (double)b.u.b; break; default: break; } break; case BOOL: switch(b.tag) { case INT: ret.tag = INT; ret.u.i = (int)floor((double)a.u.b / (double)b.u.i); break; case FLOAT: ret.tag = FLOAT; ret.u.f = ((double)a.u.b / b.u.f); break; case BOOL: ret.tag = INT; ret.u.i = (int)floor((double)a.u.b / (double)b.u.b); break; default: break; } break; default: break; } return ret; } char logic_not_pyobj(pyobj v) { if (is_false(v)) return 1; else return 0; } int min(int x, int y) { return y < x ? y : x; } char list_less(array x, array y) { int i; for (i = 0; i != min(x.len, y.len); ++i) { if (less_pyobj(x.data[i], y.data[i])) return 1; else if (less_pyobj(y.data[i], x.data[i])) return 0; } if (x.len < y.len) return 1; else return 0; } char list_equal(array x, array y) { char eq = 1; int i; for (i = 0; i != min(x.len, y.len); ++i) eq = eq && equal_pyobj(x.data[i], y.data[i]); if (x.len == y.len) return eq; else return 0; } char list_not_equal(array x, array y) { return !list_equal(x,y); } char list_greater(array x, array y) { return list_less(y,x); } char list_less_equal(array x, array y) { return !list_greater(y,x); } char list_greater_equal(array x, array y) { return !list_less(y,x); } static struct hashtable *current_cmp_a; static struct hashtable *current_cmp_b; char dict_equal(struct hashtable* x, struct hashtable* y) { if(hashtable_count(x) != hashtable_count(y)) return 0; if(current_cmp_a) { if(current_cmp_a == x) { return current_cmp_a == y; } else if(current_cmp_a == y) { return current_cmp_a == x; } } if(current_cmp_b) { if(current_cmp_b == y) { return current_cmp_b == x; } else if(current_cmp_b == x) { return current_cmp_b == y; } } char will_reset = 0; char same = 1; if(!current_cmp_a) { current_cmp_a = x; current_cmp_b = y; will_reset = 1; } int max = hashtable_count(x); struct hashtable_itr *itr_a = hashtable_iterator(x); struct hashtable_itr *itr_b = hashtable_iterator(y); if (max) { do { pyobj k_a = *(pyobj *)hashtable_iterator_key(itr_a); pyobj v_a = *(pyobj *)hashtable_iterator_value(itr_a); pyobj k_b = *(pyobj *)hashtable_iterator_key(itr_b); pyobj v_b = *(pyobj *)hashtable_iterator_value(itr_b); if(not_equal_pyobj(k_a,k_b) || not_equal_pyobj(v_a,v_b)) same = 0; } while (hashtable_iterator_advance(itr_a) && hashtable_iterator_advance(itr_b)); } if(will_reset) { current_cmp_a = NULL; current_cmp_b = NULL; } return same; } char dict_greater(struct hashtable* x, struct hashtable* y) { int size_a = hashtable_count(x); int size_b = hashtable_count(y); if(size_a == size_b) { struct hashtable_itr *itr_a = hashtable_iterator(x); struct hashtable_itr *itr_b = hashtable_iterator(y); if (size_a) { do { pyobj k_a = *(pyobj *)hashtable_iterator_key(itr_a); pyobj v_a = *(pyobj *)hashtable_iterator_value(itr_a); pyobj k_b = *(pyobj *)hashtable_iterator_key(itr_b); pyobj v_b = *(pyobj *)hashtable_iterator_value(itr_b); if(greater_pyobj(k_a,k_b) || greater_pyobj(v_a,v_b)) { return 1; } } while (hashtable_iterator_advance(itr_a) && hashtable_iterator_advance(itr_b)); } return 0; } else return size_a > size_b; } char dict_less(struct hashtable* x, struct hashtable* y) { return dict_greater(y, x); } char dict_less_equal(struct hashtable* x, struct hashtable* y) { return !dict_greater(y,x); } char dict_greater_equal(struct hashtable* x, struct hashtable* y) { return !dict_less(y,x); } char dict_not_equal(struct hashtable* x, struct hashtable* y) { return !dict_equal(x,y); } #define gen_comparison(NAME, OP) \\ char NAME##_pyobj(pyobj a, pyobj b) \\ {\\ switch (a.tag) {\\ case INT:\\ switch (b.tag) {\\ case INT:\\ return a.u.i OP b.u.i; \\ case FLOAT:\\ return a.u.i OP b.u.f; \\ case BOOL:\\ return a.u.i OP b.u.b; \\ default: \\ return 0; \\ }\\ case FLOAT: \\ switch (b.tag) {\\ case INT:\\ return a.u.f OP b.u.i; \\ case FLOAT:\\ return a.u.f OP b.u.f; \\ case BOOL:\\ return a.u.f OP b.u.b; \\ default: \\ return 0; \\ }\\ case BOOL: \\ switch (b.tag) {\\ case INT:\\ return a.u.b OP b.u.i; \\ case FLOAT:\\ return a.u.b OP b.u.f; \\ case BOOL:\\ return a.u.b OP b.u.b; \\ default: \\ return 0; \\ }\\ case LIST: \\ switch (b.tag) { \\ case LIST: \\ return list_##NAME(a.u.l, b.u.l); \\ default: \\ return 0; \\ } \\ case DICT: \\ switch (b.tag) { \\ case DICT: \\ return dict_##NAME(a.u.d, b.u.d); \\ default: \\ return 0; \\ } \\ default: \\ return 0; \\ } \\ } gen_comparison(less, <) gen_comparison(equal, ==) char less_equal_pyobj(pyobj a, pyobj b) { return less_pyobj(a,b) || equal_pyobj(a,b); } char greater_pyobj(pyobj a, pyobj b) { return !less_equal_pyobj(a,b); } char greater_equal_pyobj(pyobj a, pyobj b) { return !less_pyobj(a,b); } char not_equal_pyobj(pyobj a, pyobj b) { return !equal_pyobj(a,b); } char identical_pyobj(pyobj a, pyobj b) { if(a.tag != b.tag) return 0; switch(a.tag) { case INT: return (a.u.i == b.u.i); case BOOL: return (a.u.b == b.u.b); case FLOAT: return (a.u.f == b.u.f); case DICT: return (a.u.d == b.u.d); case LIST: return (a.u.l.len == b.u.l.len && a.u.l.data == b.u.l.data); case CLOSURE: return (a.u.c.free_vars == b.u.c.free_vars); case CLASS: return (a.u.cl.attrs == b.u.cl.attrs); case OBJECT: return (a.u.obj.attrs == b.u.obj.attrs); case UBMETHOD: return a.u.ubm.fun.free_vars == b.u.ubm.fun.free_vars && a.u.ubm.cl.attrs == b.u.ubm.cl.attrs; case BMETHOD: return a.u.bm.fun.free_vars == b.u.bm.fun.free_vars && a.u.bm.receiver.attrs == b.u.bm.receiver.attrs; } return 0; } pyobj power_pyobj(pyobj a, pyobj b) { pyobj ret; switch(a.tag) { case INT: switch(b.tag) { case INT: if(b.u.i >= 0) { ret.tag = INT; ret.u.i = (int)pow(a.u.i,b.u.i); } else { ret.tag = FLOAT; ret.u.f = pow(a.u.i,b.u.i); } break; case FLOAT: ret.tag = FLOAT; ret.u.f = pow(a.u.i,b.u.f); break; case BOOL: ret.tag = INT; ret.u.i = (int)pow(a.u.i,b.u.b); break; default: break; } break; case FLOAT: switch(b.tag) { case INT: ret.tag = FLOAT; ret.u.f = pow(a.u.f,b.u.i); break; case FLOAT: ret.tag = FLOAT; ret.u.f = pow(a.u.f,b.u.f); break; case BOOL: ret.tag = FLOAT; ret.u.f = pow(a.u.f,b.u.b); break; default: break; } break; case BOOL: switch(b.tag) { case INT: ret.tag = INT; ret.u.i = (int)pow(a.u.b,b.u.i); break; case FLOAT: ret.tag = FLOAT; ret.u.f = (int)pow(a.u.b,b.u.f); break; case BOOL: ret.tag = INT; ret.u.i = (int)pow(a.u.b,b.u.b); break; default: break; } break; default: break; } return ret; } pyobj power_int(int a, int b) { return power_pyobj(int_to_pyobj(a), int_to_pyobj(b)); } double power_float(double a, double b) { return pow(a, b); } pyobj power_bool(char a, char b) { return power_pyobj(bool_to_pyobj(a), bool_to_pyobj(b)); } int mod_int(int a, int b) { int ret = a %% b; return (SIGN_OF(ret) == SIGN_OF(b) || ret == 0) ? ret : ret + b; } double mod_float(double a, double b) { double ret = fmod(a,b); return (SIGN_OF(ret) == SIGN_OF(b) || ret == 0) ? ret : ret + b; } pyobj mod_pyobj(pyobj a, pyobj b) { pyobj ret; switch(a.tag) { case INT: switch(b.tag) { case INT: return int_to_pyobj(mod_int(a.u.i,b.u.i)); case FLOAT: return float_to_pyobj(mod_float(a.u.i,b.u.f)); case BOOL: ret.tag = INT; ret.u.i = a.u.i %% b.u.b; break; default: break; } break; case FLOAT: switch(b.tag) { case INT: ret.tag = FLOAT; ret.u.f = fabs(fmod(a.u.f, b.u.i)); if(SIGN_OF(ret.u.f) != SIGN_OF(b.u.i)) ret.u.f += b.u.i; break; case FLOAT: ret.tag = FLOAT; ret.u.f = fabs(fmod(a.u.f,b.u.f)); break; case BOOL: ret.tag = FLOAT; ret.u.f = fmod(a.u.f,(double)b.u.b); break; default: break; } break; case BOOL: switch(b.tag) { case INT: ret.tag = INT; ret.u.i = a.u.b %% b.u.i; break; case FLOAT: ret.tag = FLOAT; ret.u.f = fmod((double)a.u.b,b.u.f); break; case BOOL: ret.tag = INT; ret.u.i = a.u.b %% b.u.b; break; default: break; } break; default: break; } return ret; } pyobj input() { pyobj ret; char buf[1000]; if (fgets(buf, 1000, stdin) != NULL) { if (strcmp(buf, "True") == 0) { ret.tag = BOOL; ret.u.b = 1; } else if (strcmp(buf, "False") == 0) { ret.tag = BOOL; ret.u.b = 0; } else if (strstr(buf, ".")) { ret.tag = FLOAT; ret.u.f = atof(buf); } else { ret.tag = INT; ret.u.i = atoi(buf); } } return ret; } pyobj make_closure_pyobj(void* fptr, pyobj fvs) { pyobj ret; ret.tag = CLOSURE; ret.u.c.function_ptr = fptr; ret.u.c.free_vars = (pyobj*)GC_malloc(sizeof(pyobj)); *ret.u.c.free_vars = fvs; ret.u.c.num_params = 0; return ret; } char* tag_names[] = {"int", "float", "bool", "list", "dict", "closure", "class", "object", "unbound method", "bound method" }; void* get_func_pyobj(pyobj clos) { switch (clos.tag) { case CLOSURE: return clos.u.c.function_ptr; default: printf("trying to get a function pointer from closure: but got a %%s\\n", tag_names[clos.tag]); exit(-1); return NULL; } } pyobj get_closure_pyobj(pyobj clos) { pyobj ret; ret.tag = CLOSURE; switch (clos.tag) { case UBMETHOD: ret.u.c = clos.u.ubm.fun; break; case BMETHOD: ret.u.c = clos.u.bm.fun; break; default: printf("trying to get a closure from method: but got a %%s\\n", tag_names[clos.tag]); exit(-1); } return ret; } pyobj get_free_pyobj(pyobj clos) { if (clos.tag == CLOSURE) return *clos.u.c.free_vars; else exit(-1); } void set_free_pyobj(pyobj clos, pyobj fvs) { if (clos.tag == CLOSURE) *clos.u.c.free_vars = fvs; else exit(-1); } unsigned int attrname_hash(void *ptr) { unsigned char *str = (unsigned char *)ptr; unsigned long hash = 5381; int c; while(c=*str++) hash = ((hash << 5) + hash) ^ c; return hash; } int attrname_equal(void *a, void *b) { return !strcmp( (char*)a, (char*)b ); } pyobj make_class_pyobj(pyobj bases) { pyobj ret; ret.tag = CLASS; ret.u.cl.attrs = create_hashtable(2, attrname_hash, attrname_equal); switch (bases.tag) { case LIST: { int i; ret.u.cl.nparents = bases.u.l.len; ret.u.cl.parents = (class*)GC_malloc(sizeof(class) * ret.u.cl.nparents); for (i = 0; i != ret.u.cl.nparents; ++i) { pyobj* parent = &bases.u.l.data[i]; if (parent->tag == CLASS) ret.u.cl.parents[i] = parent->u.cl; else exit(-1); } break; } default: exit(-1); } return ret; } /* we leave calling the __init__ function for a separate step. */ pyobj make_object_pyobj(pyobj cl) { pyobj ret; ret.tag = OBJECT; if (cl.tag == CLASS) ret.u.obj.cl = cl.u.cl; else { printf("in make object, expected a class\\n"); exit(-1); } ret.u.obj.attrs = create_hashtable(2, attrname_hash, attrname_equal); return ret; } pyobj* attrsearch_rec(class cl, char* attr) { pyobj* ptr; int i; ptr = hashtable_search(cl.attrs, attr); if(ptr == NULL) { for(i=0; i != cl.nparents; ++i) { ptr = attrsearch_rec(cl.parents[i], attr); if (ptr != NULL) return ptr; } return NULL; } else return ptr; } pyobj* attrsearch(class cl, char* attr) { pyobj* ret = attrsearch_rec(cl, attr); if (ret == NULL) { printf("attribute %%s not found\\n", attr); exit(-1); } return ret; } pyobj make_bound_method(object receiver, closure f) { pyobj ret; ret.tag = BMETHOD; ret.u.bm.fun = f; ret.u.bm.receiver = receiver; return ret; } pyobj make_unbound_method(class cl, closure f) { pyobj ret; ret.tag = UBMETHOD; ret.u.ubm.fun = f; ret.u.ubm.cl = cl; return ret; } char has_attr_pyobj(pyobj o, char* attr) { switch (o.tag) { case CLASS: { pyobj* attribute = attrsearch_rec(o.u.cl, attr); return attribute != NULL; } case OBJECT: { pyobj* attribute = hashtable_search(o.u.obj.attrs, attr); if (attribute == NULL) { attribute = attrsearch_rec(o.u.cl, attr); return attribute != NULL; } else { return 1; } } default: return 0; } } char is_class_pyobj(pyobj o) { return o.tag == CLASS; } char is_class_bool(char o) { return 0; } char is_class_int(int o) { return 0; } char is_class_float(float o) { return 0; } char is_bound_method_pyobj(pyobj o) { return o.tag == BMETHOD; } char is_unbound_method_pyobj(pyobj o) { return o.tag == UBMETHOD; } char inherits_rec(class c1, class c2) { char ret = 0; if (c1.attrs == c2.attrs) { ret = 1; } else { int i; for(i=0; i != c1.nparents; ++i) { ret = inherits_rec(c1.parents[i], c2); if (ret) break; } } return ret; } char inherits_pyobj(pyobj c1, pyobj c2) { if (c1.tag == CLASS && c2.tag == CLASS) { return inherits_rec(c1.u.cl, c2.u.cl); } else { printf("inherits expects classes\\n"); exit(-1); return 0; } } pyobj get_class_pyobj(pyobj o) { pyobj ret; ret.tag = CLASS; switch (o.tag) { case OBJECT: ret.u.cl = o.u.obj.cl; break; case UBMETHOD: ret.u.cl = o.u.ubm.cl; break; default: printf("get class expected object or unbound method\\n"); exit(-1); } return ret; } pyobj get_object_pyobj(pyobj o) { pyobj ret; ret.tag = OBJECT; switch (o.tag) { case BMETHOD: ret.u.obj = o.u.bm.receiver; break; default: printf("get object expected bound method\\n"); exit(-1); } return ret; } pyobj get_attr_pyobj(pyobj c, char* attr) { switch (c.tag) { case CLASS: { pyobj* attribute = attrsearch(c.u.cl, attr); if (attribute->tag == CLOSURE) { return make_unbound_method(c.u.cl, attribute->u.c); } else { return *attribute; } } case OBJECT: { pyobj* attribute = hashtable_search(c.u.obj.attrs, attr); if (attribute == NULL) { attribute = attrsearch(c.u.obj.cl, attr); if (attribute->tag == CLOSURE) { return make_bound_method(c.u.obj, attribute->u.c); } else { return *attribute; } } else { return *attribute; } } default: printf("error in get attribute, not a class or object\\n"); *(int*)0 = 42; } } void set_attr_pyobj(pyobj obj, char* attr, pyobj val) { char* k; pyobj* v; k = (char *)GC_malloc(strlen(attr)+1); v = (pyobj *)GC_malloc(sizeof(pyobj)); strcpy(k, attr); *v = val; struct hashtable* attrs; switch (obj.tag) { case CLASS: attrs = obj.u.cl.attrs; break; case OBJECT: attrs = obj.u.obj.attrs; break; default: printf("error, expected object or class in set attribute\\n"); exit(-1); } if(!hashtable_change(attrs, k, v)) if(!hashtable_insert(attrs, k, v)) { printf("out of memory"); exit(-1); } } pyobj error_pyobj(char* string) { printf(string); exit(-1); } %s''' % '\n'.join([generate_c(d) for d in n.decls]) elif isinstance(n, Module): return '''int main() { %s return 0; }''' % generate_c(n.node) elif isinstance(n, Function): return 'pyobj %s(%s){\n%s}\n' % \ (n.name, ', '.join(['pyobj ' + x for x in n.argnames]), generate_c(n.code)) elif isinstance(n, Stmt): return '{' + '\n'.join([generate_c(e) for e in n.nodes]) + '\n' + '}' elif isinstance(n, Printnl): space = 'printf(\" \");\n' newline = 'printf(\"\\n\");\n' nodes_in_c = ['print_%s(%s);\n' % (x.type, generate_c(x)) for x in n.nodes] return space.join(nodes_in_c) + newline elif isinstance(n, Discard): return generate_c(n.expr) + ';' elif isinstance(n, Return): return 'return ' + generate_c(n.value) + ';' elif isinstance(n, If): if n.else_ == None: else_ = '' else: else_ = 'else\n' + generate_c(n.else_) return 'if ' + '\n else if '.join(['(%s)\n%s' % (generate_c(cond), generate_c(body)) for (cond,body) in n.tests]) + else_ elif isinstance(n, While): return 'while (%s)\n%s' % (generate_c(n.test), generate_c(n.body)) elif isinstance(n, Pass): return ';' elif isinstance(n, Assign): return '='.join([generate_c(v) for v in n.nodes]) \ + ' = ' + generate_c(n.expr) + ';' elif isinstance(n, AssName): return n.name elif isinstance(n, VarDecl): return '%s %s;' % (python_type_to_c[n.type], n.name) elif isinstance(n, Const): if isinstance(n.value, str): return '\"%s\"' % n.value else: return repr(n.value) elif isinstance(n, Name): if n.name == 'True': return '1' elif n.name == 'False': return '0' else: return n.name elif isinstance(n, PrimitiveOp): if n.name == 'deref': return '*' + generate_c(n.nodes[0]) elif n.name == 'assign_pyobj' or n.name == 'assign_int' or n.name == 'assign_bool' or n.name == 'assign_float': return '(' + generate_c(n.nodes[0]) + '=' + generate_c(n.nodes[1]) + ')' else: return n.name + '(' + ', '.join([generate_c(e) for e in n.nodes]) + ')' elif isinstance(n, CallFunc): f = generate_c(n.node) args = ', '.join([generate_c(a) for a in n.args]) params = ', '.join(['pyobj' for i in range(0, len(n.args))]) return '((pyobj(*)(%s))%s)(%s)' % (params, f, args) elif isinstance(n, IfExp): return '(' + generate_c(n.test) + ' ? ' \ + generate_c(n.then) + ':' + generate_c(n.else_) + ')' elif isinstance(n, Let): t = python_type_to_c[n.rhs.type] rhs = generate_c(n.rhs) return '({ ' + t + ' ' + n.var + ' = ' + rhs + '; ' + generate_c(n.body) + ';})' elif n == None: return '' else: raise Exception('Error in generate_c: unrecognized AST node ' + repr(n)) try: ast = compiler.parseFile(sys.argv[1]) if debug: print ast print '** simplifying ops' ir = simplify_ops(ast) if debug: print ast print '** lowering classes' ir = lower_classes(ir, set([])) if debug: print ir print '** heapifying' ir = heapify(ir, set([])) if debug: print ir print '** closure conversion' ir = convert_closures(ir) if debug: print ir print '** converting to ssa' ir = convert_to_ssa(ir) if debug: print ir print '** inserting var decls' ir = insert_var_decls(ir) if debug: print '** predicting types' predict_type(ir, LinkedEnv()) if debug: print '** type specialization' print ir ir = type_specialize(ir) if debug: print '** remove ssa' print ir ir = remove_ssa(ir) if debug: print '** remove complex' print ir ir = remove_complex(ir) if debug: print '** generate C' print ir print generate_c(ir) except EOFError: print "Could not open file %s." % sys.argv[1] except Exception, e: print "exception!" print e.args exit(-1)