[
  {
    "path": ".gitignore",
    "content": "hello\n.vscode\ncore\n"
  },
  {
    "path": "README.md",
    "content": "# The Pretty Laughable Programming Language\n\nAn educational C-like toy programming language that compiles to x64 binary.\n\nThe [compiler](pl_comp.py) is a self-contained Python program that weighs about 1000 LoC.\nIt's part of an online [tutorial](https://build-your-own.org/compiler/) on compilers and interpreters.\n\n## Introduction\n\nThe hello world looks like this:\n\n```clojure\n; the write() syscall:\n; ssize_t write(int fd, const void *buf, size_t count);\n(syscall 1 1 \"Hello world!\\n\" 13)\n0\n```\n\nCompile and run the program:\n```sh\n$ ./pl_comp.py ./samples/hello.txt -o ./hello\n$ ./hello\nHello world!\n```\n\nThe output is a tiny freestanding x64 Linux ELF binary.\n```sh\n$ file hello\nhello: ELF 64-bit LSB executable, x86-64, version 1 (SYSV), statically linked, no section header\n$ wc -c hello\n288 hello\n```\n\n## The Language\n\nThe syntax is just S-expression, parsing strings is too boring for me.\n\nThe semantics are C-like. The only data types are integers and pointers. This should be enough to write any program in.\n\n### 01. Pointers\n\nThe `peek` command reads data from a pointer and the `poke` command writes to a pointer.\n\n```clojure\n; copy data byte by byte\n(def (memcpy void) ((dst ptr byte) (src ptr byte) (n int)) (do\n    (loop n (do\n        (poke dst (peek src))\n        (set dst (+ 1 dst))\n        (set src (+ 1 src))\n        (set n (- n 1))\n    ))\n))\n```\n\n### 02. Control Flows\n\nList of control flow structures:\n\n```clojure\n(? cond yes no)\n(if cond (then yes blah blah) (else no no no))\n(do a b c...)\n(loop cond body)\n(break)\n(continue)\n(call f a b c...)\n(return val)\n```\n\nSome examples:\n\n```clojure\n(def (fib int) ((n int))\n    (if (le n 0) (then 0) (else (+ n (call fib (- n 1))))))\n```\n\n```clojure\n(def (fib int) ((n int)) (do\n    (var r 0)\n    (loop (gt n 0) (do\n        (set r (+ r n))\n        (set n (- n 1))\n    ))\n    (return r)\n))\n```\n\n### 03. Data Types\n\nThe only data types are:\n\n- `byte`: unsigned 8-bit integer.\n- `int`:  signed 64-bit integer.\n- `ptr elem_type`: pointer to `elem_type`.\n\nVariable types are automatically inferred:\n\n```clojure\n(var a 123)         ; int\n(var b 45u8)        ; byte\n(var p (ptr int))   ; a null pointer to int\n(var s \"asdf\")      ; ptr byte\n```\n\nThe type of the function return value and the argument must be specified explicitly:\n\n```clojure\n(def (memcpy void) ((dst ptr byte) (src ptr byte) (n int)) (do\n    ; ...\n))\n```\n\n`int` can be cast to any pointer types and vice versa.\n\n```clojure\n(var i 0x1234)                  ; int\n(var p (cast (ptr int) i))      ; ptr int\n(var a (cast (int) (+ 1 p)))    ; int\n```\n\n### 04. Memory Management\n\nMemory management is very simple at this point, because it doesn't exist at all.\n\nHowever, the language doesn't prevent you from building your own memory management routines. This usually starts with the `mmap` syscall.\n\n```clojure\n(var heap (ptr byte))\n\n; a fake malloc\n(def (malloc ptr byte) ((n int)) (do\n    (if (not heap) (do\n        ; create the heap via mmap()\n        (var heapsz 1048576)    ; 1M\n        (var prot 3)            ; PROT_READ|PROT_WRITE\n        (var flags 0x22)        ; MAP_PRIVATE|MAP_ANONYMOUS\n        (var fd -1)\n        (var offset 0)\n        (var r (syscall 9 0 heapsz prot flags fd offset))\n        (set heap (cast (ptr byte) r))\n    ))\n    ; just move the heap pointer forward\n    (var r heap)\n    (set heap (+ n heap))\n    (return r)\n))\n\n; TODO: figure out how to recycle the memory\n(def (free void) ((p ptr byte)) (do))\n```\n\n### 05. The stdlib\n\nThe Pretty Laughable Language comes with the world's smallest standard library &mdash; no standard library &mdash; not even a builtin `print` function.\n\nBut with the ability to make arbitrary syscalls and peek-poke the memory, you can build your own stdlibs. Let's add the `print` function:\n\n```clojure\n(def (strlen int) ((s ptr byte)) (do\n    (var start s)\n    (loop (peek s) (set s (+ 1 s)))\n    (return (- s start))\n))\n\n(def (print void) ((s ptr byte)) (do\n    (syscall 1 1 s (call strlen s))\n))\n\n(call print \"Yes!\\n\")\n0\n```\n\n[Here](samples/malloc_and_strings.txt) is a more sophisticated program you can play with.\n\n## Roadmaps\n\nLanguage features:\n\n- [x] int, byte\n- [x] pointer\n- [x] syscall\n- [x] if-then-else, loop\n- [x] function\n- [x] nested function, nonlocal variable\n- [ ] array\n- [ ] struct, class\n- [ ] function pointer\n\nExplorations:\n\n- [ ] module or `include` directive\n- [ ] macro?\n- [ ] alternative syntax?\n- [ ] Windows\n- [ ] ARM64\n- [ ] WASM\n\nOptimizations:\n\n- [ ] register allocation\n- [ ] constants\n- [ ] tail call\n\n## The Design\n\n### 01. The Goal\n\nTBA\n\n### 02. The IR (Intermediate Representation)\n\nTBA\n\n### 03. Machine Code Generation\n\nTBA\n\n## The Implementation\n\nTo be added.\n\nBut you can learn how to do it by reading the source code.\n\nOr you might like the book [From Source Code To Machine Code](https://build-your-own.org/compiler/), which this repo is based on.\n"
  },
  {
    "path": "pl_comp.py",
    "content": "#!/usr/bin/env python3\n\nimport os\nimport sys\nimport mmap\nimport ctypes\nimport struct\nimport platform\n\n\n# TODO: function pointers\n# TODO: pointer to variables\n# TODO: binop8, unop8\n# TODO: array\n# TODO: class or struct\n# TODO: node comment\n# TODO: module or include directive\n\n# TODO: register allocation\n# TODO: constant prop\n# TODO: tail call\n\n\ndef skip_space(s, idx):\n    while True:\n        save = idx\n        # spaces\n        while idx < len(s) and s[idx].isspace():\n            idx += 1\n        # line comment\n        if idx < len(s) and s[idx] == ';':\n            idx += 1\n            while idx < len(s) and s[idx] != '\\n':\n                idx += 1\n        if idx == save:\n            break\n    return idx\n\n\ndef parse_expr(s, idx):\n    idx = skip_space(s, idx)\n    if s[idx] == '(':\n        idx += 1\n        l = []\n        while True:\n            idx = skip_space(s, idx)\n            if idx >= len(s):\n                raise Exception('unbalanced parenthesis')\n            if s[idx] == ')':\n                idx += 1\n                break\n            idx, v = parse_expr(s, idx)\n            l.append(v)\n        return idx, l\n    elif s[idx] == ')':\n        raise Exception('bad parenthesis')\n    elif s[idx] == '\"' or s[idx] == \"'\":\n        # string or u8\n        return parse_quotes(s, idx)\n    else:\n        # constant or name\n        start = idx\n        while idx < len(s) and (not s[idx].isspace()) and s[idx] not in '()':\n            idx += 1\n        if start == idx:\n            raise Exception('empty program')\n        return idx, parse_value(s[start:idx])\n\n\ndef parse_quotes(s, idx):\n    term = s[idx]\n    end = idx + 1\n    while end < len(s):\n        if s[end] == term:\n            break\n        if s[end] == '\\\\':\n            end += 1\n        end += 1\n    if end < len(s) and s[end] == term:\n        # TODO: actually implement this\n        import json\n        v = json.loads('\"' + s[idx+1:end] + '\"')\n        if term == '\"':\n            v = ['str', v]\n        else:\n            if len(v) != 1:\n                raise Exception('bad char')\n            v = ord(v)\n            if not (0 <= v < 256):\n                raise ValueError('bad integer range')\n            v = ['val8', v]\n        return end + 1, v\n\n\n# a single constant, or a name\ndef parse_value(s):\n    # int\n    try:\n        v = try_int(s)\n    except ValueError:\n        pass\n    else:\n        if not (-(1 << 63) <= v < (1 << 63)):\n            raise ValueError('bad integer range')\n        return ['val', v]\n\n    # u8\n    if s.endswith('u8'):\n        try:\n            v = try_int(s[:-2])\n        except ValueError:\n            pass\n        else:\n            if not (0 <= v < 256):\n                raise ValueError('bad integer range')\n            return ['val8', v]\n\n    # other\n    if s[0].isdigit():\n        raise ValueError('bad name')\n    return s\n\n\ndef try_int(s):\n    base = 10\n    if s[:2].lower() == '0x':\n        base = 16\n    # TODO: other bases\n    return int(s, base)\n\n\ndef pl_parse(s):\n    idx, node = parse_expr(s, 0)\n    idx = skip_space(s, idx)\n    if idx < len(s):\n        raise ValueError('trailing garbage')\n    return node\n\n\ndef pl_parse_main(s):\n    return pl_parse('(def (main int) () (do ' + s + '))')\n\n\n# the compiler state for functions\nclass Func:\n    def __init__(self, prev):\n        # the parent function (linked list)\n        self.prev = prev\n        # nested function level. the level of `main` is 1.\n        self.level = (prev.level + 1) if prev else 0\n        # the return type of this function\n        self.rtype = None\n        # a list of all functions. shared by all functions in a program.\n        self.funcs = prev.funcs if prev else []\n        # the name scope\n        self.scope = Scope(None)\n        # the output: a list of instructions\n        self.code = []\n        # current number of local variable in the stack (non-temporary)\n        self.nvar = 0\n        # current number of variables (both locals and temporaries)\n        self.stack = 0\n        # label IDs to instruction locations\n        self.labels = []\n\n    # enter a new scope\n    def scope_enter(self):\n        self.scope = Scope(self.scope)  # new list head\n        self.scope.save = self.stack\n\n    # exit a scope and revert the stack\n    def scope_leave(self):\n        self.stack = self.scope.save\n        self.nvar -= self.scope.nlocal\n        self.scope = self.scope.prev\n\n    # allocate a new local variable in the current scope\n    def add_var(self, name, tp):\n        # add it to the map\n        if name in self.scope.names:\n            raise ValueError('duplicated name')\n        self.scope.names[name] = (tp, self.nvar)    # (type, index)\n        self.scope.nlocal += 1\n        # assign the index\n        assert self.stack == self.nvar\n        dst = self.stack\n        self.stack += 1\n        self.nvar += 1\n        return dst\n\n    # lookup a name. returns a tuple of (function_level, type, index)\n    def get_var(self, name):\n        tp, var = scope_get_var(self.scope, name)\n        if var >= 0:\n            return self.level, tp, var\n        if not self.prev:\n            raise ValueError('undefined name')\n        return self.prev.get_var(name)\n\n    # allocate a temporary variable on the stack top and return its index\n    def tmp(self):\n        dst = self.stack\n        self.stack += 1\n        return dst\n\n    # allocate a new label ID\n    def new_label(self):\n        l = len(self.labels)\n        self.labels.append(None)    # filled later\n        return l\n\n    # associate the label ID to the current location\n    def set_label(self, l):\n        assert l < len(self.labels)\n        self.labels[l] = len(self.code)\n\n\n# the name scope linked list\nclass Scope:\n    def __init__(self, prev):\n        # the parent scope\n        self.prev = prev\n        # the number of local variables seen\n        self.nlocal = 0\n        # Variable names to (type, index) tuples.\n        # For functions, the key includes argument types\n        # and the index is the index of `Func.funcs`.\n        self.names = dict()\n        # the label IDs of the nearest loop\n        self.loop_start = prev.loop_start if prev else -1\n        self.loop_end = prev.loop_end if prev else -1\n\n\n# lookup a name from a scope. returns a (type, index) tuple.\ndef scope_get_var(scope, name):\n    while scope:    # linked list\n        if name in scope.names:\n            return scope.names[name]\n        scope = scope.prev\n    return None, -1 # not found\n\n\n# the entry point of compilation.\n# returns a (type, index) tuple. the index is -1 if the type is `('void',)`\ndef pl_comp_expr(fenv: Func, node, *, allow_var=False):\n    if allow_var:\n        assert fenv.stack == fenv.nvar\n    save = fenv.stack\n\n    # the actual implementation\n    tp, var = pl_comp_expr_tmp(fenv, node, allow_var=allow_var)\n    assert var < fenv.stack\n\n    # Discard temporaries from the above compilation:\n    if allow_var:\n        # The stack is either local variables only\n        fenv.stack = fenv.nvar\n    else:\n        # or reverts to its previous state.\n        fenv.stack = save\n\n    # The result is either a temporary stored at the top of the stack\n    # or a local variable.\n    assert var <= fenv.stack\n    return tp, var\n\n\ndef pl_comp_getvar(fenv: Func, node):\n    assert isinstance(node, str)\n    flevel, tp, var = fenv.get_var(node)\n    if flevel == fenv.level:\n        # local variable\n        return tp, var\n    else:\n        # non-local\n        dst = fenv.tmp()\n        fenv.code.append(('get_env', flevel, var, dst))\n        return tp, dst\n\n\ndef pl_comp_const(fenv: Func, node):\n    _, kid = node\n    assert isinstance(kid, (int, str))\n    dst = fenv.tmp()\n    fenv.code.append(('const', kid, dst))\n    tp = dict(val='int', val8='byte', str='ptr byte')[node[0]]\n    tp = tuple(tp.split())\n    return tp, dst\n\n\ndef pl_comp_binop(fenv: Func, node):\n    op, lhs, rhs = node\n\n    # compile subexpressions\n    # FIXME: boolean short circuit\n    save = fenv.stack\n    t1, a1 = pl_comp_expr_tmp(fenv, lhs)\n    t2, a2 = pl_comp_expr_tmp(fenv, rhs)\n    fenv.stack = save   # discard temporaries\n\n    # pointers\n    if op == '+' and (t1[0], t2[0]) == ('int', 'ptr'):\n        # rewrite `offset + ptr` into `ptr + offset`\n        t1, a1, t2, a2 = t2, a2, t1, a1\n    if op in '+-' and (t1[0], t2[0]) == ('ptr', 'int'):\n        # ptr + offset\n        scale = 8\n        if t1 == ('ptr', 'byte'):\n            scale = 1\n        if op == '-':\n            scale = -scale\n        # output to a new temporary\n        dst = fenv.tmp()\n        fenv.code.append(('lea', a1, a2, scale, dst))\n        return t1, dst\n    if op == '-' and (t1[0], t2[0]) == ('ptr', 'ptr'):\n        # ptr - ptr\n        if t1 != t2:\n            raise ValueError('comparison of different pointer types')\n        if t1 != ('ptr', 'byte'):\n            # TODO: ptr int\n            raise NotImplementedError\n        dst = fenv.tmp()\n        fenv.code.append(('binop', '-', a1, a2, dst))\n        return ('int',), dst\n\n    # check types\n    # TODO: allow different types\n    cmp = {'eq', 'ge', 'gt', 'le', 'lt', 'ne'}\n    ints = (t1 == t2 and t1[0] in ('int', 'byte'))\n    ptr_cmp = (t1 == t2 and t1[0] == 'ptr' and op in cmp)\n    if not (ints or ptr_cmp):\n        raise ValueError('bad binop types')\n    rtype = t1\n    if op in cmp:\n        rtype = ('int',)    # boolean\n\n    suffix = ''\n    if t1 == t2 and t1 == ('byte',):\n        suffix = '8'\n    # output to a new temporary\n    dst = fenv.tmp()\n    fenv.code.append(('binop' + suffix, op, a1, a2, dst))\n    return rtype, dst\n\n\ndef pl_comp_unop(fenv: Func, node):\n    op, arg = node\n    t1, a1 = pl_comp_expr(fenv, arg)\n\n    suffix = ''\n    rtype = t1\n    if op == '-':\n        if t1[0] not in ('int', 'byte'):\n            raise ValueError('bad unop types')\n        if t1 == ('byte',):\n            suffix = '8'\n    elif op == 'not':\n        if t1[0] not in ('int', 'byte', 'ptr'):\n            raise ValueError('bad unop types')\n        rtype = ('int',)    # boolean\n    dst = fenv.tmp()\n    fenv.code.append(('unop' + suffix, op, a1, dst))\n    return rtype, dst\n\n\n# The actual implementation of `pl_comp_expr`.\n# This preserves temporaries while `pl_comp` discards temporaries.\ndef pl_comp_expr_tmp(fenv: Func, node, *, allow_var=False):\n    # read a variable\n    if not isinstance(node, list):\n        return pl_comp_getvar(fenv, node)\n\n    # anything else\n    if len(node) == 0:\n        raise ValueError('empty list')\n\n    # constant\n    if len(node) == 2 and node[0] in ('val', 'val8', 'str'):\n        return pl_comp_const(fenv, node)\n    # binary operators\n    binops = {\n        '%', '*', '+', '-', '/',\n        'and', 'or',\n        'eq', 'ge', 'gt', 'le', 'lt', 'ne',\n    }\n    if len(node) == 3 and node[0] in binops:\n        return pl_comp_binop(fenv, node)\n    # unary operators\n    if len(node) == 2 and node[0] in {'-', 'not'}:\n        return pl_comp_unop(fenv, node)\n    # new scope\n    if node[0] in ('do', 'then', 'else'):\n        return pl_comp_scope(fenv, node)\n    # new variable\n    if node[0] == 'var' and len(node) == 3:\n        if not allow_var:\n            # Variable declarations are allowed only as\n            # children of scopes and conditions.\n            raise ValueError('variable declaration not allowed here')\n        return pl_comp_newvar(fenv, node)\n    # update a variable\n    if node[0] == 'set' and len(node) == 3:\n        return pl_comp_setvar(fenv, node)\n    # conditional\n    if len(node) in (3, 4) and node[0] in ('?', 'if'):\n        return pl_comp_cond(fenv, node)\n    # loop\n    if node[0] == 'loop' and len(node) == 3:\n        return pl_comp_loop(fenv, node)\n    # break & continue\n    if node == ['break']:\n        if fenv.scope.loop_end < 0:\n            raise ValueError('`break` outside a loop')\n        fenv.code.append(('jmp', fenv.scope.loop_end))\n        return ('void'), -1\n    if node == ['continue']:\n        if fenv.scope.loop_start < 0:\n            raise ValueError('`continue` outside a loop')\n        fenv.code.append(('jmp', fenv.scope.loop_start))\n        return ('void'), -1\n    # function call\n    if node[0] == 'call' and len(node) >= 2:\n        return pl_comp_call(fenv, node)\n    if node[0] == 'syscall' and len(node) >= 2:\n        return pl_comp_syscall(fenv, node)\n    # return\n    if node[0] == 'return' and len(node) in (1, 2):\n        return pl_comp_return(fenv, node)\n    # null pointer\n    if node[0] == 'ptr':\n        tp = validate_type(node)\n        dst = fenv.tmp()\n        fenv.code.append(('const', 0, dst))\n        return tp, dst\n    # cast\n    if node[0] == 'cast' and len(node) == 3:\n        return pl_comp_cast(fenv, node)\n    # peek & poke\n    if node[0] == 'peek' and len(node) == 2:\n        return pl_comp_peek(fenv, node)\n    if node[0] == 'poke' and len(node) == 3:\n        return pl_comp_poke(fenv, node)\n    # ref\n    if node[0] == 'ref' and len(node) == 2:\n        return pl_comp_ref(fenv, node)\n    # debug\n    if node == ['debug']:\n        fenv.code.append(('debug',))\n        return ('void',), -1\n\n    raise ValueError('unknown expression')\n\n\ndef pl_comp_cast(fenv: Func, node):\n    _, tp, value = node\n    tp = validate_type(tp)\n    val_tp, var = pl_comp_expr_tmp(fenv, value)\n\n    # to, from\n    free = [\n        ('int', 'ptr'),\n        ('ptr', 'int'),\n        ('ptr', 'ptr'),\n        ('int', 'byte'),\n        ('int', 'int'),\n        ('byte', 'byte'),\n    ]\n    if (tp[0], val_tp[0]) in free:\n        return tp, var\n    if (tp[0], val_tp[0]) == ('byte', 'int'):\n        fenv.code.append(('cast8', var))\n        return tp, var\n\n    raise ValueError('bad cast')\n\n\ndef pl_comp_peek(fenv: Func, node):\n    _, ptr = node\n    tp, var = pl_comp_expr(fenv, ptr)\n    head, *tail = tp\n    tail = tuple(tail)\n    if head != 'ptr':\n        raise ValueError('not a pointer')\n    suffix = ''\n    if tail == ('byte',):\n        suffix = '8'\n    fenv.code.append(('peek' + suffix, var, fenv.stack))\n    return tail, fenv.tmp()\n\n\ndef pl_comp_poke(fenv: Func, node):\n    _, ptr, value = node\n\n    save = fenv.stack\n    t2, var_val = pl_comp_expr_tmp(fenv, value)\n    t1, var_ptr = pl_comp_expr_tmp(fenv, ptr)\n    if t1 != ('ptr', *t2):\n        raise ValueError('pointer type mismatch')\n    fenv.stack = save\n\n    suffix = ''\n    if t2 == ('byte',):\n        suffix = '8'\n    fenv.code.append(('poke' + suffix, var_ptr, var_val))\n    return t2, move_to(fenv, var_val, fenv.tmp())\n\n\ndef pl_comp_ref(fenv: Func, node):\n    _, name = node\n\n    flevel, var_tp, var = fenv.get_var(name)\n    dst = fenv.tmp()\n    if flevel == fenv.level:\n        fenv.code.append(('ref_var', var, dst))         # local\n    else:\n        fenv.code.append(('ref_env', flevel, var, dst)) # non-local\n    return ('ptr', *var_tp), dst\n\n\ndef pl_comp_main(fenv: Func, node):\n    assert node[:3] == ['def', ['main', 'int'], []]\n    func = pl_scan_func(fenv, node)\n    return pl_comp_func(func, node)\n\n\ndef pl_comp_return(fenv: Func, node):\n    _, *kid = node\n    tp, var = ('void',), -1\n    if kid:\n        tp, var = pl_comp_expr_tmp(fenv, kid[0])\n    if tp != fenv.rtype:\n        raise ValueError('bad return type')\n    fenv.code.append(('ret', var))\n    return tp, var\n\n\ndef pl_comp_call(fenv: Func, node):\n    _, name, *args = node\n\n    # compile arguments\n    arg_types = []\n    for kid in args:\n        tp, var = pl_comp_expr(fenv, kid)\n        arg_types.append(tp)\n        move_to(fenv, var, fenv.tmp())  # stored continuously\n    fenv.stack -= len(args) # points to the first argument\n\n    # look up the target `Func`\n    key = (name, tuple(arg_types))\n    _, _, idx = fenv.get_var(key)\n    func = fenv.funcs[idx]\n\n    fenv.code.append(('call', idx, fenv.stack, fenv.level, func.level))\n    dst = -1\n    if func.rtype != ('void',):\n        dst = fenv.tmp()    # the return value on the stack top\n    return func.rtype, dst\n\n\ndef pl_comp_scope(fenv: Func, node):\n    fenv.scope_enter()\n    tp, var = ('void',), -1\n\n    # split kids into groups separated by variable declarations\n    groups = [[]]\n    for kid in node[1:]:\n        groups[-1].append(kid)\n        if kid[0] == 'var':\n            groups.append([])\n\n    # Functions are visible before they are defined,\n    # as long as they don't cross a variable declaration.\n    # This allows adjacent functions to call each other mutually.\n    for g in groups:\n        # preprocess functions\n        funcs = [\n            pl_scan_func(fenv, kid)\n            for kid in g if kid[0] == 'def' and len(kid) == 4\n        ]\n        # compile subexpressions\n        for kid in g:\n            if kid[0] == 'def' and len(kid) == 4:\n                target, *funcs = funcs\n                tp, var = pl_comp_func(target, kid)\n            else:\n                tp, var = pl_comp_expr(fenv, kid, allow_var=True)\n\n    fenv.scope_leave()\n\n    # the return is either a local variable or a new temporary\n    if var >= fenv.stack:\n        var = move_to(fenv, var, fenv.tmp())\n    return tp, var\n\n\ndef move_to(fenv, var, dst):\n    if dst != var:\n        fenv.code.append(('mov', var, dst))\n    return dst\n\n\ndef pl_comp_newvar(fenv: Func, node):\n    _, name, kid = node\n    # compile the initialization expression\n    tp, var = pl_comp_expr(fenv, kid)\n    if var < 0: # void\n        raise ValueError('bad variable init type')\n    # store the initialization value into the new variable\n    dst = fenv.add_var(name, tp)\n    return tp, move_to(fenv, var, dst)\n\n\ndef pl_comp_setvar(fenv: Func, node):\n    _, name, kid = node\n\n    flevel, dst_tp, dst = fenv.get_var(name)\n    tp, var = pl_comp_expr(fenv, kid)\n    if dst_tp != tp:\n        raise ValueError('bad variable set type')\n\n    if flevel == fenv.level:\n        # local\n        return dst_tp, move_to(fenv, var, dst)\n    else:\n        # non-local\n        fenv.code.append(('set_env', flevel, dst, var))\n        return dst_tp, move_to(fenv, var, fenv.tmp())\n\n\ndef pl_comp_cond(fenv: Func, node):\n    _, cond, yes, *no = node\n    l_true = fenv.new_label()   # then\n    l_false = fenv.new_label()  # else\n    fenv.scope_enter()  # a variable declaration is allowed on the condition\n\n    # the condition expression\n    tp, var = pl_comp_expr(fenv, cond, allow_var=True)\n    if tp == ('void',):\n        raise ValueError('expect boolean condition')\n    fenv.code.append(('jmpf', var, l_false))    # go to `else` if false\n\n    # then\n    t1, a1 = pl_comp_expr(fenv, yes)\n    if a1 >= 0:\n        # Both `then` and `else` goes to the same variable,\n        # thus a temporary is needed.\n        move_to(fenv, a1, fenv.stack)\n\n    # else, optional\n    t2, a2 = ('void',), -1\n    if no:\n        fenv.code.append(('jmp', l_true))   # skip `else` after `then`\n    fenv.set_label(l_false)\n    if no:\n        t2, a2 = pl_comp_expr(fenv, no[0])\n        if a2 >= 0:\n            move_to(fenv, a2, fenv.stack)   # the same variable for `then`\n    fenv.set_label(l_true)\n\n    fenv.scope_leave()\n    if a1 < 0 or a2 < 0 or t1 != t2:\n        return ('void',), -1    # different types, no return value\n    else:\n        return t1, fenv.tmp()   # allocate the temporary for the result\n\n\ndef pl_comp_loop(fenv: Func, node):\n    _, cond, body = node\n    fenv.scope.loop_start = fenv.new_label()\n    fenv.scope.loop_end = fenv.new_label()\n\n    # enter\n    fenv.scope_enter()  # allow_var=True\n    fenv.set_label(fenv.scope.loop_start)\n    # cond\n    _, var = pl_comp_expr(fenv, cond, allow_var=True)\n    if var < 0: # void\n        raise ValueError('bad condition type')\n    fenv.code.append(('jmpf', var, fenv.scope.loop_end))\n    # body\n    _, _ = pl_comp_expr(fenv, body)\n    # loop\n    fenv.code.append(('jmp', fenv.scope.loop_start))\n    # leave\n    fenv.set_label(fenv.scope.loop_end)\n    fenv.scope_leave()\n\n    return ('void',), -1\n\n\n# check for accepted types. returns a tuple.\ndef validate_type(tp):\n    if len(tp) == 0:\n        raise ValueError('type missing')\n    head, *body = tp\n    if head == 'ptr':\n        body = validate_type(body)\n        if body == ('void',):\n            raise ValueError('bad pointer element')\n    elif head in ('void', 'int', 'byte'):\n        if body:\n            raise ValueError('bad scalar type')\n    else:\n        raise ValueError('unknown type')\n    return (head, *body)\n\n\n# function preprocessing:\n# make the function visible to the whole scope before its definition.\ndef pl_scan_func(fenv: Func, node):\n    _, (name, *rtype), args, _ = node\n    rtype = validate_type(rtype)\n\n    # add the (name, arg-types) pair to the map\n    arg_type_list = tuple(validate_type(arg_type) for _, *arg_type in args)\n    key = (name, arg_type_list) # allows overloading by argument types\n    if key in fenv.scope.names:\n        raise ValueError('duplicated function')\n    fenv.scope.names[key] = (rtype, len(fenv.funcs))\n\n    # the new function\n    func = Func(fenv)\n    func.rtype = rtype\n    fenv.funcs.append(func)\n    return func\n\n\n# actually compile the function definition.\n# note that the `fenv` argument is the target function!\ndef pl_comp_func(fenv: Func, node):\n    _, _, args, body = node\n\n    # treat arguments as local variables\n    for arg_name, *arg_type in args:\n        if not isinstance(arg_name, str):\n            raise ValueError('bad argument name')\n        arg_type = validate_type(arg_type)\n        if arg_type == ('void',):\n            raise ValueError('bad argument type')\n        fenv.add_var(arg_name, arg_type)\n    assert fenv.stack == len(args)\n\n    # compile the function body\n    body_type, var = pl_comp_expr(fenv, body)\n    if fenv.rtype != ('void',) and fenv.rtype != body_type:\n        raise ValueError('bad body type')\n    if fenv.rtype == ('void',):\n        var = -1\n    fenv.code.append(('ret', var))  # the implicit return\n    return ('void',), -1\n\n\ndef pl_comp_syscall(fenv: Func, node):\n    _, num, *args = node\n    if isinstance(num, list) and num[0] == 'val':\n        _, num = num\n    if not isinstance(num, int) or num < 0:\n        raise ValueError('bad syscall number')\n\n    save = fenv.stack\n    sys_vars = []\n    for kid in args:\n        arg_tp, var = pl_comp_expr_tmp(fenv, kid)\n        if arg_tp == ('void',):\n            raise ValueError('bad syscall argument type')\n        sys_vars.append(var)\n    fenv.stack = save\n\n    fenv.code.append(('syscall', fenv.stack, num, *sys_vars))\n    return ('int',), fenv.tmp()\n\n\n# execute the program as a ctype function\nclass MemProgram:\n    def __init__(self, code):\n        # copy the code to an executable memory buffer\n        flags = mmap.MAP_PRIVATE|mmap.MAP_ANONYMOUS\n        prot = mmap.PROT_EXEC|mmap.PROT_READ|mmap.PROT_WRITE\n        self.code = mmap.mmap(-1, len(code), flags=flags, prot=prot)\n        self.code[:] = code\n\n        # ctype function: int64_t (*)(void *stack)\n        func_type = ctypes.CFUNCTYPE(ctypes.c_int64, ctypes.c_void_p)\n        cbuf = ctypes.c_void_p.from_buffer(self.code)\n        self.cfunc = func_type(ctypes.addressof(cbuf))\n\n        # create the data stack\n        flags = mmap.MAP_PRIVATE|mmap.MAP_ANONYMOUS\n        prot = mmap.PROT_READ|mmap.PROT_WRITE\n        self.stack = mmap.mmap(-1, 8 << 20, flags=flags, prot=prot)\n        cbuf = ctypes.c_void_p.from_buffer(self.stack)\n        self.stack_addr = ctypes.addressof(cbuf)\n        # TODO: mprotect\n\n    def invoke(self):\n        return self.cfunc(self.stack_addr)\n\n    def close(self):\n        self.code.close()\n        self.stack.close()\n\n\n# execute the program as a ctype function\nclass MemProgramWindows:\n    def __init__(self, code):\n        self.kernel32 = ctypes.CDLL('kernel32', use_last_error=True)\n\n        MEM_COMMIT = 0x00001000\n        MEM_RESERVE = 0x00002000\n        PAGE_READWRITE = 0x04\n        PAGE_EXECUTE_READWRITE = 0x40\n\n        VirtualAlloc = self.kernel32.VirtualAlloc\n        VirtualAlloc.restype = ctypes.c_void_p\n\n        # copy the code to an executable memory buffer\n        self.code = VirtualAlloc(\n            None, len(code),\n            MEM_COMMIT | MEM_RESERVE,\n            PAGE_EXECUTE_READWRITE,\n        )\n        cbuf = ctypes.c_void_p.from_buffer(code)\n        ctypes.memmove(self.code, ctypes.addressof(cbuf), len(code))\n\n        # ctype function: int64_t (*)(void *stack)\n        func_type = ctypes.CFUNCTYPE(ctypes.c_int64, ctypes.c_void_p)\n        self.cfunc = func_type(self.code)\n\n        # create the data stack\n        self.stack = VirtualAlloc(\n            None, 8 << 20,\n            MEM_COMMIT | MEM_RESERVE,\n            PAGE_READWRITE,\n        )\n        # TODO: mprotect\n\n    def invoke(self):\n        return self.cfunc(self.stack)\n\n    def close(self):\n        MEM_RELEASE = 0x00008000\n\n        VirtualFree = self.kernel32.VirtualFree\n        VirtualFree.argtypes = (ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int)\n        VirtualFree.restype = ctypes.c_bool\n\n        ok = VirtualFree(self.code, 0, MEM_RELEASE)\n        assert ok\n        ok = VirtualFree(self.stack, 0, MEM_RELEASE)\n        assert ok\n\n\n# ELF dissambler:\n# objdump -b binary -M intel,x86-64 -m i386 \\\n#   --adjust-vma=0x1000 --start-address=0x1080 -D ELF_FILE\nclass CodeGen:\n    # register encodings\n    A = 0\n    C = 1\n    D = 2\n    B = 3\n    SP = 4\n    BP = 5\n    SI = 6\n    DI = 7\n\n    def __init__(self):\n        # params\n        self.vaddr = 0x1000     # the virtual address for the program\n        self.alignment = 16\n        # output\n        self.buf = bytearray()\n        # states\n        self.jmps = dict()      # label -> offset list\n        self.calls = dict()     # function index -> offset list\n        self.strings = dict()   # string literal -> offset list\n        self.func2off = []      # func idx -> offset\n        self.fields = dict()    # ELF field name -> (size, offset)\n\n    # append a placeholder field\n    def f16(self, name):\n        self.fields[name] = (2, len(self.buf))\n        self.buf.extend(b'\\0\\0')\n    def f32(self, name):\n        self.fields[name] = (4, len(self.buf))\n        self.buf.extend(b'\\0\\0\\0\\0')\n    def f64(self, name):\n        self.fields[name] = (8, len(self.buf))\n        self.buf.extend(b'\\0' * 8)\n\n    # fill in the placeholder\n    def setf(self, name, i):\n        sz, off = self.fields[name]\n        fmt = {2: '<H', 4: '<I', 8: '<Q'}[sz]\n        self.buf[off:off+sz] = struct.pack(fmt, i)\n\n    def elf_begin(self):\n        self.elf_header()\n\n        phdr_start = len(self.buf)  # the program header starts here\n        self.elf_program_header()\n        # program header size\n        self.setf('e_phentsize', len(self.buf) - phdr_start)\n        # number of program headers: 1\n        self.setf('e_phnum', 1)\n\n        self.padding()\n        # the entry point: the virtual address where the program start\n        self.setf('e_entry', self.vaddr + len(self.buf))\n\n    def elf_header(self):\n        # ref: https://www.muppetlabs.com/~breadbox/software/tiny/tiny-elf64.asm.txt\n        self.buf.extend(bytes.fromhex('7F 45 4C 46 02 01 01 00'))\n        self.buf.extend(bytes.fromhex('00 00 00 00 00 00 00 00'))\n        # e_type, e_machine, e_version\n        self.buf.extend(bytes.fromhex('02 00 3E 00 01 00 00 00'))\n        self.f64('e_entry')\n        self.f64('e_phoff')\n        self.f64('e_shoff')\n        self.f32('e_flags')\n        self.f16('e_ehsize')\n        self.f16('e_phentsize')\n        self.f16('e_phnum')\n        self.f16('e_shentsize')\n        self.f16('e_shnum')\n        self.f16('e_shstrndx')\n        self.setf('e_phoff', len(self.buf))     # offset of the program header\n        self.setf('e_ehsize', len(self.buf))    # size of the ELF header\n\n    def elf_program_header(self):\n        # p_type, p_flags\n        self.buf.extend(bytes.fromhex('01 00 00 00 05 00 00 00'))\n        # p_offset\n        self.i64(0)\n        # p_vaddr, p_paddr\n        self.i64(self.vaddr)\n        self.i64(self.vaddr)    # useless\n        self.f64('p_filesz')\n        self.f64('p_memsz')\n        # p_align\n        self.i64(0x1000)\n\n    # compile the program to an ELF executable\n    def output_elf(self, root: Func):\n        # ELF header + program header\n        self.elf_begin()\n        # machine code\n        self.code_entry()\n        for func in root.funcs:\n            self.func(func)\n        self.code_end()\n        # fill in some ELF fields\n        self.elf_end()\n\n    def elf_end(self):\n        # fields in program header:\n        # the size of the mapping. we're mapping the whole file here.\n        self.setf('p_filesz', len(self.buf))\n        self.setf('p_memsz', len(self.buf))\n\n    def create_stack(self, data):\n        def operand(i):\n            return struct.pack('<i', i)\n\n        # syscall ref: https://blog.rchapman.org/posts/Linux_System_Call_Table_for_x86_64/\n        # syscall abi: https://github.com/torvalds/linux/blob/v5.0/arch/x86/entry/entry_64.S#L107\n        # mmap\n        self.buf.extend(\n            b\"\\xb8\\x09\\x00\\x00\\x00\"     # mov eax, 9\n            # b\"\\x31\\xff\"                 # xor edi, edi      // addr = NULL\n            b\"\\xbf\\x00\\x10\\x00\\x00\"     # mov edi, 4096     // addr\n            b\"\\x48\\xc7\\xc6%s\"           # mov rsi, xxx      // len\n            b\"\\xba\\x03\\x00\\x00\\x00\"     # mov edx, 3        // prot = PROT_READ|PROT_WRITE\n            b\"\\x41\\xba\\x22\\x00\\x00\\x00\" # mov r10d, 0x22    // flags = MAP_PRIVATE|MAP_ANONYMOUS\n            b\"\\x49\\x83\\xc8\\xff\"         # or r8, -1         // fd = -1\n            b\"\\x4d\\x31\\xc9\"             # xor r9, r9        // offset = 0\n            b\"\\x0f\\x05\"                 # syscall\n            b\"\\x48\\x89\\xc3\"             # mov rbx, rax      // the data stack\n            % operand(data + 4096)\n        )\n\n        # mprotect\n        self.buf.extend(\n            b\"\\xb8\\x0a\\x00\\x00\\x00\"     # mov eax, 10\n            b\"\\x48\\x8d\\xbb%s\"           # lea rdi, [rbx + data]\n            b\"\\xbe\\x00\\x10\\x00\\x00\"     # mov esi, 4096\n            b\"\\x31\\xd2\"                 # xor edx, edx\n            b\"\\x0f\\x05\"                 # syscall\n            % operand(data)\n        )\n        # FIXME: check the syscall return value\n\n    def code_entry(self):\n        # create the data stack (8M)\n        self.create_stack(0x800000)\n        # call the main function\n        self.asm_call(0)\n        # exit\n        self.buf.extend(\n            b\"\\xb8\\x3c\\x00\\x00\\x00\"     # mov eax, 60\n            b\"\\x48\\x8b\\x3b\"             # mov rdi, [rbx]\n            b\"\\x0f\\x05\"                 # syscall\n        )\n\n    # easier to find things in hexdump\n    def padding(self):\n        if self.alignment == 0:\n            return\n        self.buf.append(0xcc)   # int3\n        while len(self.buf) % self.alignment:\n            self.buf.append(0xcc)\n\n    # compile to a callable function\n    def output_mem(self, root: Func):\n        self.mem_entry()\n        for func in root.funcs:\n            self.func(func)\n        self.code_end()\n\n    # C function: int64_t (*)(void *stack)\n    def mem_entry(self):\n        # the first argument is the data stack\n        self.buf.extend(b\"\\x53\")            # push rbx\n        system = platform.system()\n        if system == 'Windows' or system.startswith('CYGWIN'):\n            self.buf.extend(b\"\\x48\\x89\\xCB\")    # mov rbx, rcx\n        else:\n            self.buf.extend(b\"\\x48\\x89\\xFB\")    # mov rbx, rdi\n        # call the main function\n        self.asm_call(0)\n        # the return value\n        self.buf.extend(b\"\\x48\\x8b\\x03\")    # mov rax, [rbx]\n        self.buf.extend(b\"\\x5b\")            # pop rbx\n        self.buf.extend(b\"\\xc3\")            # ret\n\n    # compile a function\n    def func(self, func: Func):\n        self.padding()\n\n        # offsets\n        self.func2off.append(len(self.buf)) # function index -> code offset\n        pos2off = []    # virtual instruction -> code offset\n\n        # call the method for each instruction\n        for instr_name, *instr_args in func.code:\n            pos2off.append(len(self.buf))\n            method = getattr(self.__class__, instr_name)\n            method(self, *instr_args)\n\n        # fill in the jmp address\n        for L, off_list in self.jmps.items():\n            dst_off = pos2off[func.labels[L]]\n            for patch_off in off_list:\n                self.patch_addr(patch_off, dst_off)\n        self.jmps.clear()\n\n    # fill in a 4-byte `rip` relative offset\n    def patch_addr(self, patch_off, dst_off):\n        src_off = patch_off + 4     # rip\n        relative = struct.pack('<i', dst_off - src_off)\n        self.buf[patch_off:patch_off+4] = relative\n\n    def code_end(self):\n        # fill in the call address\n        for L, off_list in self.calls.items():\n            dst_off = self.func2off[L]\n            for patch_off in off_list:\n                self.patch_addr(patch_off, dst_off)\n        self.calls.clear()\n        self.padding()\n        # strings\n        for s, off_list in self.strings.items():\n            dst_off = len(self.buf)\n            for patch_off in off_list:\n                self.patch_addr(patch_off, dst_off)\n            self.buf.extend(s.encode('utf-8') + b'\\0')\n        self.strings.clear()\n\n    # append a signed integer\n    def i8(self, i):\n        self.buf.append(i if i >= 0 else (256 + i))\n    def i32(self, i):\n        self.buf.extend(struct.pack('<i', i))\n    def i64(self, i):\n        self.buf.extend(struct.pack('<q', i))\n\n    # instr reg, [rm + disp]\n    # instr [rm + disp], reg\n    def asm_disp(self, lead, reg, rm, disp):\n        assert reg < 16 and rm < 16 and rm != CodeGen.SP\n\n        lead = bytearray(lead)  # optional prefix + opcode\n        if reg >= 8 or rm >= 8:\n            assert (lead[0] >> 4) == 0b0100 # REX\n            lead[0] |= (reg >> 3) << 2      # REX.R\n            lead[0] |= (rm >> 3) << 0       # REX.B\n            reg &= 0b111\n            rm &= 0b111\n\n        self.buf.extend(lead)\n        if disp == 0:\n            mod = 0     # [rm]\n        elif -128 <= disp < 128:\n            mod = 1     # [rm + disp8]\n        else:\n            mod = 2     # [rm + disp32]\n        self.buf.append((mod << 6) | (reg << 3) | rm)  # ModR/M\n        if mod == 1:\n            self.i8(disp)\n        if mod == 2:\n            self.i32(disp)\n\n    # mov reg, [rm + disp]\n    def asm_load(self, reg, rm, disp):\n        self.asm_disp(b'\\x48\\x8b', reg, rm, disp)\n\n    # mov [rm + disp], reg\n    def asm_store(self, rm, disp, reg):\n        self.asm_disp(b'\\x48\\x89', reg, rm, disp)\n\n    def store_rax(self, dst):\n        # mov [rbx + dst*8], rax\n        self.asm_store(CodeGen.B, dst * 8, CodeGen.A)\n\n    def load_rax(self, src):\n        # mov rax, [rbx + src*8]\n        self.asm_load(CodeGen.A, CodeGen.B, src * 8)\n\n    def const(self, val, dst):\n        assert isinstance(val, (int, str))\n        if isinstance(val, str):\n            # lea rax, [rip + offset]\n            self.buf.extend(b\"\\x48\\x8d\\x05\")\n            self.strings.setdefault(val, []).append(len(self.buf))\n            self.buf.extend(b\"\\0\\0\\0\\0\")\n        elif val == 0:\n            self.buf.extend(b\"\\x31\\xc0\")            # xor eax, eax\n        elif val == -1:\n            self.buf.extend(b\"\\x48\\x83\\xc8\\xff\")    # or rax, -1\n        elif (val >> 31) == 0:\n            self.buf.extend(b\"\\xb8\")                # mov eax, imm32\n            self.i32(val)\n        elif (val >> 31) == -1:\n            # sign-extended\n            self.buf.extend(b\"\\x48\\xc7\\xc0\")        # mov rax, imm32\n            self.i32(val)\n        else:\n            self.buf.extend(b\"\\x48\\xb8\")            # mov rax, imm64\n            self.i64(val)\n        self.store_rax(dst)\n\n    def mov(self, src, dst):\n        if src == dst:\n            return\n        self.load_rax(src)\n        self.store_rax(dst)\n\n    def binop(self, op, a1, a2, dst):\n        self.load_rax(a1)\n\n        arith = {\n            '+': b'\\x48\\x03',       # add  reg, rm\n            '-': b'\\x48\\x2b',       # sub  reg, rm\n            '*': b'\\x48\\x0f\\xaf',   # imul reg, rm\n        }\n        cmp = {\n            'eq': b'\\x0f\\x94\\xc0',  # sete  al\n            'ne': b'\\x0f\\x95\\xc0',  # setne al\n            'ge': b'\\x0f\\x9d\\xc0',  # setge al\n            'gt': b'\\x0f\\x9f\\xc0',  # setg  al\n            'le': b'\\x0f\\x9e\\xc0',  # setle al\n            'lt': b'\\x0f\\x9c\\xc0',  # setl  al\n        }\n\n        if op in ('/', '%'):\n            # xor edx, edx\n            self.buf.extend(b\"\\x31\\xd2\")\n            # idiv rax, [rbx + a2*8]\n            self.buf.extend(b'\\x48\\xf7\\xbb')\n            self.i32(a2 * 8)\n            if op == '%':\n                # mov rax, rdx\n                self.buf.extend(b\"\\x48\\x89\\xd0\")\n        elif op in arith:\n            # op rax, [rbx + a2*8]\n            self.asm_disp(arith[op], CodeGen.A, CodeGen.B, a2 * 8)\n        elif op in cmp:\n            # cmp rax, [rbx + a2*8]\n            self.asm_disp(b'\\x48\\x3b', CodeGen.A, CodeGen.B, a2 * 8)\n            # setcc al\n            self.buf.extend(cmp[op])\n            # movzx eax, al\n            self.buf.extend(b\"\\x0f\\xb6\\xc0\")\n        elif op == 'and':\n            self.buf.extend(\n                b\"\\x48\\x85\\xc0\"     # test rax, rax\n                b\"\\x0f\\x95\\xc0\"     # setne al\n            )\n            # mov rdx, [rbx + a2*8]\n            self.asm_load(CodeGen.D, CodeGen.B, a2 * 8)\n            self.buf.extend(\n                b\"\\x48\\x85\\xd2\"     # test rdx, rdx\n                b\"\\x0f\\x95\\xc2\"     # setne dl\n                b\"\\x21\\xd0\"         # and eax, edx\n                b\"\\x0f\\xb6\\xc0\"     # movzx eax, al\n            )\n        elif op == 'or':\n            # or rax, [rbx + a2*8]\n            self.asm_disp(b\"\\x48\\x0b\", CodeGen.A, CodeGen.B, a2 * 8)\n            self.buf.extend(\n                b\"\\x0f\\x95\\xc0\"     # setne al\n                b\"\\x0f\\xb6\\xc0\"     # movzx eax, al\n            )\n        else:\n            raise NotImplementedError\n\n        self.store_rax(dst)\n\n    def unop(self, op, a1, dst):\n        self.load_rax(a1)\n        if op == '-':\n            self.buf.extend(b\"\\x48\\xf7\\xd8\")    # neg rax\n        elif op == 'not':\n            self.buf.extend(\n                b\"\\x48\\x85\\xc0\"     # test rax, rax\n                b\"\\x0f\\x94\\xc0\"     # sete al\n                b\"\\x0f\\xb6\\xc0\"     # movzx eax, al\n            )\n        else:\n            raise NotImplementedError\n        self.store_rax(dst)\n\n    def jmpf(self, a1, L):\n        self.load_rax(a1)\n        self.buf.extend(\n            b\"\\x48\\x85\\xc0\"         # test rax, rax\n            b\"\\x0f\\x84\"             # je\n        )\n        self.jmps.setdefault(L, []).append(len(self.buf))\n        self.buf.extend(b'\\0\\0\\0\\0')\n\n    def jmp(self, L):\n        self.buf.extend(b\"\\xe9\")    # jmp\n        self.jmps.setdefault(L, []).append(len(self.buf))\n        self.buf.extend(b'\\0\\0\\0\\0')\n\n    def asm_call(self, L):\n        self.buf.extend(b\"\\xe8\")    # call\n        self.calls.setdefault(L, []).append(len(self.buf))\n        self.buf.extend(b'\\0\\0\\0\\0')\n\n    def call(self, func, arg_start, level_cur, level_new):\n        assert 1 <= level_cur\n        assert 1 <= level_new <= level_cur + 1\n\n        # put a list of pointers to outer frames in the `rsp` stack\n        if level_new > level_cur:\n            # grow the list by one\n            self.buf.append(0x53)               # push rbx\n        for _ in range(min(level_new, level_cur) - 1):\n            # copy the previous list\n            self.buf.extend(b\"\\xff\\xb4\\x24\")    # push [rsp + (level_new-1)*8]\n            self.i32((level_new - 1) * 8)\n\n        # make a new frame and call the target\n        if arg_start != 0:\n            self.buf.extend(b\"\\x48\\x81\\xc3\")    # add rbx, arg_start*8\n            self.i32(arg_start * 8)\n        self.asm_call(func)                     # call func\n        if arg_start != 0:\n            self.buf.extend(b\"\\x48\\x81\\xc3\")    # add rbx, -arg_start*8\n            self.i32(-arg_start * 8)\n\n        # discard the list of pointers\n        self.buf.extend(b\"\\x48\\x81\\xc4\")        # add rsp, (level_new - 1)*8\n        self.i32((level_new - 1) * 8)\n\n    def ret(self, a1):\n        if a1 > 0:\n            self.load_rax(a1)\n            self.store_rax(0)\n        self.buf.append(0xc3)       # ret\n\n    def load_env_addr(self, level_var):\n        self.buf.extend(b\"\\x48\\x8b\\x84\\x24\")    # mov rax, [rsp + level_var*8]\n        self.i32(level_var * 8)\n\n    def get_env(self, level_var, var, dst):\n        self.load_env_addr(level_var)\n        # mov rax, [rax + var*8]\n        self.asm_load(CodeGen.A, CodeGen.A, var * 8)\n        # mov [rbx + dst*8], rax\n        self.store_rax(dst)\n\n    def set_env(self, level_var, var, src):\n        self.load_env_addr(level_var)\n        # mov rdx, [rbx + src*8]\n        self.asm_load(CodeGen.D, CodeGen.B, src * 8)\n        # mov [rax + var*8], rdx\n        self.asm_store(CodeGen.A, var * 8, CodeGen.D)\n\n    def lea(self, a1, a2, scale, dst):\n        self.load_rax(a1)\n        self.asm_load(CodeGen.D, CodeGen.B, a2 * 8) # mov rdx, [rbx + a2*8]\n        if scale < 0:\n            self.buf.extend(b\"\\x48\\xf7\\xda\")        # neg rdx\n        self.buf.extend({\n            1: b\"\\x48\\x8d\\x04\\x10\",                 # lea rax, [rax + rdx]\n            2: b\"\\x48\\x8d\\x04\\x50\",                 # lea rax, [rax + rdx*2]\n            4: b\"\\x48\\x8d\\x04\\x90\",                 # lea rax, [rax + rdx*4]\n            8: b\"\\x48\\x8d\\x04\\xd0\",                 # lea rax, [rax + rdx*8]\n        }[abs(scale)])\n        self.store_rax(dst)\n\n    def peek(self, var, dst):\n        self.load_rax(var)\n        # mov rax, [rax]\n        self.asm_load(CodeGen.A, CodeGen.A, 0)\n        self.store_rax(dst)\n\n    def peek8(self, var, dst):\n        self.load_rax(var)\n        # movzx eax, byte ptr [rax]\n        self.buf.extend(b\"\\x0f\\xb6\\x00\")\n        self.store_rax(dst)\n\n    def poke(self, ptr, val):\n        self.load_rax(val)\n        # mov rdx, [rbx + ptr*8]\n        self.asm_load(CodeGen.D, CodeGen.B, ptr * 8)\n        # mov [rdx], rax\n        self.asm_store(CodeGen.D, 0, CodeGen.A)\n\n    def poke8(self, ptr, val):\n        self.load_rax(val)\n        # mov rdx, [rbx + ptr*8]\n        self.asm_load(CodeGen.D, CodeGen.B, ptr * 8)\n        # mov [rdx], al\n        self.buf.extend(b\"\\x88\\x02\")\n\n    def ref_var(self, var, dst):\n        # lea rax, [rbx + var*8]\n        self.buf.extend(b\"\\x48\\x8D\\x83\")\n        self.i32(var * 8)\n        self.store_rax(dst)\n\n    def ref_env(self, level_var, var, dst):\n        # mov rax, [rsp + level_var*8]\n        self.load_env_addr(level_var)\n        # add rax, var*8\n        self.buf.extend(b\"\\x48\\x05\")\n        self.i32(var * 8)\n        self.store_rax(dst)\n\n    def cast8(self, var):\n        # and qword ptr [rbx + var*8], 0xff\n        self.asm_disp(b\"\\x48\\x81\", 4, CodeGen.B, var * 8)\n        self.i32(0xff)\n\n    def syscall(self, dst, num, *arg_list):\n        # syscall ref: https://blog.rchapman.org/posts/Linux_System_Call_Table_for_x86_64/\n        self.buf.extend(b\"\\xb8\")                # mov eax, imm32\n        self.i32(num)\n        arg_regs = [CodeGen.DI, CodeGen.SI, CodeGen.D, 10, 8, 9]\n        assert len(arg_list) <= len(arg_regs)\n        for i, arg in enumerate(arg_list):\n            # mov reg, [rbx + arg*8]\n            self.asm_load(arg_regs[i], CodeGen.B, arg * 8)\n        self.buf.extend(b\"\\x0f\\x05\")            # syscall\n        self.store_rax(dst)                     # mov [rbx + dst*8], rax\n\n    def debug(self):\n        self.buf.append(0xcc)                   # int3\n\n\n# ir\n'''\nconst val dst\nmov src dst\nbinop op a1 a2 dst\nunop op a1 dst\nbinop8 op a1 a2 dst\nunop8 op a1 dst\njmpf a1 L\njmp L\nret a1\nret -1\ncall func arg_start level_cur level_new\nget_env level_var var dst\nset_env level_var var src\nref_var var dst\nref_env level_var var dst\nlea\npeek\npoke\npeek8\npoke8\ncast8\nsyscall\ndebug\n'''\n\n\n# syntax\n'''\n(+ a b)\n(- a b)\n(* a b)\n(/ a b)\n...\n\n(eq a b)\n(ne a b)\n(ge a b)\n(gt a b)\n(le a b)\n(lt a b)\n\n(not b)\n(and a b)\n(or a b)\n\n(? cond yes no)\n(if cond (then yes blah blah) (else no no no))\n(do a b c...)\n(var name val)\n(set name val)\n(loop cond body)\n(break)\n(continue)\n\n(def (name rtype) ((a1 a1type) (a2 a2type)...) body)\n(call f a b c...)\n(return val)\n\n(ptr elem_type)\n(peek ptr)\n(poke ptr value)\n(ref name)\n(syscall num args...)\n(cast type val)\n'''\n\n\n# types\n'''\nvoid\nint\nbyte\nptr int\nptr byte\n'''\n\n\ndef ir_dump(root: Func):\n    out = []\n    for i, func in enumerate(root.funcs):\n        out.append(f'func{i}:')\n        pos2labels = dict()\n        for label, pos in enumerate(func.labels):\n            pos2labels.setdefault(pos, []).append(label)\n        for pos, instr in enumerate(func.code):\n            for label in pos2labels.get(pos, []):\n                out.append(f'L{label}:')\n            if instr[0].startswith('jmp'):\n                instr = instr[:-1] + (f'L{instr[-1]}',)\n            if instr[0] == 'const' and isinstance(instr[1], str):\n                import json\n                instr = list(instr)\n                instr[1] = json.dumps(instr[1])\n            out.append('    ' + ' '.join(map(str, instr)))\n        out.append('')\n\n    return '\\n'.join(out)\n\n\ndef test_comp():\n    def f(s):\n        node = pl_parse_main(s)\n        fenv = Func(None)\n        pl_comp_main(fenv, node)\n        return [x.code for x in fenv.funcs]\n\n    def asm(s):\n        node = pl_parse_main(s)\n        fenv = Func(None)\n        pl_comp_main(fenv, node)\n        return ir_dump(fenv)\n\n    assert f('1') == [[\n        ('const', 1, 0),\n        ('ret', 0),\n    ]]\n    assert f('1 3') == [[\n        ('const', 1, 0),\n        ('const', 3, 0),\n        ('ret', 0),\n    ]]\n    assert f('(+ (- 1 2) 3)') == [[\n        ('const', 1, 0),\n        ('const', 2, 1),\n        ('binop', '-', 0, 1, 0),\n        ('const', 3, 1),\n        ('binop', '+', 0, 1, 0),\n        ('ret', 0),\n    ]]\n    assert f('(return 1)') == [[\n        ('const', 1, 0),\n        ('ret', 0),\n        ('ret', 0),\n    ]]\n    assert asm('(if 1 2 3)').split() == '''\n        func0:\n            const 1 0\n            jmpf 0 L1\n            const 2 0\n            jmp L0\n        L1:\n            const 3 0\n        L0:\n            ret 0\n    '''.split()\n    assert asm('''\n        (loop (var a 1) (do\n            (var b a)\n            (if (gt a 11)\n                (break))\n            (var c (set a (+ 2 b)))\n            (if (lt c 100)\n                (continue))\n            (set b 5)\n        ))\n        0''').split() == '''\n        func0:\n        L0:\n            const 1 0\n            jmpf 0 L1\n            mov 0 1\n            const 11 2\n            binop gt 0 2 2\n            jmpf 2 L3\n            jmp L1\n        L2:\n        L3:\n            const 2 2\n            binop + 2 1 2\n            mov 2 0\n            mov 0 2\n            const 100 3\n            binop lt 2 3 3\n            jmpf 3 L5\n            jmp L0\n        L4:\n        L5:\n            const 5 3\n            mov 3 1\n            jmp L0\n        L1:\n            const 0 0\n            ret 0\n    '''.split()\n    assert asm('(if 1 (return 2)) 0').split() == '''\n        func0:\n            const 1 0\n            jmpf 0 L1\n            const 2 0\n            ret 0\n        L0:\n        L1:\n            const 0 0\n            ret 0\n    '''.split()\n    assert asm('(var a 1) (set a (+ 3 a)) (var b 2) (- b a)').split() == '''\n        func0:\n            const 1 0\n            const 3 1\n            binop + 1 0 1\n            mov 1 0\n            const 2 1\n            binop - 1 0 2\n            mov 2 0\n            ret 0\n    '''.split()\n    assert asm('(var a 1) (return (+ 3 a))').split() == '''\n        func0:\n            const 1 0\n            const 3 1\n            binop + 1 0 1\n            ret 1\n            mov 1 0\n            ret 0\n    '''.split()\n    assert asm('(var a 1) (+ 3 a)').split() == '''\n        func0:\n            const 1 0\n            const 3 1\n            binop + 1 0 1\n            mov 1 0\n            ret 0\n    '''.split()\n    assert asm('''\n        (def (fib int) ((n int)) (if (le n 0) (then 0) (else (call fib (- n 1)))))\n        (call fib 5)\n        ''').split() == '''\n        func0:\n            const 5 0\n            call 1 0 1 2\n            ret 0\n        func1:\n            const 0 1\n            binop le 0 1 1\n            jmpf 1 L1\n            const 0 1\n            jmp L0\n        L1:\n            const 1 1\n            binop - 0 1 1\n            call 1 1 2 2\n        L0:\n            ret 1\n    '''.split()\n    assert asm('''\n        (var b 456)\n        (def (f void) () (do\n            (var a 123)\n            (def (g void) () (do\n                (set a (+ b a))\n            ))\n            (call g)\n        ))\n\n        (call f)\n        0\n        ''').split() == '''\n        func0:\n            const 456 0\n            call 1 1 1 2\n            const 0 1\n            mov 1 0\n            ret 0\n        func1:\n            const 123 0\n            call 2 1 2 3\n            ret -1\n        func2:\n            get_env 1 0 0\n            get_env 2 0 1\n            binop + 0 1 0\n            set_env 2 0 0\n            ret -1\n    '''.split()\n    assert asm('''\n        (var p (ptr int))\n        (poke (cast (ptr byte) p) 124u8)\n        (peek (cast (ptr byte) p))\n        (poke p 123)\n    ''').split() == '''\n        func0:\n            const 0 0\n            const 124 1\n            poke8 0 1\n            peek8 0 1\n            const 123 1\n            poke 0 1\n            mov 1 0\n            ret 0\n    '''.split()\n\n\ndef main():\n    # args\n    import argparse\n    ap = argparse.ArgumentParser()\n    ap.add_argument('file', nargs='?', help='the input source file')\n    ap.add_argument('--exec', action='store_true', help='compile to memory and execute it')\n    ap.add_argument('-o', '--output', help='the output path')\n    ap.add_argument('--print-ir', action='store_true', help='print the intermediate representation')\n    ap.add_argument('--alignment', type=int, default=16)\n    ap.add_argument('--vaddr', type=int, default=0x1000, help='the virtual address for the program')\n    args = ap.parse_args()\n    if not (args.file or args.output or args.exec):\n        ap.print_help()\n        test_comp()\n        return\n\n    # source text\n    with open(args.file, 'rt', encoding='utf-8') as fp:\n        text = fp.read()\n\n    # parse & compile\n    node = pl_parse_main(text)\n    root = Func(None)\n    _ = pl_comp_main(root, node)\n    if args.print_ir:\n        print(ir_dump(root))\n\n    # generate output\n    if args.output:\n        gen = CodeGen()\n        gen.vaddr = args.vaddr\n        gen.alignment = args.alignment\n        gen.output_elf(root)\n        fd = os.open(args.output, os.O_WRONLY|os.O_CREAT|os.O_TRUNC, 0o755)\n        with os.fdopen(fd, 'wb', closefd=True) as fp:\n            fp.write(gen.buf)\n\n    # execute\n    if args.exec:\n        gen = CodeGen()\n        gen.alignment = args.alignment\n        gen.output_mem(root)\n        if platform.system() == 'Windows':\n            prog = MemProgramWindows(gen.buf)\n        else:\n            prog = MemProgram(gen.buf)\n        try:\n            sys.exit(prog.invoke())\n        finally:\n            prog.close()\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "samples/42.txt",
    "content": "(def (foo int) ((n int))\n    (if (le n 0)\n        (then 1)\n        (else (* n (call foo (- n 1))))))\n\n(/ (call foo 7) (call foo 5))\n"
  },
  {
    "path": "samples/fizzbuzz.txt",
    "content": "\n(def (mmap ptr byte) ((n int)) (do\n    (var prot 3)            ;; PROT_READ|PROT_WRITE\n    (var flags 0x22)        ;; MAP_PRIVATE|MAP_ANONYMOUS\n    (var fd -1)\n    (var offset 0)\n    (var r (syscall 9 0 n prot flags fd offset))\n    (return (cast (ptr byte) r))\n))\n\n(var g_str_buf (call mmap 24))\n\n;; convert a number to string\n(def (str ptr byte) ((i int)) (do\n    ;; number to digits\n    (var s g_str_buf)\n    (loop i (do\n        (var d (+ 48 (% i 10)))\n        (set i (/ i 10))\n        (poke s (cast (byte) d))\n        (set s (+ 1 s))\n    ))\n    (poke s '\\u0000')\n\n    ;; reverse the digits\n    (set s (- s 1))\n    (var l g_str_buf)\n    (loop (lt l s) (do\n        (var t (peek l))\n        (poke l (peek s))\n        (poke s t)\n        (set l (+ l 1))\n        (set s (- s 1))\n    ))\n    (return g_str_buf)\n))\n\n(def (strlen int) ((s ptr byte)) (do\n    (var start s)\n    (loop (peek s) (set s (+ 1 s)))\n    (return (- s start))\n))\n\n(def (print void) ((s ptr byte)) (do\n    (syscall 1 1 s (call strlen s))\n))\n(def (print void) ((i int)) (do\n    (call print (call str i))\n))\n\n;; a stupid way of fizzbuzzing.\n;; notice the mutual function calls.\n(def (fizz void) ((n int)) (do\n    (call number (- n 1))\n    (call print (? (not (% n 5)) \"fizzbuzz\\n\" \"fizz\\n\"))\n))\n(def (buzz void) ((n int)) (do\n    (call number (- n 1))\n    (call print \"buzz\\n\")\n))\n(def (number void) ((n int)) (do\n    (if (le n 0) (return))\n    (if (not (% n 3)) (return (call fizz n)))\n    (if (not (% n 5)) (return (call buzz n)))\n    (call number (- n 1))\n    (call print n)\n    (call print \"\\n\")\n))\n\n(call number 101)\n0\n"
  },
  {
    "path": "samples/hello.txt",
    "content": ";; the write() syscall:\n;; ssize_t write(int fd, const void *buf, size_t count);\n(syscall 1 1 \"Hello world!\\n\" 13)\n0\n"
  },
  {
    "path": "samples/malloc_and_strings.txt",
    "content": "\n(var heap (ptr byte))\n\n; a fake malloc\n(def (malloc ptr byte) ((n int)) (do\n    (if (not heap) (do\n        ; create the heap via mmap()\n        (var heapsz 1048576)    ; 1M\n        (var prot 3)            ; PROT_READ|PROT_WRITE\n        (var flags 0x22)        ; MAP_PRIVATE|MAP_ANONYMOUS\n        (var fd -1)\n        (var offset 0)\n        (var r (syscall 9 0 heapsz prot flags fd offset))\n        (set heap (cast (ptr byte) r))\n    ))\n    ; just move the heap pointer forward\n    (var r heap)\n    (set heap (+ n heap))\n    (return r)\n))\n\n; never free anything\n(def (free void) ((p ptr byte)) (do))\n\n; allocate a new string.\n; the length and capacity are stored before the string data.\n; | len | cap | data\n; |  8  |  8  | ....\n(def (strnew ptr byte) ((cap int)) (do\n    (var addr (call malloc (+ 16 cap)))\n    (var iaddr (cast (ptr int) addr))\n    (poke iaddr 0)\n    (poke (+ 1 iaddr) cap)\n    (return (+ 16 addr))\n))\n\n; free the string\n(def (strdel void) ((s ptr byte)) (do\n    (call free (- s 16))\n))\n\n; access the len and the cap\n(def (strlen int) ((s ptr byte)) (do\n    (var iaddr (cast (ptr int) s))\n    (return (peek (- iaddr 2)))\n))\n(def (strcap int) ((s ptr byte)) (do\n    (var iaddr (cast (ptr int) s))\n    (return (peek (- iaddr 1)))\n))\n\n; copy data byte by byte\n(def (memcpy void) ((dst ptr byte) (src ptr byte) (n int)) (do\n    (loop n (do\n        (poke dst (peek src))\n        (set dst (+ 1 dst))\n        (set src (+ 1 src))\n        (set n (- n 1))\n    ))\n))\n\n; append a character to a string, growing it if necessary.\n(def (append ptr byte) ((s ptr byte) (ch byte)) (do\n    (var len (call strlen s))\n    (var cap (call strcap s))\n    (if (eq len cap) (do\n        ; create a new string with double the capacity\n        (set cap (* 2 cap))\n        (if (lt cap 8) (set cap 8))\n        (var new (call strnew cap))\n        ; copy the data to the new string and replace the old one\n        (call memcpy new s len)\n        (call strdel s)\n        (set s new)\n    ))\n    ; write the character\n    (poke (+ len s) ch)\n    ; update the length field\n    (poke (cast (ptr int) (- s 16)) (+ 1 len))\n    (return s)\n))\n\n; print a string to stdout\n(def (print void) ((s ptr byte)) (do\n    (var len (call strlen s))\n    (syscall 1 1 s len)\n))\n\n; reverse a string in place\n(def (strrev ptr byte) ((s ptr byte)) (do\n    (var l s)\n    (var r (- (+ s (call strlen s)) 1))\n    (loop (lt l r) (do\n        (var t (peek l))\n        (poke l (peek r))\n        (poke r t)\n        (set l (+ l 1))\n        (set r (- r 1))\n    ))\n    (return s)\n))\n\n; convert an int to string\n; FIXME: negative numbers\n(def (str ptr byte) ((i int)) (do\n    (var s (call strnew 24))\n    (if (eq 0 i) (call append s '0'))\n    (loop i (do\n        (var d (+ 48 (% i 10)))\n        (set i (/ i 10))\n        (set s (call append s (cast (byte) d)))\n    ))\n    (call strrev s)\n    (return s)\n))\n\n; print an int\n(def (print void) ((i int)) (do\n    (var s (call str i))\n    (call print s)\n    (call strdel s)\n))\n\n; hello world\n(var s (call strnew 0))\n(set s (call append s 72u8))\n(set s (call append s 101u8))\n(set s (call append s 108u8))\n(set s (call append s 108u8))\n(set s (call append s 111u8))\n\n(set s (call append s '_'))\n(set s (call append s 'w'))\n(set s (call append s 'o'))\n(set s (call append s 'r'))\n(set s (call append s 'l'))\n(set s (call append s 'd'))\n(set s (call append s 10u8))\n\n(call print s)\n(call print (call strlen s))\n(call print (call strcap s))\n(call print s)\n\n; return\n0\n"
  },
  {
    "path": "samples/null_deref.txt",
    "content": "(peek (cast (ptr int) 0))   ;; dereference the NULL pointer\n"
  },
  {
    "path": "samples/print.txt",
    "content": "(def (strlen int) ((s ptr byte)) (do\n    (var start s)\n    (loop (peek s) (set s (+ 1 s)))\n    (return (- s start))\n))\n\n(def (print void) ((s ptr byte)) (do\n    (syscall 1 1 s (call strlen s))\n))\n\n(call print \"Yes!\\n\")\n0\n"
  }
]