[
  {
    "path": ".gitignore",
    "content": "*.exe\n*.o\n*.s\n"
  },
  {
    "path": "LICENSE.txt",
    "content": "MIT License\n\nCopyright (c) 2017 Ben Hoyt\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "\npyast64\n=======\n\npyast64 is a Python 3 program that compiles a subset of the Python AST to x64-64 assembler. It's extremely restricted (read \"a toy\") but it's a nice proof of concept in any case. [Read more about pyast64 here.](http://benhoyt.com/writings/pyast64/)\n"
  },
  {
    "path": "arrays.p64",
    "content": "# Test arrays\n\ndef fetch(array, ofs):\n    return array[ofs]\n\ndef store(array, ofs, value):\n    array[ofs] = value\n\ndef print_num(n):\n    if n == 0:\n        putc(48)  # '0'\n        return\n    if n < 0:\n        putc(45)  # '-' sign\n        n = -n\n    div = n // 10\n    if div != 0:\n        print_num(div)\n    putc(48 + n % 10)\n\ndef main():\n    size = 100\n    a = array(size)\n    for i in range(size):\n        store(a, i, i)\n    sum = 0\n    for i in range(size):\n        sum = sum + fetch(a, i)\n    print_num(sum)\n"
  },
  {
    "path": "benchmark_for.c",
    "content": "#include <stdio.h>\n\nint main() {\n    long long sum = 0;\n    for (long long i = 0; i < 100000000; i++) {\n        sum += i;\n    }\n    return (int)sum;\n}\n"
  },
  {
    "path": "benchmark_for.p64",
    "content": "# Benchmark a for loop\n\n# python3.5: 7.002s\n# pyast64: 0.548s (12.7x)\n# pyast64_peephole: 0.236s (29.7x, 2.3x)\n\ndef main():\n    sum = 0\n    for i in range(100000000):\n        sum += i\n"
  },
  {
    "path": "compile_and_run.sh",
    "content": "#!/usr/bin/env bash\n\nset -e\n\nfullname=\"$1\"\nfname=${fullname%.*}\n\npython pyast64.py $1 > $fname.s\nas $fname.s -o $fname.o\nld $fname.o -e _main -o $fname.exe -w\n./$fname.exe\necho\n"
  },
  {
    "path": "forloop.p64",
    "content": "# Basic for loop\n\ndef loop():\n    for i in range(10):\n        putc(65 + i)\n\ndef main():\n    loop()\n"
  },
  {
    "path": "print_num.p64",
    "content": "# Convert a number to decimal and print it\n\ndef print_num(n):\n    if n == 0:\n        putc(48)  # '0'\n        return\n    if n < 0:\n        putc(45)  # '-' sign\n        n = -n\n    div = n // 10\n    if div != 0:\n        print_num(div)\n    putc(48 + n % 10)\n\ndef newline():\n    putc(10)\n\ndef main():\n    print_num(1234)\n    newline()\n    print_num(-404)\n    newline()\n    print_num(0)\n"
  },
  {
    "path": "pyast64.py",
    "content": "\"\"\"Compile a subset of the Python AST to x64-64 assembler.\n\nRead more about it here: http://benhoyt.com/writings/pyast64/\n\nReleased under a permissive MIT license (see LICENSE.txt).\n\"\"\"\n\nimport argparse\nimport ast\nimport sys\n\n\nclass Assembler:\n    \"\"\"The Assembler takes care of outputting instructions, labels, etc.,\n    as well as a simple peephole optimization to combine sequences of pushes\n    and pops.\n    \"\"\"\n\n    def __init__(self, output_file=sys.stdout, peephole=True):\n        self.output_file = output_file\n        self.peephole = peephole\n        # Current batch of instructions, flushed on label and end of function\n        self.batch = []\n\n    def flush(self):\n        if self.peephole:\n            self.optimize_pushes_pops()\n        for opcode, args in self.batch:\n            print('\\t{}\\t{}'.format(opcode, ', '.join(str(a) for a in args)),\n                  file=self.output_file)\n        self.batch = []\n\n    def optimize_pushes_pops(self):\n        \"\"\"This finds runs of push(es) followed by pop(s) and combines\n        them into simpler, faster mov instructions. For example:\n\n        pushq   8(%rbp)\n        pushq   $100\n        popq    %rdx\n        popq    %rax\n\n        Will be turned into:\n\n        movq    $100, %rdx\n        movq    8(%rbp), %rax\n        \"\"\"\n        state = 'default'\n        optimized = []\n        pushes = 0\n        pops = 0\n\n        # This nested function combines a sequence of pushes and pops\n        def combine():\n            mid = len(optimized) - pops\n            num = min(pushes, pops)\n            moves = []\n            for i in range(num):\n                pop_arg = optimized[mid + i][1][0]\n                push_arg = optimized[mid - i - 1][1][0]\n                if push_arg != pop_arg:\n                    moves.append(('movq', [push_arg, pop_arg]))\n            optimized[mid - num:mid + num] = moves\n\n        # This loop actually finds the sequences\n        for opcode, args in self.batch:\n            if state == 'default':\n                if opcode == 'pushq':\n                    state = 'push'\n                    pushes += 1\n                else:\n                    pushes = 0\n                    pops = 0\n                optimized.append((opcode, args))\n            elif state == 'push':\n                if opcode == 'pushq':\n                    pushes += 1\n                elif opcode == 'popq':\n                    state = 'pop'\n                    pops += 1\n                else:\n                    state = 'default'\n                    pushes = 0\n                    pops = 0\n                optimized.append((opcode, args))\n            elif state == 'pop':\n                if opcode == 'popq':\n                    pops += 1\n                elif opcode == 'pushq':\n                    combine()\n                    state = 'push'\n                    pushes = 1\n                    pops = 0\n                else:\n                    combine()\n                    state = 'default'\n                    pushes = 0\n                    pops = 0\n                optimized.append((opcode, args))\n            else:\n                assert False, 'bad state: {}'.format(state)\n        if state == 'pop':\n            combine()\n        self.batch = optimized\n\n    def instr(self, opcode, *args):\n        self.batch.append((opcode, args))\n\n    def label(self, name):\n        self.flush()\n        print('{}:'.format(name), file=self.output_file)\n\n    def directive(self, line):\n        self.flush()\n        print(line, file=self.output_file)\n\n    def comment(self, text):\n        self.flush()\n        print('# {}'.format(text), file=self.output_file)\n\n\nclass LocalsVisitor(ast.NodeVisitor):\n    \"\"\"Recursively visit a FunctionDef node to find all the locals\n    (so we can allocate the right amount of stack space for them).\n    \"\"\"\n\n    def __init__(self):\n        self.local_names = []\n        self.global_names = []\n        self.function_calls = []\n\n    def add(self, name):\n        if name not in self.local_names and name not in self.global_names:\n            self.local_names.append(name)\n\n    def visit_Global(self, node):\n        self.global_names.extend(node.names)\n\n    def visit_Assign(self, node):\n        assert len(node.targets) == 1, \\\n            'can only assign one variable at a time'\n        self.visit(node.value)\n        target = node.targets[0]\n        if isinstance(target, ast.Subscript):\n            self.add(target.value.id)\n        else:\n            self.add(target.id)\n\n    def visit_For(self, node):\n        self.add(node.target.id)\n        for statement in node.body:\n            self.visit(statement)\n\n    def visit_Call(self, node):\n        self.function_calls.append(node.func.id)\n\n\nclass Compiler:\n    \"\"\"The main Python AST -> x86-64 compiler.\"\"\"\n\n    def __init__(self, assembler=None, peephole=True):\n        if assembler is None:\n            assembler = Assembler(peephole=peephole)\n        self.asm = assembler\n        self.func = None\n\n    def compile(self, node):\n        self.header()\n        self.visit(node)\n        self.footer()\n\n    def visit(self, node):\n        # We could have subclassed ast.NodeVisitor, but it's better to fail\n        # hard on AST nodes we don't support\n        name = node.__class__.__name__\n        visit_func = getattr(self, 'visit_' + name, None)\n        assert visit_func is not None, '{} not supported - node {}'.format(\n                name, ast.dump(node))\n        visit_func(node)\n\n    def header(self):\n        self.asm.directive('.section __TEXT, __text')\n        self.asm.comment('')\n\n    def footer(self):\n        self.compile_putc()\n        self.asm.flush()\n\n    def compile_putc(self):\n        # Insert this into every program so it can call putc() for output\n        self.asm.label('putc')\n        self.compile_enter()\n        self.asm.instr('movl', '$0x2000004', '%eax')    # write\n        self.asm.instr('movl', '$1', '%edi')            # stdout\n        self.asm.instr('movq', '%rbp', '%rsi')          # address\n        self.asm.instr('addq', '$16', '%rsi')\n        self.asm.instr('movq', '$1', '%rdx')            # length\n        self.asm.instr('syscall')\n        self.compile_return(has_arrays=False)\n\n    def visit_Module(self, node):\n        for statement in node.body:\n            self.visit(statement)\n\n    def visit_FunctionDef(self, node):\n        assert self.func is None, 'nested functions not supported'\n        assert node.args.vararg is None, '*args not supported'\n        assert not node.args.kwonlyargs, 'keyword-only args not supported'\n        assert not node.args.kwarg, 'keyword args not supported'\n\n        self.func = node.name\n        self.label_num = 1\n        self.locals = {a.arg: i for i, a in enumerate(node.args.args)}\n\n        # Find names of additional locals assigned in this function\n        locals_visitor = LocalsVisitor()\n        locals_visitor.visit(node)\n        for name in locals_visitor.local_names:\n            if name not in self.locals:\n                self.locals[name] = len(self.locals) + 1\n        if 'array' in locals_visitor.function_calls:\n            self.locals['_array_size'] = len(self.locals) + 1\n        self.globals = set(locals_visitor.global_names)\n        self.break_labels = []\n\n        # Function label and header\n        if node.name == 'main':\n            self.asm.directive('.globl _main')\n            self.asm.label('_main')\n        else:\n            self.asm.label(node.name)\n        self.num_extra_locals = len(self.locals) - len(node.args.args)\n        self.compile_enter(self.num_extra_locals)\n\n        # Now compile all the statements in the function body\n        for statement in node.body:\n            self.visit(statement)\n\n        if not isinstance(node.body[-1], ast.Return):\n            # Function didn't have explicit return at the end,\n            # compile return now (or exit for \"main\")\n            if self.func == 'main':\n                self.compile_exit(0)\n            else:\n                self.compile_return(self.num_extra_locals)\n\n        self.asm.comment('')\n        self.func = None\n\n    def compile_enter(self, num_extra_locals=0):\n        # Make space for extra locals (in addition to the arguments)\n        for i in range(num_extra_locals):\n            self.asm.instr('pushq', '$0')\n        # Use rbp for a stack frame pointer\n        self.asm.instr('pushq', '%rbp')\n        self.asm.instr('movq', '%rsp', '%rbp')\n\n    def compile_return(self, num_extra_locals=0, has_arrays=None):\n        if has_arrays is None:\n            has_arrays = '_array_size' in self.locals\n        if has_arrays:\n            offset = self.local_offset('_array_size')\n            self.asm.instr('movq', '{}(%rbp)'.format(offset), '%rbx')\n            self.asm.instr('addq', '%rbx', '%rsp')\n        self.asm.instr('popq', '%rbp')\n        if num_extra_locals > 0:\n            self.asm.instr('leaq', '{}(%rsp),%rsp'.format(\n                    num_extra_locals * 8))\n        self.asm.instr('ret')\n\n    def compile_exit(self, return_code):\n        if return_code is None:\n            self.asm.instr('popq', '%rdi')\n        else:\n            self.asm.instr('movl', '${}'.format(return_code), '%edi')\n        self.asm.instr('movl', '$0x2000001', '%eax')\n        self.asm.instr('syscall')\n\n    def visit_Return(self, node):\n        if node.value:\n            self.visit(node.value)\n        if self.func == 'main':\n            # Returning from main, exit with that return code\n            self.compile_exit(None if node.value else 0)\n        else:\n            if node.value:\n                self.asm.instr('popq', '%rax')\n            self.compile_return(self.num_extra_locals)\n\n    def visit_Num(self, node):\n        self.asm.instr('pushq', '${}'.format(node.n))\n\n    def local_offset(self, name):\n        index = self.locals[name]\n        return (len(self.locals) - index) * 8 + 8\n\n    def visit_Name(self, node):\n        # Only supports locals, not globals\n        offset = self.local_offset(node.id)\n        self.asm.instr('pushq', '{}(%rbp)'.format(offset))\n\n    def visit_Assign(self, node):\n        # Only supports assignment of (a single) local variable\n        assert len(node.targets) == 1, \\\n            'can only assign one variable at a time'\n        self.visit(node.value)\n        target = node.targets[0]\n        if isinstance(target, ast.Subscript):\n            # array[offset] = value\n            self.visit(target.slice.value)\n            self.asm.instr('popq', '%rax')\n            self.asm.instr('popq', '%rbx')\n            local_offset = self.local_offset(target.value.id)\n            self.asm.instr('movq', '{}(%rbp)'.format(local_offset), '%rdx')\n            self.asm.instr('movq', '%rbx', '(%rdx,%rax,8)')\n        else:\n            # variable = value\n            offset = self.local_offset(node.targets[0].id)\n            self.asm.instr('popq', '{}(%rbp)'.format(offset))\n\n    def visit_AugAssign(self, node):\n        # Handles \"n += 1\" and the like\n        self.visit(node.target)\n        self.visit(node.value)\n        self.visit(node.op)\n        offset = self.local_offset(node.target.id)\n        self.asm.instr('popq', '{}(%rbp)'.format(offset))\n\n    def simple_binop(self, op):\n        self.asm.instr('popq', '%rdx')\n        self.asm.instr('popq', '%rax')\n        self.asm.instr(op, '%rdx', '%rax')\n        self.asm.instr('pushq', '%rax')\n\n    def visit_Mult(self, node):\n        self.asm.instr('popq', '%rdx')\n        self.asm.instr('popq', '%rax')\n        self.asm.instr('imulq', '%rdx')\n        self.asm.instr('pushq', '%rax')\n\n    def compile_divide(self, push_reg):\n        self.asm.instr('popq', '%rbx')\n        self.asm.instr('popq', '%rax')\n        self.asm.instr('cqo')\n        self.asm.instr('idiv', '%rbx')\n        self.asm.instr('pushq', push_reg)\n\n    def visit_Mod(self, node):\n        self.compile_divide('%rdx')\n\n    def visit_FloorDiv(self, node):\n        self.compile_divide('%rax')\n\n    def visit_Add(self, node):\n        self.simple_binop('addq')\n\n    def visit_Sub(self, node):\n        self.simple_binop('subq')\n\n    def visit_BinOp(self, node):\n        self.visit(node.left)\n        self.visit(node.right)\n        self.visit(node.op)\n\n    def visit_UnaryOp(self, node):\n        assert isinstance(node.op, ast.USub), \\\n            'only unary minus is supported, not {}'.format(node.op.__class__.__name__)\n        self.visit(ast.Num(n=0))\n        self.visit(node.operand)\n        self.visit(ast.Sub())\n\n    def visit_Expr(self, node):\n        self.visit(node.value)\n        self.asm.instr('popq', '%rax')\n\n    def visit_And(self, node):\n        self.simple_binop('and')\n\n    def visit_BitAnd(self, node):\n        self.simple_binop('and')\n\n    def visit_Or(self, node):\n        self.simple_binop('or')\n\n    def visit_BitOr(self, node):\n        self.simple_binop('or')\n\n    def visit_BitXor(self, node):\n        self.simple_binop('xor')\n\n    def visit_BoolOp(self, node):\n        self.visit(node.values[0])\n        for value in node.values[1:]:\n            self.visit(value)\n            self.visit(node.op)\n\n    def builtin_array(self, args):\n        assert len(args) == 1, 'array(len) expected 1 arg, not {}'.format(len(args))\n        self.visit(args[0])\n        # Allocate array on stack, add size to _array_size, push address\n        self.asm.instr('popq', '%rax')\n        self.asm.instr('shlq', '$3', '%rax')  # len*8 to get size in bytes\n        offset = self.local_offset('_array_size')\n        self.asm.instr('addq', '%rax', '{}(%rbp)'.format(offset))\n        self.asm.instr('subq', '%rax', '%rsp')\n        self.asm.instr('movq', '%rsp', '%rax')\n        self.asm.instr('pushq', '%rax')\n\n    def visit_Call(self, node):\n        assert not node.keywords, 'keyword args not supported'\n        builtin = getattr(self, 'builtin_{}'.format(node.func.id), None)\n        if builtin is not None:\n            builtin(node.args)\n        else:\n            for arg in node.args:\n                self.visit(arg)\n            self.asm.instr('call', node.func.id)\n            if node.args:\n                # Caller cleans up the arguments from the stack\n                self.asm.instr('addq', '${}'.format(8 * len(node.args)), '%rsp')\n            # Return value is in rax, so push it on the stack now\n            self.asm.instr('pushq', '%rax')\n\n    def label(self, slug):\n        label = '{}_{}_{}'.format(self.func, self.label_num, slug)\n        self.label_num += 1\n        return label\n\n    def visit_Compare(self, node):\n        assert len(node.ops) == 1, 'only single comparisons supported'\n        self.visit(node.left)\n        self.visit(node.comparators[0])\n        self.visit(node.ops[0])\n\n    def compile_comparison(self, jump_not, slug):\n        self.asm.instr('popq', '%rdx')\n        self.asm.instr('popq', '%rax')\n        self.asm.instr('cmpq', '%rdx', '%rax')\n        self.asm.instr('movq', '$0', '%rax')\n        label = self.label(slug)\n        self.asm.instr(jump_not, label)\n        self.asm.instr('incq', '%rax')\n        self.asm.label(label)\n        self.asm.instr('pushq', '%rax')\n\n    def visit_Lt(self, node):\n        self.compile_comparison('jnl', 'less')\n\n    def visit_LtE(self, node):\n        self.compile_comparison('jnle', 'less_or_equal')\n\n    def visit_Gt(self, node):\n        self.compile_comparison('jng', 'greater')\n\n    def visit_GtE(self, node):\n        self.compile_comparison('jnge', 'greater_or_equal')\n\n    def visit_Eq(self, node):\n        self.compile_comparison('jne', 'equal')\n\n    def visit_NotEq(self, node):\n        self.compile_comparison('je', 'not_equal')\n\n    def visit_If(self, node):\n        # Handles if, elif, and else\n        self.visit(node.test)\n        self.asm.instr('popq', '%rax')\n        self.asm.instr('cmpq', '$0', '%rax')\n        label_else = self.label('else')\n        label_end = self.label('end')\n        self.asm.instr('jz', label_else)\n        for statement in node.body:\n            self.visit(statement)\n        if node.orelse:\n            self.asm.instr('jmp', label_end)\n        self.asm.label(label_else)\n        for statement in node.orelse:\n            self.visit(statement)\n        if node.orelse:\n            self.asm.label(label_end)\n\n    def visit_While(self, node):\n        # Handles while and break (also used for \"for\" -- see below)\n        while_label = self.label('while')\n        break_label = self.label('break')\n        self.break_labels.append(break_label)\n        self.asm.label(while_label)\n        self.visit(node.test)\n        self.asm.instr('popq', '%rax')\n        self.asm.instr('cmpq', '$0', '%rax')\n        self.asm.instr('jz', break_label)\n        for statement in node.body:\n            self.visit(statement)\n        self.asm.instr('jmp', while_label)\n        self.asm.label(break_label)\n        self.break_labels.pop()\n\n    def visit_Break(self, node):\n        self.asm.instr('jmp', self.break_labels[-1])\n\n    def visit_Pass(self, node):\n        pass\n\n    def visit_For(self, node):\n        # Turn for+range loop into a while loop:\n        #   i = start\n        #   while i < stop:\n        #       body\n        #       i = i + step\n        assert isinstance(node.iter, ast.Call) and \\\n            node.iter.func.id == 'range', \\\n            'for can only be used with range()'\n        range_args = node.iter.args\n        if len(range_args) == 1:\n            start = ast.Num(n=0)\n            stop = range_args[0]\n            step = ast.Num(n=1)\n        elif len(range_args) == 2:\n            start, stop = range_args\n            step = ast.Num(n=1)\n        else:\n            start, stop, step = range_args\n            if (isinstance(step, ast.UnaryOp) and\n                    isinstance(step.op, ast.USub) and\n                    isinstance(step.operand, ast.Num)):\n                # Handle negative step\n                step = ast.Num(n=-step.operand.n)\n            assert isinstance(step, ast.Num) and step.n != 0, \\\n                'range() step must be a nonzero integer constant'\n        self.visit(ast.Assign(targets=[node.target], value=start))\n        test = ast.Compare(\n            left=node.target,\n            ops=[ast.Lt() if step.n > 0 else ast.Gt()],\n            comparators=[stop],\n        )\n        incr = ast.Assign(\n            targets=[node.target],\n            value=ast.BinOp(left=node.target, op=ast.Add(), right=step),\n        )\n        self.visit(ast.While(test=test, body=node.body + [incr]))\n\n    def visit_Global(self, node):\n        # Global names are already collected by LocalsVisitor\n        pass\n\n    def visit_Subscript(self, node):\n        self.visit(node.slice.value)\n        self.asm.instr('popq', '%rax')\n        local_offset = self.local_offset(node.value.id)\n        self.asm.instr('movq', '{}(%rbp)'.format(local_offset), '%rdx')\n        self.asm.instr('pushq', '(%rdx,%rax,8)')\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('filename', help='filename to compile')\n    parser.add_argument('-n', '--no-peephole', action='store_true',\n                        help='enable peephole assembler optimizer')\n    args = parser.parse_args()\n\n    with open(args.filename) as f:\n        source = f.read()\n    node = ast.parse(source, filename=args.filename)\n    compiler = Compiler(peephole=not args.no_peephole)\n    compiler.compile(node)\n"
  }
]