Repository: exaloop/codon Branch: develop Commit: f1cd7006ad65 Files: 612 Total size: 6.3 MB Directory structure: gitextract_o7k99b34/ ├── .clang-format ├── .clang-tidy ├── .gitattributes ├── .github/ │ ├── dependabot.yml │ └── workflows/ │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CMakeLists.txt ├── CODEOWNERS ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── bench/ │ ├── README.md │ ├── codon/ │ │ ├── binary_trees.codon │ │ ├── binary_trees.cpp │ │ ├── binary_trees.py │ │ ├── chaos.codon │ │ ├── chaos.py │ │ ├── fannkuch.codon │ │ ├── fannkuch.py │ │ ├── float.py │ │ ├── go.codon │ │ ├── go.py │ │ ├── mandelbrot.codon │ │ ├── mandelbrot.py │ │ ├── nbody.cpp │ │ ├── nbody.py │ │ ├── npbench.codon │ │ ├── npbench_lib.codon │ │ ├── primes.codon │ │ ├── primes.py │ │ ├── set_partition.cpp │ │ ├── set_partition.py │ │ ├── spectral_norm.py │ │ ├── sum.py │ │ ├── taq.cpp │ │ ├── taq.py │ │ ├── word_count.cpp │ │ └── word_count.py │ └── run.sh ├── cmake/ │ ├── CMakeRC.cmake │ ├── backtrace-config.h.in │ ├── backtrace-supported.h.in │ ├── config.h.in │ ├── config.py.in │ └── deps.cmake ├── codon/ │ ├── app/ │ │ └── main.cpp │ ├── cir/ │ │ ├── analyze/ │ │ │ ├── analysis.cpp │ │ │ ├── analysis.h │ │ │ ├── dataflow/ │ │ │ │ ├── capture.cpp │ │ │ │ ├── capture.h │ │ │ │ ├── cfg.cpp │ │ │ │ ├── cfg.h │ │ │ │ ├── dominator.cpp │ │ │ │ ├── dominator.h │ │ │ │ ├── reaching.cpp │ │ │ │ └── reaching.h │ │ │ └── module/ │ │ │ ├── global_vars.cpp │ │ │ ├── global_vars.h │ │ │ ├── side_effect.cpp │ │ │ └── side_effect.h │ │ ├── attribute.cpp │ │ ├── attribute.h │ │ ├── base.cpp │ │ ├── base.h │ │ ├── cir.h │ │ ├── const.cpp │ │ ├── const.h │ │ ├── dsl/ │ │ │ ├── codegen.h │ │ │ ├── nodes.cpp │ │ │ └── nodes.h │ │ ├── flow.cpp │ │ ├── flow.h │ │ ├── func.cpp │ │ ├── func.h │ │ ├── instr.cpp │ │ ├── instr.h │ │ ├── llvm/ │ │ │ ├── gpu.cpp │ │ │ ├── gpu.h │ │ │ ├── llvisitor.cpp │ │ │ ├── llvisitor.h │ │ │ ├── llvm.h │ │ │ ├── native/ │ │ │ │ ├── native.cpp │ │ │ │ ├── native.h │ │ │ │ └── targets/ │ │ │ │ ├── aarch64.cpp │ │ │ │ ├── aarch64.h │ │ │ │ ├── arm.cpp │ │ │ │ ├── arm.h │ │ │ │ ├── target.h │ │ │ │ ├── x86.cpp │ │ │ │ └── x86.h │ │ │ ├── optimize.cpp │ │ │ └── optimize.h │ │ ├── module.cpp │ │ ├── module.h │ │ ├── pyextension.h │ │ ├── transform/ │ │ │ ├── cleanup/ │ │ │ │ ├── canonical.cpp │ │ │ │ ├── canonical.h │ │ │ │ ├── dead_code.cpp │ │ │ │ ├── dead_code.h │ │ │ │ ├── global_demote.cpp │ │ │ │ ├── global_demote.h │ │ │ │ ├── replacer.cpp │ │ │ │ └── replacer.h │ │ │ ├── folding/ │ │ │ │ ├── const_fold.cpp │ │ │ │ ├── const_fold.h │ │ │ │ ├── const_prop.cpp │ │ │ │ ├── const_prop.h │ │ │ │ ├── folding.cpp │ │ │ │ ├── folding.h │ │ │ │ └── rule.h │ │ │ ├── lowering/ │ │ │ │ ├── async_for.cpp │ │ │ │ ├── async_for.h │ │ │ │ ├── await.cpp │ │ │ │ ├── await.h │ │ │ │ ├── imperative.cpp │ │ │ │ ├── imperative.h │ │ │ │ ├── pipeline.cpp │ │ │ │ └── pipeline.h │ │ │ ├── manager.cpp │ │ │ ├── manager.h │ │ │ ├── numpy/ │ │ │ │ ├── expr.cpp │ │ │ │ ├── forward.cpp │ │ │ │ ├── indexing.cpp │ │ │ │ ├── indexing.h │ │ │ │ ├── numpy.cpp │ │ │ │ └── numpy.h │ │ │ ├── parallel/ │ │ │ │ ├── openmp.cpp │ │ │ │ ├── openmp.h │ │ │ │ ├── schedule.cpp │ │ │ │ └── schedule.h │ │ │ ├── pass.cpp │ │ │ ├── pass.h │ │ │ ├── pythonic/ │ │ │ │ ├── dict.cpp │ │ │ │ ├── dict.h │ │ │ │ ├── generator.cpp │ │ │ │ ├── generator.h │ │ │ │ ├── io.cpp │ │ │ │ ├── io.h │ │ │ │ ├── list.cpp │ │ │ │ ├── list.h │ │ │ │ ├── str.cpp │ │ │ │ └── str.h │ │ │ └── rewrite.h │ │ ├── types/ │ │ │ ├── types.cpp │ │ │ └── types.h │ │ ├── util/ │ │ │ ├── cloning.cpp │ │ │ ├── cloning.h │ │ │ ├── context.h │ │ │ ├── format.cpp │ │ │ ├── format.h │ │ │ ├── inlining.cpp │ │ │ ├── inlining.h │ │ │ ├── irtools.cpp │ │ │ ├── irtools.h │ │ │ ├── iterators.h │ │ │ ├── matching.cpp │ │ │ ├── matching.h │ │ │ ├── operator.h │ │ │ ├── outlining.cpp │ │ │ ├── outlining.h │ │ │ ├── packs.h │ │ │ ├── side_effect.cpp │ │ │ ├── side_effect.h │ │ │ ├── visitor.cpp │ │ │ └── visitor.h │ │ ├── value.cpp │ │ ├── value.h │ │ ├── var.cpp │ │ └── var.h │ ├── compiler/ │ │ ├── compiler.cpp │ │ ├── compiler.h │ │ ├── debug_listener.cpp │ │ ├── debug_listener.h │ │ ├── engine.cpp │ │ ├── engine.h │ │ ├── error.cpp │ │ ├── error.h │ │ ├── jit.cpp │ │ ├── jit.h │ │ ├── jit_extern.h │ │ ├── memory_manager.cpp │ │ └── memory_manager.h │ ├── config/ │ │ └── .gitignore │ ├── dsl/ │ │ ├── dsl.h │ │ ├── plugins.cpp │ │ └── plugins.h │ ├── parser/ │ │ ├── ast/ │ │ │ ├── attr.cpp │ │ │ ├── attr.h │ │ │ ├── error.h │ │ │ ├── expr.cpp │ │ │ ├── expr.h │ │ │ ├── node.h │ │ │ ├── stmt.cpp │ │ │ ├── stmt.h │ │ │ ├── types/ │ │ │ │ ├── class.cpp │ │ │ │ ├── class.h │ │ │ │ ├── function.cpp │ │ │ │ ├── function.h │ │ │ │ ├── link.cpp │ │ │ │ ├── link.h │ │ │ │ ├── static.cpp │ │ │ │ ├── static.h │ │ │ │ ├── traits.cpp │ │ │ │ ├── traits.h │ │ │ │ ├── type.cpp │ │ │ │ ├── type.h │ │ │ │ ├── union.cpp │ │ │ │ └── union.h │ │ │ └── types.h │ │ ├── ast.h │ │ ├── cache.cpp │ │ ├── cache.h │ │ ├── common.cpp │ │ ├── common.h │ │ ├── ctx.h │ │ ├── match.cpp │ │ ├── match.h │ │ ├── peg/ │ │ │ ├── grammar.peg │ │ │ ├── openmp.peg │ │ │ ├── peg.cpp │ │ │ ├── peg.h │ │ │ └── rules.h │ │ └── visitors/ │ │ ├── doc/ │ │ │ ├── doc.cpp │ │ │ └── doc.h │ │ ├── format/ │ │ │ ├── format.cpp │ │ │ └── format.h │ │ ├── scoping/ │ │ │ ├── scoping.cpp │ │ │ └── scoping.h │ │ ├── translate/ │ │ │ ├── translate.cpp │ │ │ ├── translate.h │ │ │ ├── translate_ctx.cpp │ │ │ └── translate_ctx.h │ │ ├── typecheck/ │ │ │ ├── access.cpp │ │ │ ├── assign.cpp │ │ │ ├── basic.cpp │ │ │ ├── call.cpp │ │ │ ├── class.cpp │ │ │ ├── collections.cpp │ │ │ ├── cond.cpp │ │ │ ├── ctx.cpp │ │ │ ├── ctx.h │ │ │ ├── error.cpp │ │ │ ├── function.cpp │ │ │ ├── import.cpp │ │ │ ├── infer.cpp │ │ │ ├── loops.cpp │ │ │ ├── op.cpp │ │ │ ├── special.cpp │ │ │ ├── typecheck.cpp │ │ │ └── typecheck.h │ │ ├── visitor.cpp │ │ └── visitor.h │ ├── runtime/ │ │ ├── exc.cpp │ │ ├── floatlib/ │ │ │ ├── extenddftf2.c │ │ │ ├── extendhfsf2.c │ │ │ ├── extendhftf2.c │ │ │ ├── extendsfdf2.c │ │ │ ├── extendsftf2.c │ │ │ ├── fp_extend.h │ │ │ ├── fp_extend_impl.inc │ │ │ ├── fp_lib.h │ │ │ ├── fp_trunc.h │ │ │ ├── fp_trunc_impl.inc │ │ │ ├── int_endianness.h │ │ │ ├── int_lib.h │ │ │ ├── int_math.h │ │ │ ├── int_types.h │ │ │ ├── int_util.h │ │ │ ├── truncdfbf2.c │ │ │ ├── truncdfhf2.c │ │ │ ├── truncdfsf2.c │ │ │ ├── truncsfbf2.c │ │ │ ├── truncsfhf2.c │ │ │ ├── trunctfdf2.c │ │ │ ├── trunctfhf2.c │ │ │ └── trunctfsf2.c │ │ ├── lib.cpp │ │ ├── lib.h │ │ ├── numpy/ │ │ │ ├── loops.cpp │ │ │ ├── sort.cpp │ │ │ └── zmath.cpp │ │ └── re.cpp │ └── util/ │ ├── common.cpp │ ├── common.h │ ├── jupyter.cpp │ ├── jupyter.h │ ├── peg2cpp.cpp │ ├── serialize.h │ └── tser.h ├── docs/ │ ├── css/ │ │ └── extra.css │ ├── developers/ │ │ ├── build.md │ │ ├── compilation.md │ │ ├── contribute.md │ │ ├── extend.md │ │ ├── ir.md │ │ └── roadmap.md │ ├── img/ │ │ └── image.avif │ ├── index.md │ ├── integrations/ │ │ ├── cpp/ │ │ │ ├── codon-from-cpp.md │ │ │ ├── cpp-from-codon.md │ │ │ └── jit.md │ │ ├── jupyter.md │ │ └── python/ │ │ ├── codon-from-python.md │ │ ├── extensions.md │ │ └── python-from-codon.md │ ├── js/ │ │ └── mathjax.js │ ├── labs/ │ │ ├── catalog/ │ │ │ └── start.md │ │ └── index.md │ ├── language/ │ │ ├── classes.md │ │ ├── generics.md │ │ ├── llvm.md │ │ ├── lowlevel.md │ │ ├── meta.md │ │ └── overview.md │ ├── libraries/ │ │ ├── api/ │ │ │ └── .gitignore │ │ ├── numpy.md │ │ └── stdlib.md │ ├── overrides/ │ │ └── main.html │ ├── parallel/ │ │ ├── gpu.md │ │ ├── multithreading.md │ │ └── simd.md │ └── start/ │ ├── changelog.md │ ├── faq.md │ ├── install.md │ └── usage.md ├── jit/ │ ├── .gitignore │ ├── MANIFEST.in │ ├── README.md │ ├── codon/ │ │ ├── __init__.py │ │ ├── decorator.py │ │ ├── jit.pxd │ │ └── jit.pyx │ ├── pyproject.toml │ └── setup.py ├── jupyter/ │ ├── CMakeLists.txt │ ├── jupyter.cpp │ ├── jupyter.h │ ├── share/ │ │ └── jupyter/ │ │ └── kernels/ │ │ └── codon/ │ │ └── kernel.json.in │ └── xeus.patch ├── mkdocs.yml ├── scripts/ │ ├── Dockerfile.codon-build │ ├── Dockerfile.codon-jupyter │ ├── Dockerfile.gpu │ ├── Dockerfile.llvm-build │ ├── deps.sh │ ├── docgen.py │ ├── fix_loader_paths.sh │ ├── get_system_libs.sh │ └── install.sh ├── stdlib/ │ ├── algorithms/ │ │ ├── heapsort.codon │ │ ├── insertionsort.codon │ │ ├── pdqsort.codon │ │ ├── qsort.codon │ │ ├── strings.codon │ │ └── timsort.codon │ ├── asyncio.codon │ ├── bisect.codon │ ├── bz2.codon │ ├── cmath.codon │ ├── codon/ │ │ └── static.codon │ ├── collections.codon │ ├── copy.codon │ ├── datetime.codon │ ├── functools.codon │ ├── getopt.codon │ ├── gpu.codon │ ├── gzip.codon │ ├── heapq.codon │ ├── internal/ │ │ ├── __init__.codon │ │ ├── __init_test__.codon │ │ ├── attributes.codon │ │ ├── builtin.codon │ │ ├── c_stubs.codon │ │ ├── core.codon │ │ ├── dlopen.codon │ │ ├── file.codon │ │ ├── format.codon │ │ ├── gc.codon │ │ ├── gpu.codon │ │ ├── internal.codon │ │ ├── khash.codon │ │ ├── pynumerics.codon │ │ ├── python.codon │ │ ├── sort.codon │ │ ├── static.codon │ │ ├── str.codon │ │ └── types/ │ │ ├── any.codon │ │ ├── array.codon │ │ ├── bool.codon │ │ ├── byte.codon │ │ ├── collections/ │ │ │ ├── dict.codon │ │ │ ├── list.codon │ │ │ ├── set.codon │ │ │ └── tuple.codon │ │ ├── complex.codon │ │ ├── ellipsis.codon │ │ ├── error.codon │ │ ├── float.codon │ │ ├── function.codon │ │ ├── generator.codon │ │ ├── import_.codon │ │ ├── int.codon │ │ ├── intn.codon │ │ ├── optional.codon │ │ ├── ptr.codon │ │ ├── range.codon │ │ ├── rtti.codon │ │ ├── slice.codon │ │ ├── str.codon │ │ ├── strbuf.codon │ │ ├── tuple.codon │ │ ├── type.codon │ │ └── union.codon │ ├── itertools.codon │ ├── math.codon │ ├── numpy/ │ │ ├── __init__.codon │ │ ├── const.codon │ │ ├── dragon4.codon │ │ ├── dtype.codon │ │ ├── emath.codon │ │ ├── fft/ │ │ │ ├── __init__.codon │ │ │ └── pocketfft.codon │ │ ├── format.codon │ │ ├── functional.codon │ │ ├── fusion.codon │ │ ├── indexing.codon │ │ ├── interp.codon │ │ ├── lib/ │ │ │ ├── __init__.codon │ │ │ ├── arraysetops.codon │ │ │ └── stride_tricks.codon │ │ ├── linalg/ │ │ │ ├── __init__.codon │ │ │ ├── blas.codon │ │ │ └── linalg.codon │ │ ├── linalg_sym.codon │ │ ├── misc.codon │ │ ├── ndarray.codon │ │ ├── ndgpu.codon │ │ ├── ndmath.codon │ │ ├── npdatetime.codon │ │ ├── npio.codon │ │ ├── operators.codon │ │ ├── pybridge.codon │ │ ├── random/ │ │ │ ├── __init__.codon │ │ │ ├── bitgen.codon │ │ │ ├── logfactorial.codon │ │ │ ├── mt19937.codon │ │ │ ├── pcg64.codon │ │ │ ├── philox.codon │ │ │ ├── seed.codon │ │ │ ├── sfc64.codon │ │ │ ├── splitmix64.codon │ │ │ └── ziggurat.codon │ │ ├── reductions.codon │ │ ├── routines.codon │ │ ├── sorting.codon │ │ ├── statistics.codon │ │ ├── ufunc.codon │ │ ├── util.codon │ │ ├── window.codon │ │ └── zmath.codon │ ├── openmp.codon │ ├── operator.codon │ ├── os/ │ │ ├── __init__.codon │ │ └── path.codon │ ├── pickle.codon │ ├── python.codon │ ├── random.codon │ ├── re.codon │ ├── simd.codon │ ├── sortedlist.codon │ ├── statistics.codon │ ├── string.codon │ ├── sys.codon │ ├── threading.codon │ ├── time.codon │ ├── typing.codon │ └── unittest.codon └── test/ ├── CMakeLists.txt.in ├── app/ │ ├── argv.codon │ ├── build.codon │ ├── exit.codon │ ├── export.codon │ ├── input.codon │ ├── input.txt │ ├── test.c │ └── test.sh ├── cir/ │ ├── analyze/ │ │ ├── dominator.cpp │ │ └── reaching.cpp │ ├── base.cpp │ ├── constant.cpp │ ├── flow.cpp │ ├── func.cpp │ ├── instr.cpp │ ├── module.cpp │ ├── test.h │ ├── transform/ │ │ └── manager.cpp │ ├── types/ │ │ └── types.cpp │ ├── util/ │ │ └── matching.cpp │ ├── value.cpp │ └── var.cpp ├── core/ │ ├── arguments.codon │ ├── arithmetic.codon │ ├── bltin.codon │ ├── containers.codon │ ├── empty.codon │ ├── exceptions.codon │ ├── generators.codon │ ├── generics.codon │ ├── helloworld.codon │ ├── match.codon │ ├── numerics.codon │ ├── parser.codon │ ├── pipeline.codon │ ├── range.codon │ ├── serialization.codon │ ├── sort.codon │ ├── trees.codon │ └── vec_simd.codon ├── main.cpp ├── numpy/ │ ├── data/ │ │ └── .gitignore │ ├── random_tests/ │ │ ├── test_mt19937.codon │ │ ├── test_pcg64.codon │ │ ├── test_philox.codon │ │ └── test_sfc64.codon │ ├── test_dtype.codon │ ├── test_elision.codon │ ├── test_fft.codon │ ├── test_functional.codon │ ├── test_fusion.codon │ ├── test_indexing.codon │ ├── test_io.codon │ ├── test_lib.codon │ ├── test_linalg.codon │ ├── test_loops.codon │ ├── test_misc.codon │ ├── test_ndmath.codon │ ├── test_npdatetime.codon │ ├── test_pybridge.codon │ ├── test_reductions.codon │ ├── test_routines.codon │ ├── test_sorting.codon │ ├── test_statistics.codon │ ├── test_ufunc.codon │ └── test_window.codon ├── parser/ │ ├── llvm.codon │ ├── simplify_expr.codon │ ├── simplify_stmt.codon │ ├── typecheck/ │ │ ├── a/ │ │ │ ├── __init__.codon │ │ │ ├── b/ │ │ │ │ ├── __init__.codon │ │ │ │ ├── rec1.codon │ │ │ │ ├── rec1_err.codon │ │ │ │ ├── rec2.codon │ │ │ │ └── rec2_err.codon │ │ │ └── sub/ │ │ │ └── __init__.codon │ │ ├── test_access.codon │ │ ├── test_assign.codon │ │ ├── test_basic.codon │ │ ├── test_call.codon │ │ ├── test_class.codon │ │ ├── test_collections.codon │ │ ├── test_cond.codon │ │ ├── test_ctx.codon │ │ ├── test_error.codon │ │ ├── test_function.codon │ │ ├── test_import.codon │ │ ├── test_infer.codon │ │ ├── test_loops.codon │ │ ├── test_op.codon │ │ ├── test_parser.codon │ │ ├── test_python.codon │ │ └── test_typecheck.codon │ ├── typecheck_expr.codon │ ├── typecheck_stmt.codon │ └── types.codon ├── python/ │ ├── __init__.py │ ├── cython_jit.py │ ├── find-python-library.py │ ├── myextension.codon │ ├── myextension2.codon │ ├── mymodule.py │ ├── pybridge.codon │ ├── pyext.py │ └── setup.py ├── stdlib/ │ ├── asyncio_test.codon │ ├── bisect_test.codon │ ├── cmath_test.codon │ ├── cmath_testcases.txt │ ├── datetime_test.codon │ ├── heapq_test.codon │ ├── itertools_test.codon │ ├── llvm_test.codon │ ├── math_test.codon │ ├── operator_test.codon │ ├── random_test.codon │ ├── re_test.codon │ ├── sort_test.codon │ ├── statistics_test.codon │ └── str_test.codon ├── transform/ │ ├── canonical.codon │ ├── dict_opt.codon │ ├── escapes.codon │ ├── folding.codon │ ├── for_lowering.codon │ ├── inlining.codon │ ├── io_opt.codon │ ├── kernels.codon │ ├── list_opt.codon │ ├── omp.codon │ ├── outlining.codon │ └── str_opt.codon └── types.cpp ================================================ FILE CONTENTS ================================================ ================================================ FILE: .clang-format ================================================ --- BasedOnStyle: LLVM ColumnLimit: 88 ================================================ FILE: .clang-tidy ================================================ --- Checks: 'clang-diagnostic-*,clang-analyzer-*,cppcoreguidelines-*,modernize-*,bugprone-*,concurrency-*,performance-*,portability-*,-modernize-use-nodiscard,-modernize-use-trailing-return-type,-cppcoreguidelines-special-member-functions,-bugprone-easily-swappable-parameters,-bugprone-assignment-in-if-condition,-modernize-use-nodiscard' WarningsAsErrors: false HeaderFilterRegex: '(build/.+)|(codon/util/.+)' AnalyzeTemporaryDtors: false FormatStyle: llvm CheckOptions: - key: cppcoreguidelines-macro-usage.CheckCapsOnly value: '1' ================================================ FILE: .gitattributes ================================================ *.codon linguist-language=Python *.png binary *.jpg binary *.jpeg binary *.gif binary *.ico binary *.mov binary *.mp4 binary *.mp3 binary *.flv binary *.fla binary *.swf binary *.gz binary *.zip binary *.7z binary *.ttf binary *.eot binary *.woff binary *.pyc binary *.pdf binary *.gz binary *.bam binary *.bam.bai binary *.cram binary *.cram.crai binary ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: "github-actions" directory: "/" schedule: interval: "monthly" ================================================ FILE: .github/workflows/ci.yml ================================================ name: Codon CI on: push: branches: - master - develop tags: - '*' pull_request: branches: - develop jobs: create_release: name: GitHub Release runs-on: ubuntu-latest outputs: upload_url: ${{ steps.create_release.outputs.upload_url }} permissions: contents: write steps: - name: Create Release if: contains(github.ref, 'tags/v') id: create_release uses: ncipollo/release-action@v1 build: strategy: matrix: include: - os: ubuntu-latest arch: linux-x86_64 - os: ubuntu-latest arch: manylinux2014-x86_64 - os: ubuntu-24.04-arm arch: linux-aarch64 - os: ubuntu-24.04-arm arch: manylinux2014-aarch64 - os: macos-15-intel arch: darwin-x86_64 - os: macos-14 arch: darwin-arm64 runs-on: ${{ matrix.os }} name: Build Codon needs: create_release permissions: contents: write id-token: write steps: - uses: actions/checkout@v6 - name: Set up Python uses: actions/setup-python@v6 with: python-version: '3.11' - name: Build (Ubuntu) if: startsWith(matrix.os, 'ubuntu') run: | (cd .github/build-linux && docker build -t local -f Dockerfile.${{ matrix.arch }} .) docker run -v $(pwd):/github/workspace local /github/workspace ${{ matrix.arch }} yes - name: Build (macOS) if: startsWith(matrix.os, 'macos') run: | sudo mkdir -p /opt/llvm-codon sudo chown -R $(whoami) /opt/llvm-codon curl -L https://github.com/exaloop/llvm-project/releases/download/codon-20.1.7/llvm-codon-20.1.7-${{ matrix.arch }}.tar.bz2 | tar jxf - -C /opt brew install gcc bash .github/build-linux/entrypoint.sh ${{ github.workspace }} ${{ matrix.arch }} yes - name: Upload Artifacts uses: actions/upload-artifact@v7 with: name: codon-${{ matrix.arch }}.tar.gz path: codon-${{ matrix.arch }}.tar.gz - name: Upload Release Asset if: contains(github.ref, 'tags/v') uses: actions/upload-release-asset@v1.0.2 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: upload_url: ${{ needs.create_release.outputs.upload_url }} asset_path: ./codon-${{ matrix.arch }}.tar.gz asset_name: codon-${{ matrix.arch }}.tar.gz asset_content_type: application/gzip - name: Publish PyPI Package if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') && matrix.arch == 'linux-x86_64' uses: pypa/gh-action-pypi-publish@release/v1 build_documentation: name: Build Docs runs-on: ubuntu-latest needs: build permissions: contents: write steps: - name: Checkout repository uses: actions/checkout@v6 - name: Set up Python uses: actions/setup-python@v6 with: python-version: '3.11' - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install mkdocs \ mkdocs-autorefs \ mkdocs-macros-plugin \ mkdocs-material \ mkdocs-material-extensions \ mkdocs-redirects sudo apt-get update sudo apt-get install -y pngquant - name: Download Artifact uses: actions/download-artifact@v8 with: name: codon-linux-x86_64.tar.gz path: ./downloaded-artifact - name: Build API reference run: | mv downloaded-artifact/* . tar -xvzf codon-linux-x86_64.tar.gz codon-deploy-linux-x86_64/bin/codon doc codon-deploy-linux-x86_64/lib/codon/stdlib > docs.json python scripts/docgen.py docs.json docs/libraries/api $(pwd)/codon-deploy-linux-x86_64/lib/codon/stdlib - name: Build MkDocs site run: mkdocs build --strict - name: Deploy to GitHub Pages if: github.ref == 'refs/heads/master' uses: peaceiris/actions-gh-pages@v4 with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: ./site force_orphan: true cname: docs.exaloop.io ================================================ FILE: .gitignore ================================================ ###################### # Generic .gitignore # ###################### # Compiled source # ################### *.com *.class *.dll *.exe *.o *.a *.obj *.so *.dylib *.pyc build*/ install*/ extra/python/src/jit.cpp extra/jupyter/build/ /site # Packages # ############ # it's better to unpack these files and commit the raw source # git has its own built-in compression methods *.7z *.dmg *.iso *.jar *.rar *.tar *.zip **/**.egg-info # Logs and databases # ###################### *.log *.sql *.sqlite # OS generated files # ###################### .DS_Store .DS_Store? ._* .Spotlight-V100 .Trashes ehthumbs.db Thumbs.db # IDE generated files # ####################### .idea .mypy_cache .vscode .cache .ipynb_checkpoints # CMake generated files # ######################### jupyter/share/jupyter/kernels/codon/kernel.json jit/codon/version.py # Testing files # ################# temp/ playground/ scratch*.* /_* ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - id: trailing-whitespace - repo: https://github.com/pre-commit/mirrors-clang-format rev: v17.0.2 hooks: - id: clang-format types: - c++ ================================================ FILE: CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.14) project( Codon VERSION "0.19.6" HOMEPAGE_URL "https://github.com/exaloop/codon" DESCRIPTION "high-performance, extensible Python compiler") set(CODON_JIT_PYTHON_VERSION "0.4.6") configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in" "${PROJECT_SOURCE_DIR}/codon/config/config.h") configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in" "${PROJECT_SOURCE_DIR}/jit/codon/version.py") if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") cmake_policy(SET CMP0135 NEW) endif() set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pedantic -fvisibility-inlines-hidden -Wno-return-type-c-linkage -Wno-gnu-zero-variadic-macro-arguments -Wno-deprecated-declarations" ) else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-return-type") endif() set(CMAKE_CXX_FLAGS_DEBUG "-g") if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fno-limit-debug-info") endif() set(CMAKE_CXX_FLAGS_RELEASE "-O3") include_directories(.) set(APPLE_ARM OFF) if (APPLE AND CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "arm64") set(APPLE_ARM ON) endif() set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) find_package(LLVM REQUIRED) message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") include(${CMAKE_SOURCE_DIR}/cmake/deps.cmake) include(${CMAKE_SOURCE_DIR}/cmake/CMakeRC.cmake) set(CMAKE_BUILD_WITH_INSTALL_RPATH ON) if(APPLE) set(CMAKE_INSTALL_RPATH "@loader_path;@loader_path/../lib/codon") else() set(CMAKE_INSTALL_RPATH "$ORIGIN:$ORIGIN/../lib/codon") endif() add_executable(peg2cpp codon/util/peg2cpp.cpp) target_include_directories(peg2cpp PRIVATE ${peglib_SOURCE_DIR}) target_link_libraries(peg2cpp PRIVATE Threads::Threads fmt) add_custom_command( OUTPUT codon_rules.cpp COMMAND peg2cpp ${CMAKE_SOURCE_DIR}/codon/parser/peg/grammar.peg codon_rules.cpp codon DEPENDS peg2cpp codon/parser/peg/grammar.peg) add_custom_command( OUTPUT omp_rules.cpp COMMAND peg2cpp ${CMAKE_SOURCE_DIR}/codon/parser/peg/openmp.peg omp_rules.cpp omp DEPENDS peg2cpp codon/parser/peg/openmp.peg) # Codon Jupyter library set(CODON_JUPYTER_FILES codon/util/jupyter.h codon/util/jupyter.cpp) add_library(codon_jupyter SHARED ${CODON_JUPYTER_FILES}) # Codon runtime library add_library(codonfloat STATIC codon/runtime/floatlib/extenddftf2.c codon/runtime/floatlib/fp_trunc.h codon/runtime/floatlib/truncdfhf2.c codon/runtime/floatlib/extendhfsf2.c codon/runtime/floatlib/int_endianness.h codon/runtime/floatlib/truncdfsf2.c codon/runtime/floatlib/extendhftf2.c codon/runtime/floatlib/int_lib.h # codon/runtime/floatlib/truncsfbf2.c codon/runtime/floatlib/extendsfdf2.c codon/runtime/floatlib/int_math.h codon/runtime/floatlib/truncsfhf2.c codon/runtime/floatlib/extendsftf2.c codon/runtime/floatlib/int_types.h codon/runtime/floatlib/trunctfdf2.c codon/runtime/floatlib/fp_extend.h codon/runtime/floatlib/int_util.h codon/runtime/floatlib/trunctfhf2.c codon/runtime/floatlib/fp_lib.h # codon/runtime/floatlib/truncdfbf2.c codon/runtime/floatlib/trunctfsf2.c) target_compile_options(codonfloat PRIVATE -O3) target_compile_definitions(codonfloat PRIVATE COMPILER_RT_HAS_FLOAT16) set(CODONRT_FILES codon/runtime/lib.h codon/runtime/lib.cpp codon/runtime/re.cpp codon/runtime/exc.cpp codon/runtime/numpy/sort.cpp codon/runtime/numpy/loops.cpp codon/runtime/numpy/zmath.cpp) add_library(codonrt SHARED ${CODONRT_FILES}) add_dependencies(codonrt zlibstatic gc backtrace bz2 liblzma re2 hwy hwy_contrib fast_float codonfloat) if(DEFINED ENV{CODON_SYSTEM_LIBRARIES}) if(APPLE) set(copied_libgfortran "${CMAKE_BINARY_DIR}/libgfortran.5${CMAKE_SHARED_LIBRARY_SUFFIX}") set(copied_libquadmath "${CMAKE_BINARY_DIR}/libquadmath.0${CMAKE_SHARED_LIBRARY_SUFFIX}") set(copied_libgcc "${CMAKE_BINARY_DIR}/libgcc_s.1.1${CMAKE_SHARED_LIBRARY_SUFFIX}") else() set(copied_libgfortran "${CMAKE_BINARY_DIR}/libgfortran${CMAKE_SHARED_LIBRARY_SUFFIX}.5") set(copied_libquadmath "${CMAKE_BINARY_DIR}/libquadmath${CMAKE_SHARED_LIBRARY_SUFFIX}.0") set(copied_libgcc "${CMAKE_BINARY_DIR}/libgcc_s${CMAKE_SHARED_LIBRARY_SUFFIX}.1") endif() add_custom_command( OUTPUT ${copied_libgfortran} DEPENDS "${CMAKE_SOURCE_DIR}/scripts/get_system_libs.sh" COMMAND ${CMAKE_SOURCE_DIR}/scripts/get_system_libs.sh "$ENV{CODON_SYSTEM_LIBRARIES}" ${CMAKE_BINARY_DIR} COMMENT "Copying system libraries to build directory") add_custom_target(copy_libraries ALL DEPENDS ${copied_libgfortran}) add_dependencies(codonrt copy_libraries) add_library(libgfortran SHARED IMPORTED) set_target_properties(libgfortran PROPERTIES IMPORTED_LOCATION ${copied_libgfortran}) target_link_libraries(codonrt PRIVATE libgfortran) else() message(FATAL_ERROR "Set 'CODON_SYSTEM_LIBRARIES' to the directory containing system libraries.") endif() target_include_directories(codonrt PRIVATE ${backtrace_SOURCE_DIR} ${re2_SOURCE_DIR} ${highway_SOURCE_DIR} "${gc_SOURCE_DIR}/include" "${fast_float_SOURCE_DIR}/include" runtime) target_link_libraries(codonrt PRIVATE fmt omp backtrace LLVMSupport) if(APPLE) target_link_libraries( codonrt PRIVATE -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$) target_link_libraries(codonrt PUBLIC "-framework Accelerate") else() add_dependencies(codonrt openblas) target_link_libraries( codonrt PRIVATE -Wl,--whole-archive $ $ $ $ $ $ $ $ $ -Wl,--no-whole-archive) endif() if(ASAN) target_compile_options( codonrt PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address" "-fsanitize-recover=address") target_link_libraries( codonrt PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address" "-fsanitize-recover=address") endif() add_custom_command( TARGET codonrt POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_BINARY_DIR}) # Codon compiler library include_directories(${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) set(CODON_HPPFILES codon/compiler/compiler.h codon/compiler/debug_listener.h codon/compiler/engine.h codon/compiler/error.h codon/compiler/jit.h codon/compiler/jit_extern.h codon/compiler/memory_manager.h codon/dsl/dsl.h codon/dsl/plugins.h codon/parser/ast.h codon/parser/match.h codon/parser/ast/node.h codon/parser/ast/expr.h codon/parser/ast/stmt.h codon/parser/ast/types.h codon/parser/ast/attr.h codon/parser/ast/types/type.h codon/parser/ast/types/link.h codon/parser/ast/types/class.h codon/parser/ast/types/function.h codon/parser/ast/types/union.h codon/parser/ast/types/static.h codon/parser/ast/types/traits.h codon/parser/cache.h codon/parser/common.h codon/parser/ctx.h codon/parser/peg/peg.h codon/parser/peg/rules.h codon/parser/visitors/doc/doc.h codon/parser/visitors/format/format.h codon/parser/visitors/scoping/scoping.h codon/parser/visitors/translate/translate.h codon/parser/visitors/translate/translate_ctx.h codon/parser/visitors/typecheck/typecheck.h codon/parser/visitors/typecheck/ctx.h codon/parser/visitors/visitor.h codon/cir/analyze/analysis.h codon/cir/analyze/dataflow/capture.h codon/cir/analyze/dataflow/cfg.h codon/cir/analyze/dataflow/dominator.h codon/cir/analyze/dataflow/reaching.h codon/cir/analyze/module/global_vars.h codon/cir/analyze/module/side_effect.h codon/cir/attribute.h codon/cir/base.h codon/cir/const.h codon/cir/dsl/codegen.h codon/cir/dsl/nodes.h codon/cir/flow.h codon/cir/func.h codon/cir/instr.h codon/cir/llvm/gpu.h codon/cir/llvm/llvisitor.h codon/cir/llvm/llvm.h codon/cir/llvm/optimize.h codon/cir/module.h codon/cir/pyextension.h codon/cir/cir.h codon/cir/transform/cleanup/canonical.h codon/cir/transform/cleanup/dead_code.h codon/cir/transform/cleanup/global_demote.h codon/cir/transform/cleanup/replacer.h codon/cir/transform/folding/const_fold.h codon/cir/transform/folding/const_prop.h codon/cir/transform/folding/folding.h codon/cir/transform/folding/rule.h codon/cir/transform/lowering/async_for.h codon/cir/transform/lowering/await.h codon/cir/transform/lowering/imperative.h codon/cir/transform/lowering/pipeline.h codon/cir/transform/manager.h codon/cir/transform/parallel/openmp.h codon/cir/transform/parallel/schedule.h codon/cir/transform/pass.h codon/cir/transform/pythonic/dict.h codon/cir/transform/pythonic/generator.h codon/cir/transform/pythonic/io.h codon/cir/transform/pythonic/list.h codon/cir/transform/pythonic/str.h codon/cir/transform/rewrite.h codon/cir/types/types.h codon/cir/util/cloning.h codon/cir/util/context.h codon/cir/util/format.h codon/cir/util/inlining.h codon/cir/util/irtools.h codon/cir/util/iterators.h codon/cir/util/matching.h codon/cir/util/operator.h codon/cir/util/outlining.h codon/cir/util/packs.h codon/cir/util/side_effect.h codon/cir/util/visitor.h codon/cir/value.h codon/cir/llvm/native/native.h codon/cir/llvm/native/targets/aarch64.h codon/cir/llvm/native/targets/arm.h codon/cir/llvm/native/targets/target.h codon/cir/llvm/native/targets/x86.h codon/cir/transform/numpy/numpy.h codon/cir/transform/numpy/indexing.h codon/cir/var.h codon/util/common.h codon/util/serialize.h codon/util/tser.h) set(CODON_CPPFILES codon/compiler/compiler.cpp codon/compiler/debug_listener.cpp codon/compiler/engine.cpp codon/compiler/error.cpp codon/compiler/jit.cpp codon/compiler/memory_manager.cpp codon/dsl/plugins.cpp codon/parser/ast/expr.cpp codon/parser/ast/attr.cpp codon/parser/ast/stmt.cpp codon/parser/ast/types/type.cpp codon/parser/ast/types/link.cpp codon/parser/ast/types/class.cpp codon/parser/ast/types/function.cpp codon/parser/ast/types/union.cpp codon/parser/ast/types/static.cpp codon/parser/ast/types/traits.cpp codon/parser/cache.cpp codon/parser/match.cpp codon/parser/common.cpp codon/parser/peg/peg.cpp codon/parser/visitors/doc/doc.cpp codon/parser/visitors/format/format.cpp codon/parser/visitors/scoping/scoping.cpp codon/parser/visitors/translate/translate.cpp codon/parser/visitors/translate/translate_ctx.cpp codon/parser/visitors/typecheck/typecheck.cpp codon/parser/visitors/typecheck/infer.cpp codon/parser/visitors/typecheck/ctx.cpp codon/parser/visitors/typecheck/assign.cpp codon/parser/visitors/typecheck/basic.cpp codon/parser/visitors/typecheck/call.cpp codon/parser/visitors/typecheck/class.cpp codon/parser/visitors/typecheck/collections.cpp codon/parser/visitors/typecheck/cond.cpp codon/parser/visitors/typecheck/function.cpp codon/parser/visitors/typecheck/access.cpp codon/parser/visitors/typecheck/import.cpp codon/parser/visitors/typecheck/loops.cpp codon/parser/visitors/typecheck/op.cpp codon/parser/visitors/typecheck/error.cpp codon/parser/visitors/typecheck/special.cpp codon/parser/visitors/visitor.cpp codon/cir/attribute.cpp codon/cir/analyze/analysis.cpp codon/cir/analyze/dataflow/capture.cpp codon/cir/analyze/dataflow/cfg.cpp codon/cir/analyze/dataflow/dominator.cpp codon/cir/analyze/dataflow/reaching.cpp codon/cir/analyze/module/global_vars.cpp codon/cir/analyze/module/side_effect.cpp codon/cir/base.cpp codon/cir/const.cpp codon/cir/dsl/nodes.cpp codon/cir/flow.cpp codon/cir/func.cpp codon/cir/instr.cpp codon/cir/llvm/gpu.cpp codon/cir/llvm/llvisitor.cpp codon/cir/llvm/optimize.cpp codon/cir/module.cpp codon/cir/transform/cleanup/canonical.cpp codon/cir/transform/cleanup/dead_code.cpp codon/cir/transform/cleanup/global_demote.cpp codon/cir/transform/cleanup/replacer.cpp codon/cir/transform/folding/const_fold.cpp codon/cir/transform/folding/const_prop.cpp codon/cir/transform/folding/folding.cpp codon/cir/transform/lowering/async_for.cpp codon/cir/transform/lowering/await.cpp codon/cir/transform/lowering/imperative.cpp codon/cir/transform/lowering/pipeline.cpp codon/cir/transform/manager.cpp codon/cir/transform/parallel/openmp.cpp codon/cir/transform/parallel/schedule.cpp codon/cir/transform/pass.cpp codon/cir/transform/pythonic/dict.cpp codon/cir/transform/pythonic/generator.cpp codon/cir/transform/pythonic/io.cpp codon/cir/transform/pythonic/list.cpp codon/cir/transform/pythonic/str.cpp codon/cir/types/types.cpp codon/cir/util/cloning.cpp codon/cir/util/format.cpp codon/cir/util/inlining.cpp codon/cir/util/irtools.cpp codon/cir/util/matching.cpp codon/cir/util/outlining.cpp codon/cir/util/side_effect.cpp codon/cir/util/visitor.cpp codon/cir/value.cpp codon/cir/var.cpp codon/cir/llvm/native/native.cpp codon/cir/llvm/native/targets/aarch64.cpp codon/cir/llvm/native/targets/arm.cpp codon/cir/llvm/native/targets/x86.cpp codon/cir/transform/numpy/expr.cpp codon/cir/transform/numpy/forward.cpp codon/cir/transform/numpy/indexing.cpp codon/cir/transform/numpy/numpy.cpp codon/util/common.cpp) add_library(codonc SHARED ${CODON_HPPFILES}) target_include_directories(codonc PRIVATE ${peglib_SOURCE_DIR} ${toml_SOURCE_DIR}/include ${semver_SOURCE_DIR}/include ${fast_float_SOURCE_DIR}/include) target_sources(codonc PRIVATE ${CODON_CPPFILES} codon_rules.cpp omp_rules.cpp) if(ASAN) target_compile_options( codonc PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address" "-fsanitize-recover=address") target_link_libraries( codonc PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address" "-fsanitize-recover=address") endif() if(CMAKE_BUILD_TYPE MATCHES Debug) set_source_files_properties(codon_rules.cpp codon/parser/peg/peg.cpp PROPERTIES COMPILE_FLAGS "-O2") endif() llvm_map_components_to_libnames( LLVM_LIBS AllTargetsAsmParsers AllTargetsCodeGens AllTargetsDescs AllTargetsInfos AggressiveInstCombine Analysis AsmParser BitWriter CodeGen Core Extensions IPO IRReader InstCombine Instrumentation MC MCJIT ObjCARCOpts OrcJIT Remarks ScalarOpts Support Symbolize Target TransformUtils Vectorize Passes) file(GLOB_RECURSE CODON_STDLIB_RESOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/stdlib/*.codon" ) cmrc_add_resource_library( codon-stdlib NAMESPACE codon ${CODON_STDLIB_RESOURCES} ) set_property(TARGET codon-stdlib PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(codonc PRIVATE ${LLVM_LIBS} fmt dl codonrt codon-stdlib) # Gather headers add_custom_target( headers ALL COMMENT "Collecting headers" BYPRODUCTS "${CMAKE_BINARY_DIR}/include" VERBATIM COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/include/codon" COMMAND ${CMAKE_COMMAND} -E copy_directory "${CMAKE_SOURCE_DIR}/codon" "${CMAKE_BINARY_DIR}/include/codon" COMMAND find "${CMAKE_BINARY_DIR}/include" -type f ! -name "*.h" -exec rm {} \\;) add_dependencies(headers codonrt codonc) # Prepare lib directory for plugin compilation add_custom_target( libs ALL COMMENT "Collecting libraries" BYPRODUCTS "${CMAKE_BINARY_DIR}/lib" VERBATIM COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/lib/codon" COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_BINARY_DIR}/libcodonc${CMAKE_SHARED_LIBRARY_SUFFIX}" "${CMAKE_BINARY_DIR}/lib/codon" COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_BINARY_DIR}/libcodonrt${CMAKE_SHARED_LIBRARY_SUFFIX}" "${CMAKE_BINARY_DIR}/lib/codon" COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_BINARY_DIR}/libomp${CMAKE_SHARED_LIBRARY_SUFFIX}" "${CMAKE_BINARY_DIR}/lib/codon" COMMAND ${CMAKE_COMMAND} -E copy ${copied_libgfortran} "${CMAKE_BINARY_DIR}/lib/codon" COMMAND /bin/sh -c "test -f '${copied_libquadmath}' && ${CMAKE_COMMAND} -E copy '${copied_libquadmath}' '${CMAKE_BINARY_DIR}/lib/codon' || true" COMMAND ${CMAKE_COMMAND} -E copy ${copied_libgcc} "${CMAKE_BINARY_DIR}/lib/codon") add_dependencies(libs codonrt codonc) # Codon command-line tool add_executable(codon codon/app/main.cpp) target_link_libraries(codon PUBLIC fmt codonc codon_jupyter Threads::Threads) # Codon test Download and unpack googletest at configure time include(FetchContent) FetchContent_Declare( googletest URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip ) # For Windows: Prevent overriding the parent project's compiler/linker settings set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) option(INSTALL_GTEST "Enable installation of googletest." OFF) FetchContent_MakeAvailable(googletest) enable_testing() set(CODON_TEST_CPPFILES test/main.cpp test/cir/analyze/dominator.cpp test/cir/analyze/reaching.cpp test/cir/base.cpp test/cir/constant.cpp test/cir/flow.cpp test/cir/func.cpp test/cir/instr.cpp test/cir/module.cpp test/cir/transform/manager.cpp test/cir/types/types.cpp test/cir/util/matching.cpp test/cir/value.cpp test/cir/var.cpp test/types.cpp) add_executable(codon_test ${CODON_TEST_CPPFILES}) target_include_directories(codon_test PRIVATE test/cir "${gc_SOURCE_DIR}/include") target_link_libraries(codon_test fmt codonc codonrt gtest_main) target_compile_definitions(codon_test PRIVATE TEST_DIR="${CMAKE_CURRENT_SOURCE_DIR}/test") install(TARGETS codonrt codonc codon_jupyter DESTINATION lib/codon) install(FILES ${CMAKE_BINARY_DIR}/libomp${CMAKE_SHARED_LIBRARY_SUFFIX} DESTINATION lib/codon) install(FILES ${copied_libgfortran} DESTINATION lib/codon) # only install libquadmath if it exists at build time install(CODE " file(GLOB _quadmath \"${copied_libquadmath}\") if(EXISTS \"\${_quadmath}\") file(INSTALL DESTINATION \"\${CMAKE_INSTALL_PREFIX}/lib/codon\" TYPE FILE FILES \"\${_quadmath}\") endif() ") install(FILES ${copied_libgcc} DESTINATION lib/codon) install(TARGETS codon DESTINATION bin) install(DIRECTORY ${CMAKE_BINARY_DIR}/include/codon DESTINATION include) install(DIRECTORY ${LLVM_INCLUDE_DIRS}/llvm DESTINATION include) install(DIRECTORY ${LLVM_INCLUDE_DIRS}/llvm-c DESTINATION include) install(DIRECTORY ${CMAKE_SOURCE_DIR}/stdlib DESTINATION lib/codon) install(DIRECTORY ${CMAKE_SOURCE_DIR}/jit/ DESTINATION python) install(DIRECTORY DESTINATION lib/codon/plugins) install(CODE [[ if(APPLE) # Compute the real install root (supports DESTDIR) set(_root "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}") message(STATUS "fix_loader_paths.sh on: ${_root}") execute_process( COMMAND /bin/bash "${CMAKE_SOURCE_DIR}/scripts/fix_loader_paths.sh" "${_root}" RESULT_VARIABLE rc ) if(NOT rc EQUAL 0) message(FATAL_ERROR "fix_loader_paths.sh failed with code ${rc}") endif() endif() ]]) ================================================ FILE: CODEOWNERS ================================================ * @arshajii @inumanag /codon/ @arshajii /codon/parser/ @inumanag ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to Codon Thank you for considering contributing to Codon! This document contains some helpful information for getting started. The best place to ask questions or get feedback is [our Discord](https://discord.gg/HeWRhagCmP). ## Development workflow All development is done on the [`develop`](https://github.com/exaloop/codon/tree/develop) branch. Just before release, we bump the version number, merge into [`master`](https://github.com/exaloop/codon/tree/master) and tag the build with a tag of the form `vX.Y.Z` where `X`, `Y` and `Z` are the [SemVer](https://semver.org) major, minor and patch numbers, respectively. Our CI build process automatically builds and deploys tagged commits as a new GitHub release. ## Coding standards All C++ code should be formatted with [ClangFormat](https://clang.llvm.org/docs/ClangFormat.html) using the `.clang-format` file in the root of the repository. ## Writing tests Tests are written as Codon programs. The [`test/core/`](https://github.com/exaloop/codon/tree/master/test/core) directory contains some examples. If you add a new test file, be sure to add it to [`test/main.cpp`](https://github.com/exaloop/codon/blob/master/test/main.cpp) so that it will be executed as part of the test suite. There are two ways to write tests for Codon: #### New style Example: ```python @test def my_test(): assert 2 + 2 == 4 my_test() ``` **Semantics:** `assert` statements in functions marked `@test` are not compiled to standard assertions: they don't terminate the program when the condition fails, but instead print source information, fail the test, and move on. #### Old style Example: ```python print(2 + 2) # EXPECT: 4 ``` **Semantics:** The source file is scanned for `EXPECT`s, executed, then the output is compared to the "expected" output. Note that if you have, for example, an `EXPECT` in a loop, you will need to duplicate it however many times the loop is executed. Using `EXPECT` is helpful mainly in cases where you need to test control flow, **otherwise prefer the new style**. ## Pull requests Pull requests should generally be based on the `develop` branch. Before submitting a pull request, please make sure... - ... to provide a clear description of the purpose of the pull request. - ... to include tests for any new or changed code. - ... that all code is formatted as per the guidelines above. Please be patient with pull request reviews, as our throughput is limited. ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================

Codon banner

Docs  ·  FAQ  ·  Blog  ·  Discord  ·  Roadmap  ·  Benchmarks

Build Status # What is Codon? Codon is a high-performance Python implementation that compiles to native machine code without any runtime overhead. Typical speedups over vanilla Python are on the order of 10-100x or more, on a single thread. Codon's performance is typically on par with (and sometimes better than) that of C/C++. Unlike Python, Codon supports native multithreading, which can lead to speedups many times higher still. *Think of Codon as Python reimagined for static, ahead-of-time compilation, built from the ground up with best possible performance in mind.* ## Goals - :bulb: **No learning curve:** Be as close to CPython as possible in terms of syntax, semantics and libraries - :rocket: **Top-notch performance:** At *least* on par with low-level languages like C, C++ or Rust - :computer: **Hardware support:** Full, seamless support for multicore programming, multithreading (no GIL!), GPU and more - :chart_with_upwards_trend: **Optimizations:** Comprehensive optimization framework that can target high-level Python constructs and libraries - :battery: **Interoperability:** Full interoperability with Python's ecosystem of packages and libraries ## Non-goals - :x: *Drop-in replacement for CPython:* Codon is not a drop-in replacement for CPython. There are some aspects of Python that are not suitable for static compilation — we don't support these in Codon. There are ways to use Codon in larger Python codebases via its [JIT decorator](https://docs.exaloop.io/codon/interoperability/decorator) or [Python extension backend](https://docs.exaloop.io/codon/interoperability/pyext). Codon also supports calling any Python module via its [Python interoperability](https://docs.exaloop.io/codon/interoperability/python). See also [*"Differences with Python"*](https://docs.exaloop.io/codon/general/differences) in the docs. - :x: *New syntax and language constructs:* We try to avoid adding new syntax, keywords or other language features as much as possible. While Codon does add some new syntax in a couple places (e.g. to express parallelism), we try to make it as familiar and intuitive as possible. ## How it works

Codon figure

# Quick start Download and install Codon with this command: ```bash /bin/bash -c "$(curl -fsSL https://exaloop.io/install.sh)" ``` After following the prompts, the `codon` command will be available to use. For example: - To run a program: `codon run file.py` - To run a program with optimizations enabled: `codon run -release file.py` - To compile to an executable: `codon build -release file.py` - To generate LLVM IR: `codon build -release -llvm file.py` Many more options are available and described in [the docs](https://docs.exaloop.io/codon/general/intro). Alternatively, you can [build from source](https://docs.exaloop.io/codon/advanced/build). # Examples ## Basics Codon supports much of Python, and many Python programs will work with few if any modifications. Here's a simple script `fib.py` that computes the 40th Fibonacci number... ``` python from time import time def fib(n): return n if n < 2 else fib(n - 1) + fib(n - 2) t0 = time() ans = fib(40) t1 = time() print(f'Computed fib(40) = {ans} in {t1 - t0} seconds.') ``` ... run through Python and Codon: ``` $ python3 fib.py Computed fib(40) = 102334155 in 17.979357957839966 seconds. $ codon run -release fib.py Computed fib(40) = 102334155 in 0.275645 seconds. ``` ## Using Python libraries You can import and use any Python package from Codon via `from python import`. For example: ```python from python import matplotlib.pyplot as plt data = [x**2 for x in range(10)] plt.plot(data) plt.show() ``` (Just remember to set the `CODON_PYTHON` environment variable to the CPython shared library, as explained in the [the Python interoperability docs](https://docs.exaloop.io/codon/interoperability/python).) ## Parallelism Codon supports native multithreading via [OpenMP](https://www.openmp.org/). The `@par` annotation in the code below tells the compiler to parallelize the following `for`-loop, in this case using a dynamic schedule, chunk size of 100, and 16 threads. ```python from sys import argv def is_prime(n): factors = 0 for i in range(2, n): if n % i == 0: factors += 1 return factors == 0 limit = int(argv[1]) total = 0 @par(schedule='dynamic', chunk_size=100, num_threads=16) for i in range(2, limit): if is_prime(i): total += 1 print(total) ``` Note that Codon automatically turns the `total += 1` statement in the loop body into an atomic reduction to avoid race conditions. Learn more in the [multithreading docs](https://docs.exaloop.io/codon/advanced/parallel). Codon also supports writing and executing GPU kernels. Here's an example that computes the [Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set): ```python import gpu MAX = 1000 # maximum Mandelbrot iterations N = 4096 # width and height of image pixels = [0 for _ in range(N * N)] def scale(x, a, b): return a + (x/N)*(b - a) @gpu.kernel def mandelbrot(pixels): idx = (gpu.block.x * gpu.block.dim.x) + gpu.thread.x i, j = divmod(idx, N) c = complex(scale(j, -2.00, 0.47), scale(i, -1.12, 1.12)) z = 0j iteration = 0 while abs(z) <= 2 and iteration < MAX: z = z**2 + c iteration += 1 pixels[idx] = int(255 * iteration/MAX) mandelbrot(pixels, grid=(N*N)//1024, block=1024) ``` GPU programming can also be done using the `@par` syntax with `@par(gpu=True)`. See the [GPU programming docs](https://docs.exaloop.io/codon/advanced/gpu) for more details. ## NumPy support Codon includes a feature-complete, fully-compiled native NumPy implementation. It uses the same API as NumPy, but re-implements everything in Codon itself, allowing for a range of optimizations and performance improvements. Here's an example NumPy program that approximates $\pi$ using random numbers... ``` python import time import numpy as np rng = np.random.default_rng(seed=0) x = rng.random(500_000_000) y = rng.random(500_000_000) t0 = time.time() # pi ~= 4 x (fraction of points in circle) pi = ((x-1)**2 + (y-1)**2 < 1).sum() * (4 / len(x)) t1 = time.time() print(f'Computed pi~={pi:.4f} in {t1 - t0:.2f} sec') ``` ... run through Python and Codon: ``` $ python3 pi.py Computed pi~=3.1417 in 2.25 sec $ codon run -release pi.py Computed pi~=3.1417 in 0.43 sec ``` Codon can speed up NumPy code through general-purpose and NumPy-specific compiler optimizations, including inlining, fusion, memory allocation elision and more. Furthermore, Codon's NumPy implementation works with its multithreading and GPU capabilities, and can even integrate with [PyTorch](https://pytorch.org). Learn more in the [Codon-NumPy docs](https://docs.exaloop.io/codon/interoperability/numpy). # Documentation Please see [docs.exaloop.io](https://docs.exaloop.io) for in-depth documentation. # Acknowledgements This project would not be possible without: - **Funding**: - National Science Foundation (NSF) 🇺🇸 - National Institutes of Health (NIH) 🇺🇸 - MIT 🇺🇸 - MIT E14 Fund 🇺🇸 - Natural Sciences and Engineering Research Council (NSERC) 🇨🇦 - Canada Research Chairs 🇨🇦 - Canada Foundation for Innovation 🇨🇦 - B.C. Knowledge Development Fund 🇨🇦 - University of Victoria 🇨🇦 - **Libraries**: [LLVM Compiler Infrastructure](https://llvm.org/), [yhirose's peglib](https://github.com/yhirose/cpp-peglib), [Boehm-Demers-Weiser Garbage Collector](https://github.com/ivmai/bdwgc), [KonanM's tser](https://github.com/KonanM/tser), [{fmt}](https://github.com/fmtlib/fmt), [toml++](https://marzer.github.io/tomlplusplus/), [semver](https://github.com/Neargye/semver), [zlib-ng](https://github.com/zlib-ng/zlib-ng), [xz](https://github.com/tukaani-project/xz), [bz2](https://sourceware.org/bzip2/), [Google RE2](https://github.com/google/re2), [libbacktrace](https://github.com/ianlancetaylor/libbacktrace), [fast_float](https://github.com/fastfloat/fast_float), [Google Highway](https://github.com/google/highway), [NumPy](https://numpy.org/) ================================================ FILE: bench/README.md ================================================ # Codon benchmark suite This folder contains a number of Codon benchmarks. Some are taken from the [pyperformance suite](https://github.com/python/pyperformance) while others are adaptations of applications we've encountered in the wild. Further, some of the benchmarks are identical in both Python and Codon, some are changed slightly to work with Codon's type system, and some use Codon-specific features like parallelism or GPU. Some of the pyperformance benchmarks can be made (much) faster in Codon by using various Codon-specific features, but their adaptations here are virtually identical to the original implementations (mainly just the use of the `pyperf` module is removed). ## Benchmarks - `chaos`: [Pyperformance's `chaos` benchmark](https://github.com/python/pyperformance/blob/main/pyperformance/data-files/benchmarks/bm_chaos/run_benchmark.py). - `float`: [Pyperformance's `float` benchmark](https://github.com/python/pyperformance/blob/main/pyperformance/data-files/benchmarks/bm_float/run_benchmark.py). - `go`: [Pyperformance's `go` benchmark](https://github.com/python/pyperformance/blob/main/pyperformance/data-files/benchmarks/bm_go/run_benchmark.py). - `nbody`: [Pyperformance's `nbody` benchmark](https://github.com/python/pyperformance/blob/main/pyperformance/data-files/benchmarks/bm_nbody/run_benchmark.py). - `spectral_norm`: [Pyperformance's `spectral_norm` benchmark](https://github.com/python/pyperformance/blob/main/pyperformance/data-files/benchmarks/bm_spectral_norm/run_benchmark.py). - `mandelbrot`: Generates an image of the Mandelbrot set. Codon version uses GPU via one additional `@par(gpu=True, collapse=2)` line. - `set_partition`: Calculates set partitions. Code taken from [this Stack Overflow answer](https://stackoverflow.com/a/73549333). - `sum`: Computes sum of integers from 1 to 50000000 with a loop. Code taken from [this article](https://towardsdatascience.com/getting-started-with-pypy-ef4ba5cb431c). - `taq`: Performs volume peak detection on an NYSE TAQ file. Sample TAQ files can be downloaded and uncompressed from [here](https://ftp.nyse.com/Historical%20Data%20Samples/DAILY%20TAQ/) (e.g. `EQY_US_ALL_NBBO_20220705.gz`). We recommend using the first 10M lines for benchmarking purposes. The TAQ file path should be passed to the benchmark script through the `DATA_TAQ` environment variable. - `binary_trees`: [Boehm's binary trees benchmark](https://hboehm.info/gc/gc_bench.html). - `fannkuch`: See [*Performing Lisp analysis of the FANNKUCH benchmark*](https://dl.acm.org/doi/10.1145/382109.382124) by Kenneth R. Anderson and Duane Rettig. Benchmark involves generating permutations and repeatedly reversing elements of a list. Codon version is multithreaded with a dynamic schedule via one additional `@par(schedule='dynamic')` line. - `word_count`: Counts occurrences of words in a file using a dictionary. The file should be passed to the benchmark script through the `DATA_WORD_COUNT` environment variable. - `primes`: Counts the number of prime numbers below a threshold. Codon version is multithreaded with a dynamic schedule via one additional `@par(schedule='dynamic')` line. ================================================ FILE: bench/codon/binary_trees.codon ================================================ # The Computer Language Benchmarks Game # http://benchmarksgame.alioth.debian.org/ # # contributed by Antoine Pitrou # modified by Dominique Wahli and Daniel Nanz # modified by Joerg Baumann # modified by @arshajii for Codon import sys import time class Node: left: Optional[Node] = None right: Optional[Node] = None def make_tree(d): return Node(make_tree(d - 1), make_tree(d - 1)) if d > 0 else Node() def check_tree(node): l, r = node.left, node.right if l is None: return 1 else: return 1 + check_tree(l) + check_tree(r) def make_check(itde, make=make_tree, check=check_tree): i, d = itde return check(make(d)) def get_argchunks(i, d, chunksize=5000): assert chunksize % 2 == 0 chunk = [] for k in range(1, i + 1): chunk.append((k, d)) if len(chunk) == chunksize: yield chunk chunk = [] if len(chunk) > 0: yield chunk def main(n, min_depth=4): max_depth = max(min_depth + 2, n) stretch_depth = max_depth + 1 print(f'stretch tree of depth {stretch_depth}\t check: {make_check((0, stretch_depth))}') long_lived_tree = make_tree(max_depth) mmd = max_depth + min_depth for d in range(min_depth, stretch_depth, 2): i = 2 ** (mmd - d) cs = 0 for argchunk in get_argchunks(i, d): cs += sum(map(make_check, argchunk)) print(f'{i}\t trees of depth {d}\t check: {cs}') print(f'long lived tree of depth {max_depth}\t check: {check_tree(long_lived_tree)}') t0 = time.time() main(int(sys.argv[1])) t1 = time.time() print(t1 - t0) ================================================ FILE: bench/codon/binary_trees.cpp ================================================ #include #include #include #include #include #include struct Node { std::unique_ptr left{}; std::unique_ptr right{}; }; inline std::unique_ptr make_tree(int d) { if (d > 0) { return std::make_unique(Node{make_tree(d - 1), make_tree(d - 1)}); } else { return std::make_unique(); } } inline int check_tree(const std::unique_ptr &node) { if (!node->left) return 1; else return 1 + check_tree(node->left) + check_tree(node->right); } inline int make_check(const std::pair &itde) { int i = itde.first, d = itde.second; auto tree = make_tree(d); return check_tree(tree); } struct ArgChunks { int i, k, d, chunksize; std::vector> chunk; ArgChunks(int i, int d, int chunksize = 5000) : i(i), k(1), d(d), chunksize(chunksize), chunk() { assert(chunksize % 2 == 0); } bool next() { chunk.clear(); while (k <= i) { chunk.emplace_back(k++, d); if (chunk.size() == chunksize) return true; } return !chunk.empty(); } }; int main(int argc, char *argv[]) { using clock = std::chrono::high_resolution_clock; using std::chrono::duration_cast; using std::chrono::milliseconds; auto t = clock::now(); int min_depth = 4; int n = std::stoi(argv[1]); int max_depth = std::max(min_depth + 2, n); int stretch_depth = max_depth + 1; std::cout << "stretch tree of depth " << stretch_depth << "\t check: " << make_check({0, stretch_depth}) << '\n'; auto long_lived_tree = make_tree(max_depth); int mmd = max_depth + min_depth; for (int d = min_depth; d < stretch_depth; d += 2) { int i = (1 << (mmd - d)); int cs = 0; ArgChunks iter(i, d); while (iter.next()) { for (auto &argchunk : iter.chunk) { cs += make_check(argchunk); } } std::cout << i << "\t trees of depth " << d << "\t check: " << cs << '\n'; } std::cout << "long lived tree of depth " << max_depth << "\t check: " << check_tree(long_lived_tree) << '\n'; std::cout << (duration_cast(clock::now() - t).count() / 1e3) << std::endl; } ================================================ FILE: bench/codon/binary_trees.py ================================================ # The Computer Language Benchmarks Game # http://benchmarksgame.alioth.debian.org/ # # contributed by Antoine Pitrou # modified by Dominique Wahli and Daniel Nanz # modified by Joerg Baumann # modified by @arshajii for Codon import sys import time class Node: def __init__(self, left = None, right = None): self.left = left self.right = right def make_tree(d): return Node(make_tree(d - 1), make_tree(d - 1)) if d > 0 else Node() def check_tree(node): l, r = node.left, node.right if l is None: return 1 else: return 1 + check_tree(l) + check_tree(r) def make_check(itde, make=make_tree, check=check_tree): i, d = itde return check(make(d)) def get_argchunks(i, d, chunksize=5000): assert chunksize % 2 == 0 chunk = [] for k in range(1, i + 1): chunk.append((k, d)) if len(chunk) == chunksize: yield chunk chunk = [] if len(chunk) > 0: yield chunk def main(n, min_depth=4): max_depth = max(min_depth + 2, n) stretch_depth = max_depth + 1 print(f'stretch tree of depth {stretch_depth}\t check: {make_check((0, stretch_depth))}') long_lived_tree = make_tree(max_depth) mmd = max_depth + min_depth for d in range(min_depth, stretch_depth, 2): i = 2 ** (mmd - d) cs = 0 for argchunk in get_argchunks(i, d): cs += sum(map(make_check, argchunk)) print(f'{i}\t trees of depth {d}\t check: {cs}') print(f'long lived tree of depth {max_depth}\t check: {check_tree(long_lived_tree)}') t0 = time.time() main(int(sys.argv[1])) t1 = time.time() print(t1 - t0) ================================================ FILE: bench/codon/chaos.codon ================================================ """create chaosgame-like fractals Copyright (C) 2005 Carl Friedrich Bolz adapted by @arshajii for Codon """ import math import random import sys import time DEFAULT_THICKNESS = 1.0 DEFAULT_WIDTH = 2048 #256 DEFAULT_HEIGHT = 2048 #256 DEFAULT_ITERATIONS = 1000000 #5000 DEFAULT_RNG_SEED = 1234 class GVector(object): x: float y: float z: float def __init__(self, x=0, y=0, z=0): self.x = x self.y = y self.z = z def Mag(self): return math.sqrt(self.x ** 2 + self.y ** 2 + self.z ** 2) def dist(self, other): return math.sqrt((self.x - other.x) ** 2 + (self.y - other.y) ** 2 + (self.z - other.z) ** 2) def __add__(self, other): if not isinstance(other, GVector): raise ValueError("Can't add GVector to " + str(type(other))) v = GVector(self.x + other.x, self.y + other.y, self.z + other.z) return v def __sub__(self, other): return self + other * -1 def __mul__(self, other): v = GVector(self.x * other, self.y * other, self.z * other) return v #__rmul__ = __mul__ def linear_combination(self, other, l1, l2=None): if l2 is None: l2 = 1 - l1 v = GVector(self.x * l1 + other.x * l2, self.y * l1 + other.y * l2, self.z * l1 + other.z * l2) return v #def __str__(self): # return "<%f, %f, %f>" % (self.x, self.y, self.z) #def __repr__(self): # return "GVector(%f, %f, %f)" % (self.x, self.y, self.z) class Spline(object): """Class for representing B-Splines and NURBS of arbitrary degree""" knots: List[int] degree: int points: List[GVector] def __init__(self, points, degree, knots): """Creates a Spline. points is a list of GVector, degree is the degree of the Spline. """ if len(points) > len(knots) - degree + 1: raise ValueError("too many control points") elif len(points) < len(knots) - degree + 1: raise ValueError("not enough control points") last = knots[0] for cur in knots[1:]: if cur < last: raise ValueError("knots not strictly increasing") last = cur self.knots = knots self.points = points self.degree = degree def GetDomain(self): """Returns the domain of the B-Spline""" return (self.knots[self.degree - 1], self.knots[len(self.knots) - self.degree]) def __call__(self, u): """Calculates a point of the B-Spline using de Boors Algorithm""" dom = self.GetDomain() if u < dom[0] or u > dom[1]: raise ValueError("Function value not in domain") if u == dom[0]: return self.points[0] if u == dom[1]: return self.points[-1] I = self.GetIndex(u) d = [self.points[I - self.degree + 1 + ii] for ii in range(self.degree + 1)] U = self.knots for ik in range(1, self.degree + 1): for ii in range(I - self.degree + ik + 1, I + 2): ua = U[ii + self.degree - ik] ub = U[ii - 1] co1 = (ua - u) / (ua - ub) co2 = (u - ub) / (ua - ub) index = ii - I + self.degree - ik - 1 d[index] = d[index].linear_combination(d[index + 1], co1, co2) return d[0] def GetIndex(self, u): dom = self.GetDomain() for ii in range(self.degree - 1, len(self.knots) - self.degree): if u >= self.knots[ii] and u < self.knots[ii + 1]: I = ii break else: I = dom[1] - 1 return I def __len__(self): return len(self.points) #def __repr__(self): # return "Spline(%r, %r, %r)" % (self.points, self.degree, self.knots) def write_ppm(im, filename): magic = 'P6\n' maxval = 255 w = len(im) h = len(im[0]) #with open(filename, "w", encoding="latin1", newline='') as fp: with open(filename, "w") as fp: fp.write(magic) #fp.write('%i %i\n%i\n' % (w, h, maxval)) fp.write(f'{w} {h}\n{maxval}\n') for j in range(h): for i in range(w): val = im[i][j] c = val * 255 #fp.write('%c%c%c' % (c, c, c)) c = chr(c) fp.write(f'{c}{c}{c}') class Chaosgame(object): splines: List[Spline] thickness: float minx: float miny: float maxx: float maxy: float height: float width: float num_trafos: List[int] num_total: int def __init__(self, splines, thickness=0.1): self.splines = splines self.thickness = thickness self.minx = min([p.x for spl in splines for p in spl.points]) self.miny = min([p.y for spl in splines for p in spl.points]) self.maxx = max([p.x for spl in splines for p in spl.points]) self.maxy = max([p.y for spl in splines for p in spl.points]) self.height = self.maxy - self.miny self.width = self.maxx - self.minx self.num_trafos = [] maxlength = thickness * self.width / self.height for spl in splines: length = 0. curr = spl(0) for i in range(1, 1000): last = curr t = 1 / 999 * i curr = spl(t) length += curr.dist(last) self.num_trafos.append(max(1, int(length / maxlength * 1.5))) self.num_total = sum(self.num_trafos) def get_random_trafo(self): r = random.randrange(int(self.num_total) + 1) l = 0 for i in range(len(self.num_trafos)): if r >= l and r < l + self.num_trafos[i]: return i, random.randrange(self.num_trafos[i]) l += self.num_trafos[i] return len(self.num_trafos) - 1, random.randrange(self.num_trafos[-1]) def transform_point(self, point): x = (point.x - self.minx) / self.width y = (point.y - self.miny) / self.height #if trafo is None: trafo = self.get_random_trafo() start, end = self.splines[trafo[0]].GetDomain() length = end - start seg_length = length / self.num_trafos[trafo[0]] t = start + seg_length * trafo[1] + seg_length * x basepoint = self.splines[trafo[0]](t) if t + 1 / 50000 > end: neighbour = self.splines[trafo[0]](t - 1 / 50000) derivative = neighbour - basepoint else: neighbour = self.splines[trafo[0]](t + 1 / 50000) derivative = basepoint - neighbour if derivative.Mag() != 0: basepoint.x += derivative.y / derivative.Mag() * (y - 0.5) * \ self.thickness basepoint.y += -derivative.x / derivative.Mag() * (y - 0.5) * \ self.thickness else: print("r", end='') self.truncate(basepoint) return basepoint def truncate(self, point): if point.x >= self.maxx: point.x = self.maxx if point.y >= self.maxy: point.y = self.maxy if point.x < self.minx: point.x = self.minx if point.y < self.miny: point.y = self.miny def create_image_chaos(self, w, h, iterations, filename, rng_seed): # Always use the same sequence of random numbers # to get reproducible benchmark random.seed(rng_seed) im = [[1] * h for i in range(w)] point = GVector((self.maxx + self.minx) / 2, (self.maxy + self.miny) / 2, 0) for _ in range(iterations): point = self.transform_point(point) x = (point.x - self.minx) / self.width * w y = (point.y - self.miny) / self.height * h x = int(x) y = int(y) if x == w: x -= 1 if y == h: y -= 1 im[x][h - y - 1] = 0 if filename: write_ppm(im, filename) def main(): splines = [ Spline([ GVector(1.597350, 3.304460, 0.000000), GVector(1.575810, 4.123260, 0.000000), GVector(1.313210, 5.288350, 0.000000), GVector(1.618900, 5.329910, 0.000000), GVector(2.889940, 5.502700, 0.000000), GVector(2.373060, 4.381830, 0.000000), GVector(1.662000, 4.360280, 0.000000)], 3, [0, 0, 0, 1, 1, 1, 2, 2, 2]), Spline([ GVector(2.804500, 4.017350, 0.000000), GVector(2.550500, 3.525230, 0.000000), GVector(1.979010, 2.620360, 0.000000), GVector(1.979010, 2.620360, 0.000000)], 3, [0, 0, 0, 1, 1, 1]), Spline([ GVector(2.001670, 4.011320, 0.000000), GVector(2.335040, 3.312830, 0.000000), GVector(2.366800, 3.233460, 0.000000), GVector(2.366800, 3.233460, 0.000000)], 3, [0, 0, 0, 1, 1, 1]) ] chaos = Chaosgame(splines, DEFAULT_THICKNESS) chaos.create_image_chaos(DEFAULT_WIDTH, DEFAULT_HEIGHT, DEFAULT_ITERATIONS, sys.argv[1], DEFAULT_RNG_SEED) t0 = time.time() main() t1 = time.time() print(t1 - t0) ================================================ FILE: bench/codon/chaos.py ================================================ """create chaosgame-like fractals Copyright (C) 2005 Carl Friedrich Bolz adapted by @arshajii for Codon """ import math import random import sys import time DEFAULT_THICKNESS = 1.0 DEFAULT_WIDTH = 2048 #256 DEFAULT_HEIGHT = 2048 #256 DEFAULT_ITERATIONS = 1000000 #5000 DEFAULT_RNG_SEED = 1234 class GVector(object): def __init__(self, x=0, y=0, z=0): self.x = x self.y = y self.z = z def Mag(self): return math.sqrt(self.x ** 2 + self.y ** 2 + self.z ** 2) def dist(self, other): return math.sqrt((self.x - other.x) ** 2 + (self.y - other.y) ** 2 + (self.z - other.z) ** 2) def __add__(self, other): if not isinstance(other, GVector): raise ValueError("Can't add GVector to " + str(type(other))) v = GVector(self.x + other.x, self.y + other.y, self.z + other.z) return v def __sub__(self, other): return self + other * -1 def __mul__(self, other): v = GVector(self.x * other, self.y * other, self.z * other) return v __rmul__ = __mul__ def linear_combination(self, other, l1, l2=None): if l2 is None: l2 = 1 - l1 v = GVector(self.x * l1 + other.x * l2, self.y * l1 + other.y * l2, self.z * l1 + other.z * l2) return v def __str__(self): return "<%f, %f, %f>" % (self.x, self.y, self.z) def __repr__(self): return "GVector(%f, %f, %f)" % (self.x, self.y, self.z) class Spline(object): """Class for representing B-Splines and NURBS of arbitrary degree""" def __init__(self, points, degree, knots): """Creates a Spline. points is a list of GVector, degree is the degree of the Spline. """ if len(points) > len(knots) - degree + 1: raise ValueError("too many control points") elif len(points) < len(knots) - degree + 1: raise ValueError("not enough control points") last = knots[0] for cur in knots[1:]: if cur < last: raise ValueError("knots not strictly increasing") last = cur self.knots = knots self.points = points self.degree = degree def GetDomain(self): """Returns the domain of the B-Spline""" return (self.knots[self.degree - 1], self.knots[len(self.knots) - self.degree]) def __call__(self, u): """Calculates a point of the B-Spline using de Boors Algorithm""" dom = self.GetDomain() if u < dom[0] or u > dom[1]: raise ValueError("Function value not in domain") if u == dom[0]: return self.points[0] if u == dom[1]: return self.points[-1] I = self.GetIndex(u) d = [self.points[I - self.degree + 1 + ii] for ii in range(self.degree + 1)] U = self.knots for ik in range(1, self.degree + 1): for ii in range(I - self.degree + ik + 1, I + 2): ua = U[ii + self.degree - ik] ub = U[ii - 1] co1 = (ua - u) / (ua - ub) co2 = (u - ub) / (ua - ub) index = ii - I + self.degree - ik - 1 d[index] = d[index].linear_combination(d[index + 1], co1, co2) return d[0] def GetIndex(self, u): dom = self.GetDomain() for ii in range(self.degree - 1, len(self.knots) - self.degree): if u >= self.knots[ii] and u < self.knots[ii + 1]: I = ii break else: I = dom[1] - 1 return I def __len__(self): return len(self.points) def __repr__(self): return "Spline(%r, %r, %r)" % (self.points, self.degree, self.knots) def write_ppm(im, filename): magic = 'P6\n' maxval = 255 w = len(im) h = len(im[0]) with open(filename, "w", encoding="latin1", newline='') as fp: fp.write(magic) fp.write('%i %i\n%i\n' % (w, h, maxval)) for j in range(h): for i in range(w): val = im[i][j] c = val * 255 fp.write('%c%c%c' % (c, c, c)) class Chaosgame(object): def __init__(self, splines, thickness=0.1): self.splines = splines self.thickness = thickness self.minx = min([p.x for spl in splines for p in spl.points]) self.miny = min([p.y for spl in splines for p in spl.points]) self.maxx = max([p.x for spl in splines for p in spl.points]) self.maxy = max([p.y for spl in splines for p in spl.points]) self.height = self.maxy - self.miny self.width = self.maxx - self.minx self.num_trafos = [] maxlength = thickness * self.width / self.height for spl in splines: length = 0 curr = spl(0) for i in range(1, 1000): last = curr t = 1 / 999 * i curr = spl(t) length += curr.dist(last) self.num_trafos.append(max(1, int(length / maxlength * 1.5))) self.num_total = sum(self.num_trafos) def get_random_trafo(self): r = random.randrange(int(self.num_total) + 1) l = 0 for i in range(len(self.num_trafos)): if r >= l and r < l + self.num_trafos[i]: return i, random.randrange(self.num_trafos[i]) l += self.num_trafos[i] return len(self.num_trafos) - 1, random.randrange(self.num_trafos[-1]) def transform_point(self, point, trafo=None): x = (point.x - self.minx) / self.width y = (point.y - self.miny) / self.height if trafo is None: trafo = self.get_random_trafo() start, end = self.splines[trafo[0]].GetDomain() length = end - start seg_length = length / self.num_trafos[trafo[0]] t = start + seg_length * trafo[1] + seg_length * x basepoint = self.splines[trafo[0]](t) if t + 1 / 50000 > end: neighbour = self.splines[trafo[0]](t - 1 / 50000) derivative = neighbour - basepoint else: neighbour = self.splines[trafo[0]](t + 1 / 50000) derivative = basepoint - neighbour if derivative.Mag() != 0: basepoint.x += derivative.y / derivative.Mag() * (y - 0.5) * \ self.thickness basepoint.y += -derivative.x / derivative.Mag() * (y - 0.5) * \ self.thickness else: print("r", end='') self.truncate(basepoint) return basepoint def truncate(self, point): if point.x >= self.maxx: point.x = self.maxx if point.y >= self.maxy: point.y = self.maxy if point.x < self.minx: point.x = self.minx if point.y < self.miny: point.y = self.miny def create_image_chaos(self, w, h, iterations, filename, rng_seed): # Always use the same sequence of random numbers # to get reproducible benchmark random.seed(rng_seed) im = [[1] * h for i in range(w)] point = GVector((self.maxx + self.minx) / 2, (self.maxy + self.miny) / 2, 0) for _ in range(iterations): point = self.transform_point(point) x = (point.x - self.minx) / self.width * w y = (point.y - self.miny) / self.height * h x = int(x) y = int(y) if x == w: x -= 1 if y == h: y -= 1 im[x][h - y - 1] = 0 if filename: write_ppm(im, filename) def main(): splines = [ Spline([ GVector(1.597350, 3.304460, 0.000000), GVector(1.575810, 4.123260, 0.000000), GVector(1.313210, 5.288350, 0.000000), GVector(1.618900, 5.329910, 0.000000), GVector(2.889940, 5.502700, 0.000000), GVector(2.373060, 4.381830, 0.000000), GVector(1.662000, 4.360280, 0.000000)], 3, [0, 0, 0, 1, 1, 1, 2, 2, 2]), Spline([ GVector(2.804500, 4.017350, 0.000000), GVector(2.550500, 3.525230, 0.000000), GVector(1.979010, 2.620360, 0.000000), GVector(1.979010, 2.620360, 0.000000)], 3, [0, 0, 0, 1, 1, 1]), Spline([ GVector(2.001670, 4.011320, 0.000000), GVector(2.335040, 3.312830, 0.000000), GVector(2.366800, 3.233460, 0.000000), GVector(2.366800, 3.233460, 0.000000)], 3, [0, 0, 0, 1, 1, 1]) ] chaos = Chaosgame(splines, DEFAULT_THICKNESS) chaos.create_image_chaos(DEFAULT_WIDTH, DEFAULT_HEIGHT, DEFAULT_ITERATIONS, sys.argv[1], DEFAULT_RNG_SEED) t0 = time.time() main() t1 = time.time() print(t1 - t0) ================================================ FILE: bench/codon/fannkuch.codon ================================================ # FANNKUCH benchmark from math import factorial as fact from sys import argv from time import time def perm(n, i): p = [0] * n for k in range(n): f = fact(n - 1 - k) p[k] = i // f i = i % f for k in range(n - 1, -1, -1): for j in range(k - 1, -1, -1): if p[j] <= p[k]: p[k] += 1 return p n = int(argv[1]) max_flips = 0 t0 = time() @par(schedule='dynamic', num_threads=4) for idx in range(fact(n)): p = perm(n, idx) flips = 0 k = p[0] while k: i = 0 j = k while i < j: p[i], p[j] = p[j], p[i] i += 1 j -= 1 k = p[0] flips += 1 max_flips = max(flips, max_flips) print(f'Pfannkuchen({n}) = {max_flips}') t1 = time() print(t1 - t0) ================================================ FILE: bench/codon/fannkuch.py ================================================ # FANNKUCH benchmark from math import factorial as fact from sys import argv from time import time def perm(n, i): p = [0] * n for k in range(n): f = fact(n - 1 - k) p[k] = i // f i = i % f for k in range(n - 1, -1, -1): for j in range(k - 1, -1, -1): if p[j] <= p[k]: p[k] += 1 return p n = int(argv[1]) max_flips = 0 t0 = time() for idx in range(fact(n)): p = perm(n, idx) flips = 0 k = p[0] while k: i = 0 j = k while i < j: p[i], p[j] = p[j], p[i] i += 1 j -= 1 k = p[0] flips += 1 max_flips = max(flips, max_flips) print(f'Pfannkuchen({n}) = {max_flips}') t1 = time() print(t1 - t0) ================================================ FILE: bench/codon/float.py ================================================ from math import sin, cos, sqrt from time import time POINTS = 10000000 class Point: x: float y: float z: float def __init__(self, i): self.x = x = sin(i) self.y = cos(i) * 3 self.z = (x * x) / 2 def __repr__(self): return f"" def normalize(self): x = self.x y = self.y z = self.z norm = sqrt(x * x + y * y + z * z) self.x /= norm self.y /= norm self.z /= norm def maximize(self, other): self.x = self.x if self.x > other.x else other.x self.y = self.y if self.y > other.y else other.y self.z = self.z if self.z > other.z else other.z return self def maximize(points): next = points[0] for p in points[1:]: next = next.maximize(p) return next def benchmark(n): points = [None] * n for i in range(n): points[i] = Point(i) for p in points: p.normalize() return maximize(points) t0 = time() print(benchmark(POINTS)) t1 = time() print(t1 - t0) ================================================ FILE: bench/codon/go.codon ================================================ """ Go board game """ import math import random from time import time SIZE = 9 GAMES = 200 KOMI = 7.5 EMPTY, WHITE, BLACK = 0, 1, 2 SHOW = {EMPTY: '.', WHITE: 'o', BLACK: 'x'} PASS = -1 MAXMOVES = SIZE * SIZE * 3 TIMESTAMP = 0 MOVES = 0 def to_pos(x, y): return y * SIZE + x def to_xy(pos): y, x = divmod(pos, SIZE) return x, y @dataclass(init=False) class Square[Board]: board: Board pos: int timestamp: int removestamp: int zobrist_strings: List[int] neighbours: Optional[List[Square[Board]]] color: int used: bool reference: Optional[Square[Board]] ledges: int temp_ledges: int def __init__(self, board, pos): self.board = board self.pos = pos self.timestamp = TIMESTAMP self.removestamp = TIMESTAMP self.zobrist_strings = [random.randrange(9223372036854775807) for i in range(3)] self.neighbours = None self.color = EMPTY self.used = False self.reference = None self.ledges = 0 self.temp_ledges = 0 def set_neighbours(self): x, y = self.pos % SIZE, self.pos // SIZE self.neighbours = [] for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]: newx, newy = x + dx, y + dy if 0 <= newx < SIZE and 0 <= newy < SIZE: self.neighbours.append(self.board.squares[to_pos(newx, newy)]) def move(self, color): global TIMESTAMP, MOVES TIMESTAMP += 1 MOVES += 1 self.board.zobrist.update(self, color) self.color = color self.reference = self self.ledges = 0 self.used = True for neighbour in self.neighbours: neighcolor = neighbour.color if neighcolor == EMPTY: self.ledges += 1 else: neighbour_ref = neighbour.find(update=True) if neighcolor == color: if neighbour_ref.reference.pos != self.pos: self.ledges += neighbour_ref.ledges neighbour_ref.reference = self self.ledges -= 1 else: neighbour_ref.ledges -= 1 if neighbour_ref.ledges == 0: neighbour.remove(neighbour_ref) self.board.zobrist.add() def remove(self, reference, update=True): self.board.zobrist.update(self, EMPTY) self.removestamp = TIMESTAMP if update: self.color = EMPTY self.board.emptyset.add(self.pos) # if color == BLACK: # self.board.black_dead += 1 # else: # self.board.white_dead += 1 for neighbour in self.neighbours: if neighbour.color != EMPTY and neighbour.removestamp != TIMESTAMP: neighbour_ref = neighbour.find(update) if neighbour_ref.pos == reference.pos: neighbour.remove(reference, update) else: if update: neighbour_ref.ledges += 1 def find(self, update=False): reference = self.reference if reference.pos != self.pos: reference = reference.find(update) if update: self.reference = reference return reference def __repr__(self): return repr(to_xy(self.pos)) class EmptySet[Board]: board: Board empties: List[int] empty_pos: List[int] def __init__(self, board): self.board = board self.empties = list(range(SIZE * SIZE)) self.empty_pos = list(range(SIZE * SIZE)) def random_choice(self): choices = len(self.empties) while choices: i = int(random.random() * choices) pos = self.empties[i] if self.board.useful(pos): return pos choices -= 1 self.set(i, self.empties[choices]) self.set(choices, pos) return PASS def add(self, pos): self.empty_pos[pos] = len(self.empties) self.empties.append(pos) def remove(self, pos): self.set(self.empty_pos[pos], self.empties[len(self.empties) - 1]) self.empties.pop() def set(self, i, pos): self.empties[i] = pos self.empty_pos[pos] = i class ZobristHash[Board]: board: Board hash_set: Set[int] hash: int def __init__(self, board): self.board = board self.hash_set = set() self.hash = 0 for square in self.board.squares: self.hash ^= square.zobrist_strings[EMPTY] self.hash_set.clear() self.hash_set.add(self.hash) def update(self, square, color): self.hash ^= square.zobrist_strings[square.color] self.hash ^= square.zobrist_strings[color] def add(self): self.hash_set.add(self.hash) def dupe(self): return self.hash in self.hash_set class Board: squares: List[Square[Board]] emptyset: EmptySet[Board] zobrist: ZobristHash[Board] color: int finished: bool lastmove: int history: List[int] white_dead: int black_dead: int def __init__(self): self.squares = [Square(self, pos) for pos in range(SIZE * SIZE)] for square in self.squares: square.set_neighbours() self.reset() def reset(self): for square in self.squares: square.color = EMPTY square.used = False self.emptyset = EmptySet(self) self.zobrist = ZobristHash(self) self.color = BLACK self.finished = False self.lastmove = -2 self.history = [] self.white_dead = 0 self.black_dead = 0 def move(self, pos): square = self.squares[pos] if pos != PASS: square.move(self.color) self.emptyset.remove(square.pos) elif self.lastmove == PASS: self.finished = True if self.color == BLACK: self.color = WHITE else: self.color = BLACK self.lastmove = pos self.history.append(pos) def random_move(self): return self.emptyset.random_choice() def useful_fast(self, square): if not square.used: for neighbour in square.neighbours: if neighbour.color == EMPTY: return True return False def useful(self, pos): global TIMESTAMP TIMESTAMP += 1 square = self.squares[pos] if self.useful_fast(square): return True old_hash = self.zobrist.hash self.zobrist.update(square, self.color) empties = opps = weak_opps = neighs = weak_neighs = 0 for neighbour in square.neighbours: neighcolor = neighbour.color if neighcolor == EMPTY: empties += 1 continue neighbour_ref = neighbour.find() if neighbour_ref.timestamp != TIMESTAMP: if neighcolor == self.color: neighs += 1 else: opps += 1 neighbour_ref.timestamp = TIMESTAMP neighbour_ref.temp_ledges = neighbour_ref.ledges neighbour_ref.temp_ledges -= 1 if neighbour_ref.temp_ledges == 0: if neighcolor == self.color: weak_neighs += 1 else: weak_opps += 1 neighbour_ref.remove(neighbour_ref, update=False) dupe = self.zobrist.dupe() self.zobrist.hash = old_hash strong_neighs = neighs - weak_neighs strong_opps = opps - weak_opps return not dupe and \ (empties or weak_opps or (strong_neighs and (strong_opps or weak_neighs))) def useful_moves(self): return [pos for pos in self.emptyset.empties if self.useful(pos)] def replay(self, history): for pos in history: self.move(pos) def score(self, color): if color == WHITE: count = KOMI + self.black_dead else: count = float(self.white_dead) for square in self.squares: squarecolor = square.color if squarecolor == color: count += 1 elif squarecolor == EMPTY: surround = 0 for neighbour in square.neighbours: if neighbour.color == color: surround += 1 if surround == len(square.neighbours): count += 1 return count def check(self): for square in self.squares: if square.color == EMPTY: continue members1 = set([square]) changed = True while changed: changed = False for member in members1.copy(): for neighbour in member.neighbours: if neighbour.color == square.color and neighbour not in members1: changed = True members1.add(neighbour) ledges1 = 0 for member in members1: for neighbour in member.neighbours: if neighbour.color == EMPTY: ledges1 += 1 root = square.find() # print 'members1', square, root, members1 # print 'ledges1', square, ledges1 members2 = set() for square2 in self.squares: if square2.color != EMPTY and square2.find() == root: members2.add(square2) ledges2 = root.ledges # print 'members2', square, root, members1 # print 'ledges2', square, ledges2 assert members1 == members2 assert ledges1 == ledges2 set(self.emptyset.empties) empties2 = set() for square in self.squares: if square.color == EMPTY: empties2.add(square.pos) def __repr__(self): result = [] for y in range(SIZE): start = to_pos(0, y) result.append(''.join( [SHOW[square.color] + ' ' for square in self.squares[start:start + SIZE]])) return '\n'.join(result) class UCTNode: bestchild: Optional[UCTNode] pos: int wins: int losses: int pos_child: List[Optional[UCTNode]] parent: Optional[UCTNode] unexplored: List[int] def __init__(self): self.bestchild = None self.pos = -1 self.wins = 0 self.losses = 0 self.pos_child = [None for x in range(SIZE * SIZE)] self.parent = None self.unexplored = [] def play(self, board): """ uct tree search """ color = board.color node = self path = [node] while True: pos = node.select(board) if pos == PASS: break board.move(pos) child = node.pos_child[pos] if not child: child = node.pos_child[pos] = UCTNode() child.unexplored = board.useful_moves() child.pos = pos child.parent = node path.append(child) break path.append(child) node = child self.random_playout(board) self.update_path(board, color, path) def select(self, board): """ select move; unexplored children first, then according to uct value """ if self.unexplored: i = random.randrange(len(self.unexplored)) pos = self.unexplored[i] self.unexplored[i] = self.unexplored[len(self.unexplored) - 1] self.unexplored.pop() return pos elif self.bestchild: return self.bestchild.pos else: return PASS def random_playout(self, board): """ random play until both players pass """ for x in range(MAXMOVES): # XXX while not self.finished? if board.finished: break board.move(board.random_move()) def update_path(self, board, color, path): """ update win/loss count along path """ wins = board.score(BLACK) >= board.score(WHITE) for node in path: if color == BLACK: color = WHITE else: color = BLACK if wins == (color == BLACK): node.wins += 1 else: node.losses += 1 if node.parent: node.parent.bestchild = node.parent.best_child() def score(self): winrate = self.wins / float(self.wins + self.losses) parentvisits = self.parent.wins + self.parent.losses if not parentvisits: return winrate nodevisits = self.wins + self.losses return winrate + math.sqrt((math.log(parentvisits)) / (5 * nodevisits)) def best_child(self): maxscore = -1. maxchild = None for child in self.pos_child: if child and child.score() > maxscore: maxchild = child maxscore = child.score() return maxchild def best_visited(self): maxvisits = -1 maxchild = None for child in self.pos_child: # if child: # print to_xy(child.pos), child.wins, child.losses, child.score() if child and (child.wins + child.losses) > maxvisits: maxvisits, maxchild = (child.wins + child.losses), child return maxchild # def user_move(board): # while True: # text = input('?').strip() # if text == 'p': # return PASS # if text == 'q': # raise EOFError # try: # x, y = [int(i) for i in text.split()] # except ValueError: # continue # if not (0 <= x < SIZE and 0 <= y < SIZE): # continue # pos = to_pos(x, y) # if board.useful(pos): # return pos def computer_move(board): pos = board.random_move() if pos == PASS: return PASS tree = UCTNode() tree.unexplored = board.useful_moves() nboard = Board() for game in range(GAMES): node = tree nboard.reset() nboard.replay(board.history) node.play(nboard) return tree.best_visited().pos def versus_cpu(): for i in range(100): random.seed(i) board = Board() computer_move(board) if __name__ == "__main__": t0 = time() versus_cpu() t1 = time() print(t1 - t0) ================================================ FILE: bench/codon/go.py ================================================ """ Go board game """ import math import random from time import time SIZE = 9 GAMES = 200 KOMI = 7.5 EMPTY, WHITE, BLACK = 0, 1, 2 SHOW = {EMPTY: '.', WHITE: 'o', BLACK: 'x'} PASS = -1 MAXMOVES = SIZE * SIZE * 3 TIMESTAMP = 0 MOVES = 0 def to_pos(x, y): return y * SIZE + x def to_xy(pos): y, x = divmod(pos, SIZE) return x, y class Square: def __init__(self, board, pos): self.board = board self.pos = pos self.timestamp = TIMESTAMP self.removestamp = TIMESTAMP self.zobrist_strings = [random.randrange(9223372036854775807) for i in range(3)] def set_neighbours(self): x, y = self.pos % SIZE, self.pos // SIZE self.neighbours = [] for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]: newx, newy = x + dx, y + dy if 0 <= newx < SIZE and 0 <= newy < SIZE: self.neighbours.append(self.board.squares[to_pos(newx, newy)]) def move(self, color): global TIMESTAMP, MOVES TIMESTAMP += 1 MOVES += 1 self.board.zobrist.update(self, color) self.color = color self.reference = self self.ledges = 0 self.used = True for neighbour in self.neighbours: neighcolor = neighbour.color if neighcolor == EMPTY: self.ledges += 1 else: neighbour_ref = neighbour.find(update=True) if neighcolor == color: if neighbour_ref.reference.pos != self.pos: self.ledges += neighbour_ref.ledges neighbour_ref.reference = self self.ledges -= 1 else: neighbour_ref.ledges -= 1 if neighbour_ref.ledges == 0: neighbour.remove(neighbour_ref) self.board.zobrist.add() def remove(self, reference, update=True): self.board.zobrist.update(self, EMPTY) self.removestamp = TIMESTAMP if update: self.color = EMPTY self.board.emptyset.add(self.pos) # if color == BLACK: # self.board.black_dead += 1 # else: # self.board.white_dead += 1 for neighbour in self.neighbours: if neighbour.color != EMPTY and neighbour.removestamp != TIMESTAMP: neighbour_ref = neighbour.find(update) if neighbour_ref.pos == reference.pos: neighbour.remove(reference, update) else: if update: neighbour_ref.ledges += 1 def find(self, update=False): reference = self.reference if reference.pos != self.pos: reference = reference.find(update) if update: self.reference = reference return reference def __repr__(self): return repr(to_xy(self.pos)) class EmptySet: def __init__(self, board): self.board = board self.empties = list(range(SIZE * SIZE)) self.empty_pos = list(range(SIZE * SIZE)) def random_choice(self): choices = len(self.empties) while choices: i = int(random.random() * choices) pos = self.empties[i] if self.board.useful(pos): return pos choices -= 1 self.set(i, self.empties[choices]) self.set(choices, pos) return PASS def add(self, pos): self.empty_pos[pos] = len(self.empties) self.empties.append(pos) def remove(self, pos): self.set(self.empty_pos[pos], self.empties[len(self.empties) - 1]) self.empties.pop() def set(self, i, pos): self.empties[i] = pos self.empty_pos[pos] = i class ZobristHash: def __init__(self, board): self.board = board self.hash_set = set() self.hash = 0 for square in self.board.squares: self.hash ^= square.zobrist_strings[EMPTY] self.hash_set.clear() self.hash_set.add(self.hash) def update(self, square, color): self.hash ^= square.zobrist_strings[square.color] self.hash ^= square.zobrist_strings[color] def add(self): self.hash_set.add(self.hash) def dupe(self): return self.hash in self.hash_set class Board: def __init__(self): self.squares = [Square(self, pos) for pos in range(SIZE * SIZE)] for square in self.squares: square.set_neighbours() self.reset() def reset(self): for square in self.squares: square.color = EMPTY square.used = False self.emptyset = EmptySet(self) self.zobrist = ZobristHash(self) self.color = BLACK self.finished = False self.lastmove = -2 self.history = [] self.white_dead = 0 self.black_dead = 0 def move(self, pos): square = self.squares[pos] if pos != PASS: square.move(self.color) self.emptyset.remove(square.pos) elif self.lastmove == PASS: self.finished = True if self.color == BLACK: self.color = WHITE else: self.color = BLACK self.lastmove = pos self.history.append(pos) def random_move(self): return self.emptyset.random_choice() def useful_fast(self, square): if not square.used: for neighbour in square.neighbours: if neighbour.color == EMPTY: return True return False def useful(self, pos): global TIMESTAMP TIMESTAMP += 1 square = self.squares[pos] if self.useful_fast(square): return True old_hash = self.zobrist.hash self.zobrist.update(square, self.color) empties = opps = weak_opps = neighs = weak_neighs = 0 for neighbour in square.neighbours: neighcolor = neighbour.color if neighcolor == EMPTY: empties += 1 continue neighbour_ref = neighbour.find() if neighbour_ref.timestamp != TIMESTAMP: if neighcolor == self.color: neighs += 1 else: opps += 1 neighbour_ref.timestamp = TIMESTAMP neighbour_ref.temp_ledges = neighbour_ref.ledges neighbour_ref.temp_ledges -= 1 if neighbour_ref.temp_ledges == 0: if neighcolor == self.color: weak_neighs += 1 else: weak_opps += 1 neighbour_ref.remove(neighbour_ref, update=False) dupe = self.zobrist.dupe() self.zobrist.hash = old_hash strong_neighs = neighs - weak_neighs strong_opps = opps - weak_opps return not dupe and \ (empties or weak_opps or (strong_neighs and (strong_opps or weak_neighs))) def useful_moves(self): return [pos for pos in self.emptyset.empties if self.useful(pos)] def replay(self, history): for pos in history: self.move(pos) def score(self, color): if color == WHITE: count = KOMI + self.black_dead else: count = self.white_dead for square in self.squares: squarecolor = square.color if squarecolor == color: count += 1 elif squarecolor == EMPTY: surround = 0 for neighbour in square.neighbours: if neighbour.color == color: surround += 1 if surround == len(square.neighbours): count += 1 return count def check(self): for square in self.squares: if square.color == EMPTY: continue members1 = set([square]) changed = True while changed: changed = False for member in members1.copy(): for neighbour in member.neighbours: if neighbour.color == square.color and neighbour not in members1: changed = True members1.add(neighbour) ledges1 = 0 for member in members1: for neighbour in member.neighbours: if neighbour.color == EMPTY: ledges1 += 1 root = square.find() # print 'members1', square, root, members1 # print 'ledges1', square, ledges1 members2 = set() for square2 in self.squares: if square2.color != EMPTY and square2.find() == root: members2.add(square2) ledges2 = root.ledges # print 'members2', square, root, members1 # print 'ledges2', square, ledges2 assert members1 == members2 assert ledges1 == ledges2, ('ledges differ at %r: %d %d' % ( square, ledges1, ledges2)) set(self.emptyset.empties) empties2 = set() for square in self.squares: if square.color == EMPTY: empties2.add(square.pos) def __repr__(self): result = [] for y in range(SIZE): start = to_pos(0, y) result.append(''.join( [SHOW[square.color] + ' ' for square in self.squares[start:start + SIZE]])) return '\n'.join(result) class UCTNode: def __init__(self): self.bestchild = None self.pos = -1 self.wins = 0 self.losses = 0 self.pos_child = [None for x in range(SIZE * SIZE)] self.parent = None def play(self, board): """ uct tree search """ color = board.color node = self path = [node] while True: pos = node.select(board) if pos == PASS: break board.move(pos) child = node.pos_child[pos] if not child: child = node.pos_child[pos] = UCTNode() child.unexplored = board.useful_moves() child.pos = pos child.parent = node path.append(child) break path.append(child) node = child self.random_playout(board) self.update_path(board, color, path) def select(self, board): """ select move; unexplored children first, then according to uct value """ if self.unexplored: i = random.randrange(len(self.unexplored)) pos = self.unexplored[i] self.unexplored[i] = self.unexplored[len(self.unexplored) - 1] self.unexplored.pop() return pos elif self.bestchild: return self.bestchild.pos else: return PASS def random_playout(self, board): """ random play until both players pass """ for x in range(MAXMOVES): # XXX while not self.finished? if board.finished: break board.move(board.random_move()) def update_path(self, board, color, path): """ update win/loss count along path """ wins = board.score(BLACK) >= board.score(WHITE) for node in path: if color == BLACK: color = WHITE else: color = BLACK if wins == (color == BLACK): node.wins += 1 else: node.losses += 1 if node.parent: node.parent.bestchild = node.parent.best_child() def score(self): winrate = self.wins / float(self.wins + self.losses) parentvisits = self.parent.wins + self.parent.losses if not parentvisits: return winrate nodevisits = self.wins + self.losses return winrate + math.sqrt((math.log(parentvisits)) / (5 * nodevisits)) def best_child(self): maxscore = -1 maxchild = None for child in self.pos_child: if child and child.score() > maxscore: maxchild = child maxscore = child.score() return maxchild def best_visited(self): maxvisits = -1 maxchild = None for child in self.pos_child: # if child: # print to_xy(child.pos), child.wins, child.losses, child.score() if child and (child.wins + child.losses) > maxvisits: maxvisits, maxchild = (child.wins + child.losses), child return maxchild # def user_move(board): # while True: # text = input('?').strip() # if text == 'p': # return PASS # if text == 'q': # raise EOFError # try: # x, y = [int(i) for i in text.split()] # except ValueError: # continue # if not (0 <= x < SIZE and 0 <= y < SIZE): # continue # pos = to_pos(x, y) # if board.useful(pos): # return pos def computer_move(board): pos = board.random_move() if pos == PASS: return PASS tree = UCTNode() tree.unexplored = board.useful_moves() nboard = Board() for game in range(GAMES): node = tree nboard.reset() nboard.replay(board.history) node.play(nboard) return tree.best_visited().pos def versus_cpu(): for i in range(100): random.seed(i) board = Board() computer_move(board) if __name__ == "__main__": t0 = time() versus_cpu() t1 = time() print(t1 - t0) ================================================ FILE: bench/codon/mandelbrot.codon ================================================ import time MAX = 1000 # maximum Mandelbrot iterations N = 4096 # width and height of image pixels = [0 for _ in range(N * N)] def scale(x, a, b): return a + (x/N)*(b - a) t0 = time.time() @par(gpu=True, collapse=2) for i in range(N): for j in range(N): c = complex(scale(j, -2.00, 0.47), scale(i, -1.12, 1.12)) z = 0j iteration = 0 while abs(z) <= 2 and iteration < MAX: z = z**2 + c iteration += 1 pixels[i*N + j] = int(255 * iteration/MAX) print(sum(pixels)) print(time.time() - t0) ================================================ FILE: bench/codon/mandelbrot.py ================================================ import time MAX = 1000 # maximum Mandelbrot iterations N = 4096 # width and height of image pixels = [0 for _ in range(N * N)] def scale(x, a, b): return a + (x/N)*(b - a) t0 = time.time() for i in range(N): for j in range(N): c = complex(scale(j, -2.00, 0.47), scale(i, -1.12, 1.12)) z = 0j iteration = 0 while abs(z) <= 2 and iteration < MAX: z = z**2 + c iteration += 1 pixels[i*N + j] = int(255 * iteration/MAX) print(sum(pixels)) print(time.time() - t0) ================================================ FILE: bench/codon/nbody.cpp ================================================ #include #include #include #include #include #include #include namespace { const double PI = 3.14159265358979323; const double SOLAR_MASS = 4 * PI * PI; const double DAYS_PER_YEAR = 365.24; struct Body { std::vector r, v; double m; }; std::unordered_map BODIES = { {"sun", {{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, SOLAR_MASS}}, {"jupiter", {{4.84143144246472090e+00, -1.16032004402742839e+00, -1.03622044471123109e-01}, {1.66007664274403694e-03 * DAYS_PER_YEAR, 7.69901118419740425e-03 * DAYS_PER_YEAR, -6.90460016972063023e-05 * DAYS_PER_YEAR}, 9.54791938424326609e-04 * SOLAR_MASS}}, {"saturn", {{8.34336671824457987e+00, 4.12479856412430479e+00, -4.03523417114321381e-01}, {-2.76742510726862411e-03 * DAYS_PER_YEAR, 4.99852801234917238e-03 * DAYS_PER_YEAR, 2.30417297573763929e-05 * DAYS_PER_YEAR}, 2.85885980666130812e-04 * SOLAR_MASS}}, {"uranus", {{1.28943695621391310e+01, -1.51111514016986312e+01, -2.23307578892655734e-01}, {2.96460137564761618e-03 * DAYS_PER_YEAR, 2.37847173959480950e-03 * DAYS_PER_YEAR, -2.96589568540237556e-05 * DAYS_PER_YEAR}, 4.36624404335156298e-05 * SOLAR_MASS}}, {"neptune", {{1.53796971148509165e+01, -2.59193146099879641e+01, 1.79258772950371181e-01}, {2.68067772490389322e-03 * DAYS_PER_YEAR, 1.62824170038242295e-03 * DAYS_PER_YEAR, -9.51592254519715870e-05 * DAYS_PER_YEAR}, 5.15138902046611451e-05 * SOLAR_MASS}}, }; template auto values(std::unordered_map &m) { std::vector v; v.reserve(m.size()); for (auto &e : m) v.push_back(&e.second); return v; } template auto combinations(const std::vector &v) { std::vector> p; auto n = v.size(); p.reserve(n); for (auto i = 0; i < n - 1; i++) for (auto j = i + 1; j < n; j++) p.push_back({v[i], v[j]}); return p; } std::vector SYSTEM = values(BODIES); auto PAIRS = combinations(SYSTEM); void advance(double dt, int n, std::vector &bodies = SYSTEM, std::vector> &pairs = PAIRS) { for (int i = 0; i < n; i++) { for (auto &pair : pairs) { double x1 = pair.first->r[0], y1 = pair.first->r[1], z1 = pair.first->r[2]; auto &v1 = pair.first->v; double m1 = pair.first->m; double x2 = pair.second->r[0], y2 = pair.second->r[1], z2 = pair.second->r[2]; auto &v2 = pair.second->v; double m2 = pair.second->m; double dx = x1 - x2, dy = y1 - y2, dz = z1 - z2; double mag = dt * std::pow((dx * dx + dy * dy + dz * dz), -1.5); double b1m = m1 * mag; double b2m = m2 * mag; v1[0] -= dx * b2m; v1[1] -= dy * b2m; v1[2] -= dz * b2m; v2[0] += dx * b1m; v2[1] += dy * b1m; v2[2] += dz * b1m; } for (auto *body : bodies) { auto &r = body->r; double vx = body->v[0], vy = body->v[1], vz = body->v[2]; r[0] += dt * vx; r[1] += dt * vy; r[2] += dt * vz; } } } void report_energy(std::vector &bodies = SYSTEM, std::vector> &pairs = PAIRS, double e = 0.0) { for (auto &pair : pairs) { double x1 = pair.first->r[0], y1 = pair.first->r[1], z1 = pair.first->r[2]; auto &v1 = pair.first->v; double m1 = pair.first->m; double x2 = pair.second->r[0], y2 = pair.second->r[1], z2 = pair.second->r[2]; auto &v2 = pair.second->v; double m2 = pair.second->m; double dx = x1 - x2, dy = y1 - y2, dz = z1 - z2; e -= (m1 * m2) / std::pow((dx * dx + dy * dy + dz * dz), 0.5); } for (auto *body : bodies) { double vx = body->v[0], vy = body->v[1], vz = body->v[2]; double m = body->m; e += m * (vx * vx + vy * vy + vz * vz) / 2.; } std::cout << e << std::endl; } void offset_momentum(Body &ref, std::vector &bodies = SYSTEM, double px = 0.0, double py = 0.0, double pz = 0.0) { for (auto *body : bodies) { double vx = body->v[0], vy = body->v[1], vz = body->v[2]; double m = body->m; px -= vx * m; py -= vy * m; pz -= vz * m; } auto &v = ref.v; double m = ref.m; v[0] = px / m; v[1] = py / m; v[2] = pz / m; } } // namespace int main(int argc, char *argv[]) { using clock = std::chrono::high_resolution_clock; using std::chrono::duration_cast; using std::chrono::milliseconds; auto t = clock::now(); std::string ref = "sun"; offset_momentum(BODIES[ref]); report_energy(); advance(0.01, std::atoi(argv[1])); report_energy(); std::cout << (duration_cast(clock::now() - t).count() / 1e3) << std::endl; } ================================================ FILE: bench/codon/nbody.py ================================================ # The Computer Language Benchmarks Game # http://benchmarksgame.alioth.debian.org/ # # originally by Kevin Carson # modified by Tupteq, Fredrik Johansson, and Daniel Nanz # modified by Maciej Fijalkowski # modified by @arshajii # 2to3 from time import time import sys def combinations(l): result = [] for x in range(len(l) - 1): ls = l[x+1:] for y in ls: result.append((l[x],y)) return result PI = 3.14159265358979323 SOLAR_MASS = 4 * PI * PI DAYS_PER_YEAR = 365.24 BODIES = { 'sun': ([0.0, 0.0, 0.0], [0.0, 0.0, 0.0], SOLAR_MASS), 'jupiter': ([4.84143144246472090e+00, -1.16032004402742839e+00, -1.03622044471123109e-01], [1.66007664274403694e-03 * DAYS_PER_YEAR, 7.69901118419740425e-03 * DAYS_PER_YEAR, -6.90460016972063023e-05 * DAYS_PER_YEAR], 9.54791938424326609e-04 * SOLAR_MASS), 'saturn': ([8.34336671824457987e+00, 4.12479856412430479e+00, -4.03523417114321381e-01], [-2.76742510726862411e-03 * DAYS_PER_YEAR, 4.99852801234917238e-03 * DAYS_PER_YEAR, 2.30417297573763929e-05 * DAYS_PER_YEAR], 2.85885980666130812e-04 * SOLAR_MASS), 'uranus': ([1.28943695621391310e+01, -1.51111514016986312e+01, -2.23307578892655734e-01], [2.96460137564761618e-03 * DAYS_PER_YEAR, 2.37847173959480950e-03 * DAYS_PER_YEAR, -2.96589568540237556e-05 * DAYS_PER_YEAR], 4.36624404335156298e-05 * SOLAR_MASS), 'neptune': ([1.53796971148509165e+01, -2.59193146099879641e+01, 1.79258772950371181e-01], [2.68067772490389322e-03 * DAYS_PER_YEAR, 1.62824170038242295e-03 * DAYS_PER_YEAR, -9.51592254519715870e-05 * DAYS_PER_YEAR], 5.15138902046611451e-05 * SOLAR_MASS) } SYSTEM = list(BODIES.values()) PAIRS = combinations(SYSTEM) def advance(dt, n, bodies=SYSTEM, pairs=PAIRS): for i in range(n): for (([x1, y1, z1], v1, m1), ([x2, y2, z2], v2, m2)) in pairs: dx = x1 - x2 dy = y1 - y2 dz = z1 - z2 mag = dt * ((dx * dx + dy * dy + dz * dz) ** (-1.5)) b1m = m1 * mag b2m = m2 * mag v1[0] -= dx * b2m v1[1] -= dy * b2m v1[2] -= dz * b2m v2[0] += dx * b1m v2[1] += dy * b1m v2[2] += dz * b1m for (r, [vx, vy, vz], m) in bodies: r[0] += dt * vx r[1] += dt * vy r[2] += dt * vz def report_energy(bodies=SYSTEM, pairs=PAIRS, e=0.0): for (((x1, y1, z1), v1, m1), ((x2, y2, z2), v2, m2)) in pairs: dx = x1 - x2 dy = y1 - y2 dz = z1 - z2 e -= (m1 * m2) / ((dx * dx + dy * dy + dz * dz) ** 0.5) for (r, [vx, vy, vz], m) in bodies: e += m * (vx * vx + vy * vy + vz * vz) / 2. print(e) def offset_momentum(ref, bodies=SYSTEM, px=0.0, py=0.0, pz=0.0): for (r, [vx, vy, vz], m) in bodies: px -= vx * m py -= vy * m pz -= vz * m (r, v, m) = ref v[0] = px / m v[1] = py / m v[2] = pz / m def main(n, ref='sun'): offset_momentum(BODIES[ref]) report_energy() advance(0.01, n) report_energy() t0 = time() main(int(sys.argv[1])) t1 = time() print(t1 - t0) ================================================ FILE: bench/codon/npbench.codon ================================================ import numpy as np from numpy.random import default_rng import npbench_lib as bench import time def run(b, prep, **kwargs): n = prep.__class__.__name__.split("(")[0] with time.timing(f"{n}.prep"): data = prep(**kwargs) with time.timing(f"{n}.run"): if isinstance(data, Tuple): return b(*data) else: return b(data) def rng_complex(shape, rng): return (rng.random(shape) + rng.random(shape) * 1j) def adi(N, TSTEPS, datatype=np.float64): u = np.fromfunction(lambda i, j: (i + N - j) / N, (N, N), dtype=datatype) return TSTEPS, N, u run(bench.adi, adi, N=500, TSTEPS=50) def arc_distance(N): rng = default_rng(42) t0, p0, t1, p1 = rng.random((N, )), rng.random((N, )), rng.random( (N, )), rng.random((N, )) return t0, p0, t1, p1 run(bench.arc_distance, arc_distance, N=10000000) def azimint_naive(N, npt): rng = default_rng(42) data, radius = rng.random((N, )), rng.random((N, )) return data, radius, npt run(bench.azimint_naive, azimint_naive, N=4000000, npt=1000) def azimint_hist(N, npt): rng = default_rng(42) data, radius = rng.random((N, )), rng.random((N, )) return data, radius, npt run(bench.azimint_hist, azimint_hist, N=40000000, npt=1000) def atax(M, N, datatype=np.float64): fn = datatype(N) x = np.fromfunction(lambda i: 1 + (i / fn), (N, ), dtype=datatype) A = np.fromfunction(lambda i, j: ((i + j) % N) / (5 * M), (M, N), dtype=datatype) return A, x run(bench.atax, atax, M=20000, N=25000) def bicg(M, N, datatype=np.float64): A = np.fromfunction(lambda i, j: (i * (j + 1) % N) / N, (N, M), dtype=datatype) p = np.fromfunction(lambda i: (i % M) / M, (M, ), dtype=datatype) r = np.fromfunction(lambda i: (i % N) / N, (N, ), dtype=datatype) return A, p, r run(bench.bicg, bicg, M=20000, N=25000) def cavity_flow(ny, nx, nt, nit, rho, nu): u = np.zeros((ny, nx), dtype=np.float64) v = np.zeros((ny, nx), dtype=np.float64) p = np.zeros((ny, nx), dtype=np.float64) dx = 2 / (nx - 1) dy = 2 / (ny - 1) dt = .1 / ((nx - 1) * (ny - 1)) return nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu run(bench.cavity_flow, cavity_flow, ny=101, nx=101, nt=700, nit=50, rho=1.0, nu=0.1) def channel_flow(ny, nx, nit, rho, nu, F): u = np.zeros((ny, nx), dtype=np.float64) v = np.zeros((ny, nx), dtype=np.float64) p = np.ones((ny, nx), dtype=np.float64) dx = 2 / (nx - 1) dy = 2 / (ny - 1) dt = .1 / ((nx - 1) * (ny - 1)) return nit, u, v, dt, dx, dy, p, rho, nu, F run(bench.channel_flow, channel_flow, ny=101, nx=101, nit=50, rho=1.0, nu=0.1, F=1.0) def cholesky(N, datatype=np.float64): A = np.empty((N, N), dtype=datatype) for i in range(N): A[i, :i + 1] = np.fromfunction(lambda j: (-j % N) / N + 1, (i + 1, ), dtype=datatype) A[i, i + 1:] = 0.0 A[i, i] = 1.0 A[:] = A @ np.transpose(A) return A run(bench.cholesky, cholesky, N=2000) def cholesky2(N, datatype=np.float64): A = np.empty((N, N), dtype=datatype) for i in range(N): A[i, :i + 1] = np.fromfunction(lambda j: (-j % N) / N + 1, (i + 1, ), dtype=datatype) A[i, i + 1:] = 0.0 A[i, i] = 1.0 A[:] = A @ np.transpose(A) return A run(bench.cholesky2, cholesky2, N=8000) def compute(M, N): rng = default_rng(42) array_1 = rng.uniform(0, 1000, size=(M, N)).astype(np.int64) array_2 = rng.uniform(0, 1000, size=(M, N)).astype(np.int64) a = np.int64(4) b = np.int64(3) c = np.int64(9) return array_1, array_2, a, b, c run(bench.compute, compute, M=16000, N=16000) def contour_integral(NR, NM, slab_per_bc, num_int_pts): rng = default_rng(42) Ham = rng_complex((slab_per_bc + 1, NR, NR), rng) int_pts = rng_complex((num_int_pts, ), rng) Y = rng_complex((NR, NM), rng) return NR, NM, slab_per_bc, Ham, int_pts, Y run(bench.contour_integral, contour_integral, NR=600, NM=1000, slab_per_bc=2, num_int_pts=32) def conv2d_bias(C_in, C_out, H, K, N, W): rng = default_rng(42) # NHWC data layout input = rng.random((N, H, W, C_in), dtype=np.float32) # Weights weights = rng.random((K, K, C_in, C_out), dtype=np.float32) bias = rng.random((C_out, ), dtype=np.float32) return input, weights, bias run(bench.conv2d_bias, conv2d_bias, N=8, C_in=3, C_out=16, K=20, H=256, W=256) def correlation(M, N, datatype=np.float64): float_n = datatype(N) data = np.fromfunction(lambda i, j: (i * j) / M + i, (N, M), dtype=datatype) return M, float_n, data run(bench.correlation, correlation, M=3200, N=4000) def covariance(M, N, datatype=np.float64): float_n = datatype(N) data = np.fromfunction(lambda i, j: (i * j) / M, (N, M), dtype=datatype) return M, float_n, data run(bench.covariance, covariance, M=3200, N=4000) def crc16(N): rng = default_rng(42) data = rng.integers(0, 256, size=(N, ), dtype=np.uint8) return data run(bench.crc16, crc16, N=1000000) def deriche(W, H, datatype=np.float64): alpha = datatype(0.25) imgIn = np.fromfunction(lambda i, j: ((313 * i + 991 * j) % 65536) / 65535.0, (W, H), dtype=datatype) return alpha, imgIn run(bench.deriche, deriche, W=7680, H=4320) def doitgen(NR, NQ, NP, datatype=np.float64): A = np.fromfunction(lambda i, j, k: ((i * j + k) % NP) / NP, (NR, NQ, NP), dtype=datatype) C4 = np.fromfunction(lambda i, j: (i * j % NP) / NP, (NP, NP), dtype=datatype) return NR, NQ, NP, A, C4 run(bench.doitgen, doitgen, NR=220, NQ=250, NP=512) def durbin(N, datatype=np.float64): r = np.fromfunction(lambda i: N + 1 - i, (N, ), dtype=datatype) return r run(bench.durbin, durbin, N=20000) def fdtd_2d(TMAX, NX, NY, datatype=np.float64): ex = np.fromfunction(lambda i, j: (i * (j + 1)) / NX, (NX, NY), dtype=datatype) ey = np.fromfunction(lambda i, j: (i * (j + 2)) / NY, (NX, NY), dtype=datatype) hz = np.fromfunction(lambda i, j: (i * (j + 3)) / NX, (NX, NY), dtype=datatype) _fict_ = np.fromfunction(lambda i: i, (TMAX, ), dtype=datatype) return TMAX, ex, ey, hz, _fict_ run(bench.fdtd_2d, fdtd_2d, TMAX=500, NX=1000, NY=1200) def floyd_warshall(N, datatype=np.int32): path = np.fromfunction(lambda i, j: i * j % 7i32 + 1i32, (N, N), dtype=datatype) for i in range(N): for j in range(N): if (i + j) % 13 == 0 or (i + j) % 7 == 0 or (i + j) % 11 == 0: path[i, j] = 999 return path run(bench.floyd_warshall, floyd_warshall, N=850) def gemm(NI, NJ, NK, datatype=np.float64): alpha = datatype(1.5) beta = datatype(1.2) C = np.fromfunction(lambda i, j: ((i * j + 1) % NI) / NI, (NI, NJ), dtype=datatype) A = np.fromfunction(lambda i, k: (i * (k + 1) % NK) / NK, (NI, NK), dtype=datatype) B = np.fromfunction(lambda k, j: (k * (j + 2) % NJ) / NJ, (NK, NJ), dtype=datatype) return alpha, beta, C, A, B run(bench.gemm, gemm, NI=7000, NJ=7500, NK=8000) def gemver(N, datatype=np.float64): alpha = datatype(1.5) beta = datatype(1.2) fn = datatype(N) A = np.fromfunction(lambda i, j: (i * j % N) / N, (N, N), dtype=datatype) u1 = np.fromfunction(lambda i: i, (N, ), dtype=datatype) u2 = np.fromfunction(lambda i: ((i + 1) / fn) / 2.0, (N, ), dtype=datatype) v1 = np.fromfunction(lambda i: ((i + 1) / fn) / 4.0, (N, ), dtype=datatype) v2 = np.fromfunction(lambda i: ((i + 1) / fn) / 6.0, (N, ), dtype=datatype) w = np.zeros((N, ), dtype=datatype) x = np.zeros((N, ), dtype=datatype) y = np.fromfunction(lambda i: ((i + 1) / fn) / 8.0, (N, ), dtype=datatype) z = np.fromfunction(lambda i: ((i + 1) / fn) / 9.0, (N, ), dtype=datatype) return alpha, beta, A, u1, v1, u2, v2, w, x, y, z run(bench.gemver, gemver, N=10000) def gesummv(N, datatype=np.float64): alpha = datatype(1.5) beta = datatype(1.2) A = np.fromfunction(lambda i, j: ((i * j + 1) % N) / N, (N, N), dtype=datatype) B = np.fromfunction(lambda i, j: ((i * j + 2) % N) / N, (N, N), dtype=datatype) x = np.fromfunction(lambda i: (i % N) / N, (N, ), dtype=datatype) return alpha, beta, A, B, x run(bench.gesummv, gesummv, N=14000) def go_fast(N): rng = default_rng(42) a = rng.random((N, N), dtype=np.float64) return a run(bench.go_fast, go_fast, N=20000) def gramschmidt(M, N, datatype=np.float64): rng = default_rng(42) A = rng.random((M, N), dtype=datatype) while np.linalg.matrix_rank(A) < N: A = rng.random((M, N), dtype=datatype) return A run(bench.gramschmidt, gramschmidt, M=600, N=500) def hdiff(I, J, K): rng = default_rng(42) # Define arrays in_field = rng.random((I + 4, J + 4, K)) out_field = rng.random((I, J, K)) coeff = rng.random((I, J, K)) return in_field, out_field, coeff run(bench.hdiff, hdiff, I=384, J=384, K=160) def heat_3d(N, TSTEPS, datatype=np.float64): A = np.fromfunction(lambda i, j, k: (i + j + (N - k)) * 10 / N, (N, N, N), dtype=datatype) B = A.copy() # TODO: np.copy(A) return TSTEPS, A, B run(bench.heat_3d, heat_3d, N=70, TSTEPS=100) def jacobi_1d(N, TSTEPS, datatype=np.float64): A = np.fromfunction(lambda i: (i + 2) / N, (N, ), dtype=datatype) B = np.fromfunction(lambda i: (i + 3) / N, (N, ), dtype=datatype) return TSTEPS, A, B run(bench.jacobi_1d, jacobi_1d, N=34000, TSTEPS=8500) def jacobi_2d(N, TSTEPS, datatype=np.float64): A = np.fromfunction(lambda i, j: i * (j + 2) / N, (N, N), dtype=datatype) B = np.fromfunction(lambda i, j: i * (j + 3) / N, (N, N), dtype=datatype) return TSTEPS, A, B run(bench.jacobi_2d, jacobi_2d, N=700, TSTEPS=200) def k2mm(NI, NJ, NK, NL, datatype=np.float64): alpha = datatype(1.5) beta = datatype(1.2) A = np.fromfunction(lambda i, j: ((i * j + 1) % NI) / NI, (NI, NK), dtype=datatype) B = np.fromfunction(lambda i, j: (i * (j + 1) % NJ) / NJ, (NK, NJ), dtype=datatype) C = np.fromfunction(lambda i, j: ((i * (j + 3) + 1) % NL) / NL, (NJ, NL), dtype=datatype) D = np.fromfunction(lambda i, j: (i * (j + 2) % NK) / NK, (NI, NL), dtype=datatype) return alpha, beta, A, B, C, D run(bench.k2mm, k2mm, NI=6000, NJ=6500, NK=7000, NL=7500) def k3mm(NI, NJ, NK, NL, NM, datatype=np.float64): A = np.fromfunction(lambda i, j: ((i * j + 1) % NI) / (5 * NI), (NI, NK), dtype=datatype) B = np.fromfunction(lambda i, j: ((i * (j + 1) + 2) % NJ) / (5 * NJ), (NK, NJ), dtype=datatype) C = np.fromfunction(lambda i, j: (i * (j + 3) % NL) / (5 * NL), (NJ, NM), dtype=datatype) D = np.fromfunction(lambda i, j: ((i * (j + 2) + 2) % NK) / (5 * NK), (NM, NL), dtype=datatype) return A, B, C, D run(bench.k3mm, k3mm, NI=5500, NJ=6000, NK=6500, NL=7000, NM=7500) def lenet(N, H, W): rng = default_rng(42) H_conv1 = H - 4 W_conv1 = W - 4 H_pool1 = H_conv1 // 2 W_pool1 = W_conv1 // 2 H_conv2 = H_pool1 - 4 W_conv2 = W_pool1 - 4 H_pool2 = H_conv2 // 2 W_pool2 = W_conv2 // 2 C_before_fc1 = 16 * H_pool2 * W_pool2 # NHWC data layout input = rng.random((N, H, W, 1), dtype=np.float32) # Weights conv1 = rng.random((5, 5, 1, 6), dtype=np.float32) conv1bias = rng.random((6, ), dtype=np.float32) conv2 = rng.random((5, 5, 6, 16), dtype=np.float32) conv2bias = rng.random((16, ), dtype=np.float32) fc1w = rng.random((C_before_fc1, 120), dtype=np.float32) fc1b = rng.random((120, ), dtype=np.float32) fc2w = rng.random((120, 84), dtype=np.float32) fc2b = rng.random((84, ), dtype=np.float32) fc3w = rng.random((84, 10), dtype=np.float32) fc3b = rng.random((10, ), dtype=np.float32) return ( input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b, N, C_before_fc1 ) run(bench.lenet, lenet, N=16, H=256, W=256) def lu(N, datatype=np.float64): A = np.empty((N, N), dtype=datatype) for i in range(N): A[i, :i + 1] = np.fromfunction(lambda j: (-j % N) / N + 1, (i + 1, ), dtype=datatype) A[i, i + 1:] = 0.0 A[i, i] = 1.0 A[:] = A @ np.transpose(A) return A run(bench.lu, lu, N=2000) def ludcmp(N, datatype=np.float64): A = np.empty((N, N), dtype=datatype) for i in range(N): A[i, :i + 1] = np.fromfunction(lambda j: (-j % N) / N + 1, (i + 1, ), dtype=datatype) A[i, i + 1:] = 0.0 A[i, i] = 1.0 A[:] = A @ np.transpose(A) fn = datatype(N) b = np.fromfunction(lambda i: (i + 1) / fn / 2.0 + 4.0, (N, ), dtype=datatype) return A, b run(bench.ludcmp, ludcmp, N=2000) def mandelbrot1(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon): return xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon run(bench.mandelbrot1, mandelbrot1, xmin=-2.25, xmax=0.75, xn=1000, ymin=-1.25, ymax=1.25, yn=1000, maxiter=200, horizon=2.0) def mandelbrot2(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon): return xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon run(bench.mandelbrot2, mandelbrot2, xmin=-2.25, xmax=0.75, xn=1000, ymin=-1.25, ymax=1.25, yn=1000, maxiter=200, horizon=2.0) def mlp(C_in, N, S0, S1, S2): rng = default_rng(42) mlp_sizes = [S0, S1, S2] # [300, 100, 10] # Inputs input = np.random.rand(N, C_in).astype(np.float32) # Weights w1 = rng.random((C_in, mlp_sizes[0]), dtype=np.float32) b1 = rng.random((mlp_sizes[0], ), dtype=np.float32) w2 = rng.random((mlp_sizes[0], mlp_sizes[1]), dtype=np.float32) b2 = rng.random((mlp_sizes[1], ), dtype=np.float32) w3 = rng.random((mlp_sizes[1], mlp_sizes[2]), dtype=np.float32) b3 = rng.random((mlp_sizes[2], ), dtype=np.float32) return input, w1, b1, w2, b2, w3, b3 run(bench.mlp, mlp, C_in=3, N=8, S0=30000, S1=30000, S2=30000) def mvt(N, datatype=np.float64): x1 = np.fromfunction(lambda i: (i % N) / N, (N, ), dtype=datatype) x2 = np.fromfunction(lambda i: ((i + 1) % N) / N, (N, ), dtype=datatype) y_1 = np.fromfunction(lambda i: ((i + 3) % N) / N, (N, ), dtype=datatype) y_2 = np.fromfunction(lambda i: ((i + 4) % N) / N, (N, ), dtype=datatype) A = np.fromfunction(lambda i, j: (i * j % N) / N, (N, N), dtype=datatype) return x1, x2, y_1, y_2, A run(bench.mvt, mvt, N=22000) def nbody(N, tEnd, dt, softening, G): rng = default_rng(42) mass = 20.0 * np.ones((N, 1)) / N # total mass of particles is 20 pos = rng.random((N, 3)) # randomly selected positions and velocities vel = rng.random((N, 3)) Nt = int(np.ceil(tEnd / dt)) return mass, pos, vel, N, Nt, dt, G, softening run(bench.nbody, nbody, N=100, tEnd=10.0, dt=0.01, softening=0.1, G=1.0) def nussinov(N, datatype=np.int32): seq = np.fromfunction(lambda i: (i + 1i32) % 4i32, (N, ), dtype=datatype) return N, seq run(bench.nussinov, nussinov, N=500) def resnet(N, W, H, C1, C2): rng = default_rng(42) # Input input = rng.random((N, H, W, C1), dtype=np.float32) # Weights conv1 = rng.random((1, 1, C1, C2), dtype=np.float32) conv2 = rng.random((3, 3, C2, C2), dtype=np.float32) conv3 = rng.random((1, 1, C2, C1), dtype=np.float32) return input, conv1, conv2, conv3 run(bench.resnet, resnet, N=8, W=56, H=56, C1=256, C2=64) def scattering_self_energies(Nkz, NE, Nqz, Nw, N3D, NA, NB, Norb): rng = default_rng(42) neigh_idx = np.empty((NA, NB), dtype=np.int32) for i in range(NA): neigh_idx[i] = np.positive(np.arange(i - NB / 2, i + NB / 2) % NA) dH = rng_complex((NA, NB, N3D, Norb, Norb), rng) G = rng_complex((Nkz, NE, NA, Norb, Norb), rng) D = rng_complex((Nqz, Nw, NA, NB, N3D, N3D), rng) Sigma = np.zeros((Nkz, NE, NA, Norb, Norb), dtype=np.complex128) return neigh_idx, dH, G, D, Sigma run(bench.scattering_self_energies, scattering_self_energies, Nkz=4, NE=10, Nqz=4, Nw=3, N3D=3, NA=20, NB=4, Norb=4) def seidel_2d(N, TSTEPS, datatype=np.float64): A = np.fromfunction(lambda i, j: (i * (j + 2) + 2) / N, (N, N), dtype=datatype) return TSTEPS, N, A run(bench.seidel_2d, seidel_2d, N=400, TSTEPS=100) def softmax(N, H, SM): rng = default_rng(42) x = rng.random((N, H, SM, SM), dtype=np.float32) return x run(bench.softmax, softmax, N=64, H=16, SM=512) def spmv(M, N, nnz): from python import numpy as NP from python import numpy.random as NR from python import scipy.sparse as SS rng = NR.default_rng(42) x: np.ndarray[float,1] = rng.random((N, )) matrix = SS.random( M, N, density=nnz / (M * N), format='csr', dtype=NP.float64, random_state=rng ) rows: np.ndarray[u32,1] = NP.uint32(matrix.indptr) cols: np.ndarray[u32,1] = NP.uint32(matrix.indices) vals: np.ndarray[float,1] = matrix.data return rows, cols, vals, x run(bench.spmv, spmv, M=131072, N=131072, nnz=262144) def stockham_fft(R, K): rng = default_rng(42) N = R**K X = rng_complex((N, ), rng) Y = np.zeros_like(X, dtype=np.complex128) return N, R, K, X, Y run(bench.stockham_fft, stockham_fft, R=2, K=21) def symm(M, N, datatype=np.float64): alpha = datatype(1.5) beta = datatype(1.2) C = np.fromfunction(lambda i, j: ((i + j) % 100) / M, (M, N), dtype=datatype) B = np.fromfunction(lambda i, j: ((N + i - j) % 100) / M, (M, N), dtype=datatype) A = np.empty((M, M), dtype=datatype) for i in range(M): A[i, :i + 1] = np.fromfunction(lambda j: ((i + j) % 100) / M, (i + 1, ), dtype=datatype) A[i, i + 1:] = -999 return alpha, beta, C, A, B run(bench.symm, symm, M=1000, N=1200) def syr2k(M, N, datatype=np.float64): alpha = datatype(1.5) beta = datatype(1.2) C = np.fromfunction(lambda i, j: ((i * j + 3) % N) / M, (N, N), dtype=datatype) A = np.fromfunction(lambda i, j: ((i * j + 1) % N) / N, (N, M), dtype=datatype) B = np.fromfunction(lambda i, j: ((i * j + 2) % M) / M, (N, M), dtype=datatype) return alpha, beta, C, A, B run(bench.syr2k, syr2k, M=350, N=400) def syrk(M, N, datatype=np.float64): alpha = datatype(1.5) beta = datatype(1.2) C = np.fromfunction(lambda i, j: ((i * j + 2) % N) / M, (N, N), dtype=datatype) A = np.fromfunction(lambda i, j: ((i * j + 1) % N) / N, (N, M), dtype=datatype) return alpha, beta, C, A run(bench.syr2k, syr2k, M=1000, N=1200) def trisolv(N, datatype=np.float64): L = np.fromfunction(lambda i, j: (i + N - j + 1) * 2 / N, (N, N), dtype=datatype) x = np.full((N, ), -999, dtype=datatype) b = np.fromfunction(lambda i: i, (N, ), dtype=datatype) return L, x, b run(bench.trisolv, trisolv, N=16000) def trmm(M, N, datatype=np.float64): alpha = datatype(1.5) A = np.fromfunction(lambda i, j: ((i * j) % M) / M, (M, M), dtype=datatype) for i in range(M): A[i, i] = 1.0 B = np.fromfunction(lambda i, j: ((N + i - j) % N) / N, (M, N), dtype=datatype) return alpha, A, B run(bench.trmm, trmm, M=1000, N=1200) def vadv(I, J, K): rng = default_rng(42) dtr_stage = 3. / 20. # Define arrays utens_stage = rng.random((I, J, K)) u_stage = rng.random((I, J, K)) wcon = rng.random((I + 1, J, K)) u_pos = rng.random((I, J, K)) utens = rng.random((I, J, K)) return utens_stage, u_stage, wcon, u_pos, utens, dtr_stage run(bench.vadv, vadv, I=256, J=256, K=160) ================================================ FILE: bench/codon/npbench_lib.codon ================================================ # Copyright 2021 ETH Zurich and the NPBench authors. All rights reserved. __CODON_RET__: Literal[bool] = True import numpy as np import numpy.pybridge def adi(TSTEPS: int, N: int, u: np.ndarray[float, 2]): v = np.empty(u.shape, dtype=u.dtype) p = np.empty(u.shape, dtype=u.dtype) q = np.empty(u.shape, dtype=u.dtype) DX = 1.0 / N DY = 1.0 / N DT = 1.0 / TSTEPS B1 = 2.0 B2 = 1.0 mul1 = B1 * DT / (DX * DX) mul2 = B2 * DT / (DY * DY) a = -mul1 / 2.0 b = 1.0 + mul2 c = a d = -mul2 / 2.0 e = 1.0 + mul2 f = d for t in range(1, TSTEPS + 1): v[0, 1:N - 1] = 1.0 p[1:N - 1, 0] = 0.0 q[1:N - 1, 0] = v[0, 1:N - 1] for j in range(1, N - 1): p[1:N - 1, j] = -c / (a * p[1:N - 1, j - 1] + b) q[1:N - 1, j] = (-d *u[j, 0:N - 2] + (1.0 + 2.0 * d) *u[j, 1:N - 1] - f *u[j, 2:N] - a * q[1:N - 1, j - 1]) / (a * p[1:N - 1, j - 1] + b) v[N - 1, 1:N - 1] = 1.0 for j in range(N - 2, 0, -1): v[j, 1:N - 1] = p[1:N - 1, j] * v[j + 1, 1:N - 1] + q[1:N - 1, j] u[1:N - 1, 0] = 1.0 p[1:N - 1, 0] = 0.0 q[1:N - 1, 0] = u[1:N - 1, 0] for j in range(1, N - 1): p[1:N - 1, j] = -f / (d * p[1:N - 1, j - 1] + e) q[1:N - 1, j] = (-a * v[0:N - 2, j] + (1.0 + 2.0 * a) * v[1:N - 1, j] - c * v[2:N, j] - d * q[1:N - 1, j - 1]) / (d * p[1:N - 1, j - 1] + e) u[1:N - 1, N - 1] = 1.0 for j in range(N - 2, 0, -1): u[1:N - 1, j] = p[1:N - 1, j] *u[1:N - 1, j + 1] + q[1:N - 1, j] def arc_distance(theta_1: np.ndarray[float, 1], phi_1: np.ndarray[float, 1], theta_2: np.ndarray[float, 1], phi_2: np.ndarray[float, 1]): """ Calculates the pairwise arc distance between all points in vector a and b. """ temp = np.sin((theta_2 - theta_1) / 2)**2 + np.cos(theta_1) * np.cos(theta_2) * np.sin( (phi_2 - phi_1) / 2)**2 distance_matrix = 2 * (np.arctan2(np.sqrt(temp), np.sqrt(1 - temp))) if __CODON_RET__: return distance_matrix def azimint_naive(data: np.ndarray[float, 1], radius: np.ndarray[float, 1], npt: int): rmax = radius.max() res = np.zeros(npt, dtype=np.float64) for i in range(npt): r1 = rmax * i / npt r2 = rmax * (i + 1) / npt mask_r12 = np.logical_and((r1 <= radius), (radius < r2)) values_r12 = data[mask_r12] res[i] = values_r12.mean() if __CODON_RET__: return res def azimint_hist(data: np.ndarray[float,1], radius: np.ndarray[float,1], npt: int): histu = np.histogram(radius, npt)[0] histw = np.histogram(radius, npt, weights=data)[0] res = histw / histu if __CODON_RET__: return res def atax(A: np.ndarray[float,2], x: np.ndarray[float,1]): res = (A @ x) @ A if __CODON_RET__: return res def bicg(A: np.ndarray[float,2], p: np.ndarray[float,1], r: np.ndarray[float,1]): res = r @ A, A @ p if __CODON_RET__: return res def cavity_flow(nx: int, ny: int, nt: int, nit: int, u: np.ndarray[float,2], v: np.ndarray[float,2], dt: float, dx: float, dy: float, p: np.ndarray[float,2], rho: float, nu: float): def build_up_b(b, rho, dt, u, v, dx, dy): b[1:-1, 1:-1] = (rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) + (v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) - ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 * ((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) * (v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) - ((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2)) def pressure_poisson(nit, p, dx, dy, b): pn = np.empty_like(p) pn = p.copy() for q in range(nit): pn = p.copy() p[1:-1, 1:-1] = (((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 + (pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) / (2 * (dx**2 + dy**2)) - dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 1:-1]) p[:, -1] = p[:, -2] # dp/dx = 0 at x = 2 p[0, :] = p[1, :] # dp/dy = 0 at y = 0 p[:, 0] = p[:, 1] # dp/dx = 0 at x = 0 p[-1, :] = 0 # p = 0 at y = 2 un = np.empty_like(u) vn = np.empty_like(v) b = np.zeros((ny, nx)) for n in range(nt): un = u.copy() vn = v.copy() build_up_b(b, rho, dt, u, v, dx, dy) pressure_poisson(nit, p, dx, dy, b) u[1:-1, 1:-1] = (un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * (un[1:-1, 1:-1] - un[1:-1, 0:-2]) - vn[1:-1, 1:-1] * dt / dy * (un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) * (p[1:-1, 2:] - p[1:-1, 0:-2]) + nu * (dt / dx**2 * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) + dt / dy**2 * (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1]))) v[1:-1, 1:-1] = (vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * (vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) - vn[1:-1, 1:-1] * dt / dy * (vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) * (p[2:, 1:-1] - p[0:-2, 1:-1]) + nu * (dt / dx**2 * (vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) + dt / dy**2 * (vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1]))) u[0, :] = 0 u[:, 0] = 0 u[:, -1] = 0 u[-1, :] = 1 # set velocity on cavity lid equal to 1 v[0, :] = 0 v[-1, :] = 0 v[:, 0] = 0 v[:, -1] = 0 def channel_flow(nit: int, u: np.ndarray[float, 2], v: np.ndarray[float, 2], dt: float, dx:float, dy: float, p: np.ndarray[float, 2], rho: float, nu: float, F: float): def build_up_b(rho, dt, dx, dy, u, v): b = np.zeros_like(u) b[1:-1, 1:-1] = (rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) + (v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) - ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 * ((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) * (v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) - ((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2)) # Periodic BC Pressure @ x = 2 b[1:-1, -1] = (rho * (1 / dt * ((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx) + (v[2:, -1] - v[0:-2, -1]) / (2 * dy)) - ((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx))**2 - 2 * ((u[2:, -1] - u[0:-2, -1]) / (2 * dy) * (v[1:-1, 0] - v[1:-1, -2]) / (2 * dx)) - ((v[2:, -1] - v[0:-2, -1]) / (2 * dy))**2)) # Periodic BC Pressure @ x = 0 b[1:-1, 0] = (rho * (1 / dt * ((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx) + (v[2:, 0] - v[0:-2, 0]) / (2 * dy)) - ((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx))**2 - 2 * ((u[2:, 0] - u[0:-2, 0]) / (2 * dy) * (v[1:-1, 1] - v[1:-1, -1]) / (2 * dx)) - ((v[2:, 0] - v[0:-2, 0]) / (2 * dy))**2)) return b def pressure_poisson_periodic(nit, p, dx, dy, b): pn = np.empty_like(p) for q in range(nit): pn = p.copy() p[1:-1, 1:-1] = (((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 + (pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) / (2 * (dx**2 + dy**2)) - dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 1:-1]) # Periodic BC Pressure @ x = 2 p[1:-1, -1] = (((pn[1:-1, 0] + pn[1:-1, -2]) * dy**2 + (pn[2:, -1] + pn[0:-2, -1]) * dx**2) / (2 * (dx**2 + dy**2)) - dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, -1]) # Periodic BC Pressure @ x = 0 p[1:-1, 0] = (((pn[1:-1, 1] + pn[1:-1, -1]) * dy**2 + (pn[2:, 0] + pn[0:-2, 0]) * dx**2) / (2 * (dx**2 + dy**2)) - dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 0]) # Wall boundary conditions, pressure p[-1, :] = p[-2, :] # dp/dy = 0 at y = 2 p[0, :] = p[1, :] # dp/dy = 0 at y = 0 udiff = 1.0 stepcount = 0 while udiff > .001: un = u.copy() vn = v.copy() b = build_up_b(rho, dt, dx, dy, u, v) pressure_poisson_periodic(nit, p, dx, dy, b) u[1:-1, 1:-1] = (un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * (un[1:-1, 1:-1] - un[1:-1, 0:-2]) - vn[1:-1, 1:-1] * dt / dy * (un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) * (p[1:-1, 2:] - p[1:-1, 0:-2]) + nu * (dt / dx**2 * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) + dt / dy**2 * (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])) + F * dt) v[1:-1, 1:-1] = (vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * (vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) - vn[1:-1, 1:-1] * dt / dy * (vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) * (p[2:, 1:-1] - p[0:-2, 1:-1]) + nu * (dt / dx**2 * (vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) + dt / dy**2 * (vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1]))) # Periodic BC u @ x = 2 u[1:-1, -1] = ( un[1:-1, -1] - un[1:-1, -1] * dt / dx * (un[1:-1, -1] - un[1:-1, -2]) - vn[1:-1, -1] * dt / dy * (un[1:-1, -1] - un[0:-2, -1]) - dt / (2 * rho * dx) * (p[1:-1, 0] - p[1:-1, -2]) + nu * (dt / dx**2 * (un[1:-1, 0] - 2 * un[1:-1, -1] + un[1:-1, -2]) + dt / dy**2 * (un[2:, -1] - 2 * un[1:-1, -1] + un[0:-2, -1])) + F * dt) # Periodic BC u @ x = 0 u[1:-1, 0] = (un[1:-1, 0] - un[1:-1, 0] * dt / dx * (un[1:-1, 0] - un[1:-1, -1]) - vn[1:-1, 0] * dt / dy * (un[1:-1, 0] - un[0:-2, 0]) - dt / (2 * rho * dx) * (p[1:-1, 1] - p[1:-1, -1]) + nu * (dt / dx**2 * (un[1:-1, 1] - 2 * un[1:-1, 0] + un[1:-1, -1]) + dt / dy**2 * (un[2:, 0] - 2 * un[1:-1, 0] + un[0:-2, 0])) + F * dt) # Periodic BC v @ x = 2 v[1:-1, -1] = ( vn[1:-1, -1] - un[1:-1, -1] * dt / dx * (vn[1:-1, -1] - vn[1:-1, -2]) - vn[1:-1, -1] * dt / dy * (vn[1:-1, -1] - vn[0:-2, -1]) - dt / (2 * rho * dy) * (p[2:, -1] - p[0:-2, -1]) + nu * (dt / dx**2 * (vn[1:-1, 0] - 2 * vn[1:-1, -1] + vn[1:-1, -2]) + dt / dy**2 * (vn[2:, -1] - 2 * vn[1:-1, -1] + vn[0:-2, -1]))) # Periodic BC v @ x = 0 v[1:-1, 0] = (vn[1:-1, 0] - un[1:-1, 0] * dt / dx * (vn[1:-1, 0] - vn[1:-1, -1]) - vn[1:-1, 0] * dt / dy * (vn[1:-1, 0] - vn[0:-2, 0]) - dt / (2 * rho * dy) * (p[2:, 0] - p[0:-2, 0]) + nu * (dt / dx**2 * (vn[1:-1, 1] - 2 * vn[1:-1, 0] + vn[1:-1, -1]) + dt / dy**2 * (vn[2:, 0] - 2 * vn[1:-1, 0] + vn[0:-2, 0]))) # Wall BC: u,v = 0 @ y = 0,2 u[0, :] = 0 u[-1, :] = 0 v[0, :] = 0 v[-1, :] = 0 udiff = (np.sum(u) - np.sum(un)) / np.sum(u) stepcount += 1 if __CODON_RET__: return stepcount def cholesky2(A: np.ndarray[float,2]): A[:] = np.linalg.cholesky(A) + np.triu(A, k=1) def cholesky(A: np.ndarray[float,2]): A[0, 0] = np.sqrt(A[0, 0]) for i in range(1, A.shape[0]): for j in range(i): A[i, j] -= np.dot(A[i, :j], A[j, :j]) A[i, j] /= A[j, j] A[i, i] -= np.dot(A[i, :i], A[i, :i]) A[i, i] = np.sqrt(A[i, i]) def compute(array_1: np.ndarray[int,2], array_2: np.ndarray[int,2], a: int, b: int, c: int): res = np.clip(array_1, 2, 10) * a + array_2 * b + c if __CODON_RET__: return res def contour_integral(NR: int, NM: int, slab_per_bc: int, Ham: np.ndarray[complex, 3], int_pts: np.ndarray[complex, 1], Y: np.ndarray[complex, 2]): P0 = np.zeros((NR, NM), dtype=np.complex128) P1 = np.zeros((NR, NM), dtype=np.complex128) for z in int_pts: Tz = np.zeros((NR, NR), dtype=np.complex128) for n in range(slab_per_bc + 1): zz = np.power(z, slab_per_bc / 2 - n) Tz += zz * Ham[n] if NR == NM: X = np.linalg.inv(Tz) else: X = np.linalg.solve(Tz, Y) if abs(z) < 1.0: X = -X P0 += X P1 += z * X res = P0, P1 if __CODON_RET__: return res def conv2d_bias(input: np.ndarray[float32,4], weights: np.ndarray[float32,4], bias: np.ndarray[float32,1]): def conv2d(input, weights): K = weights.shape[0] # Assuming square kernel N = input.shape[0] H_out = input.shape[1] - K + 1 W_out = input.shape[2] - K + 1 C_out = weights.shape[3] output = np.empty((N, H_out, W_out, C_out), dtype=np.float32) # Loop structure adapted from https://github.com/SkalskiP/ILearnDeepLearning.py/blob/ba0b5ba589d4e656141995e8d1a06d44db6ce58d/01_mysteries_of_neural_networks/06_numpy_convolutional_neural_net/src/layers/convolutional.py#L88 for i in range(H_out): for j in range(W_out): output[:, i, j, :] = np.sum( input[:, i:i + K, j:j + K, :, np.newaxis] * weights[np.newaxis, :, :, :], axis=(1, 2, 3), ) return output res = conv2d(input, weights) + bias if __CODON_RET__: return res def correlation(M: int, float_n: float, data: np.ndarray[float,2]): mean = np.mean(data, axis=0) stddev = np.std(data, axis=0) stddev[stddev <= 0.1] = 1.0 data -= mean data /= np.sqrt(float_n) * stddev corr = np.eye(M, dtype=data.dtype) for i in range(M - 1): corr[i + 1:M, i] = corr[i, i + 1:M] = data[:, i] @ data[:, i + 1:M] if __CODON_RET__: return corr def covariance(M: int, float_n: float, data: np.ndarray[float, 2]): mean = np.mean(data, axis=0) data -= mean cov = np.zeros((M, M), dtype=data.dtype) for i in range(M): cov[i:M, i] = cov[i, i:M] = data[:, i] @ data[:, i:M] / (float_n - 1.0) if __CODON_RET__: return cov def crc16(data: np.ndarray[np.uint8, 1]): ''' CRC-16-CCITT Algorithm ''' poly=0x8408 crc = 0xFFFF for b in data: cur_byte = 0xFF & int(b) for _ in range(0, 8): if (crc & 0x0001) ^ (cur_byte & 0x0001): crc = (crc >> 1) ^ poly else: crc >>= 1 cur_byte >>= 1 crc = (~crc & 0xFFFF) crc = (crc << 8) | ((crc >> 8) & 0xFF) res = crc & 0xFFFF if __CODON_RET__: return res def deriche(alpha: float, imgIn: np.ndarray[float,2]): k = (1.0 - np.exp(-alpha)) * (1.0 - np.exp(-alpha)) / ( 1.0 + alpha * np.exp(-alpha) - np.exp(2.0 * alpha)) a1 = a5 = k a2 = a6 = k * np.exp(-alpha) * (alpha - 1.0) a3 = a7 = k * np.exp(-alpha) * (alpha + 1.0) a4 = a8 = -k * np.exp(-2.0 * alpha) b1 = 2.0**(-alpha) b2 = -np.exp(-2.0 * alpha) c1 = c2 = 1 y1 = np.empty_like(imgIn) y1[:, 0] = a1 * imgIn[:, 0] y1[:, 1] = a1 * imgIn[:, 1] + a2 * imgIn[:, 0] + b1 * y1[:, 0] for j in range(2, imgIn.shape[1]): y1[:, j] = (a1 * imgIn[:, j] + a2 * imgIn[:, j - 1] + b1 * y1[:, j - 1] + b2 * y1[:, j - 2]) y2 = np.empty_like(imgIn) y2[:, -1] = 0.0 y2[:, -2] = a3 * imgIn[:, -1] for j in range(imgIn.shape[1] - 3, -1, -1): y2[:, j] = (a3 * imgIn[:, j + 1] + a4 * imgIn[:, j + 2] + b1 * y2[:, j + 1] + b2 * y2[:, j + 2]) imgOut = c1 * (y1 + y2) y1[0, :] = a5 * imgOut[0, :] y1[1, :] = a5 * imgOut[1, :] + a6 * imgOut[0, :] + b1 * y1[0, :] for i in range(2, imgIn.shape[0]): y1[i, :] = (a5 * imgOut[i, :] + a6 * imgOut[i - 1, :] + b1 * y1[i - 1, :] + b2 * y1[i - 2, :]) y2[-1, :] = 0.0 y2[-2, :] = a7 * imgOut[-1, :] for i in range(imgIn.shape[0] - 3, -1, -1): y2[i, :] = (a7 * imgOut[i + 1, :] + a8 * imgOut[i + 2, :] + b1 * y2[i + 1, :] + b2 * y2[i + 2, :]) imgOut[:] = c2 * (y1 + y2) if __CODON_RET__: return imgOut def doitgen(NR: int, NQ: int, NP: int, A: np.ndarray[float,3], C4: np.ndarray[float,2]): A[:] = np.reshape(np.reshape(A, (NR, NQ, 1, NP)) @ C4, (NR, NQ, NP)) def durbin(r: np.ndarray[float,1]): y = np.empty_like(r) alpha = -r[0] beta = 1.0 y[0] = -r[0] for k in range(1, r.shape[0]): beta *= 1.0 - alpha * alpha alpha = -(r[k] + np.dot(np.flip(r[:k]), y[:k])) / beta y[:k] += alpha * np.flip(y[:k]) y[k] = alpha if __CODON_RET__: return y def fdtd_2d(TMAX: int, ex: np.ndarray[float,2], ey: np.ndarray[float,2], hz: np.ndarray[float,2], _fict_: np.ndarray[float,1]): for t in range(TMAX): ey[0, :] = _fict_[t] ey[1:, :] -= 0.5 * (hz[1:, :] - hz[:-1, :]) ex[:, 1:] -= 0.5 * (hz[:, 1:] - hz[:, :-1]) hz[:-1, :-1] -= 0.7 * (ex[:-1, 1:] - ex[:-1, :-1] + ey[1:, :-1] - ey[:-1, :-1]) def floyd_warshall(path: np.ndarray[Int[32],2]): for k in range(path.shape[0]): path[:] = np.minimum(path[:], np.add.outer(path[:, k], path[k, :])) def gemm(alpha: float, beta: float, C: np.ndarray[float,2], A: np.ndarray[float,2], B: np.ndarray[float,2]): C[:] = alpha * A @ B + beta * C def gemver(alpha: float, beta: float, A: np.ndarray[float,2], u1: np.ndarray[float,1], v1: np.ndarray[float,1], u2: np.ndarray[float,1], v2: np.ndarray[float,1], w: np.ndarray[float,1], x: np.ndarray[float,1], y: np.ndarray[float,1], z: np.ndarray[float,1]): A += np.outer(u1, v1) + np.outer(u2, v2) x += beta * y @ A + z w += alpha * A @ x def gesummv(alpha: float, beta: float, A: np.ndarray[float,2], B: np.ndarray[float,2], x: np.ndarray[float,1]): res = alpha * A @ x + beta * B @ x if __CODON_RET__: return res def go_fast(a: np.ndarray[float,2]): trace = 0.0 for i in range(a.shape[0]): trace += np.tanh(a[i, i]) res = a + trace if __CODON_RET__: return res def gramschmidt(A: np.ndarray[float,2]): Q = np.zeros_like(A) R = np.zeros((A.shape[1], A.shape[1]), dtype=A.dtype) for k in range(A.shape[1]): nrm = np.dot(A[:, k], A[:, k]) R[k, k] = np.sqrt(nrm) Q[:, k] = A[:, k] / R[k, k] for j in range(k + 1, A.shape[1]): R[k, j] = np.dot(Q[:, k], A[:, j]) A[:, j] -= Q[:, k] * R[k, j] res = Q, R if __CODON_RET__: return res def hdiff(in_field: np.ndarray[float,3], out_field: np.ndarray[float,3], coeff: np.ndarray[float,3]): I, J, K = out_field.shape[0], out_field.shape[1], out_field.shape[2] lap_field = 4.0 * in_field[1:I + 3, 1:J + 3, :] - ( in_field[2:I + 4, 1:J + 3, :] + in_field[0:I + 2, 1:J + 3, :] + in_field[1:I + 3, 2:J + 4, :] + in_field[1:I + 3, 0:J + 2, :]) res = lap_field[1:, 1:J + 1, :] - lap_field[:-1, 1:J + 1, :] flx_field = np.where( (res * (in_field[2:I + 3, 2:J + 2, :] - in_field[1:I + 2, 2:J + 2, :])) > 0, 0, res, ) res = lap_field[1:I + 1, 1:, :] - lap_field[1:I + 1, :-1, :] fly_field = np.where( (res * (in_field[2:I + 2, 2:J + 3, :] - in_field[2:I + 2, 1:J + 2, :])) > 0, 0, res, ) out_field[:, :, :] = in_field[2:I + 2, 2:J + 2, :] - coeff[:, :, :] * ( flx_field[1:, :, :] - flx_field[:-1, :, :] + fly_field[:, 1:, :] - fly_field[:, :-1, :]) def heat_3d(TSTEPS: int, A: np.ndarray[float, 3], B: np.ndarray[float, 3]): for t in range(1, TSTEPS): B[1:-1, 1:-1, 1:-1] = (0.125 * (A[2:, 1:-1, 1:-1] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[:-2, 1:-1, 1:-1]) + 0.125 * (A[1:-1, 2:, 1:-1] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[1:-1, :-2, 1:-1]) + 0.125 * (A[1:-1, 1:-1, 2:] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[1:-1, 1:-1, 0:-2]) + A[1:-1, 1:-1, 1:-1]) A[1:-1, 1:-1, 1:-1] = (0.125 * (B[2:, 1:-1, 1:-1] - 2.0 * B[1:-1, 1:-1, 1:-1] + B[:-2, 1:-1, 1:-1]) + 0.125 * (B[1:-1, 2:, 1:-1] - 2.0 * B[1:-1, 1:-1, 1:-1] + B[1:-1, :-2, 1:-1]) + 0.125 * (B[1:-1, 1:-1, 2:] - 2.0 * B[1:-1, 1:-1, 1:-1] + B[1:-1, 1:-1, 0:-2]) + B[1:-1, 1:-1, 1:-1]) def jacobi_1d(TSTEPS: int, A: np.ndarray[float, 1], B: np.ndarray[float, 1]): for t in range(1, TSTEPS): B[1:-1] = 0.33333 * (A[:-2] + A[1:-1] + A[2:]) A[1:-1] = 0.33333 * (B[:-2] + B[1:-1] + B[2:]) def jacobi_2d(TSTEPS: int, A: np.ndarray[float, 2], B: np.ndarray[float, 2]): for t in range(1, TSTEPS): B[1:-1, 1:-1] = 0.2 * (A[1:-1, 1:-1] + A[1:-1, :-2] + A[1:-1, 2:] + A[2:, 1:-1] + A[:-2, 1:-1]) A[1:-1, 1:-1] = 0.2 * (B[1:-1, 1:-1] + B[1:-1, :-2] + B[1:-1, 2:] + B[2:, 1:-1] + B[:-2, 1:-1]) def k2mm(alpha: float, beta: float, A: np.ndarray[float,2], B: np.ndarray[float,2], C: np.ndarray[float,2], D: np.ndarray[float,2]): D[:] = alpha * A @ B @ C + beta * D def k3mm(A: np.ndarray[float,2], B: np.ndarray[float,2], C: np.ndarray[float,2], D: np.ndarray[float,2]): res = A @ B @ C @ D if __CODON_RET__: return res def lenet(input: np.ndarray[float32,4], conv1: np.ndarray[float32,4], conv1bias: np.ndarray[float32,1], conv2: np.ndarray[float32,4], conv2bias: np.ndarray[float32,1], fc1w: np.ndarray[float32,2], fc1b: np.ndarray[float32,1], fc2w: np.ndarray[float32,2], fc2b: np.ndarray[float32,1], fc3w: np.ndarray[float32,2], fc3b: np.ndarray[float32,1], N:int, C_before_fc1:int): def relu(x): return np.maximum(x, 0) def conv2d(input, weights): K = weights.shape[0] # Assuming square kernel N = input.shape[0] H_out = input.shape[1] - K + 1 W_out = input.shape[2] - K + 1 C_out = weights.shape[3] output = np.empty((N, H_out, W_out, C_out), dtype=np.float32) # Loop structure adapted from https://github.com/SkalskiP/ILearnDeepLearning.py/blob/ba0b5ba589d4e656141995e8d1a06d44db6ce58d/01_mysteries_of_neural_networks/06_numpy_convolutional_neural_net/src/layers/convolutional.py#L88 for i in range(H_out): for j in range(W_out): output[:, i, j, :] = np.sum( input[:, i:i + K, j:j + K, :, np.newaxis] * weights[np.newaxis, :, :, :], axis=(1, 2, 3), ) return output def maxpool2d(x): output = np.empty( (x.shape[0], x.shape[1] // 2, x.shape[2] // 2, x.shape[3]), dtype=x.dtype) for i in range(x.shape[1] // 2): for j in range(x.shape[2] // 2): output[:, i, j, :] = np.max(x[:, 2 * i:2 * i + 2, 2 * j:2 * j + 2, :], axis=(1, 2)) return output x = relu(conv2d(input, conv1) + conv1bias) x = maxpool2d(x) x = relu(conv2d(x, conv2) + conv2bias) x = maxpool2d(x) x = np.reshape(x, (N, C_before_fc1)) x = relu(x @ fc1w + fc1b) x = relu(x @ fc2w + fc2b) res = x @ fc3w + fc3b if __CODON_RET__: return res def lu(A: np.ndarray[float, 2]): for i in range(A.shape[0]): for j in range(i): A[i, j] -= A[i, :j] @ A[:j, j] A[i, j] /= A[j, j] for j in range(i, A.shape[0]): A[i, j] -= A[i, :i] @ A[:i, j] def ludcmp(A: np.ndarray[float, 2], b: np.ndarray[float,1]): x = np.zeros_like(b) y = np.zeros_like(b) for i in range(A.shape[0]): for j in range(i): A[i, j] -= A[i, :j] @ A[:j, j] A[i, j] /= A[j, j] for j in range(i, A.shape[0]): A[i, j] -= A[i, :i] @ A[:i, j] for i in range(A.shape[0]): y[i] = b[i] - A[i, :i] @ y[:i] for i in range(A.shape[0] - 1, -1, -1): x[i] = (y[i] - A[i, i + 1:] @ x[i + 1:]) / A[i, i] res = x, y if __CODON_RET__: return res def mandelbrot1(xmin: float, xmax: float, ymin: float, ymax: float, xn: int, yn: int, maxiter: int, horizon: float): horizon = 2.0 X = np.linspace(xmin, xmax, xn, dtype=np.float64) Y = np.linspace(ymin, ymax, yn, dtype=np.float64) C = X + Y[:, None] * 1j N = np.zeros(C.shape, dtype=np.int64) Z = np.zeros(C.shape, dtype=np.complex128) for n in range(maxiter): I = np.less(abs(Z), horizon) N[I] = n Z[I] = Z[I]**2 + C[I] N[N == maxiter - 1] = 0 res = Z, N if __CODON_RET__: return res def mandelbrot2(xmin: float, xmax: float, ymin: float, ymax: float, xn: int, yn: int, itermax: int, horizon: float): def mgrid(xn, yn): Xi = np.empty((xn, yn), dtype=np.int64) Yi = np.empty((xn, yn), dtype=np.int64) for i in range(xn): Xi[i, :] = i for j in range(yn): Yi[:, j] = j return Xi, Yi Xi, Yi = mgrid(xn, yn) X = np.linspace(xmin, xmax, xn, dtype=np.float64)[Xi] Y = np.linspace(ymin, ymax, yn, dtype=np.float64)[Yi] C = X + Y * 1j N_ = np.zeros(C.shape, dtype=np.int64) Z_ = np.zeros(C.shape, dtype=np.complex128) # Xi.shape = Yi.shape = C.shape = xn * yn Xi = Xi.reshape(xn * yn) Yi = Yi.reshape(xn * yn) C = C.reshape(xn * yn) Z = np.zeros(C.shape, np.complex128) for i in range(itermax): if not len(Z): break # Compute for relevant points only np.multiply(Z, Z, Z) np.add(Z, C, Z) # Failed convergence I = abs(Z) > horizon N_[Xi[I], Yi[I]] = i + 1 Z_[Xi[I], Yi[I]] = Z[I] # Keep going with those who have not diverged yet np.logical_not(I, I) # np.negative(I, I) not working any longer Z = Z[I] Xi, Yi = Xi[I], Yi[I] C = C[I] res = Z_.T, N_.T if __CODON_RET__: return res def mlp(input: np.ndarray[float32, 2], w1: np.ndarray[float32,2], b1: np.ndarray[float32,1], w2: np.ndarray[float32,2], b2: np.ndarray[float32,1], w3: np.ndarray[float32, 2], b3: np.ndarray[float32, 1]): def relu(x): return np.maximum(x, 0) def softmax(x): tmp_max = np.max(x, axis=-1, keepdims=True) tmp_out = np.exp(x - tmp_max) tmp_sum = np.sum(tmp_out, axis=-1, keepdims=True) return tmp_out / tmp_sum x = relu(input @ w1 + b1) # x = np.array(x, dtype=np.float32) x = relu(x @ w2 + b2) # x = np.array(x, dtype=np.float32) x = softmax(x @ w3 + b3) # Softmax call can be omitted if necessary if __CODON_RET__: return x def mvt(x1: np.ndarray[float, 1], x2: np.ndarray[float, 1], y_1: np.ndarray[float,1], y_2: np.ndarray[float,1], A: np.ndarray[float, 2]): x1 += A @ y_1 x2 += y_2 @ A def nbody(mass: np.ndarray[float, 2], pos: np.ndarray[float, 2], vel: np.ndarray[float, 2], N: int, Nt: int, dt: float, G: float, softening: float): def getAcc(pos, mass, G, softening): """ Calculate the acceleration on each particle due to Newton's Law pos is an N x 3 matrix of positions mass is an N x 1 vector of masses G is Newton's Gravitational constant softening is the softening length a is N x 3 matrix of accelerations """ # positions r = [x,y,z] for all particles x = pos[:, 0:1] y = pos[:, 1:2] z = pos[:, 2:3] # matrix that stores all pairwise particle separations: r_j - r_i dx = x.T - x dy = y.T - y dz = z.T - z # matrix that stores 1/r^3 for all particle pairwise particle separations inv_r3 = (dx**2 + dy**2 + dz**2 + softening**2) inv_r3[inv_r3 > 0] = inv_r3[inv_r3 > 0]**(-1.5) ax = G * (dx * inv_r3) @ mass ay = G * (dy * inv_r3) @ mass az = G * (dz * inv_r3) @ mass # pack together the acceleration components a = np.hstack((ax, ay, az)) return a def getEnergy(pos, vel, mass, G): """ Get kinetic energy (KE) and potential energy (PE) of simulation pos is N x 3 matrix of positions vel is N x 3 matrix of velocities mass is an N x 1 vector of masses G is Newton's Gravitational constant KE is the kinetic energy of the system PE is the potential energy of the system """ # Kinetic Energy: # KE = 0.5 * np.sum(np.sum( mass * vel**2 )) KE = 0.5 * np.sum(mass * vel**2) # Potential Energy: # positions r = [x,y,z] for all particles x = pos[:, 0:1] y = pos[:, 1:2] z = pos[:, 2:3] # matrix that stores all pairwise particle separations: r_j - r_i dx = x.T - x dy = y.T - y dz = z.T - z # matrix that stores 1/r for all particle pairwise particle separations inv_r = np.sqrt(dx**2 + dy**2 + dz**2) inv_r[inv_r > 0] = 1.0 / inv_r[inv_r > 0] # sum over upper triangle, to count each interaction only once # PE = G * np.sum(np.sum(np.triu(-(mass*mass.T)*inv_r,1))) PE = G * np.sum(np.triu(-(mass * mass.T) * inv_r, 1)) return KE, PE # Convert to Center-of-Mass frame vel -= np.mean(mass * vel, axis=0) / np.mean(mass) # calculate initial gravitational accelerations acc = getAcc(pos, mass, G, softening) # calculate initial energy of system KE = np.empty(Nt + 1, dtype=np.float64) PE = np.empty(Nt + 1, dtype=np.float64) KE[0], PE[0] = getEnergy(pos, vel, mass, G) t = 0.0 # Simulation Main Loop for i in range(Nt): # (1/2) kick vel += acc * dt / 2.0 # drift pos += vel * dt # update accelerations acc = getAcc(pos, mass, G, softening) # (1/2) kick vel += acc * dt / 2.0 # update time t += dt # get energy of system KE[i + 1], PE[i + 1] = getEnergy(pos, vel, mass, G) res = KE, PE if __CODON_RET__: return res def nussinov(N: int, seq: np.ndarray[Int[32], 1]): def match(b1, b2): if b1 + b2 == 3: return 1 else: return 0 table = np.zeros((N, N), np.int32) for i in range(N - 1, -1, -1): for j in range(i + 1, N): if j - 1 >= 0: table[i, j] = max(table[i, j], table[i, j - 1]) if i + 1 < N: table[i, j] = max(table[i, j], table[i + 1, j]) if j - 1 >= 0 and i + 1 < N: if i < j - 1: table[i, j] = max(table[i, j], table[i + 1, j - 1] + Int[32](match(int(seq[i]), int(seq[j])))) else: table[i, j] = max(table[i, j], table[i + 1, j - 1]) for k in range(i + 1, j): table[i, j] = max(table[i, j], table[i, k] + table[k + 1, j]) if __CODON_RET__: return table def resnet(input: np.ndarray[float32, 4], conv1: np.ndarray[float32, 4], conv2: np.ndarray[float32, 4], conv3: np.ndarray[float32, 4]): def relu(x): return np.maximum(x, 0) def conv2d(input, weights): K = weights.shape[0] # Assuming square kernel N = input.shape[0] H_out = input.shape[1] - K + 1 W_out = input.shape[2] - K + 1 C_out = weights.shape[3] output = np.empty((N, H_out, W_out, C_out), dtype=np.float32) # Loop structure adapted from https://github.com/SkalskiP/ILearnDeepLearning.py/blob/ba0b5ba589d4e656141995e8d1a06d44db6ce58d/01_mysteries_of_neural_networks/06_numpy_convolutional_neural_net/src/layers/convolutional.py#L88 for i in range(H_out): for j in range(W_out): output[:, i, j, :] = np.sum(np.sum( np.sum(input[:, i:i + K, j:j + K, :, np.newaxis] * weights[np.newaxis, :, :, :], axis=1), axis=1), axis=1) return output def batchnorm2d(x): # mean = np.mean(x, axis=0, keepdims=True) eps=1e-5 mean = np.mean(x, axis=0)[np.newaxis, :, :, :] # std = np.std(x, axis=0, keepdims=True) std = np.std(x, axis=0)[np.newaxis, :, :, :] return (x - mean) / np.sqrt(std + eps) # Pad output of first convolution for second convolution padded = np.zeros((input.shape[0], input.shape[1] + 2, input.shape[2] + 2, conv1.shape[3])) padded[:, 1:-1, 1:-1, :] = conv2d(input, conv1) x = batchnorm2d(padded) x = relu(x) x = conv2d(x, conv2) x = batchnorm2d(x) x = relu(x) x = conv2d(x, conv3) x = batchnorm2d(x) res = relu(x + input) if __CODON_RET__: return res def scattering_self_energies(neigh_idx: np.ndarray[Int[32], 2], dH: np.ndarray[complex, 5], G: np.ndarray[complex, 5], D: np.ndarray[complex, 6], Sigma: np.ndarray[complex, 5]): for k in range(G.shape[0]): for E in range(G.shape[1]): for q in range(D.shape[0]): for w in range(D.shape[1]): for i in range(D.shape[-2]): for j in range(D.shape[-1]): for a in range(neigh_idx.shape[0]): for b in range(neigh_idx.shape[1]): if E - w >= 0: dHG = G[k, E - w, int(neigh_idx[a, b])] @ dH[a, b, i] dHD = dH[a, b, j] * D[q, w, a, b, i, j] Sigma[k, E, a] += dHG @ dHD def seidel_2d(TSTEPS: int, N: int, A: np.ndarray[float,2]): for t in range(0, TSTEPS - 1): for i in range(1, N - 1): A[i, 1:-1] += (A[i - 1, :-2] + A[i - 1, 1:-1] + A[i - 1, 2:] + A[i, 2:] + A[i + 1, :-2] + A[i + 1, 1:-1] + A[i + 1, 2:]) for j in range(1, N - 1): A[i, j] += A[i, j - 1] A[i, j] /= 9.0 def softmax(x: np.ndarray[float32, 4]): tmp_max = np.max(x, axis=-1, keepdims=True) tmp_out = np.exp(x - tmp_max) tmp_sum = np.sum(tmp_out, axis=-1, keepdims=True) res = tmp_out / tmp_sum if __CODON_RET__: return res def spmv(A_row: np.ndarray[u32,1], A_col: np.ndarray[u32,1], A_val: np.ndarray[float,1], x: np.ndarray[float,1]): y = np.empty(A_row.size - 1, A_val.dtype) for i in range(A_row.size - 1): cols = A_col[int(A_row[i]):int(A_row[i + 1])].astype(np.int64) vals = A_val[int(A_row[i]):int(A_row[i + 1])] y[i] = vals @ x[cols] if __CODON_RET__: return y def stockham_fft(N: int, R: int, K: int, x: np.ndarray[complex, 1], y: np.ndarray[complex, 1]): def mgrid_stockham(xn, yn): Xi = np.empty((xn, yn), dtype=np.int32) Yi = np.empty((xn, yn), dtype=np.int32) for i in range(xn): Xi[i, :] = i for j in range(yn): Yi[:, j] = j return Xi, Yi # Generate DFT matrix for radix R. # Define transient variable for matrix. # i_coord, j_coord = np.mgrid[0:R, 0:R] i_coord, j_coord = mgrid_stockham(R, R) dft_mat = np.empty((R, R), dtype=np.complex128) dft_mat = np.exp(-2.0j * np.pi * i_coord * j_coord / R) # Move input x to output y # to avoid overwriting the input. y[:] = x[:] # ii_coord, jj_coord = np.mgrid[0:R, 0:R**K] ii_coord, jj_coord = mgrid_stockham(R, R**K) # Main Stockham loop for i in range(K): # Stride permutation yv = np.reshape(y, (R**i, R, R**(K-i-1))) tmp_perm = np.transpose(yv, axes=(1, 0, 2)) # Twiddle Factor multiplication D = np.empty((R, R**i, R**(K - i - 1)), dtype=np.complex128) tmp = np.exp(-2.0j * np.pi * ii_coord[:, :R**i] * jj_coord[:, :R**i] / R**(i + 1)) D[:] = np.repeat(np.reshape(tmp, (R, R**i, 1)), R ** (K-i-1), axis=2) tmp_twid = np.reshape(tmp_perm, (N, )) * np.reshape(D, (N, )) # Product with Butterfly y[:] = np.reshape(dft_mat @ np.reshape(tmp_twid, (R, R**(K-1))), (N, )) def symm(alpha: float, beta: float, C: np.ndarray[float, 2], A: np.ndarray[float, 2], B: np.ndarray[float, 2]): temp2 = np.empty((C.shape[1], ), dtype=C.dtype) C *= beta for i in range(C.shape[0]): for j in range(C.shape[1]): C[:i, j] += alpha * B[i, j] * A[i, :i] temp2[j] = B[:i, j] @ A[i, :i] C[i, :] += alpha * B[i, :] * A[i, i] + alpha * temp2 def syr2k(alpha: float, beta: float, C: np.ndarray[float,2], A: np.ndarray[float,2], B: np.ndarray[float,2]): for i in range(A.shape[0]): C[i, :i + 1] *= beta for k in range(A.shape[1]): C[i, :i + 1] += (A[:i + 1, k] * alpha * B[i, k] + B[:i + 1, k] * alpha * A[i, k]) def syrk(alpha: float, beta: float, C: np.ndarray[float, 2], A: np.ndarray[float, 2]): for i in range(A.shape[0]): C[i, :i + 1] *= beta for k in range(A.shape[1]): C[i, :i + 1] += alpha * A[i, k] * A[:i + 1, k] def trisolv(L: np.ndarray[float, 2], x: np.ndarray[float, 1], b: np.ndarray[float,1]): for i in range(x.shape[0]): x[i] = (b[i] - L[i, :i] @ x[:i]) / L[i, i] def trmm(alpha: float, A: np.ndarray[float,2], B: np.ndarray[float,2]): for i in range(B.shape[0]): for j in range(B.shape[1]): B[i, j] += np.dot(A[i + 1:, i], B[i + 1:, j]) B *= alpha def vadv(utens_stage: np.ndarray[float, 3], u_stage: np.ndarray[float,3], wcon: np.ndarray[float,3], u_pos: np.ndarray[float,3], utens: np.ndarray[float,3], dtr_stage: float): BET_M = 0.5 BET_P = 0.5 I, J, K = utens_stage.shape[0], utens_stage.shape[1], utens_stage.shape[2] # ccol = np.ndarray((I, J, K), dtype=utens_stage.dtype) # dcol = np.ndarray((I, J, K), dtype=utens_stage.dtype) # data_col = np.ndarray((I, J), dtype=utens_stage.dtype) ccol = np.empty((I, J, K), dtype=utens_stage.dtype) dcol = np.empty((I, J, K), dtype=utens_stage.dtype) data_col = np.empty((I, J), dtype=utens_stage.dtype) for k in range(1): gcv = 0.25 * (wcon[1:, :, k + 1] + wcon[:-1, :, k + 1]) cs = gcv * BET_M ccol[:, :, k] = gcv * BET_P bcol = dtr_stage - ccol[:, :, k] # update the d column correction_term = -cs * (u_stage[:, :, k + 1] - u_stage[:, :, k]) dcol[:, :, k] = (dtr_stage * u_pos[:, :, k] + utens[:, :, k] + utens_stage[:, :, k] + correction_term) # Thomas forward divided = 1.0 / bcol ccol[:, :, k] = ccol[:, :, k] * divided dcol[:, :, k] = dcol[:, :, k] * divided for k in range(1, K - 1): gav = -0.25 * (wcon[1:, :, k] + wcon[:-1, :, k]) gcv = 0.25 * (wcon[1:, :, k + 1] + wcon[:-1, :, k + 1]) as_ = gav * BET_M cs = gcv * BET_M acol = gav * BET_P ccol[:, :, k] = gcv * BET_P bcol = dtr_stage - acol - ccol[:, :, k] # update the d column correction_term = -as_ * (u_stage[:, :, k - 1] - u_stage[:, :, k]) - cs * ( u_stage[:, :, k + 1] - u_stage[:, :, k]) dcol[:, :, k] = (dtr_stage * u_pos[:, :, k] + utens[:, :, k] + utens_stage[:, :, k] + correction_term) # Thomas forward divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) ccol[:, :, k] = ccol[:, :, k] * divided dcol[:, :, k] = (dcol[:, :, k] - (dcol[:, :, k - 1]) * acol) * divided for k in range(K - 1, K): gav = -0.25 * (wcon[1:, :, k] + wcon[:-1, :, k]) as_ = gav * BET_M acol = gav * BET_P bcol = dtr_stage - acol # update the d column correction_term = -as_ * (u_stage[:, :, k - 1] - u_stage[:, :, k]) dcol[:, :, k] = (dtr_stage * u_pos[:, :, k] + utens[:, :, k] + utens_stage[:, :, k] + correction_term) # Thomas forward divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) dcol[:, :, k] = (dcol[:, :, k] - (dcol[:, :, k - 1]) * acol) * divided for k in range(K - 1, K - 2, -1): datacol = dcol[:, :, k] data_col[:] = datacol utens_stage[:, :, k] = dtr_stage * (datacol - u_pos[:, :, k]) for k in range(K - 2, -1, -1): datacol = dcol[:, :, k] - ccol[:, :, k] * data_col[:, :] data_col[:] = datacol utens_stage[:, :, k] = dtr_stage * (datacol - u_pos[:, :, k]) ================================================ FILE: bench/codon/primes.codon ================================================ from sys import argv from time import time def is_prime(n): factors = 0 for i in range(2, n): if n % i == 0: factors += 1 return factors == 0 limit = int(argv[1]) total = 0 t0 = time() @par(schedule='dynamic') for i in range(2, limit): if is_prime(i): total += 1 t1 = time() print(total) print(t1 - t0) ================================================ FILE: bench/codon/primes.py ================================================ from sys import argv from time import time def is_prime(n): factors = 0 for i in range(2, n): if n % i == 0: factors += 1 return factors == 0 limit = int(argv[1]) total = 0 t0 = time() for i in range(2, limit): if is_prime(i): total += 1 t1 = time() print(total) print(t1 - t0) ================================================ FILE: bench/codon/set_partition.cpp ================================================ #include #include #include #include #include template using vec = std::vector; inline vec range(int start, int stop) { vec v(stop - start); uint j = 0; for (int i = start; i < stop; i++) v[j++] = i; return v; } inline bool conforms(const vec> &candidate, int minsize, int forgive) { int deficit = 0; for (const auto &p : candidate) { int need = minsize - static_cast(p.size()); if (need > 0) deficit += need; } return deficit <= forgive; } inline void partition_filtered(const vec &collection, std::function> &)> callback, int minsize = 1, int forgive = 0) { if (collection.size() == 1) { callback({collection}); return; } auto first = collection[0]; auto loop = [&](const vec> &smaller) { int n = 0; vec> candidate; candidate.reserve(smaller.size() + 1); vec rep; for (const auto &subset : smaller) { candidate.resize(n); for (int i = 0; i < n; i++) candidate[i] = smaller[i]; rep.clear(); rep.reserve(subset.size() + 1); rep.push_back(first); rep.insert(rep.end(), subset.begin(), subset.end()); candidate.push_back({rep}); for (int i = n + 1; i < smaller.size(); i++) candidate.push_back(smaller[i]); if (conforms(candidate, minsize, forgive)) callback(candidate); ++n; } candidate.clear(); candidate.push_back({first}); candidate.insert(candidate.end(), smaller.begin(), smaller.end()); if (conforms(candidate, minsize, forgive)) callback(candidate); }; vec new_collection(collection.begin() + 1, collection.end()); partition_filtered(new_collection, loop, minsize, forgive + 1); } int main(int argc, char *argv[]) { using clock = std::chrono::high_resolution_clock; using std::chrono::duration_cast; using std::chrono::milliseconds; auto t = clock::now(); int n = 1; int x = 0; auto callback = [&](const vec> &p) { auto copy = p; std::sort(copy.begin(), copy.end()); x += copy[copy.size() / 3][0]; }; auto something = range(1, std::atoi(argv[1])); partition_filtered(something, callback, 2); std::cout << x << std::endl; std::cout << (duration_cast(clock::now() - t).count() / 1e3) << std::endl; } ================================================ FILE: bench/codon/set_partition.py ================================================ # https://stackoverflow.com/questions/73473074/speed-up-set-partition-generation-by-skipping-ones-with-subsets-smaller-or-large import sys import time def conforms(candidate, minsize, forgive): """ Check if partition `candidate` is at most `forgive` additions from making all its elements conform to having minimum size `minsize` """ deficit = 0 for p in candidate: need = minsize - len(p) if need > 0: deficit += need # Is the deficit small enough? return (deficit <= forgive) def partition_filtered(collection, minsize=1, forgive=0): """ Generate partitions that contain at least `minsize` elements per set; allow `forgive` missing elements, which can get added in subsequent steps """ if len(collection) == 1: yield [ collection ] return first = collection[0] for smaller in partition_filtered(collection[1:], minsize, forgive=forgive+1): # insert `first` in each of the subpartition's subsets for n, subset in enumerate(smaller): candidate = smaller[:n] + [[ first ] + subset] + smaller[n+1:] if conforms(candidate, minsize, forgive): yield candidate # put `first` in its own subset candidate = [ [ first ] ] + smaller if conforms(candidate, minsize, forgive): yield candidate import time t = time.time() something = list(range(1, int(sys.argv[1]))) v = partition_filtered(something, minsize=2) x = 0 for p in v: p.sort() x += p[len(p) // 3][0] print(x) print(time.time() - t) ================================================ FILE: bench/codon/spectral_norm.py ================================================ """ MathWorld: "Hundred-Dollar, Hundred-Digit Challenge Problems", Challenge #3. http://mathworld.wolfram.com/Hundred-DollarHundred-DigitChallengeProblems.html The Computer Language Benchmarks Game http://benchmarksgame.alioth.debian.org/u64q/spectralnorm-description.html#spectralnorm Contributed by Sebastien Loisel Fixed by Isaac Gouy Sped up by Josh Goldfoot Dirtily sped up by Simon Descarpentries Concurrency by Jason Stitt Adapted for Codon by @arshajii """ from time import time DEFAULT_N = 260 def eval_A(i, j): return 1.0 / ((i + j) * (i + j + 1) // 2 + i + 1) def eval_times_u(func, u): return [func((i, u)) for i in range(len(list(u)))] def part_At_times_u(i_u): i, u = i_u partial_sum = 0. for j, u_j in enumerate(u): partial_sum += eval_A(j, i) * u_j return partial_sum def part_A_times_u(i_u): i, u = i_u partial_sum = 0. for j, u_j in enumerate(u): partial_sum += eval_A(i, j) * u_j return partial_sum def eval_AtA_times_u(u): return eval_times_u(part_At_times_u, eval_times_u(part_A_times_u, u)) def bench_spectral_norm(loops): range_it = range(loops) total = 0. for _ in range_it: u = [1.] * DEFAULT_N v = None for dummy in range(10): v = eval_AtA_times_u(u) u = eval_AtA_times_u(v) vBv = vv = 0. for ue, ve in zip(u, v): vBv += ue * ve vv += ve * ve total += vBv + vv return total t0 = time() print(bench_spectral_norm(100)) t1 = time() print(t1 - t0) ================================================ FILE: bench/codon/sum.py ================================================ # https://towardsdatascience.com/getting-started-with-pypy-ef4ba5cb431c import time t1 = time.time() nums = range(50000000) sum = 0 for k in nums: sum = sum + k print("Sum of 50000000 numbers is : ", sum) t2 = time.time() t = t2 - t1 print(t) ================================================ FILE: bench/codon/taq.cpp ================================================ #include #include #include #include #include #include #include #include #include #include namespace { template double mean(It begin, It end) { double sum = std::accumulate(begin, end, 0.0); double mean = sum / std::distance(begin, end); return mean; } template double stdev(It begin, It end) { auto n = std::distance(begin, end); double sum = std::accumulate(begin, end, 0.0); double mean = sum / n; double sq_sum = std::inner_product(begin, end, begin, 0.0); double stdev = std::sqrt(sq_sum / n - mean * mean); return stdev; } std::vector find_peaks(const std::vector &y) { int lag = 100; double threshold = 10.0; double influence = 0.5; int t = y.size(); std::vector signals(t); if (t <= lag) return signals; std::vector filtered_y; filtered_y.reserve(t); for (int i = 0; i < t; i++) filtered_y.push_back(i < lag ? y[i] : 0.0); std::vector avg_filter(t); std::vector std_filter(t); avg_filter[lag] = mean(y.begin(), y.begin() + lag); avg_filter[lag] = stdev(y.begin(), y.begin() + lag); for (int i = lag; i < t; i++) { if (std::abs(y[i] - avg_filter[i - 1]) > threshold * std_filter[i - 1]) { signals[i] = y[i] > avg_filter[i - 1] ? +1 : -1; filtered_y[i] = influence * y[i] + (1 - influence) * filtered_y[i - 1]; } else { signals[i] = 0; filtered_y[i] = y[i]; } avg_filter[i] = mean(filtered_y.begin() + (i - lag), filtered_y.begin() + i); std_filter[i] = stdev(filtered_y.begin() + (i - lag), filtered_y.begin() + i); } return signals; } std::pair, std::vector> process_data(const std::vector> &series) { std::unordered_map grouped; for (const auto &p : series) { auto bucket = p.first; auto volume = p.second; grouped[bucket] += volume; } std::vector> temp; temp.reserve(grouped.size()); for (const auto &p : grouped) temp.emplace_back(p.first, p.second); std::sort(temp.begin(), temp.end()); std::vector y; y.reserve(grouped.size()); for (const auto &p : temp) y.push_back(p.second); return {y, find_peaks(y)}; } const uint64_t BUCKET_SIZE = 1000000000; } // namespace int main(int argc, char *argv[]) { using clock = std::chrono::high_resolution_clock; using std::chrono::duration_cast; using std::chrono::milliseconds; auto t = clock::now(); std::unordered_map>> data; std::ifstream file(argv[1]); bool header = true; for (std::string line; std::getline(file, line);) { if (header) { header = false; continue; } std::stringstream ss(line); std::vector x; for (std::string field; std::getline(ss, field, '|');) x.push_back(field); if (x[0] == "END" || x[4] == "ENDP") continue; uint64_t timestamp = std::stoull(x[0]); std::string symbol = x[2]; long volume = std::stol(x[4]); data[symbol].emplace_back(timestamp / BUCKET_SIZE, volume); } for (auto &e : data) { auto symbol = e.first; auto &series = e.second; auto p = process_data(series); auto &signals = p.second; std::cout << symbol << " " << std::reduce(signals.begin(), signals.end()) << std::endl; } std::cout << (duration_cast(clock::now() - t).count() / 1e3) << std::endl; } ================================================ FILE: bench/codon/taq.py ================================================ # Parses TAQ file and performs volume peak detection from sys import argv from time import time from statistics import mean, stdev # https://stackoverflow.com/questions/22583391/peak-signal-detection-in-realtime-timeseries-data def find_peaks(y): lag = 100 threshold = 10.0 influence = 0.5 t = len(y) signals = [0. for _ in range(t)] if t <= lag: return signals filtered_y = [y[i] if i < lag else 0. for i in range(t)] avg_filter = [0. for _ in range(t)] std_filter = [0. for _ in range(t)] avg_filter[lag] = mean(y[:lag]) std_filter[lag] = stdev(y[:lag]) for i in range(lag, t): if abs(y[i] - avg_filter[i-1]) > threshold * std_filter[i-1]: signals[i] = +1 if y[i] > avg_filter[i-1] else -1 filtered_y[i] = influence*y[i] + (1 - influence)*filtered_y[i-1] else: signals[i] = 0 filtered_y[i] = y[i] avg_filter[i] = mean(filtered_y[i-lag:i]) std_filter[i] = stdev(filtered_y[i-lag:i]) return signals def process_data(series): grouped = {} for bucket, volume in series: grouped[bucket] = grouped.get(bucket, 0) + volume y = [float(t[1]) for t in sorted(grouped.items())] return y, find_peaks(y) BUCKET_SIZE = 1_000_000_000 t0 = time() data = {} with open(argv[1]) as f: header = True for line in f: if header: header = False continue x = line.split('|') if x[0] == 'END' or x[4] == 'ENDP': continue timestamp = int(x[0]) symbol = x[2] volume = int(x[4]) series = data.setdefault(symbol, []) series.append((timestamp // BUCKET_SIZE, volume)) for symbol, series in data.items(): y, signals = process_data(series) print(symbol, sum(signals)) t1 = time() print(t1 - t0) ================================================ FILE: bench/codon/word_count.cpp ================================================ #include #include #include #include #include #include using namespace std; int main(int argc, char *argv[]) { using clock = chrono::high_resolution_clock; using chrono::duration_cast; using chrono::milliseconds; auto t = clock::now(); cin.tie(nullptr); cout.sync_with_stdio(false); if (argc != 2) { cerr << "Expected one argument." << endl; return -1; } ifstream file(argv[1]); if (!file.is_open()) { cerr << "Could not open file: " << argv[1] << endl; return -1; } unordered_map map; for (string line; getline(file, line);) { istringstream sin(line); for (string word; sin >> word;) map[word] += 1; } cout << map.size() << endl; cout << (duration_cast(clock::now() - t).count() / 1e3) << endl; } ================================================ FILE: bench/codon/word_count.py ================================================ from sys import argv from time import time t0 = time() wc = {} filename = argv[-1] with open(filename) as f: for l in f: for w in l.split(): wc[w] = wc.get(w, 0) + 1 print(len(wc)) t1 = time() print(t1 - t0) ================================================ FILE: bench/run.sh ================================================ # set -x trap 'echo Exiting...; exit' INT get_data() { git clone https://github.com/exaloop/seq mkdir -p test cp -r seq/test/* test/ cp -r ../test/* test/ mkdir -p build mkdir -p data curl -L https://ftp.nyse.com/Historical%20Data%20Samples/DAILY%20TAQ/EQY_US_ALL_NBBO_20250102.gz | gzip -d -c | head -n10000000 > data/taq.txt curl -L https://hgdownload.soe.ucsc.edu/goldenPath/hg38/chromosomes/chr22.fa.gz | gzip -d -c > data/chr22.fa samtools faidx data/chr22.fa curl -L https://github.com/lh3/biofast/releases/download/biofast-data-v1/biofast-data-v1.tar.gz | tar zxvf - -C data curl -L http://cb.csail.mit.edu/cb/seq/nbt/sw-data.tar.bz2 | tar jxvf - -C data curl -L http://cb.csail.mit.edu/cb/seq/nbt/umi-data.bz2 | bzip2 -c -d > data/hgmm_100_R1.fastq run/exe/bench_fasta 25000000 > data/three_.fa samtools faidx data/three_.fa samtools faidx data/three_.fa THREE > data/three.fa rm -f data/three_.fa samtools view https://hgdownload.cse.ucsc.edu/goldenPath/hg19/encodeDCC/wgEncodeSydhRnaSeq/wgEncodeSydhRnaSeqK562Ifna6hPolyaAln.bam chr22 -b -o data/rnaseq.bam curl -L https://hgdownload.soe.ucsc.edu/goldenPath/hg19/chromosomes/chr22.fa.gz | gzip -d -c > data/chr22_hg19.fa samtools faidx data/chr22_hg19.fa samtools index data/rnaseq.bam } compile() { name=$1 path=$2 extra=$3 echo -n "====> C: ${name} ${path} " start=$(date +%s.%N) CODON_DEBUG=lt /usr/bin/time -f 'time=%e mem=%M exit=%x' \ codon build -release $extra $path -o run/exe/${name}.exe \ >run/log/${name}.compile.txt 2>&1 duration=$(echo "$(date +%s.%N) $start" | awk '{printf "%.1f", $1-$2}') echo "[$? run/log/${name}.compile.txt ${duration}]" } run() { name=$1 args="${@:2}" echo -n " R: $name $args " start=$(date +%s.%N) eval "/usr/bin/time -o run/log/${name}.time.txt -f 'time=%e mem=%M exit=%x' run/exe/${name}.exe $args >run/log/${name}.run.txt 2>&1" duration=$(echo "$(date +%s.%N) $start" | awk '{printf "%.1f", $1-$2}') echo "[$? run/log/${name}.run.txt ${duration}]" } # get_data mkdir -p run/exe mkdir -p run/log mkdir -p build bench() { path=$1 args=$2 dirname=$(basename $(dirname $path)) filename=$(basename $path) filename=${filename%.*} extra="" if [[ $path == *"seq/"* ]] then dirname="seq_${dirname}" extra="-plugin seq" fi name="${dirname}_${filename}" istart=$(date +%s.%N) compile $name $path "$extra" run $name $args duration=$(echo "$(date +%s.%N) $istart" | awk '{printf "%.1f", $1-$2}') echo " T: ${duration}" } bench ../test/stdlib/bisect_test.codon bench ../test/stdlib/cmath_test.codon bench ../test/stdlib/datetime_test.codon bench ../test/stdlib/heapq_test.codon bench ../test/stdlib/itertools_test.codon bench ../test/stdlib/math_test.codon bench ../test/stdlib/operator_test.codon bench ../test/stdlib/random_test.codon bench ../test/stdlib/re_test.codon bench ../test/stdlib/sort_test.codon bench ../test/stdlib/statistics_test.codon bench ../test/stdlib/str_test.codon bench ../test/numpy/test_dtype.codon bench ../test/numpy/test_fft.codon bench ../test/numpy/test_functional.codon bench ../test/numpy/test_fusion.codon bench ../test/numpy/test_indexing.codon bench ../test/numpy/test_io.codon bench ../test/numpy/test_lib.codon bench ../test/numpy/test_linalg.codon bench ../test/numpy/test_loops.codon bench ../test/numpy/test_misc.codon bench ../test/numpy/test_ndmath.codon bench ../test/numpy/test_npdatetime.codon bench ../test/numpy/test_pybridge.codon bench ../test/numpy/test_reductions.codon bench ../test/numpy/test_routines.codon bench ../test/numpy/test_sorting.codon bench ../test/numpy/test_statistics.codon bench ../test/numpy/test_window.codon bench ../test/numpy/random_tests/test_mt19937.codon bench ../test/numpy/random_tests/test_pcg64.codon bench ../test/numpy/random_tests/test_philox.codon bench ../test/numpy/random_tests/test_sfc64.codon bench codon/binary_trees.codon 20 # 6s bench codon/chaos.codon /dev/null # 1s bench codon/fannkuch.codon 11 # 6s bench codon/float.py bench codon/go.codon # TODO: bench codon/mandelbrot.codon bench codon/nbody.py 10000000 # 6s bench codon/npbench.codon # ...s bench codon/set_partition.py 15 # 15s bench codon/spectral_norm.py bench codon/primes.codon 100000 # 3s bench codon/sum.py bench codon/taq.py data/taq.txt # 10s bench codon/word_count.py data/taq.txt # 10s bench ../seq/test/core/align.codon # 10s bench ../seq/test/core/big.codon bench ../seq/test/core/bltin.codon bench ../seq/test/core/bwtsa.codon # 10s bench ../seq/test/core/containers.codon bench ../seq/test/core/formats.codon bench ../seq/test/core/kmers.codon bench ../seq/test/core/match.codon bench ../seq/test/core/proteins.codon bench ../seq/test/core/serialization.codon bench ../seq/test/pipeline/canonical_opt.codon bench ../seq/test/pipeline/interalign.codon # 25s bench ../seq/test/pipeline/prefetch.codon bench ../seq/test/pipeline/revcomp_opt.codon # 1s bench ../seq/test/bench/16mer.codon data/chr22.fa # 10s bench ../seq/test/bench/bedcov.codon "data/biofast-data-v1/ex-anno.bed data/biofast-data-v1/ex-rna.bed" # 25s bench ../seq/test/bench/cpg.codon data/chr22.fa # 1s bench ../seq/test/bench/fasta.codon 25000000 # 10s bench ../seq/test/bench/fastx.codon "data/chr22.fa data/biofast-data-v1/M_abscessus_HiSeq.fq" # 15s # TODO: ../seq/test/bench/fmindex.codon bench ../seq/test/bench/fqcnt.codon data/biofast-data-v1/M_abscessus_HiSeq.fq # 1s bench ../seq/test/bench/hamming.codon data/chr22.fa # 20s bench ../seq/test/bench/hash.codon data/chr22.fa # 15s bench ../seq/test/bench/kmercnt.codon data/biofast-data-v1/M_abscessus_HiSeq.fq # 25s bench ../seq/test/bench/knucleotide.codon " #include #include #include #include #include #include #include #include #if !(defined(__EXCEPTIONS) || defined(__cpp_exceptions) || defined(_CPPUNWIND) || defined(CMRC_NO_EXCEPTIONS)) #define CMRC_NO_EXCEPTIONS 1 #endif namespace cmrc { namespace detail { struct dummy; } } #define CMRC_DECLARE(libid) \ namespace cmrc { namespace detail { \ struct dummy; \ static_assert(std::is_same::value, "CMRC_DECLARE() must only appear at the global namespace"); \ } } \ namespace cmrc { namespace libid { \ cmrc::embedded_filesystem get_filesystem(); \ } } static_assert(true, "") namespace cmrc { class file { const char* _begin = nullptr; const char* _end = nullptr; public: using iterator = const char*; using const_iterator = iterator; iterator begin() const noexcept { return _begin; } iterator cbegin() const noexcept { return _begin; } iterator end() const noexcept { return _end; } iterator cend() const noexcept { return _end; } std::size_t size() const { return static_cast(std::distance(begin(), end())); } file() = default; file(iterator beg, iterator end) noexcept : _begin(beg), _end(end) {} }; class directory_entry; namespace detail { class directory; class file_data; class file_or_directory { union _data_t { class file_data* file_data; class directory* directory; } _data; bool _is_file = true; public: explicit file_or_directory(file_data& f) { _data.file_data = &f; } explicit file_or_directory(directory& d) { _data.directory = &d; _is_file = false; } bool is_file() const noexcept { return _is_file; } bool is_directory() const noexcept { return !is_file(); } const directory& as_directory() const noexcept { assert(!is_file()); return *_data.directory; } const file_data& as_file() const noexcept { assert(is_file()); return *_data.file_data; } }; class file_data { public: const char* begin_ptr; const char* end_ptr; file_data(const file_data&) = delete; file_data(const char* b, const char* e) : begin_ptr(b), end_ptr(e) {} }; inline std::pair split_path(const std::string& path) { auto first_sep = path.find("/"); if (first_sep == path.npos) { return std::make_pair(path, ""); } else { return std::make_pair(path.substr(0, first_sep), path.substr(first_sep + 1)); } } struct created_subdirectory { class directory& directory; class file_or_directory& index_entry; }; class directory { std::list _files; std::list _dirs; std::map _index; using base_iterator = std::map::const_iterator; public: directory() = default; directory(const directory&) = delete; created_subdirectory add_subdir(std::string name) & { _dirs.emplace_back(); auto& back = _dirs.back(); auto& fod = _index.emplace(name, file_or_directory{back}).first->second; return created_subdirectory{back, fod}; } file_or_directory* add_file(std::string name, const char* begin, const char* end) & { assert(_index.find(name) == _index.end()); _files.emplace_back(begin, end); return &_index.emplace(name, file_or_directory{_files.back()}).first->second; } const file_or_directory* get(const std::string& path) const { auto pair = split_path(path); auto child = _index.find(pair.first); if (child == _index.end()) { return nullptr; } auto& entry = child->second; if (pair.second.empty()) { // We're at the end of the path return &entry; } if (entry.is_file()) { // We can't traverse into a file. Stop. return nullptr; } // Keep going down return entry.as_directory().get(pair.second); } class iterator { base_iterator _base_iter; base_iterator _end_iter; public: using value_type = directory_entry; using difference_type = std::ptrdiff_t; using pointer = const value_type*; using reference = const value_type&; using iterator_category = std::input_iterator_tag; iterator() = default; explicit iterator(base_iterator iter, base_iterator end) : _base_iter(iter), _end_iter(end) {} iterator begin() const noexcept { return *this; } iterator end() const noexcept { return iterator(_end_iter, _end_iter); } inline value_type operator*() const noexcept; bool operator==(const iterator& rhs) const noexcept { return _base_iter == rhs._base_iter; } bool operator!=(const iterator& rhs) const noexcept { return !(*this == rhs); } iterator& operator++() noexcept { ++_base_iter; return *this; } iterator operator++(int) noexcept { auto cp = *this; ++_base_iter; return cp; } }; using const_iterator = iterator; iterator begin() const noexcept { return iterator(_index.begin(), _index.end()); } iterator end() const noexcept { return iterator(); } }; inline std::string normalize_path(std::string path) { while (path.find("/") == 0) { path.erase(path.begin()); } while (!path.empty() && (path.rfind("/") == path.size() - 1)) { path.pop_back(); } auto off = path.npos; while ((off = path.find("//")) != path.npos) { path.erase(path.begin() + static_cast(off)); } return path; } using index_type = std::map; } // detail class directory_entry { std::string _fname; const detail::file_or_directory* _item; public: directory_entry() = delete; explicit directory_entry(std::string filename, const detail::file_or_directory& item) : _fname(filename) , _item(&item) {} const std::string& filename() const & { return _fname; } std::string filename() const && { return std::move(_fname); } bool is_file() const { return _item->is_file(); } bool is_directory() const { return _item->is_directory(); } }; directory_entry detail::directory::iterator::operator*() const noexcept { assert(begin() != end()); return directory_entry(_base_iter->first, _base_iter->second); } using directory_iterator = detail::directory::iterator; class embedded_filesystem { // Never-null: const cmrc::detail::index_type* _index; const detail::file_or_directory* _get(std::string path) const { path = detail::normalize_path(path); auto found = _index->find(path); if (found == _index->end()) { return nullptr; } else { return found->second; } } public: explicit embedded_filesystem(const detail::index_type& index) : _index(&index) {} file open(const std::string& path) const { auto entry_ptr = _get(path); if (!entry_ptr || !entry_ptr->is_file()) { #ifdef CMRC_NO_EXCEPTIONS fprintf(stderr, "Error no such file or directory: %s\n", path.c_str()); abort(); #else throw std::system_error(make_error_code(std::errc::no_such_file_or_directory), path); #endif } auto& dat = entry_ptr->as_file(); return file{dat.begin_ptr, dat.end_ptr}; } bool is_file(const std::string& path) const noexcept { auto entry_ptr = _get(path); return entry_ptr && entry_ptr->is_file(); } bool is_directory(const std::string& path) const noexcept { auto entry_ptr = _get(path); return entry_ptr && entry_ptr->is_directory(); } bool exists(const std::string& path) const noexcept { return !!_get(path); } directory_iterator iterate_directory(const std::string& path) const { auto entry_ptr = _get(path); if (!entry_ptr) { #ifdef CMRC_NO_EXCEPTIONS fprintf(stderr, "Error no such file or directory: %s\n", path.c_str()); abort(); #else throw std::system_error(make_error_code(std::errc::no_such_file_or_directory), path); #endif } if (!entry_ptr->is_directory()) { #ifdef CMRC_NO_EXCEPTIONS fprintf(stderr, "Error not a directory: %s\n", path.c_str()); abort(); #else throw std::system_error(make_error_code(std::errc::not_a_directory), path); #endif } return entry_ptr->as_directory().begin(); } }; } #endif // CMRC_CMRC_HPP_INCLUDED ]==]) set(cmrc_hpp "${CMRC_INCLUDE_DIR}/cmrc/cmrc.hpp" CACHE INTERNAL "") set(_generate 1) if(EXISTS "${cmrc_hpp}") file(READ "${cmrc_hpp}" _current) if(_current STREQUAL hpp_content) set(_generate 0) endif() endif() file(GENERATE OUTPUT "${cmrc_hpp}" CONTENT "${hpp_content}" CONDITION ${_generate}) add_library(cmrc-base INTERFACE) target_include_directories(cmrc-base INTERFACE $) # Signal a basic C++11 feature to require C++11. target_compile_features(cmrc-base INTERFACE cxx_nullptr) set_property(TARGET cmrc-base PROPERTY INTERFACE_CXX_EXTENSIONS OFF) add_library(cmrc::base ALIAS cmrc-base) function(cmrc_add_resource_library name) set(args ALIAS NAMESPACE TYPE) cmake_parse_arguments(ARG "" "${args}" "" "${ARGN}") # Generate the identifier for the resource library's namespace set(ns_re "[a-zA-Z_][a-zA-Z0-9_]*") if(NOT DEFINED ARG_NAMESPACE) # Check that the library name is also a valid namespace if(NOT name MATCHES "${ns_re}") message(SEND_ERROR "Library name is not a valid namespace. Specify the NAMESPACE argument") endif() set(ARG_NAMESPACE "${name}") else() if(NOT ARG_NAMESPACE MATCHES "${ns_re}") message(SEND_ERROR "NAMESPACE for ${name} is not a valid C++ namespace identifier (${ARG_NAMESPACE})") endif() endif() set(libname "${name}") # Check that type is either "STATIC" or "OBJECT", or default to "STATIC" if # not set if(NOT DEFINED ARG_TYPE) set(ARG_TYPE STATIC) elseif(NOT "${ARG_TYPE}" MATCHES "^(STATIC|OBJECT)$") message(SEND_ERROR "${ARG_TYPE} is not a valid TYPE (STATIC and OBJECT are acceptable)") set(ARG_TYPE STATIC) endif() # Generate a library with the compiled in character arrays. string(CONFIGURE [=[ #include #include #include namespace cmrc { namespace @ARG_NAMESPACE@ { namespace res_chars { // These are the files which are available in this resource library $, > } namespace { const cmrc::detail::index_type& get_root_index() { static cmrc::detail::directory root_directory_; static cmrc::detail::file_or_directory root_directory_fod{root_directory_}; static cmrc::detail::index_type root_index; root_index.emplace("", &root_directory_fod); struct dir_inl { class cmrc::detail::directory& directory; }; dir_inl root_directory_dir{root_directory_}; (void)root_directory_dir; $, > $, > return root_index; } } cmrc::embedded_filesystem get_filesystem() { static auto& index = get_root_index(); return cmrc::embedded_filesystem{index}; } } // @ARG_NAMESPACE@ } // cmrc ]=] cpp_content @ONLY) get_filename_component(libdir "${CMAKE_CURRENT_BINARY_DIR}/__cmrc_${name}" ABSOLUTE) get_filename_component(lib_tmp_cpp "${libdir}/lib_.cpp" ABSOLUTE) string(REPLACE "\n " "\n" cpp_content "${cpp_content}") file(GENERATE OUTPUT "${lib_tmp_cpp}" CONTENT "${cpp_content}") get_filename_component(libcpp "${libdir}/lib.cpp" ABSOLUTE) add_custom_command(OUTPUT "${libcpp}" DEPENDS "${lib_tmp_cpp}" "${cmrc_hpp}" COMMAND ${CMAKE_COMMAND} -E copy_if_different "${lib_tmp_cpp}" "${libcpp}" COMMENT "Generating ${name} resource loader" ) # Generate the actual static library. Each source file is just a single file # with a character array compiled in containing the contents of the # corresponding resource file. add_library(${name} ${ARG_TYPE} ${libcpp}) set_property(TARGET ${name} PROPERTY CMRC_LIBDIR "${libdir}") set_property(TARGET ${name} PROPERTY CMRC_NAMESPACE "${ARG_NAMESPACE}") target_link_libraries(${name} PUBLIC cmrc::base) set_property(TARGET ${name} PROPERTY CMRC_IS_RESOURCE_LIBRARY TRUE) if(ARG_ALIAS) add_library("${ARG_ALIAS}" ALIAS ${name}) endif() cmrc_add_resources(${name} ${ARG_UNPARSED_ARGUMENTS}) endfunction() function(_cmrc_register_dirs name dirpath) if(dirpath STREQUAL "") return() endif() # Skip this dir if we have already registered it get_target_property(registered "${name}" _CMRC_REGISTERED_DIRS) if(dirpath IN_LIST registered) return() endif() # Register the parent directory first get_filename_component(parent "${dirpath}" DIRECTORY) if(NOT parent STREQUAL "") _cmrc_register_dirs("${name}" "${parent}") endif() # Now generate the registration set_property(TARGET "${name}" APPEND PROPERTY _CMRC_REGISTERED_DIRS "${dirpath}") _cm_encode_fpath(sym "${dirpath}") if(parent STREQUAL "") set(parent_sym root_directory) else() _cm_encode_fpath(parent_sym "${parent}") endif() get_filename_component(leaf "${dirpath}" NAME) set_property( TARGET "${name}" APPEND PROPERTY CMRC_MAKE_DIRS "static auto ${sym}_dir = ${parent_sym}_dir.directory.add_subdir(\"${leaf}\")\;" "root_index.emplace(\"${dirpath}\", &${sym}_dir.index_entry)\;" ) endfunction() function(cmrc_add_resources name) get_target_property(is_reslib ${name} CMRC_IS_RESOURCE_LIBRARY) if(NOT TARGET ${name} OR NOT is_reslib) message(SEND_ERROR "cmrc_add_resources called on target '${name}' which is not an existing resource library") return() endif() set(options) set(args WHENCE PREFIX) set(list_args) cmake_parse_arguments(ARG "${options}" "${args}" "${list_args}" "${ARGN}") if(NOT ARG_WHENCE) set(ARG_WHENCE ${CMAKE_CURRENT_SOURCE_DIR}) endif() _cmrc_normalize_path(ARG_WHENCE) get_filename_component(ARG_WHENCE "${ARG_WHENCE}" ABSOLUTE) # Generate the identifier for the resource library's namespace get_target_property(lib_ns "${name}" CMRC_NAMESPACE) get_target_property(libdir ${name} CMRC_LIBDIR) get_target_property(target_dir ${name} SOURCE_DIR) file(RELATIVE_PATH reldir "${target_dir}" "${CMAKE_CURRENT_SOURCE_DIR}") if(reldir MATCHES "^\\.\\.") message(SEND_ERROR "Cannot call cmrc_add_resources in a parent directory from the resource library target") return() endif() foreach(input IN LISTS ARG_UNPARSED_ARGUMENTS) _cmrc_normalize_path(input) get_filename_component(abs_in "${input}" ABSOLUTE) # Generate a filename based on the input filename that we can put in # the intermediate directory. file(RELATIVE_PATH relpath "${ARG_WHENCE}" "${abs_in}") if(relpath MATCHES "^\\.\\.") # For now we just error on files that exist outside of the soure dir. message(SEND_ERROR "Cannot add file '${input}': File must be in a subdirectory of ${ARG_WHENCE}") continue() endif() if(DEFINED ARG_PREFIX) _cmrc_normalize_path(ARG_PREFIX) endif() if(ARG_PREFIX AND NOT ARG_PREFIX MATCHES "/$") set(ARG_PREFIX "${ARG_PREFIX}/") endif() get_filename_component(dirpath "${ARG_PREFIX}${relpath}" DIRECTORY) _cmrc_register_dirs("${name}" "${dirpath}") get_filename_component(abs_out "${libdir}/intermediate/${ARG_PREFIX}${relpath}.cpp" ABSOLUTE) # Generate a symbol name relpath the file's character array _cm_encode_fpath(sym "${relpath}") # Get the symbol name for the parent directory if(dirpath STREQUAL "") set(parent_sym root_directory) else() _cm_encode_fpath(parent_sym "${dirpath}") endif() # Generate the rule for the intermediate source file _cmrc_generate_intermediate_cpp(${lib_ns} ${sym} "${abs_out}" "${abs_in}") target_sources(${name} PRIVATE "${abs_out}") set_property(TARGET ${name} APPEND PROPERTY CMRC_EXTERN_DECLS "// Pointers to ${input}" "extern const char* const ${sym}_begin\;" "extern const char* const ${sym}_end\;" ) get_filename_component(leaf "${relpath}" NAME) set_property( TARGET ${name} APPEND PROPERTY CMRC_MAKE_FILES "root_index.emplace(" " \"${ARG_PREFIX}${relpath}\"," " ${parent_sym}_dir.directory.add_file(" " \"${leaf}\"," " res_chars::${sym}_begin," " res_chars::${sym}_end" " )" ")\;" ) endforeach() endfunction() function(_cmrc_generate_intermediate_cpp lib_ns symbol outfile infile) add_custom_command( # This is the file we will generate OUTPUT "${outfile}" # These are the primary files that affect the output DEPENDS "${infile}" "${_CMRC_SCRIPT}" COMMAND "${CMAKE_COMMAND}" -D_CMRC_GENERATE_MODE=TRUE -DNAMESPACE=${lib_ns} -DSYMBOL=${symbol} "-DINPUT_FILE=${infile}" "-DOUTPUT_FILE=${outfile}" -P "${_CMRC_SCRIPT}" COMMENT "Generating intermediate file for ${infile}" ) endfunction() function(_cm_encode_fpath var fpath) string(MAKE_C_IDENTIFIER "${fpath}" ident) string(MD5 hash "${fpath}") string(SUBSTRING "${hash}" 0 4 hash) set(${var} f_${hash}_${ident} PARENT_SCOPE) endfunction() ================================================ FILE: cmake/backtrace-config.h.in ================================================ /* config.h.cmake */ /* ELF size: 32 or 64 */ #define BACKTRACE_ELF_SIZE ${BACKTRACE_ELF_SIZE} /* XCOFF size: 32 or 64 */ #cmakedefine BACKTRACE_XCOFF_SIZE /* Define to 1 if you have the __atomic functions */ #cmakedefine HAVE_ATOMIC_FUNCTIONS 1 /* Define to 1 if you have the `clock_gettime' function. */ #cmakedefine HAVE_CLOCK_GETTIME 1 /* Define to 1 if you have the declaration of `strnlen', and to 0 if you don't. */ #cmakedefine HAVE_DECL_STRNLEN 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_DLFCN_H 1 /* Define if dl_iterate_phdr is available. */ #cmakedefine HAVE_DL_ITERATE_PHDR 1 /* Define to 1 if you have the fcntl function */ #cmakedefine HAVE_FCNTL 1 /* Define if getexecname is available. */ #cmakedefine HAVE_GETEXECNAME 1 /* Define if _Unwind_GetIPInfo is available. */ #cmakedefine HAVE_GETIPINFO 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_INTTYPES_H 1 /* Define to 1 if you have the `z' library (-lz). */ #cmakedefine HAVE_LIBZ 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_LINK_H 1 /* Define if AIX loadquery is available. */ #cmakedefine HAVE_LOADQUERY 1 /* Define to 1 if you have the `lstat' function. */ #cmakedefine HAVE_LSTAT 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_MACH_O_DYLD_H 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_MEMORY_H 1 /* Define to 1 if you have the `readlink' function. */ #cmakedefine HAVE_READLINK 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_STDINT_H 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_STDLIB_H 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_STRINGS_H 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_STRING_H 1 /* Define to 1 if you have the __sync functions */ #cmakedefine HAVE_SYNC_FUNCTIONS 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_SYS_LDR_H 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_SYS_MMAN_H 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_SYS_STAT_H 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_SYS_TYPES_H 1 /* Define to 1 if you have the header file. */ #cmakedefine HAVE_UNISTD_H 1 /* Define if -lz is available. */ #cmakedefine HAVE_ZLIB 1 ================================================ FILE: cmake/backtrace-supported.h.in ================================================ /* backtrace-supported.h.cmake -- Whether stack backtrace is supported. Copyright (C) 2012-2016 Free Software Foundation, Inc. Based backtrace-supported.h.in, written by Ian Lance Taylor, Google. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: (1) Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. (2) Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. (3) The name of the author may not be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ /* The file backtrace-supported.h.in is used by configure to generate the file backtrace-supported.h. The file backtrace-supported.h may be #include'd to see whether the backtrace library will be able to get a backtrace and produce symbolic information. */ /* BACKTRACE_SUPPORTED will be #define'd as 1 if the backtrace library should work, 0 if it will not. Libraries may #include this to make other arrangements. */ #cmakedefine01 BACKTRACE_SUPPORTED /* BACKTRACE_USES_MALLOC will be #define'd as 1 if the backtrace library will call malloc as it works, 0 if it will call mmap instead. This may be used to determine whether it is safe to call the backtrace functions from a signal handler. In general this only applies to calls like backtrace and backtrace_pcinfo. It does not apply to backtrace_simple, which never calls malloc. It does not apply to backtrace_print, which always calls fprintf and therefore malloc. */ #cmakedefine01 BACKTRACE_USES_MALLOC /* BACKTRACE_SUPPORTS_THREADS will be #define'd as 1 if the backtrace library is configured with threading support, 0 if not. If this is 0, the threaded parameter to backtrace_create_state must be passed as 0. */ #cmakedefine01 BACKTRACE_SUPPORTS_THREADS /* BACKTRACE_SUPPORTS_DATA will be #defined'd as 1 if the backtrace_syminfo will work for variables. It will always work for functions. */ #cmakedefine01 BACKTRACE_SUPPORTS_DATA ================================================ FILE: cmake/config.h.in ================================================ #pragma once #define CODON_VERSION "@PROJECT_VERSION@" #define CODON_VERSION_MAJOR @PROJECT_VERSION_MAJOR@ #define CODON_VERSION_MINOR @PROJECT_VERSION_MINOR@ #define CODON_VERSION_PATCH @PROJECT_VERSION_PATCH@ ================================================ FILE: cmake/config.py.in ================================================ __version__ = "@CODON_JIT_PYTHON_VERSION@" CODON_VERSION = "@PROJECT_VERSION@" CODON_VERSION_MAJOR = @PROJECT_VERSION_MAJOR@ CODON_VERSION_MINOR = @PROJECT_VERSION_MINOR@ CODON_VERSION_PATCH = @PROJECT_VERSION_PATCH@ ================================================ FILE: cmake/deps.cmake ================================================ set(CPM_DOWNLOAD_VERSION 0.40.8) set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake") if(NOT (EXISTS ${CPM_DOWNLOAD_LOCATION})) message(STATUS "Downloading CPM.cmake...") file(DOWNLOAD https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake ${CPM_DOWNLOAD_LOCATION}) endif() include(${CPM_DOWNLOAD_LOCATION}) CPMAddPackage( NAME peglib GITHUB_REPOSITORY "exaloop/cpp-peglib" GIT_TAG codon OPTIONS "BUILD_TESTS OFF") CPMAddPackage( NAME fmt GITHUB_REPOSITORY "fmtlib/fmt" GIT_TAG 11.1.0 OPTIONS "CMAKE_POSITION_INDEPENDENT_CODE ON" "FMT_INSTALL ON") CPMAddPackage( NAME toml GITHUB_REPOSITORY "marzer/tomlplusplus" GIT_TAG v3.2.0) CPMAddPackage( NAME semver GITHUB_REPOSITORY "Neargye/semver" GIT_TAG v0.3.0) CPMAddPackage( NAME zlibng GITHUB_REPOSITORY "zlib-ng/zlib-ng" VERSION 2.0.5 GIT_TAG 2.0.5 EXCLUDE_FROM_ALL YES OPTIONS "HAVE_OFF64_T ON" "ZLIB_COMPAT ON" "ZLIB_ENABLE_TESTS OFF" "CMAKE_POSITION_INDEPENDENT_CODE ON") if(zlibng_ADDED) set_target_properties(zlib PROPERTIES EXCLUDE_FROM_ALL ON) endif() CPMAddPackage( NAME xz GITHUB_REPOSITORY "xz-mirror/xz" VERSION 5.2.5 GIT_TAG e7da44d5151e21f153925781ad29334ae0786101 EXCLUDE_FROM_ALL YES OPTIONS "BUILD_SHARED_LIBS OFF" "CMAKE_POSITION_INDEPENDENT_CODE ON") if(xz_ADDED) set_target_properties(xz PROPERTIES EXCLUDE_FROM_ALL ON) set_target_properties(xzdec PROPERTIES EXCLUDE_FROM_ALL ON) endif() CPMAddPackage( NAME bz2 URL "https://www.sourceware.org/pub/bzip2/bzip2-1.0.8.tar.gz" DOWNLOAD_ONLY YES) if(bz2_ADDED) add_library(bz2 STATIC "${bz2_SOURCE_DIR}/blocksort.c" "${bz2_SOURCE_DIR}/huffman.c" "${bz2_SOURCE_DIR}/crctable.c" "${bz2_SOURCE_DIR}/randtable.c" "${bz2_SOURCE_DIR}/compress.c" "${bz2_SOURCE_DIR}/decompress.c" "${bz2_SOURCE_DIR}/bzlib.c" "${bz2_SOURCE_DIR}/libbz2.def") set_target_properties(bz2 PROPERTIES COMPILE_FLAGS "-D_FILE_OFFSET_BITS=64" POSITION_INDEPENDENT_CODE ON) endif() CPMAddPackage( NAME bdwgc GITHUB_REPOSITORY "exaloop/bdwgc" VERSION 8.0.5 GIT_TAG e16c67244aff26802203060422545d38305e0160 EXCLUDE_FROM_ALL YES OPTIONS "CMAKE_POSITION_INDEPENDENT_CODE ON" "BUILD_SHARED_LIBS OFF" "enable_threads ON" "enable_large_config ON" "enable_thread_local_alloc ON" "enable_handle_fork ON") if(bdwgc_ADDED) set_target_properties(cord PROPERTIES EXCLUDE_FROM_ALL ON) endif() CPMAddPackage( NAME backtrace GITHUB_REPOSITORY "ianlancetaylor/libbacktrace" GIT_TAG d0f5e95a87a4d3e0a1ed6c069b5dae7cbab3ed2a DOWNLOAD_ONLY YES) if(backtrace_ADDED) set(backtrace_SOURCES "${backtrace_SOURCE_DIR}/atomic.c" "${backtrace_SOURCE_DIR}/backtrace.c" "${backtrace_SOURCE_DIR}/dwarf.c" "${backtrace_SOURCE_DIR}/fileline.c" "${backtrace_SOURCE_DIR}/mmapio.c" "${backtrace_SOURCE_DIR}/mmap.c" "${backtrace_SOURCE_DIR}/posix.c" "${backtrace_SOURCE_DIR}/print.c" "${backtrace_SOURCE_DIR}/simple.c" "${backtrace_SOURCE_DIR}/sort.c" "${backtrace_SOURCE_DIR}/state.c") # https://go.googlesource.com/gollvm/+/refs/heads/master/cmake/modules/LibbacktraceUtils.cmake set(BACKTRACE_SUPPORTED 1) set(BACKTRACE_ELF_SIZE 64) set(HAVE_GETIPINFO 1) set(BACKTRACE_USES_MALLOC 1) set(BACKTRACE_SUPPORTS_THREADS 1) set(BACKTRACE_SUPPORTS_DATA 1) set(HAVE_SYNC_FUNCTIONS 1) if(APPLE) set(HAVE_MACH_O_DYLD_H 1) list(APPEND backtrace_SOURCES "${backtrace_SOURCE_DIR}/macho.c") else() set(HAVE_MACH_O_DYLD_H 0) list(APPEND backtrace_SOURCES "${backtrace_SOURCE_DIR}/elf.c") endif() # Generate backtrace-supported.h based on the above. configure_file( ${CMAKE_SOURCE_DIR}/cmake/backtrace-supported.h.in ${backtrace_SOURCE_DIR}/backtrace-supported.h) configure_file( ${CMAKE_SOURCE_DIR}/cmake/backtrace-config.h.in ${backtrace_SOURCE_DIR}/config.h) add_library(backtrace STATIC ${backtrace_SOURCES}) target_include_directories(backtrace BEFORE PRIVATE "${backtrace_SOURCE_DIR}") set_target_properties(backtrace PROPERTIES COMPILE_FLAGS "-funwind-tables -D_GNU_SOURCE" POSITION_INDEPENDENT_CODE ON) endif() CPMAddPackage( NAME re2 GITHUB_REPOSITORY "google/re2" VERSION 2022-06-01 GIT_TAG 5723bb8950318135ed9cf4fc76bed988a087f536 EXCLUDE_FROM_ALL YES OPTIONS "CMAKE_POSITION_INDEPENDENT_CODE ON" "BUILD_SHARED_LIBS OFF" "RE2_BUILD_TESTING OFF") CPMAddPackage( NAME fast_float GITHUB_REPOSITORY "fastfloat/fast_float" GIT_TAG v6.1.1 EXCLUDE_FROM_ALL YES) if(NOT APPLE) enable_language(Fortran) CPMAddPackage( NAME openblas GITHUB_REPOSITORY "OpenMathLib/OpenBLAS" GIT_TAG v0.3.29 EXCLUDE_FROM_ALL YES OPTIONS "DYNAMIC_ARCH ON" "BUILD_TESTING OFF" "BUILD_BENCHMARKS OFF" "NUM_THREADS 64" "CCOMMON_OPT -O3") endif() CPMAddPackage( NAME highway GITHUB_REPOSITORY "google/highway" GIT_TAG 1.3.0 EXCLUDE_FROM_ALL YES OPTIONS "HWY_ENABLE_CONTRIB ON" "HWY_ENABLE_EXAMPLES OFF" "HWY_ENABLE_INSTALL OFF" "HWY_ENABLE_TESTS OFF" "BUILD_TESTING OFF") ================================================ FILE: codon/app/main.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include #include #include #include #include #include #include #include "codon/cir/util/format.h" #include "codon/compiler/compiler.h" #include "codon/compiler/error.h" #include "codon/compiler/jit.h" #include "codon/parser/common.h" #include "codon/util/common.h" #include "codon/util/jupyter.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" namespace { void versMsg(llvm::raw_ostream &out) { out << CODON_VERSION_MAJOR << "." << CODON_VERSION_MINOR << "." << CODON_VERSION_PATCH << "\n"; } bool isMacOS() { #ifdef __APPLE__ return true; #else return false; #endif } const std::vector &supportedExtensions() { static const std::vector extensions = {".codon", ".py", ".seq"}; return extensions; } bool hasExtension(const std::string &filename, const std::string &extension) { return filename.size() >= extension.size() && filename.compare(filename.size() - extension.size(), extension.size(), extension) == 0; } std::string trimExtension(const std::string &filename, const std::string &extension) { if (hasExtension(filename, extension)) { return filename.substr(0, filename.size() - extension.size()); } else { return filename; } } std::string makeOutputFilename(const std::string &filename, const std::string &extension) { for (const auto &ext : supportedExtensions()) { if (hasExtension(filename, ext)) return trimExtension(filename, ext) + extension; } return filename + extension; } void display(const codon::error::ParserErrorInfo &e) { using codon::MessageGroupPos; std::unordered_set seen; for (auto &group : e.getErrors()) { int i = 0; for (auto &msg : group) { auto t = msg.toString(); if (seen.find(t) != seen.end()) { continue; } seen.insert(t); MessageGroupPos pos = MessageGroupPos::NONE; if (i == 0) { pos = MessageGroupPos::HEAD; } else if (i == group.size() - 1) { pos = MessageGroupPos::LAST; } else { pos = MessageGroupPos::MID; } i++; codon::compilationError(msg.getMessage(), msg.getFile(), msg.getLine(), msg.getColumn(), msg.getLength(), msg.getErrorCode(), /*terminate=*/false, pos); } } } void initLogFlags(const llvm::cl::opt &log) { codon::getLogger().parse(log); if (auto *d = getenv("CODON_DEBUG")) codon::getLogger().parse(std::string(d)); } enum BuildKind { LLVM, Bitcode, Object, Assembly, Executable, Library, PyExtension, Detect, CIR }; enum OptMode { Debug, Release }; enum Numerics { C, Python }; } // namespace int docMode(const std::vector &args, const std::string &argv0) { llvm::cl::opt input(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-")); llvm::cl::ParseCommandLineOptions(static_cast(args.size()), args.data()); std::vector files; auto collectPaths = [&files](const std::string &path) { llvm::sys::fs::file_status status; llvm::sys::fs::status(path, status); if (!llvm::sys::fs::exists(status)) { codon::compilationError(fmt::format("'{}' does not exist", path), "", 0, 0, 0, -1, false); } if (llvm::sys::fs::is_regular_file(status)) { files.emplace_back(path); } else if (llvm::sys::fs::is_directory(status)) { std::error_code ec; for (llvm::sys::fs::recursive_directory_iterator it(path, ec), e; it != e; it.increment(ec)) { auto status = it->status(); if (!status) continue; if (status->type() == llvm::sys::fs::file_type::regular_file) if (!codon::ast::endswith(it->path(), "__init_test__.codon")) files.emplace_back(it->path()); } } }; if (args.size() > 1) collectPaths(args[1]); auto compiler = std::make_unique(args[0]); bool failed = false; std::sort(files.begin(), files.end()); auto result = compiler->docgen(files); llvm::handleAllErrors(result.takeError(), [&failed](const codon::error::ParserErrorInfo &e) { display(e); failed = true; }); if (failed) return EXIT_FAILURE; fmt::print("{}\n", *result); return EXIT_SUCCESS; } std::unique_ptr processSource( const std::vector &args, bool standalone, std::function pyExtension = [] { return false; }) { llvm::cl::opt input(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-")); auto regs = llvm::cl::getRegisteredOptions(); llvm::cl::opt optMode( llvm::cl::desc("optimization mode"), llvm::cl::values( clEnumValN(Debug, regs.find("debug") != regs.end() ? "default" : "debug", "Turn off compiler optimizations and show backtraces"), clEnumValN(Release, "release", "Turn on compiler optimizations and disable debug info")), llvm::cl::init(Debug)); llvm::cl::list defines( "D", llvm::cl::Prefix, llvm::cl::desc("Add static variable definitions. The syntax is =")); llvm::cl::list disabledOpts( "disable-opt", llvm::cl::desc("Disable the specified IR optimization")); llvm::cl::list plugins("plugin", llvm::cl::desc("Load specified plugin")); llvm::cl::opt log("log", llvm::cl::desc("Enable given log streams")); llvm::cl::opt numerics( "numerics", llvm::cl::desc("numerical semantics"), llvm::cl::values( clEnumValN(C, "c", "C semantics: best performance but deviates from Python"), clEnumValN(Python, "py", "Python semantics: mirrors Python but might disable optimizations " "like vectorization")), llvm::cl::init(C)); llvm::cl::ParseCommandLineOptions(args.size(), args.data()); initLogFlags(log); std::unordered_map defmap; for (const auto &define : defines) { auto eq = define.find('='); if (eq == std::string::npos || !eq) { codon::compilationWarning("ignoring malformed definition: " + define); continue; } auto name = define.substr(0, eq); auto value = define.substr(eq + 1); if (defmap.find(name) != defmap.end()) { codon::compilationWarning("ignoring duplicate definition: " + define); continue; } defmap.emplace(name, value); } const bool isDebug = (optMode == OptMode::Debug); std::vector disabledOptsVec(disabledOpts); auto compiler = std::make_unique( args[0], isDebug, disabledOptsVec, /*isTest=*/false, (numerics == Numerics::Python), pyExtension()); compiler->getLLVMVisitor()->setStandalone(standalone); // load plugins for (const auto &plugin : plugins) { bool failed = false; llvm::handleAllErrors( compiler->load(plugin), [&failed](const codon::error::PluginErrorInfo &e) { codon::compilationError(e.getMessage(), /*file=*/"", /*line=*/0, /*col=*/0, /*len=*/0, /*errorCode=*/-1, /*terminate=*/false); failed = true; }); if (failed) return {}; } bool failed = false; int testFlags = 0; if (auto *tf = getenv("CODON_TEST_FLAGS")) testFlags = std::atoi(tf); llvm::handleAllErrors(compiler->parseFile(input, /*testFlags=*/testFlags, defmap), [&failed](const codon::error::ParserErrorInfo &e) { display(e); failed = true; }); if (failed) return {}; { TIME("compile"); llvm::cantFail(compiler->compile()); } return compiler; } int runMode(const std::vector &args) { llvm::cl::list libs( "l", llvm::cl::desc("Load and link the specified library")); llvm::cl::list progArgs(llvm::cl::ConsumeAfter, llvm::cl::desc("...")); auto compiler = processSource(args, /*standalone=*/false); if (!compiler) return EXIT_FAILURE; std::vector libsVec(libs); std::vector argsVec(progArgs); argsVec.insert(argsVec.begin(), compiler->getInput()); compiler->getLLVMVisitor()->run(argsVec, libsVec); return EXIT_SUCCESS; } namespace { std::string jitExec(codon::jit::JIT *jit, const std::string &code) { auto result = jit->execute(code); if (auto err = result.takeError()) { std::string output; llvm::handleAllErrors( std::move(err), [](const codon::error::ParserErrorInfo &e) { display(e); }, [&output](const codon::error::RuntimeErrorInfo &e) { std::stringstream buf; buf << e.getOutput(); buf << "\n\033[1mBacktrace:\033[0m\n"; for (const auto &line : e.getBacktrace()) { buf << " " << line << "\n"; } output = buf.str(); }); return output; } return *result; } void jitLoop(codon::jit::JIT *jit, std::istream &fp) { std::string code; for (std::string line; std::getline(fp, line);) { if (line != "#%%") { code += line + "\n"; } else { fmt::print("{}[done]\n", jitExec(jit, code)); code = ""; fflush(stdout); } } if (!code.empty()) fmt::print("{}[done]\n", jitExec(jit, code)); } } // namespace int jitMode(const std::vector &args) { llvm::cl::opt input(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-")); llvm::cl::list plugins("plugin", llvm::cl::desc("Load specified plugin")); llvm::cl::opt log("log", llvm::cl::desc("Enable given log streams")); llvm::cl::ParseCommandLineOptions(args.size(), args.data()); initLogFlags(log); codon::jit::JIT jit(args[0]); // load plugins for (const auto &plugin : plugins) { bool failed = false; llvm::handleAllErrors(jit.getCompiler()->load(plugin), [&failed](const codon::error::PluginErrorInfo &e) { codon::compilationError(e.getMessage(), /*file=*/"", /*line=*/0, /*col=*/0, /*len=*/0, /*errorCode=*/-1, /*terminate=*/false); failed = true; }); if (failed) return EXIT_FAILURE; } llvm::cantFail(jit.init()); fmt::print(">>> Codon JIT v{} <<<\n", CODON_VERSION); if (input == "-") { jitLoop(&jit, std::cin); } else { std::ifstream fileInput(input); jitLoop(&jit, fileInput); } return EXIT_SUCCESS; } int buildMode(const std::vector &args, const std::string &argv0) { llvm::cl::list libs( "l", llvm::cl::desc("Link the specified library (only for executables)")); llvm::cl::opt lflags("linker-flags", llvm::cl::desc("Pass given flags to linker")); llvm::cl::opt buildKind( llvm::cl::desc("output type"), llvm::cl::values( clEnumValN(LLVM, "llvm", "Generate LLVM IR"), clEnumValN(Bitcode, "bc", "Generate LLVM bitcode"), clEnumValN(Object, "obj", "Generate native object file"), clEnumValN(Assembly, "asm", "Generate assembly code"), clEnumValN(Executable, "exe", "Generate executable"), clEnumValN(Library, "lib", "Generate shared library"), clEnumValN(PyExtension, "pyext", "Generate Python extension module"), clEnumValN(CIR, "cir", "Generate Codon Intermediate Representation"), clEnumValN(Detect, "detect", "Detect output type based on output file extension")), llvm::cl::init(Detect)); llvm::cl::opt output( "o", llvm::cl::desc( "Write compiled output to specified file. Supported extensions: " "none (executable), .o (object file), .ll (LLVM IR), .bc (LLVM bitcode)")); llvm::cl::opt pyModule( "module", llvm::cl::desc("Python extension module name (only applicable when " "building Python extension module)")); auto compiler = processSource(args, /*standalone=*/true, [&] { return buildKind == BuildKind::PyExtension; }); if (!compiler) return EXIT_FAILURE; std::vector libsVec(libs); if (output.empty() && compiler->getInput() == "-") codon::compilationError("output file must be specified when reading from stdin"); std::string extension; switch (buildKind) { case BuildKind::LLVM: extension = ".ll"; break; case BuildKind::Bitcode: extension = ".bc"; break; case BuildKind::Object: case BuildKind::PyExtension: extension = ".o"; break; case BuildKind::Assembly: extension = ".s"; break; case BuildKind::Library: extension = isMacOS() ? ".dylib" : ".so"; break; case BuildKind::Executable: case BuildKind::Detect: extension = ""; break; case BuildKind::CIR: extension = ".cir"; break; default: seqassertn(0, "unknown build kind"); } const std::string filename = output.empty() ? makeOutputFilename(compiler->getInput(), extension) : output; switch (buildKind) { case BuildKind::LLVM: compiler->getLLVMVisitor()->writeToLLFile(filename); break; case BuildKind::Bitcode: compiler->getLLVMVisitor()->writeToBitcodeFile(filename); break; case BuildKind::Object: compiler->getLLVMVisitor()->writeToObjectFile(filename); break; case BuildKind::Assembly: compiler->getLLVMVisitor()->writeToObjectFile(filename, /*pic=*/false, /*assembly=*/true); break; case BuildKind::Executable: compiler->getLLVMVisitor()->writeToExecutable(filename, argv0, false, libsVec, lflags); break; case BuildKind::Library: compiler->getLLVMVisitor()->writeToExecutable(filename, argv0, true, libsVec, lflags); break; case BuildKind::PyExtension: compiler->getCache()->pyModule->name = pyModule.empty() ? llvm::sys::path::stem(compiler->getInput()).str() : pyModule; compiler->getLLVMVisitor()->writeToPythonExtension(*compiler->getCache()->pyModule, filename); break; case BuildKind::CIR: { std::ofstream out(filename); codon::ir::util::format(out, compiler->getModule()); break; } case BuildKind::Detect: compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags); break; default: seqassertn(0, "unknown build kind"); } return EXIT_SUCCESS; } int jupyterMode(const std::vector &args) { llvm::cl::list plugins("plugin", llvm::cl::desc("Load specified plugin")); llvm::cl::opt input(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("connection.json")); llvm::cl::ParseCommandLineOptions(args.size(), args.data()); int code = codon::startJupyterKernel(args[0], plugins, input); return code; } void showCommandsAndExit() { codon::compilationError("Available commands: codon "); } int otherMode(const std::vector &args) { llvm::cl::opt input(llvm::cl::Positional, llvm::cl::desc("")); llvm::cl::extrahelp("\nMODES:\n\n" " run - run a program interactively\n" " build - build a program\n" " doc - generate program documentation\n"); llvm::cl::ParseCommandLineOptions(args.size(), args.data()); if (!input.empty()) showCommandsAndExit(); return EXIT_SUCCESS; } int main(int argc, const char **argv) { if (argc < 2) showCommandsAndExit(); llvm::cl::SetVersionPrinter(versMsg); std::vector args{argv[0]}; for (int i = 2; i < argc; i++) args.push_back(argv[i]); std::string mode(argv[1]); std::string argv0 = std::string(args[0]) + " " + mode; if (mode == "run") { args[0] = argv0.data(); return runMode(args); } if (mode == "build") { const char *oldArgv0 = args[0]; args[0] = argv0.data(); return buildMode(args, oldArgv0); } if (mode == "doc") { const char *oldArgv0 = args[0]; args[0] = argv0.data(); return docMode(args, oldArgv0); } if (mode == "jit") { args[0] = argv0.data(); return jitMode(args); } if (mode == "jupyter") { args[0] = argv0.data(); return jupyterMode(args); } return otherMode({argv, argv + argc}); } ================================================ FILE: codon/cir/analyze/analysis.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "analysis.h" #include "codon/cir/transform/manager.h" namespace codon { namespace ir { namespace analyze { Result *Analysis::doGetAnalysis(const std::string &key) { return manager ? manager->getAnalysisResult(key) : nullptr; } } // namespace analyze } // namespace ir } // namespace codon ================================================ FILE: codon/cir/analyze/analysis.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/cir/module.h" #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace analyze { /// Analysis result base struct. struct Result { virtual ~Result() noexcept = default; }; /// Base class for IR analyses. class Analysis { private: transform::PassManager *manager = nullptr; public: virtual ~Analysis() noexcept = default; /// @return a unique key for this pass virtual std::string getKey() const = 0; /// Execute the analysis. /// @param module the module virtual std::unique_ptr run(const Module *module) = 0; /// Sets the manager. /// @param mng the new manager void setManager(transform::PassManager *mng) { manager = mng; } /// Returns the result of a given analysis. /// @param key the analysis key template AnalysisType *getAnalysisResult(const std::string &key) { return static_cast(doGetAnalysis(key)); } private: analyze::Result *doGetAnalysis(const std::string &key); }; } // namespace analyze } // namespace ir } // namespace codon ================================================ FILE: codon/cir/analyze/dataflow/capture.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "capture.h" #include #include #include #include "codon/cir/analyze/dataflow/reaching.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/side_effect.h" namespace codon { namespace ir { namespace analyze { namespace dataflow { namespace { template bool contains(const S &x, T i) { for (auto a : x) { if (a == i) return true; } return false; } template bool containsId(const S &x, T i) { for (auto a : x) { if (a->getId() == i->getId()) return true; } return false; } template bool shouldTrack(const T *x) { // We only care about things with pointers, // since you can't capture primitive types // like int, float, etc. return x && !x->getType()->isAtomic(); } template <> bool shouldTrack(const types::Type *x) { return x && !x->isAtomic(); } struct CaptureContext; bool extractVars(CaptureContext &cc, const Value *v, std::vector &result); bool reachable(CFBlock *start, CFBlock *end, std::unordered_set &seen) { if (start == end) return true; if (seen.count(start)) return false; seen.insert(start); for (auto it = start->successors_begin(); it != start->successors_end(); ++it) { if (reachable(*it, end, seen)) return true; } return false; } // Check if one value must always be encountered before another, if // it is to be encountered at all. This is NOT the same as domination // since we can have "(if _: B) ; A", where B does not dominate A yet // must always occur before A is it does occur. bool happensBefore(const Value *before, const Value *after, CFGraph *cfg, DominatorInspector *dom) { auto *beforeBlock = cfg->getBlock(before); auto *afterBlock = cfg->getBlock(after); if (!beforeBlock || !afterBlock) return false; // If values are in the same block we just need to see // which one shows up first. if (beforeBlock == afterBlock) { for (auto *val : *beforeBlock) { if (val->getId() == before->getId()) return true; else if (val->getId() == after->getId()) return false; } seqassertn(false, "could not find values in CFG block"); return false; } // If we have different blocks, then either 'before' dominates // 'after', in which case the answer is true, or there must be // no paths from 'afterBlock' to 'beforeBlock'. std::unordered_set seen; return dom->isDominated(after, before) || !reachable(afterBlock, beforeBlock, seen); } struct RDManager { struct IDPairHash { template std::size_t operator()(const std::pair &pair) const { return (std::hash()(pair.first) << 32) ^ std::hash()(pair.second); } }; RDInspector *rd; std::unordered_map, std::unordered_set, IDPairHash> cache; explicit RDManager(RDInspector *rd) : rd(rd), cache() {} std::unordered_set getReachingDefinitions(const Var *var, const Value *loc) { auto key = std::make_pair(var->getId(), loc->getId()); auto it = cache.find(key); if (it == cache.end()) { auto defs = rd->getReachingDefinitions(var, loc); std::unordered_set dset; for (auto &def : defs) { dset.insert(def.assignment->getId()); } cache.emplace(key, dset); return dset; } else { return it->second; } } bool isInvalid(const Var *v) { return rd->isInvalid(v); } }; struct DerivedSet { const Func *func; const Var *root; std::vector args; std::unordered_set derivedVals; std::unordered_map> derivedVars; CaptureInfo result; void setReturnCaptured() { if (shouldTrack(util::getReturnType(func))) result.returnCaptures = true; } void setExternCaptured() { setReturnCaptured(); result.externCaptures = true; } bool isDerived(const Var *v, const Value *loc, RDManager &rd) const { auto it = derivedVars.find(v->getId()); if (it == derivedVars.end()) return false; // We assume global references are always derived // if the var is derived, since they can change // at any point as far as we know. Same goes for // vars untracked by the reaching-def analysis. if (v->isGlobal() || rd.isInvalid(v)) return true; // Make sure the var at this point is reached by // at least one definition that has led to a // derived value. auto mySet = rd.getReachingDefinitions(v, loc); for (auto *cause : it->second) { auto otherSet = rd.getReachingDefinitions(v, cause); for (auto &elem : mySet) { if (otherSet.count(elem)) return true; } } return false; } bool isDerived(const Value *v) const { return derivedVals.find(v->getId()) != derivedVals.end(); } void setDerived(const Var *v, const Value *cause, bool shouldArgCapture = true) { if (!shouldTrack(v)) return; if (v->isGlobal()) setExternCaptured(); auto id = v->getId(); if (shouldArgCapture && root && id != root->getId()) { for (unsigned i = 0; i < args.size(); i++) { if (args[i] == id && !contains(result.argCaptures, i)) result.argCaptures.push_back(i); } } auto it = derivedVars.find(id); if (it == derivedVars.end()) { std::vector info = {cause}; derivedVars.emplace(id, info); } else { if (!containsId(it->second, cause)) it->second.push_back(cause); } } void setDerived(const Value *v) { if (!shouldTrack(v)) return; derivedVals.insert(v->getId()); } unsigned size() const { unsigned total = derivedVals.size(); for (auto &e : derivedVars) { total += e.second.size(); } return total; } explicit DerivedSet(const Func *func, const Var *root = nullptr) : func(func), root(root), args(), derivedVals(), derivedVars(), result() {} // Set for function argument DerivedSet(const Func *func, const Var *root, const Value *cause) : DerivedSet(func, root) { // extract arguments for (auto it = func->arg_begin(); it != func->arg_end(); ++it) { args.push_back((*it)->getId()); } setDerived(root, cause); } // Set for function argument DerivedSet(const Func *func, const Value *value, CaptureContext &cc) : DerivedSet(func) { std::vector vars; bool escapes = extractVars(cc, value, vars); if (escapes) setExternCaptured(); setDerived(value); for (auto *var : vars) { setDerived(var, value); } } }; bool noCaptureByAnnotation(const Func *func) { return util::hasAttribute(func, util::PURE_ATTR) || util::hasAttribute(func, util::NO_SIDE_EFFECT_ATTR) || util::hasAttribute(func, util::NO_CAPTURE_ATTR); } std::vector makeAllCaptureInfo(const Func *func) { std::vector result; for (auto it = func->arg_begin(); it != func->arg_end(); ++it) { result.push_back(CaptureInfo::unknown(func, (*it)->getType())); } return result; } std::vector makeNoCaptureInfo(const Func *func, bool derives) { std::vector result; for (auto it = func->arg_begin(); it != func->arg_end(); ++it) { auto info = CaptureInfo::nothing(); if (derives && shouldTrack(*it)) info.returnCaptures = true; result.push_back(info); } return result; } struct CaptureContext { RDResult *reaching; DominatorResult *dominating; std::unordered_map> results; CaptureContext(RDResult *reaching, DominatorResult *dominating) : reaching(reaching), dominating(dominating), results() {} std::vector get(const Func *func); void set(const Func *func, const std::vector &result); CFGraph *getCFGraph(const Func *func) { auto it = reaching->cfgResult->graphs.find(func->getId()); seqassertn(it != reaching->cfgResult->graphs.end(), "could not find function in CFG results"); return it->second.get(); } RDInspector *getRDInspector(const Func *func) { auto it = reaching->results.find(func->getId()); seqassertn(it != reaching->results.end(), "could not find function in reaching-definitions results"); return it->second.get(); } DominatorInspector *getDomInspector(const Func *func) { auto it = dominating->results.find(func->getId()); seqassertn(it != dominating->results.end(), "could not find function in dominator results"); return it->second.get(); } }; // This visitor answers the questions of what vars are // relevant to track in a capturing expression. For // example, in "a[i] = x", the expression "a[i]" captures // "x"; in this case we need to track "a" but the variable // "i" (typically) we would not care about. struct ExtractVars : public util::ConstVisitor { CaptureContext &cc; std::unordered_set vars; bool escapes; explicit ExtractVars(CaptureContext &cc) : util::ConstVisitor(), cc(cc), vars(), escapes(false) {} template void process(const Node *v) { v->accept(*this); } void add(const Var *v) { if (shouldTrack(v)) vars.insert(v->getId()); } void defaultVisit(const Node *) override {} void visit(const VarValue *v) override { add(v->getVar()); } void visit(const PointerValue *v) override { add(v->getVar()); } void visit(const CallInstr *v) override { if (auto *func = util::getFunc(v->getCallee())) { auto capInfo = cc.get(util::getFunc(v->getCallee())); unsigned i = 0; for (auto *arg : *v) { // note possibly capInfo.size() != v->numArgs() if calling vararg C function auto info = (i < capInfo.size()) ? capInfo[i] : CaptureInfo::unknown(func, arg->getType()); if (shouldTrack(arg) && capInfo[i].returnCaptures) process(arg); ++i; } } else { for (auto *arg : *v) { if (shouldTrack(arg)) process(arg); } } } void visit(const YieldInInstr *v) override { // We have no idea what the yield-in // value could be, so just assume we // escape in this case. escapes = true; } void visit(const TernaryInstr *v) override { process(v->getTrueValue()); process(v->getFalseValue()); } void visit(const ExtractInstr *v) override { process(v->getVal()); } void visit(const FlowInstr *v) override { process(v->getValue()); } void visit(const dsl::CustomInstr *v) override { // TODO } }; bool extractVars(CaptureContext &cc, const Value *v, std::vector &result) { auto *M = v->getModule(); ExtractVars ev(cc); v->accept(ev); for (auto id : ev.vars) { result.push_back(M->getVar(id)); } return ev.escapes; } struct CaptureTracker : public util::Operator { CaptureContext &cc; CFGraph *cfg; RDManager rd; DominatorInspector *dom; std::vector dsets; CaptureTracker(CaptureContext &cc, const Func *func, bool isArg) : Operator(), cc(cc), cfg(cc.getCFGraph(func)), rd(cc.getRDInspector(func)), dom(cc.getDomInspector(func)), dsets() {} CaptureTracker(CaptureContext &cc, const BodiedFunc *func) : CaptureTracker(cc, func, /*isArg=*/true) { // find synthetic assignments in CFG for argument vars auto *entry = cfg->getEntryBlock(); std::unordered_map synthAssigns; for (auto *v : *entry) { if (auto *synth = cast(v)) { if (shouldTrack(synth->getLhs())) synthAssigns[synth->getLhs()->getId()] = synth; } } // extract arguments std::vector args; for (auto it = func->arg_begin(); it != func->arg_end(); ++it) { args.push_back((*it)->getId()); } // make a derived set for each function argument for (auto it = func->arg_begin(); it != func->arg_end(); ++it) { if (!shouldTrack(*it)) continue; auto it2 = synthAssigns.find((*it)->getId()); seqassertn(it2 != synthAssigns.end(), "could not find synthetic assignment for arg var"); dsets.push_back(DerivedSet(func, *it, it2->second)); } } CaptureTracker(CaptureContext &cc, const BodiedFunc *func, const Value *value) : CaptureTracker(cc, func, /*isArg=*/false) { dsets.push_back(DerivedSet(func, value, cc)); } unsigned size() const { unsigned total = 0; for (auto &dset : dsets) { total += dset.size(); } return total; } void forEachDSetOf(Value *v, std::function func) { if (!v) return; for (auto &dset : dsets) { if (dset.isDerived(v)) func(dset); } } void forEachDSetOf(Var *v, Value *loc, std::function func) { if (!v) return; for (auto &dset : dsets) { if (dset.isDerived(v, loc, rd)) func(dset); } } void forwardLink(Value *from, Value *cause, const std::vector &toVars, bool toEscapes, bool shouldArgCapture) { forEachDSetOf(from, [&](DerivedSet &dset) { if (toEscapes) dset.setExternCaptured(); for (auto *toVar : toVars) { dset.setDerived(toVar, cause, shouldArgCapture); } }); } void backwardLinkFunc(DerivedSet &dset, Value *cause, const std::vector &toVars, const std::vector &fromVars, bool fromEscapes) { if (fromEscapes) dset.setExternCaptured(); for (auto *toVar : toVars) { auto it = dset.derivedVars.find(toVar->getId()); if (it == dset.derivedVars.end()) continue; auto &toCauses = it->second; for (auto *toCause : toCauses) { if (isA(toCause) || isA(toCause) || happensBefore(toCause, cause, cfg, dom)) continue; bool derived = false; if (toVar->isGlobal() || rd.isInvalid(toVar)) { derived = true; } else { auto mySet = rd.getReachingDefinitions(toVar, cause); auto otherSet = rd.getReachingDefinitions(toVar, toCause); for (auto &elem : mySet) { if (otherSet.count(elem)) { derived = true; break; } } } if (derived) { for (auto *fromVar : fromVars) { dset.setDerived(fromVar, toCause); } } } } } void link(Value *from, Value *to, Value *cause) { std::vector fromVars, toVars; bool fromEscapes = extractVars(cc, from, fromVars); bool toEscapes = extractVars(cc, to, toVars); forwardLink(from, cause, toVars, toEscapes, /*shouldArgCapture=*/true); forEachDSetOf(to, [&](DerivedSet &dset) { backwardLinkFunc(dset, cause, toVars, fromVars, fromEscapes); }); } void link(Value *from, Var *to, Value *cause) { std::vector fromVars, toVars = {to}; bool fromEscapes = extractVars(cc, from, fromVars); bool toEscapes = false; forwardLink(from, cause, toVars, toEscapes, /*shouldArgCapture=*/false); forEachDSetOf(to, cause, [&](DerivedSet &dset) { backwardLinkFunc(dset, cause, toVars, fromVars, fromEscapes); }); } void handle(VarValue *v) override { forEachDSetOf(v->getVar(), v, [&](DerivedSet &dset) { dset.setDerived(v); }); } void handle(PointerValue *v) override { forEachDSetOf(v->getVar(), v, [&](DerivedSet &dset) { dset.setDerived(v); }); } void handle(AssignInstr *v) override { link(v->getRhs(), v->getLhs(), v); } void handle(ExtractInstr *v) override { if (!shouldTrack(v)) return; forEachDSetOf(v->getVal(), [&](DerivedSet &dset) { dset.setDerived(v); }); } void handle(InsertInstr *v) override { link(v->getRhs(), v->getLhs(), v); forEachDSetOf(v->getLhs(), [&](DerivedSet &dset) { dset.result.modified = true; }); } void handle(CallInstr *v) override { std::vector args(v->begin(), v->end()); std::vector capInfo; auto *func = util::getFunc(v->getCallee()); if (func) { capInfo = cc.get(func); } else { std::vector argCaptures; unsigned i = 0; for (auto *arg : args) { if (shouldTrack(arg)) argCaptures.push_back(i); ++i; } const bool returnCaptures = shouldTrack(v); for (auto *arg : args) { CaptureInfo info = CaptureInfo::nothing(); if (shouldTrack(arg)) { info.argCaptures = argCaptures; info.returnCaptures = returnCaptures; info.externCaptures = true; info.modified = true; } capInfo.push_back(info); } } unsigned i = 0; for (auto *arg : args) { // note possibly capInfo.size() != v->numArgs() if calling vararg C function auto info = (i < capInfo.size()) ? capInfo[i] : CaptureInfo::unknown(func, arg->getType()); for (auto argno : info.argCaptures) { Value *other = args[argno]; link(arg, other, v); } forEachDSetOf(arg, [&](DerivedSet &dset) { // Check if the return value captures. if (info.returnCaptures) dset.setDerived(v); // Check if we're externally captured. if (info.externCaptures) dset.setExternCaptured(); if (info.modified) dset.result.modified = true; }); ++i; } } void handle(ForFlow *v) override { auto *var = v->getVar(); if (!shouldTrack(var)) return; forEachDSetOf(v->getIter(), [&](DerivedSet &dset) { bool found = false; for (auto it = cfg->synth_begin(); it != cfg->synth_end(); ++it) { if (auto *synth = cast(*it)) { if (synth->getKind() == SyntheticAssignInstr::Kind::NEXT_VALUE && synth->getLhs()->getId() == var->getId()) { seqassertn(!found, "found multiple synthetic assignments for loop var"); dset.setDerived(var, synth); found = true; } } } }); } void handle(TernaryInstr *v) override { forEachDSetOf(v->getTrueValue(), [&](DerivedSet &dset) { dset.setDerived(v); }); forEachDSetOf(v->getFalseValue(), [&](DerivedSet &dset) { dset.setDerived(v); }); } void handle(FlowInstr *v) override { forEachDSetOf(v->getValue(), [&](DerivedSet &dset) { dset.setDerived(v); }); } void handle(dsl::CustomInstr *v) override { // TODO } // Actual capture points: void handle(ReturnInstr *v) override { forEachDSetOf(v->getValue(), [&](DerivedSet &dset) { dset.result.returnCaptures = true; }); } void handle(YieldInstr *v) override { forEachDSetOf(v->getValue(), [&](DerivedSet &dset) { dset.result.returnCaptures = true; }); } void handle(AwaitInstr *v) override { forEachDSetOf(v->getValue(), [&](DerivedSet &dset) { dset.result.returnCaptures = true; }); } void handle(ThrowInstr *v) override { forEachDSetOf(v->getValue(), [&](DerivedSet &dset) { dset.setExternCaptured(); }); } // Helper to run to completion void runToCompletion(const Func *func) { unsigned oldSize = 0; do { oldSize = size(); const_cast(func)->accept(*this); reset(); } while (size() != oldSize); } }; std::vector CaptureContext::get(const Func *func) { // Don't know anything about external/LLVM funcs so use annotations. if (isA(func) || isA(func)) { bool derives = util::hasAttribute(func, util::DERIVES_ATTR); if (util::hasAttribute(func, util::SELF_CAPTURES_ATTR)) { auto ans = makeNoCaptureInfo(func, derives); if (!ans.empty()) ans[0].modified = true; std::vector argVars(func->arg_begin(), func->arg_end()); for (unsigned i = 1; i < ans.size(); i++) { if (shouldTrack(argVars[i])) ans[i].argCaptures.push_back(0); } return ans; } return noCaptureByAnnotation(func) ? makeNoCaptureInfo(func, derives) : makeAllCaptureInfo(func); } // Only Tuple.__new__(...) and Generator.__promise__(self) capture. if (isA(func)) { bool isTupleNew = func->getUnmangledName() == "__new__" && isA(util::getReturnType(func)); bool isPromise = func->getUnmangledName() == "__promise__" && std::distance(func->arg_begin(), func->arg_end()) == 1 && isA(func->arg_front()->getType()); bool derives = (isTupleNew || isPromise); return makeNoCaptureInfo(func, derives); } // Bodied function if (isA(func)) { auto it = results.find(func->getId()); if (it != results.end()) return it->second; set(func, makeAllCaptureInfo(func)); CaptureTracker ct(*this, cast(func)); ct.runToCompletion(func); std::vector answer; unsigned i = 0; for (auto it = func->arg_begin(); it != func->arg_end(); ++it) { if (shouldTrack(*it)) { answer.push_back(ct.dsets[i++].result); } else { answer.push_back(CaptureInfo::nothing()); } } set(func, answer); return answer; } seqassertn(false, "unknown function type"); return {}; } void CaptureContext::set(const Func *func, const std::vector &result) { results[func->getId()] = result; } } // namespace CaptureInfo CaptureInfo::unknown(const Func *func, types::Type *type) { if (!shouldTrack(type)) return CaptureInfo::nothing(); CaptureInfo c; unsigned i = 0; for (auto it = func->arg_begin(); it != func->arg_end(); ++it) { if (shouldTrack(*it)) c.argCaptures.push_back(i); ++i; } c.returnCaptures = shouldTrack(util::getReturnType(func)); c.externCaptures = true; c.modified = true; return c; } const std::string CaptureAnalysis::KEY = "core-analyses-capture"; std::unique_ptr CaptureAnalysis::run(const Module *m) { auto res = std::make_unique(); auto *rdResult = getAnalysisResult(rdAnalysisKey); auto *domResult = getAnalysisResult(domAnalysisKey); res->rdResult = rdResult; res->domResult = domResult; CaptureContext cc(rdResult, domResult); if (const auto *main = cast(m->getMainFunc())) { auto ans = cc.get(main); res->results.emplace(main->getId(), ans); } for (const auto *var : *m) { if (const auto *f = cast(var)) { auto ans = cc.get(f); res->results.emplace(f->getId(), ans); } } return res; } CaptureInfo escapes(const BodiedFunc *func, const Value *value, CaptureResult *cr) { if (!shouldTrack(value)) return CaptureInfo::nothing(); CaptureContext cc(cr->rdResult, cr->domResult); cc.results = cr->results; CaptureTracker ct(cc, cast(func), value); ct.runToCompletion(func); seqassertn(ct.dsets.size() == 1, "unexpected dsets size"); return ct.dsets[0].result; } } // namespace dataflow } // namespace analyze } // namespace ir } // namespace codon ================================================ FILE: codon/cir/analyze/dataflow/capture.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include "codon/cir/analyze/analysis.h" #include "codon/cir/analyze/dataflow/dominator.h" #include "codon/cir/analyze/dataflow/reaching.h" #include "codon/cir/cir.h" namespace codon { namespace ir { namespace analyze { namespace dataflow { /// Information about how a function argument is captured. struct CaptureInfo { /// vector of other argument indices capturing this one std::vector argCaptures; /// true if the return value of the function captures this argument bool returnCaptures = false; /// true if this argument is externally captured e.g. by assignment to global bool externCaptures = false; /// true if this argument is modified bool modified = false; /// @return true if anything captures operator bool() const { return !argCaptures.empty() || returnCaptures || externCaptures; } /// Returns an instance denoting no captures. /// @return an instance denoting no captures static CaptureInfo nothing() { return {}; } /// Returns an instance denoting unknown capture status. /// @param func the function containing this argument /// @param type the argument's type /// @return an instance denoting unknown capture status static CaptureInfo unknown(const Func *func, types::Type *type); }; /// Capture analysis result. struct CaptureResult : public Result { /// the corresponding reaching definitions result RDResult *rdResult = nullptr; /// the corresponding dominator result DominatorResult *domResult = nullptr; /// map from function id to capture information, where /// each element of the value vector corresponds to an /// argument of the function std::unordered_map> results; }; /// Capture analysis that runs on all functions. class CaptureAnalysis : public Analysis { private: /// the reaching definitions analysis key std::string rdAnalysisKey; /// the dominator analysis key std::string domAnalysisKey; public: static const std::string KEY; std::string getKey() const override { return KEY; } /// Initializes a capture analysis. /// @param rdAnalysisKey the reaching definitions analysis key /// @param domAnalysisKey the dominator analysis key explicit CaptureAnalysis(std::string rdAnalysisKey, std::string domAnalysisKey) : rdAnalysisKey(std::move(rdAnalysisKey)), domAnalysisKey(std::move(domAnalysisKey)) {} std::unique_ptr run(const Module *m) override; }; CaptureInfo escapes(const BodiedFunc *func, const Value *value, CaptureResult *cr); } // namespace dataflow } // namespace analyze } // namespace ir } // namespace codon ================================================ FILE: codon/cir/analyze/dataflow/cfg.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "cfg.h" #include #include "codon/cir/dsl/codegen.h" #include "codon/cir/dsl/nodes.h" namespace codon { namespace ir { namespace analyze { namespace dataflow { namespace { // TODO: this logic is very similar to lowering/pipeline -- unify somehow? Value *callStage(analyze::dataflow::CFGraph *cfg, PipelineFlow::Stage *stage, Value *last) { std::vector args; for (auto *arg : *stage) { args.push_back(arg ? arg : last); } return cfg->N(stage->getCallee(), args); } Value *convertPipelineToForLoopsHelper(analyze::dataflow::CFGraph *cfg, std::vector &stages, unsigned idx = 0, Value *last = nullptr) { if (idx >= stages.size()) return last; auto *stage = stages[idx]; if (idx == 0) return convertPipelineToForLoopsHelper(cfg, stages, idx + 1, stage->getCallee()); auto *prev = stages[idx - 1]; if (prev->isGenerator()) { auto *var = cfg->N(prev->getOutputElementType()); auto *body = convertPipelineToForLoopsHelper( cfg, stages, idx + 1, callStage(cfg, stage, cfg->N(var))); auto *series = cfg->N(); series->push_back(body); return cfg->N(last, series, var); } else { return convertPipelineToForLoopsHelper(cfg, stages, idx + 1, callStage(cfg, stage, last)); } } } // namespace const Value *convertPipelineToForLoops(analyze::dataflow::CFGraph *cfg, const PipelineFlow *p) { std::vector stages; for (const auto &stage : *p) { stages.push_back(const_cast(&stage)); } return convertPipelineToForLoopsHelper(cfg, stages); } void CFBlock::reg(const Value *v) { graph->valueLocations[v->getId()] = this; } const char SyntheticAssignInstr::NodeId = 0; int SyntheticAssignInstr::doReplaceUsedValue(id_t id, Value *newValue) { if (arg && arg->getId() == id) { arg = newValue; return 1; } return 0; } int SyntheticAssignInstr::doReplaceUsedVariable(id_t id, Var *newVar) { if (lhs->getId() == id) { lhs = newVar; return 1; } return 0; } const char SyntheticPhiInstr::NodeId = 0; std::vector SyntheticPhiInstr::doGetUsedValues() const { std::vector ret; for (auto &p : *this) { ret.push_back(const_cast(p.getResult())); } return ret; } int SyntheticPhiInstr::doReplaceUsedValue(id_t id, Value *newValue) { auto res = 0; for (auto &p : *this) { if (p.getResult()->getId() == id) { p.setResult(newValue); ++res; } } return res; } CFGraph::CFGraph(const BodiedFunc *f) : func(f) { newBlock("entry", true); } std::ostream &operator<<(std::ostream &os, const CFGraph &cfg) { os << "digraph \"" << cfg.func->getName() << "\" {\n"; for (auto *block : cfg) { os << " "; os << block->getName() << "_" << reinterpret_cast(block); os << " [ label=\"" << block->getName() << "\""; if (block == cfg.getEntryBlock()) { os << " shape=square"; } os << " ];\n"; } for (auto *block : cfg) { for (auto next = block->successors_begin(); next != block->successors_end(); ++next) { CFBlock *succ = *next; os << " "; os << block->getName() << "_" << reinterpret_cast(block); os << " -> "; os << succ->getName() << "_" << reinterpret_cast(succ); os << ";\n"; } } os << "}"; return os; } std::unique_ptr buildCFGraph(const BodiedFunc *f) { auto ret = std::make_unique(f); CFVisitor v(ret.get()); v.process(f); return ret; } const std::string CFAnalysis::KEY = "core-analyses-cfg"; std::unique_ptr CFAnalysis::run(const Module *m) { auto res = std::make_unique(); if (const auto *main = cast(m->getMainFunc())) { res->graphs.insert(std::make_pair(main->getId(), buildCFGraph(main))); } for (const auto *var : *m) { if (const auto *f = cast(var)) { res->graphs.insert(std::make_pair(f->getId(), buildCFGraph(f))); } } return res; } void CFVisitor::visit(const BodiedFunc *f) { auto *blk = graph->getCurrentBlock(); for (auto it = f->arg_begin(); it != f->arg_end(); it++) { blk->push_back(graph->N( f, const_cast(*it), const_cast(graph->N(*it)))); } process(f->getBody()); } void CFVisitor::visit(const SeriesFlow *v) { for (auto *c : *v) { process(c); } } void CFVisitor::visit(const IfFlow *v) { process(v->getCond()); auto *original = graph->getCurrentBlock(); auto *end = graph->newBlock("endIf"); auto *tBranch = graph->newBlock("trueBranch", true); process(v->getTrueBranch()); graph->getCurrentBlock()->successors_insert(end); analyze::dataflow::CFBlock *fBranch = nullptr; if (v->getFalseBranch()) { fBranch = graph->newBlock("falseBranch", true); process(v->getFalseBranch()); graph->getCurrentBlock()->successors_insert(end); } original->successors_insert(tBranch); if (fBranch) original->successors_insert(fBranch); else original->successors_insert(end); graph->setCurrentBlock(end); } void CFVisitor::visit(const WhileFlow *v) { auto *original = graph->getCurrentBlock(); auto *end = graph->newBlock("endWhile"); auto *loopBegin = graph->newBlock("whileBegin", true); original->successors_insert(loopBegin); process(v->getCond()); graph->getCurrentBlock()->successors_insert(end); loopStack.emplace_back(loopBegin, end, v->getId(), tryCatchStack.size() - 1); auto *body = graph->newBlock("whileBody", true); loopBegin->successors_insert(body); process(v->getBody()); loopStack.pop_back(); graph->getCurrentBlock()->successors_insert(loopBegin); graph->setCurrentBlock(end); } void CFVisitor::visit(const ForFlow *v) { if (v->isParallel()) { for (auto *v : v->getSchedule()->getUsedValues()) { process(v); } } auto *original = graph->getCurrentBlock(); auto *end = graph->newBlock("endFor"); auto *loopBegin = graph->newBlock("forBegin", true); original->successors_insert(loopBegin); process(v->getIter()); auto *loopCheck = graph->newBlock("forCheck"); graph->getCurrentBlock()->successors_insert(loopCheck); loopCheck->successors_insert(end); auto *loopNext = graph->newBlock("forNext"); loopCheck->successors_insert(loopNext); loopNext->push_back(graph->N( v, const_cast(v->getVar()), const_cast(v->getIter()), analyze::dataflow::SyntheticAssignInstr::NEXT_VALUE)); loopStack.emplace_back(loopCheck, end, v->getId(), tryCatchStack.size() - 1); auto *loopBody = graph->newBlock("forBody", true); loopNext->successors_insert(loopBody); process(v->getBody()); graph->getCurrentBlock()->successors_insert(loopCheck); loopStack.pop_back(); graph->setCurrentBlock(end); } void CFVisitor::visit(const ImperativeForFlow *v) { if (v->isParallel()) { for (auto *v : v->getSchedule()->getUsedValues()) { process(v); } } auto *original = graph->getCurrentBlock(); auto *end = graph->newBlock("endFor"); auto *loopBegin = graph->newBlock("forBegin", true); original->successors_insert(loopBegin); loopBegin->push_back(graph->N( v, const_cast(v->getVar()), const_cast(v->getStart()), analyze::dataflow::SyntheticAssignInstr::KNOWN)); process(v->getStart()); process(v->getEnd()); auto *loopCheck = graph->newBlock("forCheck"); graph->getCurrentBlock()->successors_insert(loopCheck); loopCheck->successors_insert(end); auto *loopNext = graph->newBlock("forUpdate"); loopNext->push_back(graph->N( v, const_cast(v->getVar()), v->getStep())); loopNext->successors_insert(loopCheck); loopStack.emplace_back(loopCheck, end, v->getId(), tryCatchStack.size() - 1); auto *loopBody = graph->newBlock("forBody", true); loopCheck->successors_insert(loopBody); process(v->getBody()); graph->getCurrentBlock()->successors_insert(loopCheck); loopStack.pop_back(); graph->setCurrentBlock(end); } void CFVisitor::visit(const TryCatchFlow *v) { auto *routeBlock = graph->newBlock("tcRoute"); auto *end = graph->newBlock("tcEnd"); analyze::dataflow::CFBlock *else_ = nullptr; analyze::dataflow::CFBlock *finally = nullptr; if (v->getElse()) else_ = graph->newBlock("tcElse"); if (v->getFinally()) finally = graph->newBlock("tcFinally"); auto *dst = finally ? finally : end; tryCatchStack.emplace_back(routeBlock, finally); process(v->getBody()); graph->getCurrentBlock()->successors_insert(else_ ? else_ : dst); for (auto &c : *v) { auto *cBlock = graph->newBlock("catch", true); if (c.getVar()) cBlock->push_back(graph->N( v, const_cast(c.getVar()))); process(c.getHandler()); routeBlock->successors_insert(cBlock); graph->getCurrentBlock()->successors_insert(dst); } if (v->getElse()) { graph->setCurrentBlock(else_); process(v->getElse()); graph->getCurrentBlock()->successors_insert(dst); } tryCatchStack.pop_back(); if (v->getFinally()) { graph->setCurrentBlock(finally); process(v->getFinally()); graph->getCurrentBlock()->successors_insert(end); routeBlock->successors_insert(finally); } if (!tryCatchStack.empty()) { if (finally) finally->successors_insert(tryCatchStack.back().first); else routeBlock->successors_insert(tryCatchStack.back().first); } graph->setCurrentBlock(end); } void CFVisitor::visit(const PipelineFlow *v) { if (auto *loops = convertPipelineToForLoops(graph, v)) { process(loops); } else { // pipeline is empty } } void CFVisitor::visit(const dsl::CustomFlow *v) { v->getCFBuilder()->buildCFNodes(this); } void CFVisitor::visit(const AssignInstr *v) { process(v->getRhs()); defaultInsert(v); } void CFVisitor::visit(const ExtractInstr *v) { process(v->getVal()); defaultInsert(v); } void CFVisitor::visit(const InsertInstr *v) { process(v->getLhs()); process(v->getRhs()); defaultInsert(v); } void CFVisitor::visit(const CallInstr *v) { process(v->getCallee()); for (auto *a : *v) process(a); defaultInsert(v); } void CFVisitor::visit(const TernaryInstr *v) { auto *end = graph->newBlock("ternaryDone"); auto *tBranch = graph->newBlock("ternaryTrue"); auto *fBranch = graph->newBlock("ternaryFalse"); process(v->getCond()); graph->getCurrentBlock()->successors_insert(tBranch); graph->getCurrentBlock()->successors_insert(fBranch); graph->setCurrentBlock(tBranch); process(v->getTrueValue()); graph->getCurrentBlock()->successors_insert(end); graph->setCurrentBlock(fBranch); process(v->getFalseValue()); graph->getCurrentBlock()->successors_insert(end); auto *phi = graph->N(v); phi->emplace_back(tBranch, const_cast(v->getTrueValue())); phi->emplace_back(fBranch, const_cast(v->getFalseValue())); end->push_back(phi); graph->remapValue(v, phi); graph->setCurrentBlock(end); } void CFVisitor::visit(const BreakInstr *v) { auto &loop = v->getLoop() ? findLoop(v->getLoop()->getId()) : loopStack.back(); defaultJump(loop.end, loop.tcIndex); defaultInsert(v); } void CFVisitor::visit(const ContinueInstr *v) { auto &loop = v->getLoop() ? findLoop(v->getLoop()->getId()) : loopStack.back(); defaultJump(loop.nextIt, loop.tcIndex); defaultInsert(v); } void CFVisitor::visit(const ReturnInstr *v) { if (v->getValue()) process(v->getValue()); defaultJump(nullptr, -1); defaultInsert(v); } void CFVisitor::visit(const YieldInstr *v) { if (v->getValue()) process(v->getValue()); defaultInsert(v); } void CFVisitor::visit(const AwaitInstr *v) { process(v->getValue()); defaultInsert(v); } void CFVisitor::visit(const ThrowInstr *v) { if (v->getValue()) process(v->getValue()); defaultInsert(v); } void CFVisitor::visit(const FlowInstr *v) { process(v->getFlow()); if (v->getValue()) process(v->getValue()); defaultInsert(v); } void CFVisitor::visit(const dsl::CustomInstr *v) { v->getCFBuilder()->buildCFNodes(this); } void CFVisitor::defaultInsert(const Value *v) { if (tryCatchStack.empty()) { graph->getCurrentBlock()->push_back(v); } else { auto *original = graph->getCurrentBlock(); auto *newBlock = graph->newBlock("default", true); original->successors_insert(newBlock); newBlock->successors_insert(tryCatchStack.back().first); graph->getCurrentBlock()->push_back(v); } seenIds.insert(v->getId()); } void CFVisitor::defaultJump(const CFBlock *cf, int newTcLevel) { int curTc = tryCatchStack.size() - 1; if (curTc == -1 || curTc <= newTcLevel) { if (cf) graph->getCurrentBlock()->successors_insert(const_cast(cf)); } else { CFBlock *nearestFinally = nullptr; for (auto i = newTcLevel + 1; i <= curTc; ++i) { if (auto *n = tryCatchStack[i].second) { nearestFinally = n; break; } } if (nearestFinally) { graph->getCurrentBlock()->successors_insert(tryCatchStack.back().first); if (cf) nearestFinally->successors_insert(const_cast(cf)); } else { if (cf) graph->getCurrentBlock()->successors_insert(const_cast(cf)); } } } CFVisitor::Loop &CFVisitor::findLoop(id_t id) { return *std::find_if(loopStack.begin(), loopStack.end(), [=](auto &it) { return it.loopId == id; }); } } // namespace dataflow } // namespace analyze } // namespace ir } // namespace codon #undef DEFAULT_VISIT ================================================ FILE: codon/cir/analyze/dataflow/cfg.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include "codon/cir/analyze/analysis.h" #include "codon/cir/cir.h" #include "codon/cir/util/iterators.h" #define DEFAULT_VISIT(x) \ void visit(const x *v) override { defaultInsert(v); } namespace codon { namespace ir { namespace analyze { namespace dataflow { class CFGraph; class CFBlock : public IdMixin { private: /// the in-order list of values in this block std::list values; /// an un-ordered list of successor blocks std::unordered_set successors; /// an un-ordered list of successor blocks std::unordered_set predecessors; /// the block's name std::string name; /// the graph CFGraph *graph; public: /// Constructs a control-flow block. /// @param graph the parent graph /// @param name the block's name explicit CFBlock(CFGraph *graph, std::string name = "") : name(std::move(name)), graph(graph) {} virtual ~CFBlock() noexcept = default; /// @return this block's name std::string getName() const { return name; } /// @return an iterator to the first value auto begin() { return values.begin(); } /// @return an iterator beyond the last value auto end() { return values.end(); } /// @return an iterator to the first value auto begin() const { return values.begin(); } /// @return an iterator beyond the last value auto end() const { return values.end(); } /// @return a pointer to the first value const Value *front() const { return values.front(); } /// @return a pointer to the last value const Value *back() const { return values.back(); } /// Inserts a value at a given position. /// @param it the position /// @param v the new value /// @param an iterator to the new value template auto insert(It it, const Value *v) { values.insert(it, v); reg(v); } /// Inserts a value at the back. /// @param v the new value void push_back(const Value *v) { values.push_back(v); reg(v); } /// Erases a value at the given position. /// @param it the position /// @return an iterator following the removed value template auto erase(It it) { values.erase(it); } /// @return an iterator to the first successor auto successors_begin() { return successors.begin(); } /// @return an iterator beyond the last successor auto successors_end() { return successors.end(); } /// @return an iterator to the first successor auto successors_begin() const { return successors.begin(); } /// @return an iterator beyond the last successor auto successors_end() const { return successors.end(); } /// Inserts a successor at some position. /// @param v the new successor /// @return an iterator to the new successor auto successors_insert(CFBlock *v) { successors.insert(v); v->predecessors.insert(this); } /// Removes a given successor. /// @param v the successor to remove auto successors_erase(CFBlock *v) { successors.erase(v); v->predecessors.erase(this); } /// @return an iterator to the first predecessor auto predecessors_begin() { return predecessors.begin(); } /// @return an iterator beyond the last predecessor auto predecessors_end() { return predecessors.end(); } /// @return an iterator to the first predecessor auto predecessors_begin() const { return predecessors.begin(); } /// @return an iterator beyond the last predecessor auto predecessors_end() const { return predecessors.end(); } /// @return the graph CFGraph *getGraph() { return graph; } /// @return the graph const CFGraph *getGraph() const { return graph; } /// Sets the graph. /// @param g the new graph void setGraph(CFGraph *g) { graph = g; } private: void reg(const Value *v); }; class SyntheticInstr : public AcceptorExtend { private: const Node *source; public: /// Constructs a synthetic instruction. /// @param source the node that gave rise to this instruction /// @param name the name of the instruction SyntheticInstr(const Node *source, std::string name = "") : AcceptorExtend(std::move(name)), source(source) {} /// Gets the source of this synthetic instruction, i.e. the /// node that gave rise to it when constructing the CFG. /// @return the node that gave rise to this synthetic instruction const Node *getSource() const { return source; } /// Sets the source of this synthetic instruction /// @param s the new source node void setSource(const Node *s) { source = s; } }; class SyntheticAssignInstr : public AcceptorExtend { public: enum Kind { UNKNOWN, KNOWN, NEXT_VALUE, ADD }; private: /// the left-hand side Var *lhs; /// the kind of synthetic assignment Kind kind; /// any argument to the synthetic assignment Value *arg = nullptr; /// the difference int64_t diff = 0; public: static const char NodeId; /// Constructs a synthetic assignment. /// @param source the node that gave rise to this instruction /// @param lhs the variable being assigned /// @param arg the argument /// @param k the kind of assignment /// @param name the name of the instruction SyntheticAssignInstr(const Node *source, Var *lhs, Value *arg, Kind k = KNOWN, std::string name = "") : AcceptorExtend(source, std::move(name)), lhs(lhs), kind(k), arg(arg) {} /// Constructs an unknown synthetic assignment. /// @param source the node that gave rise to this instruction /// @param lhs the variable being assigned /// @param name the name of the instruction explicit SyntheticAssignInstr(const Node *source, Var *lhs, std::string name = "") : SyntheticAssignInstr(source, lhs, nullptr, UNKNOWN, std::move(name)) {} /// Constructs an addition synthetic assignment. /// @param source the node that gave rise to this instruction /// @param lhs the variable being assigned /// @param diff the difference /// @param name the name of the instruction SyntheticAssignInstr(const Node *source, Var *lhs, int64_t diff, std::string name = "") : AcceptorExtend(source, std::move(name)), lhs(lhs), kind(ADD), diff(diff) {} /// @return the variable being assigned Var *getLhs() { return lhs; } /// @return the variable being assigned const Var *getLhs() const { return lhs; } /// Sets the variable being assigned. /// @param v the variable void setLhs(Var *v) { lhs = v; } /// @return the argument Value *getArg() { return arg; } /// @return the argument const Value *getArg() const { return arg; } /// Sets the argument. /// @param v the new value void setArg(Value *v) { arg = v; } /// @return the diff int64_t getDiff() const { return diff; } /// Sets the diff. /// @param v the new value void setDiff(int64_t v) { diff = v; } /// @return the kind of synthetic assignment Kind getKind() const { return kind; } /// Sets the kind. /// @param k the new value void setKind(Kind k) { kind = k; } protected: std::vector doGetUsedValues() const override { return {arg}; } int doReplaceUsedValue(id_t id, Value *newValue) override; std::vector doGetUsedVariables() const override { return {lhs}; } int doReplaceUsedVariable(id_t id, Var *newVar) override; }; class SyntheticPhiInstr : public AcceptorExtend { public: class Predecessor { private: /// the predecessor block CFBlock *pred; /// the value Value *result; public: /// Constructs a predecessor. /// @param pred the predecessor block /// @param result the result of this predecessor. Predecessor(CFBlock *pred, Value *result) : pred(pred), result(result) {} /// @return the predecessor block CFBlock *getPred() { return pred; } /// @return the predecessor block const CFBlock *getPred() const { return pred; } /// Sets the predecessor. /// @param v the new value void setPred(CFBlock *v) { pred = v; } /// @return the result Value *getResult() { return result; } /// @return the result const Value *getResult() const { return result; } /// Sets the result /// @param v the new value void setResult(Value *v) { result = v; } }; private: std::list preds; public: static const char NodeId; /// Constructs a synthetic phi instruction. /// @param source the node that gave rise to this instruction /// @param name the name of the instruction explicit SyntheticPhiInstr(const Node *source, std::string name = "") : AcceptorExtend(source, std::move(name)) {} /// @return an iterator to the first instruction/flow auto begin() { return preds.begin(); } /// @return an iterator beyond the last instruction/flow auto end() { return preds.end(); } /// @return an iterator to the first instruction/flow auto begin() const { return preds.begin(); } /// @return an iterator beyond the last instruction/flow auto end() const { return preds.end(); } /// @return a pointer to the first instruction/flow Predecessor &front() { return preds.front(); } /// @return a pointer to the last instruction/flow Predecessor &back() { return preds.back(); } /// @return a pointer to the first instruction/flow const Predecessor &front() const { return preds.front(); } /// @return a pointer to the last instruction/flow const Predecessor &back() const { return preds.back(); } /// Inserts a predecessor. /// @param pos the position /// @param v the predecessor /// @return an iterator to the newly added predecessor template auto insert(It pos, Predecessor v) { return preds.insert(pos, v); } /// Appends an predecessor. /// @param v the predecessor void push_back(Predecessor v) { preds.push_back(v); } /// Erases the item at the supplied position. /// @param pos the position /// @return the iterator beyond the removed predecessor template auto erase(It pos) { return preds.erase(pos); } /// Emplaces a predecessor. /// @param args the args template void emplace_back(Args &&...args) { preds.emplace_back(std::forward(args)...); } protected: std::vector doGetUsedValues() const override; int doReplaceUsedValue(id_t id, Value *newValue) override; }; class CFGraph { private: /// owned list of blocks std::list> blocks; /// the current block CFBlock *cur = nullptr; /// the function being analyzed const BodiedFunc *func; /// a list of synthetic values std::list> syntheticValues; /// a map of synthetic values std::unordered_map valueMapping; /// a list of synthetic variables std::list> syntheticVars; /// a mapping from value id to block std::unordered_map valueLocations; public: /// Constructs a control-flow graph. explicit CFGraph(const BodiedFunc *f); /// @return number of blocks in this CFG auto size() const { return blocks.size(); } /// @return an iterator to the first block auto begin() { return util::raw_ptr_adaptor(blocks.begin()); } /// @return an iterator beyond the last block auto end() { return util::raw_ptr_adaptor(blocks.end()); } /// @return an iterator to the first block auto begin() const { return util::raw_ptr_adaptor(blocks.begin()); } /// @return an iterator beyond the last block auto end() const { return util::raw_ptr_adaptor(blocks.end()); } /// @return an iterator to the synthetic value auto synth_begin() { return util::raw_ptr_adaptor(syntheticValues.begin()); } /// @return an iterator beyond the last synthetic value auto synth_end() { return util::raw_ptr_adaptor(syntheticValues.end()); } /// @return an iterator to the first synthetic value auto synth_begin() const { return util::raw_ptr_adaptor(syntheticValues.begin()); } /// @return an iterator beyond the last synthetic value auto synth_end() const { return util::raw_ptr_adaptor(syntheticValues.end()); } /// @return the entry block CFBlock *getEntryBlock() { return blocks.front().get(); } /// @return the entry block const CFBlock *getEntryBlock() const { return blocks.front().get(); } /// @return the entry block CFBlock *getCurrentBlock() { return cur; } /// @return the entry block const CFBlock *getCurrentBlock() const { return cur; } /// Sets the current block. /// @param v the new value void setCurrentBlock(CFBlock *v) { cur = v; } /// @return the function const BodiedFunc *getFunc() const { return func; } /// Sets the function. /// @param f the new value void setFunc(BodiedFunc *f) { func = f; } /// Gets the block containing a value. /// @param val the value /// @return the block CFBlock *getBlock(const Value *v) { auto vmIt = valueMapping.find(v->getId()); if (vmIt != valueMapping.end()) v = vmIt->second; auto it = valueLocations.find(v->getId()); return it != valueLocations.end() ? it->second : nullptr; } /// Gets the block containing a value. /// @param val the value /// @return the block const CFBlock *getBlock(const Value *v) const { auto vmIt = valueMapping.find(v->getId()); if (vmIt != valueMapping.end()) v = vmIt->second; auto it = valueLocations.find(v->getId()); return it != valueLocations.end() ? it->second : nullptr; } /// Creates and inserts a new block /// @param name the name /// @param setCur true if the block should be made the current one /// @return a newly inserted block CFBlock *newBlock(std::string name = "", bool setCur = false) { auto *ret = new CFBlock(this, std::move(name)); blocks.emplace_back(ret); if (setCur) setCurrentBlock(ret); return ret; } template NodeType *N(Args &&...args) { auto *ret = new NodeType(std::forward(args)...); reg(ret); ret->setModule(func->getModule()); return ret; } /// Remaps a value. /// @param id original id /// @param newValue the new value void remapValue(id_t id, Value *newValue) { valueMapping[id] = newValue; } /// Remaps a value. /// @param original the original value /// @param newValue the new value void remapValue(const Value *original, Value *newValue) { remapValue(original->getId(), newValue); } /// Gets a value by id. /// @param id the id /// @return the value or nullptr Value *getValue(id_t id) { auto it = valueMapping.find(id); return it != valueMapping.end() ? it->second : func->getModule()->getValue(id); } friend std::ostream &operator<<(std::ostream &os, const CFGraph &cfg); friend class CFBlock; private: void reg(Var *v) { syntheticVars.emplace_back(v); } void reg(Value *v) { syntheticValues.emplace_back(v); valueMapping[v->getId()] = v; } }; /// Builds a control-flow graph from a given function. /// @param f the function /// @return the control-flow graph std::unique_ptr buildCFGraph(const BodiedFunc *f); /// Control-flow analysis result. struct CFResult : public Result { /// map from function id to control-flow graph std::unordered_map> graphs; }; /// Control-flow analysis that runs on all functions. class CFAnalysis : public Analysis { public: static const std::string KEY; std::string getKey() const override { return KEY; } std::unique_ptr run(const Module *m) override; }; class CFVisitor : public util::ConstVisitor { private: struct Loop { analyze::dataflow::CFBlock *nextIt; analyze::dataflow::CFBlock *end; id_t loopId; int tcIndex; Loop(analyze::dataflow::CFBlock *nextIt, analyze::dataflow::CFBlock *end, id_t loopId, int tcIndex = -1) : nextIt(nextIt), end(end), loopId(loopId), tcIndex(tcIndex) {} }; analyze::dataflow::CFGraph *graph; std::vector> tryCatchStack; std::unordered_set seenIds; std::vector loopStack; public: explicit CFVisitor(analyze::dataflow::CFGraph *graph) : graph(graph) {} void visit(const BodiedFunc *f) override; DEFAULT_VISIT(VarValue) DEFAULT_VISIT(PointerValue) void visit(const SeriesFlow *v) override; void visit(const IfFlow *v) override; void visit(const WhileFlow *v) override; void visit(const ForFlow *v) override; void visit(const ImperativeForFlow *v) override; void visit(const TryCatchFlow *v) override; void visit(const PipelineFlow *v) override; void visit(const dsl::CustomFlow *v) override; DEFAULT_VISIT(TemplatedConst); DEFAULT_VISIT(TemplatedConst); DEFAULT_VISIT(TemplatedConst); DEFAULT_VISIT(TemplatedConst); DEFAULT_VISIT(dsl::CustomConst); void visit(const AssignInstr *v) override; void visit(const ExtractInstr *v) override; void visit(const InsertInstr *v) override; void visit(const CallInstr *v) override; DEFAULT_VISIT(StackAllocInstr); DEFAULT_VISIT(TypePropertyInstr); DEFAULT_VISIT(YieldInInstr); void visit(const TernaryInstr *v) override; void visit(const BreakInstr *v) override; void visit(const ContinueInstr *v) override; void visit(const ReturnInstr *v) override; void visit(const YieldInstr *v) override; void visit(const AwaitInstr *v) override; void visit(const ThrowInstr *v) override; void visit(const FlowInstr *v) override; void visit(const dsl::CustomInstr *v) override; template void process(const NodeType *v) { if (!v) return; if (seenIds.find(v->getId()) != seenIds.end()) return; seenIds.insert(v->getId()); v->accept(*this); } void defaultInsert(const Value *v); void defaultJump(const CFBlock *cf, int newTcLevel = -1); private: Loop &findLoop(id_t id); }; } // namespace dataflow } // namespace analyze } // namespace ir } // namespace codon template <> struct fmt::formatter : fmt::ostream_formatter { }; #undef DEFAULT_VISIT ================================================ FILE: codon/cir/analyze/dataflow/dominator.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "dominator.h" #include "codon/cir/llvm/llvm.h" namespace codon { namespace ir { namespace analyze { namespace dataflow { void DominatorInspector::analyze() { const auto numBlocks = cfg->size(); std::unordered_map mapping; // id -> sequential id std::vector mappingInv(numBlocks); // sequential id -> id std::vector bitvecs(numBlocks); // sequential id -> bitvector mapping.reserve(numBlocks); size_t next = 0; for (auto *blk : *cfg) { auto id = blk->getId(); mapping[id] = next; mappingInv[next] = id; ++next; } // Initialize: all blocks dominate themselves; others start with universal set for (auto *blk : *cfg) { auto id = mapping[blk->getId()]; llvm::BitVector &bv = bitvecs[id]; bv.resize(numBlocks, true); if (blk == cfg->getEntryBlock()) { // entry block only dominated by itself bv.reset(); bv.set(id); } } // Run simple domination algorithm bool changed = true; while (changed) { changed = false; for (auto *blk : *cfg) { auto id = mapping[blk->getId()]; llvm::BitVector old = bitvecs[id]; llvm::BitVector newSet; if (blk->predecessors_begin() == blk->predecessors_end()) newSet.resize(numBlocks); bool first = true; for (auto it = blk->predecessors_begin(); it != blk->predecessors_end(); ++it) { auto predId = mapping[(*it)->getId()]; const auto &predSet = bitvecs[predId]; if (first) { newSet = predSet; first = false; } else { newSet &= predSet; } } newSet.set(id); // a block always dominates itself if (newSet != old) { bitvecs[id] = newSet; changed = true; } } } // Map back to canonical id sets.reserve(numBlocks); for (unsigned id = 0; id < numBlocks; id++) { auto &bv = bitvecs[id]; auto &set = sets[mappingInv[id]]; for (auto n = bv.find_first(); n != -1; n = bv.find_next(n)) { set.insert(mappingInv[n]); } } } bool DominatorInspector::isDominated(const Value *v, const Value *dominator) { auto *vBlock = cfg->getBlock(v); auto *dBlock = cfg->getBlock(dominator); if (vBlock->getId() == dBlock->getId()) { auto vDist = std::distance(vBlock->begin(), std::find(vBlock->begin(), vBlock->end(), v)); auto dDist = std::distance(vBlock->begin(), std::find(vBlock->begin(), vBlock->end(), dominator)); return dDist <= vDist; } auto &set = sets[vBlock->getId()]; return set.find(dBlock->getId()) != set.end(); } const std::string DominatorAnalysis::KEY = "core-analyses-dominator"; std::unique_ptr DominatorAnalysis::run(const Module *m) { auto *cfgResult = getAnalysisResult(cfAnalysisKey); auto ret = std::make_unique(cfgResult); for (const auto &graph : cfgResult->graphs) { auto inspector = std::make_unique(graph.second.get()); inspector->analyze(); ret->results[graph.first] = std::move(inspector); } return ret; } } // namespace dataflow } // namespace analyze } // namespace ir } // namespace codon ================================================ FILE: codon/cir/analyze/dataflow/dominator.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include "codon/cir/analyze/analysis.h" #include "codon/cir/analyze/dataflow/cfg.h" namespace codon { namespace ir { namespace analyze { namespace dataflow { /// Helper to query the dominators of a particular function. class DominatorInspector { private: std::unordered_map> sets; CFGraph *cfg; public: explicit DominatorInspector(CFGraph *cfg) : cfg(cfg) {} /// Do the analysis. void analyze(); /// Checks if one value dominates another. /// @param v the value /// @param dominator the dominator value bool isDominated(const Value *v, const Value *dominator); }; /// Result of a dominator analysis. struct DominatorResult : public Result { /// the corresponding control flow result const CFResult *cfgResult; /// the dominator inspectors std::unordered_map> results; explicit DominatorResult(const CFResult *cfgResult) : cfgResult(cfgResult) {} }; /// Dominator analysis. Must have control flow-graph available. class DominatorAnalysis : public Analysis { private: /// the control-flow analysis key std::string cfAnalysisKey; public: static const std::string KEY; /// Initializes a dominator analysis. /// @param cfAnalysisKey the control-flow analysis key explicit DominatorAnalysis(std::string cfAnalysisKey) : cfAnalysisKey(std::move(cfAnalysisKey)) {} std::string getKey() const override { return KEY; } std::unique_ptr run(const Module *m) override; }; } // namespace dataflow } // namespace analyze } // namespace ir } // namespace codon ================================================ FILE: codon/cir/analyze/dataflow/reaching.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "reaching.h" #include #include namespace codon { namespace ir { namespace analyze { namespace dataflow { namespace { id_t getKilled(const Value *val) { if (auto *assign = cast(val)) { return assign->getLhs()->getId(); } else if (auto *synthAssign = cast(val)) { return synthAssign->getLhs()->getId(); } return -1; } std::pair getGenerated(const Value *val) { if (auto *assign = cast(val)) { return std::make_pair(assign->getLhs()->getId(), ReachingDef(assign, assign->getRhs())); } else if (auto *synthAssign = cast(val)) { if (synthAssign->getKind() == analyze::dataflow::SyntheticAssignInstr::KNOWN) return std::make_pair(synthAssign->getLhs()->getId(), ReachingDef(synthAssign, synthAssign->getArg())); else return std::make_pair(synthAssign->getLhs()->getId(), ReachingDef(synthAssign)); } return std::make_pair(-1, ReachingDef(nullptr)); } template struct WorkList { std::unordered_set have; std::deque queue; void push(T *a) { auto id = a->getId(); if (have.count(id)) return; have.insert(id); queue.push_back(a); } T *pop() { if (queue.empty()) return nullptr; auto *a = queue.front(); queue.pop_front(); have.erase(a->getId()); return a; } template WorkList(S *x) : have(), queue() { for (T *a : *x) { push(a); } } }; struct BitSet { static constexpr unsigned B = 64; static unsigned allocSize(unsigned size) { return (size + B - 1) / B; } std::vector words; explicit BitSet(unsigned size) : words(allocSize(size), 0) {} BitSet copy(unsigned size) const { auto res = BitSet(size); std::memcpy(res.words.data(), words.data(), allocSize(size) * (B / 8)); return res; } void set(unsigned bit) { words.data()[bit / B] |= (1UL << (bit % B)); } bool get(unsigned bit) const { return (words.data()[bit / B] & (1UL << (bit % B))) != 0; } bool equals(const BitSet &other, unsigned size) { return std::memcmp(words.data(), other.words.data(), allocSize(size) * (B / 8)) == 0; } void clear(unsigned size) { std::memset(words.data(), 0, allocSize(size) * (B / 8)); } void setAll(unsigned size) { std::memset(words.data(), 0xff, allocSize(size) * (B / 8)); } void overwrite(const BitSet &other, unsigned size) { std::memcpy(words.data(), other.words.data(), allocSize(size) * (B / 8)); } void update(const BitSet &other, unsigned size) { auto *p = words.data(); auto *q = other.words.data(); auto n = allocSize(size); for (unsigned i = 0; i < n; i++) { p[i] |= q[i]; } } void subtract(const BitSet &other, unsigned size) { auto *p = words.data(); auto *q = other.words.data(); auto n = allocSize(size); for (unsigned i = 0; i < n; i++) { p[i] &= ~q[i]; } } }; template struct BlockBitSets { T *blk; BitSet gen; BitSet kill; BitSet in; BitSet out; BlockBitSets(T *blk, BitSet gen, BitSet kill, BitSet in, BitSet out) : blk(blk), gen(std::move(gen)), kill(std::move(kill)), in(std::move(in)), out(std::move(out)) {} }; } // namespace void RDInspector::analyze() { std::vector ordering; std::unordered_map lookup; std::unordered_map> varToAssignments; for (auto *blk : *cfg) { for (auto *val : *blk) { auto k = getKilled(val); if (k != -1) { lookup.emplace(val->getId(), ordering.size()); ordering.push_back(val); varToAssignments[k].push_back(val); } } } unsigned n = ordering.size(); std::unordered_map> bitsets; // construct initial gen and kill sets for (auto *blk : *cfg) { auto gen = BitSet(n); auto kill = BitSet(n); std::unordered_map generated; for (auto *val : *blk) { // vars that are used by pointer may change at any time, so don't track them if (auto *ptr = cast(val)) { invalid.insert(ptr->getVar()->getId()); continue; } auto g = getGenerated(val); if (g.first != -1) { // generated map will store latest generated assignment, as desired generated[g.first] = val->getId(); } auto k = getKilled(val); if (k != -1) { // all other assignments that use the var are killed for (auto *assign : varToAssignments[k]) { if (assign->getId() != val->getId()) kill.set(lookup[assign->getId()]); } } } // set gen for the last assignment of each var in the block for (auto &entry : generated) { gen.set(lookup[entry.second]); } auto in = BitSet(n); auto out = gen.copy(n); // out = gen is an optimization over out = {} bitsets.emplace(std::piecewise_construct, std::forward_as_tuple(blk->getId()), std::forward_as_tuple(blk, std::move(gen), std::move(kill), std::move(in), std::move(out))); } WorkList worklist(cfg); while (auto *blk = worklist.pop()) { auto &data = bitsets.find(blk->getId())->second; // IN[blk] = U OUT[pred], for all predecessors pred data.in.clear(n); for (auto it = blk->predecessors_begin(); it != blk->predecessors_end(); ++it) { data.in.update(bitsets.find((*it)->getId())->second.out, n); } // OUT[blk] = GEN[blk] U (IN[blk] - KILL[blk]) auto oldout = data.out.copy(n); auto tmp = data.in.copy(n); tmp.subtract(data.kill, n); tmp.update(data.gen, n); data.out.overwrite(tmp, n); // if OUT changed, add all successors to worklist if (!data.out.equals(oldout, n)) { for (auto it = blk->successors_begin(); it != blk->successors_end(); ++it) { worklist.push(*it); } } } // reconstruct final sets in more convenient format for (auto &elem : bitsets) { auto &data = elem.second; auto &entry = sets[data.blk->getId()]; for (unsigned i = 0; i < n; i++) { if (data.in.get(i)) { auto g = getGenerated(ordering[i]); entry.in[g.first].insert(g.second); } } } } std::vector RDInspector::getReachingDefinitions(const Var *var, const Value *loc) { if (invalid.find(var->getId()) != invalid.end() || var->isGlobal()) return {}; auto *blk = cfg->getBlock(loc); if (!blk) return {}; auto &entry = sets[blk->getId()]; auto defs = entry.in[var->getId()]; bool needClear = (blk->getId() == cfg->getEntryBlock()->getId()); bool didClear = false; auto done = false; for (auto *val : *blk) { if (done) break; if (val->getId() == loc->getId()) done = true; auto killed = getKilled(val); if (killed == var->getId()) { defs.clear(); didClear = true; } auto gen = getGenerated(val); if (gen.first == var->getId()) defs.insert(gen.second); } if (needClear && !didClear) return {}; return std::vector(defs.begin(), defs.end()); } const std::string RDAnalysis::KEY = "core-analyses-rd"; std::unique_ptr RDAnalysis::run(const Module *m) { auto *cfgResult = getAnalysisResult(cfAnalysisKey); auto ret = std::make_unique(cfgResult); for (const auto &graph : cfgResult->graphs) { auto inspector = std::make_unique(graph.second.get()); inspector->analyze(); ret->results[graph.first] = std::move(inspector); } return ret; } } // namespace dataflow } // namespace analyze } // namespace ir } // namespace codon ================================================ FILE: codon/cir/analyze/dataflow/reaching.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/cir/analyze/analysis.h" #include "codon/cir/analyze/dataflow/cfg.h" namespace codon { namespace ir { namespace analyze { namespace dataflow { /// Single answer to a reaching definition query struct ReachingDef { /// Assignment instruction, which can be a `AssignInstr` or /// e.g. a `SyntheticAssignInstr` to represent loop variable /// assignment etc. const Instr *assignment; /// The value being assigned, or null if unknown. The assigned /// value is unknown when, for example, assigning the next value /// of a loop variable. const Value *assignee; explicit ReachingDef(const Instr *assignment, const Value *assignee = nullptr) : assignment(assignment), assignee(assignee) {} bool known() const { return assignee != nullptr; } id_t getId() const { return known() ? assignee->getId() : assignment->getId(); } bool operator==(const ReachingDef &other) const { if (known() != other.known()) return false; return known() ? (assignee->getId() == other.assignee->getId()) : (assignment->getId() == other.assignment->getId()); } }; } // namespace dataflow } // namespace analyze } // namespace ir } // namespace codon namespace std { template <> struct hash { size_t operator()(const codon::ir::analyze::dataflow::ReachingDef &d) const { return d.known() ? hash{}(d.assignee->getId()) : hash{}(d.assignment->getId()); } }; } // namespace std namespace codon { namespace ir { namespace analyze { namespace dataflow { /// Helper to query the reaching definitions of a particular function. class RDInspector { private: struct BlockData { std::unordered_map> in; BlockData() = default; }; std::unordered_set invalid; std::unordered_map sets; CFGraph *cfg; public: explicit RDInspector(CFGraph *cfg) : cfg(cfg) {} /// Do the analysis. void analyze(); /// Gets the reaching definitions at a particular location. /// @param var the variable being inspected /// @param loc the location /// @return a vector of reaching definitions std::vector getReachingDefinitions(const Var *var, const Value *loc); bool isInvalid(const Var *var) const { return invalid.count(var->getId()) != 0; } }; /// Result of a reaching definition analysis. struct RDResult : public Result { /// the corresponding control flow result const CFResult *cfgResult; /// the reaching definition inspectors std::unordered_map> results; explicit RDResult(const CFResult *cfgResult) : cfgResult(cfgResult) {} }; /// Reaching definition analysis. Must have control flow-graph available. class RDAnalysis : public Analysis { private: /// the control-flow analysis key std::string cfAnalysisKey; public: static const std::string KEY; /// Initializes a reaching definition analysis. /// @param cfAnalysisKey the control-flow analysis key explicit RDAnalysis(std::string cfAnalysisKey) : cfAnalysisKey(std::move(cfAnalysisKey)) {} std::string getKey() const override { return KEY; } std::unique_ptr run(const Module *m) override; }; } // namespace dataflow } // namespace analyze } // namespace ir } // namespace codon ================================================ FILE: codon/cir/analyze/module/global_vars.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "global_vars.h" #include "codon/cir/util/operator.h" namespace codon { namespace ir { namespace analyze { namespace module { namespace { struct GlobalVarAnalyzer : public util::Operator { std::unordered_map assignments; void handle(PointerValue *v) override { if (v->getVar()->isGlobal()) assignments[v->getVar()->getId()] = -1; } void handle(AssignInstr *v) override { auto *lhs = v->getLhs(); auto id = lhs->getId(); if (lhs->isGlobal()) { if (assignments.find(id) != assignments.end()) { assignments[id] = -1; } else { assignments[id] = v->getRhs()->getId(); } } } }; } // namespace const std::string GlobalVarsAnalyses::KEY = "core-analyses-global-vars"; std::unique_ptr GlobalVarsAnalyses::run(const Module *m) { GlobalVarAnalyzer gva; gva.visit(const_cast(m)); // TODO: any way around this cast? return std::make_unique(std::move(gva.assignments)); } } // namespace module } // namespace analyze } // namespace ir } // namespace codon ================================================ FILE: codon/cir/analyze/module/global_vars.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/cir/analyze/analysis.h" namespace codon { namespace ir { namespace analyze { namespace module { struct GlobalVarsResult : public Result { std::unordered_map assignments; explicit GlobalVarsResult(std::unordered_map assignments) : assignments(std::move(assignments)) {} }; class GlobalVarsAnalyses : public Analysis { static const std::string KEY; std::string getKey() const override { return KEY; } std::unique_ptr run(const Module *m) override; }; } // namespace module } // namespace analyze } // namespace ir } // namespace codon ================================================ FILE: codon/cir/analyze/module/side_effect.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "side_effect.h" #include #include #include "codon/cir/analyze/dataflow/capture.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/operator.h" namespace codon { namespace ir { namespace analyze { namespace module { namespace { template T max(T &&t) { return std::forward(t); } template typename std::common_type::type max(T0 &&val1, T1 &&val2, Ts &&...vs) { return (val1 > val2) ? max(val1, std::forward(vs)...) : max(val2, std::forward(vs)...); } struct VarUseAnalyzer : public util::Operator { std::unordered_map varCounts; std::unordered_map varAssignCounts; void preHook(Node *v) override { for (auto *var : v->getUsedVariables()) { ++varCounts[var->getId()]; } } void handle(AssignInstr *v) override { ++varAssignCounts[v->getLhs()->getId()]; } }; struct SideEfectAnalyzer : public util::ConstVisitor { using Status = util::SideEffectStatus; static Status getFunctionStatusFromAttributes(const Func *v, bool *force = nullptr) { auto attr = [v](const auto &s) { return util::hasAttribute(v, s); }; if (attr(util::PURE_ATTR)) { if (force) *force = true; return Status::PURE; } if (attr(util::NO_SIDE_EFFECT_ATTR)) { if (force) *force = true; return Status::NO_SIDE_EFFECT; } if (attr(util::NO_CAPTURE_ATTR)) { if (force) *force = true; return Status::NO_CAPTURE; } if (attr(util::NON_PURE_ATTR)) { if (force) *force = true; return Status::UNKNOWN; } if (force) *force = false; return Status::UNKNOWN; } VarUseAnalyzer &vua; dataflow::CaptureResult *cr; bool globalAssignmentHasSideEffects; std::unordered_map result; std::vector funcStack; Status exprStatus; Status funcStatus; // We have to sometimes be careful with globals since future // IR passes might introduce globals that we've eliminated // or demoted earlier. Hence the distinction with whether // global assignments are considered to have side effects. SideEfectAnalyzer(VarUseAnalyzer &vua, dataflow::CaptureResult *cr, bool globalAssignmentHasSideEffects) : util::ConstVisitor(), vua(vua), cr(cr), globalAssignmentHasSideEffects(globalAssignmentHasSideEffects), result(), funcStack(), exprStatus(Status::PURE), funcStatus(Status::PURE) {} template bool has(const T *v) { return result.find(v->getId()) != result.end(); } template void set(const T *v, Status expr, Status func = Status::PURE) { result[v->getId()] = exprStatus = expr; funcStatus = max(funcStatus, func); } template Status process(const T *v) { if (!v) return Status::PURE; if (has(v)) return result[v->getId()]; v->accept(*this); seqassertn(has(v), "node not added to results"); return result[v->getId()]; } std::pair getVarAssignStatus(const Var *var) { if (!var) return {Status::PURE, Status::PURE}; auto id = var->getId(); auto it1 = vua.varCounts.find(id); auto it2 = vua.varAssignCounts.find(id); auto count1 = (it1 != vua.varCounts.end()) ? it1->second : 0; auto count2 = (it2 != vua.varAssignCounts.end()) ? it2->second : 0; bool global = var->isGlobal(); bool used = (count1 != count2); Status defaultStatus = global ? Status::UNKNOWN : Status::NO_CAPTURE; auto se2stat = [&](bool b) { return b ? defaultStatus : Status::PURE; }; if (globalAssignmentHasSideEffects || var->isExternal()) { return {se2stat(used || global), se2stat(global)}; } else { return {se2stat(used), se2stat(used && global)}; } } void handleVarAssign(const Value *v, const Var *var, Status base) { auto pair = getVarAssignStatus(var); set(v, max(pair.first, base), pair.second); } void visit(const Module *v) override { process(v->getMainFunc()); for (auto *x : *v) { process(x); } } void visit(const Var *v) override { set(v, Status::PURE); } void visit(const BodiedFunc *v) override { bool force; auto s = getFunctionStatusFromAttributes(v, &force); set(v, s, s); // avoid infinite recursion auto oldFuncStatus = funcStatus; funcStatus = Status::PURE; funcStack.push_back(v); process(v->getBody()); funcStack.pop_back(); if (force) funcStatus = s; set(v, funcStatus); funcStatus = oldFuncStatus; } void visit(const ExternalFunc *v) override { set(v, getFunctionStatusFromAttributes(v)); } void visit(const InternalFunc *v) override { set(v, Status::PURE); } void visit(const LLVMFunc *v) override { set(v, getFunctionStatusFromAttributes(v)); } void visit(const VarValue *v) override { set(v, Status::PURE); } void visit(const PointerValue *v) override { set(v, Status::PURE); } void visit(const SeriesFlow *v) override { Status s = Status::PURE; for (auto *x : *v) { s = max(s, process(x)); } set(v, s); } void visit(const IfFlow *v) override { set(v, max(process(v->getCond()), process(v->getTrueBranch()), process(v->getFalseBranch()))); } void visit(const WhileFlow *v) override { set(v, max(process(v->getCond()), process(v->getBody()))); } void visit(const ForFlow *v) override { auto s = max(process(v->getIter()), process(v->getBody())); if (auto *sched = v->getSchedule()) { for (auto *x : sched->getUsedValues()) { s = max(s, process(x)); } } handleVarAssign(v, v->getVar(), s); } void visit(const ImperativeForFlow *v) override { auto s = max(process(v->getStart()), process(v->getEnd()), process(v->getBody())); if (auto *sched = v->getSchedule()) { for (auto *x : sched->getUsedValues()) { s = max(s, process(x)); } } handleVarAssign(v, v->getVar(), s); } void visit(const TryCatchFlow *v) override { auto s = max(process(v->getBody()), process(v->getElse()), process(v->getFinally())); auto callStatus = Status::PURE; for (auto &x : *v) { auto pair = getVarAssignStatus(x.getVar()); s = max(s, pair.first, process(x.getHandler())); callStatus = max(callStatus, pair.second); } set(v, s, callStatus); } void visit(const PipelineFlow *v) override { auto s = Status::PURE; auto callStatus = Status::PURE; for (auto &stage : *v) { // make sure we're treating this as a call if (auto *f = util::getFunc(stage.getCallee())) { auto stageCallStatus = process(f); callStatus = max(callStatus, stageCallStatus); s = max(s, stageCallStatus); } else { // unknown function process(stage.getCallee()); callStatus = Status::UNKNOWN; s = Status::UNKNOWN; } for (auto *arg : stage) { s = max(s, process(arg)); } } set(v, s, callStatus); } void visit(const dsl::CustomFlow *v) override { set(v, v->getSideEffectStatus(/*local=*/true), v->getSideEffectStatus(/*local=*/false)); } void visit(const IntConst *v) override { set(v, Status::PURE); } void visit(const FloatConst *v) override { set(v, Status::PURE); } void visit(const BoolConst *v) override { set(v, Status::PURE); } void visit(const StringConst *v) override { set(v, Status::PURE); } void visit(const dsl::CustomConst *v) override { set(v, Status::PURE); } void visit(const AssignInstr *v) override { handleVarAssign(v, v->getLhs(), process(v->getRhs())); } void visit(const ExtractInstr *v) override { set(v, process(v->getVal())); } void visit(const InsertInstr *v) override { process(v->getLhs()); process(v->getRhs()); auto *func = funcStack.back(); auto it = cr->results.find(func->getId()); seqassertn(it != cr->results.end(), "function not found in capture results"); auto captureInfo = it->second; bool pure = true; for (auto &info : captureInfo) { if (info.externCaptures || info.modified || !info.argCaptures.empty()) { pure = false; break; } } if (pure) { // make sure the lhs does not escape auto escapeInfo = escapes(func, v->getLhs(), cr); pure = (!escapeInfo || (escapeInfo.returnCaptures && !escapeInfo.externCaptures && escapeInfo.argCaptures.empty())); } set(v, Status::UNKNOWN, pure ? Status::PURE : Status::UNKNOWN); } void visit(const CallInstr *v) override { auto s = process(v->getCallee()); auto callStatus = Status::UNKNOWN; for (auto *x : *v) { s = max(s, process(x)); } if (auto *f = util::getFunc(v->getCallee())) { callStatus = process(f); s = max(s, callStatus); } else { // unknown function s = Status::UNKNOWN; } set(v, s, callStatus); } void visit(const StackAllocInstr *v) override { set(v, Status::PURE); } void visit(const TypePropertyInstr *v) override { set(v, Status::PURE); } void visit(const YieldInInstr *v) override { set(v, Status::NO_CAPTURE); } void visit(const TernaryInstr *v) override { set(v, max(process(v->getCond()), process(v->getTrueValue()), process(v->getFalseValue()))); } void visit(const BreakInstr *v) override { set(v, Status::NO_CAPTURE); } void visit(const ContinueInstr *v) override { set(v, Status::NO_CAPTURE); } void visit(const ReturnInstr *v) override { set(v, max(Status::NO_CAPTURE, process(v->getValue()))); } void visit(const YieldInstr *v) override { set(v, max(Status::NO_CAPTURE, process(v->getValue()))); } void visit(const AwaitInstr *v) override { set(v, max(Status::NO_CAPTURE, process(v->getValue()))); } void visit(const ThrowInstr *v) override { process(v->getValue()); set(v, Status::UNKNOWN, Status::NO_CAPTURE); } void visit(const FlowInstr *v) override { set(v, max(process(v->getFlow()), process(v->getValue()))); } void visit(const dsl::CustomInstr *v) override { set(v, v->getSideEffectStatus(/*local=*/true), v->getSideEffectStatus(/*local=*/false)); } }; } // namespace const std::string SideEffectAnalysis::KEY = "core-analyses-side-effect"; bool SideEffectResult::hasSideEffect(const Value *v) const { auto it = result.find(v->getId()); return it == result.end() || it->second != util::SideEffectStatus::PURE; } std::unique_ptr SideEffectAnalysis::run(const Module *m) { auto *capResult = getAnalysisResult(capAnalysisKey); VarUseAnalyzer vua; const_cast(m)->accept(vua); SideEfectAnalyzer sea(vua, capResult, globalAssignmentHasSideEffects); m->accept(sea); return std::make_unique(sea.result); } } // namespace module } // namespace analyze } // namespace ir } // namespace codon ================================================ FILE: codon/cir/analyze/module/side_effect.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/cir/analyze/analysis.h" #include "codon/cir/util/side_effect.h" namespace codon { namespace ir { namespace analyze { namespace module { struct SideEffectResult : public Result { /// mapping of ID to corresponding node's side effect status std::unordered_map result; SideEffectResult(std::unordered_map result) : result(std::move(result)) {} /// @param v the value to check /// @return true if the node has side effects (false positives allowed) bool hasSideEffect(const Value *v) const; }; class SideEffectAnalysis : public Analysis { private: /// the capture analysis key std::string capAnalysisKey; /// true if assigning to a global variable automatically has side effects bool globalAssignmentHasSideEffects; public: static const std::string KEY; /// Constructs a side effect analysis. /// @param globalAssignmentHasSideEffects true if global variable assignment /// automatically has side effects explicit SideEffectAnalysis(const std::string &capAnalysisKey, bool globalAssignmentHasSideEffects = true) : Analysis(), capAnalysisKey(capAnalysisKey), globalAssignmentHasSideEffects(globalAssignmentHasSideEffects) {} std::string getKey() const override { return KEY; } std::unique_ptr run(const Module *m) override; }; } // namespace module } // namespace analyze } // namespace ir } // namespace codon ================================================ FILE: codon/cir/attribute.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "attribute.h" #include "codon/cir/func.h" #include "codon/cir/util/cloning.h" #include "codon/cir/value.h" #include namespace codon { namespace ir { const int StringValueAttribute::AttributeID = 101; const int IntValueAttribute::AttributeID = 102; const int StringListAttribute::AttributeID = 103; const int KeyValueAttribute::AttributeID = 104; const int MemberAttribute::AttributeID = 105; const int PythonWrapperAttribute::AttributeID = 106; const int SrcInfoAttribute::AttributeID = 107; const int DocstringAttribute::AttributeID = 108; const int TupleLiteralAttribute::AttributeID = 109; const int ListLiteralAttribute::AttributeID = 111; const int SetLiteralAttribute::AttributeID = 111; const int DictLiteralAttribute::AttributeID = 112; const int PartialFunctionAttribute::AttributeID = 113; std::ostream &StringListAttribute::doFormat(std::ostream &os) const { fmt::print(os, FMT_STRING("{}"), fmt::join(values.begin(), values.end(), ",")); return os; } bool KeyValueAttribute::has(const std::string &key) const { return attributes.find(key) != attributes.end(); } std::string KeyValueAttribute::get(const std::string &key) const { auto it = attributes.find(key); return it != attributes.end() ? it->second : ""; } std::ostream &KeyValueAttribute::doFormat(std::ostream &os) const { std::vector keys; for (auto &val : attributes) keys.push_back(val.second); fmt::print(os, FMT_STRING("{}"), fmt::join(keys.begin(), keys.end(), ",")); return os; } std::ostream &MemberAttribute::doFormat(std::ostream &os) const { std::vector strings; for (auto &val : memberSrcInfo) strings.push_back(fmt::format(FMT_STRING("{}={}"), val.first, val.second)); fmt::print(os, FMT_STRING("({})"), fmt::join(strings.begin(), strings.end(), ",")); return os; } std::unique_ptr PythonWrapperAttribute::clone(util::CloneVisitor &cv) const { return std::make_unique(cast(cv.clone(original))); } std::unique_ptr PythonWrapperAttribute::forceClone(util::CloneVisitor &cv) const { return std::make_unique(cv.forceClone(original)); } std::ostream &PythonWrapperAttribute::doFormat(std::ostream &os) const { fmt::print(os, FMT_STRING("(pywrap {})"), original->referenceString()); return os; } std::unique_ptr TupleLiteralAttribute::clone(util::CloneVisitor &cv) const { std::vector elementsCloned; for (auto *val : elements) elementsCloned.push_back(cv.clone(val)); return std::make_unique(elementsCloned); } std::unique_ptr TupleLiteralAttribute::forceClone(util::CloneVisitor &cv) const { std::vector elementsCloned; for (auto *val : elements) elementsCloned.push_back(cv.forceClone(val)); return std::make_unique(elementsCloned); } std::ostream &TupleLiteralAttribute::doFormat(std::ostream &os) const { std::vector strings; for (auto *val : elements) strings.push_back(fmt::format(FMT_STRING("{}"), *val)); fmt::print(os, FMT_STRING("({})"), fmt::join(strings.begin(), strings.end(), ",")); return os; } std::unique_ptr ListLiteralAttribute::clone(util::CloneVisitor &cv) const { std::vector elementsCloned; for (auto &e : elements) elementsCloned.push_back({cv.clone(e.value), e.star}); return std::make_unique(elementsCloned); } std::unique_ptr ListLiteralAttribute::forceClone(util::CloneVisitor &cv) const { std::vector elementsCloned; for (auto &e : elements) elementsCloned.push_back({cv.forceClone(e.value), e.star}); return std::make_unique(elementsCloned); } std::ostream &ListLiteralAttribute::doFormat(std::ostream &os) const { std::vector strings; for (auto &e : elements) strings.push_back(fmt::format(FMT_STRING("{}{}"), e.star ? "*" : "", *e.value)); fmt::print(os, FMT_STRING("[{}]"), fmt::join(strings.begin(), strings.end(), ",")); return os; } std::unique_ptr SetLiteralAttribute::clone(util::CloneVisitor &cv) const { std::vector elementsCloned; for (auto &e : elements) elementsCloned.push_back({cv.clone(e.value), e.star}); return std::make_unique(elementsCloned); } std::unique_ptr SetLiteralAttribute::forceClone(util::CloneVisitor &cv) const { std::vector elementsCloned; for (auto &e : elements) elementsCloned.push_back({cv.forceClone(e.value), e.star}); return std::make_unique(elementsCloned); } std::ostream &SetLiteralAttribute::doFormat(std::ostream &os) const { std::vector strings; for (auto &e : elements) strings.push_back(fmt::format(FMT_STRING("{}{}"), e.star ? "*" : "", *e.value)); fmt::print(os, FMT_STRING("set([{}])"), fmt::join(strings.begin(), strings.end(), ",")); return os; } std::unique_ptr DictLiteralAttribute::clone(util::CloneVisitor &cv) const { std::vector elementsCloned; for (auto &val : elements) elementsCloned.push_back( {cv.clone(val.key), val.value ? cv.clone(val.value) : nullptr}); return std::make_unique(elementsCloned); } std::unique_ptr DictLiteralAttribute::forceClone(util::CloneVisitor &cv) const { std::vector elementsCloned; for (auto &val : elements) elementsCloned.push_back( {cv.forceClone(val.key), val.value ? cv.forceClone(val.value) : nullptr}); return std::make_unique(elementsCloned); } std::ostream &DictLiteralAttribute::doFormat(std::ostream &os) const { std::vector strings; for (auto &val : elements) { if (val.value) { strings.push_back(fmt::format(FMT_STRING("{}:{}"), *val.key, *val.value)); } else { strings.push_back(fmt::format(FMT_STRING("**{}"), *val.key)); } } fmt::print(os, FMT_STRING("dict([{}])"), fmt::join(strings.begin(), strings.end(), ",")); return os; } std::unique_ptr PartialFunctionAttribute::clone(util::CloneVisitor &cv) const { std::vector argsCloned; for (auto *val : args) argsCloned.push_back(cv.clone(val)); return std::make_unique(name, argsCloned); } std::unique_ptr PartialFunctionAttribute::forceClone(util::CloneVisitor &cv) const { std::vector argsCloned; for (auto *val : args) argsCloned.push_back(cv.forceClone(val)); return std::make_unique(name, argsCloned); } std::ostream &PartialFunctionAttribute::doFormat(std::ostream &os) const { std::vector strings; for (auto *val : args) strings.push_back(val ? fmt::format(FMT_STRING("{}"), *val) : "..."); fmt::print(os, FMT_STRING("{}({})"), name, fmt::join(strings.begin(), strings.end(), ",")); return os; } } // namespace ir std::unordered_map> clone(const std::unordered_map> &t) { std::unordered_map> r; for (auto &[k, v] : t) r[k] = v->clone(); return r; } } // namespace codon ================================================ FILE: codon/cir/attribute.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include #include "codon/util/common.h" namespace codon { namespace ir { class Func; class Value; namespace util { class CloneVisitor; } /// Base for CIR attributes. struct Attribute { virtual ~Attribute() noexcept = default; /// @return true if the attribute should be propagated across clones virtual bool needsClone() const { return true; } friend std::ostream &operator<<(std::ostream &os, const Attribute &a) { return a.doFormat(os); } /// @return a clone of the attribute virtual std::unique_ptr clone() const { return std::make_unique(); } /// @return a clone of the attribute virtual std::unique_ptr clone(util::CloneVisitor &cv) const { return clone(); } /// @return a clone of the attribute virtual std::unique_ptr forceClone(util::CloneVisitor &cv) const { return clone(cv); } private: virtual std::ostream &doFormat(std::ostream &os) const { return os; } }; /// Attribute containing SrcInfo struct SrcInfoAttribute : public Attribute { static const int AttributeID; /// source info codon::SrcInfo info; SrcInfoAttribute() = default; /// Constructs a SrcInfoAttribute. /// @param info the source info explicit SrcInfoAttribute(codon::SrcInfo info) : info(std::move(info)) {} std::unique_ptr clone() const override { return std::make_unique(*this); } private: std::ostream &doFormat(std::ostream &os) const override { return os << info; } }; /// Attribute containing docstring from source struct StringValueAttribute : public Attribute { static const int AttributeID; std::string value; StringValueAttribute() = default; /// Constructs a StringValueAttribute. explicit StringValueAttribute(const std::string &value) : value(value) {} std::unique_ptr clone() const override { return std::make_unique(*this); } private: std::ostream &doFormat(std::ostream &os) const override { return os << value; } }; /// Attribute containing docstring from source struct DocstringAttribute : public Attribute { static const int AttributeID; /// the docstring std::string docstring; DocstringAttribute() = default; /// Constructs a DocstringAttribute. /// @param docstring the docstring explicit DocstringAttribute(const std::string &docstring) : docstring(docstring) {} std::unique_ptr clone() const override { return std::make_unique(*this); } private: std::ostream &doFormat(std::ostream &os) const override { return os << docstring; } }; /// Attribute containing function information struct KeyValueAttribute : public Attribute { static const int AttributeID; /// attributes map std::unordered_map attributes; KeyValueAttribute() = default; /// Constructs a KeyValueAttribute. /// @param attributes the map of attributes explicit KeyValueAttribute(std::unordered_map attributes) : attributes(std::move(attributes)) {} /// @param key the key /// @return true if the map contains key, false otherwise bool has(const std::string &key) const; /// @param key the key /// @return the value associated with the given key, or empty /// string if none std::string get(const std::string &key) const; std::unique_ptr clone() const override { return std::make_unique(*this); } private: std::ostream &doFormat(std::ostream &os) const override; }; /// Attribute containing function information struct StringListAttribute : public Attribute { static const int AttributeID; /// attributes map std::vector values; StringListAttribute() = default; /// Constructs a StringListAttribute. /// @param attributes the map of attributes explicit StringListAttribute(std::vector values) : values(std::move(values)) {} std::unique_ptr clone() const override { return std::make_unique(*this); } private: std::ostream &doFormat(std::ostream &os) const override; }; /// Attribute containing type member information struct MemberAttribute : public Attribute { static const int AttributeID; /// member source info map std::map memberSrcInfo; MemberAttribute() = default; /// Constructs a KeyValueAttribute. /// @param attributes the map of attributes explicit MemberAttribute(std::map memberSrcInfo) : memberSrcInfo(std::move(memberSrcInfo)) {} std::unique_ptr clone() const override { return std::make_unique(*this); } private: std::ostream &doFormat(std::ostream &os) const override; }; /// Attribute used to mark Python wrappers of Codon functions struct PythonWrapperAttribute : public Attribute { static const int AttributeID; /// the function being wrapped Func *original; /// Constructs a PythonWrapperAttribute. /// @param original the function being wrapped explicit PythonWrapperAttribute(Func *original) : original(original) {} bool needsClone() const override { return false; } std::unique_ptr clone() const override { seqassertn(false, "cannot operate without CloneVisitor"); return nullptr; } std::unique_ptr clone(util::CloneVisitor &cv) const override; std::unique_ptr forceClone(util::CloneVisitor &cv) const override; private: std::ostream &doFormat(std::ostream &os) const override; }; /// Attribute attached to IR structures corresponding to tuple literals struct TupleLiteralAttribute : public Attribute { static const int AttributeID; /// values contained in tuple literal std::vector elements; explicit TupleLiteralAttribute(std::vector elements) : elements(std::move(elements)) {} std::unique_ptr clone() const override { seqassertn(false, "cannot operate without CloneVisitor"); return nullptr; } std::unique_ptr clone(util::CloneVisitor &cv) const override; std::unique_ptr forceClone(util::CloneVisitor &cv) const override; private: std::ostream &doFormat(std::ostream &os) const override; }; /// Information about an element in a collection literal struct LiteralElement { /// the element value Value *value; /// true if preceded by "*", as in "[*x]" bool star; }; /// Attribute attached to IR structures corresponding to list literals struct ListLiteralAttribute : public Attribute { static const int AttributeID; /// elements contained in list literal std::vector elements; explicit ListLiteralAttribute(std::vector elements) : elements(std::move(elements)) {} std::unique_ptr clone() const override { seqassertn(false, "cannot operate without CloneVisitor"); return nullptr; } std::unique_ptr clone(util::CloneVisitor &cv) const override; std::unique_ptr forceClone(util::CloneVisitor &cv) const override; private: std::ostream &doFormat(std::ostream &os) const override; }; /// Attribute attached to IR structures corresponding to set literals struct SetLiteralAttribute : public Attribute { static const int AttributeID; /// elements contained in set literal std::vector elements; explicit SetLiteralAttribute(std::vector elements) : elements(std::move(elements)) {} std::unique_ptr clone() const override { seqassertn(false, "cannot operate without CloneVisitor"); return nullptr; } std::unique_ptr clone(util::CloneVisitor &cv) const override; std::unique_ptr forceClone(util::CloneVisitor &cv) const override; private: std::ostream &doFormat(std::ostream &os) const override; }; /// Attribute attached to IR structures corresponding to dict literals struct DictLiteralAttribute : public Attribute { struct KeyValuePair { /// the key in the literal Value *key; /// the value in the literal, or null if key is being star-unpacked Value *value; }; static const int AttributeID; /// keys and values contained in dict literal std::vector elements; explicit DictLiteralAttribute(std::vector elements) : elements(std::move(elements)) {} std::unique_ptr clone() const override { seqassertn(false, "cannot operate without CloneVisitor"); return nullptr; } std::unique_ptr clone(util::CloneVisitor &cv) const override; std::unique_ptr forceClone(util::CloneVisitor &cv) const override; private: std::ostream &doFormat(std::ostream &os) const override; }; /// Attribute attached to IR structures corresponding to partial functions struct PartialFunctionAttribute : public Attribute { static const int AttributeID; /// base name of the function being used in the partial std::string name; /// partial arguments, or null if none /// e.g. "f(a, ..., b)" has elements [a, null, b] std::vector args; PartialFunctionAttribute(const std::string &name, std::vector args) : name(name), args(std::move(args)) {} std::unique_ptr clone() const override { seqassertn(false, "cannot operate without CloneVisitor"); return nullptr; } std::unique_ptr clone(util::CloneVisitor &cv) const override; std::unique_ptr forceClone(util::CloneVisitor &cv) const override; private: std::ostream &doFormat(std::ostream &os) const override; }; struct IntValueAttribute : public Attribute { static const int AttributeID; int64_t value; IntValueAttribute() = default; /// Constructs a IntValueAttribute. explicit IntValueAttribute(int64_t value) : value(value) {} std::unique_ptr clone() const override { return std::make_unique(*this); } private: std::ostream &doFormat(std::ostream &os) const override { return os << value; } }; } // namespace ir std::unordered_map> clone(const std::unordered_map> &t); } // namespace codon template <> struct fmt::formatter : fmt::ostream_formatter {}; ================================================ FILE: codon/cir/base.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "base.h" #include "codon/cir/types/types.h" #include "codon/cir/util/format.h" #include "codon/cir/value.h" #include "codon/cir/var.h" namespace codon { namespace ir { id_t IdMixin::currentId = 0; void IdMixin::resetId() { currentId = 0; } const char Node::NodeId = 0; std::ostream &operator<<(std::ostream &os, const Node &other) { return util::format(os, &other); } Node::Node(const Node &n) : name(n.name), module(n.module), replacement(n.replacement), attributes(codon::clone(n.attributes)) {} int Node::replaceUsedValue(Value *old, Value *newValue) { return replaceUsedValue(old->getId(), newValue); } int Node::replaceUsedType(types::Type *old, types::Type *newType) { return replaceUsedType(old->getName(), newType); } int Node::replaceUsedVariable(Var *old, Var *newVar) { return replaceUsedVariable(old->getId(), newVar); } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/base.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include "codon/cir/attribute.h" #include "codon/cir/util/iterators.h" #include "codon/cir/util/visitor.h" #include "codon/util/common.h" #include #include namespace codon { namespace ir { using id_t = std::int64_t; class Func; class Module; /// Mixin class for IR nodes that need ids. class IdMixin { private: /// the global id counter static id_t currentId; protected: /// the instance's id id_t id; public: /// Resets the global id counter. static void resetId(); IdMixin() : id(currentId++) {} /// @return the node's id. virtual id_t getId() const { return id; } }; /// Base for named IR nodes. class Node { private: /// the node's name std::string name; /// the module Module *module = nullptr; /// a replacement, if set Node *replacement = nullptr; protected: /// key-value attribute store std::unordered_map> attributes; public: // RTTI is implemented using a port of LLVM's Extensible RTTI // For more details, see // https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html#rtti-for-open-class-hierarchies static const char NodeId; /// Constructs a node. /// @param name the node's name explicit Node(std::string name = "") : name(std::move(name)) {} /// Constructs a node. /// @param name the node's name explicit Node(const Node &n); /// See LLVM documentation. static const void *nodeId() { return &NodeId; } /// See LLVM documentation. virtual const void *dynamicNodeId() const = 0; /// See LLVM documentation. virtual bool isConvertible(const void *other) const { if (hasReplacement()) return getActual()->isConvertible(other); return other == nodeId(); } /// See LLVM documentation. template bool is() const { return isConvertible(Target::nodeId()); } /// See LLVM documentation. template Target *as() { return isConvertible(Target::nodeId()) ? static_cast(getActual()) : nullptr; } /// See LLVM documentation. template const Target *as() const { return isConvertible(Target::nodeId()) ? static_cast(getActual()) : nullptr; } /// @return the node's name const std::string &getName() const { return getActual()->name; } /// Sets the node's name /// @param n the new name void setName(std::string n) { getActual()->name = std::move(n); } /// Accepts visitors. /// @param v the visitor virtual void accept(util::Visitor &v) {} /// Accepts visitors. /// @param v the visitor virtual void accept(util::ConstVisitor &v) const {} /// Sets an attribute /// @param the attribute key /// @param value the attribute void setAttribute(std::unique_ptr value, int key) { getActual()->attributes[key] = std::move(value); } /// Sets an attribute /// @param value the attribute template void setAttribute(std::unique_ptr value) { setAttribute(std::move(value), AttributeType::AttributeID); } /// @param n attribute ID /// @return true if the attribute is in the store bool hasAttribute(int n) const { auto *actual = getActual(); return actual->attributes.find(n) != actual->attributes.end(); } /// @return true if the attribute is in the store template bool hasAttribute() const { return hasAttribute(AttributeType::AttributeID); } /// Gets the appropriate attribute. /// @param key the attribute key Attribute *getAttribute(int key) { auto *actual = getActual(); auto it = actual->attributes.find(key); return it != actual->attributes.end() ? it->second.get() : nullptr; } /// Gets the appropriate attribute. /// @param key the attribute key const Attribute *getAttribute(int key) const { auto *actual = getActual(); auto it = actual->attributes.find(key); return it != actual->attributes.end() ? it->second.get() : nullptr; } /// Gets the appropriate attribute. /// @tparam AttributeType the return type template AttributeType *getAttribute() { return static_cast(getAttribute(AttributeType::AttributeID)); } /// Gets the appropriate attribute. /// @tparam AttributeType the return type template const AttributeType *getAttribute() const { return static_cast(getAttribute(AttributeType::AttributeID)); } template AttributeType *getAttribute(int key) { return static_cast(getAttribute(key)); } template const AttributeType *getAttribute(int key) const { return static_cast(getAttribute(key)); } void eraseAttribute(int key) { attributes.erase(key); } void cloneAttributesFrom(Node *n) { attributes = codon::clone(n->attributes); } /// @return iterator to the first attribute auto attributes_begin() { return util::map_key_adaptor(getActual()->attributes.begin()); } /// @return iterator beyond the last attribute auto attributes_end() { return util::map_key_adaptor(getActual()->attributes.end()); } /// @return iterator to the first attribute auto attributes_begin() const { return util::const_map_key_adaptor(getActual()->attributes.begin()); } /// @return iterator beyond the last attribute auto attributes_end() const { return util::const_map_key_adaptor(getActual()->attributes.end()); } /// Helper to add source information. /// @param the source information void setSrcInfo(codon::SrcInfo s) { setAttribute(std::make_unique(std::move(s))); } /// @return the src info codon::SrcInfo getSrcInfo() const { auto a = getAttribute(); return a ? a->info : codon::SrcInfo(); } /// @return a text representation of a reference to the object virtual std::string referenceString() const { return getActual()->name; } /// @return the IR module Module *getModule() const { return getActual()->module; } /// Sets the module. /// @param m the new module void setModule(Module *m) { getActual()->module = m; } friend std::ostream &operator<<(std::ostream &os, const Node &a); bool hasReplacement() const { return replacement != nullptr; } /// @return a vector of all the node's children virtual std::vector getUsedValues() { return {}; } /// @return a vector of all the node's children virtual std::vector getUsedValues() const { return {}; } /// Physically replaces all instances of a child value. /// @param id the id of the value to be replaced /// @param newValue the new value /// @return number of replacements virtual int replaceUsedValue(id_t id, Value *newValue) { return 0; } /// Physically replaces all instances of a child value. /// @param oldValue the old value /// @param newValue the new value /// @return number of replacements int replaceUsedValue(Value *old, Value *newValue); /// @return a vector of all the node's used types virtual std::vector getUsedTypes() const { return {}; } /// Physically replaces all instances of a used type. /// @param name the name of the type being replaced /// @param newType the new type /// @return number of replacements virtual int replaceUsedType(const std::string &name, types::Type *newType) { return 0; } /// Physically replaces all instances of a used type. /// @param old the old type /// @param newType the new type /// @return number of replacements int replaceUsedType(types::Type *old, types::Type *newType); /// @return a vector of all the node's used variables virtual std::vector getUsedVariables() { return {}; } /// @return a vector of all the node's used variables virtual std::vector getUsedVariables() const { return {}; } /// Physically replaces all instances of a used variable. /// @param id the id of the variable /// @param newType the new type /// @return number of replacements virtual int replaceUsedVariable(id_t id, Var *newVar) { return 0; } /// Physically replaces all instances of a used variable. /// @param old the old variable /// @param newVar the new variable /// @return number of replacements int replaceUsedVariable(Var *old, Var *newVar); template friend class AcceptorExtend; template friend class ReplaceableNodeBase; private: Node *getActual() { return replacement ? replacement->getActual() : this; } const Node *getActual() const { return replacement ? replacement->getActual() : this; } }; template class AcceptorExtend : public Parent { public: using Parent::Parent; /// See LLVM documentation. static const void *nodeId() { return &Derived::NodeId; } /// See LLVM documentation. const void *dynamicNodeId() const override { return &Derived::NodeId; } /// See LLVM documentation. virtual bool isConvertible(const void *other) const override { if (Node::hasReplacement()) return Node::getActual()->isConvertible(other); return other == nodeId() || Parent::isConvertible(other); } void accept(util::Visitor &v) override { if (Node::hasReplacement()) Node::getActual()->accept(v); else v.visit(static_cast(this)); } void accept(util::ConstVisitor &v) const override { if (Node::hasReplacement()) Node::getActual()->accept(v); else v.visit(static_cast(this)); } }; template class ReplaceableNodeBase : public AcceptorExtend { private: /// true if the node can be lazily replaced bool replaceable = true; public: using AcceptorExtend::AcceptorExtend; static const char NodeId; /// @return the logical value of the node Derived *getActual() { return Node::replacement ? static_cast(Node::replacement)->getActual() : static_cast(this); } /// @return the logical value of the node const Derived *getActual() const { return Node::replacement ? static_cast(Node::replacement)->getActual() : static_cast(this); } /// Lazily replaces all instances of the node. /// @param v the new value void replaceAll(Derived *v) { seqassertn(replaceable, "node {} not replaceable", *v); Node::replacement = v; } /// @return true if the object can be replaced bool isReplaceable() const { return replaceable; } /// Sets the object's replaceable flag. /// @param v the new value void setReplaceable(bool v = true) { replaceable = v; } }; template const char ReplaceableNodeBase::NodeId = 0; template Desired *cast(Node *other) { return other != nullptr ? other->as() : nullptr; } template const Desired *cast(const Node *other) { return other != nullptr ? other->as() : nullptr; } template bool isA(Node *other) { return other && other->is(); } template bool isA(const Node *other) { return other && other->is(); } } // namespace ir } // namespace codon template <> struct fmt::formatter : fmt::ostream_formatter {}; ================================================ FILE: codon/cir/cir.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/const.h" #include "codon/cir/dsl/nodes.h" #include "codon/cir/flow.h" #include "codon/cir/func.h" #include "codon/cir/instr.h" #include "codon/cir/module.h" #include "codon/cir/types/types.h" #include "codon/cir/value.h" #include "codon/cir/var.h" ================================================ FILE: codon/cir/const.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "const.h" namespace codon { namespace ir { const char Const::NodeId = 0; int Const::doReplaceUsedType(const std::string &name, types::Type *newType) { if (type->getName() == name) { type = newType; return 1; } return 0; } const char TemplatedConst::NodeId = 0; } // namespace ir } // namespace codon ================================================ FILE: codon/cir/const.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/module.h" #include "codon/cir/value.h" namespace codon { namespace ir { /// CIR constant base. Once created, constants are immutable. class Const : public AcceptorExtend { private: /// the type types::Type *type; public: static const char NodeId; /// Constructs a constant. /// @param type the type /// @param name the name explicit Const(types::Type *type, std::string name = "") : AcceptorExtend(std::move(name)), type(type) {} private: types::Type *doGetType() const override { return type; } std::vector doGetUsedTypes() const override { return {type}; } int doReplaceUsedType(const std::string &name, types::Type *newType) override; }; template class TemplatedConst : public AcceptorExtend, Const> { private: ValueType val; public: static const char NodeId; using AcceptorExtend, Const>::getModule; using AcceptorExtend, Const>::getSrcInfo; using AcceptorExtend, Const>::getType; TemplatedConst(ValueType v, types::Type *type, std::string name = "") : AcceptorExtend, Const>(type, std::move(name)), val(v) {} /// @return the internal value. ValueType getVal() const { return val; } /// Sets the value. /// @param v the value void setVal(ValueType v) { val = v; } }; using IntConst = TemplatedConst; using FloatConst = TemplatedConst; using BoolConst = TemplatedConst; using StringConst = TemplatedConst; template const char TemplatedConst::NodeId = 0; template <> class TemplatedConst : public AcceptorExtend, Const> { private: std::string val; public: static const char NodeId; TemplatedConst(std::string v, types::Type *type, std::string name = "") : AcceptorExtend(type, std::move(name)), val(std::move(v)) {} /// @return the internal value. std::string getVal() const { return val; } /// Sets the value. /// @param v the value void setVal(std::string v) { val = std::move(v); } }; } // namespace ir } // namespace codon ================================================ FILE: codon/cir/dsl/codegen.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/cir/llvm/llvm.h" #include "codon/cir/types/types.h" namespace codon { namespace ir { namespace analyze { namespace dataflow { class CFVisitor; } // namespace dataflow } // namespace analyze class LLVMVisitor; namespace dsl { namespace codegen { /// Builder for LLVM types. struct TypeBuilder { virtual ~TypeBuilder() noexcept = default; /// Construct the LLVM type. /// @param the LLVM visitor /// @return the LLVM type virtual llvm::Type *buildType(LLVMVisitor *visitor) = 0; /// Construct the LLVM debug type. /// @param the LLVM visitor /// @return the LLVM debug type virtual llvm::DIType *buildDebugType(LLVMVisitor *visitor) = 0; }; /// Builder for LLVM values. struct ValueBuilder { virtual ~ValueBuilder() noexcept = default; /// Construct the LLVM value. /// @param the LLVM visitor /// @return the LLVM value virtual llvm::Value *buildValue(LLVMVisitor *visitor) = 0; }; /// Builder for control flow graphs. struct CFBuilder { virtual ~CFBuilder() noexcept = default; /// Construct the control-flow nodes. /// @param graph the graph virtual void buildCFNodes(analyze::dataflow::CFVisitor *visitor) = 0; }; } // namespace codegen } // namespace dsl } // namespace ir } // namespace codon ================================================ FILE: codon/cir/dsl/nodes.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "nodes.h" namespace codon { namespace ir { namespace dsl { namespace types { const char CustomType::NodeId = 0; } const char CustomConst::NodeId = 0; const char CustomFlow::NodeId = 0; const char CustomInstr::NodeId = 0; } // namespace dsl } // namespace ir } // namespace codon ================================================ FILE: codon/cir/dsl/nodes.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/cir/base.h" #include "codon/cir/const.h" #include "codon/cir/instr.h" #include "codon/cir/util/side_effect.h" namespace codon { namespace ir { namespace util { class CloneVisitor; } // namespace util namespace dsl { namespace codegen { struct CFBuilder; struct TypeBuilder; struct ValueBuilder; } // namespace codegen namespace types { /// DSL type. class CustomType : public AcceptorExtend { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; /// @return the type builder virtual std::unique_ptr getBuilder() const = 0; /// Compares DSL nodes. /// @param v the other node /// @return true if they match virtual bool match(const Type *v) const = 0; /// Format the DSL node. /// @param os the output stream virtual std::ostream &doFormat(std::ostream &os) const = 0; }; } // namespace types /// DSL constant. class CustomConst : public AcceptorExtend { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; /// @return the value builder virtual std::unique_ptr getBuilder() const = 0; /// Compares DSL nodes. /// @param v the other node /// @return true if they match virtual bool match(const Value *v) const = 0; /// Clones the value. /// @param cv the clone visitor /// @return a clone of the object virtual Value *doClone(util::CloneVisitor &cv) const = 0; /// Format the DSL node. /// @param os the output stream virtual std::ostream &doFormat(std::ostream &os) const = 0; }; /// DSL flow. class CustomFlow : public AcceptorExtend { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; /// @return the value builder virtual std::unique_ptr getBuilder() const = 0; /// Compares DSL nodes. /// @param v the other node /// @return true if they match virtual bool match(const Value *v) const = 0; /// Clones the value. /// @param cv the clone visitor /// @return a clone of the object virtual Value *doClone(util::CloneVisitor &cv) const = 0; /// @return the control-flow builder virtual std::unique_ptr getCFBuilder() const = 0; /// Query this custom node for its side effect properties. If "local" /// is true, then the return value should reflect this node and this /// node alone, otherwise the value should reflect functions containing /// this node in their bodies. For example, a "break" instruction has /// side effects locally, but functions containing "break" might still /// be side effect free, hence the distinction. /// @param local true if result should reflect only this node /// @return this node's side effect status virtual util::SideEffectStatus getSideEffectStatus(bool local = true) const { return util::SideEffectStatus::UNKNOWN; } /// Format the DSL node. /// @param os the output stream virtual std::ostream &doFormat(std::ostream &os) const = 0; }; /// DSL instruction. class CustomInstr : public AcceptorExtend { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; /// @return the value builder virtual std::unique_ptr getBuilder() const = 0; /// Compares DSL nodes. /// @param v the other node /// @return true if they match virtual bool match(const Value *v) const = 0; /// Clones the value. /// @param cv the clone visitor /// @return a clone of the object virtual Value *doClone(util::CloneVisitor &cv) const = 0; /// @return the control-flow builder virtual std::unique_ptr getCFBuilder() const = 0; /// Query this custom node for its side effect properties. If "local" /// is true, then the return value should reflect this node and this /// node alone, otherwise the value should reflect functions containing /// this node in their bodies. For example, a "break" instruction has /// side effects locally, but functions containing "break" might still /// be side effect free, hence the distinction. /// @param local true if result should reflect only this node /// @return this node's side effect status virtual util::SideEffectStatus getSideEffectStatus(bool local = true) const { return util::SideEffectStatus::UNKNOWN; } /// Format the DSL node. /// @param os the output stream virtual std::ostream &doFormat(std::ostream &os) const = 0; }; } // namespace dsl } // namespace ir } // namespace codon ================================================ FILE: codon/cir/flow.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "flow.h" #include "codon/cir/module.h" #include "codon/cir/util/iterators.h" #include namespace codon { namespace ir { namespace { int findAndReplace(id_t id, codon::ir::Value *newVal, std::list &values) { auto replacements = 0; for (auto &value : values) { if (value->getId() == id) { value = newVal; ++replacements; } } return replacements; } } // namespace const char Flow::NodeId = 0; types::Type *Flow::doGetType() const { return getModule()->getNoneType(); } const char SeriesFlow::NodeId = 0; int SeriesFlow::doReplaceUsedValue(id_t id, Value *newValue) { return findAndReplace(id, newValue, series); } const char WhileFlow::NodeId = 0; int WhileFlow::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; if (cond->getId() == id) { cond = newValue; ++replacements; } if (body->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); body = f; ++replacements; } return replacements; } const char ForFlow::NodeId = 0; std::vector ForFlow::doGetUsedValues() const { std::vector ret; if (isParallel()) ret = getSchedule()->getUsedValues(); ret.push_back(iter); ret.push_back(body); return ret; } int ForFlow::doReplaceUsedValue(id_t id, Value *newValue) { auto count = 0; if (isParallel()) count += getSchedule()->replaceUsedValue(id, newValue); if (iter->getId() == id) { iter = newValue; ++count; } if (body->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); body = f; ++count; } return count; } int ForFlow::doReplaceUsedVariable(id_t id, Var *newVar) { if (var->getId() == id) { var = newVar; return 1; } return 0; } const char ImperativeForFlow::NodeId = 0; std::vector ImperativeForFlow::doGetUsedValues() const { std::vector ret; if (isParallel()) ret = getSchedule()->getUsedValues(); ret.push_back(start); ret.push_back(end); ret.push_back(body); return ret; } int ImperativeForFlow::doReplaceUsedValue(id_t id, Value *newValue) { auto count = 0; if (isParallel()) count += getSchedule()->replaceUsedValue(id, newValue); if (body->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); body = f; ++count; } if (start->getId() == id) { start = newValue; ++count; } if (end->getId() == id) { end = newValue; ++count; } return count; } int ImperativeForFlow::doReplaceUsedVariable(id_t id, Var *newVar) { if (var->getId() == id) { var = newVar; return 1; } return 0; } const char IfFlow::NodeId = 0; std::vector IfFlow::doGetUsedValues() const { std::vector ret = {cond, trueBranch}; if (falseBranch) ret.push_back(falseBranch); return ret; } int IfFlow::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; if (cond->getId() == id) { cond = newValue; ++replacements; } if (trueBranch->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); trueBranch = f; ++replacements; } if (falseBranch && falseBranch->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); falseBranch = f; ++replacements; } return replacements; } const char TryCatchFlow::NodeId = 0; std::vector TryCatchFlow::doGetUsedValues() const { std::vector ret = {body}; if (else_) ret.push_back(else_); if (finally) ret.push_back(finally); for (auto &c : catches) ret.push_back(const_cast(static_cast(c.getHandler()))); return ret; } int TryCatchFlow::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; if (body->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); body = f; ++replacements; } if (else_ && else_->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); else_ = f; ++replacements; } if (finally && finally->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); finally = f; ++replacements; } for (auto &c : catches) { if (c.getHandler()->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); c.setHandler(f); ++replacements; } } return replacements; } std::vector TryCatchFlow::doGetUsedTypes() const { std::vector ret; for (auto &c : catches) { if (auto *t = c.getType()) ret.push_back(const_cast(t)); } return ret; } int TryCatchFlow::doReplaceUsedType(const std::string &name, types::Type *newType) { auto count = 0; for (auto &c : catches) { if (c.getType()->getName() == name) { c.setType(newType); ++count; } } return count; } std::vector TryCatchFlow::doGetUsedVariables() const { std::vector ret; for (auto &c : catches) { if (auto *t = c.getVar()) ret.push_back(const_cast(t)); } return ret; } int TryCatchFlow::doReplaceUsedVariable(id_t id, Var *newVar) { auto count = 0; for (auto &c : catches) { if (c.getVar()->getId() == id) { c.setVar(newVar); ++count; } } return count; } const char PipelineFlow::NodeId = 0; types::Type *PipelineFlow::Stage::getOutputType() const { if (args.empty()) { return callee->getType(); } else { auto *funcType = cast(callee->getType()); seqassertn(funcType, "{} is not a function type", *callee->getType()); return funcType->getReturnType(); } } types::Type *PipelineFlow::Stage::getOutputElementType() const { if (isGenerator()) { types::GeneratorType *genType = nullptr; if (args.empty()) { genType = cast(callee->getType()); return genType->getBase(); } else { auto *funcType = cast(callee->getType()); seqassertn(funcType, "{} is not a function type", *callee->getType()); genType = cast(funcType->getReturnType()); } seqassertn(genType, "generator type not found"); return genType->getBase(); } else if (args.empty()) { return callee->getType(); } else { auto *funcType = cast(callee->getType()); seqassertn(funcType, "{} is not a function type", *callee->getType()); return funcType->getReturnType(); } } std::vector PipelineFlow::doGetUsedValues() const { std::vector ret; for (auto &s : stages) { ret.push_back(const_cast(s.getCallee())); for (auto *arg : s.args) if (arg) ret.push_back(arg); } return ret; } int PipelineFlow::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; for (auto &c : stages) { if (c.getCallee()->getId() == id) { c.setCallee(newValue); ++replacements; } for (auto &s : c.args) if (s && s->getId() == id) { s = newValue; ++replacements; } } return replacements; } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/flow.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include "codon/cir/base.h" #include "codon/cir/transform/parallel/schedule.h" #include "codon/cir/value.h" #include "codon/cir/var.h" namespace codon { namespace ir { /// Base for flows, which represent control flow. class Flow : public AcceptorExtend { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; protected: types::Type *doGetType() const final; }; /// Flow that contains a series of flows or instructions. class SeriesFlow : public AcceptorExtend { private: std::list series; public: static const char NodeId; using AcceptorExtend::AcceptorExtend; /// @return an iterator to the first instruction/flow auto begin() { return series.begin(); } /// @return an iterator beyond the last instruction/flow auto end() { return series.end(); } /// @return an iterator to the first instruction/flow auto begin() const { return series.begin(); } /// @return an iterator beyond the last instruction/flow auto end() const { return series.end(); } /// @return a pointer to the first instruction/flow Value *front() { return series.front(); } /// @return a pointer to the last instruction/flow Value *back() { return series.back(); } /// @return a pointer to the first instruction/flow const Value *front() const { return series.front(); } /// @return a pointer to the last instruction/flow const Value *back() const { return series.back(); } /// Inserts an instruction/flow at the given position. /// @param pos the position /// @param v the flow or instruction /// @return an iterator to the newly added instruction/flow template auto insert(It pos, Value *v) { return series.insert(pos, v); } /// Appends an instruction/flow. /// @param f the flow or instruction void push_back(Value *f) { series.push_back(f); } /// Erases the item at the supplied position. /// @param pos the position /// @return the iterator beyond the removed flow or instruction template auto erase(It pos) { return series.erase(pos); } protected: std::vector doGetUsedValues() const override { return std::vector(series.begin(), series.end()); } int doReplaceUsedValue(id_t id, Value *newValue) override; }; /// Flow representing a while loop. class WhileFlow : public AcceptorExtend { private: /// the condition Value *cond; /// the body Value *body; public: static const char NodeId; /// Constructs a while loop. /// @param cond the condition /// @param body the body /// @param name the flow's name WhileFlow(Value *cond, Flow *body, std::string name = "") : AcceptorExtend(std::move(name)), cond(cond), body(body) {} /// @return the condition Value *getCond() { return cond; } /// @return the condition const Value *getCond() const { return cond; } /// Sets the condition. /// @param c the new condition void setCond(Value *c) { cond = c; } /// @return the body Flow *getBody() { return cast(body); } /// @return the body const Flow *getBody() const { return cast(body); } /// Sets the body. /// @param f the new value void setBody(Flow *f) { body = f; } protected: std::vector doGetUsedValues() const override { return {cond, body}; } int doReplaceUsedValue(id_t id, Value *newValue) override; }; /// Flow representing a for loop. class ForFlow : public AcceptorExtend { private: /// the iterator Value *iter; /// the body Value *body; /// the variable Var *var; /// parallel loop schedule, or null if none std::unique_ptr schedule; /// true if loop is async bool async; public: static const char NodeId; /// Constructs a for loop. /// @param iter the iterator /// @param body the body /// @param var the variable /// @param schedule the parallel schedule /// @param async true if loop is async /// @param name the flow's name ForFlow(Value *iter, Flow *body, Var *var, std::unique_ptr schedule = {}, bool async = false, std::string name = "") : AcceptorExtend(std::move(name)), iter(iter), body(body), var(var), schedule(std::move(schedule)), async(async) {} /// @return the iter Value *getIter() { return iter; } /// @return the iter const Value *getIter() const { return iter; } /// Sets the iter. /// @param f the new iter void setIter(Value *f) { iter = f; } /// @return the body Flow *getBody() { return cast(body); } /// @return the body const Flow *getBody() const { return cast(body); } /// Sets the body. /// @param f the new body void setBody(Flow *f) { body = f; } /// @return the var Var *getVar() { return var; } /// @return the var const Var *getVar() const { return var; } /// Sets the var. /// @param c the new var void setVar(Var *c) { var = c; } /// @return true if parallel bool isParallel() const { return bool(schedule); } /// Sets parallel status. /// @param a true if parallel void setParallel(bool a = true) { if (a) schedule = std::make_unique(); else schedule = std::unique_ptr(); } /// @return the parallel loop schedule, or null if none transform::parallel::OMPSched *getSchedule() { return schedule.get(); } /// @return the parallel loop schedule, or null if none const transform::parallel::OMPSched *getSchedule() const { return schedule.get(); } /// Sets the parallel loop schedule /// @param s the schedule string (e.g. OpenMP pragma) void setSchedule(std::unique_ptr s) { schedule = std::move(s); } /// @return true if async bool isAsync() const { return async; } /// Sets async status. /// @param a true if async void setAsync(bool a = true) { async = a; } protected: std::vector doGetUsedValues() const override; int doReplaceUsedValue(id_t id, Value *newValue) override; std::vector doGetUsedVariables() const override { return {var}; } int doReplaceUsedVariable(id_t id, Var *newVar) override; }; /// Flow representing an imperative for loop. class ImperativeForFlow : public AcceptorExtend { private: /// the initial value Value *start; /// the step value int64_t step; /// the end value Value *end; /// the body Value *body; /// the variable, must be integer type Var *var; /// parallel loop schedule, or null if none std::unique_ptr schedule; public: static const char NodeId; /// Constructs an imperative for loop. /// @param body the body /// @param start the start value /// @param step the step value /// @param end the end value /// @param var the end variable, must be integer /// @param name the flow's name ImperativeForFlow(Value *start, int64_t step, Value *end, Flow *body, Var *var, std::unique_ptr schedule = {}, std::string name = "") : AcceptorExtend(std::move(name)), start(start), step(step), end(end), body(body), var(var), schedule(std::move(schedule)) {} /// @return the start value Value *getStart() const { return start; } /// Sets the start value. /// @param v the new value void setStart(Value *val) { start = val; } /// @return the step value int64_t getStep() const { return step; } /// Sets the step value. /// @param v the new value void setStep(int64_t val) { step = val; } /// @return the end value Value *getEnd() const { return end; } /// Sets the end value. /// @param v the new value void setEnd(Value *val) { end = val; } /// @return the body Flow *getBody() { return cast(body); } /// @return the body const Flow *getBody() const { return cast(body); } /// Sets the body. /// @param f the new body void setBody(Flow *f) { body = f; } /// @return the var Var *getVar() { return var; } /// @return the var const Var *getVar() const { return var; } /// Sets the var. /// @param c the new var void setVar(Var *c) { var = c; } /// @return true if parallel bool isParallel() const { return bool(schedule); } /// Sets parallel status. /// @param a true if parallel void setParallel(bool a = true) { if (a) schedule = std::make_unique(); else schedule = std::unique_ptr(); } /// @return the parallel loop schedule, or null if none transform::parallel::OMPSched *getSchedule() { return schedule.get(); } /// @return the parallel loop schedule, or null if none const transform::parallel::OMPSched *getSchedule() const { return schedule.get(); } /// Sets the parallel loop schedule /// @param s the schedule string (e.g. OpenMP pragma) void setSchedule(std::unique_ptr s) { schedule = std::move(s); } protected: std::vector doGetUsedValues() const override; int doReplaceUsedValue(id_t id, Value *newValue) override; std::vector doGetUsedVariables() const override { return {var}; } int doReplaceUsedVariable(id_t id, Var *newVar) override; }; /// Flow representing an if statement. class IfFlow : public AcceptorExtend { private: /// the condition Value *cond; /// the true branch Value *trueBranch; /// the false branch Value *falseBranch; public: static const char NodeId; /// Constructs an if. /// @param cond the condition /// @param trueBranch the true branch /// @param falseBranch the false branch /// @param name the flow's name IfFlow(Value *cond, Flow *trueBranch, Flow *falseBranch = nullptr, std::string name = "") : AcceptorExtend(std::move(name)), cond(cond), trueBranch(trueBranch), falseBranch(falseBranch) {} /// @return the true branch Flow *getTrueBranch() { return cast(trueBranch); } /// @return the true branch const Flow *getTrueBranch() const { return cast(trueBranch); } /// Sets the true branch. /// @param f the new true branch void setTrueBranch(Flow *f) { trueBranch = f; } /// @return the false branch Flow *getFalseBranch() { return cast(falseBranch); } /// @return the false branch const Flow *getFalseBranch() const { return cast(falseBranch); } /// Sets the false. /// @param f the new false void setFalseBranch(Flow *f) { falseBranch = f; } /// @return the condition Value *getCond() { return cond; } /// @return the condition const Value *getCond() const { return cond; } /// Sets the condition. /// @param c the new condition void setCond(Value *c) { cond = c; } protected: std::vector doGetUsedValues() const override; int doReplaceUsedValue(id_t id, Value *newValue) override; }; /// Flow representing a try-catch statement. class TryCatchFlow : public AcceptorExtend { public: /// Class representing a catch clause. class Catch { private: /// the handler Value *handler; /// the catch type, may be nullptr types::Type *type; /// the catch variable, may be nullptr Var *catchVar; public: explicit Catch(Flow *handler, types::Type *type = nullptr, Var *catchVar = nullptr) : handler(handler), type(type), catchVar(catchVar) {} /// @return the handler Flow *getHandler() { return cast(handler); } /// @return the handler const Flow *getHandler() const { return cast(handler); } /// Sets the handler. /// @param h the new value void setHandler(Flow *h) { handler = h; } /// @return the catch type, may be nullptr types::Type *getType() const { return type; } /// Sets the catch type. /// @param t the new type, nullptr for catch all void setType(types::Type *t) { type = t; } /// @return the variable, may be nullptr Var *getVar() { return catchVar; } /// @return the variable, may be nullptr const Var *getVar() const { return catchVar; } /// Sets the variable. /// @param v the new value, may be nullptr void setVar(Var *v) { catchVar = v; } }; private: /// the catch clauses std::list catches; /// the body Value *body; /// the else block, may be nullptr Value *else_; /// the finally, may be nullptr Value *finally; public: static const char NodeId; /// Constructs an try-catch. /// @param name the's name /// @param body the body /// @param finally the finally explicit TryCatchFlow(Flow *body, Flow *finally = nullptr, Flow *else_ = nullptr, std::string name = "") : AcceptorExtend(std::move(name)), body(body), else_(else_), finally(finally) {} /// @return the body Flow *getBody() { return cast(body); } /// @return the body const Flow *getBody() const { return cast(body); } /// Sets the body. /// @param f the new body void setBody(Flow *f) { body = f; } /// @return the else block Flow *getElse() { return cast(else_); } /// @return the else block const Flow *getElse() const { return cast(else_); } /// Sets the else block. /// @param f the new else block void setElse(Flow *f) { else_ = f; } /// @return the finally Flow *getFinally() { return cast(finally); } /// @return the finally const Flow *getFinally() const { return cast(finally); } /// Sets the finally. /// @param f the new finally void setFinally(Flow *f) { finally = f; } /// @return an iterator to the first catch auto begin() { return catches.begin(); } /// @return an iterator beyond the last catch auto end() { return catches.end(); } /// @return an iterator to the first catch auto begin() const { return catches.begin(); } /// @return an iterator beyond the last catch auto end() const { return catches.end(); } /// @return a reference to the first catch auto &front() { return catches.front(); } /// @return a reference to the last catch auto &back() { return catches.back(); } /// @return a reference to the first catch auto &front() const { return catches.front(); } /// @return a reference to the last catch auto &back() const { return catches.back(); } /// Inserts a catch at the given position. /// @param pos the position /// @param v the catch /// @return an iterator to the newly added catch template auto insert(It pos, Catch v) { return catches.insert(pos, v); } /// Appends a catch. /// @param v the catch void push_back(Catch v) { catches.push_back(v); } /// Emplaces a catch. /// @tparam Args the catch constructor args template void emplace_back(Args &&...args) { catches.emplace_back(std::forward(args)...); } /// Erases a catch at the given position. /// @param pos the position /// @return the iterator beyond the erased catch template auto erase(It pos) { return catches.erase(pos); } protected: std::vector doGetUsedValues() const override; int doReplaceUsedValue(id_t id, Value *newValue) override; std::vector doGetUsedTypes() const override; int doReplaceUsedType(const std::string &name, types::Type *newType) override; std::vector doGetUsedVariables() const override; int doReplaceUsedVariable(id_t id, Var *newVar) override; }; /// Flow that represents a pipeline. Pipelines with only function /// stages are expressions and have a concrete type. Pipelines with /// generator stages are not expressions and have no type. This /// representation allows for stages that output generators but do /// not get explicitly iterated in the pipeline, since generator /// stages are denoted by a separate flag. class PipelineFlow : public AcceptorExtend { public: /// Represents a single stage in a pipeline. class Stage { private: /// the function being (partially) called in this stage Value *callee; /// the function arguments, where null represents where /// previous pipeline output should go std::vector args; /// true if this stage is a generator bool generator; /// true if this stage is marked parallel bool parallel; public: /// Constructs a pipeline stage. /// @param callee the function being called /// @param args call arguments, with exactly one null entry /// @param generator whether this stage is a generator stage /// @param parallel whether this stage is parallel Stage(Value *callee, std::vector args, bool generator, bool parallel) : callee(callee), args(std::move(args)), generator(generator), parallel(parallel) {} /// @return an iterator to the first argument auto begin() { return args.begin(); } /// @return an iterator beyond the last argument auto end() { return args.end(); } /// @return an iterator to the first argument auto begin() const { return args.begin(); } /// @return an iterator beyond the last argument auto end() const { return args.end(); } /// @return a pointer to the first argument Value *front() { return args.front(); } /// @return a pointer to the last argument Value *back() { return args.back(); } /// @return a pointer to the first argument const Value *front() const { return args.front(); } /// @return a pointer to the last argument const Value *back() const { return args.back(); } /// Inserts an argument. /// @param pos the position /// @param v the argument /// @return an iterator to the newly added argument template auto insert(It pos, Value *v) { return args.insert(pos, v); } /// Appends an argument. /// @param v the argument void push_back(Value *v) { args.push_back(v); } /// Erases the item at the supplied position. /// @param pos the position /// @return the iterator beyond the removed argument template auto erase(It pos) { return args.erase(pos); } /// Sets the called function. /// @param c the callee void setCallee(Value *c) { callee = c; } /// @return the called function Value *getCallee() { return callee; } /// @return the called function const Value *getCallee() const { return callee; } /// Sets the stage's generator flag. /// @param v the new value void setGenerator(bool v = true) { generator = v; } /// @return whether this stage is a generator stage bool isGenerator() const { return generator; } /// Sets the stage's parallel flag. /// @param v the new value void setParallel(bool v = true) { parallel = v; } /// @return whether this stage is parallel bool isParallel() const { return parallel; } /// @return the output type of this stage types::Type *getOutputType() const; /// @return the output element type of this stage types::Type *getOutputElementType() const; friend class PipelineFlow; }; private: /// pipeline stages std::list stages; public: static const char NodeId; /// Constructs a pipeline flow. /// @param stages vector of pipeline stages /// @param name the name explicit PipelineFlow(std::vector stages = {}, std::string name = "") : AcceptorExtend(std::move(name)), stages(stages.begin(), stages.end()) {} /// @return an iterator to the first stage auto begin() { return stages.begin(); } /// @return an iterator beyond the last stage auto end() { return stages.end(); } /// @return an iterator to the first stage auto begin() const { return stages.begin(); } /// @return an iterator beyond the last stage auto end() const { return stages.end(); } /// @return a pointer to the first stage Stage &front() { return stages.front(); } /// @return a pointer to the last stage Stage &back() { return stages.back(); } /// @return a pointer to the first stage const Stage &front() const { return stages.front(); } /// @return a pointer to the last stage const Stage &back() const { return stages.back(); } /// Inserts a stage /// @param pos the position /// @param v the stage /// @return an iterator to the newly added stage template auto insert(It pos, Stage v) { return stages.insert(pos, v); } /// Appends an stage. /// @param v the stage void push_back(Stage v) { stages.push_back(std::move(v)); } /// Erases the item at the supplied position. /// @param pos the position /// @return the iterator beyond the removed stage template auto erase(It pos) { return stages.erase(pos); } /// Emplaces a stage. /// @param args the args template void emplace_back(Args &&...args) { stages.emplace_back(std::forward(args)...); } protected: std::vector doGetUsedValues() const override; int doReplaceUsedValue(id_t id, Value *newValue) override; }; } // namespace ir } // namespace codon ================================================ FILE: codon/cir/func.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "func.h" #include #include "codon/cir/module.h" #include "codon/cir/util/iterators.h" #include "codon/cir/util/operator.h" #include "codon/cir/util/visitor.h" #include "codon/cir/var.h" #include "codon/parser/common.h" namespace codon { namespace ir { namespace { int findAndReplace(id_t id, codon::ir::Var *newVal, std::list &values) { auto replacements = 0; for (auto &value : values) { if (value->getId() == id) { value = newVal; ++replacements; } } return replacements; } } // namespace const char Func::NodeId = 0; void Func::realize(types::Type *newType, const std::vector &names) { auto *funcType = cast(newType); seqassert(funcType, "{} is not a function type", *newType); setType(funcType); args.clear(); auto i = 0; for (auto *t : *funcType) { args.push_back(getModule()->Nr(t, false, false, false, names[i])); ++i; } } Var *Func::getArgVar(const std::string &n) { auto it = std::find_if(args.begin(), args.end(), [n](auto *other) { return other->getName() == n; }); return (it != args.end()) ? *it : nullptr; } std::vector Func::doGetUsedVariables() const { std::vector ret(args.begin(), args.end()); return ret; } int Func::doReplaceUsedVariable(id_t id, Var *newVar) { return findAndReplace(id, newVar, args); } std::vector Func::doGetUsedTypes() const { std::vector ret; for (auto *t : Var::getUsedTypes()) ret.push_back(const_cast(t)); if (parentType) ret.push_back(parentType); return ret; } int Func::doReplaceUsedType(const std::string &name, types::Type *newType) { auto count = Var::replaceUsedType(name, newType); if (parentType && parentType->getName() == name) { parentType = newType; ++count; } return count; } const char BodiedFunc::NodeId = 0; int BodiedFunc::doReplaceUsedValue(id_t id, Value *newValue) { if (body && body->getId() == id) { auto *flow = cast(newValue); seqassert(flow, "{} is not a flow", *newValue); body = flow; return 1; } return 0; } std::vector BodiedFunc::doGetUsedVariables() const { auto ret = Func::doGetUsedVariables(); ret.insert(ret.end(), symbols.begin(), symbols.end()); return ret; } int BodiedFunc::doReplaceUsedVariable(id_t id, Var *newVar) { return Func::doReplaceUsedVariable(id, newVar) + findAndReplace(id, newVar, symbols); } const char ExternalFunc::NodeId = 0; const char InternalFunc::NodeId = 0; const char LLVMFunc::NodeId = 0; std::vector LLVMFunc::doGetUsedTypes() const { std::vector ret; for (auto *t : Func::getUsedTypes()) ret.push_back(const_cast(t)); for (auto &l : llvmLiterals) if (l.isType()) ret.push_back(const_cast(l.getTypeValue())); return ret; } int LLVMFunc::doReplaceUsedType(const std::string &name, types::Type *newType) { auto count = Var::doReplaceUsedType(name, newType); for (auto &l : llvmLiterals) if (l.isType() && l.getTypeValue()->getName() == name) { l = newType; ++count; } return count; } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/func.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/flow.h" #include "codon/cir/util/iterators.h" #include "codon/cir/var.h" namespace codon { namespace ir { /// CIR function class Func : public AcceptorExtend { private: /// unmangled (source code) name of the function std::string unmangledName; /// whether the function is a generator bool generator; /// whether the function is an async function bool async; /// Parent type if func is a method, or null if not types::Type *parentType; protected: /// list of arguments std::list args; std::vector doGetUsedVariables() const override; int doReplaceUsedVariable(id_t id, Var *newVar) override; std::vector doGetUsedTypes() const override; int doReplaceUsedType(const std::string &name, types::Type *newType) override; public: static const char NodeId; /// Constructs an unrealized CIR function. /// @param name the function's name explicit Func(std::string name = "") : AcceptorExtend(nullptr, true, false, false, std::move(name)), generator(false), async(false), parentType(nullptr) {} /// Re-initializes the function with a new type and names. /// @param newType the function's new type /// @param names the function's new argument names void realize(types::Type *newType, const std::vector &names); /// @return iterator to the first arg auto arg_begin() { return args.begin(); } /// @return iterator beyond the last arg auto arg_end() { return args.end(); } /// @return iterator to the first arg auto arg_begin() const { return args.begin(); } /// @return iterator beyond the last arg auto arg_end() const { return args.end(); } /// @return a pointer to the last arg Var *arg_front() { return args.front(); } /// @return a pointer to the last arg Var *arg_back() { return args.back(); } /// @return a pointer to the last arg const Var *arg_back() const { return args.back(); } /// @return a pointer to the first arg const Var *arg_front() const { return args.front(); } /// @return the function's unmangled (source code) name std::string getUnmangledName() const { return unmangledName; } /// Sets the unmangled name. /// @param v the new value void setUnmangledName(std::string v) { unmangledName = std::move(v); } /// @return true if the function is a generator bool isGenerator() const { return generator; } /// Sets the function's generator flag. /// @param v the new value void setGenerator(bool v = true) { generator = v; } /// @return true if the function is an async function bool isAsync() const { return async; } /// Sets the function's async flag. /// @param v the new value void setAsync(bool v = true) { async = v; } /// @return the variable corresponding to the given argument name /// @param n the argument name Var *getArgVar(const std::string &n); /// @return the parent type types::Type *getParentType() const { return parentType; } /// Sets the parent type. /// @param p the new parent void setParentType(types::Type *p) { parentType = p; } }; class BodiedFunc : public AcceptorExtend { private: /// list of variables defined and used within the function std::list symbols; /// the function body Value *body = nullptr; /// whether the function is a JIT input bool jit = false; public: static const char NodeId; using AcceptorExtend::AcceptorExtend; /// @return iterator to the first symbol auto begin() { return symbols.begin(); } /// @return iterator beyond the last symbol auto end() { return symbols.end(); } /// @return iterator to the first symbol auto begin() const { return symbols.begin(); } /// @return iterator beyond the last symbol auto end() const { return symbols.end(); } /// @return a pointer to the first symbol Var *front() { return symbols.front(); } /// @return a pointer to the last symbol Var *back() { return symbols.back(); } /// @return a pointer to the first symbol const Var *front() const { return symbols.front(); } /// @return a pointer to the last symbol const Var *back() const { return symbols.back(); } /// Inserts an symbol at the given position. /// @param pos the position /// @param v the symbol /// @return an iterator to the newly added symbol template auto insert(It pos, Var *v) { return symbols.insert(pos, v); } /// Appends an symbol. /// @param v the new symbol void push_back(Var *v) { symbols.push_back(v); } /// Erases the symbol at the given position. /// @param pos the position /// @return symbol_iterator following the removed symbol. template auto erase(It pos) { return symbols.erase(pos); } /// @return the function body Flow *getBody() { return cast(body); } /// @return the function body const Flow *getBody() const { return cast(body); } /// Sets the function's body. /// @param b the new body void setBody(Flow *b) { body = b; } /// @return true if the function is a JIT input bool isJIT() const { return jit; } /// Changes the function's JIT input status. /// @param v true if JIT input, false otherwise void setJIT(bool v = true) { jit = v; } protected: std::vector doGetUsedValues() const override { return body ? std::vector{body} : std::vector{}; } int doReplaceUsedValue(id_t id, Value *newValue) override; std::vector doGetUsedVariables() const override; int doReplaceUsedVariable(id_t id, Var *newVar) override; }; class ExternalFunc : public AcceptorExtend { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; /// @return true if the function is variadic bool isVariadic() const { return cast(getType())->isVariadic(); } }; /// Internal, LLVM-only function. class InternalFunc : public AcceptorExtend { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; }; /// LLVM function defined in Seq source. class LLVMFunc : public AcceptorExtend { private: /// literals that must be formatted into the body std::vector llvmLiterals; /// declares for llvm-only function std::string llvmDeclares; /// body of llvm-only function std::string llvmBody; public: static const char NodeId; using AcceptorExtend::AcceptorExtend; /// Sets the LLVM literals. /// @param v the new values. void setLLVMLiterals(std::vector v) { llvmLiterals = std::move(v); } /// @return iterator to the first literal auto literal_begin() { return llvmLiterals.begin(); } /// @return iterator beyond the last literal auto literal_end() { return llvmLiterals.end(); } /// @return iterator to the first literal auto literal_begin() const { return llvmLiterals.begin(); } /// @return iterator beyond the last literal auto literal_end() const { return llvmLiterals.end(); } /// @return a reference to the first literal auto &literal_front() { return llvmLiterals.front(); } /// @return a reference to the last literal auto &literal_back() { return llvmLiterals.back(); } /// @return a reference to the first literal auto &literal_front() const { return llvmLiterals.front(); } /// @return a reference to the last literal auto &literal_back() const { return llvmLiterals.back(); } /// @return the LLVM declarations const std::string &getLLVMDeclarations() const { return llvmDeclares; } /// Sets the LLVM declarations. /// @param v the new value void setLLVMDeclarations(std::string v) { llvmDeclares = std::move(v); } /// @return the LLVM body const std::string &getLLVMBody() const { return llvmBody; } /// Sets the LLVM body. /// @param v the new value void setLLVMBody(std::string v) { llvmBody = std::move(v); } protected: std::vector doGetUsedTypes() const override; int doReplaceUsedType(const std::string &name, types::Type *newType) override; }; } // namespace ir } // namespace codon template <> struct fmt::formatter : fmt::ostream_formatter {}; ================================================ FILE: codon/cir/instr.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "instr.h" #include "codon/cir/module.h" #include "codon/cir/util/iterators.h" namespace codon { namespace ir { namespace { int findAndReplace(id_t id, codon::ir::Value *newVal, std::vector &values) { auto replacements = 0; for (auto &value : values) { if (value->getId() == id) { value = newVal; ++replacements; } } return replacements; } } // namespace const char Instr::NodeId = 0; types::Type *Instr::doGetType() const { return getModule()->getNoneType(); } const char AssignInstr::NodeId = 0; int AssignInstr::doReplaceUsedValue(id_t id, Value *newValue) { if (rhs->getId() == id) { rhs = newValue; return 1; } return 0; } int AssignInstr::doReplaceUsedVariable(id_t id, Var *newVar) { if (lhs->getId() == id) { lhs = newVar; return 1; } return 0; } const char ExtractInstr::NodeId = 0; types::Type *ExtractInstr::doGetType() const { auto *memberedType = cast(val->getType()); seqassert(memberedType, "{} is not a membered type", *val->getType()); return memberedType->getMemberType(field); } int ExtractInstr::doReplaceUsedValue(id_t id, Value *newValue) { if (val->getId() == id) { val = newValue; return 1; } return 0; } const char InsertInstr::NodeId = 0; int InsertInstr::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; if (lhs->getId() == id) { lhs = newValue; ++replacements; } if (rhs->getId() == id) { rhs = newValue; ++replacements; } return replacements; } const char CallInstr::NodeId = 0; types::Type *CallInstr::doGetType() const { auto *funcType = cast(callee->getType()); seqassert(funcType, "{} is not a function type", *callee->getType()); return funcType->getReturnType(); } std::vector CallInstr::doGetUsedValues() const { std::vector ret(args.begin(), args.end()); ret.push_back(callee); return ret; } int CallInstr::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; if (callee->getId() == id) { callee = newValue; ++replacements; } replacements += findAndReplace(id, newValue, args); return replacements; } const char StackAllocInstr::NodeId = 0; int StackAllocInstr::doReplaceUsedType(const std::string &name, types::Type *newType) { if (arrayType->getName() == name) { arrayType = newType; return 1; } return 0; } const char TypePropertyInstr::NodeId = 0; types::Type *TypePropertyInstr::doGetType() const { switch (property) { case Property::IS_ATOMIC: return getModule()->getBoolType(); case Property::IS_CONTENT_ATOMIC: return getModule()->getBoolType(); case Property::SIZEOF: return getModule()->getIntType(); default: return getModule()->getNoneType(); } } int TypePropertyInstr::doReplaceUsedType(const std::string &name, types::Type *newType) { if (inspectType->getName() == name) { inspectType = newType; return 1; } return 0; } const char YieldInInstr::NodeId = 0; int YieldInInstr::doReplaceUsedType(const std::string &name, types::Type *newType) { if (type->getName() == name) { type = newType; return 1; } return 0; } const char AwaitInstr::NodeId = 0; int AwaitInstr::doReplaceUsedValue(id_t id, Value *newValue) { if (value->getId() == id) { value = newValue; return 1; } return 0; } int AwaitInstr::doReplaceUsedType(const std::string &name, types::Type *newType) { if (type->getName() == name) { type = newType; return 1; } return 0; } const char TernaryInstr::NodeId = 0; int TernaryInstr::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; if (cond->getId() == id) { cond = newValue; ++replacements; } if (trueValue->getId() == id) { trueValue = newValue; ++replacements; } if (falseValue->getId() == id) { falseValue = newValue; ++replacements; } return replacements; } const char ControlFlowInstr::NodeId = 0; const char BreakInstr::NodeId = 0; std::vector BreakInstr::doGetUsedValues() const { if (loop) return {loop}; return {}; } int BreakInstr::doReplaceUsedValue(id_t id, Value *newValue) { if (loop && loop->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); loop = f; return 1; } return 0; } const char ContinueInstr::NodeId = 0; std::vector ContinueInstr::doGetUsedValues() const { if (loop) return {loop}; return {}; } int ContinueInstr::doReplaceUsedValue(id_t id, Value *newValue) { if (loop && loop->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); loop = f; return 1; } return 0; } const char ReturnInstr::NodeId = 0; std::vector ReturnInstr::doGetUsedValues() const { if (value) return {value}; return {}; } int ReturnInstr::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; if (value && value->getId() == id) { setValue(newValue); ++replacements; } return replacements; } const char YieldInstr::NodeId = 0; std::vector YieldInstr::doGetUsedValues() const { if (value) return {value}; return {}; } int YieldInstr::doReplaceUsedValue(id_t id, Value *newValue) { if (value && value->getId() == id) { setValue(newValue); return 1; } return 0; } const char ThrowInstr::NodeId = 0; std::vector ThrowInstr::doGetUsedValues() const { if (value) return {value}; return {}; } int ThrowInstr::doReplaceUsedValue(id_t id, Value *newValue) { if (value && value->getId() == id) { setValue(newValue); return 1; } return 0; } const char FlowInstr::NodeId = 0; int FlowInstr::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; if (flow->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); setFlow(f); ++replacements; } if (val->getId() == id) { setValue(newValue); ++replacements; } return replacements; } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/instr.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include "codon/cir/flow.h" #include "codon/cir/types/types.h" #include "codon/cir/util/iterators.h" #include "codon/cir/value.h" #include "codon/cir/var.h" namespace codon { namespace ir { /// CIR object representing an "instruction," or discrete operation in the context of a /// block. class Instr : public AcceptorExtend { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; private: types::Type *doGetType() const override; }; /// Instr representing setting a memory location. class AssignInstr : public AcceptorExtend { private: /// the left-hand side Var *lhs; /// the right-hand side Value *rhs; public: static const char NodeId; /// Constructs an assign instruction. /// @param lhs the left-hand side /// @param rhs the right-hand side /// @param field the field being set, may be empty /// @param name the instruction's name AssignInstr(Var *lhs, Value *rhs, std::string name = "") : AcceptorExtend(std::move(name)), lhs(lhs), rhs(rhs) {} /// @return the left-hand side Var *getLhs() { return lhs; } /// @return the left-hand side const Var *getLhs() const { return lhs; } /// Sets the left-hand side /// @param l the new value void setLhs(Var *v) { lhs = v; } /// @return the right-hand side Value *getRhs() { return rhs; } /// @return the right-hand side const Value *getRhs() const { return rhs; } /// Sets the right-hand side /// @param l the new value void setRhs(Value *v) { rhs = v; } protected: std::vector doGetUsedValues() const override { return {rhs}; } int doReplaceUsedValue(id_t id, Value *newValue) override; std::vector doGetUsedVariables() const override { return {lhs}; } int doReplaceUsedVariable(id_t id, Var *newVar) override; }; /// Instr representing loading the field of a value. class ExtractInstr : public AcceptorExtend { private: /// the value being manipulated Value *val; /// the field std::string field; public: static const char NodeId; /// Constructs a load instruction. /// @param val the value being manipulated /// @param field the field /// @param name the instruction's name explicit ExtractInstr(Value *val, std::string field, std::string name = "") : AcceptorExtend(std::move(name)), val(val), field(std::move(field)) {} /// @return the location Value *getVal() { return val; } /// @return the location const Value *getVal() const { return val; } /// Sets the location. /// @param p the new value void setVal(Value *p) { val = p; } /// @return the field const std::string &getField() const { return field; } /// Sets the field. /// @param f the new field void setField(std::string f) { field = std::move(f); } protected: types::Type *doGetType() const override; std::vector doGetUsedValues() const override { return {val}; } int doReplaceUsedValue(id_t id, Value *newValue) override; }; /// Instr representing setting the field of a value. class InsertInstr : public AcceptorExtend { private: /// the value being manipulated Value *lhs; /// the field std::string field; /// the value being inserted Value *rhs; public: static const char NodeId; /// Constructs a load instruction. /// @param lhs the value being manipulated /// @param field the field /// @param rhs the new value /// @param name the instruction's name explicit InsertInstr(Value *lhs, std::string field, Value *rhs, std::string name = "") : AcceptorExtend(std::move(name)), lhs(lhs), field(std::move(field)), rhs(rhs) {} /// @return the left-hand side Value *getLhs() { return lhs; } /// @return the left-hand side const Value *getLhs() const { return lhs; } /// Sets the left-hand side. /// @param p the new value void setLhs(Value *p) { lhs = p; } /// @return the right-hand side Value *getRhs() { return rhs; } /// @return the right-hand side const Value *getRhs() const { return rhs; } /// Sets the right-hand side. /// @param p the new value void setRhs(Value *p) { rhs = p; } /// @return the field const std::string &getField() const { return field; } /// Sets the field. /// @param f the new field void setField(std::string f) { field = std::move(f); } protected: types::Type *doGetType() const override { return lhs->getType(); } std::vector doGetUsedValues() const override { return {lhs, rhs}; } int doReplaceUsedValue(id_t id, Value *newValue) override; }; /// Instr representing calling a function. class CallInstr : public AcceptorExtend { private: /// the function Value *callee; /// the arguments std::vector args; public: static const char NodeId; /// Constructs a call instruction. /// @param callee the function /// @param args the arguments /// @param name the instruction's name CallInstr(Value *callee, std::vector args, std::string name = "") : AcceptorExtend(std::move(name)), callee(callee), args(std::move(args)) {} /// Constructs a call instruction with no arguments. /// @param callee the function /// @param name the instruction's name explicit CallInstr(Value *callee, std::string name = "") : CallInstr(callee, {}, std::move(name)) {} /// @return the callee Value *getCallee() { return callee; } /// @return the callee const Value *getCallee() const { return callee; } /// Sets the callee. /// @param c the new value void setCallee(Value *c) { callee = c; } /// @return an iterator to the first argument auto begin() { return args.begin(); } /// @return an iterator beyond the last argument auto end() { return args.end(); } /// @return an iterator to the first argument auto begin() const { return args.begin(); } /// @return an iterator beyond the last argument auto end() const { return args.end(); } /// @return a pointer to the first argument Value *front() { return args.front(); } /// @return a pointer to the last argument Value *back() { return args.back(); } /// @return a pointer to the first argument const Value *front() const { return args.front(); } /// @return a pointer to the last argument const Value *back() const { return args.back(); } /// Inserts an argument at the given position. /// @param pos the position /// @param v the argument /// @return an iterator to the newly added argument template auto insert(It pos, Value *v) { return args.insert(pos, v); } /// Appends an argument. /// @param v the argument void push_back(Value *v) { args.push_back(v); } /// Sets the args. /// @param v the new args vector void setArgs(std::vector v) { args = std::move(v); } /// @return the number of arguments int numArgs() const { return args.size(); } protected: types::Type *doGetType() const override; std::vector doGetUsedValues() const override; int doReplaceUsedValue(id_t id, Value *newValue) override; }; /// Instr representing allocating an array on the stack. class StackAllocInstr : public AcceptorExtend { private: /// the array type types::Type *arrayType; /// number of elements to allocate int64_t count; public: static const char NodeId; /// Constructs a stack allocation instruction. /// @param arrayType the type of the array /// @param count the number of elements /// @param name the name StackAllocInstr(types::Type *arrayType, int64_t count, std::string name = "") : AcceptorExtend(std::move(name)), arrayType(arrayType), count(count) {} /// @return the count int64_t getCount() const { return count; } /// Sets the count. /// @param c the new value void setCount(int64_t c) { count = c; } /// @return the array type types::Type *getArrayType() { return arrayType; } /// @return the array type types::Type *getArrayType() const { return arrayType; } /// Sets the array type. /// @param t the new type void setArrayType(types::Type *t) { arrayType = t; } protected: types::Type *doGetType() const override { return arrayType; } std::vector doGetUsedTypes() const override { return {arrayType}; } int doReplaceUsedType(const std::string &name, types::Type *newType) override; }; /// Instr representing getting information about a type. class TypePropertyInstr : public AcceptorExtend { public: enum Property { IS_ATOMIC, IS_CONTENT_ATOMIC, SIZEOF }; private: /// the type being inspected types::Type *inspectType; /// the property being checked Property property; public: static const char NodeId; /// Constructs a type property instruction. /// @param type the type being inspected /// @param name the name explicit TypePropertyInstr(types::Type *type, Property property, std::string name = "") : AcceptorExtend(std::move(name)), inspectType(type), property(property) {} /// @return the type being inspected types::Type *getInspectType() { return inspectType; } /// @return the type being inspected types::Type *getInspectType() const { return inspectType; } /// Sets the type being inspected /// @param t the new type void setInspectType(types::Type *t) { inspectType = t; } /// @return the property being inspected Property getProperty() const { return property; } /// Sets the property. /// @param p the new value void setProperty(Property p) { property = p; } protected: types::Type *doGetType() const override; std::vector doGetUsedTypes() const override { return {inspectType}; } int doReplaceUsedType(const std::string &name, types::Type *newType) override; }; /// Instr representing a Python yield expression. class YieldInInstr : public AcceptorExtend { private: /// the type of the value being yielded in. types::Type *type; /// whether or not to suspend bool suspend; public: static const char NodeId; /// Constructs a yield in instruction. /// @param type the type of the value being yielded in /// @param suspend whether to suspend /// @param name the instruction's name explicit YieldInInstr(types::Type *type, bool suspend = true, std::string name = "") : AcceptorExtend(std::move(name)), type(type), suspend(suspend) {} /// @return true if the instruction suspends bool isSuspending() const { return suspend; } /// Sets the instruction suspending flag. /// @param v the new value void setSuspending(bool v = true) { suspend = v; } /// Sets the type. /// @param t the new type void setType(types::Type *t) { type = t; } protected: types::Type *doGetType() const override { return type; } std::vector doGetUsedTypes() const override { return {type}; } int doReplaceUsedType(const std::string &name, types::Type *newType) override; }; /// Instr representing a ternary operator. class TernaryInstr : public AcceptorExtend { private: /// the condition Value *cond; /// the true value Value *trueValue; /// the false value Value *falseValue; public: static const char NodeId; /// Constructs a ternary instruction. /// @param cond the condition /// @param trueValue the true value /// @param falseValue the false value /// @param name the instruction's name TernaryInstr(Value *cond, Value *trueValue, Value *falseValue, std::string name = "") : AcceptorExtend(std::move(name)), cond(cond), trueValue(trueValue), falseValue(falseValue) {} /// @return the condition Value *getCond() { return cond; } /// @return the condition const Value *getCond() const { return cond; } /// Sets the condition. /// @param v the new value void setCond(Value *v) { cond = v; } /// @return the condition Value *getTrueValue() { return trueValue; } /// @return the condition const Value *getTrueValue() const { return trueValue; } /// Sets the true value. /// @param v the new value void setTrueValue(Value *v) { trueValue = v; } /// @return the false value Value *getFalseValue() { return falseValue; } /// @return the false value const Value *getFalseValue() const { return falseValue; } /// Sets the value. /// @param v the new value void setFalseValue(Value *v) { falseValue = v; } protected: types::Type *doGetType() const override { return trueValue->getType(); } std::vector doGetUsedValues() const override { return {cond, trueValue, falseValue}; } int doReplaceUsedValue(id_t id, Value *newValue) override; }; /// Base for control flow instructions class ControlFlowInstr : public AcceptorExtend { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; }; /// Instr representing a break statement. class BreakInstr : public AcceptorExtend { private: /// the loop being broken, nullptr if the immediate ancestor Value *loop; public: static const char NodeId; /// Constructs a break instruction. /// @param loop the loop being broken, nullptr if immediate ancestor /// @param name the instruction's name explicit BreakInstr(Value *loop = nullptr, std::string name = "") : AcceptorExtend(std::move(name)), loop(loop) {} /// @return the loop, nullptr if immediate ancestor Value *getLoop() const { return loop; } /// Sets the loop id. /// @param v the new loop, nullptr if immediate ancestor void setLoop(Value *v) { loop = v; } std::vector doGetUsedValues() const override; int doReplaceUsedValue(id_t id, Value *newValue) override; }; /// Instr representing a continue statement. class ContinueInstr : public AcceptorExtend { private: /// the loop being continued, nullptr if the immediate ancestor Value *loop; public: static const char NodeId; /// Constructs a continue instruction. /// @param loop the loop being continued, nullptr if immediate ancestor /// @param name the instruction's name explicit ContinueInstr(Value *loop = nullptr, std::string name = "") : AcceptorExtend(std::move(name)), loop(loop) {} /// @return the loop, nullptr if immediate ancestor Value *getLoop() const { return loop; } /// Sets the loop id. /// @param v the new loop, -1 if immediate ancestor void setLoop(Value *v) { loop = v; } std::vector doGetUsedValues() const override; int doReplaceUsedValue(id_t id, Value *newValue) override; }; /// Instr representing a return statement. class ReturnInstr : public AcceptorExtend { private: /// the value Value *value; public: static const char NodeId; explicit ReturnInstr(Value *value = nullptr, std::string name = "") : AcceptorExtend(std::move(name)), value(value) {} /// @return the value Value *getValue() { return value; } /// @return the value const Value *getValue() const { return value; } /// Sets the value. /// @param v the new value void setValue(Value *v) { value = v; } protected: std::vector doGetUsedValues() const override; int doReplaceUsedValue(id_t id, Value *newValue) override; }; /// Instr representing a yield statement. class YieldInstr : public AcceptorExtend { private: /// the value Value *value; /// whether this yield is final bool final; public: static const char NodeId; explicit YieldInstr(Value *value = nullptr, bool final = false, std::string name = "") : AcceptorExtend(std::move(name)), value(value), final(final) {} /// @return the value Value *getValue() { return value; } /// @return the value const Value *getValue() const { return value; } /// Sets the value. /// @param v the new value void setValue(Value *v) { value = v; } /// @return if this yield is final bool isFinal() const { return final; } /// Sets whether this yield is final. /// @param f true if final void setFinal(bool f = true) { final = f; } protected: std::vector doGetUsedValues() const override; int doReplaceUsedValue(id_t id, Value *newValue) override; }; /// Instr representing an await statement. class AwaitInstr : public AcceptorExtend { private: /// the value Value *value; /// the type of the result types::Type *type; /// true if argument is a generator (e.g. custom __await__) bool generator; public: static const char NodeId; explicit AwaitInstr(Value *value, types::Type *type, bool generator = false, std::string name = "") : AcceptorExtend(std::move(name)), value(value), type(type), generator(generator) {} /// @return the value Value *getValue() { return value; } /// @return the value const Value *getValue() const { return value; } /// Sets the value. /// @param v the new value void setValue(Value *v) { value = v; } /// Sets the type. /// @param t the new type void setType(types::Type *t) { type = t; } /// @return whether argument is a generator bool isGenerator() const { return generator; } /// Sets generator status /// @param g the new value void setGenerator(bool g = true) { generator = g; } protected: types::Type *doGetType() const override { return type; } std::vector doGetUsedValues() const override { return {value}; } std::vector doGetUsedTypes() const override { return {type}; } int doReplaceUsedValue(id_t id, Value *newValue) override; int doReplaceUsedType(const std::string &name, types::Type *newType) override; }; class ThrowInstr : public AcceptorExtend { private: /// the value Value *value; public: static const char NodeId; explicit ThrowInstr(Value *value = nullptr, std::string name = "") : AcceptorExtend(std::move(name)), value(value) {} /// @return the value Value *getValue() { return value; } /// @return the value const Value *getValue() const { return value; } /// Sets the value. /// @param v the new value void setValue(Value *v) { value = v; } protected: std::vector doGetUsedValues() const override; int doReplaceUsedValue(id_t id, Value *newValue) override; }; /// Instr that contains a flow and value. class FlowInstr : public AcceptorExtend { private: /// the flow Value *flow; /// the output value Value *val; public: static const char NodeId; /// Constructs a flow value. /// @param flow the flow /// @param val the output value /// @param name the name explicit FlowInstr(Flow *flow, Value *val, std::string name = "") : AcceptorExtend(std::move(name)), flow(flow), val(val) {} /// @return the flow Flow *getFlow() { return cast(flow); } /// @return the flow const Flow *getFlow() const { return cast(flow); } /// Sets the flow. /// @param f the new flow void setFlow(Flow *f) { flow = f; } /// @return the value Value *getValue() { return val; } /// @return the value const Value *getValue() const { return val; } /// Sets the value. /// @param v the new value void setValue(Value *v) { val = v; } protected: types::Type *doGetType() const override { return val->getType(); } std::vector doGetUsedValues() const override { return {flow, val}; } int doReplaceUsedValue(id_t id, Value *newValue) override; }; } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/gpu.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "gpu.h" #include #include #include #include "codon/cir/llvm/optimize.h" #include "codon/util/common.h" namespace codon { namespace ir { namespace { const std::string GPU_TRIPLE = "nvptx64-nvidia-cuda"; const std::string GPU_DL = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-" "f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"; llvm::cl::opt libdevice("libdevice", llvm::cl::desc("libdevice path for GPU kernels"), llvm::cl::init("/usr/local/cuda/nvvm/libdevice/libdevice.10.bc")); llvm::cl::opt ptxOutput("ptx", llvm::cl::desc("Output PTX to specified file")); llvm::cl::opt gpuName( "gpu-name", llvm::cl::desc( "Target GPU architecture or compute capability (e.g. sm_70, sm_80, etc.)"), llvm::cl::init("sm_30")); llvm::cl::opt gpuFeatures( "gpu-features", llvm::cl::desc("GPU feature flags passed (e.g. +ptx42 to enable PTX 4.2 features)"), llvm::cl::init("+ptx42")); // Adapted from LLVM's GVExtractorPass, which is not externally available // as a pass for the new pass manager. class GVExtractor : public llvm::PassInfoMixin { llvm::SetVector named; bool deleteStuff; bool keepConstInit; public: // If deleteS is true, this pass deletes the specified global values. // Otherwise, it deletes as much of the module as possible, except for the // global values specified. explicit GVExtractor(std::vector &GVs, bool deleteS = true, bool keepConstInit = false) : named(GVs.begin(), GVs.end()), deleteStuff(deleteS), keepConstInit(keepConstInit) {} // Make sure GV is visible from both modules. Delete is true if it is // being deleted from this module. // This also makes sure GV cannot be dropped so that references from // the split module remain valid. static void makeVisible(llvm::GlobalValue &GV, bool del) { bool local = GV.hasLocalLinkage(); if (local || del) { GV.setLinkage(llvm::GlobalValue::ExternalLinkage); if (local) GV.setVisibility(llvm::GlobalValue::HiddenVisibility); return; } if (!GV.hasLinkOnceLinkage()) { seqassertn(!GV.isDiscardableIfUnused(), "bad global in extractor"); return; } // Map linkonce* to weak* so that llvm doesn't drop this GV. switch (GV.getLinkage()) { default: seqassertn(false, "unexpected linkage"); case llvm::GlobalValue::LinkOnceAnyLinkage: GV.setLinkage(llvm::GlobalValue::WeakAnyLinkage); return; case llvm::GlobalValue::LinkOnceODRLinkage: GV.setLinkage(llvm::GlobalValue::WeakODRLinkage); return; } } llvm::PreservedAnalyses run(llvm::Module &M, llvm::ModuleAnalysisManager &) { // Visit the global inline asm. if (!deleteStuff) M.setModuleInlineAsm(""); // For simplicity, just give all GlobalValues ExternalLinkage. A trickier // implementation could figure out which GlobalValues are actually // referenced by the 'named' set, and which GlobalValues in the rest of // the module are referenced by the NamedSet, and get away with leaving // more internal and private things internal and private. But for now, // be conservative and simple. // Visit the GlobalVariables. for (auto &GV : M.globals()) { bool del = deleteStuff == (bool)named.count(&GV) && !GV.isDeclaration() && (!GV.isConstant() || !keepConstInit); if (!del) { if (GV.hasAvailableExternallyLinkage()) continue; if (GV.getName() == "llvm.global_ctors") continue; } makeVisible(GV, del); if (del) { // Make this a declaration and drop it's comdat. GV.setInitializer(nullptr); GV.setComdat(nullptr); } } // Visit the Functions. for (auto &F : M) { bool del = deleteStuff == (bool)named.count(&F) && !F.isDeclaration(); if (!del) { if (F.hasAvailableExternallyLinkage()) continue; } makeVisible(F, del); if (del) { // Make this a declaration and drop it's comdat. F.deleteBody(); F.setComdat(nullptr); } } // Visit the Aliases. for (auto &GA : llvm::make_early_inc_range(M.aliases())) { bool del = deleteStuff == (bool)named.count(&GA); makeVisible(GA, del); if (del) { auto *ty = GA.getValueType(); GA.removeFromParent(); llvm::Value *decl; if (auto *funcTy = llvm::dyn_cast(ty)) { decl = llvm::Function::Create(funcTy, llvm::GlobalValue::ExternalLinkage, GA.getAddressSpace(), GA.getName(), &M); } else { decl = new llvm::GlobalVariable( M, ty, false, llvm::GlobalValue::ExternalLinkage, nullptr, GA.getName()); } GA.replaceAllUsesWith(decl); delete &GA; } } return llvm::PreservedAnalyses::none(); } }; std::string cleanUpName(llvm::StringRef name) { std::string validName; llvm::raw_string_ostream validNameStream(validName); auto valid = [](char c, bool first) { bool ok = ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || (c == '_'); if (!first) ok = ok || ('0' <= c && c <= '9'); return ok; }; bool first = true; for (char c : name) { validNameStream << (valid(c, first) ? c : '_'); first = false; } return validNameStream.str(); } void linkLibdevice(llvm::Module *M, const std::string &path) { llvm::SMDiagnostic err; auto libdevice = llvm::parseIRFile(path, err, M->getContext()); if (!libdevice) compilationError(err.getMessage().str(), err.getFilename().str(), err.getLineNo(), err.getColumnNo()); libdevice->setDataLayout(M->getDataLayout()); libdevice->setTargetTriple(M->getTargetTriple()); llvm::Linker L(*M); const bool fail = L.linkInModule(std::move(libdevice)); seqassertn(!fail, "linking libdevice failed"); } llvm::Function *copyPrototype(llvm::Function *F, const std::string &name, bool external = false) { auto *M = F->getParent(); return llvm::Function::Create(F->getFunctionType(), external ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::PrivateLinkage, name.empty() ? F->getName() : name, *M); } llvm::Function *makeNoOp(llvm::Function *F) { auto *M = F->getParent(); auto &context = M->getContext(); auto dummyName = (".codon.gpu.dummy." + F->getName()).str(); auto *dummy = M->getFunction(dummyName); if (!dummy) { dummy = copyPrototype(F, dummyName); auto *entry = llvm::BasicBlock::Create(context, "entry", dummy); llvm::IRBuilder<> B(entry); auto *retType = F->getReturnType(); if (retType->isVoidTy()) { B.CreateRetVoid(); } else { B.CreateRet(llvm::UndefValue::get(retType)); } } return dummy; } using Codegen = std::function &, const std::vector &)>; void codegenVectorizedUnaryLoop(llvm::IRBuilder<> &B, const std::vector &args, llvm::Function *func) { // Create IR to represent: // p_in = in // p_out = out // for i in range(n): // *p_out = func(*p_in) // p_in += is // p_out += os auto &context = B.getContext(); auto *parent = B.GetInsertBlock()->getParent(); auto *ty = func->getReturnType(); auto *in = args[0]; auto *is = args[1]; auto *out = args[2]; auto *os = args[3]; auto *n = args[4]; auto *loop = llvm::BasicBlock::Create(context, "loop", parent); auto *exit = llvm::BasicBlock::Create(context, "exit", parent); auto *pinStore = B.CreateAlloca(B.getPtrTy()); auto *poutStore = B.CreateAlloca(B.getPtrTy()); auto *idxStore = B.CreateAlloca(B.getInt64Ty()); // p_in = in B.CreateStore(in, pinStore); // p_out = out B.CreateStore(out, poutStore); // i = 0 B.CreateStore(B.getInt64(0), idxStore); // if n > 0: goto loop; else: goto exit B.CreateCondBr(B.CreateICmpSGT(n, B.getInt64(0)), loop, exit); // load pointers B.SetInsertPoint(loop); auto *pin = B.CreateLoad(B.getPtrTy(), pinStore); auto *pout = B.CreateLoad(B.getPtrTy(), poutStore); // y = func(x) auto *x = B.CreateLoad(ty, pin); auto *y = B.CreateCall(func, x); B.CreateStore(y, pout); auto *idx = B.CreateLoad(B.getInt64Ty(), idxStore); // i += 1 B.CreateStore(B.CreateAdd(idx, B.getInt64(1)), idxStore); // p_in += is B.CreateStore(B.CreateGEP(B.getInt8Ty(), pin, is), pinStore); // p_out += os B.CreateStore(B.CreateGEP(B.getInt8Ty(), pout, os), poutStore); idx = B.CreateLoad(B.getInt64Ty(), idxStore); // if i < n: goto loop; else: goto exit B.CreateCondBr(B.CreateICmpSLT(idx, n), loop, exit); B.SetInsertPoint(exit); B.CreateRet(llvm::UndefValue::get(parent->getReturnType())); } void codegenVectorizedBinaryLoop(llvm::IRBuilder<> &B, const std::vector &args, llvm::Function *func) { // Create IR to represent: // p_in1 = in1 // p_in2 = in2 // p_out = out // for i in range(n): // *p_out = func(*p_in1, *p_in2) // p_in1 += is1 // p_in2 += is2 // p_out += os auto &context = B.getContext(); auto *parent = B.GetInsertBlock()->getParent(); auto *ty = func->getReturnType(); auto *in1 = args[0]; auto *is1 = args[1]; auto *in2 = args[2]; auto *is2 = args[3]; auto *out = args[4]; auto *os = args[5]; auto *n = args[6]; auto *loop = llvm::BasicBlock::Create(context, "loop", parent); auto *exit = llvm::BasicBlock::Create(context, "exit", parent); auto *pin1Store = B.CreateAlloca(B.getPtrTy()); auto *pin2Store = B.CreateAlloca(B.getPtrTy()); auto *poutStore = B.CreateAlloca(B.getPtrTy()); auto *idxStore = B.CreateAlloca(B.getInt64Ty()); // p_in1 = in1 B.CreateStore(in1, pin1Store); // p_in2 = in2 B.CreateStore(in2, pin2Store); // p_out = out B.CreateStore(out, poutStore); // i = 0 B.CreateStore(B.getInt64(0), idxStore); // if n > 0: goto loop; else: goto exit B.CreateCondBr(B.CreateICmpSGT(n, B.getInt64(0)), loop, exit); // load pointers B.SetInsertPoint(loop); auto *pin1 = B.CreateLoad(B.getPtrTy(), pin1Store); auto *pin2 = B.CreateLoad(B.getPtrTy(), pin2Store); auto *pout = B.CreateLoad(B.getPtrTy(), poutStore); // y = func(x1, x2) auto *x1 = B.CreateLoad(ty, pin1); auto *x2 = B.CreateLoad(ty, pin2); auto *y = B.CreateCall(func, {x1, x2}); B.CreateStore(y, pout); auto *idx = B.CreateLoad(B.getInt64Ty(), idxStore); // i += 1 B.CreateStore(B.CreateAdd(idx, B.getInt64(1)), idxStore); // p_in1 += is1 B.CreateStore(B.CreateGEP(B.getInt8Ty(), pin1, is1), pin1Store); // p_in2 += is2 B.CreateStore(B.CreateGEP(B.getInt8Ty(), pin2, is2), pin2Store); // p_out += os B.CreateStore(B.CreateGEP(B.getInt8Ty(), pout, os), poutStore); idx = B.CreateLoad(B.getInt64Ty(), idxStore); // if i < n: goto loop; else: goto exit B.CreateCondBr(B.CreateICmpSLT(idx, n), loop, exit); B.SetInsertPoint(exit); B.CreateRet(llvm::UndefValue::get(parent->getReturnType())); } llvm::Function *makeFillIn(llvm::Function *F, Codegen codegen) { auto *M = F->getParent(); auto &context = M->getContext(); auto fillInName = (".codon.gpu.fillin." + F->getName()).str(); auto *fillIn = M->getFunction(fillInName); if (!fillIn) { fillIn = copyPrototype(F, fillInName); std::vector args; for (auto it = fillIn->arg_begin(); it != fillIn->arg_end(); ++it) { args.push_back(it); } auto *entry = llvm::BasicBlock::Create(context, "entry", fillIn); llvm::IRBuilder<> B(entry); codegen(B, args); } return fillIn; } llvm::Function *makeMalloc(llvm::Module *M) { auto &context = M->getContext(); llvm::IRBuilder<> B(context); auto F = M->getOrInsertFunction("malloc", B.getPtrTy(), B.getInt64Ty()); auto *G = llvm::cast(F.getCallee()); G->setLinkage(llvm::GlobalValue::ExternalLinkage); G->setDoesNotThrow(); G->setReturnDoesNotAlias(); G->setOnlyAccessesInaccessibleMemory(); G->setWillReturn(); return G; } void remapFunctions(llvm::Module *M) { // simple name-to-name remappings static const std::vector> remapping = { // 64-bit float intrinsics {"llvm.ceil.f64", "__nv_ceil"}, {"llvm.floor.f64", "__nv_floor"}, {"llvm.fabs.f64", "__nv_fabs"}, {"llvm.exp.f64", "__nv_exp"}, {"llvm.log.f64", "__nv_log"}, {"llvm.log2.f64", "__nv_log2"}, {"llvm.log10.f64", "__nv_log10"}, {"llvm.sqrt.f64", "__nv_sqrt"}, {"llvm.pow.f64", "__nv_pow"}, {"llvm.sin.f64", "__nv_sin"}, {"llvm.cos.f64", "__nv_cos"}, {"llvm.copysign.f64", "__nv_copysign"}, {"llvm.trunc.f64", "__nv_trunc"}, {"llvm.rint.f64", "__nv_rint"}, {"llvm.nearbyint.f64", "__nv_nearbyint"}, {"llvm.round.f64", "__nv_round"}, {"llvm.minnum.f64", "__nv_fmin"}, {"llvm.maxnum.f64", "__nv_fmax"}, {"llvm.copysign.f64", "__nv_copysign"}, {"llvm.fma.f64", "__nv_fma"}, // 64-bit float math functions {"expm1", "__nv_expm1"}, {"ldexp", "__nv_ldexp"}, {"acos", "__nv_acos"}, {"asin", "__nv_asin"}, {"atan", "__nv_atan"}, {"atan2", "__nv_atan2"}, {"hypot", "__nv_hypot"}, {"tan", "__nv_tan"}, {"cosh", "__nv_cosh"}, {"sinh", "__nv_sinh"}, {"tanh", "__nv_tanh"}, {"acosh", "__nv_acosh"}, {"asinh", "__nv_asinh"}, {"atanh", "__nv_atanh"}, {"erf", "__nv_erf"}, {"erfc", "__nv_erfc"}, {"tgamma", "__nv_tgamma"}, {"lgamma", "__nv_lgamma"}, {"remainder", "__nv_remainder"}, {"frexp", "__nv_frexp"}, {"modf", "__nv_modf"}, // 32-bit float intrinsics {"llvm.ceil.f32", "__nv_ceilf"}, {"llvm.floor.f32", "__nv_floorf"}, {"llvm.fabs.f32", "__nv_fabsf"}, {"llvm.exp.f32", "__nv_expf"}, {"llvm.log.f32", "__nv_logf"}, {"llvm.log2.f32", "__nv_log2f"}, {"llvm.log10.f32", "__nv_log10f"}, {"llvm.sqrt.f32", "__nv_sqrtf"}, {"llvm.pow.f32", "__nv_powf"}, {"llvm.sin.f32", "__nv_sinf"}, {"llvm.cos.f32", "__nv_cosf"}, {"llvm.copysign.f32", "__nv_copysignf"}, {"llvm.trunc.f32", "__nv_truncf"}, {"llvm.rint.f32", "__nv_rintf"}, {"llvm.nearbyint.f32", "__nv_nearbyintf"}, {"llvm.round.f32", "__nv_roundf"}, {"llvm.minnum.f32", "__nv_fminf"}, {"llvm.maxnum.f32", "__nv_fmaxf"}, {"llvm.copysign.f32", "__nv_copysignf"}, {"llvm.fma.f32", "__nv_fmaf"}, // 32-bit float math functions {"expm1f", "__nv_expm1f"}, {"ldexpf", "__nv_ldexpf"}, {"acosf", "__nv_acosf"}, {"asinf", "__nv_asinf"}, {"atanf", "__nv_atanf"}, {"atan2f", "__nv_atan2f"}, {"hypotf", "__nv_hypotf"}, {"tanf", "__nv_tanf"}, {"coshf", "__nv_coshf"}, {"sinhf", "__nv_sinhf"}, {"tanhf", "__nv_tanhf"}, {"acoshf", "__nv_acoshf"}, {"asinhf", "__nv_asinhf"}, {"atanhf", "__nv_atanhf"}, {"erff", "__nv_erff"}, {"erfcf", "__nv_erfcf"}, {"tgammaf", "__nv_tgammaf"}, {"lgammaf", "__nv_lgammaf"}, {"remainderf", "__nv_remainderf"}, {"frexpf", "__nv_frexpf"}, {"modff", "__nv_modff"}, // runtime library functions {"seq_free", "free"}, {"seq_register_finalizer", ""}, {"seq_gc_add_roots", ""}, {"seq_gc_remove_roots", ""}, {"seq_gc_clear_roots", ""}, {"seq_gc_exclude_static_roots", ""}, }; // functions that need to be generated as they're not available on GPU static const std::vector> fillins = { {"seq_alloc", [](llvm::IRBuilder<> &B, const std::vector &args) { auto *M = B.GetInsertBlock()->getModule(); llvm::Value *mem = B.CreateCall(makeMalloc(M), args[0]); B.CreateRet(mem); }}, {"seq_alloc_uncollectable", [](llvm::IRBuilder<> &B, const std::vector &args) { auto *M = B.GetInsertBlock()->getModule(); llvm::Value *mem = B.CreateCall(makeMalloc(M), args[0]); B.CreateRet(mem); }}, {"seq_alloc_atomic", [](llvm::IRBuilder<> &B, const std::vector &args) { auto *M = B.GetInsertBlock()->getModule(); llvm::Value *mem = B.CreateCall(makeMalloc(M), args[0]); B.CreateRet(mem); }}, {"seq_alloc_atomic_uncollectable", [](llvm::IRBuilder<> &B, const std::vector &args) { auto *M = B.GetInsertBlock()->getModule(); llvm::Value *mem = B.CreateCall(makeMalloc(M), args[0]); B.CreateRet(mem); }}, {"seq_realloc", [](llvm::IRBuilder<> &B, const std::vector &args) { auto *M = B.GetInsertBlock()->getModule(); llvm::Value *mem = B.CreateCall(makeMalloc(M), args[1]); auto F = llvm::Intrinsic::getDeclaration( M, llvm::Intrinsic::memcpy, {B.getPtrTy(), B.getPtrTy(), B.getInt64Ty()}); B.CreateCall(F, {mem, args[0], args[2], B.getFalse()}); B.CreateRet(mem); }}, {"seq_calloc", [](llvm::IRBuilder<> &B, const std::vector &args) { auto *M = B.GetInsertBlock()->getModule(); llvm::Value *size = B.CreateMul(args[0], args[1]); llvm::Value *mem = B.CreateCall(makeMalloc(M), size); auto F = llvm::Intrinsic::getDeclaration(M, llvm::Intrinsic::memset, {B.getPtrTy(), B.getInt64Ty()}); B.CreateCall(F, {mem, B.getInt8(0), size, B.getFalse()}); B.CreateRet(mem); }}, {"seq_calloc_atomic", [](llvm::IRBuilder<> &B, const std::vector &args) { auto *M = B.GetInsertBlock()->getModule(); llvm::Value *size = B.CreateMul(args[0], args[1]); llvm::Value *mem = B.CreateCall(makeMalloc(M), size); auto F = llvm::Intrinsic::getDeclaration(M, llvm::Intrinsic::memset, {B.getPtrTy(), B.getInt64Ty()}); B.CreateCall(F, {mem, B.getInt8(0), size, B.getFalse()}); B.CreateRet(mem); }}, {"seq_alloc_exc", [](llvm::IRBuilder<> &B, const std::vector &args) { // TODO: print error message and abort if in debug mode B.CreateUnreachable(); }}, {"seq_throw", [](llvm::IRBuilder<> &B, const std::vector &args) { B.CreateUnreachable(); }}, #define FILLIN_VECLOOP_UNARY32(loop, func) \ { \ loop, [](llvm::IRBuilder<> &B, const std::vector &args) { \ auto *M = B.GetInsertBlock()->getModule(); \ auto f = llvm::cast( \ M->getOrInsertFunction(func, B.getFloatTy(), B.getFloatTy()).getCallee()); \ f->setWillReturn(); \ codegenVectorizedUnaryLoop(B, args, f); \ } \ } #define FILLIN_VECLOOP_UNARY64(loop, func) \ { \ loop, [](llvm::IRBuilder<> &B, const std::vector &args) { \ auto *M = B.GetInsertBlock()->getModule(); \ auto f = llvm::cast( \ M->getOrInsertFunction(func, B.getDoubleTy(), B.getDoubleTy()).getCallee()); \ f->setWillReturn(); \ codegenVectorizedUnaryLoop(B, args, f); \ } \ } #define FILLIN_VECLOOP_BINARY32(loop, func) \ { \ loop, [](llvm::IRBuilder<> &B, const std::vector &args) { \ auto *M = B.GetInsertBlock()->getModule(); \ auto f = llvm::cast( \ M->getOrInsertFunction(func, B.getFloatTy(), B.getFloatTy(), B.getFloatTy()) \ .getCallee()); \ f->setWillReturn(); \ codegenVectorizedBinaryLoop(B, args, f); \ } \ } #define FILLIN_VECLOOP_BINARY64(loop, func) \ { \ loop, [](llvm::IRBuilder<> &B, const std::vector &args) { \ auto *M = B.GetInsertBlock()->getModule(); \ auto f = llvm::cast( \ M->getOrInsertFunction(func, B.getDoubleTy(), B.getDoubleTy(), \ B.getDoubleTy()) \ .getCallee()); \ f->setWillReturn(); \ codegenVectorizedBinaryLoop(B, args, f); \ } \ } FILLIN_VECLOOP_UNARY64("cnp_acos_float64", "__nv_acos"), FILLIN_VECLOOP_UNARY64("cnp_acosh_float64", "__nv_acosh"), FILLIN_VECLOOP_UNARY64("cnp_asin_float64", "__nv_asin"), FILLIN_VECLOOP_UNARY64("cnp_asinh_float64", "__nv_asinh"), FILLIN_VECLOOP_UNARY64("cnp_atan_float64", "__nv_atan"), FILLIN_VECLOOP_UNARY64("cnp_atanh_float64", "__nv_atanh"), FILLIN_VECLOOP_BINARY64("cnp_atan2_float64", "__nv_atan2"), FILLIN_VECLOOP_UNARY64("cnp_exp_float64", "__nv_exp"), FILLIN_VECLOOP_UNARY64("cnp_exp2_float64", "__nv_exp2"), FILLIN_VECLOOP_UNARY64("cnp_expm1_float64", "__nv_expm1"), FILLIN_VECLOOP_UNARY64("cnp_log_float64", "__nv_log"), FILLIN_VECLOOP_UNARY64("cnp_log10_float64", "__nv_log10"), FILLIN_VECLOOP_UNARY64("cnp_log1p_float64", "__nv_log1p"), FILLIN_VECLOOP_UNARY64("cnp_log2_float64", "__nv_log2"), FILLIN_VECLOOP_UNARY64("cnp_sin_float64", "__nv_sin"), FILLIN_VECLOOP_UNARY64("cnp_sinh_float64", "__nv_sinh"), FILLIN_VECLOOP_UNARY64("cnp_tan_float64", "__nv_tan"), FILLIN_VECLOOP_UNARY64("cnp_tanh_float64", "__nv_tanh"), FILLIN_VECLOOP_BINARY64("cnp_hypot_float64", "__nv_hypot"), FILLIN_VECLOOP_UNARY32("cnp_acos_float32", "__nv_acosf"), FILLIN_VECLOOP_UNARY32("cnp_acosh_float32", "__nv_acoshf"), FILLIN_VECLOOP_UNARY32("cnp_asin_float32", "__nv_asinf"), FILLIN_VECLOOP_UNARY32("cnp_asinh_float32", "__nv_asinhf"), FILLIN_VECLOOP_UNARY32("cnp_atan_float32", "__nv_atanf"), FILLIN_VECLOOP_UNARY32("cnp_atanh_float32", "__nv_atanhf"), FILLIN_VECLOOP_BINARY32("cnp_atan2_float32", "__nv_atan2f"), FILLIN_VECLOOP_UNARY32("cnp_exp_float32", "__nv_expf"), FILLIN_VECLOOP_UNARY32("cnp_exp2_float32", "__nv_exp2f"), FILLIN_VECLOOP_UNARY32("cnp_expm1_float32", "__nv_expm1f"), FILLIN_VECLOOP_UNARY32("cnp_log_float32", "__nv_logf"), FILLIN_VECLOOP_UNARY32("cnp_log10_float32", "__nv_log10f"), FILLIN_VECLOOP_UNARY32("cnp_log1p_float32", "__nv_log1pf"), FILLIN_VECLOOP_UNARY32("cnp_log2_float32", "__nv_log2f"), FILLIN_VECLOOP_UNARY32("cnp_sin_float32", "__nv_sinf"), FILLIN_VECLOOP_UNARY32("cnp_sinh_float32", "__nv_sinhf"), FILLIN_VECLOOP_UNARY32("cnp_tan_float32", "__nv_tanf"), FILLIN_VECLOOP_UNARY32("cnp_tanh_float32", "__nv_tanhf"), FILLIN_VECLOOP_BINARY32("cnp_hypot_float32", "__nv_hypotf"), }; for (auto &pair : remapping) { if (auto *F = M->getFunction(pair.first)) { llvm::Function *G = nullptr; if (pair.second.empty()) { G = makeNoOp(F); } else { G = M->getFunction(pair.second); if (!G) G = copyPrototype(F, pair.second, /*external=*/true); } G->setWillReturn(); F->replaceAllUsesWith(G); F->dropAllReferences(); F->eraseFromParent(); } } for (auto &pair : fillins) { if (auto *F = M->getFunction(pair.first)) { llvm::Function *G = makeFillIn(F, pair.second); F->replaceAllUsesWith(G); F->dropAllReferences(); F->eraseFromParent(); } } } void exploreGV(llvm::GlobalValue *G, llvm::SmallPtrSetImpl &keep) { if (keep.contains(G)) return; keep.insert(G); if (auto *F = llvm::dyn_cast(G)) { for (auto I = llvm::inst_begin(F), E = inst_end(F); I != E; ++I) { for (auto &U : I->operands()) { if (auto *G2 = llvm::dyn_cast(U.get())) exploreGV(G2, keep); } } } } std::vector getRequiredGVs(const std::vector &kernels) { llvm::SmallPtrSet keep; for (auto *G : kernels) { exploreGV(G, keep); } return std::vector(keep.begin(), keep.end()); } std::string moduleToPTX(llvm::Module *M, std::vector &kernels) { llvm::Triple triple(llvm::Triple::normalize(GPU_TRIPLE)); llvm::TargetLibraryInfoImpl tlii(triple); std::string err; const llvm::Target *target = llvm::TargetRegistry::lookupTarget("nvptx64", triple, err); seqassertn(target, "couldn't lookup target: {}", err); const llvm::TargetOptions options = llvm::codegen::InitTargetOptionsFromCodeGenFlags(triple); std::unique_ptr machine(target->createTargetMachine( triple.getTriple(), gpuName, gpuFeatures, options, llvm::codegen::getExplicitRelocModel(), llvm::codegen::getExplicitCodeModel(), llvm::CodeGenOptLevel::Aggressive)); // Remove personality functions for (auto &F : *M) { F.setDoesNotThrow(); F.setPersonalityFn(nullptr); } M->setDataLayout(machine->createDataLayout()); auto keep = getRequiredGVs(kernels); auto prune = [&](std::vector keep) { llvm::LoopAnalysisManager lam; llvm::FunctionAnalysisManager fam; llvm::CGSCCAnalysisManager cgam; llvm::ModuleAnalysisManager mam; llvm::ModulePassManager mpm; llvm::PassBuilder pb; pb.registerModuleAnalyses(mam); pb.registerCGSCCAnalyses(cgam); pb.registerFunctionAnalyses(fam); pb.registerLoopAnalyses(lam); pb.crossRegisterProxies(lam, fam, cgam, mam); mpm.addPass(GVExtractor(keep, false)); mpm.addPass(llvm::GlobalDCEPass()); mpm.addPass(llvm::StripDeadDebugInfoPass()); mpm.addPass(llvm::StripDeadPrototypesPass()); mpm.run(*M, mam); }; // Remove non-kernel functions. prune(keep); // Link libdevice and other cleanup. linkLibdevice(M, libdevice); remapFunctions(M); // Strip debug info and remove noinline from functions (added in debug mode). // Also, tell LLVM that all functions will return. for (auto &F : *M) { F.removeFnAttr(llvm::Attribute::AttrKind::NoInline); F.setWillReturn(); } llvm::StripDebugInfo(*M); // Run NVPTX passes and general opt pipeline. { llvm::LoopAnalysisManager lam; llvm::FunctionAnalysisManager fam; llvm::CGSCCAnalysisManager cgam; llvm::ModuleAnalysisManager mam; llvm::PassBuilder pb(machine.get()); llvm::TargetLibraryInfoImpl tlii(triple); fam.registerPass([&] { return llvm::TargetLibraryAnalysis(tlii); }); pb.registerModuleAnalyses(mam); pb.registerCGSCCAnalyses(cgam); pb.registerFunctionAnalyses(fam); pb.registerLoopAnalyses(lam); pb.crossRegisterProxies(lam, fam, cgam, mam); pb.registerPipelineStartEPCallback( [&](llvm::ModulePassManager &pm, llvm::OptimizationLevel opt) { pm.addPass(llvm::InternalizePass([&](const llvm::GlobalValue &gv) { return std::find(keep.begin(), keep.end(), &gv) != keep.end(); })); }); llvm::ModulePassManager mpm = pb.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3); mpm.run(*M, mam); } // Prune again after optimizations. keep = getRequiredGVs(kernels); prune(keep); // Clean up names. { for (auto &G : M->globals()) { G.setName(cleanUpName(G.getName())); } for (auto &F : M->functions()) { if (F.getInstructionCount() > 0) F.setName(cleanUpName(F.getName())); } for (auto *S : M->getIdentifiedStructTypes()) { S->setName(cleanUpName(S->getName())); } } // Generate PTX code. { llvm::SmallVector ptx; llvm::raw_svector_ostream os(ptx); auto *mmiwp = new llvm::MachineModuleInfoWrapperPass(machine.get()); llvm::legacy::PassManager pm; pm.add(new llvm::TargetLibraryInfoWrapperPass(tlii)); bool fail = machine->addPassesToEmitFile(pm, os, nullptr, llvm::CodeGenFileType::AssemblyFile, /*DisableVerify=*/false, mmiwp); seqassertn(!fail, "could not add passes"); const_cast(machine->getObjFileLowering()) ->Initialize(mmiwp->getMMI().getContext(), *machine); pm.run(*M); return std::string(ptx.data(), ptx.size()); } } void cleanUpIntrinsics(llvm::Module *M) { llvm::LLVMContext &context = M->getContext(); llvm::SmallVector remove; for (auto &F : *M) { if (F.getIntrinsicID() != llvm::Intrinsic::not_intrinsic && F.getName().starts_with("llvm.nvvm")) remove.push_back(&F); } for (auto *F : remove) { F->replaceAllUsesWith(makeNoOp(F)); F->dropAllReferences(); F->eraseFromParent(); } } void patchPTXVar(llvm::Module *M, llvm::GlobalValue *ptxVar, const std::string &ptxTarget = "__codon_ptx__") { // Find and patch direct calls to cuModuleLoadData() llvm::SmallVector callsToReplace; for (auto &F : *M) { for (auto &BB : F) { for (auto &I : BB) { auto *call = llvm::dyn_cast(&I); if (!call) continue; auto *callee = call->getCalledFunction(); if (!callee) continue; if (callee->getName() == ptxTarget && call->arg_size() == 0) callsToReplace.push_back(call); } } } for (auto *call : callsToReplace) { if (ptxVar) { call->replaceAllUsesWith(ptxVar); } else { call->replaceAllUsesWith( llvm::ConstantPointerNull::get(llvm::PointerType::get(M->getContext(), 0))); } call->dropAllReferences(); call->eraseFromParent(); } // Delete __codon_ptx__() stub if (auto *F = M->getFunction(ptxTarget)) { seqassertn(F->use_empty(), "some __codon_ptx__() calls not replaced in module"); F->eraseFromParent(); } } } // namespace void applyGPUTransformations(llvm::Module *M, const std::string &ptxFilename) { llvm::LLVMContext &context = M->getContext(); std::unique_ptr clone = llvm::CloneModule(*M); clone->setTargetTriple(llvm::Triple::normalize(GPU_TRIPLE)); clone->setDataLayout(GPU_DL); if (isFastMathOn()) { clone->addModuleFlag(llvm::Module::ModFlagBehavior::Override, "nvvm-reflect-ftz", 1); } llvm::NamedMDNode *nvvmAnno = clone->getOrInsertNamedMetadata("nvvm.annotations"); std::vector kernels; for (auto &F : *clone) { if (!F.hasFnAttribute("kernel")) continue; llvm::Metadata *nvvmElem[] = { llvm::ConstantAsMetadata::get(&F), llvm::MDString::get(context, "kernel"), llvm::ConstantAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 1)), }; nvvmAnno->addOperand(llvm::MDNode::get(context, nvvmElem)); kernels.push_back(&F); } if (kernels.empty()) { patchPTXVar(M, nullptr); return; } auto ptx = moduleToPTX(clone.get(), kernels); cleanUpIntrinsics(M); if (ptxOutput.getNumOccurrences() > 0) { std::error_code err; llvm::ToolOutputFile out(ptxOutput, err, llvm::sys::fs::OF_Text); seqassertn(!err, "Could not open file: {}", err.message()); llvm::raw_ostream &os = out.os(); os << ptx; os.flush(); out.keep(); } // Add ptx code as a global var auto *ptxVar = new llvm::GlobalVariable( *M, llvm::ArrayType::get(llvm::Type::getInt8Ty(context), ptx.length() + 1), /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, llvm::ConstantDataArray::getString(context, ptx), ".ptx"); ptxVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); patchPTXVar(M, ptxVar); } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/gpu.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/cir/llvm/llvm.h" namespace codon { namespace ir { /// Applies GPU-specific transformations and generates PTX /// code from kernel functions in the given LLVM module. /// @param module LLVM module containing GPU kernel functions (marked with "kernel" /// annotation) /// @param ptxFilename Filename for output PTX code; empty to use filename based on /// module void applyGPUTransformations(llvm::Module *module, const std::string &ptxFilename = ""); } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/llvisitor.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "llvisitor.h" #include #include #include #include #include #include #include #include "codon/cir/dsl/codegen.h" #include "codon/cir/llvm/optimize.h" #include "codon/cir/util/irtools.h" #include "codon/compiler/debug_listener.h" #include "codon/compiler/memory_manager.h" #include "codon/parser/common.h" #include "codon/runtime/lib.h" #include "codon/util/common.h" namespace codon { namespace ir { namespace { const std::string EXPORT_ATTR = ast::getMangledFunc("std.internal.attributes", "export"); const std::string INLINE_ATTR = ast::getMangledFunc("std.internal.attributes", "inline"); const std::string NOINLINE_ATTR = ast::getMangledFunc("std.internal.attributes", "noinline"); const std::string GPU_KERNEL_ATTR = ast::getMangledFunc("std.internal.gpu", "kernel"); const std::string MAIN_UNCLASH = ".main.unclash"; const std::string MAIN_CTOR = ".main.ctor"; enum GlobalCTORMode { No, Yes, Auto }; llvm::cl::opt GlobalCTOR( "global-ctor", llvm::cl::desc("generate global constructor with main code"), llvm::cl::values(clEnumValN(No, "no", "Keep main code in main() function"), clEnumValN(Yes, "yes", "Put main code in global constructor"), clEnumValN(Auto, "auto", "'yes' if shared library output, 'no' otherwise")), llvm::cl::init(Auto)); llvm::cl::opt DisableExceptions("disable-exceptions", llvm::cl::desc("Disable exception handling"), llvm::cl::init(false)); } // namespace llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) { std::string filename; std::string directory; auto pos = path.find_last_of("/"); if (pos != std::string::npos) { filename = path.substr(pos + 1); directory = path.substr(0, pos); } else { filename = path; directory = "."; } return builder->createFile(filename, directory); } std::string LLVMVisitor::getGlobalCtorName() { return MAIN_CTOR; } std::string LLVMVisitor::getNameForFunction(const Func *x) { if (isA(x) || util::hasAttribute(x, EXPORT_ATTR)) { return x->getUnmangledName(); } else if (util::hasAttribute(x, GPU_KERNEL_ATTR)) { return x->getName(); } else { return x->referenceString(); } } std::string LLVMVisitor::getNameForVar(const Var *x) { if (auto *f = cast(x)) return getNameForFunction(f); auto name = x->getName(); if (x->isExternal()) { return name; } else { // ".Lxxx" is a linker-local name, so add an underscore if needed return ((!name.empty() && name[0] == 'L') ? "._" : ".") + name; } } LLVMVisitor::LLVMVisitor() : util::ConstVisitor(), context(std::make_unique()), M(), B(std::make_unique>(*context)), func(nullptr), block(nullptr), value(nullptr), vars(), funcs(), coro(), loops(), trycatch(), finally(), catches(), db(), plugins(nullptr) { llvm::InitializeAllTargets(); llvm::InitializeAllTargetMCs(); llvm::InitializeAllAsmPrinters(); llvm::InitializeAllAsmParsers(); // Initialize passes auto ®istry = *llvm::PassRegistry::getPassRegistry(); llvm::initializeCore(registry); llvm::initializeScalarOpts(registry); llvm::initializeVectorization(registry); llvm::initializeIPO(registry); llvm::initializeAnalysis(registry); llvm::initializeTransformUtils(registry); llvm::initializeInstCombine(registry); llvm::initializeTarget(registry); llvm::initializeExpandLargeDivRemLegacyPassPass(registry); llvm::initializeExpandLargeFpConvertLegacyPassPass(registry); llvm::initializeExpandMemCmpLegacyPassPass(registry); llvm::initializeScalarizeMaskedMemIntrinLegacyPassPass(registry); llvm::initializeSelectOptimizePass(registry); llvm::initializeCallBrPreparePass(registry); llvm::initializeCodeGenPrepareLegacyPassPass(registry); llvm::initializeAtomicExpandLegacyPass(registry); llvm::initializeWinEHPreparePass(registry); llvm::initializeDwarfEHPrepareLegacyPassPass(registry); llvm::initializeSafeStackLegacyPassPass(registry); llvm::initializeSjLjEHPreparePass(registry); llvm::initializePreISelIntrinsicLoweringLegacyPassPass(registry); llvm::initializeGlobalMergePass(registry); llvm::initializeIndirectBrExpandLegacyPassPass(registry); llvm::initializeInterleavedLoadCombinePass(registry); llvm::initializeInterleavedAccessPass(registry); llvm::initializePostInlineEntryExitInstrumenterPass(registry); llvm::initializeUnreachableBlockElimLegacyPassPass(registry); llvm::initializeExpandReductionsPass(registry); llvm::initializeWasmEHPreparePass(registry); llvm::initializeWriteBitcodePassPass(registry); llvm::initializeReplaceWithVeclibLegacyPass(registry); llvm::initializeJMCInstrumenterPass(registry); } void LLVMVisitor::registerGlobal(const Var *var) { if (!var->isGlobal()) return; if (auto *f = cast(var)) { insertFunc(f, makeLLVMFunction(f)); } else { auto *llvmType = getLLVMType(var->getType()); if (llvmType->isVoidTy()) { insertVar(var, getDummyVoidValue()); } else { bool external = var->isExternal(); bool tls = var->isThreadLocal(); auto linkage = (db.jit || external) ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::PrivateLinkage; auto *storage = new llvm::GlobalVariable( *M, llvmType, /*isConstant=*/false, linkage, external ? nullptr : llvm::Constant::getNullValue(llvmType), getNameForVar(var)); insertVar(var, storage); if (external) { if (db.jit) { storage->setDSOLocal(true); } else { storage->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Local); } } else { // debug info auto *srcInfo = getSrcInfo(var); auto *file = db.getFile(srcInfo->file); auto *scope = db.unit; auto *debugVar = db.builder->createGlobalVariableExpression( scope, getDebugNameForVariable(var), var->getName(), file, srcInfo->line, getDIType(var->getType()), !var->isExternal()); storage->addDebugInfo(debugVar); } if (tls) storage->setThreadLocal(true); } } } llvm::Value *LLVMVisitor::getVar(const Var *var) { auto it = vars.find(var->getId()); if (db.jit && var->isGlobal()) { if (it != vars.end()) { if (!it->second) { // if value is null, it's from another module // see if it's in the module already auto name = var->getName(); auto privName = getNameForVar(var); if (auto *global = M->getNamedValue(privName)) return global; auto *llvmType = getLLVMType(var->getType()); auto *storage = new llvm::GlobalVariable(*M, llvmType, /*isConstant=*/false, llvm::GlobalValue::ExternalLinkage, /*Initializer=*/nullptr, privName); storage->setExternallyInitialized(true); // debug info auto *srcInfo = getSrcInfo(var); auto *file = db.getFile(srcInfo->file); auto *scope = db.unit; auto *debugVar = db.builder->createGlobalVariableExpression( scope, getDebugNameForVariable(var), name, file, srcInfo->line, getDIType(var->getType()), /*IsLocalToUnit=*/true); storage->addDebugInfo(debugVar); insertVar(var, storage); return storage; } } else { registerGlobal(var); it = vars.find(var->getId()); return it->second; } } return (it != vars.end()) ? it->second : nullptr; } llvm::Function *LLVMVisitor::getFunc(const Func *func) { auto it = funcs.find(func->getId()); if (db.jit) { if (it != funcs.end()) { if (!it->second) { // if value is null, it's from another module // see if it's in the module already const std::string name = getNameForFunction(func); if (auto *g = M->getFunction(name)) return g; auto *funcType = cast(func->getType()); auto *returnType = getLLVMType(funcType->getReturnType()); std::vector argTypes; for (const auto &argType : *funcType) { argTypes.push_back(getLLVMType(argType)); } auto *llvmFuncType = llvm::FunctionType::get(returnType, argTypes, funcType->isVariadic()); auto *g = llvm::Function::Create(llvmFuncType, llvm::Function::ExternalLinkage, name, M.get()); insertFunc(func, g); return g; } } else { registerGlobal(func); it = funcs.find(func->getId()); return it->second; } } return (it != funcs.end()) ? it->second : nullptr; } std::unique_ptr LLVMVisitor::makeModule(llvm::LLVMContext &context, const SrcInfo *src) { auto builder = llvm::EngineBuilder(); builder.setMArch(llvm::codegen::getMArch()); builder.setMCPU(llvm::codegen::getCPUStr()); builder.setMAttrs(llvm::codegen::getFeatureList()); auto target = builder.selectTarget(); auto M = std::make_unique("codon", context); M->setTargetTriple(target->getTargetTriple().str()); M->setDataLayout(target->createDataLayout()); B = std::make_unique>(context); auto *srcInfo = src ? src : getDefaultSrcInfo(); M->setSourceFileName(srcInfo->file); // debug info setup db.builder = std::make_unique(*M); auto *file = db.getFile(srcInfo->file); db.unit = db.builder->createCompileUnit(llvm::dwarf::DW_LANG_C, file, ("codon version " CODON_VERSION), !db.debug, db.flags, /*RV=*/0); M->addModuleFlag(llvm::Module::Warning, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); // darwin only supports dwarf2 if (llvm::Triple(M->getTargetTriple()).isOSDarwin()) { M->addModuleFlag(llvm::Module::Warning, "Dwarf Version", 2); } return M; } void LLVMVisitor::clearLLVMData() { B = {}; func = nullptr; block = nullptr; value = nullptr; for (auto it = funcs.begin(); it != funcs.end();) { if (it->second && it->second->hasPrivateLinkage()) { it = funcs.erase(it); } else { it->second = nullptr; ++it; } } for (auto it = vars.begin(); it != vars.end();) { if (it->second && !llvm::isa(it->second)) { it = vars.erase(it); } else { it->second = nullptr; ++it; } } coro.reset(); loops.clear(); trycatch.clear(); finally.clear(); catches.clear(); db.reset(); context = {}; M = {}; } std::pair, std::unique_ptr> LLVMVisitor::takeModule(Module *module, const SrcInfo *src) { // process any new functions or globals if (module) { std::unordered_set funcsToProcess; for (auto *var : *module) { auto id = var->getId(); if (auto *func = cast(var)) { if (funcs.find(id) != funcs.end()) continue; else funcsToProcess.insert(id); } else { if (vars.find(id) != vars.end()) continue; } registerGlobal(var); } for (auto *var : *module) { if (auto *func = cast(var)) { if (funcsToProcess.find(func->getId()) != funcsToProcess.end()) { process(func); } } } } db.builder->finalize(); auto currentContext = std::move(context); auto currentModule = std::move(M); // reset all LLVM fields/data -- they are owned by the context clearLLVMData(); context = std::make_unique(); M = makeModule(*context, src); return {std::move(currentModule), std::move(currentContext)}; } void LLVMVisitor::setDebugInfoForNode(const Node *x) { if (x && func) { auto *srcInfo = getSrcInfo(x); B->SetCurrentDebugLocation(llvm::DILocation::get( *context, srcInfo->line, srcInfo->col, func->getSubprogram())); } else { B->SetCurrentDebugLocation(llvm::DebugLoc()); } } void LLVMVisitor::process(const Node *x) { setDebugInfoForNode(x); x->accept(*this); } void LLVMVisitor::dump(const std::string &filename) { writeToLLFile(filename, false); } void LLVMVisitor::runLLVMPipeline() { db.builder->finalize(); optimize(M.get(), db.debug, db.jit, plugins); } void LLVMVisitor::writeToObjectFile(const std::string &filename, bool pic, bool assembly) { if (GlobalCTOR == GlobalCTORMode::Yes) setupGlobalCtor(); runLLVMPipeline(); std::error_code err; auto out = std::make_unique(filename, err, llvm::sys::fs::OF_None); if (err) compilationError(err.message()); auto *os = &out->os(); auto machine = getTargetMachine(M.get(), /*setFunctionAttributes=*/false, pic); auto *mmiwp = new llvm::MachineModuleInfoWrapperPass(machine.get()); llvm::legacy::PassManager pm; llvm::TargetLibraryInfoImpl tlii(llvm::Triple(M->getTargetTriple())); pm.add(new llvm::TargetLibraryInfoWrapperPass(tlii)); if (machine->addPassesToEmitFile(pm, *os, nullptr, assembly ? llvm::CodeGenFileType::AssemblyFile : llvm::CodeGenFileType::ObjectFile, /*DisableVerify=*/true, mmiwp)) seqassertn(false, "could not add passes"); const_cast(machine->getObjFileLowering()) ->Initialize(mmiwp->getMMI().getContext(), *machine); pm.run(*M); out->keep(); } void LLVMVisitor::writeToBitcodeFile(const std::string &filename) { runLLVMPipeline(); std::error_code err; llvm::raw_fd_ostream stream(filename, err, llvm::sys::fs::OF_None); llvm::WriteBitcodeToFile(*M, stream); if (err) { compilationError(err.message()); } } void LLVMVisitor::writeToLLFile(const std::string &filename, bool optimize) { if (GlobalCTOR == GlobalCTORMode::Yes) setupGlobalCtor(); if (optimize) runLLVMPipeline(); auto fo = fopen(filename.c_str(), "w"); llvm::raw_fd_ostream fout(fileno(fo), true); fout << *M; fout.close(); } namespace { void executeCommand(const std::vector &args) { std::vector cArgs; for (auto &arg : args) { cArgs.push_back(arg.c_str()); } LOG_USER("Executing '{}'", fmt::join(cArgs, " ")); cArgs.push_back(nullptr); if (fork() == 0) { int status = execvp(cArgs[0], (char *const *)&cArgs[0]); exit(status); } else { int status; if (wait(&status) < 0) { compilationError("process for '" + args[0] + "' encountered an error in wait"); } if (WEXITSTATUS(status) != 0) { compilationError("process for '" + args[0] + "' exited with status " + std::to_string(WEXITSTATUS(status))); } } } } // namespace void LLVMVisitor::setupGlobalCtor() { const std::string llvmCtor = "llvm.global_ctors"; if (M->getNamedValue(llvmCtor)) return; auto *main = M->getFunction(MAIN_UNCLASH); if (!main) { main = M->getFunction("main"); if (!main) return; main->setName(MAIN_UNCLASH); // avoid clash with other main main->setLinkage(llvm::GlobalValue::PrivateLinkage); } auto *ctorFuncTy = llvm::FunctionType::get(B->getVoidTy(), {}, /*isVarArg=*/false); auto *ctorEntryTy = llvm::StructType::get(B->getInt32Ty(), ctorFuncTy->getPointerTo(), B->getPtrTy()); auto *ctorArrayTy = llvm::ArrayType::get(ctorEntryTy, 1); auto *ctor = cast(M->getOrInsertFunction(MAIN_CTOR, ctorFuncTy).getCallee()); ctor->setLinkage(llvm::GlobalValue::PrivateLinkage); auto *entry = llvm::BasicBlock::Create(*context, "entry", ctor); B->SetInsertPoint(entry); B->CreateCall( {main->getFunctionType(), main}, {B->getInt32(0), llvm::ConstantPointerNull::get(B->getPtrTy()->getPointerTo())}); B->CreateRetVoid(); const int priority = 65535; // default auto *ctorEntry = llvm::ConstantStruct::get( ctorEntryTy, {B->getInt32(priority), ctor, llvm::ConstantPointerNull::get(B->getPtrTy())}); new llvm::GlobalVariable(*M, ctorArrayTy, /*isConstant=*/true, llvm::GlobalValue::AppendingLinkage, llvm::ConstantArray::get(ctorArrayTy, {ctorEntry}), llvmCtor); } void LLVMVisitor::writeToExecutable(const std::string &filename, const std::string &argv0, bool library, const std::vector &libs, const std::string &lflags) { if (library && GlobalCTOR != GlobalCTORMode::No) setupGlobalCtor(); const std::string objFile = filename + ".o"; writeToObjectFile(objFile, /*pic=*/library); const std::string base = ast::Filesystem::executable_path(argv0.c_str()); auto path = llvm::SmallString<128>(llvm::sys::path::parent_path(base)); std::vector relatives = {"../lib", "../lib/codon"}; std::vector rpaths; for (const auto &rel : relatives) { auto newPath = path; llvm::sys::path::append(newPath, rel); llvm::sys::path::remove_dots(newPath, /*remove_dot_dot=*/true); if (llvm::sys::fs::exists(newPath)) { rpaths.push_back(std::string(newPath)); } } if (rpaths.empty()) { rpaths.push_back(std::string(path)); } std::vector command = {"g++"}; // Avoid "argument unused during compilation" warning command.push_back("-Wno-unused-command-line-argument"); // MUST go before -llib to compile on Linux command.push_back(objFile); if (library) command.push_back("-shared"); for (const auto &rpath : rpaths) { if (!rpath.empty()) { command.push_back("-L" + rpath); command.push_back("-Wl,-rpath," + rpath); } } if (plugins) { for (auto *plugin : *plugins) { auto dylibPath = plugin->info.dylibPath; if (dylibPath.empty()) continue; llvm::SmallString<128> rpath0 = llvm::sys::path::parent_path(dylibPath); llvm::sys::fs::make_absolute(rpath0); llvm::StringRef rpath = rpath0.str(); if (!rpath.empty()) { command.push_back("-L" + rpath.str()); command.push_back("-Wl,-rpath," + rpath.str()); } } } for (const auto &lib : libs) { command.push_back("-l" + lib); } if (plugins) { for (auto *plugin : *plugins) { if (plugin->info.linkArgs.empty()) { auto dylibPath = plugin->info.dylibPath; if (dylibPath.empty()) continue; auto stem = llvm::sys::path::stem(dylibPath); if (stem.starts_with("lib")) stem = stem.substr(3); command.push_back("-l" + stem.str()); } else { for (auto &l : plugin->info.linkArgs) command.push_back(l); } } } std::vector extraArgs = { "-lcodonrt", "-lomp", "-lpthread", "-ldl", "-lz", "-lm", "-lc", "-o", filename}; for (const auto &arg : extraArgs) { command.push_back(arg); } llvm::SmallVector userFlags(16); llvm::StringRef(lflags).split(userFlags, " ", /*MaxSplit=*/-1, /*KeepEmpty=*/false); for (const auto &uflag : userFlags) { if (!uflag.empty()) command.push_back(uflag.str()); } // Avoid "relocation R_X86_64_32 against `.bss' can not be used when making a PIE // object" complaints by gcc when it is built with --enable-default-pie if (!library) command.push_back("-no-pie"); executeCommand(command); #if __APPLE__ if (db.debug) executeCommand({"dsymutil", filename}); #endif llvm::sys::fs::remove(objFile); } namespace { // https://github.com/python/cpython/blob/main/Include/methodobject.h constexpr int PYEXT_METH_VARARGS = 0x0001; constexpr int PYEXT_METH_KEYWORDS = 0x0002; constexpr int PYEXT_METH_NOARGS = 0x0004; constexpr int PYEXT_METH_O = 0x0008; constexpr int PYEXT_METH_CLASS = 0x0010; constexpr int PYEXT_METH_STATIC = 0x0020; constexpr int PYEXT_METH_COEXIST = 0x0040; constexpr int PYEXT_METH_FASTCALL = 0x0080; constexpr int PYEXT_METH_METHOD = 0x0200; // https://github.com/python/cpython/blob/main/Include/modsupport.h constexpr int PYEXT_PYTHON_ABI_VERSION = 1013; // https://github.com/python/cpython/blob/main/Include/descrobject.h constexpr int PYEXT_READONLY = 1; } // namespace llvm::Function *LLVMVisitor::createPyTryCatchWrapper(llvm::Function *func) { auto *wrap = cast(M->getOrInsertFunction((func->getName() + ".tc_wrap").str(), func->getFunctionType()) .getCallee()); wrap->setPersonalityFn(llvm::cast(makePersonalityFunc().getCallee())); auto *entry = llvm::BasicBlock::Create(*context, "entry", wrap); auto *normal = llvm::BasicBlock::Create(*context, "normal", wrap); auto *unwind = llvm::BasicBlock::Create(*context, "unwind", wrap); B->SetInsertPoint(entry); std::vector args; for (auto &arg : wrap->args()) { args.push_back(&arg); } auto *result = B->CreateInvoke(func, normal, unwind, args); B->SetInsertPoint(normal); B->CreateRet(result); B->SetInsertPoint(unwind); auto *caughtResult = B->CreateLandingPad(getPadType(), 1); caughtResult->setCleanup(true); caughtResult->addClause(getTypeIdxVar(nullptr)); auto *unwindType = llvm::StructType::get(B->getInt64Ty()); // header only auto *unwindException = B->CreateExtractValue(caughtResult, 0); auto *unwindExceptionClass = B->CreateLoad( B->getInt64Ty(), B->CreateStructGEP(unwindType, unwindException, 0)); unwindException = B->CreateExtractValue(caughtResult, 0); auto *excVal = B->CreateConstGEP1_64(B->getInt8Ty(), unwindException, (uint64_t)seq_exc_offset()); auto *loadedExc = B->CreateLoad(B->getPtrTy(), excVal); auto *strType = llvm::StructType::get(B->getInt64Ty(), B->getPtrTy()); auto *excHeader = llvm::StructType::get(strType, strType, strType, B->getInt64Ty(), B->getInt64Ty(), B->getPtrTy(), B->getPtrTy()); auto *header = B->CreateLoad(excHeader, B->CreateLoad(B->getPtrTy(), loadedExc)); auto *msg = B->CreateExtractValue(header, 0); auto *msgLen = B->CreateExtractValue(msg, 0); auto *msgPtr = B->CreateExtractValue(msg, 1); auto *pyType = B->CreateExtractValue(header, 5); // copy msg into new null-terminated buffer auto alloc = makeAllocFunc(/*atomic=*/true); auto *buf = B->CreateCall(alloc, B->CreateAdd(msgLen, B->getInt64(1))); B->CreateMemCpy(buf, {}, msgPtr, {}, msgLen); auto *last = B->CreateInBoundsGEP(B->getInt8Ty(), buf, msgLen); B->CreateStore(B->getInt8(0), last); auto *pyErrSetString = llvm::cast( M->getOrInsertFunction("PyErr_SetString", B->getVoidTy(), B->getPtrTy(), B->getPtrTy()) .getCallee()); const std::string pyExcRuntimeErrorName = "PyExc_RuntimeError"; llvm::Value *pyExcRuntimeError = M->getNamedValue(pyExcRuntimeErrorName); if (!pyExcRuntimeError) { auto *pyExcRuntimeErrorVar = new llvm::GlobalVariable( *M, B->getPtrTy(), /*isConstant=*/false, llvm::GlobalValue::ExternalLinkage, /*Initializer=*/nullptr, pyExcRuntimeErrorName); pyExcRuntimeErrorVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); pyExcRuntimeError = pyExcRuntimeErrorVar; } pyExcRuntimeError = B->CreateLoad(B->getPtrTy(), pyExcRuntimeError); auto *havePyType = B->CreateICmpNE(pyType, llvm::ConstantPointerNull::get(B->getPtrTy())); B->CreateCall(pyErrSetString, {B->CreateSelect(havePyType, pyType, pyExcRuntimeError), buf}); auto *retType = wrap->getReturnType(); if (retType == B->getInt32Ty()) { B->CreateRet(B->getInt32(-1)); } else { B->CreateRet(llvm::Constant::getNullValue(retType)); } return wrap; } void LLVMVisitor::writeToPythonExtension(const PyModule &pymod, const std::string &filename) { // Setup LLVM types & constants auto *i64 = B->getInt64Ty(); auto *i32 = B->getInt32Ty(); auto *i8 = B->getInt8Ty(); auto *ptr = B->getPtrTy(); auto *pyMethodDefType = llvm::StructType::create("PyMethodDef", ptr, ptr, i32, ptr); auto *pyObjectType = llvm::StructType::create("PyObject", i64, ptr); auto *pyVarObjectType = llvm::StructType::create("PyVarObject", pyObjectType, i64); auto *pyModuleDefBaseType = llvm::StructType::create("PyMethodDefBase", pyObjectType, ptr, i64, ptr); auto *pyModuleDefType = llvm::StructType::create("PyModuleDef", pyModuleDefBaseType, ptr, ptr, i64, pyMethodDefType->getPointerTo(), ptr, ptr, ptr, ptr); auto *pyMemberDefType = llvm::StructType::create("PyMemberDef", ptr, i32, i64, i32, ptr); auto *pyGetSetDefType = llvm::StructType::create("PyGetSetDef", ptr, ptr, ptr, ptr, ptr); std::vector pyNumberMethodsFields(36, ptr); auto *pyNumberMethodsType = llvm::StructType::create(*context, pyNumberMethodsFields, "PyNumberMethods"); std::vector pySequenceMethodsFields(10, ptr); auto *pySequenceMethodsType = llvm::StructType::create(*context, pySequenceMethodsFields, "PySequenceMethods"); std::vector pyMappingMethodsFields(3, ptr); auto *pyMappingMethodsType = llvm::StructType::create(*context, pyMappingMethodsFields, "PyMappingMethods"); std::vector pyAsyncMethodsFields(4, ptr); auto *pyAsyncMethodsType = llvm::StructType::create(*context, pyAsyncMethodsFields, "PyAsyncMethods"); auto *pyBufferProcsType = llvm::StructType::create("PyBufferProcs", ptr, ptr); auto *pyTypeObjectType = llvm::StructType::create( "PyTypeObject", pyVarObjectType, ptr, i64, i64, ptr, i64, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, i64, ptr, ptr, ptr, ptr, i64, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, i64, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, i32, ptr, ptr, i8); auto *zero64 = B->getInt64(0); auto *zero32 = B->getInt32(0); auto *zero8 = B->getInt8(0); auto *null = llvm::Constant::getNullValue(ptr); auto *pyTypeType = new llvm::GlobalVariable(*M, ptr, /*isConstant=*/false, llvm::GlobalValue::ExternalLinkage, /*Initializer=*/nullptr, "PyType_Type"); auto allocUncollectable = llvm::cast( makeAllocFunc(/*atomic=*/false, /*uncollectable=*/true).getCallee()); auto free = llvm::cast(makeFreeFunc().getCallee()); // Helpers auto pyFuncWrap = [&](Func *func, bool wrap) -> llvm::Constant * { if (!func) return null; auto llvmName = getNameForFunction(func); auto *llvmFunc = M->getFunction(llvmName); seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName); if (wrap) llvmFunc = createPyTryCatchWrapper(llvmFunc); return llvmFunc; }; auto pyFunc = [&](Func *func) -> llvm::Constant * { return pyFuncWrap(func, true); }; auto pyString = [&](const std::string &str) -> llvm::Constant * { if (str.empty()) return null; auto *var = new llvm::GlobalVariable( *M, llvm::ArrayType::get(i8, str.length() + 1), /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, llvm::ConstantDataArray::getString(*context, str), ".pyext_str"); var->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); return var; }; auto pyFunctions = [&](const std::vector &functions) -> llvm::Constant * { if (functions.empty()) return null; std::vector pyMethods; for (auto &pyfunc : functions) { int flag = 0; if (pyfunc.keywords) { flag = PYEXT_METH_FASTCALL | PYEXT_METH_KEYWORDS; } else { switch (pyfunc.nargs) { case 0: flag = PYEXT_METH_NOARGS; break; case 1: flag = PYEXT_METH_O; break; default: flag = PYEXT_METH_FASTCALL; break; } } switch (pyfunc.type) { case PyFunction::CLASS: flag |= PYEXT_METH_CLASS; break; case PyFunction::STATIC: flag |= PYEXT_METH_STATIC; break; default: break; } if (pyfunc.coexist) flag |= PYEXT_METH_COEXIST; pyMethods.push_back(llvm::ConstantStruct::get( pyMethodDefType, pyString(pyfunc.name), pyFunc(pyfunc.func), B->getInt32(flag), pyString(pyfunc.doc))); } pyMethods.push_back( llvm::ConstantStruct::get(pyMethodDefType, null, null, zero32, null)); auto *pyMethodDefArrayType = llvm::ArrayType::get(pyMethodDefType, pyMethods.size()); auto *pyMethodDefArray = new llvm::GlobalVariable( *M, pyMethodDefArrayType, /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, llvm::ConstantArray::get(pyMethodDefArrayType, pyMethods), ".pyext_methods"); return pyMethodDefArray; }; auto pyMembers = [&](const std::vector &members, llvm::StructType *type) -> llvm::Constant * { if (members.empty()) return null; std::vector pyMemb; for (auto &memb : members) { // Calculate offset by creating const GEP into null ptr std::vector indexes = {zero64, B->getInt32(1)}; for (auto idx : memb.indexes) { indexes.push_back(B->getInt32(idx)); } auto offset = llvm::ConstantExpr::getPtrToInt( llvm::ConstantExpr::getGetElementPtr(type, null, indexes), i64); pyMemb.push_back(llvm::ConstantStruct::get( pyMemberDefType, pyString(memb.name), B->getInt32(memb.type), offset, B->getInt32(memb.readonly ? PYEXT_READONLY : 0), pyString(memb.doc))); } pyMemb.push_back( llvm::ConstantStruct::get(pyMemberDefType, null, zero32, zero64, zero32, null)); auto *pyMemberDefArrayType = llvm::ArrayType::get(pyMemberDefType, pyMemb.size()); auto *pyMemberDefArray = new llvm::GlobalVariable( *M, pyMemberDefArrayType, /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, llvm::ConstantArray::get(pyMemberDefArrayType, pyMemb), ".pyext_members"); return pyMemberDefArray; }; auto pyGetSet = [&](const std::vector &getset) -> llvm::Constant * { if (getset.empty()) return null; std::vector pyGS; for (auto &gs : getset) { pyGS.push_back(llvm::ConstantStruct::get(pyGetSetDefType, pyString(gs.name), pyFunc(gs.get), pyFunc(gs.set), pyString(gs.doc), null)); } pyGS.push_back( llvm::ConstantStruct::get(pyGetSetDefType, null, null, null, null, null)); auto *pyGetSetDefArrayType = llvm::ArrayType::get(pyGetSetDefType, pyGS.size()); auto *pyGetSetDefArray = new llvm::GlobalVariable( *M, pyGetSetDefArrayType, /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, llvm::ConstantArray::get(pyGetSetDefArrayType, pyGS), ".pyext_getset"); return pyGetSetDefArray; }; // Construct PyModuleDef array auto *pyObjectConst = llvm::ConstantStruct::get(pyObjectType, B->getInt64(1), null); auto *pyModuleDefBaseConst = llvm::ConstantStruct::get(pyModuleDefBaseType, pyObjectConst, null, zero64, null); auto *pyModuleDef = llvm::ConstantStruct::get( pyModuleDefType, pyModuleDefBaseConst, pyString(pymod.name), pyString(pymod.doc), B->getInt64(-1), pyFunctions(pymod.functions), null, null, null, null); auto *pyModuleVar = new llvm::GlobalVariable(*M, pyModuleDef->getType(), /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, pyModuleDef, ".pyext_module"); std::unordered_map typeVars; for (auto &pytype : pymod.types) { std::vector numberSlots = { pyFunc(pytype.add), // nb_add pyFunc(pytype.sub), // nb_subtract pyFunc(pytype.mul), // nb_multiply pyFunc(pytype.mod), // nb_remainder pyFunc(pytype.divmod), // nb_divmod pyFunc(pytype.pow), // nb_power pyFunc(pytype.neg), // nb_negative pyFunc(pytype.pos), // nb_positive pyFunc(pytype.abs), // nb_absolute pyFunc(pytype.bool_), // nb_bool pyFunc(pytype.invert), // nb_invert pyFunc(pytype.lshift), // nb_lshift pyFunc(pytype.rshift), // nb_rshift pyFunc(pytype.and_), // nb_and pyFunc(pytype.xor_), // nb_xor pyFunc(pytype.or_), // nb_or pyFunc(pytype.int_), // nb_int null, // nb_reserved pyFunc(pytype.float_), // nb_float pyFunc(pytype.iadd), // nb_inplace_add pyFunc(pytype.isub), // nb_inplace_subtract pyFunc(pytype.imul), // nb_inplace_multiply pyFunc(pytype.imod), // nb_inplace_remainder pyFunc(pytype.ipow), // nb_inplace_power pyFunc(pytype.ilshift), // nb_inplace_lshift pyFunc(pytype.irshift), // nb_inplace_rshift pyFunc(pytype.iand), // nb_inplace_and pyFunc(pytype.ixor), // nb_inplace_xor pyFunc(pytype.ior), // nb_inplace_or pyFunc(pytype.floordiv), // nb_floor_divide pyFunc(pytype.truediv), // nb_true_divide pyFunc(pytype.ifloordiv), // nb_inplace_floor_divide pyFunc(pytype.itruediv), // nb_inplace_true_divide pyFunc(pytype.index), // nb_index pyFunc(pytype.matmul), // nb_matrix_multiply pyFunc(pytype.imatmul), // nb_inplace_matrix_multiply }; std::vector sequenceSlots = { pyFunc(pytype.len), // sq_length null, // sq_concat null, // sq_repeat null, // sq_item null, // was_sq_slice null, // sq_ass_item null, // was_sq_ass_slice pyFunc(pytype.contains), // sq_contains null, // sq_inplace_concat null, // sq_inplace_repeat }; std::vector mappingSlots = { null, // mp_length pyFunc(pytype.getitem), // mp_subscript pyFunc(pytype.setitem), // mp_ass_subscript }; bool needNumberSlots = std::find_if(numberSlots.begin(), numberSlots.end(), [&](auto *v) { return v != null; }) != numberSlots.end(); bool needSequenceSlots = std::find_if(sequenceSlots.begin(), sequenceSlots.end(), [&](auto *v) { return v != null; }) != sequenceSlots.end(); bool needMappingSlots = std::find_if(mappingSlots.begin(), mappingSlots.end(), [&](auto *v) { return v != null; }) != mappingSlots.end(); auto *numberSlotsConst = null; auto *sequenceSlotsConst = null; auto *mappingSlotsConst = null; if (needNumberSlots) { auto *pyNumberSlotsVar = new llvm::GlobalVariable( *M, pyNumberMethodsType, /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, llvm::ConstantStruct::get(pyNumberMethodsType, numberSlots), ".pyext_number_slots." + pytype.name); numberSlotsConst = pyNumberSlotsVar; } if (needSequenceSlots) { auto *pySequenceSlotsVar = new llvm::GlobalVariable( *M, pySequenceMethodsType, /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, llvm::ConstantStruct::get(pySequenceMethodsType, sequenceSlots), ".pyext_sequence_slots." + pytype.name); sequenceSlotsConst = pySequenceSlotsVar; } if (needMappingSlots) { auto *pyMappingSlotsVar = new llvm::GlobalVariable( *M, pyMappingMethodsType, /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, llvm::ConstantStruct::get(pyMappingMethodsType, mappingSlots), ".pyext_mapping_slots." + pytype.name); mappingSlotsConst = pyMappingSlotsVar; } auto *refType = cast(pytype.type); if (refType) { seqassertn(!refType->isPolymorphic(), "Python extension types cannot be polymorphic"); } auto *llvmType = getLLVMType(pytype.type); auto *objectType = llvm::StructType::get(pyObjectType, llvmType); auto codonSize = refType ? M->getDataLayout().getTypeAllocSize(getLLVMType(refType->getContents())) : 0; auto pySize = M->getDataLayout().getTypeAllocSize(objectType); auto *alloc = llvm::cast( M->getOrInsertFunction(pytype.name + ".py_alloc", ptr, ptr, i64).getCallee()); { auto *entry = llvm::BasicBlock::Create(*context, "entry", alloc); B->SetInsertPoint(entry); auto *pythonObject = B->CreateCall(allocUncollectable, B->getInt64(pySize)); auto *header = B->CreateInsertValue( llvm::ConstantStruct::get(pyObjectType, B->getInt64(1), null), alloc->arg_begin(), 1); B->CreateStore(header, pythonObject); if (refType) { auto *codonObject = B->CreateCall( makeAllocFunc(refType->getContents()->isAtomic()), B->getInt64(codonSize)); B->CreateStore(codonObject, B->CreateGEP(objectType, pythonObject, {zero64, B->getInt32(1)})); } B->CreateRet(pythonObject); } auto *delFn = pyFuncWrap(pytype.del, /*wrap=*/false); auto *dealloc = llvm::cast( M->getOrInsertFunction(pytype.name + ".py_dealloc", B->getVoidTy(), ptr) .getCallee()); { auto *obj = dealloc->arg_begin(); auto *entry = llvm::BasicBlock::Create(*context, "entry", dealloc); B->SetInsertPoint(entry); if (delFn != null) B->CreateCall(llvm::FunctionCallee(dealloc->getFunctionType(), delFn), obj); B->CreateCall(free, obj); B->CreateRetVoid(); } auto *pyNew = llvm::cast( M->getOrInsertFunction("PyType_GenericNew", ptr, ptr, ptr, ptr).getCallee()); std::vector typeSlots = { llvm::ConstantStruct::get( pyVarObjectType, llvm::ConstantStruct::get(pyObjectType, B->getInt64(1), pyTypeType), zero64), // PyObject_VAR_HEAD pyString(pymod.name + "." + pytype.name), // tp_name B->getInt64(pySize), // tp_basicsize zero64, // tp_itemsize dealloc, // tp_dealloc zero64, // tp_vectorcall_offset null, // tp_getattr null, // tp_setattr null, // tp_as_async pyFunc(pytype.repr), // tp_repr numberSlotsConst, // tp_as_number sequenceSlotsConst, // tp_as_sequence mappingSlotsConst, // tp_as_mapping pyFunc(pytype.hash), // tp_hash pyFunc(pytype.call), // tp_call pyFunc(pytype.str), // tp_str null, // tp_getattro null, // tp_setattro null, // tp_as_buffer zero64, // tp_flags pyString(pytype.doc), // tp_doc null, // tp_traverse null, // tp_clear pyFunc(pytype.cmp), // tp_richcompare zero64, // tp_weaklistoffset pyFunc(pytype.iter), // tp_iter pyFunc(pytype.iternext), // tp_iternext pyFunctions(pytype.methods), // tp_methods pyMembers(pytype.members, objectType), // tp_members pyGetSet(pytype.getset), // tp_getset null, // tp_base null, // tp_dict null, // tp_descr_get null, // tp_descr_set zero64, // tp_dictoffset pyFunc(pytype.init), // tp_init alloc, // tp_alloc pyNew, // tp_new free, // tp_free null, // tp_is_gc null, // tp_bases null, // tp_mro null, // tp_cache null, // tp_subclasses null, // tp_weaklist null, // tp_del zero32, // tp_version_tag free, // tp_finalize null, // tp_vectorcall B->getInt8(0), // tp_watched }; auto *pyTypeObjectVar = new llvm::GlobalVariable( *M, pyTypeObjectType, /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, llvm::ConstantStruct::get(pyTypeObjectType, typeSlots), ".pyext_type." + pytype.name); if (pytype.typePtrHook) { auto *hook = llvm::cast(pyFuncWrap(pytype.typePtrHook, false)); for (auto it = llvm::inst_begin(hook), end = llvm::inst_end(hook); it != end; ++it) { if (auto *ret = llvm::dyn_cast(&*it)) ret->setOperand(0, pyTypeObjectVar); } } typeVars.emplace(pytype.type, pyTypeObjectVar); } // Construct initialization hook auto pyIncRef = llvm::cast( M->getOrInsertFunction("Py_IncRef", B->getVoidTy(), ptr).getCallee()); pyIncRef->setDoesNotThrow(); auto pyDecRef = llvm::cast( M->getOrInsertFunction("Py_DecRef", B->getVoidTy(), ptr).getCallee()); pyDecRef->setDoesNotThrow(); auto *pyModuleCreate = llvm::cast( M->getOrInsertFunction("PyModule_Create2", ptr, ptr, i32).getCallee()); pyModuleCreate->setDoesNotThrow(); auto *pyTypeReady = llvm::cast( M->getOrInsertFunction("PyType_Ready", i32, ptr).getCallee()); pyTypeReady->setDoesNotThrow(); auto *pyModuleAddObject = llvm::cast( M->getOrInsertFunction("PyModule_AddObject", i32, ptr, ptr, ptr).getCallee()); pyModuleAddObject->setDoesNotThrow(); auto *pyModuleInit = llvm::cast( M->getOrInsertFunction("PyInit_" + pymod.name, ptr).getCallee()); auto *block = llvm::BasicBlock::Create(*context, "entry", pyModuleInit); B->SetInsertPoint(block); if (auto *main = M->getFunction("main")) { main->setName(MAIN_UNCLASH); B->CreateCall({main->getFunctionType(), main}, {zero32, null}); } // Set base types for (auto &pytype : pymod.types) { if (pytype.base) { auto subcIt = typeVars.find(pytype.type); auto baseIt = typeVars.find(pytype.base->type); seqassertn(subcIt != typeVars.end() && baseIt != typeVars.end(), "types not found"); // 30 is the index of tp_base B->CreateStore(baseIt->second, B->CreateConstInBoundsGEP2_64( pyTypeObjectType, subcIt->second, 0, 30)); } } // Call PyType_Ready for (auto &pytype : pymod.types) { auto it = typeVars.find(pytype.type); seqassertn(it != typeVars.end(), "type not found"); auto *typeVar = it->second; auto *fail = llvm::BasicBlock::Create(*context, "failure", pyModuleInit); block = llvm::BasicBlock::Create(*context, "success", pyModuleInit); auto *status = B->CreateCall(pyTypeReady, typeVar); B->CreateCondBr(B->CreateICmpSLT(status, zero32), fail, block); B->SetInsertPoint(fail); B->CreateRet(null); B->SetInsertPoint(block); } // Create module auto *mod = B->CreateCall(pyModuleCreate, {pyModuleVar, B->getInt32(PYEXT_PYTHON_ABI_VERSION)}); auto *fail = llvm::BasicBlock::Create(*context, "failure", pyModuleInit); block = llvm::BasicBlock::Create(*context, "success", pyModuleInit); B->CreateCondBr(B->CreateICmpEQ(mod, null), fail, block); B->SetInsertPoint(fail); B->CreateRet(null); B->SetInsertPoint(block); // Add types for (auto &pytype : pymod.types) { auto it = typeVars.find(pytype.type); seqassertn(it != typeVars.end(), "type not found"); auto *typeVar = it->second; B->CreateCall(pyIncRef, typeVar); auto *status = B->CreateCall(pyModuleAddObject, {mod, pyString(pytype.name), typeVar}); fail = llvm::BasicBlock::Create(*context, "failure", pyModuleInit); block = llvm::BasicBlock::Create(*context, "success", pyModuleInit); B->CreateCondBr(B->CreateICmpSLT(status, zero32), fail, block); B->SetInsertPoint(fail); B->CreateCall(pyDecRef, typeVar); B->CreateCall(pyDecRef, mod); B->CreateRet(null); B->SetInsertPoint(block); } B->CreateRet(mod); writeToObjectFile(filename); } void LLVMVisitor::compile(const std::string &filename, const std::string &argv0, const std::vector &libs, const std::string &lflags) { llvm::StringRef f(filename); if (f.ends_with(".ll")) { writeToLLFile(filename); } else if (f.ends_with(".bc")) { writeToBitcodeFile(filename); } else if (f.ends_with(".o") || f.ends_with(".obj")) { writeToObjectFile(filename); } else if (f.ends_with(".s") || f.ends_with(".S") || f.ends_with(".asm")) { writeToObjectFile(filename, /*pic=*/false, /*assembly=*/true); } else if (f.ends_with(".so") || f.ends_with(".dylib")) { writeToExecutable(filename, argv0, /*library=*/true, libs, lflags); } else { writeToExecutable(filename, argv0, /*library=*/false, libs, lflags); } } void LLVMVisitor::run(const std::vector &args, const std::vector &libs, const char *const *envp) { runLLVMPipeline(); Timer t1("llvm/jitlink"); for (auto &lib : libs) { std::string err; if (llvm::sys::DynamicLibrary::LoadLibraryPermanently(lib.c_str(), &err)) { compilationError(err); } } DebugPlugin *dbp = nullptr; llvm::Triple triple(M->getTargetTriple()); auto epc = llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create( std::make_shared())); llvm::orc::LLJITBuilder builder; builder.setDataLayout(M->getDataLayout()); builder.setObjectLinkingLayerCreator( [&epc, &dbp](llvm::orc::ExecutionSession &es, const llvm::Triple &triple) -> llvm::Expected> { auto L = std::make_unique( es, llvm::cantFail(BoehmGCJITLinkMemoryManager::Create())); L->addPlugin(std::make_unique( es, llvm::cantFail(llvm::orc::createJITLoaderGDBRegistrar(es)))); auto dbPlugin = std::make_unique(); dbp = dbPlugin.get(); L->addPlugin(std::move(dbPlugin)); return L; }); builder.setJITTargetMachineBuilder(llvm::orc::JITTargetMachineBuilder(triple)); auto jit = llvm::cantFail(builder.create()); jit->getMainJITDylib().addGenerator( llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( jit->getDataLayout().getGlobalPrefix()))); llvm::cantFail(jit->addIRModule({std::move(M), std::move(context)})); clearLLVMData(); auto mainAddr = llvm::cantFail(jit->lookup("main")); if (db.debug) { runtime::setJITErrorCallback([dbp](const runtime::JITError &e) { fmt::print(stderr, "{}\n{}", e.getOutput(), dbp->getPrettyBacktrace(e.getBacktrace())); std::abort(); }); } else { runtime::setJITErrorCallback([](const runtime::JITError &e) { fmt::print(stderr, "{}", e.getOutput()); std::abort(); }); } t1.log(); try { llvm::cantFail(epc->runAsMain(mainAddr, args)); } catch (const runtime::JITError &e) { fmt::print(stderr, "{}\n", e.getOutput()); std::abort(); } } #define ALLOC_FAMILY "seq_alloc" llvm::FunctionCallee LLVMVisitor::makeAllocFunc(bool atomic, bool uncollectable) { const std::string name = atomic ? (uncollectable ? "seq_alloc_atomic_uncollectable" : "seq_alloc_atomic") : (uncollectable ? "seq_alloc_uncollectable" : "seq_alloc"); auto f = M->getOrInsertFunction(name, B->getPtrTy(), B->getInt64Ty()); auto *g = cast(f.getCallee()); g->setDoesNotThrow(); g->setReturnDoesNotAlias(); g->setOnlyAccessesInaccessibleMemory(); g->addRetAttr(llvm::Attribute::AttrKind::NoUndef); g->addRetAttr(llvm::Attribute::AttrKind::NonNull); g->addFnAttrs( llvm::AttrBuilder(*context) .addAllocKindAttr(llvm::AllocFnKind::Alloc | llvm::AllocFnKind::Uninitialized) .addAllocSizeAttr(0, {}) .addAttribute("alloc-family", ALLOC_FAMILY)); return f; } llvm::FunctionCallee LLVMVisitor::makeReallocFunc() { // note that seq_realloc takes arguments (ptr, new_size, old_size) auto f = M->getOrInsertFunction("seq_realloc", B->getPtrTy(), B->getPtrTy(), B->getInt64Ty(), B->getInt64Ty()); auto *g = cast(f.getCallee()); g->setDoesNotThrow(); g->addRetAttr(llvm::Attribute::AttrKind::NoUndef); g->addRetAttr(llvm::Attribute::AttrKind::NonNull); g->addParamAttr(0, llvm::Attribute::AttrKind::AllocatedPointer); g->addFnAttrs(llvm::AttrBuilder(*context) .addAllocKindAttr(llvm::AllocFnKind::Realloc | llvm::AllocFnKind::Uninitialized) .addAllocSizeAttr(1, {}) .addAttribute("alloc-family", ALLOC_FAMILY)); return f; } llvm::FunctionCallee LLVMVisitor::makeFreeFunc() { auto f = M->getOrInsertFunction("seq_free", B->getVoidTy(), B->getPtrTy()); auto *g = cast(f.getCallee()); g->setDoesNotThrow(); g->addParamAttr(0, llvm::Attribute::AttrKind::AllocatedPointer); g->addFnAttrs(llvm::AttrBuilder(*context) .addAllocKindAttr(llvm::AllocFnKind::Free) .addAttribute("alloc-family", ALLOC_FAMILY)); return f; } #undef ALLOC_FAMILY llvm::FunctionCallee LLVMVisitor::makePersonalityFunc() { return M->getOrInsertFunction("seq_personality", B->getInt32Ty(), B->getInt32Ty(), B->getInt32Ty(), B->getInt64Ty(), B->getPtrTy(), B->getPtrTy()); } llvm::FunctionCallee LLVMVisitor::makeExcAllocFunc() { auto f = M->getOrInsertFunction("seq_alloc_exc", B->getPtrTy(), B->getPtrTy()); auto *g = cast(f.getCallee()); g->setDoesNotThrow(); return f; } llvm::FunctionCallee LLVMVisitor::makeThrowFunc() { auto f = M->getOrInsertFunction("seq_throw", B->getVoidTy(), B->getPtrTy()); auto *g = cast(f.getCallee()); g->setDoesNotReturn(); return f; } llvm::FunctionCallee LLVMVisitor::makeTerminateFunc() { auto f = M->getOrInsertFunction("seq_terminate", B->getVoidTy(), B->getPtrTy()); auto *g = cast(f.getCallee()); g->setDoesNotReturn(); return f; } llvm::StructType *LLVMVisitor::getTypeInfoType() { return llvm::StructType::get(B->getInt32Ty()); } llvm::StructType *LLVMVisitor::getPadType() { return llvm::StructType::get(B->getPtrTy(), B->getInt32Ty()); } namespace { int typeIdxLookup(types::Type *type) { if (!type) return 0; auto *M = type->getModule(); return M->getCache()->getRealizationId(type->getAstType()->getClass()); } } // namespace llvm::GlobalVariable *LLVMVisitor::getTypeIdxVar(types::Type *type) { auto *typeInfoType = getTypeInfoType(); const std::string name = type ? type->getName() : ""; const std::string typeVarName = "codon.typeidx." + (type ? name : ""); auto *tidx = M->getGlobalVariable(typeVarName); int idx = typeIdxLookup(type); if (!tidx) { tidx = new llvm::GlobalVariable( *M, typeInfoType, /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, llvm::ConstantStruct::get(typeInfoType, B->getInt32(idx)), typeVarName); tidx->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); } return tidx; } int LLVMVisitor::getTypeIdx(types::Type *catchType) { return typeIdxLookup(catchType); } llvm::Value *LLVMVisitor::call(llvm::FunctionCallee callee, llvm::ArrayRef args) { B->SetInsertPoint(block); if ((trycatch.empty() && finally.empty()) || DisableExceptions) { return B->CreateCall(callee, args); } else { auto *normalBlock = llvm::BasicBlock::Create(*context, "invoke.normal", func); // use non-empty of finally-stack and try-stack, or whichever is most recent if both // are non-empty auto *unwindBlock = finally.empty() ? trycatch.back().exceptionBlock : (trycatch.empty() ? finally.back().finallyExceptionBlock : (trycatch.back().sequenceNumber > finally.back().sequenceNumber ? trycatch.back().exceptionBlock : finally.back().finallyExceptionBlock)); auto *result = B->CreateInvoke(callee, normalBlock, unwindBlock, args); block = normalBlock; return result; } } static int nextSequenceNumber = 0; void LLVMVisitor::enterLoop(LoopData data) { loops.push_back(std::move(data)); loops.back().sequenceNumber = nextSequenceNumber++; } void LLVMVisitor::exitLoop() { seqassertn(!loops.empty(), "no loops present"); loops.pop_back(); } void LLVMVisitor::enterTry(TryCatchData data) { trycatch.push_back(std::move(data)); trycatch.back().sequenceNumber = nextSequenceNumber++; } void LLVMVisitor::exitTry() { seqassertn(!trycatch.empty(), "no try catches present"); trycatch.pop_back(); } void LLVMVisitor::enterFinally(TryCatchData data) { finally.push_back(std::move(data)); finally.back().sequenceNumber = nextSequenceNumber++; } void LLVMVisitor::exitFinally() { seqassertn(!finally.empty(), "no finally present"); finally.pop_back(); } void LLVMVisitor::enterCatch(CatchData data) { catches.push_back(std::move(data)); catches.back().sequenceNumber = nextSequenceNumber++; } void LLVMVisitor::exitCatch() { seqassertn(!catches.empty(), "no catches present"); catches.pop_back(); } LLVMVisitor::TryCatchData *LLVMVisitor::getInnermostTryCatch() { return trycatch.empty() ? nullptr : &trycatch.back(); } LLVMVisitor::TryCatchData *LLVMVisitor::getInnermostTryCatchBeforeLoop() { if (!trycatch.empty() && (loops.empty() || trycatch.back().sequenceNumber > loops.back().sequenceNumber)) return &trycatch.back(); return nullptr; } /* * General values, module, functions, vars */ void LLVMVisitor::visit(const Module *x) { // initialize module M = makeModule(*context, getSrcInfo(x)); // args variable seqassertn(x->getArgVar()->isGlobal(), "arg var is not global"); registerGlobal(x->getArgVar()); // set up global variables and initialize functions for (auto *var : *x) { registerGlobal(var); } // process functions for (auto *var : *x) { if (auto *f = cast(var)) { process(f); } } const Func *main = x->getMainFunc(); llvm::FunctionCallee realMain = makeLLVMFunction(main); process(main); setDebugInfoForNode(nullptr); // build canonical main function auto *strType = llvm::StructType::get(*context, {B->getInt64Ty(), B->getPtrTy()}); auto *arrType = llvm::StructType::get(*context, {B->getInt64Ty(), strType->getPointerTo()}); auto *initFunc = llvm::cast( M->getOrInsertFunction("seq_init", B->getVoidTy(), B->getInt32Ty()).getCallee()); auto *strlenFunc = llvm::cast( M->getOrInsertFunction("strlen", B->getInt64Ty(), B->getPtrTy()).getCallee()); // check if main exists already as an exported function const std::string mainName = M->getFunction("main") ? MAIN_UNCLASH : "main"; auto *canonicalMainFunc = llvm::cast( M->getOrInsertFunction(mainName, B->getInt32Ty(), B->getInt32Ty(), B->getPtrTy()->getPointerTo()) .getCallee()); canonicalMainFunc->setPersonalityFn( llvm::cast(makePersonalityFunc().getCallee())); auto argiter = canonicalMainFunc->arg_begin(); auto *argc = argiter++; auto *argv = argiter; argc->setName("argc"); argv->setName("argv"); // The following generates code to put program arguments in an array, i.e.: // for (int i = 0; i < argc; i++) // array[i] = {strlen(argv[i]), argv[i]} auto *entryBlock = llvm::BasicBlock::Create(*context, "entry", canonicalMainFunc); auto *loopBlock = llvm::BasicBlock::Create(*context, "loop", canonicalMainFunc); auto *bodyBlock = llvm::BasicBlock::Create(*context, "body", canonicalMainFunc); auto *exitBlock = llvm::BasicBlock::Create(*context, "exit", canonicalMainFunc); B->SetInsertPoint(entryBlock); auto allocFunc = makeAllocFunc(/*atomic=*/false); auto *len = B->CreateZExt(argc, B->getInt64Ty()); auto *elemSize = B->getInt64(M->getDataLayout().getTypeAllocSize(strType)); auto *allocSize = B->CreateMul(len, elemSize); auto *ptr = B->CreateCall(allocFunc, allocSize); llvm::Value *arr = llvm::UndefValue::get(arrType); arr = B->CreateInsertValue(arr, len, 0); arr = B->CreateInsertValue(arr, ptr, 1); B->CreateBr(loopBlock); B->SetInsertPoint(loopBlock); auto *control = B->CreatePHI(B->getInt32Ty(), 2, "i"); auto *next = B->CreateAdd(control, B->getInt32(1), "next"); auto *cond = B->CreateICmpSLT(control, argc); control->addIncoming(B->getInt32(0), entryBlock); control->addIncoming(next, bodyBlock); B->CreateCondBr(cond, bodyBlock, exitBlock); B->SetInsertPoint(bodyBlock); auto *arg = B->CreateLoad(B->getPtrTy(), B->CreateGEP(B->getPtrTy(), argv, control)); auto *argLen = B->CreateZExtOrTrunc(B->CreateCall(strlenFunc, arg), B->getInt64Ty()); llvm::Value *str = llvm::UndefValue::get(strType); str = B->CreateInsertValue(str, argLen, 0); str = B->CreateInsertValue(str, arg, 1); B->CreateStore(str, B->CreateGEP(strType, ptr, control)); B->CreateBr(loopBlock); B->SetInsertPoint(exitBlock); auto *argStorage = getVar(x->getArgVar()); seqassertn(argStorage, "argument storage missing"); B->CreateStore(arr, argStorage); const int flags = (db.debug ? SEQ_FLAG_DEBUG : 0) | (db.capture ? SEQ_FLAG_CAPTURE_OUTPUT : 0) | (db.standalone ? SEQ_FLAG_STANDALONE : 0); B->CreateCall(initFunc, B->getInt32(flags)); // Put the entire program in a new function { auto *proxyMainTy = llvm::FunctionType::get(B->getVoidTy(), {}, false); auto *proxyMain = llvm::cast( M->getOrInsertFunction("codon.proxy_main", proxyMainTy).getCallee()); proxyMain->setLinkage(llvm::GlobalValue::PrivateLinkage); proxyMain->setPersonalityFn( llvm::cast(makePersonalityFunc().getCallee())); auto *proxyBlockEntry = llvm::BasicBlock::Create(*context, "entry", proxyMain); auto *proxyBlockMain = llvm::BasicBlock::Create(*context, "main", proxyMain); auto *proxyBlockExit = llvm::BasicBlock::Create(*context, "exit", proxyMain); B->SetInsertPoint(proxyBlockEntry); auto *shouldExit = B->getFalse(); B->CreateCondBr(shouldExit, proxyBlockExit, proxyBlockMain); B->SetInsertPoint(proxyBlockExit); B->CreateRetVoid(); // invoke real main auto *normal = llvm::BasicBlock::Create(*context, "normal", proxyMain); auto *unwind = llvm::BasicBlock::Create(*context, "unwind", proxyMain); B->SetInsertPoint(proxyBlockMain); B->CreateInvoke(realMain, normal, unwind); B->SetInsertPoint(unwind); auto *caughtResult = B->CreateLandingPad(getPadType(), 1); caughtResult->setCleanup(true); caughtResult->addClause(getTypeIdxVar(nullptr)); auto *unwindException = B->CreateExtractValue(caughtResult, 0); B->CreateCall(makeTerminateFunc(), unwindException); B->CreateUnreachable(); B->SetInsertPoint(normal); B->CreateRetVoid(); // actually make the call B->SetInsertPoint(exitBlock); B->CreateCall(proxyMain); } B->SetInsertPoint(exitBlock); B->CreateRet(B->getInt32(0)); // make sure allocation functions have the correct attributes if (M->getFunction("seq_alloc")) makeAllocFunc(/*atomic=*/false, /*uncollectable=*/false); if (M->getFunction("seq_alloc_atomic")) makeAllocFunc(/*atomic=*/true, /*uncollectable=*/false); if (M->getFunction("seq_alloc_uncollectable")) makeAllocFunc(/*atomic=*/false, /*uncollectable=*/true); if (M->getFunction("seq_alloc_atomic_uncollectable")) makeAllocFunc(/*atomic=*/true, /*uncollectable=*/true); if (M->getFunction("seq_realloc")) makeReallocFunc(); if (M->getFunction("seq_free")) makeFreeFunc(); } llvm::DISubprogram *LLVMVisitor::getDISubprogramForFunc(const Func *x) { auto *srcInfo = getSrcInfo(x); auto *file = db.getFile(srcInfo->file); auto *derivedType = llvm::cast(getDIType(x->getType())); auto *subroutineType = llvm::cast(derivedType->getRawBaseType()); std::string baseName = x->getUnmangledName(); if (auto *parent = x->getParentType()) baseName = parent->getName() + "." + baseName; auto *subprogram = db.builder->createFunction( file, baseName, getNameForFunction(x), file, srcInfo->line, subroutineType, /*ScopeLine=*/0, llvm::DINode::FlagZero, llvm::DISubprogram::toSPFlags(/*IsLocalToUnit=*/true, /*IsDefinition=*/true, /*IsOptimized=*/!db.debug)); return subprogram; } llvm::Function *LLVMVisitor::makeLLVMFunction(const Func *x) { // process LLVM functions in full immediately if (auto *llvmFunc = cast(x)) { auto *oldFunc = func; process(llvmFunc); setDebugInfoForNode(nullptr); auto *newFunc = func; func = oldFunc; return newFunc; } auto *funcType = cast(x->getType()); auto *returnType = getLLVMType(funcType->getReturnType()); std::vector argTypes; for (const auto &argType : *funcType) { argTypes.push_back(getLLVMType(argType)); } auto *llvmFuncType = llvm::FunctionType::get(returnType, argTypes, funcType->isVariadic()); const std::string functionName = getNameForFunction(x); auto *f = llvm::cast( M->getOrInsertFunction(functionName, llvmFuncType).getCallee()); if (!cast(x)) { f->setSubprogram(getDISubprogramForFunc(x)); } return f; } void LLVMVisitor::makeYield(llvm::Value *value, bool finalYield) { B->SetInsertPoint(block); if (value) { seqassertn(coro.promise, "promise is null"); B->CreateStore(value, coro.promise); } llvm::FunctionCallee coroSuspend = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_suspend); auto *suspendResult = B->CreateCall( coroSuspend, {llvm::ConstantTokenNone::get(*context), B->getInt1(finalYield)}); block = llvm::BasicBlock::Create(*context, "yield.new", func); seqassertn(coro.suspend && coro.cleanup, "suspend and/or cleanup is null"); auto *inst = B->CreateSwitch(suspendResult, coro.suspend, 2); inst->addCase(B->getInt8(0), block); inst->addCase(B->getInt8(1), coro.cleanup); } void LLVMVisitor::visit(const ExternalFunc *x) { func = M->getFunction(getNameForFunction(x)); if (!func) { func = makeLLVMFunction(x); insertFunc(x, func); } coro = {}; func->setDoesNotThrow(); func->setWillReturn(); } namespace { // internal function type checking template bool internalFuncMatchesIgnoreArgs(const std::string &name, const InternalFunc *x) { return name == x->getUnmangledName() && cast(x->getParentType()); } template bool internalFuncMatches(const std::string &name, const InternalFunc *x, std::index_sequence) { auto *funcType = cast(x->getType()); if (name != x->getUnmangledName() || std::distance(funcType->begin(), funcType->end()) != sizeof...(ArgTypes)) return false; std::vector argTypes(funcType->begin(), funcType->end()); std::vector m = {bool(cast(x->getParentType())), bool(cast(argTypes[Index]))...}; const bool match = std::all_of(m.begin(), m.end(), [](bool b) { return b; }); return match; } template bool internalFuncMatches(const std::string &name, const InternalFunc *x) { return internalFuncMatches( name, x, std::make_index_sequence()); } } // namespace void LLVMVisitor::visit(const InternalFunc *x) { using namespace types; func = M->getFunction(getNameForFunction(x)); coro = {}; seqassertn(func, "{} not inserted", *x); setDebugInfoForNode(x); Type *parentType = x->getParentType(); auto *funcType = cast(x->getType()); std::vector argTypes(funcType->begin(), funcType->end()); func->setLinkage(llvm::GlobalValue::PrivateLinkage); func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline); std::vector args; for (auto it = func->arg_begin(); it != func->arg_end(); ++it) { args.push_back(it); } block = llvm::BasicBlock::Create(*context, "entry", func); B->SetInsertPoint(block); llvm::Value *result = nullptr; if (internalFuncMatches("__new__", x)) { auto *pointerType = cast(parentType); Type *baseType = pointerType->getBase(); auto *llvmBaseType = getLLVMType(baseType); auto allocFunc = makeAllocFunc(baseType->isAtomic()); auto *elemSize = B->getInt64(M->getDataLayout().getTypeAllocSize(llvmBaseType)); auto *allocSize = B->CreateMul(elemSize, args[0]); result = B->CreateCall(allocFunc, allocSize); } else if (internalFuncMatches("__new__", x)) { auto *intNType = cast(argTypes[0]); if (intNType->isSigned()) { result = B->CreateSExtOrTrunc(args[0], B->getInt64Ty()); } else { result = B->CreateZExtOrTrunc(args[0], B->getInt64Ty()); } } else if (internalFuncMatches("__new__", x)) { auto *intNType = cast(parentType); if (intNType->isSigned()) { result = B->CreateSExtOrTrunc(args[0], getLLVMType(intNType)); } else { result = B->CreateZExtOrTrunc(args[0], getLLVMType(intNType)); } } else if (internalFuncMatches("__promise__", x)) { auto *generatorType = cast(parentType); auto *baseType = getLLVMType(generatorType->getBase()); if (baseType->isVoidTy()) { result = llvm::ConstantPointerNull::get(B->getVoidTy()->getPointerTo()); } else { llvm::FunctionCallee coroPromise = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_promise); auto *aln = B->getInt32(M->getDataLayout().getPrefTypeAlign(baseType).value()); auto *from = B->getFalse(); auto *ptr = B->CreateCall(coroPromise, {args[0], aln, from}); result = ptr; } } else if (internalFuncMatchesIgnoreArgs("__new__", x)) { auto *recordType = cast(cast(x->getType())->getReturnType()); seqassertn(args.size() == std::distance(recordType->begin(), recordType->end()), "args size does not match: {} vs {}", args.size(), std::distance(recordType->begin(), recordType->end())); result = llvm::UndefValue::get(getLLVMType(recordType)); for (auto i = 0; i < args.size(); i++) { result = B->CreateInsertValue(result, args[i], i); } } seqassertn(result, "internal function {} not found", *x); B->CreateRet(result); } std::string LLVMVisitor::buildLLVMCodeString(const LLVMFunc *x) { auto *funcType = cast(x->getType()); seqassertn(funcType, "{} is not a function type", *x->getType()); std::string bufStr; llvm::raw_string_ostream buf(bufStr); // build function signature buf << "define "; getLLVMType(funcType->getReturnType())->print(buf); buf << " @\"" << getNameForFunction(x) << "\"("; const int numArgs = std::distance(x->arg_begin(), x->arg_end()); int argIndex = 0; for (auto it = x->arg_begin(); it != x->arg_end(); ++it) { getLLVMType((*it)->getType())->print(buf); buf << " %" << (*it)->getName(); if (argIndex < numArgs - 1) buf << ", "; ++argIndex; } buf << ")"; std::string signature = buf.str(); bufStr.clear(); // replace literal '{' and '}' std::string::size_type n = 0; while ((n = signature.find("{", n)) != std::string::npos) { signature.replace(n, 1, "{{"); n += 2; } n = 0; while ((n = signature.find("}", n)) != std::string::npos) { signature.replace(n, 1, "}}"); n += 2; } // build remaining code auto body = x->getLLVMBody(); buf << x->getLLVMDeclarations() << "\n" << signature << " {{\n" << body << "\n}}"; return buf.str(); } void LLVMVisitor::visit(const LLVMFunc *x) { func = M->getFunction(getNameForFunction(x)); coro = {}; if (func) return; // build code std::string code = buildLLVMCodeString(x); // format code fmt::dynamic_format_arg_store store; for (auto it = x->literal_begin(); it != x->literal_end(); ++it) { if (it->isStatic()) { store.push_back(it->getStaticValue()); } else if (it->isStaticStr()) { store.push_back(it->getStaticStringValue()); } else if (it->isType()) { auto *llvmType = getLLVMType(it->getTypeValue()); std::string bufStr; llvm::raw_string_ostream buf(bufStr); llvmType->print(buf); store.push_back(buf.str()); } else { seqassertn(0, "formatting failed"); } } code = fmt::vformat(code, store); llvm::SMDiagnostic err; std::unique_ptr buf = llvm::MemoryBuffer::getMemBuffer(code); seqassertn(buf, "could not create buffer"); std::unique_ptr sub = llvm::parseIR(buf->getMemBufferRef(), err, *context); if (!sub) { std::string bufStr; llvm::raw_string_ostream buf(bufStr); err.print("LLVM", buf); compilationError(fmt::format("{} ({})", buf.str(), x->getName())); } sub->setDataLayout(M->getDataLayout()); llvm::Linker L(*M); const bool fail = L.linkInModule(std::move(sub)); seqassertn(!fail, "linking failed"); func = M->getFunction(getNameForFunction(x)); seqassertn(func, "function not linked in"); func->setLinkage(llvm::GlobalValue::PrivateLinkage); func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline); func->setSubprogram(getDISubprogramForFunc(x)); // set up debug info // for now we just set all to func's source location auto *srcInfo = getSrcInfo(x); for (auto &block : *func) { for (auto &inst : block) { if (!inst.getDebugLoc()) { inst.setDebugLoc(llvm::DebugLoc(llvm::DILocation::get( *context, srcInfo->line, srcInfo->col, func->getSubprogram()))); } } } } void LLVMVisitor::visit(const BodiedFunc *x) { func = M->getFunction(getNameForFunction(x)); coro = {}; seqassertn(func, "{} not inserted", *x); setDebugInfoForNode(x); auto *fnAttributes = x->getAttribute(); if (x->isJIT()) { func->addFnAttr(llvm::Attribute::get(*context, "jit")); } if (x->isJIT() || (fnAttributes && fnAttributes->has(EXPORT_ATTR))) { func->setLinkage(llvm::GlobalValue::ExternalLinkage); } else { func->setLinkage(llvm::GlobalValue::PrivateLinkage); } if (fnAttributes && fnAttributes->has(INLINE_ATTR)) { func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline); } if (fnAttributes && fnAttributes->has(NOINLINE_ATTR)) { func->addFnAttr(llvm::Attribute::AttrKind::NoInline); } if (fnAttributes && fnAttributes->has(GPU_KERNEL_ATTR)) { func->addFnAttr(llvm::Attribute::AttrKind::NoInline); func->addFnAttr(llvm::Attribute::get(*context, "kernel")); func->setLinkage(llvm::GlobalValue::ExternalLinkage); } if (!DisableExceptions) func->setPersonalityFn( llvm::cast(makePersonalityFunc().getCallee())); auto *funcType = cast(x->getType()); seqassertn(funcType, "{} is not a function type", *x->getType()); auto *returnType = funcType->getReturnType(); auto *entryBlock = llvm::BasicBlock::Create(*context, "entry", func); B->SetInsertPoint(entryBlock); // set up arguments and other symbols seqassertn(std::distance(func->arg_begin(), func->arg_end()) == std::distance(x->arg_begin(), x->arg_end()), "argument length does not match"); unsigned argIdx = 1; auto argIter = func->arg_begin(); for (auto varIter = x->arg_begin(); varIter != x->arg_end(); ++varIter) { const Var *var = *varIter; auto *storage = B->CreateAlloca(getLLVMType(var->getType())); B->CreateStore(argIter, storage); insertVar(var, storage); // debug info auto *srcInfo = getSrcInfo(var); auto *file = db.getFile(srcInfo->file); auto *scope = func->getSubprogram(); auto *debugVar = db.builder->createParameterVariable( scope, getDebugNameForVariable(var), argIdx, file, srcInfo->line, getDIType(var->getType()), db.debug); db.builder->insertDeclare( storage, debugVar, db.builder->createExpression(), llvm::DILocation::get(*context, srcInfo->line, srcInfo->col, scope), entryBlock); ++argIter; ++argIdx; } for (auto *var : *x) { auto *llvmType = getLLVMType(var->getType()); if (llvmType->isVoidTy()) { insertVar(var, getDummyVoidValue()); } else { auto *storage = B->CreateAlloca(llvmType); insertVar(var, storage); // debug info auto *srcInfo = getSrcInfo(var); auto *file = db.getFile(srcInfo->file); auto *scope = func->getSubprogram(); auto *debugVar = db.builder->createAutoVariable( scope, getDebugNameForVariable(var), file, srcInfo->line, getDIType(var->getType()), db.debug); db.builder->insertDeclare( storage, debugVar, db.builder->createExpression(), llvm::DILocation::get(*context, srcInfo->line, srcInfo->col, scope), entryBlock); } } auto *startBlock = llvm::BasicBlock::Create(*context, "start", func); const bool generator = x->isGenerator() || x->isAsync(); if (generator) { func->setPresplitCoroutine(); auto *generatorType = cast(returnType); seqassertn(generatorType, "{} is not a generator type", *returnType); llvm::FunctionCallee coroId = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_id); llvm::FunctionCallee coroBegin = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_begin); llvm::FunctionCallee coroSize = llvm::Intrinsic::getDeclaration( M.get(), llvm::Intrinsic::coro_size, {B->getInt64Ty()}); llvm::FunctionCallee coroEnd = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_end); llvm::FunctionCallee coroAlloc = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_alloc); llvm::FunctionCallee coroFree = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_free); coro.cleanup = llvm::BasicBlock::Create(*context, "coro.cleanup", func); coro.suspend = llvm::BasicBlock::Create(*context, "coro.suspend", func); coro.exit = llvm::BasicBlock::Create(*context, "coro.exit", func); coro.async = x->isAsync(); auto *allocBlock = llvm::BasicBlock::Create(*context, "coro.alloc", func); auto *freeBlock = llvm::BasicBlock::Create(*context, "coro.free", func); // coro ID and promise llvm::Value *id = nullptr; auto *nullPtr = llvm::ConstantPointerNull::get(B->getPtrTy()); if (!cast(generatorType->getBase())) { coro.promise = B->CreateAlloca(getLLVMType(generatorType->getBase())); coro.promise->setName("coro.promise"); id = B->CreateCall(coroId, {B->getInt32(0), coro.promise, nullPtr, nullPtr}); } else { id = B->CreateCall(coroId, {B->getInt32(0), nullPtr, nullPtr, nullPtr}); } id->setName("coro.id"); auto *needAlloc = B->CreateCall(coroAlloc, id); B->CreateCondBr(needAlloc, allocBlock, startBlock); // coro alloc B->SetInsertPoint(allocBlock); auto *size = B->CreateCall(coroSize); auto allocFunc = makeAllocFunc(/*atomic=*/false); auto *alloc = B->CreateCall(allocFunc, size); B->CreateBr(startBlock); // coro start B->SetInsertPoint(startBlock); auto *phi = B->CreatePHI(B->getPtrTy(), 2); phi->addIncoming(nullPtr, entryBlock); phi->addIncoming(alloc, allocBlock); coro.handle = B->CreateCall(coroBegin, {id, phi}); coro.handle->setName("coro.handle"); // coro cleanup B->SetInsertPoint(coro.cleanup); auto *mem = B->CreateCall(coroFree, {id, coro.handle}); auto *needFree = B->CreateIsNotNull(mem); B->CreateCondBr(needFree, freeBlock, coro.suspend); // coro free B->SetInsertPoint(freeBlock); // no-op: GC will free automatically B->CreateBr(coro.suspend); // coro suspend B->SetInsertPoint(coro.suspend); B->CreateCall(coroEnd, {coro.handle, B->getFalse(), llvm::ConstantTokenNone::get(*context)}); B->CreateRet(coro.handle); // coro exit block = coro.exit; makeYield(nullptr, /*finalYield=*/true); B->SetInsertPoint(block); B->CreateUnreachable(); // initial yield block = startBlock; makeYield(); // coroutine will be initially suspended } else { B->CreateBr(startBlock); block = startBlock; } seqassertn(x->getBody(), "{} has no body [{}]", x->getName(), x->getSrcInfo()); process(x->getBody()); B->SetInsertPoint(block); if (generator) { B->CreateBr(coro.exit); } else { if (cast(returnType)) { B->CreateRetVoid(); } else { B->CreateRet(llvm::Constant::getNullValue(getLLVMType(returnType))); } } } void LLVMVisitor::visit(const Var *x) { seqassertn(0, "cannot visit var"); } void LLVMVisitor::visit(const VarValue *x) { if (auto *f = cast(x->getVar())) { value = getFunc(f); seqassertn(value, "{} value not found", *x); } else { auto *varPtr = getVar(x->getVar()); seqassertn(varPtr, "{} value not found", *x); B->SetInsertPoint(block); if (x->getVar()->isThreadLocal()) varPtr = B->CreateThreadLocalAddress(varPtr); value = B->CreateLoad(getLLVMType(x->getType()), varPtr); } } void LLVMVisitor::visit(const PointerValue *x) { const auto &fields = x->getFields(); auto *var = getVar(x->getVar()); seqassertn(var, "{} variable not found", *x); B->SetInsertPoint(block); if (x->getVar()->isThreadLocal()) var = B->CreateThreadLocalAddress(var); if (fields.empty()) { value = var; // note: we don't load the pointer return; } auto *type = x->getVar()->getType(); std::vector gepIndices = {B->getInt32(0)}; for (auto &field : x->getFields()) { if (auto *ref = cast(type)) { auto membIndex = ref->getMemberIndex(field); auto membType = ref->getMemberType(field); seqassertn(membIndex >= 0 && membType, "field {} not found in referecne type", field); gepIndices.push_back(B->getInt32(0)); gepIndices.push_back(B->getInt32(membIndex)); type = membType; } else if (auto *rec = cast(type)) { auto membIndex = rec->getMemberIndex(field); auto membType = rec->getMemberType(field); seqassertn(membIndex >= 0 && membType, "field {} not found in record type", field); gepIndices.push_back(B->getInt32(membIndex)); type = membType; } else { seqassertn(false, "type in pointer value was not a record or reference type"); } } value = B->CreateInBoundsGEP(getLLVMType(x->getVar()->getType()), var, gepIndices); } /* * Types */ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) { if (auto *x = cast(t)) { return B->getInt64Ty(); } if (auto *x = cast(t)) { return B->getDoubleTy(); } if (auto *x = cast(t)) { return B->getFloatTy(); } if (auto *x = cast(t)) { return B->getHalfTy(); } if (auto *x = cast(t)) { return B->getBFloatTy(); } if (auto *x = cast(t)) { return llvm::Type::getFP128Ty(*context); } if (auto *x = cast(t)) { return B->getInt8Ty(); } if (auto *x = cast(t)) { return B->getInt8Ty(); } if (auto *x = cast(t)) { return B->getVoidTy(); } if (auto *x = cast(t)) { std::vector body; for (const auto &field : *x) { body.push_back(getLLVMType(field.getType())); } return llvm::StructType::get(*context, body); } if (auto *x = cast(t)) { return B->getPtrTy(); } if (auto *x = cast(t)) { return getLLVMFuncType(x)->getPointerTo(); } if (auto *x = cast(t)) { if (cast(x->getBase())) { return getLLVMType(x->getBase()); } else { return llvm::StructType::get(B->getInt1Ty(), getLLVMType(x->getBase())); } } if (auto *x = cast(t)) { return getLLVMType(x->getBase())->getPointerTo(); } if (auto *x = cast(t)) { return B->getPtrTy(); } if (auto *x = cast(t)) { return B->getIntNTy(x->getLen()); } if (auto *x = cast(t)) { return llvm::VectorType::get(getLLVMType(x->getBase()), x->getCount(), /*Scalable=*/false); } if (auto *x = cast(t)) { auto &layout = M->getDataLayout(); llvm::Type *largest = nullptr; size_t maxSize = 0; for (auto *t : *x) { auto *llvmType = getLLVMType(t); size_t size = layout.getTypeAllocSizeInBits(llvmType); if (!largest || size > maxSize) { largest = llvmType; maxSize = size; } } if (!largest) largest = llvm::StructType::get(*context, {}); return llvm::StructType::get(*context, {B->getInt8Ty(), largest}); } if (auto *x = cast(t)) { return x->getBuilder()->buildType(this); } seqassertn(0, "unknown type: {}", *t); return nullptr; } llvm::FunctionType *LLVMVisitor::getLLVMFuncType(types::Type *t) { auto *x = cast(t); seqassertn(x, "input type was not a func type"); auto *returnType = getLLVMType(x->getReturnType()); std::vector argTypes; for (auto *argType : *x) { argTypes.push_back(getLLVMType(argType)); } return llvm::FunctionType::get(returnType, argTypes, x->isVariadic()); } llvm::DIType *LLVMVisitor::getDITypeHelper( types::Type *t, std::unordered_map &cache) { auto *type = getLLVMType(t); auto &layout = M->getDataLayout(); if (auto *x = cast(t)) { return db.builder->createBasicType( x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_signed); } if (auto *x = cast(t)) { return db.builder->createBasicType( x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float); } if (auto *x = cast(t)) { return db.builder->createBasicType( x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float); } if (auto *x = cast(t)) { return db.builder->createBasicType( x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float); } if (auto *x = cast(t)) { return db.builder->createBasicType( x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float); } if (auto *x = cast(t)) { return db.builder->createBasicType(x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_HP_float128); } if (auto *x = cast(t)) { return db.builder->createBasicType( x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_boolean); } if (auto *x = cast(t)) { return db.builder->createBasicType(x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_signed_char); } if (auto *x = cast(t)) { return nullptr; } if (auto *x = cast(t)) { auto it = cache.find(x->getName()); if (it != cache.end()) { return it->second; } else { auto *structType = llvm::cast(type); auto *structLayout = layout.getStructLayout(structType); auto *srcInfo = getSrcInfo(x); auto *memberInfo = x->getAttribute(); auto *file = db.getFile(srcInfo->file); std::vector members; auto *diType = db.builder->createStructType( file, x->getName(), file, srcInfo->line, structLayout->getSizeInBits(), /*AlignInBits=*/0, llvm::DINode::FlagZero, /*DerivedFrom=*/nullptr, db.builder->getOrCreateArray(members)); // prevent infinite recursion on recursive types cache.emplace(x->getName(), diType); unsigned memberIdx = 0; for (const auto &field : *x) { auto *subSrcInfo = srcInfo; auto *subFile = file; if (memberInfo) { auto it = memberInfo->memberSrcInfo.find(field.getName()); if (it != memberInfo->memberSrcInfo.end()) { subSrcInfo = &it->second; subFile = db.getFile(subSrcInfo->file); } } members.push_back(db.builder->createMemberType( diType, field.getName(), subFile, subSrcInfo->line, layout.getTypeAllocSizeInBits(getLLVMType(field.getType())), /*AlignInBits=*/0, structLayout->getElementOffsetInBits(memberIdx), llvm::DINode::FlagZero, getDITypeHelper(field.getType(), cache))); ++memberIdx; } db.builder->replaceArrays(diType, db.builder->getOrCreateArray(members)); return diType; } } if (auto *x = cast(t)) { auto *ref = db.builder->createReferenceType( llvm::dwarf::DW_TAG_reference_type, getDITypeHelper(x->getContents(), cache)); return ref; } if (auto *x = cast(t)) { std::vector argTypes = { getDITypeHelper(x->getReturnType(), cache)}; for (auto *argType : *x) { argTypes.push_back(getDITypeHelper(argType, cache)); } return db.builder->createPointerType( db.builder->createSubroutineType(llvm::MDTuple::get(*context, argTypes)), layout.getTypeAllocSizeInBits(type)); } if (auto *x = cast(t)) { if (cast(x->getBase())) { return getDITypeHelper(x->getBase(), cache); } else { auto *baseType = getLLVMType(x->getBase()); auto *structType = llvm::StructType::get(B->getInt1Ty(), baseType); auto *structLayout = layout.getStructLayout(structType); auto *srcInfo = getSrcInfo(x); auto i1SizeInBits = layout.getTypeAllocSizeInBits(B->getInt1Ty()); auto *i1DebugType = db.builder->createBasicType("i1", i1SizeInBits, llvm::dwarf::DW_ATE_boolean); auto *file = db.getFile(srcInfo->file); std::vector members; auto *diType = db.builder->createStructType( file, x->getName(), file, srcInfo->line, structLayout->getSizeInBits(), /*AlignInBits=*/0, llvm::DINode::FlagZero, /*DerivedFrom=*/nullptr, db.builder->getOrCreateArray(members)); members.push_back(db.builder->createMemberType( diType, "has", file, srcInfo->line, i1SizeInBits, /*AlignInBits=*/0, structLayout->getElementOffsetInBits(0), llvm::DINode::FlagZero, i1DebugType)); members.push_back(db.builder->createMemberType( diType, "val", file, srcInfo->line, layout.getTypeAllocSizeInBits(baseType), /*AlignInBits=*/0, structLayout->getElementOffsetInBits(1), llvm::DINode::FlagZero, getDITypeHelper(x->getBase(), cache))); db.builder->replaceArrays(diType, db.builder->getOrCreateArray(members)); return diType; } } if (auto *x = cast(t)) { return db.builder->createPointerType(getDITypeHelper(x->getBase(), cache), layout.getTypeAllocSizeInBits(type)); } if (auto *x = cast(t)) { return db.builder->createBasicType( x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_address); } if (auto *x = cast(t)) { return db.builder->createBasicType( x->getName(), layout.getTypeAllocSizeInBits(type), x->isSigned() ? llvm::dwarf::DW_ATE_signed : llvm::dwarf::DW_ATE_unsigned); } if (auto *x = cast(t)) { return db.builder->createBasicType(x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_unsigned); } if (auto *x = cast(t)) { return db.builder->createBasicType(x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_unsigned); } if (auto *x = cast(t)) { return x->getBuilder()->buildDebugType(this); } seqassertn(0, "unknown type"); return nullptr; } llvm::DIType *LLVMVisitor::getDIType(types::Type *t) { std::unordered_map cache; return getDITypeHelper(t, cache); } LLVMVisitor::LoopData *LLVMVisitor::getLoopData(id_t loopId) { for (auto &d : loops) { if (d.loopId == loopId) return &d; } return nullptr; } /* * Constants */ void LLVMVisitor::visit(const IntConst *x) { B->SetInsertPoint(block); value = B->getInt64(x->getVal()); } void LLVMVisitor::visit(const FloatConst *x) { B->SetInsertPoint(block); value = llvm::ConstantFP::get(B->getDoubleTy(), x->getVal()); } void LLVMVisitor::visit(const BoolConst *x) { B->SetInsertPoint(block); value = B->getInt8(x->getVal() ? 1 : 0); } void LLVMVisitor::visit(const StringConst *x) { B->SetInsertPoint(block); std::string s = x->getVal(); auto *strVar = new llvm::GlobalVariable(*M, llvm::ArrayType::get(B->getInt8Ty(), s.length() + 1), /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, llvm::ConstantDataArray::getString(*context, s), ".str"); strVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); auto *strType = llvm::StructType::get(B->getInt64Ty(), B->getPtrTy()); auto *ptr = B->CreateBitCast(strVar, B->getPtrTy()); auto *len = B->getInt64(s.length()); llvm::Value *str = llvm::UndefValue::get(strType); str = B->CreateInsertValue(str, len, 0); str = B->CreateInsertValue(str, ptr, 1); value = str; } void LLVMVisitor::visit(const dsl::CustomConst *x) { x->getBuilder()->buildValue(this); } /* * Control flow */ void LLVMVisitor::visit(const SeriesFlow *x) { for (auto *value : *x) { process(value); } } void LLVMVisitor::visit(const IfFlow *x) { auto *trueBlock = llvm::BasicBlock::Create(*context, "if.true", func); auto *falseBlock = llvm::BasicBlock::Create(*context, "if.false", func); auto *exitBlock = llvm::BasicBlock::Create(*context, "if.exit", func); process(x->getCond()); auto *cond = value; B->SetInsertPoint(block); cond = B->CreateTrunc(cond, B->getInt1Ty()); B->CreateCondBr(cond, trueBlock, falseBlock); block = trueBlock; if (x->getTrueBranch()) { process(x->getTrueBranch()); } B->SetInsertPoint(block); B->CreateBr(exitBlock); block = falseBlock; if (x->getFalseBranch()) { process(x->getFalseBranch()); } B->SetInsertPoint(block); B->CreateBr(exitBlock); block = exitBlock; } void LLVMVisitor::visit(const WhileFlow *x) { auto *condBlock = llvm::BasicBlock::Create(*context, "while.cond", func); auto *bodyBlock = llvm::BasicBlock::Create(*context, "while.body", func); auto *exitBlock = llvm::BasicBlock::Create(*context, "while.exit", func); B->SetInsertPoint(block); B->CreateBr(condBlock); block = condBlock; process(x->getCond()); auto *cond = value; B->SetInsertPoint(block); cond = B->CreateTrunc(cond, B->getInt1Ty()); B->CreateCondBr(cond, bodyBlock, exitBlock); block = bodyBlock; enterLoop( {/*breakBlock=*/exitBlock, /*continueBlock=*/condBlock, /*loopId=*/x->getId()}); process(x->getBody()); exitLoop(); B->SetInsertPoint(block); B->CreateBr(condBlock); block = exitBlock; } void LLVMVisitor::visit(const ForFlow *x) { seqassertn(!x->isParallel(), "parallel for-loop not lowered"); auto *loopVarType = getLLVMType(x->getVar()->getType()); auto *loopVar = getVar(x->getVar()); seqassertn(loopVar, "{} loop variable not found", *x); auto *condBlock = llvm::BasicBlock::Create(*context, "for.cond", func); auto *bodyBlock = llvm::BasicBlock::Create(*context, "for.body", func); auto *cleanupBlock = llvm::BasicBlock::Create(*context, "for.cleanup", func); auto *exitBlock = llvm::BasicBlock::Create(*context, "for.exit", func); // LLVM coroutine intrinsics // https://prereleases.llvm.org/6.0.0/rc3/docs/Coroutines.html llvm::FunctionCallee coroResume = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_resume); llvm::FunctionCallee coroDone = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_done); llvm::FunctionCallee coroPromise = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_promise); llvm::FunctionCallee coroDestroy = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_destroy); process(x->getIter()); auto *iter = value; B->SetInsertPoint(block); B->CreateBr(condBlock); block = condBlock; call(coroResume, {iter}); B->SetInsertPoint(block); auto *done = B->CreateCall(coroDone, iter); B->CreateCondBr(done, cleanupBlock, bodyBlock); if (!loopVarType->isVoidTy()) { B->SetInsertPoint(bodyBlock); auto *alignment = B->getInt32(M->getDataLayout().getPrefTypeAlign(loopVarType).value()); auto *from = B->getFalse(); auto *promise = B->CreateCall(coroPromise, {iter, alignment, from}); auto *generatedValue = B->CreateLoad(loopVarType, promise); B->CreateStore(generatedValue, loopVar); } block = bodyBlock; enterLoop( {/*breakBlock=*/exitBlock, /*continueBlock=*/condBlock, /*loopId=*/x->getId()}); process(x->getBody()); exitLoop(); B->SetInsertPoint(block); B->CreateBr(condBlock); B->SetInsertPoint(cleanupBlock); B->CreateCall(coroDestroy, iter); B->CreateBr(exitBlock); block = exitBlock; } void LLVMVisitor::visit(const ImperativeForFlow *x) { seqassertn(!x->isParallel(), "parallel for-loop not lowered"); auto *loopVar = getVar(x->getVar()); seqassertn(loopVar, "{} loop variable not found", *x); seqassertn(x->getStep() != 0, "step cannot be 0"); auto *condBlock = llvm::BasicBlock::Create(*context, "imp_for.cond", func); auto *bodyBlock = llvm::BasicBlock::Create(*context, "imp_for.body", func); auto *updateBlock = llvm::BasicBlock::Create(*context, "imp_for.update", func); auto *exitBlock = llvm::BasicBlock::Create(*context, "imp_for.exit", func); process(x->getStart()); auto *start = value; process(x->getEnd()); auto *end = value; B->SetInsertPoint(block); B->CreateBr(condBlock); B->SetInsertPoint(condBlock); auto *phi = B->CreatePHI(B->getInt64Ty(), 2); phi->addIncoming(start, block); auto *done = (x->getStep() > 0) ? B->CreateICmpSGE(phi, end) : B->CreateICmpSLE(phi, end); B->CreateCondBr(done, exitBlock, bodyBlock); B->SetInsertPoint(bodyBlock); B->CreateStore(phi, loopVar); block = bodyBlock; enterLoop( {/*breakBlock=*/exitBlock, /*continueBlock=*/updateBlock, /*loopId=*/x->getId()}); process(x->getBody()); exitLoop(); B->SetInsertPoint(block); B->CreateBr(updateBlock); B->SetInsertPoint(updateBlock); phi->addIncoming(B->CreateAdd(phi, B->getInt64(x->getStep())), updateBlock); B->CreateBr(condBlock); block = exitBlock; } namespace { bool anyMatch(types::Type *type, std::vector types) { if (type) { for (auto *t : types) { if (t && t->getName() == type->getName()) return true; } } else { for (auto *t : types) { if (!t) return true; } } return false; } } // namespace void LLVMVisitor::visit(const TryCatchFlow *x) { const bool isRoot = trycatch.empty(); const bool supportBreakAndContinue = !loops.empty(); auto *finallyFlow = cast(x->getFinally()); const bool haveFinally = (finallyFlow && finallyFlow->begin() != finallyFlow->end()); B->SetInsertPoint(block); auto *entryBlock = llvm::BasicBlock::Create(*context, "trycatch.entry", func); B->CreateBr(entryBlock); TryCatchData tc; tc.exceptionBlock = llvm::BasicBlock::Create(*context, "trycatch.exception", func); tc.exceptionRouteBlock = llvm::BasicBlock::Create(*context, "trycatch.exception_route", func); tc.finallyBlock = llvm::BasicBlock::Create(*context, "trycatch.finally", func); tc.finallyExceptionBlock = llvm::BasicBlock::Create(*context, "trycatch.finally.exception", func); auto *externalExcBlock = llvm::BasicBlock::Create(*context, "trycatch.exception_external", func); auto *unwindResumeBlock = llvm::BasicBlock::Create(*context, "trycatch.unwind_resume", func); auto *rethrowBlock = llvm::BasicBlock::Create(*context, "trycatch.rethrow", func); auto *endBlock = llvm::BasicBlock::Create(*context, "trycatch.end", func); B->SetInsertPoint(func->getEntryBlock().getTerminator()); auto *excStateNotThrown = B->getInt8(TryCatchData::State::NOT_THROWN); auto *excStateThrown = B->getInt8(TryCatchData::State::THROWN); auto *excStateCaught = B->getInt8(TryCatchData::State::CAUGHT); auto *excStateReturn = B->getInt8(TryCatchData::State::RETURN); auto *excStateBreak = B->getInt8(TryCatchData::State::BREAK); auto *excStateContinue = B->getInt8(TryCatchData::State::CONTINUE); auto *excStateRethrow = B->getInt8(TryCatchData::State::RETHROW); auto *padType = getPadType(); auto *unwindType = llvm::StructType::get(B->getInt64Ty()); // header only if (isRoot) { tc.excFlag = B->CreateAlloca(B->getInt8Ty()); tc.catchStore = B->CreateAlloca(padType); tc.delegateDepth = B->CreateAlloca(B->getInt64Ty()); tc.retStore = (coro.exit || func->getReturnType()->isVoidTy()) ? nullptr : B->CreateAlloca(func->getReturnType()); tc.loopSequence = B->CreateAlloca(B->getInt64Ty()); B->CreateStore(excStateNotThrown, tc.excFlag); B->CreateStore(llvm::ConstantAggregateZero::get(padType), tc.catchStore); B->CreateStore(B->getInt64(0), tc.delegateDepth); B->CreateStore(B->getInt64(-1), tc.loopSequence); } else { tc.excFlag = trycatch[0].excFlag; tc.catchStore = trycatch[0].catchStore; tc.delegateDepth = trycatch[0].delegateDepth; tc.retStore = trycatch[0].retStore; tc.loopSequence = trycatch[0].loopSequence; } // translate finally block = tc.finallyBlock; process(x->getFinally()); auto *finallyBlock = block; B->SetInsertPoint(finallyBlock); auto *excFlagRead = B->CreateLoad(B->getInt8Ty(), tc.excFlag); if (!isRoot) { auto *depthRead = B->CreateLoad(B->getInt64Ty(), tc.delegateDepth); auto *delegate = B->CreateICmpSGT(depthRead, B->getInt64(0)); auto *finallyNormal = llvm::BasicBlock::Create(*context, "trycatch.finally.normal", func); auto *finallyDelegate = llvm::BasicBlock::Create(*context, "trycatch.finally.delegate", func); B->CreateCondBr(delegate, finallyDelegate, finallyNormal); B->SetInsertPoint(finallyDelegate); auto *depthNew = B->CreateSub(depthRead, B->getInt64(1)); auto *delegateNew = B->CreateICmpSGT(depthNew, B->getInt64(0)); B->CreateStore(depthNew, tc.delegateDepth); B->CreateCondBr(delegateNew, trycatch.back().finallyBlock, trycatch.back().exceptionRouteBlock); finallyBlock = finallyNormal; B->SetInsertPoint(finallyNormal); } // handle exceptions that must route through 'finally' B->SetInsertPoint(tc.finallyExceptionBlock); if (!DisableExceptions) { auto *finallyCaughtResult = B->CreateLandingPad(padType, 1); finallyCaughtResult->setCleanup(true); finallyCaughtResult->addClause(getTypeIdxVar(nullptr)); B->CreateStore(finallyCaughtResult, tc.catchStore); B->CreateStore(excStateRethrow, tc.excFlag); auto *depthMax = B->getInt64(trycatch.size()); B->CreateStore(depthMax, tc.delegateDepth); B->CreateBr(tc.finallyBlock); } else { B->CreateUnreachable(); } B->SetInsertPoint(finallyBlock); auto *theSwitch = B->CreateSwitch(excFlagRead, endBlock, supportBreakAndContinue ? 6 : 4); theSwitch->addCase(excStateCaught, endBlock); theSwitch->addCase(excStateThrown, unwindResumeBlock); theSwitch->addCase(excStateRethrow, rethrowBlock); if (isRoot) { auto *finallyReturn = llvm::BasicBlock::Create(*context, "trycatch.finally.return", func); theSwitch->addCase(excStateReturn, finallyReturn); B->SetInsertPoint(finallyReturn); if (coro.exit) { B->CreateBr(coro.exit); } else if (tc.retStore) { auto *retVal = B->CreateLoad(func->getReturnType(), tc.retStore); B->CreateRet(retVal); } else { B->CreateRetVoid(); } } else { theSwitch->addCase(excStateReturn, trycatch.back().finallyBlock); } if (supportBreakAndContinue) { auto prevSeq = isRoot ? -1 : trycatch.back().sequenceNumber; auto *finallyBreak = llvm::BasicBlock::Create(*context, "trycatch.finally.break", func); auto *finallyBreakDone = llvm::BasicBlock::Create(*context, "trycatch.finally.break.done", func); auto *finallyContinue = llvm::BasicBlock::Create(*context, "trycatch.finally.continue", func); auto *finallyContinueDone = llvm::BasicBlock::Create(*context, "trycatch.finally.continue.done", func); B->SetInsertPoint(finallyBreak); auto *breakSwitch = B->CreateSwitch(B->CreateLoad(B->getInt64Ty(), tc.loopSequence), endBlock, 0); B->SetInsertPoint(finallyBreakDone); B->CreateStore(excStateNotThrown, tc.excFlag); auto *breakDoneSwitch = B->CreateSwitch(B->CreateLoad(B->getInt64Ty(), tc.loopSequence), endBlock, 0); B->SetInsertPoint(finallyContinue); auto *continueSwitch = B->CreateSwitch(B->CreateLoad(B->getInt64Ty(), tc.loopSequence), endBlock, 0); B->SetInsertPoint(finallyContinueDone); B->CreateStore(excStateNotThrown, tc.excFlag); auto *continueDoneSwitch = B->CreateSwitch(B->CreateLoad(B->getInt64Ty(), tc.loopSequence), endBlock, 0); for (auto &l : loops) { if (!trycatch.empty() && l.sequenceNumber < prevSeq) { breakSwitch->addCase(B->getInt64(l.sequenceNumber), trycatch.back().finallyBlock); continueSwitch->addCase(B->getInt64(l.sequenceNumber), trycatch.back().finallyBlock); } else { breakSwitch->addCase(B->getInt64(l.sequenceNumber), finallyBreakDone); breakDoneSwitch->addCase(B->getInt64(l.sequenceNumber), l.breakBlock); continueSwitch->addCase(B->getInt64(l.sequenceNumber), finallyContinueDone); continueDoneSwitch->addCase(B->getInt64(l.sequenceNumber), l.continueBlock); } } theSwitch->addCase(excStateBreak, finallyBreak); theSwitch->addCase(excStateContinue, finallyContinue); } // try and catch translate std::vector catches; for (auto &c : *x) { catches.push_back(&c); } llvm::BasicBlock *catchAll = nullptr; for (auto *c : catches) { auto *catchBlock = llvm::BasicBlock::Create(*context, "trycatch.catch", func); tc.catchTypes.push_back(c->getType()); tc.handlers.push_back(catchBlock); if (!c->getType()) { seqassertn(!catchAll, "cannot be catch all"); catchAll = catchBlock; } } // translate try if (!loops.empty()) { // make sure we reset the state to avoid issues with 'break'/'continue' B->SetInsertPoint(entryBlock); B->CreateStore(excStateNotThrown, tc.excFlag); } block = entryBlock; if (haveFinally) enterFinally(tc); enterTry(tc); // this is last so as to have larger sequence number process(x->getBody()); exitTry(); // translate else if (x->getElse()) { B->SetInsertPoint(block); auto *elseBlock = llvm::BasicBlock::Create(*context, "trycatch.else", func); B->CreateBr(elseBlock); block = elseBlock; process(x->getElse()); } // make sure we always get to finally block B->SetInsertPoint(block); B->CreateBr(tc.finallyBlock); // resume if uncaught B->SetInsertPoint(unwindResumeBlock); if (DisableExceptions) { B->CreateUnreachable(); } else { B->CreateResume(B->CreateLoad(padType, tc.catchStore)); } // make sure we delegate to parent try-catch if necessary std::vector catchTypesFull(tc.catchTypes); std::vector handlersFull(tc.handlers); std::vector depths(tc.catchTypes.size(), 0); unsigned depth = 1; unsigned catchAllDepth = 0; for (auto it = trycatch.rbegin(); it != trycatch.rend(); ++it) { if (catchAll) // can't ever delegate past catch-all break; seqassertn(it->catchTypes.size() == it->handlers.size(), "handler mismatch"); for (unsigned i = 0; i < it->catchTypes.size(); i++) { if (!anyMatch(it->catchTypes[i], catchTypesFull)) { catchTypesFull.push_back(it->catchTypes[i]); depths.push_back(depth); if (!it->catchTypes[i] && !catchAll) { // catch-all is in parent; set finally depth catchAll = llvm::BasicBlock::Create(*context, "trycatch.fdepth_catchall", func); B->SetInsertPoint(catchAll); B->CreateStore(B->getInt64(depth), tc.delegateDepth); B->CreateBr(it->handlers[i]); handlersFull.push_back(catchAll); catchAllDepth = depth; } else { handlersFull.push_back(it->handlers[i]); } } } ++depth; } // exception handling B->SetInsertPoint(tc.exceptionBlock); llvm::LandingPadInst *caughtResult = nullptr; if (!DisableExceptions) { caughtResult = B->CreateLandingPad(padType, catches.size()); caughtResult->setCleanup(true); } std::vector typeIndices; for (auto *catchType : catchTypesFull) { seqassertn(!catchType || cast(catchType), "invalid catch type"); const std::string typeVarName = "codon.typeidx." + (catchType ? catchType->getName() : ""); auto *tidx = getTypeIdxVar(catchType); typeIndices.push_back(tidx); if (caughtResult) caughtResult->addClause(tidx); } auto *caughtResultOrUndef = caughtResult ? llvm::cast(caughtResult) : llvm::UndefValue::get(padType); auto *unwindException = B->CreateExtractValue(caughtResultOrUndef, 0); B->CreateStore(caughtResultOrUndef, tc.catchStore); B->CreateStore(excStateThrown, tc.excFlag); auto *depthMax = B->getInt64(trycatch.size()); B->CreateStore(depthMax, tc.delegateDepth); auto *unwindExceptionClass = B->CreateLoad( B->getInt64Ty(), B->CreateStructGEP(unwindType, unwindException, 0)); // check for foreign exceptions B->CreateCondBr( B->CreateICmpEQ(unwindExceptionClass, B->getInt64(SEQ_EXCEPTION_CLASS)), tc.exceptionRouteBlock, externalExcBlock); // external exception (currently assumed to be unreachable) B->SetInsertPoint(externalExcBlock); B->CreateUnreachable(); // reroute Codon exceptions B->SetInsertPoint(tc.exceptionRouteBlock); unwindException = B->CreateExtractValue(B->CreateLoad(padType, tc.catchStore), 0); auto *excVal = B->CreateConstGEP1_64(B->getInt8Ty(), unwindException, (uint64_t)seq_exc_offset()); auto *loadedExc = B->CreateLoad(B->getPtrTy(), excVal); // set depth when catch-all entered auto *defaultRouteBlock = llvm::BasicBlock::Create(*context, "trycatch.fdepth", func); B->SetInsertPoint(defaultRouteBlock); if (catchAll) B->CreateStore(B->getInt64(catchAllDepth), tc.delegateDepth); B->CreateBr(catchAll ? (catchAllDepth > 0 ? tc.finallyBlock : catchAll) : tc.finallyBlock); B->SetInsertPoint(tc.exceptionRouteBlock); auto *objType = B->CreateExtractValue(B->CreateLoad(padType, tc.catchStore), 1); auto *switchToCatchBlock = B->CreateSwitch(objType, defaultRouteBlock, (unsigned)handlersFull.size()); for (unsigned i = 0; i < handlersFull.size(); i++) { // set finally depth auto *depthSet = llvm::BasicBlock::Create(*context, "trycatch.fdepth", func); B->SetInsertPoint(depthSet); B->CreateStore(B->getInt64(depths[i]), tc.delegateDepth); B->CreateBr((i < tc.handlers.size()) ? handlersFull[i] : tc.finallyBlock); if (catchTypesFull[i]) { switchToCatchBlock->addCase(B->getInt32((uint64_t)getTypeIdx(catchTypesFull[i])), depthSet); } // translate catch body if this block is ours (vs. a parent's) if (i < catches.size()) { block = handlersFull[i]; B->SetInsertPoint(block); const Var *var = catches[i]->getVar(); if (var) { auto *varPtr = getVar(var); seqassertn(varPtr, "could not get catch var"); B->CreateStore(loadedExc, varPtr); } B->CreateStore(excStateCaught, tc.excFlag); CatchData cd; cd.exception = loadedExc; cd.typeId = objType; enterCatch(cd); process(catches[i]->getHandler()); exitCatch(); B->SetInsertPoint(block); B->CreateBr(tc.finallyBlock); } } if (haveFinally) exitFinally(); // rethrow if handling 'finally' after exception raised from 'except'/'else' B->SetInsertPoint(rethrowBlock); if (!haveFinally || DisableExceptions) { B->CreateUnreachable(); } else { auto throwFunc = makeThrowFunc(); unwindException = B->CreateExtractValue(B->CreateLoad(padType, tc.catchStore), 0); block = rethrowBlock; call(throwFunc, unwindException); B->SetInsertPoint(block); B->CreateUnreachable(); } block = endBlock; } void LLVMVisitor::callStage(const PipelineFlow::Stage *stage) { auto *output = value; process(stage->getCallee()); auto *f = value; std::vector args; for (const auto *arg : *stage) { if (arg) { process(arg); args.push_back(value); } else { args.push_back(output); } } auto *funcType = getLLVMFuncType(stage->getCallee()->getType()); value = call({funcType, f}, args); } void LLVMVisitor::codegenPipeline( const std::vector &stages, unsigned where) { if (where >= stages.size()) { return; } auto *stage = stages[where]; if (where == 0) { process(stage->getCallee()); codegenPipeline(stages, where + 1); return; } auto *prevStage = stages[where - 1]; const bool generator = prevStage->isGenerator(); if (generator) { auto *generatorType = cast(prevStage->getOutputType()); seqassertn(generatorType, "{} is not a generator type", *prevStage->getOutputType()); auto *baseType = getLLVMType(generatorType->getBase()); auto *condBlock = llvm::BasicBlock::Create(*context, "pipeline.cond", func); auto *bodyBlock = llvm::BasicBlock::Create(*context, "pipeline.body", func); auto *cleanupBlock = llvm::BasicBlock::Create(*context, "pipeline.cleanup", func); auto *exitBlock = llvm::BasicBlock::Create(*context, "pipeline.exit", func); // LLVM coroutine intrinsics // https://prereleases.llvm.org/6.0.0/rc3/docs/Coroutines.html llvm::FunctionCallee coroResume = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_resume); llvm::FunctionCallee coroDone = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_done); llvm::FunctionCallee coroPromise = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_promise); llvm::FunctionCallee coroDestroy = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_destroy); auto *iter = value; B->SetInsertPoint(block); B->CreateBr(condBlock); block = condBlock; call(coroResume, {iter}); B->SetInsertPoint(block); auto *done = B->CreateCall(coroDone, iter); B->CreateCondBr(done, cleanupBlock, bodyBlock); B->SetInsertPoint(bodyBlock); auto *alignment = B->getInt32(M->getDataLayout().getPrefTypeAlign(baseType).value()); auto *from = B->getFalse(); auto *promise = B->CreateCall(coroPromise, {iter, alignment, from}); value = B->CreateLoad(baseType, promise); block = bodyBlock; callStage(stage); codegenPipeline(stages, where + 1); B->SetInsertPoint(block); B->CreateBr(condBlock); B->SetInsertPoint(cleanupBlock); B->CreateCall(coroDestroy, iter); B->CreateBr(exitBlock); block = exitBlock; } else { callStage(stage); codegenPipeline(stages, where + 1); } } void LLVMVisitor::visit(const PipelineFlow *x) { std::vector stages; for (const auto &stage : *x) { stages.push_back(&stage); } codegenPipeline(stages); } void LLVMVisitor::visit(const dsl::CustomFlow *x) { B->SetInsertPoint(block); value = x->getBuilder()->buildValue(this); } /* * Instructions */ void LLVMVisitor::visit(const AssignInstr *x) { auto *var = getVar(x->getLhs()); seqassertn(var, "could not find {} var", *x->getLhs()); process(x->getRhs()); if (var != getDummyVoidValue()) { B->SetInsertPoint(block); if (x->getLhs()->isThreadLocal()) var = B->CreateThreadLocalAddress(var); B->CreateStore(value, var); } } void LLVMVisitor::visit(const ExtractInstr *x) { auto *memberedType = cast(x->getVal()->getType()); seqassertn(memberedType, "{} is not a membered type", *x->getVal()->getType()); const int index = memberedType->getMemberIndex(x->getField()); seqassertn(index >= 0, "invalid index"); process(x->getVal()); B->SetInsertPoint(block); if (auto *refType = cast(memberedType)) { if (refType->isPolymorphic()) { // polymorphic ref type is ref to (data, rtti) value = B->CreateLoad(B->getPtrTy(), value); } value = B->CreateLoad(getLLVMType(refType->getContents()), value); } value = B->CreateExtractValue(value, index); } void LLVMVisitor::visit(const InsertInstr *x) { auto *refType = cast(x->getLhs()->getType()); seqassertn(refType, "{} is not a reference type", *x->getLhs()->getType()); const int index = refType->getMemberIndex(x->getField()); seqassertn(index >= 0, "invalid index"); process(x->getLhs()); auto *lhs = value; process(x->getRhs()); auto *rhs = value; B->SetInsertPoint(block); if (refType->isPolymorphic()) { // polymorphic ref type is ref to (data, rtti) lhs = B->CreateLoad(B->getPtrTy(), lhs); } llvm::Value *load = B->CreateLoad(getLLVMType(refType->getContents()), lhs); load = B->CreateInsertValue(load, rhs, index); B->CreateStore(load, lhs); } void LLVMVisitor::visit(const CallInstr *x) { B->SetInsertPoint(block); process(x->getCallee()); auto *f = value; std::vector args; for (auto *arg : *x) { B->SetInsertPoint(block); process(arg); args.push_back(value); } auto *funcType = getLLVMFuncType(x->getCallee()->getType()); value = call({funcType, f}, args); } void LLVMVisitor::visit(const TypePropertyInstr *x) { B->SetInsertPoint(block); switch (x->getProperty()) { case TypePropertyInstr::Property::SIZEOF: value = B->getInt64( M->getDataLayout().getTypeAllocSize(getLLVMType(x->getInspectType()))); break; case TypePropertyInstr::Property::IS_ATOMIC: value = B->getInt8(x->getInspectType()->isAtomic() ? 1 : 0); break; case TypePropertyInstr::Property::IS_CONTENT_ATOMIC: value = B->getInt8(x->getInspectType()->isContentAtomic() ? 1 : 0); break; default: seqassertn(0, "unknown type property"); } } void LLVMVisitor::visit(const YieldInInstr *x) { B->SetInsertPoint(block); if (x->isSuspending()) { llvm::FunctionCallee coroSuspend = llvm::Intrinsic::getDeclaration(M.get(), llvm::Intrinsic::coro_suspend); auto *tok = llvm::ConstantTokenNone::get(*context); auto *final = B->getFalse(); auto *susp = B->CreateCall(coroSuspend, {tok, final}); block = llvm::BasicBlock::Create(*context, "yieldin.new", func); auto *inst = B->CreateSwitch(susp, coro.suspend, 2); inst->addCase(B->getInt8(0), block); inst->addCase(B->getInt8(1), coro.cleanup); B->SetInsertPoint(block); } value = B->CreateLoad(getLLVMType(x->getType()), coro.promise); } void LLVMVisitor::visit(const StackAllocInstr *x) { auto *recordType = cast(x->getType()); seqassertn(recordType, "stack alloc does not have record type"); auto *ptrType = cast(recordType->back().getType()); seqassertn(ptrType, "array did not have ptr type"); auto *arrayType = llvm::cast(getLLVMType(x->getType())); B->SetInsertPoint(func->getEntryBlock().getTerminator()); auto *len = B->getInt64(x->getCount()); auto *ptr = B->CreateAlloca(getLLVMType(ptrType->getBase()), len); llvm::Value *arr = llvm::UndefValue::get(arrayType); arr = B->CreateInsertValue(arr, len, 0); arr = B->CreateInsertValue(arr, ptr, 1); value = arr; } void LLVMVisitor::visit(const TernaryInstr *x) { auto *trueBlock = llvm::BasicBlock::Create(*context, "ternary.true", func); auto *falseBlock = llvm::BasicBlock::Create(*context, "ternary.false", func); auto *exitBlock = llvm::BasicBlock::Create(*context, "ternary.exit", func); auto *valueType = getLLVMType(x->getType()); process(x->getCond()); auto *cond = value; B->SetInsertPoint(block); cond = B->CreateTrunc(cond, B->getInt1Ty()); B->CreateCondBr(cond, trueBlock, falseBlock); block = trueBlock; process(x->getTrueValue()); auto *trueValue = value; trueBlock = block; B->SetInsertPoint(trueBlock); B->CreateBr(exitBlock); block = falseBlock; process(x->getFalseValue()); auto *falseValue = value; falseBlock = block; B->SetInsertPoint(falseBlock); B->CreateBr(exitBlock); B->SetInsertPoint(exitBlock); auto *phi = B->CreatePHI(valueType, 2); phi->addIncoming(trueValue, trueBlock); phi->addIncoming(falseValue, falseBlock); value = phi; block = exitBlock; } void LLVMVisitor::visit(const BreakInstr *x) { seqassertn(!loops.empty(), "not in a loop"); B->SetInsertPoint(block); auto *loop = !x->getLoop() ? &loops.back() : getLoopData(x->getLoop()->getId()); if (finally.empty() || finally.back().sequenceNumber < loop->sequenceNumber) { B->CreateBr(loop->breakBlock); } else { auto *tc = &finally.back(); auto *excStateBreak = B->getInt8(TryCatchData::State::BREAK); B->CreateStore(excStateBreak, tc->excFlag); B->CreateStore(B->getInt64(loop->sequenceNumber), tc->loopSequence); B->CreateBr(tc->finallyBlock); } block = llvm::BasicBlock::Create(*context, "break.new", func); } void LLVMVisitor::visit(const ContinueInstr *x) { seqassertn(!loops.empty(), "not in a loop"); B->SetInsertPoint(block); auto *loop = !x->getLoop() ? &loops.back() : getLoopData(x->getLoop()->getId()); if (finally.empty() || finally.back().sequenceNumber < loop->sequenceNumber) { B->CreateBr(loop->continueBlock); } else { auto *tc = &finally.back(); auto *excStateContinue = B->getInt8(TryCatchData::State::CONTINUE); B->CreateStore(excStateContinue, tc->excFlag); B->CreateStore(B->getInt64(loop->sequenceNumber), tc->loopSequence); B->CreateBr(tc->finallyBlock); } block = llvm::BasicBlock::Create(*context, "continue.new", func); } void LLVMVisitor::visit(const ReturnInstr *x) { if (x->getValue()) { process(x->getValue()); } B->SetInsertPoint(block); if (coro.exit) { if (coro.async) B->CreateStore(value, coro.promise); if (auto *tc = getInnermostTryCatch()) { auto *excStateReturn = B->getInt8(TryCatchData::State::RETURN); B->CreateStore(excStateReturn, tc->excFlag); B->CreateBr(tc->finallyBlock); } else { B->CreateBr(coro.exit); } } else { if (auto *tc = getInnermostTryCatch()) { auto *excStateReturn = B->getInt8(TryCatchData::State::RETURN); B->CreateStore(excStateReturn, tc->excFlag); if (tc->retStore) { seqassertn(value, "no return value storage"); B->CreateStore(value, tc->retStore); } B->CreateBr(tc->finallyBlock); } else { if (x->getValue()) { B->CreateRet(value); } else { B->CreateRetVoid(); } } } block = llvm::BasicBlock::Create(*context, "return.new", func); } void LLVMVisitor::visit(const YieldInstr *x) { if (x->isFinal()) { if (x->getValue()) { seqassertn(coro.promise, "no coroutine promise"); process(x->getValue()); B->SetInsertPoint(block); B->CreateStore(value, coro.promise); } B->SetInsertPoint(block); if (auto *tc = getInnermostTryCatch()) { auto *excStateReturn = B->getInt8(TryCatchData::State::RETURN); B->CreateStore(excStateReturn, tc->excFlag); B->CreateBr(tc->finallyBlock); } else { B->CreateBr(coro.exit); } block = llvm::BasicBlock::Create(*context, "yield.new", func); } else { if (x->getValue()) { process(x->getValue()); makeYield(value); } else { makeYield(nullptr); } } } void LLVMVisitor::visit(const AwaitInstr *x) { seqassertn(false, "await instruction not lowered"); } void LLVMVisitor::visit(const ThrowInstr *x) { if (DisableExceptions) { B->SetInsertPoint(block); B->CreateUnreachable(); block = llvm::BasicBlock::Create(*context, "throw_unreachable.new", func); return; } // note: exception header should be set in the frontend auto excAllocFunc = makeExcAllocFunc(); auto throwFunc = makeThrowFunc(); llvm::Value *obj = nullptr; llvm::Value *typ = nullptr; if (x->getValue()) { process(x->getValue()); obj = value; typ = B->getInt32(getTypeIdx(x->getValue()->getType())); } else { seqassertn(!catches.empty(), "empty raise outside of except block"); obj = catches.back().exception; typ = catches.back().typeId; } B->SetInsertPoint(block); auto *exc = B->CreateCall(excAllocFunc, {obj}); call(throwFunc, exc); } void LLVMVisitor::visit(const FlowInstr *x) { process(x->getFlow()); process(x->getValue()); } void LLVMVisitor::visit(const dsl::CustomInstr *x) { B->SetInsertPoint(block); value = x->getBuilder()->buildValue(this); } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/llvisitor.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/cir.h" #include "codon/cir/llvm/llvm.h" #include "codon/cir/pyextension.h" #include "codon/dsl/plugins.h" #include "codon/util/common.h" #include #include #include #include namespace codon { namespace ir { class LLVMVisitor : public util::ConstVisitor { private: struct CoroData { /// Coroutine promise (where yielded values are stored) llvm::Value *promise; /// Coroutine handle llvm::Value *handle; /// Coroutine cleanup block llvm::BasicBlock *cleanup; /// Coroutine suspend block llvm::BasicBlock *suspend; /// Coroutine exit block llvm::BasicBlock *exit; /// True if coroutine represents 'async' function bool async; void reset() { promise = handle = cleanup = suspend = exit = nullptr; async = false; } }; struct NestableData { int sequenceNumber; NestableData() : sequenceNumber(-1) {} }; struct LoopData : NestableData { /// Block to branch to in case of "break" llvm::BasicBlock *breakBlock; /// Block to branch to in case of "continue" llvm::BasicBlock *continueBlock; /// Loop id id_t loopId; LoopData(llvm::BasicBlock *breakBlock, llvm::BasicBlock *continueBlock, id_t loopId) : NestableData(), breakBlock(breakBlock), continueBlock(continueBlock), loopId(loopId) {} void reset() { breakBlock = continueBlock = nullptr; } }; struct TryCatchData : NestableData { /// Possible try-catch states when reaching finally block enum State { NOT_THROWN = 0, THROWN, CAUGHT, RETURN, BREAK, CONTINUE, RETHROW }; /// Exception block llvm::BasicBlock *exceptionBlock; /// Exception route block llvm::BasicBlock *exceptionRouteBlock; /// Finally start block llvm::BasicBlock *finallyBlock; /// Block to support exceptions raised from 'except' and 'else' blocks llvm::BasicBlock *finallyExceptionBlock; /// Try-catch catch types std::vector catchTypes; /// Try-catch handlers, corresponding to catch types std::vector handlers; /// Exception state flag (see "State") llvm::Value *excFlag; /// Storage for caught exception llvm::Value *catchStore; /// How far to delegate up the finally chain llvm::Value *delegateDepth; /// Storage for postponed return llvm::Value *retStore; /// Loop being manipulated llvm::Value *loopSequence; TryCatchData() : NestableData(), exceptionBlock(nullptr), exceptionRouteBlock(nullptr), finallyBlock(nullptr), finallyExceptionBlock(nullptr), catchTypes(), handlers(), excFlag(nullptr), catchStore(nullptr), delegateDepth(nullptr), retStore(nullptr), loopSequence(nullptr) {} void reset() { exceptionBlock = exceptionRouteBlock = finallyBlock = finallyExceptionBlock = nullptr; catchTypes.clear(); handlers.clear(); excFlag = catchStore = delegateDepth = loopSequence = nullptr; } }; struct CatchData : NestableData { /// Exception object llvm::Value *exception; /// Exception type ID llvm::Value *typeId; }; struct DebugInfo { /// LLVM debug info builder std::unique_ptr builder; /// Current compilation unit llvm::DICompileUnit *unit; /// Whether we are compiling in debug mode bool debug; /// Whether we are compiling in JIT mode bool jit; /// Whether we are compiling a standalone object/executable bool standalone; /// Whether to capture writes to stdout/stderr bool capture; /// Program command-line flags std::string flags; DebugInfo() : builder(), unit(nullptr), debug(false), jit(false), standalone(false), capture(false), flags() {} llvm::DIFile *getFile(const std::string &path); void reset() { builder = {}; unit = nullptr; } }; /// LLVM context used for compilation std::unique_ptr context; /// Module we are compiling std::unique_ptr M; /// LLVM IR builder used for constructing LLVM IR std::unique_ptr> B; /// Current function we are compiling llvm::Function *func; /// Current basic block we are compiling llvm::BasicBlock *block; /// Last compiled value llvm::Value *value; /// LLVM values corresponding to IR variables std::unordered_map vars; /// LLVM functions corresponding to IR functions std::unordered_map funcs; /// Coroutine data, if current function is a coroutine CoroData coro; /// Loop data stack, containing break/continue blocks std::vector loops; /// Try-block data stack std::vector trycatch; /// Finally-block data stack std::vector finally; /// Catch-block data stack std::vector catches; /// Debug information DebugInfo db; /// Plugin manager PluginManager *plugins; llvm::DIType * getDITypeHelper(types::Type *t, std::unordered_map &cache); /// GC allocation functions llvm::FunctionCallee makeAllocFunc(bool atomic, bool uncollectable = false); // GC reallocation function llvm::FunctionCallee makeReallocFunc(); // GC free function llvm::FunctionCallee makeFreeFunc(); /// Personality function for exception handling llvm::FunctionCallee makePersonalityFunc(); /// Exception allocation function llvm::FunctionCallee makeExcAllocFunc(); /// Exception throw function llvm::FunctionCallee makeThrowFunc(); /// Program termination function llvm::FunctionCallee makeTerminateFunc(); // Try-catch types and utilities llvm::StructType *getTypeInfoType(); llvm::StructType *getPadType(); llvm::StructType *getExceptionType(); llvm::GlobalVariable *getTypeIdxVar(types::Type *catchType); int getTypeIdx(types::Type *catchType = nullptr); // General function helpers llvm::Value *call(llvm::FunctionCallee callee, llvm::ArrayRef args); llvm::Function *makeLLVMFunction(const Func *); void makeYield(llvm::Value *value = nullptr, bool finalYield = false); std::string buildLLVMCodeString(const LLVMFunc *); void callStage(const PipelineFlow::Stage *stage); void codegenPipeline(const std::vector &stages, unsigned where = 0); // Loop and try-catch state void enterLoop(LoopData data); void exitLoop(); void enterTry(TryCatchData data); void exitTry(); void enterFinally(TryCatchData data); void exitFinally(); void enterCatch(CatchData data); void exitCatch(); TryCatchData *getInnermostTryCatch(); TryCatchData *getInnermostTryCatchBeforeLoop(); // Global constructor setup void setupGlobalCtor(); // Python extension setup llvm::Function *createPyTryCatchWrapper(llvm::Function *func); // LLVM passes void runLLVMPipeline(); llvm::Value *getVar(const Var *var); void insertVar(const Var *var, llvm::Value *x) { vars.emplace(var->getId(), x); } llvm::Function *getFunc(const Func *func); void insertFunc(const Func *func, llvm::Function *x) { funcs.emplace(func->getId(), x); } llvm::Value *getDummyVoidValue() { return llvm::ConstantTokenNone::get(*context); } llvm::DISubprogram *getDISubprogramForFunc(const Func *x); void clearLLVMData(); public: static std::string getGlobalCtorName(); static std::string getNameForFunction(const Func *x); static std::string getNameForVar(const Var *x); static std::string getDebugNameForVariable(const Var *x) { std::string name = x->getName(); auto pos = name.find("."); if (pos != 0 && pos != std::string::npos) { return name.substr(0, pos); } else { return name; } } static const SrcInfo *getDefaultSrcInfo() { static SrcInfo defaultSrcInfo("", 0, 0, 0); return &defaultSrcInfo; } static const SrcInfo *getSrcInfo(const Node *x) { if (auto *srcInfo = x->getAttribute()) { return &srcInfo->info; } else { return getDefaultSrcInfo(); } } /// Constructs an LLVM visitor. LLVMVisitor(); /// @return true if in debug mode, false otherwise bool getDebug() const { return db.debug; } /// Sets debug status. /// @param d true if debug mode void setDebug(bool d = true) { db.debug = d; } /// @return true if in JIT mode, false otherwise bool getJIT() const { return db.jit; } /// Sets JIT status. /// @param j true if JIT mode void setJIT(bool j = true) { db.jit = j; } /// @return true if in standalone mode, false otherwise bool getStandalone() const { return db.standalone; } /// Sets standalone status. /// @param s true if standalone void setStandalone(bool s = true) { db.standalone = s; } /// @return true if capturing outputs, false otherwise bool getCapture() const { return db.capture; } /// Sets capture status. /// @param c true to capture void setCapture(bool c = true) { db.capture = c; } /// @return program flags std::string getFlags() const { return db.flags; } /// Sets program flags. /// @param f flags void setFlags(const std::string &f) { db.flags = f; } llvm::LLVMContext &getContext() { return *context; } llvm::IRBuilder<> &getBuilder() { return *B; } llvm::Module *getModule() { return M.get(); } llvm::FunctionCallee getFunc() { return func; } llvm::BasicBlock *getBlock() { return block; } llvm::Value *getValue() { return value; } std::unordered_map &getVars() { return vars; } std::unordered_map &getFuncs() { return funcs; } CoroData &getCoro() { return coro; } std::vector &getLoops() { return loops; } std::vector &getTryCatch() { return trycatch; } DebugInfo &getDebugInfo() { return db; } void setFunc(llvm::Function *f) { func = f; } void setBlock(llvm::BasicBlock *b) { block = b; } void setValue(llvm::Value *v) { value = v; } /// Registers a new global variable or function with /// this visitor. /// @param var the global variable (or function) to register void registerGlobal(const Var *var); /// Returns a new LLVM module initialized for the host /// architecture. /// @param context LLVM context used for creating module /// @param src source information for the new module /// @return a new module std::unique_ptr makeModule(llvm::LLVMContext &context, const SrcInfo *src = nullptr); /// Returns the current module/LLVM context and replaces them /// with new, fresh ones. References to variables or functions /// from the old module will be included as "external". /// @param module the IR module /// @param src source information for the new module /// @return the current module/context, replaced internally std::pair, std::unique_ptr> takeModule(Module *module, const SrcInfo *src = nullptr); /// Sets current debug info based on a given node. /// @param node the node whose debug info to use void setDebugInfoForNode(const Node *node); /// Compiles a given IR node, updating the internal /// LLVM value and/or function as a result. /// @param node the node to compile void process(const Node *node); /// Dumps the unoptimized module IR to a file. /// @param filename name of file to write IR to void dump(const std::string &filename = "_dump.ll"); /// Writes module as native object file. /// @param filename the .o file to write to /// @param pic true to write position-independent code /// @param assembly true to write assembly instead of binary void writeToObjectFile(const std::string &filename, bool pic = false, bool assembly = false); /// Writes module as LLVM bitcode file. /// @param filename the .bc file to write to void writeToBitcodeFile(const std::string &filename); /// Writes module as LLVM IR file. /// @param filename the .ll file to write to void writeToLLFile(const std::string &filename, bool optimize = true); /// Writes module as native executable. Invokes an /// external linker to generate the final executable. /// @param filename the file to write to /// @param argv0 compiler's argv[0] used to set rpath /// @param library whether to make a shared library /// @param libs library names to link /// @param lflags extra flags to pass linker void writeToExecutable(const std::string &filename, const std::string &argv0, bool library = false, const std::vector &libs = {}, const std::string &lflags = ""); /// Writes module as Python extension object. /// @param pymod extension module /// @param filename the file to write to void writeToPythonExtension(const PyModule &pymod, const std::string &filename); /// Runs optimization passes on module and writes the result /// to the specified file. The output type is determined by /// the file extension (.ll for LLVM IR, .bc for LLVM bitcode /// .o or .obj for object file, other for executable). /// @param filename name of the file to write to /// @param argv0 compiler's argv[0] used to set rpath /// @param libs library names to link to, if creating executable /// @param lflags extra flags to pass linker, if creating executable void compile(const std::string &filename, const std::string &argv0, const std::vector &libs = {}, const std::string &lflags = ""); /// Runs optimization passes on module and executes it. /// @param args vector of arguments to program /// @param libs vector of libraries to load /// @param envp program environment void run(const std::vector &args = {}, const std::vector &libs = {}, const char *const *envp = nullptr); /// Gets LLVM type from IR type /// @param t the IR type /// @return corresponding LLVM type llvm::Type *getLLVMType(types::Type *t); /// Gets LLVM function type from IR function type /// @param t the IR type (must be FuncType) /// @return corresponding LLVM function type llvm::FunctionType *getLLVMFuncType(types::Type *t); /// Gets the LLVM debug info type from the IR type /// @param t the IR type /// @return corresponding LLVM DI type llvm::DIType *getDIType(types::Type *t); /// Gets loop data for a given loop id /// @param loopId the IR id of the loop /// @return the loop's datas LoopData *getLoopData(id_t loopId); /// Sets the plugin manager /// @param p the plugin manager void setPluginManager(PluginManager *p) { plugins = p; } /// @return the plugin manager PluginManager *getPluginManager() { return plugins; } void visit(const Module *) override; void visit(const BodiedFunc *) override; void visit(const ExternalFunc *) override; void visit(const InternalFunc *) override; void visit(const LLVMFunc *) override; void visit(const Var *) override; void visit(const VarValue *) override; void visit(const PointerValue *) override; void visit(const IntConst *) override; void visit(const FloatConst *) override; void visit(const BoolConst *) override; void visit(const StringConst *) override; void visit(const dsl::CustomConst *) override; void visit(const SeriesFlow *) override; void visit(const IfFlow *) override; void visit(const WhileFlow *) override; void visit(const ForFlow *) override; void visit(const ImperativeForFlow *) override; void visit(const TryCatchFlow *) override; void visit(const PipelineFlow *) override; void visit(const dsl::CustomFlow *) override; void visit(const AssignInstr *) override; void visit(const ExtractInstr *) override; void visit(const InsertInstr *) override; void visit(const CallInstr *) override; void visit(const StackAllocInstr *) override; void visit(const TypePropertyInstr *) override; void visit(const YieldInInstr *) override; void visit(const TernaryInstr *) override; void visit(const BreakInstr *) override; void visit(const ContinueInstr *) override; void visit(const ReturnInstr *) override; void visit(const YieldInstr *) override; void visit(const AwaitInstr *) override; void visit(const ThrowInstr *) override; void visit(const FlowInstr *) override; void visit(const dsl::CustomInstr *) override; }; } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/llvm.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "llvm/ADT/BitVector.h" #include "llvm/ADT/FunctionExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/CycleAnalysis.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/RegionPass.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/AsmParser/Parser.h" #include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/CodeGen/CommandFlags.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/DebugInfo/Symbolize/Symbolize.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/GenericValue.h" #include "llvm/ExecutionEngine/JITEventListener.h" #include "llvm/ExecutionEngine/JITLink/JITLink.h" #include "llvm/ExecutionEngine/JITLink/JITLinkDylib.h" #include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/MCJIT.h" #include "llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/DebugObjectManagerPlugin.h" #include "llvm/ExecutionEngine/Orc/DebugUtils.h" #include "llvm/ExecutionEngine/Orc/ELFNixPlatform.h" #include "llvm/ExecutionEngine/Orc/EPCDebugObjectRegistrar.h" #include "llvm/ExecutionEngine/Orc/EPCDynamicLibrarySearchGenerator.h" #include "llvm/ExecutionEngine/Orc/EPCEHFrameRegistrar.h" #include "llvm/ExecutionEngine/Orc/EPCGenericRTDyldMemoryManager.h" #include "llvm/ExecutionEngine/Orc/EPCIndirectionUtils.h" #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" #include "llvm/ExecutionEngine/Orc/LLJIT.h" #include "llvm/ExecutionEngine/Orc/MachOPlatform.h" #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "llvm/ExecutionEngine/Orc/Shared/AllocationActions.h" #include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" #include "llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h" #include "llvm/ExecutionEngine/Orc/SymbolStringPool.h" #include "llvm/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.h" #include "llvm/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.h" #include "llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h" #include "llvm/ExecutionEngine/RuntimeDyld.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LLVMRemarkStreamer.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassNameParser.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Verifier.h" #include "llvm/IRReader/IRReader.h" #include "llvm/InitializePasses.h" #include "llvm/LinkAllIR.h" #include "llvm/LinkAllPasses.h" #include "llvm/Linker/Linker.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DynamicLibrary.h" #include "llvm/Support/Error.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/Memory.h" #include "llvm/Support/Process.h" #include "llvm/Support/RecyclingAllocator.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/SystemUtils.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/YAMLTraits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetLoweringObjectFile.h" #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Host.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/IPO/GlobalDCE.h" #include "llvm/Transforms/IPO/Internalize.h" #include "llvm/Transforms/IPO/StripDeadPrototypes.h" #include "llvm/Transforms/IPO/StripSymbols.h" #include "llvm/Transforms/IPO/WholeProgramDevirt.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Debugify.h" ================================================ FILE: codon/cir/llvm/native/native.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "native.h" #include "codon/cir/llvm/llvm.h" #include "codon/cir/llvm/native/targets/aarch64.h" #include "codon/cir/llvm/native/targets/arm.h" #include "codon/cir/llvm/native/targets/x86.h" // Targets adapted from // https://github.com/llvm/llvm-project/tree/main/clang/lib/Driver/ToolChains/Arch namespace codon { namespace ir { namespace { std::unique_ptr getNativeTarget(const llvm::Triple &triple) { std::unique_ptr result = std::unique_ptr(); switch (triple.getArch()) { default: break; case llvm::Triple::mips: case llvm::Triple::mipsel: case llvm::Triple::mips64: case llvm::Triple::mips64el: // nothing break; case llvm::Triple::arm: case llvm::Triple::armeb: case llvm::Triple::thumb: case llvm::Triple::thumbeb: result = std::make_unique(); break; case llvm::Triple::ppc: case llvm::Triple::ppcle: case llvm::Triple::ppc64: case llvm::Triple::ppc64le: // nothing break; case llvm::Triple::riscv32: case llvm::Triple::riscv64: // nothing break; case llvm::Triple::systemz: // nothing break; case llvm::Triple::aarch64: case llvm::Triple::aarch64_32: case llvm::Triple::aarch64_be: result = std::make_unique(); break; case llvm::Triple::x86: case llvm::Triple::x86_64: result = std::make_unique(); break; case llvm::Triple::hexagon: // nothing break; case llvm::Triple::wasm32: case llvm::Triple::wasm64: // nothing break; case llvm::Triple::sparc: case llvm::Triple::sparcel: case llvm::Triple::sparcv9: // nothing break; case llvm::Triple::r600: case llvm::Triple::amdgcn: // nothing break; case llvm::Triple::msp430: // nothing break; case llvm::Triple::ve: // nothing break; } return result; } class ArchNativePass : public llvm::PassInfoMixin { private: std::string cpu; std::string features; public: explicit ArchNativePass(const std::string &cpu = "", const std::string &features = "") : cpu(cpu), features(features) {} llvm::PreservedAnalyses run(llvm::Function &F, llvm::FunctionAnalysisManager &) { if (!cpu.empty()) F.addFnAttr("target-cpu", cpu); if (!features.empty()) F.addFnAttr("target-features", features); F.addFnAttr("frame-pointer", "none"); return llvm::PreservedAnalyses::all(); } }; } // namespace void addNativeLLVMPasses(llvm::PassBuilder *pb) { llvm::Triple triple = llvm::EngineBuilder().selectTarget()->getTargetTriple(); auto target = getNativeTarget(triple); if (!target) return; std::string cpu = target->getCPU(triple); std::string features = target->getFeatures(triple); pb->registerPipelineEarlySimplificationEPCallback( [cpu, features](llvm::ModulePassManager &pm, llvm::OptimizationLevel opt, llvm::ThinOrFullLTOPhase lto) { pm.addPass( llvm::createModuleToFunctionPassAdaptor(ArchNativePass(cpu, features))); }); } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/native/native.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/llvm/llvm.h" namespace codon { namespace ir { void addNativeLLVMPasses(llvm::PassBuilder *pb); } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/native/targets/aarch64.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "aarch64.h" #include "llvm/TargetParser/AArch64TargetParser.h" namespace codon { namespace ir { namespace { template std::string join(const T &v, const std::string &delim = ",") { std::ostringstream s; for (const auto &i : v) { if (&i != &v[0]) s << delim; s << std::string(i); } return s.str(); } } // namespace std::string Aarch64::getCPU(const llvm::Triple &triple) const { return llvm::sys::getHostCPUName().str(); } std::string Aarch64::getFeatures(const llvm::Triple &triple) const { llvm::AArch64::ExtensionSet extensions; std::vector features; std::string cpu(llvm::sys::getHostCPUName()); const std::optional cpuInfo = llvm::AArch64::parseCpu(cpu); if (!cpuInfo) return ""; auto *archInfo = llvm::AArch64::getArchForCpu(cpu); extensions.addCPUDefaults(*cpuInfo); extensions.addArchDefaults(*archInfo); extensions.toLLVMFeatureList(features); for (auto &f : llvm::sys::getHostCPUFeatures()) { features.push_back((f.second ? "+" : "-") + f.first().str()); } if (cpu == "cyclone" || llvm::StringRef(cpu).starts_with("apple")) { features.push_back("+zcm"); features.push_back("+zcz"); } if (triple.isAndroid() || triple.isOHOSFamily()) { // Enabled A53 errata (835769) workaround by default on android features.push_back("+fix-cortex-a53-835769"); } else if (triple.isOSFuchsia()) { if (cpu.empty() || cpu == "generic" || cpu == "cortex-a53") features.push_back("+fix-cortex-a53-835769"); } if (triple.isOSOpenBSD()) features.push_back("+strict-align"); return join(features); } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/native/targets/aarch64.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/llvm/native/targets/target.h" namespace codon { namespace ir { class Aarch64 : public Target { public: std::string getCPU(const llvm::Triple &triple) const override; std::string getFeatures(const llvm::Triple &triple) const override; }; } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/native/targets/arm.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "arm.h" #include "llvm/TargetParser/ARMTargetParser.h" namespace codon { namespace ir { namespace { template std::string join(const T &v, const std::string &delim = ",") { std::ostringstream s; for (const auto &i : v) { if (&i != &v[0]) s << delim; s << std::string(i); } return s.str(); } int getARMSubArchVersionNumber(const llvm::Triple &triple) { auto arch = triple.getArchName(); return llvm::ARM::parseArchVersion(arch); } bool isARMMProfile(const llvm::Triple &triple) { auto arch = triple.getArchName(); return llvm::ARM::parseArchProfile(arch) == llvm::ARM::ProfileKind::M; } bool isARMBigEndian(const llvm::Triple &triple) { return triple.getArch() == llvm::Triple::armeb || triple.getArch() == llvm::Triple::thumbeb; } bool isARMAProfile(const llvm::Triple &triple) { auto arch = triple.getArchName(); return llvm::ARM::parseArchProfile(arch) == llvm::ARM::ProfileKind::A; } bool isARMEABIBareMetal(const llvm::Triple &triple) { auto arch = triple.getArch(); if (arch != llvm::Triple::arm && arch != llvm::Triple::thumb && arch != llvm::Triple::armeb && arch != llvm::Triple::thumbeb) return false; if (triple.getVendor() != llvm::Triple::UnknownVendor) return false; if (triple.getOS() != llvm::Triple::UnknownOS) return false; if (triple.getEnvironment() != llvm::Triple::EABI && triple.getEnvironment() != llvm::Triple::EABIHF) return false; return true; } bool useAAPCSForMachO(const llvm::Triple &triple) { return triple.getEnvironment() == llvm::Triple::EABI || triple.getEnvironment() == llvm::Triple::EABIHF || triple.getOS() == llvm::Triple::UnknownOS || isARMMProfile(triple); } enum FloatABI { Invalid, Hard, Soft, SoftFP }; FloatABI getDefaultFloatABI(const llvm::Triple &triple) { auto sub = getARMSubArchVersionNumber(triple); switch (triple.getOS()) { case llvm::Triple::Darwin: case llvm::Triple::MacOSX: case llvm::Triple::IOS: case llvm::Triple::TvOS: case llvm::Triple::DriverKit: case llvm::Triple::XROS: // Darwin defaults to "softfp" for v6 and v7. if (triple.isWatchABI()) return FloatABI::Hard; else return (sub == 6 || sub == 7) ? FloatABI::SoftFP : FloatABI::Soft; case llvm::Triple::WatchOS: return FloatABI::Hard; // FIXME: this is invalid for WindowsCE case llvm::Triple::Win32: // It is incorrect to select hard float ABI on MachO platforms if the ABI is // "apcs-gnu". if (triple.isOSBinFormatMachO() && !useAAPCSForMachO(triple)) return FloatABI::Soft; return FloatABI::Hard; case llvm::Triple::NetBSD: switch (triple.getEnvironment()) { case llvm::Triple::EABIHF: case llvm::Triple::GNUEABIHF: return FloatABI::Hard; default: return FloatABI::Soft; } break; case llvm::Triple::FreeBSD: switch (triple.getEnvironment()) { case llvm::Triple::GNUEABIHF: return FloatABI::Hard; default: // FreeBSD defaults to soft float return FloatABI::Soft; } break; case llvm::Triple::Haiku: case llvm::Triple::OpenBSD: return FloatABI::SoftFP; default: if (triple.isOHOSFamily()) return FloatABI::Soft; switch (triple.getEnvironment()) { case llvm::Triple::GNUEABIHF: case llvm::Triple::GNUEABIHFT64: case llvm::Triple::MuslEABIHF: case llvm::Triple::EABIHF: return FloatABI::Hard; case llvm::Triple::Android: case llvm::Triple::GNUEABI: case llvm::Triple::GNUEABIT64: case llvm::Triple::MuslEABI: case llvm::Triple::EABI: // EABI is always AAPCS, and if it was not marked 'hard', it's softfp return FloatABI::SoftFP; default: return FloatABI::Invalid; } } return FloatABI::Invalid; } FloatABI getARMFloatABI(const llvm::Triple &triple) { FloatABI abi = getDefaultFloatABI(triple); if (abi == FloatABI::Invalid) { // Assume "soft", but warn the user we are guessing. if (triple.isOSBinFormatMachO() && triple.getSubArch() == llvm::Triple::ARMSubArch_v7em) abi = FloatABI::Hard; else abi = FloatABI::Soft; } return abi; } llvm::ARM::ArchKind getLLVMArchKindForARM(llvm::StringRef cpu, llvm::StringRef arch, const llvm::Triple &triple) { return (arch == "armv7k" || arch == "thumbv7k") ? llvm::ARM::ArchKind::ARMV7K : llvm::ARM::parseCPUArch(cpu); } llvm::StringRef getLLVMArchSuffixForARM(llvm::StringRef cpu, llvm::StringRef arch, const llvm::Triple &triple) { llvm::ARM::ArchKind archKind = getLLVMArchKindForARM(cpu, arch, triple); if (archKind == llvm::ARM::ArchKind::INVALID) return ""; return llvm::ARM::getSubArch(archKind); } bool hasIntegerMVE(const std::vector &F) { auto MVE = llvm::find(llvm::reverse(F), "+mve"); auto NoMVE = llvm::find(llvm::reverse(F), "-mve"); return MVE != F.rend() && (NoMVE == F.rend() || std::distance(MVE, NoMVE) > 0); } } // namespace std::string ARM::getCPU(const llvm::Triple &triple) const { return llvm::sys::getHostCPUName().str(); } std::string ARM::getFeatures(const llvm::Triple &triple) const { std::vector features; auto abi = getARMFloatABI(triple); // uint64_t HWDivID = llvm::ARM::parseHWDiv(HWDiv); // Use software floating point operations? if (abi == FloatABI::Soft) features.push_back("+soft-float"); // Use software floating point argument passing? if (abi != FloatABI::Hard) features.push_back("+soft-float-abi"); std::vector tmp; // make sure we don't delete string data for (auto &f : llvm::sys::getHostCPUFeatures()) { tmp.push_back((f.second ? "+" : "-") + f.first().str()); features.push_back(tmp.back()); } llvm::ARM::FPUKind fpu = llvm::ARM::FK_INVALID; auto arch = triple.getArchName(); // Honor -mfpu=. ClangAs gives preference to -Wa,-mfpu=. if (triple.isAndroid() && getARMSubArchVersionNumber(triple) == 7) { const char *androidFPU = "neon"; fpu = llvm::ARM::parseFPU(androidFPU); } else { std::string cpu = getCPU(triple); llvm::ARM::ArchKind archkind = getLLVMArchKindForARM(cpu, arch, triple); fpu = llvm::ARM::getDefaultFPU(cpu, archkind); } (void)llvm::ARM::getFPUFeatures(fpu, features); // Handle (arch-dependent) fp16fml/fullfp16 relationship. // Must happen before any features are disabled due to soft-float. // FIXME: this fp16fml option handling will be reimplemented after the // TargetParser rewrite. const auto ItRNoFullFP16 = std::find(features.rbegin(), features.rend(), "-fullfp16"); const auto ItRFP16FML = std::find(features.rbegin(), features.rend(), "+fp16fml"); if (triple.getSubArch() == llvm::Triple::SubArchType::ARMSubArch_v8_4a) { const auto ItRFullFP16 = std::find(features.rbegin(), features.rend(), "+fullfp16"); if (ItRFullFP16 < ItRNoFullFP16 && ItRFullFP16 < ItRFP16FML) { // Only entangled feature that can be to the right of this +fullfp16 is -fp16fml. // Only append the +fp16fml if there is no -fp16fml after the +fullfp16. if (std::find(features.rbegin(), ItRFullFP16, "-fp16fml") == ItRFullFP16) features.push_back("+fp16fml"); } else goto fp16_fml_fallthrough; } else { fp16_fml_fallthrough: // In both of these cases, putting the 'other' feature on the end of the vector will // result in the same effect as placing it immediately after the current feature. if (ItRNoFullFP16 < ItRFP16FML) features.push_back("-fp16fml"); else if (ItRNoFullFP16 > ItRFP16FML) features.push_back("+fullfp16"); } // Setting -msoft-float/-mfloat-abi=soft, -mfpu=none, or adding +nofp to // -march/-mcpu effectively disables the FPU (GCC ignores the -mfpu options in // this case). Note that the ABI can also be set implicitly by the target // selected. bool HasFPRegs = true; if (abi == FloatABI::Soft) { llvm::ARM::getFPUFeatures(llvm::ARM::FK_NONE, features); // Disable all features relating to hardware FP, not already disabled by the // above call. features.insert(features.end(), {"-dotprod", "-fp16fml", "-bf16", "-mve", "-mve.fp"}); HasFPRegs = false; fpu = llvm::ARM::FK_NONE; } else if (fpu == llvm::ARM::FK_NONE) { // -mfpu=none, -march=armvX+nofp or -mcpu=X+nofp is *very* similar to // -mfloat-abi=soft, only that it should not disable MVE-I. They disable the // FPU, but not the FPU registers, thus MVE-I, which depends only on the // latter, is still supported. features.insert(features.end(), {"-dotprod", "-fp16fml", "-bf16", "-mve.fp"}); HasFPRegs = hasIntegerMVE(features); } if (!HasFPRegs) features.emplace_back("-fpregs"); // Invalid value of the __ARM_FEATURE_MVE macro when an explicit -mfpu= option // disables MVE-FP -mfpu=fpv5-d16 or -mfpu=fpv5-sp-d16 disables the scalar // half-precision floating-point operations feature. Therefore, because the // M-profile Vector Extension (MVE) floating-point feature requires the scalar // half-precision floating-point operations, this option also disables the MVE // floating-point feature: -mve.fp if (fpu == llvm::ARM::FK_FPV5_D16 || fpu == llvm::ARM::FK_FPV5_SP_D16) features.push_back("-mve.fp"); // For Arch >= ARMv8.0 && A or R profile: crypto = sha2 + aes // Rather than replace within the feature vector, determine whether each // algorithm is enabled and append this to the end of the vector. // The algorithms can be controlled by their specific feature or the crypto // feature, so their status can be determined by the last occurance of // either in the vector. This allows one to supercede the other. // e.g. +crypto+noaes in -march/-mcpu should enable sha2, but not aes // FIXME: this needs reimplementation after the TargetParser rewrite bool HasSHA2 = false; bool HasAES = false; const auto ItCrypto = llvm::find_if(llvm::reverse(features), [](const llvm::StringRef F) { return F.contains("crypto"); }); const auto ItSHA2 = llvm::find_if(llvm::reverse(features), [](const llvm::StringRef F) { return F.contains("crypto") || F.contains("sha2"); }); const auto ItAES = llvm::find_if(llvm::reverse(features), [](const llvm::StringRef F) { return F.contains("crypto") || F.contains("aes"); }); const bool FoundSHA2 = ItSHA2 != features.rend(); const bool FoundAES = ItAES != features.rend(); if (FoundSHA2) HasSHA2 = ItSHA2->take_front() == "+"; if (FoundAES) HasAES = ItAES->take_front() == "+"; if (ItCrypto != features.rend()) { if (HasSHA2 && HasAES) features.push_back("+crypto"); else features.push_back("-crypto"); if (HasSHA2) features.push_back("+sha2"); else features.push_back("-sha2"); if (HasAES) features.push_back("+aes"); else features.push_back("-aes"); } if (HasSHA2 || HasAES) { auto ArchSuffix = getLLVMArchSuffixForARM(getCPU(triple), arch, triple); llvm::ARM::ProfileKind ArchProfile = llvm::ARM::parseArchProfile(ArchSuffix); if (!((llvm::ARM::parseArchVersion(ArchSuffix) >= 8) && (ArchProfile == llvm::ARM::ProfileKind::A || ArchProfile == llvm::ARM::ProfileKind::R))) { features.push_back("-sha2"); features.push_back("-aes"); } } // Assume pre-ARMv6 doesn't support unaligned accesses. // // ARMv6 may or may not support unaligned accesses depending on the // SCTLR.U bit, which is architecture-specific. We assume ARMv6 // Darwin and NetBSD targets support unaligned accesses, and others don't. // // ARMv7 always has SCTLR.U set to 1, but it has a new SCTLR.A bit which // raises an alignment fault on unaligned accesses. Assume ARMv7+ supports // unaligned accesses, except ARMv6-M, and ARMv8-M without the Main // Extension. This aligns with the default behavior of ARM's downstream // versions of GCC and Clang. // // Users can change the default behavior via -m[no-]unaliged-access. int versionNum = getARMSubArchVersionNumber(triple); if (triple.isOSDarwin() || triple.isOSNetBSD()) { if (versionNum < 6 || triple.getSubArch() == llvm::Triple::SubArchType::ARMSubArch_v6m) features.push_back("+strict-align"); } else if (triple.getVendor() == llvm::Triple::Apple && triple.isOSBinFormatMachO()) { // Firmwares on Apple platforms are strict-align by default. features.push_back("+strict-align"); } else if (versionNum < 7 || triple.getSubArch() == llvm::Triple::SubArchType::ARMSubArch_v6m || triple.getSubArch() == llvm::Triple::SubArchType::ARMSubArch_v8m_baseline) { features.push_back("+strict-align"); } return join(features); } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/native/targets/arm.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/llvm/native/targets/target.h" namespace codon { namespace ir { class ARM : public Target { public: std::string getCPU(const llvm::Triple &triple) const override; std::string getFeatures(const llvm::Triple &triple) const override; }; } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/native/targets/target.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include "codon/cir/llvm/llvm.h" namespace codon { namespace ir { class Target { public: virtual ~Target() {} virtual std::string getCPU(const llvm::Triple &triple) const = 0; virtual std::string getFeatures(const llvm::Triple &triple) const = 0; }; } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/native/targets/x86.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "x86.h" namespace codon { namespace ir { namespace { template std::string join(const T &v, const std::string &delim = ",") { std::ostringstream s; for (const auto &i : v) { if (&i != &v[0]) s << delim; s << std::string(i); } return s.str(); } } // namespace std::string X86::getCPU(const llvm::Triple &triple) const { auto CPU = llvm::sys::getHostCPUName(); if (!CPU.empty() && CPU != "generic") return std::string(CPU); // Select the default CPU if none was given (or detection failed). if (!triple.isX86()) return ""; // This routine is only handling x86 targets. bool is64Bit = triple.getArch() == llvm::Triple::x86_64; // FIXME: Need target hooks. if (triple.isOSDarwin()) { if (triple.getArchName() == "x86_64h") return "core-avx2"; // macosx10.12 drops support for all pre-Penryn Macs. // Simulators can still run on 10.11 though, like Xcode. if (triple.isMacOSX() && !triple.isOSVersionLT(10, 12)) return "penryn"; if (triple.isDriverKit()) return "nehalem"; // The oldest x86_64 Macs have core2/Merom; the oldest x86 Macs have Yonah. return is64Bit ? "core2" : "yonah"; } // Set up default CPU name for PS4/PS5 compilers. if (triple.isPS4()) return "btver2"; if (triple.isPS5()) return "znver2"; // On Android use targets compatible with gcc if (triple.isAndroid()) return is64Bit ? "x86-64" : "i686"; // Everything else goes to x86-64 in 64-bit mode. if (is64Bit) return "x86-64"; switch (triple.getOS()) { case llvm::Triple::NetBSD: return "i486"; case llvm::Triple::Haiku: case llvm::Triple::OpenBSD: return "i586"; case llvm::Triple::FreeBSD: return "i686"; default: // Fallback to p4. return "pentium4"; } } std::string X86::getFeatures(const llvm::Triple &triple) const { std::vector features; for (auto &f : llvm::sys::getHostCPUFeatures()) { features.push_back((f.second ? "+" : "-") + f.first().str()); } if (triple.getArchName() == "x86_64h") { // x86_64h implies quite a few of the more modern subtarget features // for Haswell class CPUs, but not all of them. Opt-out of a few. features.push_back("-rdrnd"); features.push_back("-aes"); features.push_back("-pclmul"); features.push_back("-rtm"); features.push_back("-fsgsbase"); } const llvm::Triple::ArchType ArchType = triple.getArch(); // Add features to be compatible with gcc for Android. if (triple.isAndroid()) { if (ArchType == llvm::Triple::x86_64) { features.push_back("+sse4.2"); features.push_back("+popcnt"); features.push_back("+cx16"); } else features.push_back("+ssse3"); } return join(features); } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/native/targets/x86.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/llvm/native/targets/target.h" namespace codon { namespace ir { class X86 : public Target { public: std::string getCPU(const llvm::Triple &triple) const override; std::string getFeatures(const llvm::Triple &triple) const override; }; } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/optimize.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "optimize.h" #include #include #include "codon/cir/llvm/gpu.h" #include "codon/cir/llvm/llvisitor.h" #include "codon/cir/llvm/native/native.h" #include "codon/util/common.h" static llvm::codegen::RegisterCodeGenFlags CFG; namespace codon { namespace ir { namespace { llvm::cl::opt AutoFree("auto-free", llvm::cl::desc("Insert free() calls on allocated memory automatically"), llvm::cl::init(false), llvm::cl::Hidden); llvm::cl::opt FastMath("fast-math", llvm::cl::desc("Apply fastmath optimizations"), llvm::cl::init(false)); } // namespace std::unique_ptr getTargetMachine(llvm::Triple triple, llvm::StringRef cpuStr, llvm::StringRef featuresStr, const llvm::TargetOptions &options, bool pic) { std::string err; const llvm::Target *target = llvm::TargetRegistry::lookupTarget(llvm::codegen::getMArch(), triple, err); if (!target) return nullptr; return std::unique_ptr(target->createTargetMachine( triple.getTriple(), cpuStr, featuresStr, options, pic ? llvm::Reloc::Model::PIC_ : llvm::codegen::getExplicitRelocModel(), llvm::codegen::getExplicitCodeModel(), llvm::CodeGenOptLevel::Aggressive)); } std::unique_ptr getTargetMachine(llvm::Module *module, bool setFunctionAttributes, bool pic) { llvm::Triple moduleTriple(module->getTargetTriple()); std::string cpuStr, featuresStr; const llvm::TargetOptions options = llvm::codegen::InitTargetOptionsFromCodeGenFlags(moduleTriple); llvm::TargetLibraryInfoImpl tlii(moduleTriple); if (moduleTriple.getArch()) { cpuStr = llvm::codegen::getCPUStr(); featuresStr = llvm::codegen::getFeaturesStr(); auto machine = getTargetMachine(moduleTriple, cpuStr, featuresStr, options); if (setFunctionAttributes) llvm::codegen::setFunctionAttributes(cpuStr, featuresStr, *module); return machine; } return {}; } namespace { void applyDebugTransformations(llvm::Module *module, bool debug, bool jit) { if (debug) { auto ctor = LLVMVisitor::getGlobalCtorName(); // remove tail calls and fix linkage for stack traces for (auto &f : *module) { // needed for debug symbols if (!jit && f.getName() != ctor) f.setLinkage(llvm::GlobalValue::ExternalLinkage); if (!f.hasFnAttribute(llvm::Attribute::AttrKind::AlwaysInline)) f.addFnAttr(llvm::Attribute::AttrKind::NoInline); f.setUWTableKind(llvm::UWTableKind::Default); f.addFnAttr("no-frame-pointer-elim", "true"); f.addFnAttr("no-frame-pointer-elim-non-leaf"); f.addFnAttr("no-jump-tables", "false"); for (auto &block : f) { for (auto &inst : block) { if (auto *call = llvm::dyn_cast(&inst)) { call->setTailCall(false); } } } } } else { llvm::StripDebugInfo(*module); } } void applyFastMathTransformations(llvm::Module *module) { if (!FastMath) return; for (auto &f : *module) { for (auto &block : f) { for (auto &inst : block) { if (auto *binop = llvm::dyn_cast(&inst)) { if (binop->getType()->isFloatingPointTy()) binop->setFast(true); } if (auto *intrinsic = llvm::dyn_cast(&inst)) { if (intrinsic->getType()->isFloatingPointTy()) intrinsic->setFast(true); } } } } } struct AllocInfo { std::vector allocators; std::string realloc; std::string free; AllocInfo(std::vector allocators, const std::string &realloc, const std::string &free) : allocators(std::move(allocators)), realloc(realloc), free(free) {} static bool getFixedArg(llvm::CallBase &cb, uint64_t &size, unsigned idx = 0) { if (cb.arg_empty()) return false; if (auto *ci = llvm::dyn_cast(cb.getArgOperand(idx))) { size = ci->getZExtValue(); return true; } return false; } bool isAlloc(const llvm::Value *value) { if (auto *func = getCalledFunction(value)) { return func->arg_size() == 1 && std::find(allocators.begin(), allocators.end(), func->getName()) != allocators.end(); } return false; } bool isRealloc(const llvm::Value *value) { if (auto *func = getCalledFunction(value)) { // Note: 3 args are (ptr, new_size, old_size) return func->arg_size() == 3 && func->getName() == realloc; } return false; } bool isFree(const llvm::Value *value) { if (auto *func = getCalledFunction(value)) { return func->arg_size() == 1 && func->getName() == free; } return false; } static const llvm::Function *getCalledFunction(const llvm::Value *value) { // Don't care about intrinsics in this case. if (llvm::isa(value)) return nullptr; const auto *cb = llvm::dyn_cast(value); if (!cb) return nullptr; if (const llvm::Function *callee = cb->getCalledFunction()) return callee; return nullptr; } bool isNeverEqualToUnescapedAlloc(llvm::Value *value, llvm::Instruction *ai) { using namespace llvm; if (isa(value)) return true; if (auto *li = dyn_cast(value)) return isa(li->getPointerOperand()); // Two distinct allocations will never be equal. return isAlloc(value) && value != ai; } bool isAllocSiteRemovable(llvm::Instruction *ai, llvm::SmallVectorImpl &users) { using namespace llvm; // Should never be an invoke, so just check right away. if (isa(ai)) return false; SmallVector worklist; worklist.push_back(ai); do { Instruction *pi = worklist.pop_back_val(); for (User *u : pi->users()) { Instruction *instr = cast(u); switch (instr->getOpcode()) { default: // Give up the moment we see something we can't handle. return false; case Instruction::AddrSpaceCast: case Instruction::BitCast: case Instruction::GetElementPtr: users.emplace_back(instr); worklist.push_back(instr); continue; case Instruction::ICmp: { ICmpInst *cmp = cast(instr); // We can fold eq/ne comparisons with null to false/true, respectively. // We also fold comparisons in some conditions provided the alloc has // not escaped (see isNeverEqualToUnescapedAlloc). if (!cmp->isEquality()) return false; unsigned otherIndex = (cmp->getOperand(0) == pi) ? 1 : 0; if (!isNeverEqualToUnescapedAlloc(cmp->getOperand(otherIndex), ai)) return false; users.emplace_back(instr); continue; } case Instruction::Call: // Ignore no-op and store intrinsics. if (IntrinsicInst *intrinsic = dyn_cast(instr)) { switch (intrinsic->getIntrinsicID()) { default: return false; case Intrinsic::memmove: case Intrinsic::memcpy: case Intrinsic::memset: { MemIntrinsic *MI = cast(intrinsic); if (MI->isVolatile() || MI->getRawDest() != pi) return false; LLVM_FALLTHROUGH; } case Intrinsic::assume: case Intrinsic::invariant_start: case Intrinsic::invariant_end: case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: users.emplace_back(instr); continue; case Intrinsic::launder_invariant_group: case Intrinsic::strip_invariant_group: users.emplace_back(instr); worklist.push_back(instr); continue; } } if (isFree(instr)) { users.emplace_back(instr); continue; } if (isRealloc(instr)) { users.emplace_back(instr); worklist.push_back(instr); continue; } return false; case Instruction::Store: { StoreInst *si = cast(instr); if (si->isVolatile() || si->getPointerOperand() != pi) return false; users.emplace_back(instr); continue; } } seqassertn(false, "missing a return?"); } } while (!worklist.empty()); return true; } bool isAllocSiteDemotable(llvm::Instruction *ai, uint64_t &size, llvm::SmallVectorImpl &users, uint64_t maxSize = 1024) { using namespace llvm; // Should never be an invoke, so just check right away. if (isa(ai)) return false; if (!(getFixedArg(*dyn_cast(&*ai), size) && 0 < size && size <= maxSize)) return false; SmallVector worklist; worklist.push_back(ai); do { Instruction *pi = worklist.pop_back_val(); for (User *u : pi->users()) { Instruction *instr = cast(u); switch (instr->getOpcode()) { default: // Give up the moment we see something we can't handle. return false; case Instruction::AddrSpaceCast: case Instruction::BitCast: case Instruction::GetElementPtr: worklist.push_back(instr); continue; case Instruction::ICmp: { ICmpInst *cmp = cast(instr); // We can fold eq/ne comparisons with null to false/true, respectively. // We also fold comparisons in some conditions provided the alloc has // not escaped (see isNeverEqualToUnescapedAlloc). if (!cmp->isEquality()) return false; unsigned otherIndex = (cmp->getOperand(0) == pi) ? 1 : 0; if (!isNeverEqualToUnescapedAlloc(cmp->getOperand(otherIndex), ai)) return false; continue; } case Instruction::Call: // Ignore no-op and store intrinsics. if (IntrinsicInst *intrinsic = dyn_cast(instr)) { switch (intrinsic->getIntrinsicID()) { default: return false; case Intrinsic::memmove: case Intrinsic::memcpy: case Intrinsic::memset: { MemIntrinsic *MI = cast(intrinsic); if (MI->isVolatile()) return false; LLVM_FALLTHROUGH; } case Intrinsic::assume: case Intrinsic::invariant_start: case Intrinsic::invariant_end: case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: users.emplace_back(instr); continue; case Intrinsic::launder_invariant_group: case Intrinsic::strip_invariant_group: users.emplace_back(instr); worklist.push_back(instr); continue; } } if (isFree(instr)) { users.emplace_back(instr); continue; } if (isRealloc(instr)) { // If the realloc also has constant small size, // then we can just update the assumed size to be // max of original alloc's and this realloc's. uint64_t newSize = 0; if (getFixedArg(*dyn_cast(instr), newSize, 1) && (0 < newSize && newSize <= maxSize)) { size = std::max(size, newSize); } else { return false; } users.emplace_back(instr); worklist.push_back(instr); continue; } return false; case Instruction::Store: { StoreInst *si = cast(instr); if (si->isVolatile() || si->getPointerOperand() != pi) return false; continue; } case Instruction::Load: { LoadInst *li = cast(instr); if (li->isVolatile()) return false; continue; } } seqassertn(false, "missing a return?"); } } while (!worklist.empty()); return true; } bool isAllocSiteHoistable(llvm::Instruction *ai, llvm::Loop &loop, llvm::CycleInfo &cycles) { using namespace llvm; auto inIrreducibleCycle = [&](Instruction *ins) { auto *cycle = cycles.getCycle(ins->getParent()); while (cycle) { if (!cycle->isReducible()) return true; cycle = cycle->getParentCycle(); } return false; }; auto anySubLoopContains = [&](Instruction *ins) { for (auto *sub : loop.getSubLoops()) { if (sub->contains(ins)) return true; } return false; }; // Some preliminary checks auto *parent = ai->getParent(); if (isa(ai) || !loop.hasLoopInvariantOperands(ai) || anySubLoopContains(ai) || inIrreducibleCycle(ai)) return false; // Need to track insertvalue/extractvalue to make this effective. // This maps each "insertvalue" of the pointer (or derived value) // to a list of indices at which it is inserted (usually there will // be just one). SmallDenseMap, 1>> inserts; std::deque worklist; SmallSet visited; auto add_to_worklist = [&](Instruction *instr) { if (!visited.contains(instr)) { visited.insert(instr); worklist.push_front(instr); } }; add_to_worklist(ai); do { Instruction *pi = worklist.back(); worklist.pop_back(); for (User *u : pi->users()) { Instruction *instr = cast(u); if (!loop.contains(instr)) return false; switch (instr->getOpcode()) { default: // Give up the moment we see something we can't handle. return false; case Instruction::PHI: if (instr->getParent() == loop.getHeader()) return false; LLVM_FALLTHROUGH; case Instruction::PtrToInt: case Instruction::IntToPtr: case Instruction::Add: case Instruction::Sub: case Instruction::AddrSpaceCast: case Instruction::BitCast: case Instruction::GetElementPtr: add_to_worklist(instr); continue; case Instruction::InsertValue: { auto *op0 = instr->getOperand(0); auto *op1 = instr->getOperand(1); if (isa(op0) || isa(op0) || isa(op0)) { // Add for this insertvalue if (op1 == pi) { auto *insertValueInst = cast(instr); inserts[instr].push_back(insertValueInst->getIndices()); } // Add for previous insertvalue if (auto *instrOp = dyn_cast(op0)) { auto it = inserts.find(instrOp); if (it != inserts.end()) inserts[instr].append(it->second); } } add_to_worklist(instr); continue; } case Instruction::ExtractValue: { auto *extractValueInst = cast(instr); auto it = inserts.end(); if (auto *instrOp = dyn_cast(instr->getOperand(0))) it = inserts.find(instrOp); if (it != inserts.end()) { for (auto &indices : it->second) { if (indices == extractValueInst->getIndices()) { add_to_worklist(instr); break; } } } else { add_to_worklist(instr); } continue; } case Instruction::Freeze: { if (auto *instrOp = dyn_cast(instr->getOperand(0))) { auto it = inserts.find(instrOp); if (it != inserts.end()) inserts[instr] = it->second; } add_to_worklist(instr); continue; } case Instruction::ICmp: continue; case Instruction::Call: case Instruction::Invoke: // Ignore no-op and store intrinsics. if (IntrinsicInst *intrinsic = dyn_cast(instr)) { switch (intrinsic->getIntrinsicID()) { default: return false; case Intrinsic::memmove: case Intrinsic::memcpy: case Intrinsic::memset: { MemIntrinsic *MI = cast(intrinsic); if (MI->isVolatile()) return false; LLVM_FALLTHROUGH; } case Intrinsic::assume: case Intrinsic::invariant_start: case Intrinsic::invariant_end: case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: continue; case Intrinsic::launder_invariant_group: case Intrinsic::strip_invariant_group: add_to_worklist(instr); continue; } } return false; case Instruction::Store: { StoreInst *si = cast(instr); if (si->isVolatile() || si->getPointerOperand() != pi) return false; continue; } case Instruction::Load: { LoadInst *li = cast(instr); if (li->isVolatile()) return false; continue; } } seqassertn(false, "missing a return?"); } } while (!worklist.empty()); return true; } }; /// Lowers allocations of known, small size to alloca when possible. /// Also removes unused allocations. struct AllocationRemover : public llvm::PassInfoMixin { AllocInfo info; explicit AllocationRemover( std::vector allocators = {"seq_alloc", "seq_alloc_atomic", "seq_alloc_uncollectable", "seq_alloc_atomic_uncollectable"}, const std::string &realloc = "seq_realloc", const std::string &free = "seq_free") : info(std::move(allocators), realloc, free) {} void getErasesAndReplacementsForAlloc( llvm::Instruction &mi, llvm::SmallPtrSetImpl &erase, llvm::SmallVectorImpl> &replace, llvm::SmallVectorImpl &alloca, llvm::SmallVectorImpl &untail) { using namespace llvm; uint64_t size = 0; SmallVector users; if (info.isAllocSiteRemovable(&mi, users)) { for (unsigned i = 0, e = users.size(); i != e; ++i) { if (!users[i]) continue; Instruction *instr = cast(&*users[i]); if (ICmpInst *cmp = dyn_cast(instr)) { replace.emplace_back(cmp, ConstantInt::get(Type::getInt1Ty(cmp->getContext()), cmp->isFalseWhenEqual())); } else if (!isa(instr)) { // Casts, GEP, or anything else: we're about to delete this instruction, // so it can not have any valid uses. replace.emplace_back(instr, PoisonValue::get(instr->getType())); } erase.insert(instr); } erase.insert(&mi); return; } else { users.clear(); } if (info.isAllocSiteDemotable(&mi, size, users)) { auto *replacement = new AllocaInst( Type::getInt8Ty(mi.getContext()), 0, ConstantInt::get(Type::getInt64Ty(mi.getContext()), size), Align()); alloca.push_back(replacement); replace.emplace_back(&mi, replacement); erase.insert(&mi); for (unsigned i = 0, e = users.size(); i != e; ++i) { if (!users[i]) continue; Instruction *instr = cast(&*users[i]); if (info.isFree(instr)) { erase.insert(instr); } else if (info.isRealloc(instr)) { replace.emplace_back(instr, replacement); erase.insert(instr); } else if (auto *ci = dyn_cast(&*instr)) { if (ci->isTailCall() || ci->isMustTailCall()) untail.push_back(ci); } } } } llvm::PreservedAnalyses run(llvm::Function &func, llvm::FunctionAnalysisManager &am) { using namespace llvm; SmallSet erase; SmallVector, 32> replace; SmallVector alloca; SmallVector untail; for (inst_iterator instr = inst_begin(func), end = inst_end(func); instr != end; ++instr) { auto *cb = dyn_cast(&*instr); if (!cb || !info.isAlloc(cb)) continue; getErasesAndReplacementsForAlloc(*cb, erase, replace, alloca, untail); } for (auto *A : alloca) { A->insertBefore(func.getEntryBlock().getFirstNonPHI()); } for (auto *C : untail) { C->setTailCall(false); } for (auto &P : replace) { P.first->replaceAllUsesWith(P.second); } for (auto *I : erase) { I->dropAllReferences(); } for (auto *I : erase) { I->eraseFromParent(); } if (!erase.empty() || !replace.empty() || !alloca.empty() || !untail.empty()) return PreservedAnalyses::none(); else return PreservedAnalyses::all(); } }; /// Hoists allocations that are inside a loop out of the loop. struct AllocationHoister : public llvm::PassInfoMixin { AllocInfo info; explicit AllocationHoister(std::vector allocators = {"seq_alloc", "seq_alloc_atomic"}, const std::string &realloc = "seq_realloc", const std::string &free = "seq_free") : info(std::move(allocators), realloc, free) {} bool processLoop(llvm::Loop &loop, llvm::LoopInfo &loops, llvm::CycleInfo &cycles, llvm::PostDominatorTree &postdom) { llvm::SmallSet hoist; for (auto *block : loop.blocks()) { for (auto &ins : *block) { if (info.isAlloc(&ins) && info.isAllocSiteHoistable(&ins, loop, cycles)) hoist.insert(llvm::cast(&ins)); } } if (hoist.empty()) return false; auto *preheader = loop.getLoopPreheader(); auto *terminator = preheader->getTerminator(); auto *parent = preheader->getParent(); auto *M = preheader->getModule(); auto &C = preheader->getContext(); llvm::IRBuilder<> B(C); auto *ptr = B.getPtrTy(); llvm::DomTreeUpdater dtu(postdom, llvm::DomTreeUpdater::UpdateStrategy::Lazy); for (auto *ins : hoist) { if (postdom.dominates(ins, preheader->getTerminator())) { // Simple case - loop must execute allocation, so // just hoist it directly. ins->removeFromParent(); ins->insertBefore(preheader->getTerminator()); } else { // Complex case - loop might not execute allocation, // so have to keep it where it is but cache it. // Transformation is as follows: // Before: // p = alloc(n) // After: // if cache is null: // p = alloc(n) // cache = p // else: // p = cache B.SetInsertPointPastAllocas(parent); auto *cache = B.CreateAlloca(ptr); cache->setName("alloc_hoist.cache"); B.CreateStore(llvm::ConstantPointerNull::get(ptr), cache); B.SetInsertPoint(ins); auto *cachedAlloc = B.CreateLoad(ptr, cache); // Split the block at the call site llvm::BasicBlock *allocYes = nullptr; llvm::BasicBlock *allocNo = nullptr; llvm::SplitBlockAndInsertIfThenElse(B.CreateIsNull(cachedAlloc), ins, &allocYes, &allocNo, /*UnreachableThen=*/false, /*UnreachableElse=*/false, /*BranchWeights=*/nullptr, &dtu, &loops); B.SetInsertPoint(&allocYes->getSingleSuccessor()->front()); llvm::PHINode *phi = B.CreatePHI(ptr, 2); ins->replaceAllUsesWith(phi); ins->removeFromParent(); ins->insertBefore(allocYes->getTerminator()); B.SetInsertPoint(allocYes->getTerminator()); B.CreateStore(ins, cache); phi->addIncoming(ins, allocYes); phi->addIncoming(cachedAlloc, allocNo); } } dtu.flush(); return true; } llvm::PreservedAnalyses run(llvm::Function &F, llvm::FunctionAnalysisManager &am) { auto &loops = am.getResult(F); auto &cycles = am.getResult(F); auto &postdom = am.getResult(F); bool changed = false; llvm::SmallPriorityWorklist worklist; llvm::appendLoopsToWorklist(loops, worklist); while (!worklist.empty()) changed |= processLoop(*worklist.pop_back_val(), loops, cycles, postdom); return changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all(); } }; struct AllocationAutoFree : public llvm::PassInfoMixin { AllocInfo info; explicit AllocationAutoFree( std::vector allocators = {"seq_alloc", "seq_alloc_atomic", "seq_alloc_uncollectable", "seq_alloc_atomic_uncollectable"}, const std::string &realloc = "seq_realloc", const std::string &free = "seq_free") : info(std::move(allocators), realloc, free) {} llvm::PreservedAnalyses run(llvm::Function &F, llvm::FunctionAnalysisManager &FAM) { // Get the necessary analysis results. auto &MSSA = FAM.getResult(F); auto &TLI = FAM.getResult(F); auto &AA = FAM.getResult(F); auto &DT = FAM.getResult(F); auto &PDT = FAM.getResult(F); auto &LI = FAM.getResult(F); auto &CI = FAM.getResult(F); bool Changed = false; // Traverse the function to find allocs and insert corresponding frees. for (auto &BB : F) { for (auto &I : BB) { if (auto *Alloc = llvm::dyn_cast(&I)) { auto *Callee = Alloc->getCalledFunction(); if (!Callee || !Callee->isDeclaration()) continue; if (info.isAlloc(Alloc)) { if (llvm::PointerMayBeCaptured(Alloc, /*ReturnCaptures=*/true, /*StoreCaptures=*/true)) continue; Changed |= insertFree(Alloc, F, DT, PDT, LI, CI); } } } } return (Changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all()); } bool insertFree(llvm::Instruction *Alloc, llvm::Function &F, llvm::DominatorTree &DT, llvm::PostDominatorTree &PDT, llvm::LoopInfo &LI, llvm::CycleInfo &CI) { llvm::SmallVector Worklist; llvm::SmallPtrSet Visited; llvm::SmallVector UseBlocks; // We need to find a basic block that: // 1. Post-dominates the allocation block (so we always free it) // 2. Is dominated by the allocation block (so the use is valid) // 3. Post-dominates all uses // Start with the original pointer. Worklist.push_back(Alloc); UseBlocks.push_back(Alloc->getParent()); // Track all blocks where the pointer or its derived values are used. while (!Worklist.empty()) { auto *CurrentPtr = Worklist.pop_back_val(); if (!Visited.insert(CurrentPtr).second) continue; // Traverse all users of the current pointer. for (auto *U : CurrentPtr->users()) { if (auto *Inst = llvm::dyn_cast(U)) { if (auto *call = llvm::dyn_cast(Inst)) if (call->getCalledFunction() && info.isFree(call->getCalledFunction())) return false; if (llvm::isa(Inst) || llvm::isa(Inst) || llvm::isa(Inst) || llvm::isa(Inst)) { Worklist.push_back(Inst); } else { // If this is a real use, record the block. UseBlocks.push_back(Inst->getParent()); } } } } // Find the closest post-dominating block of all the use blocks. llvm::BasicBlock *PostDomBlock = nullptr; for (auto *BB : UseBlocks) { if (!PostDomBlock) { PostDomBlock = BB; } else { PostDomBlock = PDT.findNearestCommonDominator(PostDomBlock, BB); if (!PostDomBlock) { return false; } } } auto *allocLoop = LI.getLoopFor(Alloc->getParent()); auto *freeLoop = LI.getLoopFor(PostDomBlock); while (allocLoop != freeLoop) { if (!freeLoop) return false; PostDomBlock = freeLoop->getExitBlock(); if (!PostDomBlock) return false; freeLoop = LI.getLoopFor(PostDomBlock); } if (!DT.dominates(Alloc->getParent(), PostDomBlock)) { return false; } llvm::IRBuilder<> B(PostDomBlock->getTerminator()); auto *FreeFunc = F.getParent()->getFunction(info.free); if (!FreeFunc) { FreeFunc = llvm::Function::Create( llvm::FunctionType::get(B.getVoidTy(), {B.getPtrTy()}, false), llvm::Function::ExternalLinkage, info.free, F.getParent()); FreeFunc->setWillReturn(); FreeFunc->setDoesNotThrow(); } // Add free B.CreateCall(FreeFunc, Alloc); return true; } }; /// Sometimes coroutine lowering produces hard-to-analyze loops involving /// function pointer comparisons. This pass puts them into a somewhat /// easier-to-analyze form. struct CoroBranchSimplifier : public llvm::PassInfoMixin { static llvm::Value *getNonNullOperand(llvm::Value *op1, llvm::Value *op2) { auto *ptr = llvm::dyn_cast(op1->getType()); if (!ptr) return nullptr; auto *c1 = llvm::dyn_cast(op1); auto *c2 = llvm::dyn_cast(op2); const bool isNull1 = (c1 && c1->isNullValue()); const bool isNull2 = (c2 && c2->isNullValue()); if (!(isNull1 ^ isNull2)) return nullptr; return isNull1 ? op2 : op1; } llvm::PreservedAnalyses run(llvm::Loop &loop, llvm::LoopAnalysisManager &am, llvm::LoopStandardAnalysisResults &ar, llvm::LPMUpdater &u) { if (auto *exit = loop.getExitingBlock()) { if (auto *br = llvm::dyn_cast(exit->getTerminator())) { if (!br->isConditional() || br->getNumSuccessors() != 2 || loop.contains(br->getSuccessor(0)) || !loop.contains(br->getSuccessor(1))) return llvm::PreservedAnalyses::all(); auto *cond = br->getCondition(); if (auto *cmp = llvm::dyn_cast(cond)) { if (cmp->getPredicate() != llvm::CmpInst::Predicate::ICMP_EQ) return llvm::PreservedAnalyses::all(); if (auto *f = getNonNullOperand(cmp->getOperand(0), cmp->getOperand(1))) { if (auto *sel = llvm::dyn_cast(f)) { if (auto *g = getNonNullOperand(sel->getTrueValue(), sel->getFalseValue())) { // If we can deduce that g is not null, we can replace the condition. if (auto *phi = llvm::dyn_cast(g)) { bool ok = true; for (unsigned i = 0; i < phi->getNumIncomingValues(); i++) { auto *phiBlock = phi->getIncomingBlock(i); auto *phiValue = phi->getIncomingValue(i); if (auto *c = llvm::dyn_cast(phiValue)) { if (c->isNullValue()) { ok = false; break; } } else { // There is no way for the value to be null if the incoming phi // value is predicated on this exit condition, which checks for a // non-null function pointer. if (phiBlock != exit || phiValue != f) { ok = false; break; } } } if (!ok) return llvm::PreservedAnalyses::all(); br->setCondition(sel->getCondition()); return llvm::PreservedAnalyses::none(); } } } } } } } return llvm::PreservedAnalyses::all(); } }; llvm::cl::opt DisableNative("disable-native", llvm::cl::desc("Disable architecture-specific optimizations"), llvm::cl::init(false)); void runLLVMOptimizationPasses(llvm::Module *module, bool debug, bool jit, PluginManager *plugins) { applyDebugTransformations(module, debug, jit); applyFastMathTransformations(module); llvm::LoopAnalysisManager lam; llvm::FunctionAnalysisManager fam; llvm::CGSCCAnalysisManager cgam; llvm::ModuleAnalysisManager mam; auto machine = getTargetMachine(module, /*setFunctionAttributes=*/true); llvm::PassBuilder pb(machine.get()); llvm::Triple moduleTriple(module->getTargetTriple()); llvm::TargetLibraryInfoImpl tlii(moduleTriple); fam.registerPass([&] { return llvm::TargetLibraryAnalysis(tlii); }); pb.registerModuleAnalyses(mam); pb.registerCGSCCAnalyses(cgam); pb.registerFunctionAnalyses(fam); pb.registerLoopAnalyses(lam); pb.crossRegisterProxies(lam, fam, cgam, mam); pb.registerLateLoopOptimizationsEPCallback( [&](llvm::LoopPassManager &pm, llvm::OptimizationLevel opt) { if (opt.isOptimizingForSpeed()) pm.addPass(CoroBranchSimplifier()); }); pb.registerPeepholeEPCallback( [&](llvm::FunctionPassManager &pm, llvm::OptimizationLevel opt) { if (opt.isOptimizingForSpeed()) { pm.addPass(AllocationRemover()); pm.addPass(llvm::LoopSimplifyPass()); pm.addPass(llvm::LCSSAPass()); pm.addPass(AllocationHoister()); if (AutoFree) pm.addPass(AllocationAutoFree()); } }); if (!DisableNative) addNativeLLVMPasses(&pb); if (plugins) { for (auto *plugin : *plugins) { plugin->dsl->addLLVMPasses(&pb, debug); } } if (debug) { llvm::ModulePassManager mpm = pb.buildO0DefaultPipeline(llvm::OptimizationLevel::O0); mpm.run(*module, mam); } else { llvm::ModulePassManager mpm = pb.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3); mpm.run(*module, mam); } applyDebugTransformations(module, debug, jit); } void verify(llvm::Module *module) { const bool broken = llvm::verifyModule(*module, &llvm::errs()); if (broken) { auto fo = fopen("_dump.ll", "w"); llvm::raw_fd_ostream fout(fileno(fo), true); fout << *module; fout.close(); } seqassertn(!broken, "Generated LLVM IR is invalid and has been dumped to '_dump.ll'. " "Please submit a bug report at https://github.com/exaloop/codon " "including the code and generated LLVM IR."); } } // namespace void optimize(llvm::Module *module, bool debug, bool jit, PluginManager *plugins) { verify(module); { TIME("llvm/opt1"); runLLVMOptimizationPasses(module, debug, jit, plugins); } if (!debug) { TIME("llvm/opt2"); runLLVMOptimizationPasses(module, debug, jit, plugins); } { TIME("llvm/gpu"); applyGPUTransformations(module); } verify(module); } bool isFastMathOn() { return FastMath; } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/llvm/optimize.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/cir/llvm/llvm.h" #include "codon/dsl/plugins.h" namespace codon { namespace ir { std::unique_ptr getTargetMachine(llvm::Triple triple, llvm::StringRef cpuStr, llvm::StringRef featuresStr, const llvm::TargetOptions &options, bool pic = false); std::unique_ptr getTargetMachine(llvm::Module *module, bool setFunctionAttributes = false, bool pic = false); void optimize(llvm::Module *module, bool debug, bool jit = false, PluginManager *plugins = nullptr); bool isFastMathOn(); } // namespace ir } // namespace codon ================================================ FILE: codon/cir/module.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "module.h" #include #include #include "codon/cir/func.h" #include "codon/parser/cache.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon { namespace ir { namespace { std::vector translateGenerics(codon::ast::Cache *cache, std::vector &generics) { std::vector ret; for (auto &g : generics) { seqassertn(g.isStatic() || g.getTypeValue(), "generic must be static or a type"); if (g.isStaticStr()) ret.push_back(std::make_shared( std::make_shared( cache, g.getStaticStringValue()))); else if (g.isStatic()) ret.push_back(std::make_shared( std::make_shared(cache, g.getStaticValue()))); else ret.push_back(std::make_shared( g.getTypeValue()->getAstType())); } return ret; } std::vector generateDummyNames(std::vector &types) { std::vector ret; for (auto *t : types) { seqassertn(t->getAstType(), "{} must have an ast type", *t); ret.emplace_back(t->getAstType().get()); } return ret; } std::vector translateArgs(codon::ast::Cache *cache, std::vector &types) { std::vector ret = { std::make_shared( cache, codon::ast::types::LinkType::Kind::Unbound, 0)}; for (auto *t : types) { seqassertn(t->getAstType(), "{} must have an ast type", *t); if (auto f = t->getAstType()->getFunc()) { auto *irType = cast(t); std::vector mask(std::distance(irType->begin(), irType->end()), 0); ret.push_back(t->getAstType()); } else { ret.push_back(t->getAstType()); } } return ret; } } // namespace const std::string Module::VOID_NAME = "void"; const std::string Module::BOOL_NAME = "bool"; const std::string Module::BYTE_NAME = "byte"; const std::string Module::INT_NAME = "int"; const std::string Module::FLOAT_NAME = "float"; const std::string Module::FLOAT32_NAME = "float32"; const std::string Module::FLOAT16_NAME = "float16"; const std::string Module::BFLOAT16_NAME = "bfloat16"; const std::string Module::FLOAT128_NAME = "float128"; const std::string Module::STRING_NAME = "str"; const std::string Module::EQ_MAGIC_NAME = "__eq__"; const std::string Module::NE_MAGIC_NAME = "__ne__"; const std::string Module::LT_MAGIC_NAME = "__lt__"; const std::string Module::GT_MAGIC_NAME = "__gt__"; const std::string Module::LE_MAGIC_NAME = "__le__"; const std::string Module::GE_MAGIC_NAME = "__ge__"; const std::string Module::POS_MAGIC_NAME = "__pos__"; const std::string Module::NEG_MAGIC_NAME = "__neg__"; const std::string Module::INVERT_MAGIC_NAME = "__invert__"; const std::string Module::ABS_MAGIC_NAME = "__abs__"; const std::string Module::ADD_MAGIC_NAME = "__add__"; const std::string Module::SUB_MAGIC_NAME = "__sub__"; const std::string Module::MUL_MAGIC_NAME = "__mul__"; const std::string Module::MATMUL_MAGIC_NAME = "__matmul__"; const std::string Module::TRUE_DIV_MAGIC_NAME = "__truediv__"; const std::string Module::FLOOR_DIV_MAGIC_NAME = "__floordiv__"; const std::string Module::MOD_MAGIC_NAME = "__mod__"; const std::string Module::POW_MAGIC_NAME = "__pow__"; const std::string Module::LSHIFT_MAGIC_NAME = "__lshift__"; const std::string Module::RSHIFT_MAGIC_NAME = "__rshift__"; const std::string Module::AND_MAGIC_NAME = "__and__"; const std::string Module::OR_MAGIC_NAME = "__or__"; const std::string Module::XOR_MAGIC_NAME = "__xor__"; const std::string Module::IADD_MAGIC_NAME = "__iadd__"; const std::string Module::ISUB_MAGIC_NAME = "__isub__"; const std::string Module::IMUL_MAGIC_NAME = "__imul__"; const std::string Module::IMATMUL_MAGIC_NAME = "__imatmul__"; const std::string Module::ITRUE_DIV_MAGIC_NAME = "__itruediv__"; const std::string Module::IFLOOR_DIV_MAGIC_NAME = "__ifloordiv__"; const std::string Module::IMOD_MAGIC_NAME = "__imod__"; const std::string Module::IPOW_MAGIC_NAME = "__ipow__"; const std::string Module::ILSHIFT_MAGIC_NAME = "__ilshift__"; const std::string Module::IRSHIFT_MAGIC_NAME = "__irshift__"; const std::string Module::IAND_MAGIC_NAME = "__iand__"; const std::string Module::IOR_MAGIC_NAME = "__ior__"; const std::string Module::IXOR_MAGIC_NAME = "__ixor__"; const std::string Module::RADD_MAGIC_NAME = "__radd__"; const std::string Module::RSUB_MAGIC_NAME = "__rsub__"; const std::string Module::RMUL_MAGIC_NAME = "__rmul__"; const std::string Module::RMATMUL_MAGIC_NAME = "__rmatmul__"; const std::string Module::RTRUE_DIV_MAGIC_NAME = "__rtruediv__"; const std::string Module::RFLOOR_DIV_MAGIC_NAME = "__rfloordiv__"; const std::string Module::RMOD_MAGIC_NAME = "__rmod__"; const std::string Module::RPOW_MAGIC_NAME = "__rpow__"; const std::string Module::RLSHIFT_MAGIC_NAME = "__rlshift__"; const std::string Module::RRSHIFT_MAGIC_NAME = "__rrshift__"; const std::string Module::RAND_MAGIC_NAME = "__rand__"; const std::string Module::ROR_MAGIC_NAME = "__ror__"; const std::string Module::RXOR_MAGIC_NAME = "__rxor__"; const std::string Module::INT_MAGIC_NAME = "__int__"; const std::string Module::FLOAT_MAGIC_NAME = "__float__"; const std::string Module::BOOL_MAGIC_NAME = "__bool__"; const std::string Module::STR_MAGIC_NAME = "__str__"; const std::string Module::REPR_MAGIC_NAME = "__repr__"; const std::string Module::CALL_MAGIC_NAME = "__call__"; const std::string Module::GETITEM_MAGIC_NAME = "__getitem__"; const std::string Module::SETITEM_MAGIC_NAME = "__setitem__"; const std::string Module::ITER_MAGIC_NAME = "__iter__"; const std::string Module::LEN_MAGIC_NAME = "__len__"; const std::string Module::NEW_MAGIC_NAME = "__new__"; const std::string Module::INIT_MAGIC_NAME = "__init__"; const char Module::NodeId = 0; Module::Module(const std::string &name) : AcceptorExtend(name) { mainFunc = std::make_unique("main"); mainFunc->realize(cast(unsafeGetDummyFuncType()), {}); mainFunc->setModule(this); mainFunc->setReplaceable(false); argVar = std::make_unique(unsafeGetArrayType(getStringType()), /*global=*/true, /*external=*/false, /*tls=*/false, ".argv"); argVar->setModule(this); argVar->setReplaceable(false); } void Module::parseCode(const std::string &code) { cache->parseCode(code); } Func *Module::getOrRealizeMethod(types::Type *parent, const std::string &methodName, std::vector args, std::vector generics) { auto cls = std::const_pointer_cast(parent->getAstType())->getClass(); auto method = cache->findMethod(cls, methodName, generateDummyNames(args)); if (!method) return nullptr; try { return cache->realizeFunction(method, translateArgs(cache, args), translateGenerics(cache, generics), cls); } catch (const exc::ParserException &e) { for (auto &trace : e.getErrors()) for (auto &msg : trace) LOG_IR("getOrRealizeMethod parser error at {}: {}", msg.getSrcInfo(), msg.getMessage()); return nullptr; } } Func *Module::getOrRealizeFunc(const std::string &funcName, std::vector args, std::vector generics, const std::string &module) { auto fqName = module.empty() ? funcName : fmt::format(FMT_STRING("{}.{}"), module, funcName); auto func = cache->findFunction(fqName); if (!func) func = cache->findFunction(fqName + ".0"); if (!func) return nullptr; auto arg = translateArgs(cache, args); auto gens = translateGenerics(cache, generics); try { return cache->realizeFunction(func, arg, gens); } catch (const exc::ParserException &e) { for (auto &trace : e.getErrors()) for (auto &msg : trace) LOG("getOrRealizeFunc parser error at {}: {}", msg.getSrcInfo(), msg.getMessage()); return nullptr; } } types::Type *Module::getOrRealizeType(const std::string &typeName, std::vector generics) { auto type = cache->findClass(typeName); if (!type) return nullptr; try { return cache->realizeType(type, translateGenerics(cache, generics)); } catch (const exc::ParserException &e) { for (auto &trace : e.getErrors()) for (auto &msg : trace) LOG_IR("getOrRealizeType parser error at {}: {}", msg.getSrcInfo(), msg.getMessage()); return nullptr; } } types::Type *Module::getVoidType() { if (auto *rVal = getType(VOID_NAME)) return rVal; return Nr(); } types::Type *Module::getBoolType() { if (auto *rVal = getType(BOOL_NAME)) return rVal; return Nr(); } types::Type *Module::getByteType() { if (auto *rVal = getType(BYTE_NAME)) return rVal; return Nr(); } types::Type *Module::getIntType() { if (auto *rVal = getType(INT_NAME)) return rVal; return Nr(); } types::Type *Module::getFloatType() { if (auto *rVal = getType(FLOAT_NAME)) return rVal; return Nr(); } types::Type *Module::getFloat32Type() { if (auto *rVal = getType(FLOAT32_NAME)) return rVal; return Nr(); } types::Type *Module::getFloat16Type() { if (auto *rVal = getType(FLOAT16_NAME)) return rVal; return Nr(); } types::Type *Module::getBFloat16Type() { if (auto *rVal = getType(BFLOAT16_NAME)) return rVal; return Nr(); } types::Type *Module::getFloat128Type() { if (auto *rVal = getType(FLOAT128_NAME)) return rVal; return Nr(); } types::Type *Module::getStringType() { if (auto *rVal = getType(STRING_NAME)) return rVal; return Nr( STRING_NAME, std::vector{getIntType(), unsafeGetPointerType(getByteType())}, std::vector{"len", "ptr"}); } types::Type *Module::getPointerType(types::Type *base) { return getOrRealizeType("Ptr", {base}); } types::Type *Module::getArrayType(types::Type *base) { return getOrRealizeType("Array", {base}); } types::Type *Module::getGeneratorType(types::Type *base) { return getOrRealizeType("Generator", {base}); } types::Type *Module::getOptionalType(types::Type *base) { return getOrRealizeType("Optional", {base}); } types::Type *Module::getFuncType(types::Type *rType, std::vector argTypes, bool variadic) { auto args = translateArgs(cache, argTypes); args[0] = std::make_shared(rType->getAstType()); auto *result = cache->makeFunction(args); if (variadic) { // Type checker types have no concept of variadic functions, so we will // create a new IR type here with the same AST type. auto *f = cast(result); result = unsafeGetFuncType(f->getName() + "$variadic", f->getReturnType(), std::vector(f->begin(), f->end()), /*variadic=*/true); result->setAstType(f->getAstType()); } return result; } types::Type *Module::getIntNType(unsigned int len, bool sign) { return getOrRealizeType(sign ? "Int" : "UInt", {len}); } types::Type *Module::getVectorType(unsigned count, types::Type *base) { return getOrRealizeType(ast::getMangledClass("std.simd", "Vec"), {base, count}); } types::Type *Module::getTupleType(std::vector args) { std::vector argTypes; for (auto *t : args) { seqassertn(t->getAstType(), "{} must have an ast type", *t); argTypes.push_back(t->getAstType()); } return cache->makeTuple(argTypes); } types::Type *Module::getUnionType(std::vector types) { std::vector argTypes; for (auto *t : types) { seqassertn(t->getAstType(), "{} must have an ast type", *t); argTypes.push_back(t->getAstType()); } return cache->makeUnion(argTypes); } types::Type *Module::getNoneType() { return getOrRealizeType("NoneType"); } Value *Module::getInt(int64_t v) { return Nr(v, getIntType()); } Value *Module::getFloat(double v) { return Nr(v, getFloatType()); } Value *Module::getBool(bool v) { return Nr(v, getBoolType()); } Value *Module::getString(std::string v) { return Nr(std::move(v), getStringType()); } types::Type *Module::unsafeGetDummyFuncType() { return unsafeGetFuncType("", getVoidType(), {}); } types::Type *Module::unsafeGetPointerType(types::Type *base) { auto name = types::PointerType::getInstanceName(base); if (auto *rVal = getType(name)) return rVal; return Nr(base); } types::Type *Module::unsafeGetArrayType(types::Type *base) { auto name = fmt::format(FMT_STRING(".Array[{}]"), base->referenceString()); if (auto *rVal = getType(name)) return rVal; std::vector members = {getIntType(), unsafeGetPointerType(base)}; std::vector names = {"len", "ptr"}; return Nr(name, members, names); } types::Type *Module::unsafeGetGeneratorType(types::Type *base) { auto name = types::GeneratorType::getInstanceName(base); if (auto *rVal = getType(name)) return rVal; return Nr(base); } types::Type *Module::unsafeGetOptionalType(types::Type *base) { auto name = types::OptionalType::getInstanceName(base); if (auto *rVal = getType(name)) return rVal; return Nr(base); } types::Type *Module::unsafeGetFuncType(const std::string &name, types::Type *rType, std::vector argTypes, bool variadic) { if (auto *rVal = getType(name)) return rVal; return Nr(name, rType, std::move(argTypes), variadic); } types::Type *Module::unsafeGetMemberedType(const std::string &name, bool ref) { auto *rVal = getType(name); if (!rVal) { if (ref) { auto contentName = name + ".contents"; auto *record = getType(contentName); if (!record) { record = Nr(contentName); } rVal = Nr(name, cast(record)); } else { rVal = Nr(name); } } return rVal; } types::Type *Module::unsafeGetIntNType(unsigned int len, bool sign) { auto name = types::IntNType::getInstanceName(len, sign); if (auto *rVal = getType(name)) return rVal; return Nr(len, sign); } types::Type *Module::unsafeGetVectorType(unsigned int count, types::Type *base) { auto *primitive = cast(base); auto name = types::VectorType::getInstanceName(count, primitive); if (auto *rVal = getType(name)) return rVal; seqassertn(primitive, "base type must be a primitive type"); return Nr(count, primitive); } types::Type *Module::unsafeGetUnionType(const std::vector &types) { auto name = types::UnionType::getInstanceName(types); if (auto *rVal = getType(name)) return rVal; return Nr(types); } void Module::pushArena() { arenas.emplace_back(); } void Module::popArena() { auto &arena = arenas.back(); for (auto id : arena.values) { auto it = valueMap.find(id); if (it == valueMap.end()) continue; values.erase(it->second); valueMap.erase(it); } for (auto id : arena.vars) { auto it = varMap.find(id); if (it == varMap.end()) continue; vars.erase(it->second); varMap.erase(it); } for (auto &type : arena.types) { auto it = typesMap.find(type); if (it == typesMap.end()) continue; types.erase(it->second); typesMap.erase(it); } arenas.pop_back(); } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/module.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/cir/func.h" #include "codon/cir/util/iterators.h" #include "codon/cir/value.h" #include "codon/cir/var.h" #include "codon/util/common.h" #include #include namespace codon { namespace ast { struct Cache; class TranslateVisitor; class TypecheckVisitor; } // namespace ast namespace ir { /// CIR object representing a program. class Module : public AcceptorExtend { public: static const std::string VOID_NAME; static const std::string BOOL_NAME; static const std::string BYTE_NAME; static const std::string INT_NAME; static const std::string FLOAT_NAME; static const std::string FLOAT32_NAME; static const std::string FLOAT16_NAME; static const std::string BFLOAT16_NAME; static const std::string FLOAT128_NAME; static const std::string STRING_NAME; static const std::string EQ_MAGIC_NAME; static const std::string NE_MAGIC_NAME; static const std::string LT_MAGIC_NAME; static const std::string GT_MAGIC_NAME; static const std::string LE_MAGIC_NAME; static const std::string GE_MAGIC_NAME; static const std::string POS_MAGIC_NAME; static const std::string NEG_MAGIC_NAME; static const std::string INVERT_MAGIC_NAME; static const std::string ABS_MAGIC_NAME; static const std::string ADD_MAGIC_NAME; static const std::string SUB_MAGIC_NAME; static const std::string MUL_MAGIC_NAME; static const std::string MATMUL_MAGIC_NAME; static const std::string TRUE_DIV_MAGIC_NAME; static const std::string FLOOR_DIV_MAGIC_NAME; static const std::string MOD_MAGIC_NAME; static const std::string POW_MAGIC_NAME; static const std::string LSHIFT_MAGIC_NAME; static const std::string RSHIFT_MAGIC_NAME; static const std::string AND_MAGIC_NAME; static const std::string OR_MAGIC_NAME; static const std::string XOR_MAGIC_NAME; static const std::string IADD_MAGIC_NAME; static const std::string ISUB_MAGIC_NAME; static const std::string IMUL_MAGIC_NAME; static const std::string IMATMUL_MAGIC_NAME; static const std::string ITRUE_DIV_MAGIC_NAME; static const std::string IFLOOR_DIV_MAGIC_NAME; static const std::string IMOD_MAGIC_NAME; static const std::string IPOW_MAGIC_NAME; static const std::string ILSHIFT_MAGIC_NAME; static const std::string IRSHIFT_MAGIC_NAME; static const std::string IAND_MAGIC_NAME; static const std::string IOR_MAGIC_NAME; static const std::string IXOR_MAGIC_NAME; static const std::string RADD_MAGIC_NAME; static const std::string RSUB_MAGIC_NAME; static const std::string RMUL_MAGIC_NAME; static const std::string RMATMUL_MAGIC_NAME; static const std::string RTRUE_DIV_MAGIC_NAME; static const std::string RFLOOR_DIV_MAGIC_NAME; static const std::string RMOD_MAGIC_NAME; static const std::string RPOW_MAGIC_NAME; static const std::string RLSHIFT_MAGIC_NAME; static const std::string RRSHIFT_MAGIC_NAME; static const std::string RAND_MAGIC_NAME; static const std::string ROR_MAGIC_NAME; static const std::string RXOR_MAGIC_NAME; static const std::string INT_MAGIC_NAME; static const std::string FLOAT_MAGIC_NAME; static const std::string BOOL_MAGIC_NAME; static const std::string STR_MAGIC_NAME; static const std::string REPR_MAGIC_NAME; static const std::string CALL_MAGIC_NAME; static const std::string GETITEM_MAGIC_NAME; static const std::string SETITEM_MAGIC_NAME; static const std::string ITER_MAGIC_NAME; static const std::string LEN_MAGIC_NAME; static const std::string NEW_MAGIC_NAME; static const std::string INIT_MAGIC_NAME; private: struct Arena { std::vector values; std::vector vars; std::vector types; }; /// the module's "main" function std::unique_ptr mainFunc; /// the module's argv variable std::unique_ptr argVar; /// the global variables list std::list> vars; /// the global variables map std::unordered_map>::iterator> varMap; /// the global value list std::list> values; /// the global value map std::unordered_map>::iterator> valueMap; /// the global types list std::list> types; /// the global types map std::unordered_map>::iterator> typesMap; /// the arena stack std::vector arenas; /// the type-checker cache ast::Cache *cache = nullptr; public: static const char NodeId; /// Constructs an CIR module. /// @param name the module name explicit Module(const std::string &name = ""); virtual ~Module() noexcept = default; /// @return the main function Func *getMainFunc() { return mainFunc.get(); } /// @return the main function const Func *getMainFunc() const { return mainFunc.get(); } /// @return the arg var Var *getArgVar() { return argVar.get(); } /// @return the arg var const Var *getArgVar() const { return argVar.get(); } /// @return iterator to the first symbol auto begin() { return util::raw_ptr_adaptor(vars.begin()); } /// @return iterator beyond the last symbol auto end() { return util::raw_ptr_adaptor(vars.end()); } /// @return iterator to the first symbol auto begin() const { return util::const_raw_ptr_adaptor(vars.begin()); } /// @return iterator beyond the last symbol auto end() const { return util::const_raw_ptr_adaptor(vars.end()); } /// @return a pointer to the first symbol Var *front() { return vars.front().get(); } /// @return a pointer to the last symbol Var *back() { return vars.back().get(); } /// @return a pointer to the first symbol const Var *front() const { return vars.front().get(); } /// @return a pointer to the last symbol const Var *back() const { return vars.back().get(); } /// Gets a var by id. /// @param id the id /// @return the variable or nullptr Var *getVar(id_t id) { auto it = varMap.find(id); return it != varMap.end() ? it->second->get() : nullptr; } /// Gets a var by id. /// @param id the id /// @return the variable or nullptr const Var *getVar(id_t id) const { auto it = varMap.find(id); return it != varMap.end() ? it->second->get() : nullptr; } /// Removes a given var. /// @param v the var void remove(const Var *v) { auto it = varMap.find(v->getId()); vars.erase(it->second); varMap.erase(it); } /// @return iterator to the first value auto values_begin() { return util::raw_ptr_adaptor(values.begin()); } /// @return iterator beyond the last value auto values_end() { return util::raw_ptr_adaptor(values.end()); } /// @return iterator to the first value auto values_begin() const { return util::const_raw_ptr_adaptor(values.begin()); } /// @return iterator beyond the last value auto values_end() const { return util::const_raw_ptr_adaptor(values.end()); } /// @return a pointer to the first value Value *values_front() { return values.front().get(); } /// @return a pointer to the last value Value *values_back() { return values.back().get(); } /// @return a pointer to the first value const Value *values_front() const { return values.front().get(); } /// @return a pointer to the last value const Value *values_back() const { return values.back().get(); } /// Gets a value by id. /// @param id the id /// @return the value or nullptr Value *getValue(id_t id) { auto it = valueMap.find(id); return it != valueMap.end() ? it->second->get() : nullptr; } /// Gets a value by id. /// @param id the id /// @return the value or nullptr const Value *getValue(id_t id) const { auto it = valueMap.find(id); return it != valueMap.end() ? it->second->get() : nullptr; } /// Removes a given value. /// @param v the value void remove(const Value *v) { auto it = valueMap.find(v->getId()); values.erase(it->second); valueMap.erase(it); } /// @return iterator to the first type auto types_begin() { return util::raw_ptr_adaptor(types.begin()); } /// @return iterator beyond the last type auto types_end() { return util::raw_ptr_adaptor(types.end()); } /// @return iterator to the first type auto types_begin() const { return util::const_raw_ptr_adaptor(types.begin()); } /// @return iterator beyond the last type auto types_end() const { return util::const_raw_ptr_adaptor(types.end()); } /// @return a pointer to the first type types::Type *types_front() const { return types.front().get(); } /// @return a pointer to the last type types::Type *types_back() const { return types.back().get(); } /// @param name the type's name /// @return the type with the given name types::Type *getType(const std::string &name) { auto it = typesMap.find(name); return it == typesMap.end() ? nullptr : it->second->get(); } /// @param name the type's name /// @return the type with the given name types::Type *getType(const std::string &name) const { auto it = typesMap.find(name); return it == typesMap.end() ? nullptr : it->second->get(); } /// Removes a given type. /// @param t the type void remove(types::Type *t) { auto it = typesMap.find(t->getName()); types.erase(it->second); typesMap.erase(it); } /// Constructs and registers an IR node with provided source information. /// @param s the source information /// @param args the arguments /// @return the new node template DesiredType *N(codon::SrcInfo s, Args &&...args) { auto *ret = new DesiredType(std::forward(args)...); ret->setModule(this); ret->setSrcInfo(s); store(ret); return ret; } /// Constructs and registers an IR node with provided source node. /// @param s the source node /// @param args the arguments /// @return the new node template DesiredType *N(const codon::SrcObject *s, Args &&...args) { return N(s->getSrcInfo(), std::forward(args)...); } /// Constructs and registers an IR node with provided source node. /// @param s the source node /// @param args the arguments /// @return the new node template DesiredType *N(const Node *s, Args &&...args) { return N(s->getSrcInfo(), std::forward(args)...); } /// Constructs and registers an IR node with no source information. /// @param args the arguments /// @return the new node template DesiredType *Nr(Args &&...args) { return N(codon::SrcInfo(), std::forward(args)...); } /// @return the type-checker cache ast::Cache *getCache() const { return cache; } /// Sets the type-checker cache. /// @param c the cache void setCache(ast::Cache *c) { cache = c; } /// Parse a codon code block. void parseCode(const std::string &code); /// Gets or realizes a method. /// @param parent the parent class /// @param methodName the method name /// @param args the argument types /// @param generics the generics /// @return the method or nullptr Func *getOrRealizeMethod(types::Type *parent, const std::string &methodName, std::vector args, std::vector generics = {}); /// Gets or realizes a function. /// @param funcName the function name /// @param args the argument types /// @param generics the generics /// @param module the module of the function /// @return the function or nullptr Func *getOrRealizeFunc(const std::string &funcName, std::vector args, std::vector generics = {}, const std::string &module = ""); /// Gets or realizes a type. /// @param typeName the type name /// @param generics the generics /// @param module the module of the type /// @return the function or nullptr types::Type *getOrRealizeType(const std::string &typeName, std::vector generics = {}); /// @return the void type types::Type *getVoidType(); /// @return the bool type types::Type *getBoolType(); /// @return the byte type types::Type *getByteType(); /// @return the int type types::Type *getIntType(); /// @return the float type types::Type *getFloatType(); /// @return the float32 type types::Type *getFloat32Type(); /// @return the float16 type types::Type *getFloat16Type(); /// @return the bfloat16 type types::Type *getBFloat16Type(); /// @return the float128 type types::Type *getFloat128Type(); /// @return the string type types::Type *getStringType(); /// Gets a pointer type. /// @param base the base type /// @return a pointer type that references the base types::Type *getPointerType(types::Type *base); /// Gets an array type. /// @param base the base type /// @return an array type that contains the base types::Type *getArrayType(types::Type *base); /// Gets a generator type. /// @param base the base type /// @return a generator type that yields the base types::Type *getGeneratorType(types::Type *base); /// Gets an optional type. /// @param base the base type /// @return an optional type that contains the base types::Type *getOptionalType(types::Type *base); /// Gets a function type. /// @param rType the return type /// @param argTypes the argument types /// @param variadic true if variadic (e.g. "printf" in C) /// @return the void type types::Type *getFuncType(types::Type *rType, std::vector argTypes, bool variadic = false); /// Gets a variable length integer type. /// @param len the length /// @param sign true if signed /// @return a variable length integer type types::Type *getIntNType(unsigned len, bool sign); /// Gets a vector type. /// @param count the vector size /// @param base the vector base type (MUST be a primitive type) /// @return a vector type types::Type *getVectorType(unsigned count, types::Type *base); /// Gets a tuple type. /// @param args the arg types /// @return the tuple type types::Type *getTupleType(std::vector args); /// Gets a union type. /// @param types the alternative types /// @return the union type types::Type *getUnionType(std::vector types); /// Gets the "none" type (i.e. empty tuple). /// @return none type types::Type *getNoneType(); /// @param v the value /// @return an int constant Value *getInt(int64_t v); /// @param v the value /// @return a float constant Value *getFloat(double v); /// @param v the value /// @return a bool constant Value *getBool(bool v); /// @param v the value /// @return a string constant Value *getString(std::string v); /// Gets a dummy function type. Should generally not be used as no type-checker /// information is generated. /// @return a func type with no args and void return type. types::Type *unsafeGetDummyFuncType(); /// Gets a pointer type. Should generally not be used as no type-checker /// information is generated. /// @param base the base type /// @return a pointer type that references the base types::Type *unsafeGetPointerType(types::Type *base); /// Gets an array type. Should generally not be used as no type-checker /// information is generated. /// @param base the base type /// @return an array type that contains the base types::Type *unsafeGetArrayType(types::Type *base); /// Gets a generator type. Should generally not be used as no type-checker /// information is generated. /// @param base the base type /// @return a generator type that yields the base types::Type *unsafeGetGeneratorType(types::Type *base); /// Gets an optional type. Should generally not be used as no type-checker /// information is generated. /// @param base the base type /// @return an optional type that contains the base types::Type *unsafeGetOptionalType(types::Type *base); /// Gets a function type. Should generally not be used as no type-checker /// information is generated. /// @param rType the return type /// @param argTypes the argument types /// @param variadic true if variadic (e.g. "printf" in C) /// @return the void type types::Type *unsafeGetFuncType(const std::string &name, types::Type *rType, std::vector argTypes, bool variadic = false); /// Gets a membered type. Should generally not be used as no type-checker /// information is generated. /// @param name the type's name /// @param ref whether the type should be a ref /// @return an empty membered/ref type types::Type *unsafeGetMemberedType(const std::string &name, bool ref = false); /// Gets a variable length integer type. Should generally not be used as no /// type-checker information is generated. /// @param len the length /// @param sign true if signed /// @return a variable length integer type types::Type *unsafeGetIntNType(unsigned len, bool sign); /// Gets a vector type. Should generally not be used as no /// type-checker information is generated. /// @param count the vector size /// @param base the vector base type (MUST be a primitive type) /// @return a vector type types::Type *unsafeGetVectorType(unsigned count, types::Type *base); /// Gets a union type. Should generally not be used as no /// type-checker information is generated. /// @param types the alternative types /// @return a union type types::Type *unsafeGetUnionType(const std::vector &types); /// Push an arena on the arena stack that stores all nodes /// that are created subsequently. void pushArena(); /// Pop the top arena of the arena stack, deallocating all /// the nodes stored therein. void popArena(); private: void store(types::Type *t) { types.emplace_back(t); typesMap[t->getName()] = std::prev(types.end()); if (!arenas.empty()) arenas.back().types.push_back(t->getName()); } void store(Value *v) { values.emplace_back(v); valueMap[v->getId()] = std::prev(values.end()); if (!arenas.empty()) arenas.back().values.push_back(v->getId()); } void store(Var *v) { vars.emplace_back(v); varMap[v->getId()] = std::prev(vars.end()); if (!arenas.empty()) arenas.back().vars.push_back(v->getId()); } }; } // namespace ir } // namespace codon template <> struct fmt::formatter : fmt::ostream_formatter {}; ================================================ FILE: codon/cir/pyextension.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include "codon/cir/func.h" #include "codon/cir/types/types.h" namespace codon { namespace ir { struct PyFunction { enum Type { TOPLEVEL, METHOD, CLASS, STATIC }; std::string name; std::string doc; Func *func = nullptr; Type type = Type::TOPLEVEL; int nargs = 0; bool keywords = false; bool coexist = false; }; struct PyMember { enum Type { SHORT = 0, INT = 1, LONG = 2, FLOAT = 3, DOUBLE = 4, STRING = 5, OBJECT = 6, CHAR = 7, BYTE = 8, UBYTE = 9, USHORT = 10, UINT = 11, ULONG = 12, STRING_INPLACE = 13, BOOL = 14, OBJECT_EX = 16, LONGLONG = 17, ULONGLONG = 18, PYSSIZET = 19, }; std::string name; std::string doc; Type type = Type::SHORT; bool readonly = false; /// Indexes of the member. For example, in the /// tuple (a, (b, c, (d,))), 'a' would have indexes /// [0], 'b' would have indexes [1, 0], 'c' would /// have indexes [1, 1], and 'd' would have indexes /// [1, 2, 0]. This corresponds to an LLVM GEP. std::vector indexes; }; struct PyGetSet { std::string name; std::string doc; Func *get = nullptr; Func *set = nullptr; }; struct PyType { std::string name; std::string doc; types::Type *type = nullptr; PyType *base = nullptr; Func *repr = nullptr; Func *add = nullptr; Func *iadd = nullptr; Func *sub = nullptr; Func *isub = nullptr; Func *mul = nullptr; Func *imul = nullptr; Func *mod = nullptr; Func *imod = nullptr; Func *divmod = nullptr; Func *pow = nullptr; Func *ipow = nullptr; Func *neg = nullptr; Func *pos = nullptr; Func *abs = nullptr; Func *bool_ = nullptr; Func *invert = nullptr; Func *lshift = nullptr; Func *ilshift = nullptr; Func *rshift = nullptr; Func *irshift = nullptr; Func *and_ = nullptr; Func *iand = nullptr; Func *xor_ = nullptr; Func *ixor = nullptr; Func *or_ = nullptr; Func *ior = nullptr; Func *int_ = nullptr; Func *float_ = nullptr; Func *floordiv = nullptr; Func *ifloordiv = nullptr; Func *truediv = nullptr; Func *itruediv = nullptr; Func *index = nullptr; Func *matmul = nullptr; Func *imatmul = nullptr; Func *len = nullptr; Func *getitem = nullptr; Func *setitem = nullptr; Func *contains = nullptr; Func *hash = nullptr; Func *call = nullptr; Func *str = nullptr; Func *cmp = nullptr; Func *iter = nullptr; Func *iternext = nullptr; Func *del = nullptr; Func *init = nullptr; std::vector methods; std::vector members; std::vector getset; Func *typePtrHook = nullptr; }; struct PyModule { std::string name; std::string doc; std::vector functions; std::vector types; }; } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/cleanup/canonical.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "canonical.h" #include #include #include #include #include #include "codon/cir/analyze/module/side_effect.h" #include "codon/cir/transform/rewrite.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/matching.h" namespace codon { namespace ir { namespace transform { namespace cleanup { namespace { struct NodeRanker : public util::Operator { // Nodes are ranked lexicographically by: // - Whether the node is constant (constants come last) // - Max node depth (deeper nodes first) // - Node hash // The hash imposes an arbitrary but well-defined ordering // to ensure a single canonical representation for (most) // nodes. using Rank = std::tuple; Node *root = nullptr; int maxDepth = 0; uint64_t hash = 0; // boost's hash_combine template void hash_combine(const T &v) { std::hash hasher; hash ^= hasher(v) + 0x9e3779b9 + (hash << 6) + (hash >> 2); } void preHook(Node *node) { if (!root) root = node; maxDepth = std::max(maxDepth, depth()); for (auto *v : node->getUsedVariables()) { hash_combine(v->getName()); } for (auto *v : node->getUsedTypes()) { hash_combine(v->getName()); } } Rank getRank() { return std::make_tuple((isA(root) ? 1 : -1), -maxDepth, hash); } }; NodeRanker::Rank getRank(Node *node) { NodeRanker ranker; node->accept(ranker); return ranker.getRank(); } bool isCommutativeOp(Func *fn) { return fn && util::hasAttribute( fn, ast::getMangledFunc("std.internal.attributes", "commutative")); } bool isAssociativeOp(Func *fn) { return fn && util::hasAttribute( fn, ast::getMangledFunc("std.internal.attributes", "associative")); } bool isDistributiveOp(Func *fn) { return fn && util::hasAttribute( fn, ast::getMangledFunc("std.internal.attributes", "distributive")); } bool isInequalityOp(Func *fn) { static const std::unordered_set ops = { Module::EQ_MAGIC_NAME, Module::NE_MAGIC_NAME, Module::LT_MAGIC_NAME, Module::LE_MAGIC_NAME, Module::GT_MAGIC_NAME, Module::GE_MAGIC_NAME}; return fn && ops.find(fn->getUnmangledName()) != ops.end(); } // c + b + a --> a + b + c struct CanonOpChain : public RewriteRule { static void extractAssociativeOpChain(Value *v, const std::string &op, types::Type *type, std::vector &result) { if (util::isCallOf(v, op, {type, type}, type, /*method=*/true)) { auto *call = cast(v); extractAssociativeOpChain(call->front(), op, type, result); extractAssociativeOpChain(call->back(), op, type, result); } else { result.push_back(v); } } static void orderOperands(std::vector &operands) { std::vector> rankedOperands; for (auto *v : operands) { rankedOperands.push_back({getRank(v), v}); } std::sort(rankedOperands.begin(), rankedOperands.end()); operands.clear(); for (auto &p : rankedOperands) { operands.push_back(std::get<1>(p)); } } void visit(CallInstr *v) override { auto *fn = util::getFunc(v->getCallee()); if (!fn) return; std::string op = fn->getUnmangledName(); types::Type *type = v->getType(); const bool isAssociative = isAssociativeOp(fn); const bool isCommutative = isCommutativeOp(fn); if (util::isCallOf(v, op, {type, type}, type, /*method=*/true)) { std::vector operands; if (isAssociative) { extractAssociativeOpChain(v, op, type, operands); } else { operands.push_back(v->front()); operands.push_back(v->back()); } seqassertn(operands.size() >= 2, "bad call canonicalization"); if (isCommutative) orderOperands(operands); Value *newCall = util::call(fn, {operands[0], operands[1]}); for (auto it = operands.begin() + 2; it != operands.end(); ++it) { newCall = util::call(fn, {newCall, *it}); } if (!util::match(v, newCall, /*checkNames=*/false, /*varIdMatch=*/true)) return setResult(newCall); } } }; // b > a --> a < b (etc.) struct CanonInequality : public RewriteRule { void visit(CallInstr *v) override { auto *fn = util::getFunc(v->getCallee()); if (!fn) return; std::string op = fn->getUnmangledName(); types::Type *type = v->getType(); // canonicalize inequalities if (v->numArgs() == 2 && isInequalityOp(fn)) { Value *newCall = nullptr; auto *lhs = v->front(); auto *rhs = v->back(); if (getRank(lhs) > getRank(rhs)) { // are we out of order? // re-order if (op == Module::EQ_MAGIC_NAME) { // lhs == rhs newCall = *rhs == *lhs; } else if (op == Module::NE_MAGIC_NAME) { // lhs != rhs newCall = *rhs != *lhs; } else if (op == Module::LT_MAGIC_NAME) { // lhs < rhs newCall = *rhs > *lhs; } else if (op == Module::LE_MAGIC_NAME) { // lhs <= rhs newCall = *rhs >= *lhs; } else if (op == Module::GT_MAGIC_NAME) { // lhs > rhs newCall = *rhs < *lhs; } else if (op == Module::GE_MAGIC_NAME) { // lhs >= rhs newCall = *rhs <= *lhs; } else { seqassertn(false, "unknown comparison op: {}", op); } if (newCall && newCall->getType()->is(type) && !util::match(v, newCall, /*checkNames=*/false, /*varIdMatch=*/true)) return setResult(newCall); } } } }; // a*x + b*x --> (a + b) * x struct CanonAddMul : public RewriteRule { static bool varMatch(Value *a, Value *b) { auto *v1 = cast(a); auto *v2 = cast(b); return v1 && v2 && v1->getVar()->getId() == v2->getVar()->getId(); } static Func *getOp(Value *v) { return isA(v) ? util::getFunc(cast(v)->getCallee()) : nullptr; } // (a + b) * x, or null if invalid static Value *addMul(Value *a, Value *b, Value *x) { if (!a || !b || !x) return nullptr; auto *y = (*a + *b); if (!y) { y = (*b + *a); if (y && !isCommutativeOp(getOp(y))) return nullptr; } if (!y) return nullptr; auto *z = (*y) * (*x); if (!z) { z = (*x) * (*y); if (z && !isCommutativeOp(getOp(z))) return nullptr; } if (!z) return nullptr; return z; } void visit(CallInstr *v) override { auto *M = v->getModule(); auto *fn = util::getFunc(v->getCallee()); if (!isCommutativeOp(fn) || !util::isCallOf(v, Module::ADD_MAGIC_NAME, 2, /*output=*/nullptr, /*method=*/true)) return; // decompose the operation Value *lhs = v->front(); Value *rhs = v->back(); Value *lhs1 = nullptr, *lhs2 = nullptr, *rhs1 = nullptr, *rhs2 = nullptr; if (util::isCallOf(lhs, Module::MUL_MAGIC_NAME, 2, /*output=*/nullptr, /*method=*/true)) { auto *lhsCall = cast(lhs); lhs1 = lhsCall->front(); lhs2 = lhsCall->back(); } else { lhs1 = lhs; lhs2 = M->getInt(1); } if (util::isCallOf(rhs, Module::MUL_MAGIC_NAME, 2, /*output=*/nullptr, /*method=*/true)) { auto *rhsCall = cast(rhs); rhs1 = rhsCall->front(); rhs2 = rhsCall->back(); } else { rhs1 = rhs; rhs2 = M->getInt(1); } Value *newCall = nullptr; if (varMatch(lhs1, rhs1)) { newCall = addMul(lhs2, rhs2, lhs1); } else if (varMatch(lhs1, rhs2)) { newCall = addMul(lhs2, rhs1, lhs1); } else if (varMatch(lhs2, rhs1)) { newCall = addMul(lhs1, rhs2, lhs2); } else if (varMatch(lhs2, rhs2)) { newCall = addMul(lhs1, rhs1, lhs2); } if (newCall && isDistributiveOp(getOp(newCall)) && newCall->getType()->is(v->getType()) && !util::match(v, newCall, /*checkNames=*/false, /*varIdMatch=*/true)) return setResult(newCall); } }; // x - c --> x + (-c) struct CanonConstSub : public RewriteRule { void visit(CallInstr *v) override { auto *M = v->getModule(); auto *type = v->getType(); if (!util::isCallOf(v, Module::SUB_MAGIC_NAME, 2, /*output=*/nullptr, /*method=*/true)) return; Value *lhs = v->front(); Value *rhs = v->back(); if (!lhs->getType()->is(rhs->getType())) return; Value *newCall = nullptr; if (util::isConst(rhs)) { auto c = util::getConst(rhs); if (c != -(1ull << 63)) // ensure no overflow newCall = *lhs + *(M->getInt(-c)); } else if (util::isConst(rhs)) { auto c = util::getConst(rhs); newCall = *lhs + *(M->getFloat(-c)); } if (newCall && newCall->getType()->is(type) && !util::match(v, newCall, /*checkNames=*/false, /*varIdMatch=*/true)) return setResult(newCall); } }; } // namespace const std::string CanonicalizationPass::KEY = "core-cleanup-canon"; void CanonicalizationPass::run(Module *m) { registerStandardRules(m); Rewriter::reset(); OperatorPass::run(m); } void CanonicalizationPass::handle(CallInstr *v) { auto *r = getAnalysisResult(sideEffectsKey); if (!r->hasSideEffect(v)) rewrite(v); } void CanonicalizationPass::handle(SeriesFlow *v) { auto it = v->begin(); while (it != v->end()) { if (auto *series = cast(*it)) { it = v->erase(it); for (auto *x : *series) { it = v->insert(it, x); ++it; } } else if (auto *flowInstr = cast(*it)) { it = v->erase(it); // inserting in reverse order causes [flow, value] to be added it = v->insert(it, flowInstr->getValue()); it = v->insert(it, flowInstr->getFlow()); // don't increment; re-traverse in case a new series flow added } else { ++it; } } } void CanonicalizationPass::registerStandardRules(Module *m) { registerRule("op-chain", std::make_unique()); registerRule("inequality", std::make_unique()); registerRule("add-mul", std::make_unique()); registerRule("const-sub", std::make_unique()); } } // namespace cleanup } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/cleanup/canonical.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" #include "codon/cir/transform/rewrite.h" namespace codon { namespace ir { namespace transform { namespace cleanup { /// Canonicalization pass that flattens nested series /// flows, puts operands in a predefined order, etc. class CanonicalizationPass : public OperatorPass, public Rewriter { private: std::string sideEffectsKey; public: /// Constructs a canonicalization pass /// @param sideEffectsKey the side effect analysis' key CanonicalizationPass(const std::string &sideEffectsKey) : OperatorPass(/*childrenFirst=*/true), sideEffectsKey(sideEffectsKey) {} static const std::string KEY; std::string getKey() const override { return KEY; } void run(Module *m) override; void handle(CallInstr *) override; void handle(SeriesFlow *) override; private: void registerStandardRules(Module *m); }; } // namespace cleanup } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/cleanup/dead_code.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "dead_code.h" #include "codon/cir/analyze/module/side_effect.h" #include "codon/cir/util/cloning.h" namespace codon { namespace ir { namespace transform { namespace cleanup { namespace { BoolConst *boolConst(Value *v) { return cast(v); } IntConst *intConst(Value *v) { return cast(v); } } // namespace const std::string DeadCodeCleanupPass::KEY = "core-cleanup-dce"; void DeadCodeCleanupPass::run(Module *m) { numReplacements = 0; OperatorPass::run(m); } void DeadCodeCleanupPass::handle(SeriesFlow *v) { auto *r = getAnalysisResult(sideEffectsKey); auto it = v->begin(); while (it != v->end()) { if (!r->hasSideEffect(*it)) { LOG_IR("[{}] no side effect, deleting: {}", KEY, **it); numReplacements++; it = v->erase(it); } else { ++it; } } } void DeadCodeCleanupPass::handle(IfFlow *v) { auto *cond = boolConst(v->getCond()); if (!cond) return; auto *M = v->getModule(); auto condVal = cond->getVal(); util::CloneVisitor cv(M); if (condVal) { doReplacement(v, cv.clone(v->getTrueBranch())); } else if (auto *f = v->getFalseBranch()) { doReplacement(v, cv.clone(f)); } else { doReplacement(v, M->Nr()); } } void DeadCodeCleanupPass::handle(WhileFlow *v) { auto *cond = boolConst(v->getCond()); if (!cond) return; auto *M = v->getModule(); auto condVal = cond->getVal(); if (!condVal) { doReplacement(v, M->Nr()); } } void DeadCodeCleanupPass::handle(ImperativeForFlow *v) { auto *start = intConst(v->getStart()); auto *end = intConst(v->getEnd()); if (!start || !end) return; auto stepVal = v->getStep(); auto startVal = start->getVal(); auto endVal = end->getVal(); auto *M = v->getModule(); if ((stepVal < 0 && startVal <= endVal) || (stepVal > 0 && startVal >= endVal)) { doReplacement(v, M->Nr()); } } void DeadCodeCleanupPass::handle(TernaryInstr *v) { auto *cond = boolConst(v->getCond()); if (!cond) return; auto *M = v->getModule(); auto condVal = cond->getVal(); util::CloneVisitor cv(M); if (condVal) { doReplacement(v, cv.clone(v->getTrueValue())); } else { doReplacement(v, cv.clone(v->getFalseValue())); } } void DeadCodeCleanupPass::doReplacement(Value *og, Value *v) { numReplacements++; og->replaceAll(v); } } // namespace cleanup } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/cleanup/dead_code.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace cleanup { /// Cleanup pass that removes dead code. class DeadCodeCleanupPass : public OperatorPass { private: std::string sideEffectsKey; int numReplacements; public: static const std::string KEY; /// Constructs a dead code elimination pass /// @param sideEffectsKey the side effect analysis' key DeadCodeCleanupPass(std::string sideEffectsKey) : OperatorPass(), sideEffectsKey(std::move(sideEffectsKey)), numReplacements(0) {} std::string getKey() const override { return KEY; } void run(Module *m) override; void handle(SeriesFlow *v) override; void handle(IfFlow *v) override; void handle(WhileFlow *v) override; void handle(ImperativeForFlow *v) override; void handle(TernaryInstr *v) override; /// @return the number of replacements int getNumReplacements() const { return numReplacements; } private: void doReplacement(Value *og, Value *v); }; } // namespace cleanup } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/cleanup/global_demote.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "global_demote.h" namespace codon { namespace ir { namespace transform { namespace cleanup { namespace { struct GetUsedGlobals : public util::Operator { std::vector vars; void preHook(Node *v) override { for (auto *var : v->getUsedVariables()) { if (!isA(var) && var->isGlobal() && !var->isThreadLocal()) vars.push_back(var); } } }; } // namespace const std::string GlobalDemotionPass::KEY = "core-cleanup-global-demote"; void GlobalDemotionPass::run(Module *M) { numDemotions = 0; std::unordered_map localGlobals; std::vector worklist = {M->getMainFunc()}; for (auto *var : *M) { if (auto *func = cast(var)) worklist.push_back(func); } for (auto *var : worklist) { if (auto *func = cast(var)) { GetUsedGlobals globals; func->accept(globals); for (auto *g : globals.vars) { LOG_IR("[{}] global {} used in {}", KEY, *g, func->getName()); auto it = localGlobals.find(g); if (it == localGlobals.end()) { localGlobals.emplace(g, func); } else if (it->second && it->second != func) { it->second = nullptr; } } } } for (auto it : localGlobals) { if (!it.second || it.first->getId() == M->getArgVar()->getId() || it.first->isExternal()) continue; seqassertn(it.first->isGlobal(), "var was not global [{}]", it.first->getSrcInfo()); it.first->setGlobal(false); if (auto *func = cast(it.second)) { func->push_back(it.first); ++numDemotions; LOG_IR("[{}] demoted {} to a local of {}", KEY, *it.first, func->getName()); } } } } // namespace cleanup } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/cleanup/global_demote.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace cleanup { /// Demotes global variables that are used in only one /// function to locals of that function. class GlobalDemotionPass : public Pass { private: /// number of variables we've demoted int numDemotions; public: static const std::string KEY; /// Constructs a global variable demotion pass GlobalDemotionPass() : Pass(), numDemotions(0) {} std::string getKey() const override { return KEY; } void run(Module *v) override; /// @return number of variables we've demoted int getNumDemotions() const { return numDemotions; } }; } // namespace cleanup } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/cleanup/replacer.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "replacer.h" #include #include "codon/cir/types/types.h" #include "codon/cir/value.h" #include "codon/cir/var.h" namespace codon { namespace ir { namespace transform { namespace cleanup { const std::string ReplaceCleanupPass::KEY = "core-cleanup-physical-replace"; void ReplaceCleanupPass::run(Module *module) { std::unordered_set valuesToDelete; std::unordered_set typesToDelete; std::unordered_set varsToDelete; { auto *f = module->getMainFunc(); for (auto *c : f->getUsedValues()) { if (c->hasReplacement()) { f->replaceUsedValue(c, c->getActual()); valuesToDelete.insert(c); } } for (auto *t : f->getUsedTypes()) { if (t->hasReplacement()) { f->replaceUsedType(t, t->getActual()); typesToDelete.insert(t); } } for (auto *v : f->getUsedVariables()) { if (v->hasReplacement()) { f->replaceUsedVariable(v, v->getActual()); varsToDelete.insert(v); } } } { auto *v = module->getArgVar(); for (auto *c : v->getUsedValues()) { if (c->hasReplacement()) { v->replaceUsedValue(c, c->getActual()); valuesToDelete.insert(c); } } for (auto *t : v->getUsedTypes()) { if (t->hasReplacement()) { v->replaceUsedType(t, t->getActual()); typesToDelete.insert(t); } } for (auto *v2 : v->getUsedVariables()) { if (v2->hasReplacement()) { v->replaceUsedVariable(v2, v2->getActual()); varsToDelete.insert(v2); } } } for (auto it = module->values_begin(); it != module->values_end(); ++it) { for (auto *c : it->getUsedValues()) { if (c->hasReplacement()) { it->replaceUsedValue(c, c->getActual()); valuesToDelete.insert(c); } } for (auto *t : it->getUsedTypes()) { if (t->hasReplacement()) { it->replaceUsedType(t, t->getActual()); typesToDelete.insert(t); } } for (auto *v : it->getUsedVariables()) { if (v->hasReplacement()) { it->replaceUsedVariable(v, v->getActual()); varsToDelete.insert(v); } } } for (auto it = module->begin(); it != module->end(); ++it) { for (auto *c : it->getUsedValues()) { if (c->hasReplacement()) { it->replaceUsedValue(c, c->getActual()); valuesToDelete.insert(c); } } for (auto *t : it->getUsedTypes()) { if (t->hasReplacement()) { it->replaceUsedType(t, t->getActual()); typesToDelete.insert(t); } } for (auto *v : it->getUsedVariables()) { if (v->hasReplacement()) { it->replaceUsedVariable(v, v->getActual()); varsToDelete.insert(v); } } } for (auto it = module->types_begin(); it != module->types_end(); ++it) { for (auto *c : it->getUsedValues()) { if (c->hasReplacement()) { it->replaceUsedValue(c, c->getActual()); valuesToDelete.insert(c); } } for (auto *t : it->getUsedTypes()) { if (t->hasReplacement()) { it->replaceUsedType(t, t->getActual()); typesToDelete.insert(t); } } for (auto *v : it->getUsedVariables()) { if (v->hasReplacement()) { it->replaceUsedVariable(v, v->getActual()); varsToDelete.insert(v); } } } for (auto *v : valuesToDelete) module->remove(v); for (auto *v : varsToDelete) module->remove(v); for (auto *t : typesToDelete) module->remove(t); } } // namespace cleanup } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/cleanup/replacer.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace cleanup { /// Cleanup pass that physically replaces nodes. class ReplaceCleanupPass : public Pass { public: static const std::string KEY; std::string getKey() const override { return KEY; } void run(Module *module) override; }; } // namespace cleanup } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/folding/const_fold.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "const_fold.h" #include #include #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" #define BINOP(o) [](auto x, auto y) -> auto { return x o y; } #define UNOP(o) [](auto x) -> auto { return o x; } namespace codon { namespace ir { namespace transform { namespace folding { namespace { auto pyDivmod(int64_t self, int64_t other) { auto d = self / other; auto m = self - d * other; if (m && ((other ^ m) < 0)) { m += other; d -= 1; } return std::make_pair(d, m); } template class IntFloatBinaryRule : public RewriteRule { private: Func f; std::string magic; types::Type *out; bool excludeRHSZero; public: IntFloatBinaryRule(Func f, std::string magic, types::Type *out, bool excludeRHSZero = false) : f(std::move(f)), magic(std::move(magic)), out(out), excludeRHSZero(excludeRHSZero) {} virtual ~IntFloatBinaryRule() noexcept = default; void visit(CallInstr *v) override { if (!util::isCallOf(v, magic, 2, /*output=*/nullptr, /*method=*/true)) return; auto *leftConst = cast(v->front()); auto *rightConst = cast(v->back()); if (!leftConst || !rightConst) return; auto *M = v->getModule(); if (isA(leftConst) && isA(rightConst)) { auto left = cast(leftConst)->getVal(); auto right = cast(rightConst)->getVal(); if (excludeRHSZero && right == 0) return; return setResult(M->template N>(v->getSrcInfo(), f(left, (double)right), out)); } else if (isA(leftConst) && isA(rightConst)) { auto left = cast(leftConst)->getVal(); auto right = cast(rightConst)->getVal(); if (excludeRHSZero && right == 0.0) return; return setResult(M->template N>(v->getSrcInfo(), f((double)left, right), out)); } } }; /// Binary rule that requires two constants. template class DoubleConstantBinaryRuleExcludeRHSZero : public DoubleConstantBinaryRule { public: DoubleConstantBinaryRuleExcludeRHSZero(Func f, std::string magic, types::Type *inputType, types::Type *resultType) : DoubleConstantBinaryRule(f, magic, inputType, resultType) {} virtual ~DoubleConstantBinaryRuleExcludeRHSZero() noexcept = default; void visit(CallInstr *v) override { if (v->numArgs() == 2) { auto *rightConst = cast>(v->back()); if (rightConst && rightConst->getVal() == ConstantType()) return; } DoubleConstantBinaryRule::visit(v); } }; auto id_val(Module *m) { return [=](Value *v) -> Value * { util::CloneVisitor cv(m); return cv.clone(v); }; } int64_t int_pow(int64_t base, int64_t exp) { if (exp < 0) return 0; int64_t result = 1; while (true) { if (exp & 1) { result *= base; } exp = exp >> 1; if (!exp) break; base = base * base; } return result; } template To convert(From x) { return To(x); } template auto intSingleRule(Module *m, Args &&...args) { return std::make_unique>( std::forward(args)..., m->getIntType()); } auto intNoOp(Module *m, std::string magic) { return std::make_unique(std::move(magic), m->getIntType()); } auto intDoubleApplyNoOp(Module *m, std::string magic) { return std::make_unique(std::move(magic), m->getIntType()); } template auto intToIntBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getIntType(), m->getIntType()); } template auto intToIntBinaryNoZeroRHS(Module *m, Func f, std::string magic) { return std::make_unique< DoubleConstantBinaryRuleExcludeRHSZero>( std::move(f), std::move(magic), m->getIntType(), m->getIntType()); } template auto intToBoolBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getIntType(), m->getBoolType()); } template auto boolToBoolBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getBoolType(), m->getBoolType()); } template auto floatToFloatBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getFloatType(), m->getFloatType()); } template auto floatToFloatBinaryNoZeroRHS(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getFloatType(), m->getFloatType()); } template auto floatToBoolBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getFloatType(), m->getBoolType()); } template auto intFloatToFloatBinary(Module *m, Func f, std::string magic, bool excludeRHSZero = false) { return std::make_unique>( std::move(f), std::move(magic), m->getFloatType(), excludeRHSZero); } template auto intFloatToBoolBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getBoolType()); } template auto intToIntUnary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getIntType(), m->getIntType()); } template auto floatToFloatUnary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getFloatType(), m->getFloatType()); } template auto boolToBoolUnary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getBoolType(), m->getBoolType()); } auto identityConvert(Module *m, std::string magic, types::Type *type) { return std::make_unique>(id_val(m), std::move(magic), type); } template auto typeConvert(Module *m, std::string magic, types::Type *fromType, types::Type *toType) { return std::make_unique< SingleConstantUnaryRule)>>>( convert, std::move(magic), fromType, toType); } } // namespace const std::string FoldingPass::KEY = "core-folding-const-fold"; void FoldingPass::run(Module *m) { registerStandardRules(m); Rewriter::reset(); OperatorPass::run(m); } void FoldingPass::handle(CallInstr *v) { rewrite(v); } void FoldingPass::registerStandardRules(Module *m) { // binary, single constant, int->int using Kind = SingleConstantCommutativeRule::Kind; registerRule("int-multiply-by-zero", intSingleRule(m, 0, 0, Module::MUL_MAGIC_NAME, Kind::COMMUTATIVE)); registerRule( "int-multiply-by-one", intSingleRule(m, 1, id_val(m), Module::MUL_MAGIC_NAME, Kind::COMMUTATIVE)); registerRule("int-subtract-zero", intSingleRule(m, 0, id_val(m), Module::SUB_MAGIC_NAME, Kind::RIGHT)); registerRule("int-add-zero", intSingleRule(m, 0, id_val(m), Module::ADD_MAGIC_NAME, Kind::COMMUTATIVE)); registerRule( "int-floor-div-by-one", intSingleRule(m, 1, id_val(m), Module::FLOOR_DIV_MAGIC_NAME, Kind::RIGHT)); registerRule("int-zero-floor-div", intSingleRule(m, 0, 0, Module::FLOOR_DIV_MAGIC_NAME, Kind::LEFT)); registerRule("int-pos", intNoOp(m, Module::POS_MAGIC_NAME)); registerRule("int-double-neg", intDoubleApplyNoOp(m, Module::NEG_MAGIC_NAME)); registerRule("int-double-inv", intDoubleApplyNoOp(m, Module::INVERT_MAGIC_NAME)); // binary, double constant, int->int registerRule("int-constant-addition", intToIntBinary(m, BINOP(+), Module::ADD_MAGIC_NAME)); registerRule("int-constant-subtraction", intToIntBinary(m, BINOP(-), Module::SUB_MAGIC_NAME)); if (pyNumerics) { registerRule("int-constant-floor-div", intToIntBinaryNoZeroRHS( m, [](auto x, auto y) -> auto { return pyDivmod(x, y).first; }, Module::FLOOR_DIV_MAGIC_NAME)); } else { registerRule("int-constant-floor-div", intToIntBinaryNoZeroRHS(m, BINOP(/), Module::FLOOR_DIV_MAGIC_NAME)); } registerRule("int-constant-mul", intToIntBinary(m, BINOP(*), Module::MUL_MAGIC_NAME)); registerRule("int-constant-lshift", intToIntBinary(m, BINOP(<<), Module::LSHIFT_MAGIC_NAME)); registerRule("int-constant-rshift", intToIntBinary(m, BINOP(>>), Module::RSHIFT_MAGIC_NAME)); registerRule("int-constant-pow", intToIntBinary(m, int_pow, Module::POW_MAGIC_NAME)); registerRule("int-constant-xor", intToIntBinary(m, BINOP(^), Module::XOR_MAGIC_NAME)); registerRule("int-constant-or", intToIntBinary(m, BINOP(|), Module::OR_MAGIC_NAME)); registerRule("int-constant-and", intToIntBinary(m, BINOP(&), Module::AND_MAGIC_NAME)); if (pyNumerics) { registerRule("int-constant-mod", intToIntBinaryNoZeroRHS( m, [](auto x, auto y) -> auto { return pyDivmod(x, y).second; }, Module::MOD_MAGIC_NAME)); } else { registerRule("int-constant-mod", intToIntBinaryNoZeroRHS(m, BINOP(%), Module::MOD_MAGIC_NAME)); } // binary, double constant, int->bool registerRule("int-constant-eq", intToBoolBinary(m, BINOP(==), Module::EQ_MAGIC_NAME)); registerRule("int-constant-ne", intToBoolBinary(m, BINOP(!=), Module::NE_MAGIC_NAME)); registerRule("int-constant-gt", intToBoolBinary(m, BINOP(>), Module::GT_MAGIC_NAME)); registerRule("int-constant-ge", intToBoolBinary(m, BINOP(>=), Module::GE_MAGIC_NAME)); registerRule("int-constant-lt", intToBoolBinary(m, BINOP(<), Module::LT_MAGIC_NAME)); registerRule("int-constant-le", intToBoolBinary(m, BINOP(<=), Module::LE_MAGIC_NAME)); // binary, double constant, bool->bool registerRule("bool-constant-xor", boolToBoolBinary(m, BINOP(^), Module::XOR_MAGIC_NAME)); registerRule("bool-constant-or", boolToBoolBinary(m, BINOP(|), Module::OR_MAGIC_NAME)); registerRule("bool-constant-and", boolToBoolBinary(m, BINOP(&), Module::AND_MAGIC_NAME)); // unary, single constant, int->int registerRule("int-constant-pos", intToIntUnary(m, UNOP(+), Module::POS_MAGIC_NAME)); registerRule("int-constant-neg", intToIntUnary(m, UNOP(-), Module::NEG_MAGIC_NAME)); registerRule("int-constant-inv", intToIntUnary(m, UNOP(~), Module::INVERT_MAGIC_NAME)); // unary, singe constant, float->float registerRule("float-constant-pos", floatToFloatUnary(m, UNOP(+), Module::POS_MAGIC_NAME)); registerRule("float-constant-neg", floatToFloatUnary(m, UNOP(-), Module::NEG_MAGIC_NAME)); // unary, single constant, bool->bool registerRule("bool-constant-inv", boolToBoolUnary(m, UNOP(!), Module::INVERT_MAGIC_NAME)); // binary, double constant, float->float registerRule("float-constant-addition", floatToFloatBinary(m, BINOP(+), Module::ADD_MAGIC_NAME)); registerRule("float-constant-subtraction", floatToFloatBinary(m, BINOP(-), Module::SUB_MAGIC_NAME)); if (pyNumerics) { registerRule("float-constant-floor-div", floatToFloatBinaryNoZeroRHS(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME)); } else { registerRule("float-constant-floor-div", floatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME)); } registerRule("float-constant-mul", floatToFloatBinary(m, BINOP(*), Module::MUL_MAGIC_NAME)); registerRule( "float-constant-pow", floatToFloatBinary( m, [](auto a, auto b) { return std::pow(a, b); }, Module::POW_MAGIC_NAME)); // binary, double constant, float->bool registerRule("float-constant-eq", floatToBoolBinary(m, BINOP(==), Module::EQ_MAGIC_NAME)); registerRule("float-constant-ne", floatToBoolBinary(m, BINOP(!=), Module::NE_MAGIC_NAME)); registerRule("float-constant-gt", floatToBoolBinary(m, BINOP(>), Module::GT_MAGIC_NAME)); registerRule("float-constant-ge", floatToBoolBinary(m, BINOP(>=), Module::GE_MAGIC_NAME)); registerRule("float-constant-lt", floatToBoolBinary(m, BINOP(<), Module::LT_MAGIC_NAME)); registerRule("float-constant-le", floatToBoolBinary(m, BINOP(<=), Module::LE_MAGIC_NAME)); // binary, double constant, int,float->float registerRule("int-float-constant-addition", intFloatToFloatBinary(m, BINOP(+), Module::ADD_MAGIC_NAME)); registerRule("int-float-constant-subtraction", intFloatToFloatBinary(m, BINOP(-), Module::SUB_MAGIC_NAME)); registerRule( "int-float-constant-floor-div", intFloatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME, pyNumerics)); registerRule("int-float-constant-mul", intFloatToFloatBinary(m, BINOP(*), Module::MUL_MAGIC_NAME)); // binary, double constant, int,float->bool registerRule("int-float-constant-eq", intFloatToBoolBinary(m, BINOP(==), Module::EQ_MAGIC_NAME)); registerRule("int-float-constant-ne", intFloatToBoolBinary(m, BINOP(!=), Module::NE_MAGIC_NAME)); registerRule("int-float-constant-gt", intFloatToBoolBinary(m, BINOP(>), Module::GT_MAGIC_NAME)); registerRule("int-float-constant-ge", intFloatToBoolBinary(m, BINOP(>=), Module::GE_MAGIC_NAME)); registerRule("int-float-constant-lt", intFloatToBoolBinary(m, BINOP(<), Module::LT_MAGIC_NAME)); registerRule("int-float-constant-le", intFloatToBoolBinary(m, BINOP(<=), Module::LE_MAGIC_NAME)); // type conversions, identity registerRule("int-constant-int", identityConvert(m, Module::INT_MAGIC_NAME, m->getIntType())); registerRule("float-constant-float", identityConvert(m, Module::FLOAT_MAGIC_NAME, m->getFloatType())); registerRule("bool-constant-bool", identityConvert(m, Module::BOOL_MAGIC_NAME, m->getBoolType())); // type conversions, distinct registerRule("float-constant-int", typeConvert(m, Module::INT_MAGIC_NAME, m->getFloatType(), m->getIntType())); registerRule("bool-constant-int", typeConvert(m, Module::INT_MAGIC_NAME, m->getBoolType(), m->getIntType())); registerRule("int-constant-float", typeConvert(m, Module::FLOAT_MAGIC_NAME, m->getIntType(), m->getFloatType())); registerRule("bool-constant-float", typeConvert(m, Module::FLOAT_MAGIC_NAME, m->getBoolType(), m->getFloatType())); registerRule("int-constant-bool", typeConvert(m, Module::BOOL_MAGIC_NAME, m->getIntType(), m->getBoolType())); registerRule("float-constant-bool", typeConvert(m, Module::BOOL_MAGIC_NAME, m->getFloatType(), m->getBoolType())); } } // namespace folding } // namespace transform } // namespace ir } // namespace codon #undef BINOP #undef UNOP ================================================ FILE: codon/cir/transform/folding/const_fold.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include "codon/cir/transform/folding/rule.h" #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace folding { class FoldingPass : public OperatorPass, public Rewriter { private: bool pyNumerics; void registerStandardRules(Module *m); public: /// Constructs a folding pass. FoldingPass(bool pyNumerics = false) : OperatorPass(/*childrenFirst=*/true), pyNumerics(pyNumerics) {} static const std::string KEY; std::string getKey() const override { return KEY; } void run(Module *m) override; void handle(CallInstr *v) override; }; } // namespace folding } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/folding/const_prop.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "const_prop.h" #include "codon/cir/analyze/dataflow/reaching.h" #include "codon/cir/analyze/module/global_vars.h" #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" namespace codon { namespace ir { namespace transform { namespace folding { namespace { bool okConst(const Value *v) { return v && (isA(v) || isA(v) || isA(v)); } } // namespace const std::string ConstPropPass::KEY = "core-folding-const-prop"; void ConstPropPass::handle(VarValue *v) { auto *M = v->getModule(); auto *var = v->getVar(); if (var->isThreadLocal()) return; Value *replacement; if (var->isGlobal()) { auto *r = getAnalysisResult(globalVarsKey); if (!r) return; auto it = r->assignments.find(var->getId()); if (it == r->assignments.end()) return; auto *constDef = M->getValue(it->second); if (!okConst(constDef)) return; util::CloneVisitor cv(M); replacement = cv.clone(constDef); } else { auto *r = getAnalysisResult(reachingDefKey); if (!r) return; auto *c = r->cfgResult; auto it = r->results.find(getParentFunc()->getId()); if (it == r->results.end()) return; auto *rd = it->second.get(); auto reaching = rd->getReachingDefinitions(var, v); if (reaching.size() != 1 || !reaching[0].known()) return; auto *constDef = reaching[0].assignee; if (!okConst(constDef)) return; util::CloneVisitor cv(M); replacement = cv.clone(constDef); } v->replaceAll(replacement); } void ConstPropPass::handle(ExtractInstr *v) { // Propagate constant tuples if (!isA(v->getVal()) || !isA(v->getVal()->getType())) return; auto *var = cast(v->getVal())->getVar(); if (var->isGlobal()) return; auto *r = getAnalysisResult(reachingDefKey); if (!r) return; auto *c = r->cfgResult; auto it = r->results.find(getParentFunc()->getId()); if (it == r->results.end()) return; auto *rd = it->second.get(); auto reaching = rd->getReachingDefinitions(var, v); if (reaching.size() != 1 || !reaching[0].known()) return; auto *call = cast(reaching[0].assignee); if (!call) return; auto *func = util::getFunc(call->getCallee()); if (!func || func->getUnmangledName() != Module::NEW_MAGIC_NAME) return; auto *tuple = cast(func->getParentType()); if (!tuple || tuple->getName() != "Tuple") return; auto idx = cast(v->getVal()->getType())->getMemberIndex(v->getField()); if (idx < 0 || idx >= call->numArgs() || !okConst(*(call->begin() + idx))) return; util::CloneVisitor cv(v->getModule()); auto *replacement = cv.clone(*(call->begin() + idx)); v->replaceAll(replacement); } } // namespace folding } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/folding/const_prop.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace folding { /// Constant propagation pass. class ConstPropPass : public OperatorPass { private: /// Key of the reaching definition analysis std::string reachingDefKey; /// Key of the global variables analysis std::string globalVarsKey; public: static const std::string KEY; /// Constructs a constant propagation pass. /// @param reachingDefKey the reaching definition analysis' key /// @param globalVarsKey global variables analysis' key ConstPropPass(const std::string &reachingDefKey, const std::string &globalVarsKey) : reachingDefKey(reachingDefKey), globalVarsKey(globalVarsKey) {} std::string getKey() const override { return KEY; } void handle(VarValue *v) override; void handle(ExtractInstr *v) override; }; } // namespace folding } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/folding/folding.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "folding.h" #include "codon/cir/transform/folding/const_fold.h" #include "codon/cir/transform/folding/const_prop.h" namespace codon { namespace ir { namespace transform { namespace folding { const std::string FoldingPassGroup::KEY = "core-folding-pass-group"; FoldingPassGroup::FoldingPassGroup(const std::string &sideEffectsPass, const std::string &reachingDefPass, const std::string &globalVarPass, int repeat, bool runGlobalDemotion, bool pyNumerics) : PassGroup(repeat) { auto gdUnique = runGlobalDemotion ? std::make_unique() : std::unique_ptr(); auto canonUnique = std::make_unique(sideEffectsPass); auto fpUnique = std::make_unique(pyNumerics); auto dceUnique = std::make_unique(sideEffectsPass); gd = gdUnique.get(); canon = canonUnique.get(); fp = fpUnique.get(); dce = dceUnique.get(); if (runGlobalDemotion) push_back(std::move(gdUnique)); push_back(std::make_unique(reachingDefPass, globalVarPass)); push_back(std::move(canonUnique)); push_back(std::move(fpUnique)); push_back(std::move(dceUnique)); } bool FoldingPassGroup::shouldRepeat(int num) const { return PassGroup::shouldRepeat(num) && ((gd && gd->getNumDemotions() != 0) || canon->getNumReplacements() != 0 || fp->getNumReplacements() != 0 || dce->getNumReplacements() != 0); } } // namespace folding } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/folding/folding.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/cleanup/canonical.h" #include "codon/cir/transform/cleanup/dead_code.h" #include "codon/cir/transform/cleanup/global_demote.h" #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace folding { class FoldingPass; /// Group of constant folding passes. class FoldingPassGroup : public PassGroup { private: cleanup::GlobalDemotionPass *gd; cleanup::CanonicalizationPass *canon; FoldingPass *fp; cleanup::DeadCodeCleanupPass *dce; public: static const std::string KEY; std::string getKey() const override { return KEY; } /// @param sideEffectsPass the key of the side effects pass /// @param reachingDefPass the key of the reaching definitions pass /// @param globalVarPass the key of the global variables pass /// @param repeat default number of times to repeat the pass /// @param runGlobalDemotion whether to demote globals if possible /// @param pyNumerics whether to use Python (vs. C) semantics when folding FoldingPassGroup(const std::string &sideEffectsPass, const std::string &reachingDefPass, const std::string &globalVarPass, int repeat = 5, bool runGlobalDemotion = true, bool pyNumerics = false); bool shouldRepeat(int num) const override; }; } // namespace folding } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/folding/rule.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include "codon/cir/transform/pass.h" #include "codon/cir/transform/rewrite.h" #include "codon/cir/util/irtools.h" namespace codon { namespace ir { namespace transform { namespace folding { /// Commutative, binary rule that requires a single constant. template class SingleConstantCommutativeRule : public RewriteRule { public: using Calculator = std::function; enum Kind { LEFT, RIGHT, COMMUTATIVE }; private: /// the value being matched against ConstantType val; /// the type being matched types::Type *type; /// the magic method name std::string magic; /// the calculator Calculator calc; /// left, right or commutative Kind kind; public: /// Constructs a commutative rule. /// @param val the matched value /// @param newVal the result /// @param magic the magic name /// @param kind left, right, or commutative /// @param type the matched type SingleConstantCommutativeRule(ConstantType val, ConstantType newVal, std::string magic, Kind kind, types::Type *type) : val(val), type(type), magic(std::move(magic)), kind(kind) { calc = [=](Value *v) -> Value * { return v->getModule()->N>(v->getSrcInfo(), val, type); }; } /// Constructs a commutative rule. /// @param val the matched value /// @param newVal the result /// @param magic the magic name /// @param calc the calculator /// @param kind left, right, or commutative /// @param type the matched type SingleConstantCommutativeRule(ConstantType val, Calculator calc, std::string magic, Kind kind, types::Type *type) : val(val), type(type), magic(std::move(magic)), calc(std::move(calc)), kind(kind) {} virtual ~SingleConstantCommutativeRule() noexcept = default; void visit(CallInstr *v) override { if (!util::isCallOf(v, magic, {type, type}, type, /*method=*/true)) return; auto *left = v->front(); auto *right = v->back(); auto *leftConst = cast>(left); auto *rightConst = cast>(right); if ((kind == Kind::COMMUTATIVE || kind == Kind::LEFT) && leftConst && leftConst->getVal() == val) return setResult(calc(right)); if ((kind == Kind::COMMUTATIVE || kind == Kind::RIGHT) && rightConst && rightConst->getVal() == val) return setResult(calc(left)); } }; /// Binary rule that requires two constants. template class DoubleConstantBinaryRule : public RewriteRule { private: /// the calculator Func f; /// the input type types::Type *inputType; /// the output type types::Type *resultType; /// the magic method name std::string magic; public: /// Constructs a binary rule. /// @param f the calculator /// @param magic the magic method name /// @param inputType the input type /// @param resultType the output type DoubleConstantBinaryRule(Func f, std::string magic, types::Type *inputType, types::Type *resultType) : f(std::move(f)), inputType(inputType), resultType(resultType), magic(std::move(magic)) {} virtual ~DoubleConstantBinaryRule() noexcept = default; void visit(CallInstr *v) override { if (!util::isCallOf(v, magic, {inputType, inputType}, resultType, /*method=*/true)) return; auto *left = v->front(); auto *right = v->back(); auto *leftConst = cast>(left); auto *rightConst = cast>(right); if (leftConst && rightConst) return setResult(toValue(v, f(leftConst->getVal(), rightConst->getVal()))); } private: Value *toValue(Value *, Value *v) { return v; } Value *toValue(Value *og, OutputType v) { return og->getModule()->template N>(og->getSrcInfo(), v, resultType); } }; /// Unary rule that requires one constant. template class SingleConstantUnaryRule : public RewriteRule { private: /// the calculator Func f; /// the input type types::Type *inputType; /// the output type types::Type *resultType; /// the magic method name std::string magic; public: /// Constructs a unary rule. /// @param f the calculator /// @param magic the magic method name /// @param inputType the input type /// @param resultType the output type SingleConstantUnaryRule(Func f, std::string magic, types::Type *inputType, types::Type *resultType) : f(std::move(f)), inputType(inputType), resultType(resultType), magic(std::move(magic)) {} virtual ~SingleConstantUnaryRule() noexcept = default; void visit(CallInstr *v) override { if (!util::isCallOf(v, magic, {inputType}, resultType, /*method=*/true)) return; auto *arg = v->front(); auto *argConst = cast>(arg); if (argConst) return setResult(toValue(v, f(argConst->getVal()))); } private: Value *toValue(Value *, Value *v) { return v; } template Value *toValue(Value *og, NewType v) { return og->getModule()->template N>(og->getSrcInfo(), v, resultType); } }; /// Unary rule that requires no constant. template class UnaryRule : public RewriteRule { private: /// the calculator Func f; /// the input type types::Type *inputType; /// the magic method name std::string magic; public: /// Constructs a unary rule. /// @param f the calculator /// @param magic the magic method name /// @param inputType the input type UnaryRule(Func f, std::string magic, types::Type *inputType) : f(std::move(f)), inputType(inputType), magic(std::move(magic)) {} virtual ~UnaryRule() noexcept = default; void visit(CallInstr *v) override { if (!util::isCallOf(v, magic, {inputType}, inputType, /*method=*/true)) return; auto *arg = v->front(); return setResult(f(arg)); } }; /// Rule that eliminates an operation, like "+x". class NoOpRule : public RewriteRule { private: /// the input type types::Type *inputType; /// the magic method name std::string magic; public: /// Constructs a no-op rule. /// @param magic the magic method name /// @param inputType the input type NoOpRule(std::string magic, types::Type *inputType) : inputType(inputType), magic(std::move(magic)) {} virtual ~NoOpRule() noexcept = default; void visit(CallInstr *v) override { if (!util::isCallOf(v, magic, {inputType}, inputType, /*method=*/true)) return; auto *arg = v->front(); return setResult(arg); } }; /// Rule that eliminates a double-application of an operation, like "-(-x)". class DoubleApplicationNoOpRule : public RewriteRule { private: /// the input type types::Type *inputType; /// the magic method name std::string magic; public: /// Constructs a double-application no-op rule. /// @param magic the magic method name /// @param inputType the input type DoubleApplicationNoOpRule(std::string magic, types::Type *inputType) : inputType(inputType), magic(std::move(magic)) {} virtual ~DoubleApplicationNoOpRule() noexcept = default; void visit(CallInstr *v) override { if (!util::isCallOf(v, magic, {inputType}, inputType, /*method=*/true)) return; if (!util::isCallOf(v->front(), magic, {inputType}, inputType, /*method=*/true)) return; auto *arg = v->front(); return setResult(cast(arg)->front()); } }; } // namespace folding } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/lowering/async_for.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "async_for.h" #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" namespace codon { namespace ir { namespace transform { namespace lowering { const std::string AsyncForLowering::KEY = "core-async-for-lowering"; void AsyncForLowering::handle(ForFlow *v) { if (!v->isAsync()) return; auto *M = v->getModule(); auto *coro = v->getIter(); auto *coroType = coro->getType(); util::CloneVisitor cv(M); auto *promise = M->getOrRealizeFunc("_promise", {coroType}, {}, "std.asyncio"); seqassertn(promise, "promise func not found"); auto *done = M->getOrRealizeFunc("_done", {coroType}, {}, "std.asyncio"); seqassertn(done, "done func not found"); auto *resume = M->getOrRealizeFunc("_resume", {coroType}, {}, "std.asyncio"); seqassertn(resume, "resume func not found"); auto *currTask = M->getOrRealizeFunc("_curr_task", {}, {}, "std.asyncio"); seqassertn(currTask, "curr-task func not found"); auto *waiting = M->getOrRealizeFunc("_is_waiting", {util::getReturnType(currTask)}, {}, "std.asyncio"); seqassertn(waiting, "is-waiting func not found"); // Construct the following: // task = curr_task() // if not coro.__done__(): // while True: // coro.__resume__() // if coro.__done__(): // break // if is_waiting(task): // yield // else: // i = coro.__promise__() // auto *series = M->Nr(); auto *taskVar = util::makeVar(util::call(currTask, {}), series, cast(getParentFunc())); auto *coroVar = util::makeVar(cv.clone(coro), series, cast(getParentFunc())); series->push_back(M->Nr( ~*util::call(done, {M->Nr(coroVar)}), util::series(M->Nr( M->getBool(true), util::series( util::call(resume, {M->Nr(coroVar)}), M->Nr(util::call(done, {M->Nr(coroVar)}), util::series(M->Nr())), M->Nr( util::call(waiting, {M->Nr(taskVar)}), util::series(M->Nr()), util::series( M->Nr( v->getVar(), util::call(promise, {M->Nr(coroVar)})), cv.clone(v->getBody())))))))); v->replaceAll(series); } } // namespace lowering } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/lowering/async_for.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace lowering { class AsyncForLowering : public OperatorPass { public: static const std::string KEY; std::string getKey() const override { return KEY; } void handle(ForFlow *v) override; }; } // namespace lowering } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/lowering/await.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "await.h" #include #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon { namespace ir { namespace transform { namespace lowering { namespace { bool isFuture(const types::Type *type) { return type->getName().rfind("std.asyncio.Future.", 0) == 0; } bool isTask(const types::Type *type) { return type->getName().rfind("std.asyncio.Task.", 0) == 0; } const types::GeneratorType *isCoroutine(const types::Type *type) { return cast(type); } } // namespace const std::string AwaitLowering::KEY = "core-await-lowering"; void AwaitLowering::handle(AwaitInstr *v) { auto *M = v->getModule(); auto *value = v->getValue(); auto *resultType = v->getType(); auto *valueType = value->getType(); util::CloneVisitor cv(M); if (isFuture(valueType) || isTask(valueType)) { auto *getResult = M->getOrRealizeMethod(valueType, "result", {valueType}); seqassertn(getResult, "get-result method not found"); auto *waitOn = M->getOrRealizeFunc("_wait_on", {valueType}, {}, "std.asyncio"); seqassertn(waitOn, "wait-on function not found"); auto *cancelCheck = M->getOrRealizeFunc("_cancel_checkpoint", {}, {}, "std.asyncio"); seqassertn(cancelCheck, "cancel-checkpoint function not found"); // Construct the following: // cancel_checkpoint() // if _wait_on(value, future): // yield // cancel_checkpoint() // future.result() auto *series = M->Nr(); auto *futureVar = util::makeVar(cv.clone(value), series, cast(getParentFunc())); series->push_back(util::call(cancelCheck, {})); series->push_back( M->Nr(util::call(waitOn, {M->Nr(futureVar)}), util::series(M->Nr(), util::call(cancelCheck, {})))); auto *replacement = M->Nr(series, util::call(getResult, {M->Nr(futureVar)})); v->replaceAll(replacement); } else if (auto *genType = isCoroutine(valueType)) { auto *promise = M->getOrRealizeFunc("_promise", {valueType}, {}, "std.asyncio"); seqassertn(promise, "promise function not found"); // Construct the following: // for _ in coro: // yield // coro.__promise__() auto *series = M->Nr(); auto *coroVar = util::makeVar(cv.clone(value), series, cast(getParentFunc())); auto *var = M->Nr(genType->getBase(), /*global=*/false); cast(getParentFunc())->push_back(var); SeriesFlow *body; if (v->isGenerator()) { auto *requeue = M->getOrRealizeFunc("_requeue", {}, {}, "std.asyncio"); seqassertn(requeue, "requeue function not found"); body = util::series(util::call(requeue, {}), M->Nr()); } else { body = util::series(M->Nr()); } series->push_back(M->Nr(M->Nr(coroVar), body, var)); auto *replacement = M->Nr(series, util::call(promise, {M->Nr(coroVar)})); v->replaceAll(replacement); } else { seqassertn(false, "unexpected value type '{}' in await instruction", valueType->getName()); } } } // namespace lowering } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/lowering/await.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace lowering { class AwaitLowering : public OperatorPass { public: static const std::string KEY; std::string getKey() const override { return KEY; } void handle(AwaitInstr *v) override; }; } // namespace lowering } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/lowering/imperative.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "imperative.h" #include #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/matching.h" namespace codon { namespace ir { namespace transform { namespace lowering { namespace { CallInstr *getRangeIter(Value *iter) { auto *M = iter->getModule(); auto *iterCall = cast(iter); if (!iterCall || iterCall->numArgs() != 1) return nullptr; auto *iterFunc = util::getFunc(iterCall->getCallee()); if (!iterFunc || iterFunc->getUnmangledName() != Module::ITER_MAGIC_NAME) return nullptr; auto *rangeCall = cast(iterCall->front()); if (!rangeCall) return nullptr; auto *newRangeFunc = util::getFunc(rangeCall->getCallee()); if (!newRangeFunc || newRangeFunc->getUnmangledName() != Module::NEW_MAGIC_NAME) return nullptr; auto *parentType = newRangeFunc->getParentType(); auto *rangeType = M->getOrRealizeType(ast::getMangledClass("std.internal.types.range", "range")); if (!parentType || !rangeType || parentType->getName() != rangeType->getName()) return nullptr; return rangeCall; } Value *getListIter(Value *iter) { auto *iterCall = cast(iter); if (!iterCall || iterCall->numArgs() != 1) return nullptr; auto *iterFunc = util::getFunc(iterCall->getCallee()); if (!iterFunc || iterFunc->getUnmangledName() != Module::ITER_MAGIC_NAME) return nullptr; auto *list = iterCall->front(); if (list->getType()->getName().rfind( ast::getMangledClass("std.internal.types.array", "List") + "[", 0) != 0) return nullptr; return list; } } // namespace const std::string ImperativeForFlowLowering::KEY = "core-imperative-for-lowering"; void ImperativeForFlowLowering::handle(ForFlow *v) { auto *M = v->getModule(); auto *iter = v->getIter(); std::unique_ptr sched; if (v->isParallel()) sched = std::make_unique(*v->getSchedule()); if (auto *rangeCall = getRangeIter(iter)) { auto it = rangeCall->begin(); auto argCount = std::distance(it, rangeCall->end()); util::CloneVisitor cv(M); IntConst *stepConst; Value *start; Value *end; int64_t step = 0; switch (argCount) { case 1: start = M->getInt(0); end = cv.clone(*it); step = 1; break; case 2: start = cv.clone(*it++); end = cv.clone(*it); step = 1; break; case 3: start = cv.clone(*it++); end = cv.clone(*it++); stepConst = cast(*it); if (!stepConst) return; step = stepConst->getVal(); break; default: seqassertn(false, "unknown range constructor"); } if (step == 0) return; v->replaceAll(M->N(v->getSrcInfo(), start, step, end, v->getBody(), v->getVar(), std::move(sched))); } else if (auto *list = getListIter(iter)) { // convert: // for a in list: // body // into: // v = list // n = v.len // p = v.arr.ptr // imp_for i in range(0, n, 1): // a = p[i] // body auto *parent = cast(getParentFunc()); auto *series = M->N(v->getSrcInfo()); auto *listVar = util::makeVar(list, series, parent); auto *lenVal = M->Nr(M->Nr(listVar), "len"); auto *lenVar = util::makeVar(lenVal, series, parent); auto *ptrVal = M->Nr( M->Nr(M->Nr(listVar), "arr"), "ptr"); auto *ptrVar = util::makeVar(ptrVal, series, parent); auto *body = cast(v->getBody()); seqassertn(body, "loop body is not a series flow [{}]", v->getSrcInfo()); auto *oldLoopVar = v->getVar(); auto *newLoopVar = M->Nr(M->getIntType()); parent->push_back(newLoopVar); auto *replacement = M->N(v->getSrcInfo(), M->getInt(0), 1, M->Nr(lenVar), body, newLoopVar, std::move(sched)); series->push_back(replacement); body->insert( body->begin(), M->Nr(oldLoopVar, (*M->Nr(ptrVar))[*M->Nr(newLoopVar)])); v->replaceAll(series); } } } // namespace lowering } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/lowering/imperative.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace lowering { class ImperativeForFlowLowering : public OperatorPass { public: static const std::string KEY; std::string getKey() const override { return KEY; } void handle(ForFlow *v) override; }; } // namespace lowering } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/lowering/pipeline.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "pipeline.h" #include #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/matching.h" namespace codon { namespace ir { namespace transform { namespace lowering { namespace { Value *callStage(Module *M, PipelineFlow::Stage *stage, Value *last) { std::vector args; for (auto *arg : *stage) { args.push_back(arg ? arg : last); } return M->N(stage->getCallee()->getSrcInfo(), stage->getCallee(), args); } Value *convertPipelineToForLoopsHelper(Module *M, BodiedFunc *parent, const std::vector &stages, unsigned idx = 0, Value *last = nullptr) { if (idx >= stages.size()) return last; auto *stage = stages[idx]; if (idx == 0) return convertPipelineToForLoopsHelper(M, parent, stages, idx + 1, stage->getCallee()); auto *prev = stages[idx - 1]; if (prev->isGenerator()) { auto *var = M->Nr(prev->getOutputElementType()); parent->push_back(var); auto *body = convertPipelineToForLoopsHelper( M, parent, stages, idx + 1, callStage(M, stage, M->Nr(var))); auto *loop = M->N(last->getSrcInfo(), last, util::series(body), var); if (stage->isParallel()) loop->setParallel(); return loop; } else { return convertPipelineToForLoopsHelper(M, parent, stages, idx + 1, callStage(M, stage, last)); } } Value *convertPipelineToForLoops(PipelineFlow *p, BodiedFunc *parent) { std::vector stages; for (auto &stage : *p) { stages.push_back(&stage); } return convertPipelineToForLoopsHelper(p->getModule(), parent, stages); } } // namespace const std::string PipelineLowering::KEY = "core-pipeline-lowering"; void PipelineLowering::handle(PipelineFlow *v) { v->replaceAll(convertPipelineToForLoops(v, cast(getParentFunc()))); } } // namespace lowering } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/lowering/pipeline.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace lowering { /// Converts pipelines to for-loops class PipelineLowering : public OperatorPass { public: static const std::string KEY; std::string getKey() const override { return KEY; } void handle(PipelineFlow *v) override; }; } // namespace lowering } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/manager.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "manager.h" #include #include "codon/cir/analyze/analysis.h" #include "codon/cir/analyze/dataflow/capture.h" #include "codon/cir/analyze/dataflow/cfg.h" #include "codon/cir/analyze/dataflow/dominator.h" #include "codon/cir/analyze/dataflow/reaching.h" #include "codon/cir/analyze/module/global_vars.h" #include "codon/cir/analyze/module/side_effect.h" #include "codon/cir/transform/folding/folding.h" #include "codon/cir/transform/lowering/async_for.h" #include "codon/cir/transform/lowering/await.h" #include "codon/cir/transform/lowering/imperative.h" #include "codon/cir/transform/lowering/pipeline.h" #include "codon/cir/transform/manager.h" #include "codon/cir/transform/numpy/indexing.h" #include "codon/cir/transform/numpy/numpy.h" #include "codon/cir/transform/parallel/openmp.h" #include "codon/cir/transform/pass.h" #include "codon/cir/transform/pythonic/dict.h" #include "codon/cir/transform/pythonic/generator.h" #include "codon/cir/transform/pythonic/io.h" #include "codon/cir/transform/pythonic/list.h" #include "codon/cir/transform/pythonic/str.h" #include "codon/util/common.h" namespace codon { namespace ir { namespace transform { std::string PassManager::KeyManager::getUniqueKey(const std::string &key) { // make sure we can't ever produce duplicate "unique'd" keys seqassertn(key.find(':') == std::string::npos, "pass key '{}' contains invalid character ':'", key); auto it = keys.find(key); if (it == keys.end()) { keys.emplace(key, 1); return key; } else { auto id = ++(it->second); return key + ":" + std::to_string(id); } } std::string PassManager::registerPass(std::unique_ptr pass, const std::string &insertBefore, std::vector reqs, std::vector invalidates) { std::string key = pass->getKey(); if (isDisabled(key)) return ""; key = km.getUniqueKey(key); for (const auto &req : reqs) { seqassertn(deps.find(req) != deps.end(), "required key '{}' not found", req); deps[req].push_back(key); } passes.insert(std::make_pair( key, PassMetadata(std::move(pass), std::move(reqs), std::move(invalidates)))); passes[key].pass->setManager(this); if (insertBefore.empty()) { executionOrder.push_back(key); } else { auto it = std::find(executionOrder.begin(), executionOrder.end(), insertBefore); seqassertn(it != executionOrder.end(), "pass with key '{}' not found in manager", insertBefore); executionOrder.insert(it, key); } return key; } std::string PassManager::registerAnalysis(std::unique_ptr analysis, std::vector reqs) { std::string key = analysis->getKey(); if (isDisabled(key)) return ""; key = km.getUniqueKey(key); for (const auto &req : reqs) { seqassertn(deps.find(req) != deps.end(), "required key '{}' not found", req); deps[req].push_back(key); } analyses.insert( std::make_pair(key, AnalysisMetadata(std::move(analysis), std::move(reqs)))); analyses[key].analysis->setManager(this); deps[key] = {}; return key; } void PassManager::run(Module *module) { for (auto &p : executionOrder) { runPass(module, p); } } void PassManager::runPass(Module *module, const std::string &name) { auto &meta = passes[name]; auto run = true; auto it = 0; while (run) { for (auto &dep : meta.reqs) { runAnalysis(module, dep); } Timer timer(" ir pass : " + meta.pass->getKey()); meta.pass->run(module); timer.log(); for (auto &inv : meta.invalidates) invalidate(inv); run = meta.pass->shouldRepeat(++it); } } void PassManager::runAnalysis(Module *module, const std::string &name) { if (results.find(name) != results.end()) return; auto &meta = analyses[name]; for (auto &dep : meta.reqs) { runAnalysis(module, dep); } Timer timer(" ir analysis: " + meta.analysis->getKey()); results[name] = meta.analysis->run(module); timer.log(); } void PassManager::invalidate(const std::string &key) { std::unordered_set open = {key}; while (!open.empty()) { std::unordered_set newOpen; for (const auto &k : open) { if (results.find(k) != results.end()) { results.erase(k); newOpen.insert(deps[k].begin(), deps[k].end()); } } open = std::move(newOpen); } } void PassManager::registerStandardPasses(PassManager::Init init) { switch (init) { case Init::EMPTY: break; case Init::DEBUG: { registerPass(std::make_unique()); registerPass(std::make_unique()); registerPass(std::make_unique()); registerPass(std::make_unique()); registerPass(std::make_unique()); break; } case Init::RELEASE: case Init::JIT: { // Pythonic registerPass(std::make_unique()); registerPass(std::make_unique()); registerPass(std::make_unique()); registerPass(std::make_unique()); registerPass(std::make_unique()); // lowering registerPass(std::make_unique()); registerPass(std::make_unique()); // folding auto cfgKey = registerAnalysis(std::make_unique()); auto rdKey = registerAnalysis( std::make_unique(cfgKey), {cfgKey}); auto domKey = registerAnalysis( std::make_unique(cfgKey), {cfgKey}); auto capKey = registerAnalysis( std::make_unique(rdKey, domKey), {rdKey, domKey}); auto globalKey = registerAnalysis(std::make_unique()); auto seKey1 = registerAnalysis(std::make_unique( capKey, /*globalAssignmentHasSideEffects=*/true), {capKey}); auto seKey2 = registerAnalysis(std::make_unique( capKey, /*globalAssignmentHasSideEffects=*/false), {capKey}); registerPass(std::make_unique( seKey1, rdKey, globalKey, /*repeat=*/5, /*runGlobalDemoton=*/false, pyNumerics), /*insertBefore=*/"", {seKey1, rdKey, globalKey}, {seKey1, rdKey, cfgKey, globalKey, capKey}); registerPass(std::make_unique(rdKey, seKey2), /*insertBefore=*/"", {rdKey, seKey2}, {seKey1, rdKey, cfgKey, globalKey, capKey}); registerPass(std::make_unique()); registerPass(std::make_unique(rdKey), /*insertBefore=*/"", {rdKey}, {seKey1, rdKey, cfgKey, globalKey, capKey}); // async & parallel registerPass(std::make_unique()); registerPass(std::make_unique()); registerPass(std::make_unique(), /*insertBefore=*/"", {}, {cfgKey, globalKey}); if (init != Init::JIT) { // Don't demote globals in JIT mode, since they might be used later // by another user input. registerPass(std::make_unique( seKey2, rdKey, globalKey, /*repeat=*/5, /*runGlobalDemoton=*/true, pyNumerics), /*insertBefore=*/"", {seKey2, rdKey, globalKey}, {seKey2, rdKey, cfgKey, globalKey}); } break; } default: seqassertn(false, "unknown PassManager init value"); } } } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/manager.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include #include "codon/cir/analyze/analysis.h" #include "codon/cir/module.h" #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { /// Utility class to run a series of passes. class PassManager { private: /// Manager for keys of passes. class KeyManager { private: /// mapping of raw key to number of occurences std::unordered_map keys; public: KeyManager() = default; /// Returns a unique'd key for a given raw key. /// Does so by appending ":" if the key /// has been seen. /// @param key the raw key /// @return the unique'd key std::string getUniqueKey(const std::string &key); }; /// Container for pass metadata. struct PassMetadata { /// pointer to the pass instance std::unique_ptr pass; /// vector of required analyses std::vector reqs; /// vector of invalidated analyses std::vector invalidates; PassMetadata() = default; PassMetadata(std::unique_ptr pass, std::vector reqs, std::vector invalidates) : pass(std::move(pass)), reqs(std::move(reqs)), invalidates(std::move(invalidates)) {} PassMetadata(PassMetadata &&) = default; PassMetadata &operator=(PassMetadata &&) = default; }; /// Container for analysis metadata. struct AnalysisMetadata { /// pointer to the analysis instance std::unique_ptr analysis; /// vector of required analyses std::vector reqs; /// vector of invalidated analyses std::vector invalidates; AnalysisMetadata() = default; AnalysisMetadata(std::unique_ptr analysis, std::vector reqs) : analysis(std::move(analysis)), reqs(std::move(reqs)) {} AnalysisMetadata(AnalysisMetadata &&) = default; AnalysisMetadata &operator=(AnalysisMetadata &&) = default; }; /// key manager to handle duplicate keys (i.e. passes being added twice) KeyManager km; /// map of keys to passes std::unordered_map passes; /// map of keys to analyses std::unordered_map analyses; /// reverse dependency map std::unordered_map> deps; /// execution order of passes std::vector executionOrder; /// map of valid analysis results std::unordered_map> results; /// passes to avoid registering std::vector disabled; /// whether to use Python (vs. C) numeric semantics in passes bool pyNumerics; /// true if we are compiling as a Python extension bool pyExtension; public: /// PassManager initialization mode. enum Init { EMPTY, DEBUG, RELEASE, JIT, }; explicit PassManager(Init init, std::vector disabled = {}, bool pyNumerics = false, bool pyExtension = false) : km(), passes(), analyses(), executionOrder(), results(), disabled(std::move(disabled)), pyNumerics(pyNumerics), pyExtension(pyExtension) { registerStandardPasses(init); } explicit PassManager(bool debug = false, std::vector disabled = {}, bool pyNumerics = false, bool pyExtension = false) : PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled), pyNumerics, pyExtension) {} /// Checks if the given pass is included in this manager. /// @param key the pass key /// @return true if manager has the given pass bool hasPass(const std::string &key) { for (auto &pair : passes) { if (pair.first == key) return true; } return false; } /// Checks if the given analysis is included in this manager. /// @param key the analysis key /// @return true if manager has the given analysis bool hasAnalysis(const std::string &key) { for (auto &pair : analyses) { if (pair.first == key) return true; } return false; } /// Registers a pass and appends it to the execution order. /// @param pass the pass /// @param insertBefore insert pass before the pass with this given key /// @param reqs keys of passes that must be run before the current one /// @param invalidates keys of passes that are invalidated by the current one /// @return unique'd key for the added pass, or empty string if not added std::string registerPass(std::unique_ptr pass, const std::string &insertBefore = "", std::vector reqs = {}, std::vector invalidates = {}); /// Registers an analysis. /// @param analysis the analysis /// @param reqs keys of analyses that must be run before the current one /// @return unique'd key for the added analysis, or empty string if not added std::string registerAnalysis(std::unique_ptr analysis, std::vector reqs = {}); /// Run all passes. /// @param module the module void run(Module *module); /// Gets the result of a given analysis. /// @param key the (unique'd) analysis key /// @return the result analyze::Result *getAnalysisResult(const std::string &key) { auto it = results.find(key); return it != results.end() ? it->second.get() : nullptr; } /// Returns whether a given pass or analysis is disabled. /// @param key the (unique'd) pass or analysis key /// @return true if the pass or analysis is disabled bool isDisabled(const std::string &key) { return std::find(disabled.begin(), disabled.end(), key) != disabled.end(); } private: void runPass(Module *module, const std::string &name); void registerStandardPasses(Init init); void runAnalysis(Module *module, const std::string &name); void invalidate(const std::string &key); }; } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/numpy/expr.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "numpy.h" #include "codon/cir/util/irtools.h" namespace codon { namespace ir { namespace transform { namespace numpy { namespace { types::Type *coerceScalarArray(NumPyType &scalar, NumPyType &array, NumPyPrimitiveTypes &T) { auto xtype = scalar.dtype; auto atype = array.dtype; bool aIsInt = false; bool xIsInt = false; bool aIsFloat = false; bool xIsFloat = false; bool aIsComplex = false; bool xIsComplex = false; switch (atype) { case NumPyType::NP_TYPE_ARR_BOOL: break; case NumPyType::NP_TYPE_ARR_I8: case NumPyType::NP_TYPE_ARR_U8: case NumPyType::NP_TYPE_ARR_I16: case NumPyType::NP_TYPE_ARR_U16: case NumPyType::NP_TYPE_ARR_I32: case NumPyType::NP_TYPE_ARR_U32: case NumPyType::NP_TYPE_ARR_I64: case NumPyType::NP_TYPE_ARR_U64: aIsInt = true; break; case NumPyType::NP_TYPE_ARR_F16: case NumPyType::NP_TYPE_ARR_F32: case NumPyType::NP_TYPE_ARR_F64: aIsFloat = true; break; case NumPyType::NP_TYPE_ARR_C64: case NumPyType::NP_TYPE_ARR_C128: aIsComplex = true; break; default: seqassertn(false, "unexpected type"); } xIsInt = (xtype == NumPyType::NP_TYPE_BOOL || xtype == NumPyType::NP_TYPE_I64); xIsFloat = (xtype == NumPyType::NP_TYPE_F64); xIsComplex = (xtype == NumPyType::NP_TYPE_C128); bool shouldCast = ((xIsInt && (aIsInt || aIsFloat || aIsComplex)) || (xIsFloat && (aIsFloat || aIsComplex)) || (xIsComplex && aIsComplex)); if ((atype == NumPyType::NP_TYPE_ARR_F16 || atype == NumPyType::NP_TYPE_ARR_F32) && xtype == NumPyType::NP_TYPE_C128) return T.c64; else if (shouldCast) return array.getIRBaseType(T); else return scalar.getIRBaseType(T); } bool isPythonScalar(NumPyType &t) { if (t.isArray()) return false; auto dt = t.dtype; return (dt == NumPyType::NP_TYPE_BOOL || dt == NumPyType::NP_TYPE_I64 || dt == NumPyType::NP_TYPE_F64 || dt == NumPyType::NP_TYPE_C128); } template types::Type *decideTypes(E *expr, NumPyType &lhs, NumPyType &rhs, NumPyPrimitiveTypes &T) { // Special case(s) if (expr->op == E::NP_OP_COPYSIGN) return expr->type.getIRBaseType(T); if (lhs.isArray() && isPythonScalar(rhs)) return coerceScalarArray(rhs, lhs, T); if (isPythonScalar(lhs) && rhs.isArray()) return coerceScalarArray(lhs, rhs, T); auto *t1 = lhs.getIRBaseType(T); auto *t2 = rhs.getIRBaseType(T); auto *M = t1->getModule(); auto *coerceFunc = M->getOrRealizeFunc("_coerce", {}, {t1, t2}, FUSION_MODULE); seqassertn(coerceFunc, "coerce func not found"); return util::getReturnType(coerceFunc); } } // namespace void NumPyExpr::replace(NumPyExpr &e) { type = e.type; val = e.val; op = e.op; lhs = std::move(e.lhs); rhs = std::move(e.rhs); freeable = e.freeable; e.type = {}; e.val = nullptr; e.op = NP_OP_NONE; e.lhs = {}; e.rhs = {}; e.freeable = false; } bool NumPyExpr::haveVectorizedLoop() const { if (lhs && !(lhs->type.dtype == NumPyType::NP_TYPE_ARR_F32 || lhs->type.dtype == NumPyType::NP_TYPE_ARR_F64)) return false; if (rhs && !(rhs->type.dtype == NumPyType::NP_TYPE_ARR_F32 || rhs->type.dtype == NumPyType::NP_TYPE_ARR_F64)) return false; if (lhs && rhs && lhs->type.dtype != rhs->type.dtype) return false; // These are the loops available in the runtime library. static const std::vector VecLoops = { "arccos", "arccosh", "arcsin", "arcsinh", "arctan", "arctanh", "arctan2", "cos", "exp", "exp2", "expm1", "log", "log10", "log1p", "log2", "sin", "sinh", "tanh", "hypot"}; return std::find(VecLoops.begin(), VecLoops.end(), opstring()) != VecLoops.end(); } int64_t NumPyExpr::opcost() const { switch (op) { case NP_OP_NONE: return 0; case NP_OP_POS: return 0; case NP_OP_NEG: return 0; case NP_OP_INVERT: return 0; case NP_OP_ABS: return 1; case NP_OP_TRANSPOSE: return 0; case NP_OP_ADD: return 1; case NP_OP_SUB: return 1; case NP_OP_MUL: return 1; case NP_OP_MATMUL: return 20; case NP_OP_TRUE_DIV: return 8; case NP_OP_FLOOR_DIV: return 8; case NP_OP_MOD: return 8; case NP_OP_FMOD: return 8; case NP_OP_POW: return 8; case NP_OP_LSHIFT: return 1; case NP_OP_RSHIFT: return 1; case NP_OP_AND: return 1; case NP_OP_OR: return 1; case NP_OP_XOR: return 1; case NP_OP_LOGICAL_AND: return 1; case NP_OP_LOGICAL_OR: return 1; case NP_OP_LOGICAL_XOR: return 1; case NP_OP_EQ: return 1; case NP_OP_NE: return 1; case NP_OP_LT: return 1; case NP_OP_LE: return 1; case NP_OP_GT: return 1; case NP_OP_GE: return 1; case NP_OP_MIN: return 3; case NP_OP_MAX: return 3; case NP_OP_FMIN: return 3; case NP_OP_FMAX: return 3; case NP_OP_SIN: return 10; case NP_OP_COS: return 10; case NP_OP_TAN: return 10; case NP_OP_ARCSIN: return 20; case NP_OP_ARCCOS: return 20; case NP_OP_ARCTAN: return 20; case NP_OP_ARCTAN2: return 35; case NP_OP_HYPOT: return 5; case NP_OP_SINH: return 10; case NP_OP_COSH: return 10; case NP_OP_TANH: return 10; case NP_OP_ARCSINH: return 10; case NP_OP_ARCCOSH: return 10; case NP_OP_ARCTANH: return 10; case NP_OP_CONJ: return 1; case NP_OP_EXP: return 5; case NP_OP_EXP2: return 5; case NP_OP_LOG: return 5; case NP_OP_LOG2: return 5; case NP_OP_LOG10: return 5; case NP_OP_EXPM1: return 5; case NP_OP_LOG1P: return 5; case NP_OP_SQRT: return 2; case NP_OP_SQUARE: return 1; case NP_OP_CBRT: return 5; case NP_OP_LOGADDEXP: return 10; case NP_OP_LOGADDEXP2: return 10; case NP_OP_RECIPROCAL: return 1; case NP_OP_RINT: return 1; case NP_OP_FLOOR: return 1; case NP_OP_CEIL: return 1; case NP_OP_TRUNC: return 1; case NP_OP_ISNAN: return 1; case NP_OP_ISINF: return 1; case NP_OP_ISFINITE: return 1; case NP_OP_SIGN: return 1; case NP_OP_SIGNBIT: return 1; case NP_OP_COPYSIGN: return 1; case NP_OP_SPACING: return 1; case NP_OP_NEXTAFTER: return 1; case NP_OP_DEG2RAD: return 2; case NP_OP_RAD2DEG: return 2; case NP_OP_HEAVISIDE: return 3; } } int64_t NumPyExpr::cost() const { auto c = opcost(); if (c == -1) return -1; // Account for the fact that the vectorized loops are much faster. if (haveVectorizedLoop()) { c *= 3; if (lhs->type.dtype == NumPyType::NP_TYPE_ARR_F32) c *= 2; } bool lhsIntConst = (lhs && lhs->isLeaf() && isA(lhs->val)); bool rhsIntConst = (rhs && rhs->isLeaf() && isA(rhs->val)); bool lhsFloatConst = (lhs && lhs->isLeaf() && isA(lhs->val)); bool rhsFloatConst = (rhs && rhs->isLeaf() && isA(rhs->val)); bool lhsConst = lhsIntConst || lhsFloatConst; bool rhsConst = rhsIntConst || rhsFloatConst; if (rhsConst || lhsConst) { switch (op) { case NP_OP_TRUE_DIV: case NP_OP_FLOOR_DIV: case NP_OP_MOD: case NP_OP_FMOD: c = 1; break; case NP_OP_POW: if (rhsIntConst) c = (cast(rhs->val)->getVal() == 2) ? 1 : 5; break; default: break; } } if (lhs) { auto cl = lhs->cost(); if (cl == -1) return -1; c += cl; } if (rhs) { auto cr = rhs->cost(); if (cr == -1) return -1; c += cr; } return c; } std::string NumPyExpr::opstring() const { static const std::unordered_map m = { {NP_OP_NONE, "a"}, {NP_OP_POS, "pos"}, {NP_OP_NEG, "neg"}, {NP_OP_INVERT, "invert"}, {NP_OP_ABS, "abs"}, {NP_OP_TRANSPOSE, "transpose"}, {NP_OP_ADD, "add"}, {NP_OP_SUB, "sub"}, {NP_OP_MUL, "mul"}, {NP_OP_MATMUL, "matmul"}, {NP_OP_TRUE_DIV, "true_div"}, {NP_OP_FLOOR_DIV, "floor_div"}, {NP_OP_MOD, "mod"}, {NP_OP_FMOD, "fmod"}, {NP_OP_POW, "pow"}, {NP_OP_LSHIFT, "lshift"}, {NP_OP_RSHIFT, "rshift"}, {NP_OP_AND, "and"}, {NP_OP_OR, "or"}, {NP_OP_XOR, "xor"}, {NP_OP_LOGICAL_AND, "logical_and"}, {NP_OP_LOGICAL_OR, "logical_or"}, {NP_OP_LOGICAL_XOR, "logical_xor"}, {NP_OP_EQ, "eq"}, {NP_OP_NE, "ne"}, {NP_OP_LT, "lt"}, {NP_OP_LE, "le"}, {NP_OP_GT, "gt"}, {NP_OP_GE, "ge"}, {NP_OP_MIN, "minimum"}, {NP_OP_MAX, "maximum"}, {NP_OP_FMIN, "fmin"}, {NP_OP_FMAX, "fmax"}, {NP_OP_SIN, "sin"}, {NP_OP_COS, "cos"}, {NP_OP_TAN, "tan"}, {NP_OP_ARCSIN, "arcsin"}, {NP_OP_ARCCOS, "arccos"}, {NP_OP_ARCTAN, "arctan"}, {NP_OP_ARCTAN2, "arctan2"}, {NP_OP_HYPOT, "hypot"}, {NP_OP_SINH, "sinh"}, {NP_OP_COSH, "cosh"}, {NP_OP_TANH, "tanh"}, {NP_OP_ARCSINH, "arcsinh"}, {NP_OP_ARCCOSH, "arccosh"}, {NP_OP_ARCTANH, "arctanh"}, {NP_OP_CONJ, "conj"}, {NP_OP_EXP, "exp"}, {NP_OP_EXP2, "exp2"}, {NP_OP_LOG, "log"}, {NP_OP_LOG2, "log2"}, {NP_OP_LOG10, "log10"}, {NP_OP_EXPM1, "expm1"}, {NP_OP_LOG1P, "log1p"}, {NP_OP_SQRT, "sqrt"}, {NP_OP_SQUARE, "square"}, {NP_OP_CBRT, "cbrt"}, {NP_OP_LOGADDEXP, "logaddexp"}, {NP_OP_LOGADDEXP2, "logaddexp2"}, {NP_OP_RECIPROCAL, "reciprocal"}, {NP_OP_RINT, "rint"}, {NP_OP_FLOOR, "floor"}, {NP_OP_CEIL, "ceil"}, {NP_OP_TRUNC, "trunc"}, {NP_OP_ISNAN, "isnan"}, {NP_OP_ISINF, "isinf"}, {NP_OP_ISFINITE, "isfinite"}, {NP_OP_SIGN, "sign"}, {NP_OP_SIGNBIT, "signbit"}, {NP_OP_COPYSIGN, "copysign"}, {NP_OP_SPACING, "spacing"}, {NP_OP_NEXTAFTER, "nextafter"}, {NP_OP_DEG2RAD, "deg2rad"}, {NP_OP_RAD2DEG, "rad2deg"}, {NP_OP_HEAVISIDE, "heaviside"}, }; auto it = m.find(op); seqassertn(it != m.end(), "op not found"); return it->second; } void NumPyExpr::dump(std::ostream &os, int level, int &leafId) const { auto indent = [&]() { for (int i = 0; i < level; i++) os << " "; }; indent(); if (op == NP_OP_NONE) { os << "\033[1;36m" << opstring() << leafId; ++leafId; } else { os << "\033[1;33m" << opstring(); } os << "\033[0m <" << type << ">"; if (op != NP_OP_NONE) os << " \033[1;35m[cost=" << cost() << "]\033[0m"; os << "\n"; if (lhs) lhs->dump(os, level + 1, leafId); if (rhs) rhs->dump(os, level + 1, leafId); } std::ostream &operator<<(std::ostream &os, NumPyExpr const &expr) { int leafId = 0; expr.dump(os, 0, leafId); return os; } std::string NumPyExpr::str() const { std::stringstream buffer; buffer << *this; return buffer.str(); } void NumPyExpr::apply(std::function f) { f(*this); if (lhs) lhs->apply(f); if (rhs) rhs->apply(f); } Value *NumPyExpr::codegenBroadcasts(CodegenContext &C) { auto *M = C.M; auto &vars = C.vars; Value *targetShape = nullptr; Value *result = nullptr; apply([&](NumPyExpr &e) { if (e.isLeaf() && e.type.isArray()) { auto it = vars.find(&e); seqassertn(it != vars.end(), "NumPyExpr not found in vars map (codegen broadcasts)"); auto *var = it->second; auto *shape = M->getOrRealizeFunc("_shape", {var->getType()}, {}, FUSION_MODULE); seqassertn(shape, "shape function not found"); auto *leafShape = util::call(shape, {M->Nr(var)}); if (!targetShape) { targetShape = leafShape; } else { auto *diff = (*targetShape != *leafShape); if (result) { result = *result | *diff; } else { result = diff; } } } }); return result ? result : M->getBool(false); } Var *NumPyExpr::codegenFusedEval(CodegenContext &C) { auto *M = C.M; auto *series = C.series; auto *func = C.func; auto &vars = C.vars; auto &T = C.T; std::vector> leaves; apply([&](NumPyExpr &e) { if (e.isLeaf()) { auto it = vars.find(&e); seqassertn(it != vars.end(), "NumPyExpr not found in vars map (fused eval)"); auto *var = it->second; leaves.emplace_back(&e, var); } }); // Arrays for scalar expression function std::vector arrays; std::vector scalarFuncArgNames; std::vector scalarFuncArgTypes; std::unordered_map scalarFuncArgMap; // Scalars passed through 'extra' arg of ndarray._loop() std::vector extra; std::unordered_map extraMap; auto *baseType = type.getIRBaseType(T); scalarFuncArgNames.push_back("out"); scalarFuncArgTypes.push_back(M->getPointerType(baseType)); unsigned argIdx = 0; unsigned extraIdx = 0; for (auto &e : leaves) { if (e.first->type.isArray()) { arrays.push_back(M->Nr(e.second)); scalarFuncArgNames.push_back("in" + std::to_string(argIdx++)); scalarFuncArgTypes.push_back(M->getPointerType(e.first->type.getIRBaseType(T))); } else { extra.push_back(M->Nr(e.second)); extraMap.emplace(e.first, extraIdx++); } } auto *extraTuple = util::makeTuple(extra, M); scalarFuncArgNames.push_back("extra"); scalarFuncArgTypes.push_back(extraTuple->getType()); auto *scalarFuncType = M->getFuncType(M->getNoneType(), scalarFuncArgTypes); auto *scalarFunc = M->Nr("__numpy_fusion_scalar_fn"); scalarFunc->realize(scalarFuncType, scalarFuncArgNames); std::vector scalarFuncArgVars(scalarFunc->arg_begin(), scalarFunc->arg_end()); argIdx = 1; for (auto &e : leaves) { if (e.first->type.isArray()) { scalarFuncArgMap.emplace(e.first, scalarFuncArgVars[argIdx++]); } } auto *scalarExpr = codegenScalarExpr(C, scalarFuncArgMap, extraMap, scalarFuncArgVars.back()); auto *ptrsetFunc = M->getOrRealizeFunc("_ptrset", {scalarFuncArgTypes[0], baseType}, {}, FUSION_MODULE); seqassertn(ptrsetFunc, "ptrset func not found"); scalarFunc->setBody(util::series( util::call(ptrsetFunc, {M->Nr(scalarFuncArgVars[0]), scalarExpr}))); auto *arraysTuple = util::makeTuple(arrays); auto *loopFunc = M->getOrRealizeFunc( "_loop_alloc", {arraysTuple->getType(), scalarFunc->getType(), extraTuple->getType()}, {baseType}, FUSION_MODULE); seqassertn(loopFunc, "loop_alloc func not found"); auto *result = util::makeVar( util::call(loopFunc, {arraysTuple, M->Nr(scalarFunc), extraTuple}), series, func); // Free temporary arrays apply([&](NumPyExpr &e) { if (e.isLeaf() && e.freeable) { auto it = vars.find(&e); seqassertn(it != vars.end(), "NumPyExpr not found in vars map (fused eval)"); auto *var = it->second; auto *freeFunc = M->getOrRealizeFunc("_free", {var->getType()}, {}, FUSION_MODULE); seqassertn(freeFunc, "free func not found"); series->push_back(util::call(freeFunc, {M->Nr(var)})); } }); return result; } Var *NumPyExpr::codegenSequentialEval(CodegenContext &C) { auto *M = C.M; auto *series = C.series; auto *func = C.func; auto &vars = C.vars; auto &T = C.T; if (isLeaf()) { auto it = vars.find(this); seqassertn(it != vars.end(), "NumPyExpr not found in vars map (codegen sequential eval)"); return it->second; } Var *lv = lhs->codegenSequentialEval(C); Var *rv = rhs ? rhs->codegenSequentialEval(C) : nullptr; Var *like = nullptr; Value *outShapeVal = nullptr; if (rv) { // Can't do anything special with matmul here... if (op == NP_OP_MATMUL) { auto *matmul = M->getOrRealizeFunc("_matmul", {lv->getType(), rv->getType()}, {}, FUSION_MODULE); return util::makeVar( util::call(matmul, {M->Nr(lv), M->Nr(rv)}), series, func); } auto *lshape = M->getOrRealizeFunc("_shape", {lv->getType()}, {}, FUSION_MODULE); seqassertn(lshape, "shape func not found for left arg"); auto *rshape = M->getOrRealizeFunc("_shape", {rv->getType()}, {}, FUSION_MODULE); seqassertn(rshape, "shape func not found for right arg"); auto *leftShape = util::call(lshape, {M->Nr(lv)}); auto *rightShape = util::call(rshape, {M->Nr(rv)}); auto *shape = M->getOrRealizeFunc( "_broadcast", {leftShape->getType(), rightShape->getType()}, {}, FUSION_MODULE); seqassertn(shape, "output shape func not found"); like = rhs->type.ndim > lhs->type.ndim ? rv : lv; outShapeVal = util::call(shape, {leftShape, rightShape}); } else { auto *shape = M->getOrRealizeFunc("_shape", {lv->getType()}, {}, FUSION_MODULE); seqassertn(shape, "shape func not found"); like = lv; outShapeVal = util::call(shape, {M->Nr(lv)}); } auto *outShape = util::makeVar(outShapeVal, series, func); Var *result = nullptr; bool lfreeable = lhs && lhs->type.isArray() && (lhs->freeable || !lhs->isLeaf()); bool rfreeable = rhs && rhs->type.isArray() && (rhs->freeable || !rhs->isLeaf()); bool ltmp = lfreeable && lhs->type.dtype == type.dtype && lhs->type.ndim == type.ndim; bool rtmp = rfreeable && rhs->type.dtype == type.dtype && rhs->type.ndim == type.ndim; auto *t = type.getIRBaseType(T); auto newArray = [&]() { auto *create = M->getOrRealizeFunc( "_create", {like->getType(), outShape->getType()}, {t}, FUSION_MODULE); seqassertn(create, "create func not found"); return util::call(create, {M->Nr(like), M->Nr(outShape)}); }; bool freeLeftStatic = false; bool freeRightStatic = false; Var *lcond = nullptr; Var *rcond = nullptr; if (rv) { if (ltmp && rhs->type.ndim == 0) { // We are adding lhs temp array to const or 0-dim array, so reuse lhs array. result = lv; } else if (rtmp && lhs->type.ndim == 0) { // We are adding rhs temp array to const or 0-dim array, so reuse rhs array. result = rv; } else if (!ltmp && !rtmp) { // Neither operand is a temp array, so we must allocate a new array. result = util::makeVar(newArray(), series, func); freeLeftStatic = lfreeable; freeRightStatic = rfreeable; } else if (ltmp && rtmp) { // We won't know until runtime if we can reuse the temp array(s) since they // might broadcast. auto *lshape = M->getOrRealizeFunc("_shape", {lv->getType()}, {}, FUSION_MODULE); seqassertn(lshape, "shape function func not found for left arg"); auto *rshape = M->getOrRealizeFunc("_shape", {rv->getType()}, {}, FUSION_MODULE); seqassertn(rshape, "shape function func not found for right arg"); auto *leftShape = util::call(lshape, {M->Nr(lv)}); auto *rightShape = util::call(rshape, {M->Nr(rv)}); lcond = util::makeVar(*leftShape == *M->Nr(outShape), series, func); rcond = util::makeVar(*rightShape == *M->Nr(outShape), series, func); auto *arr = M->Nr( M->Nr(lcond), M->Nr(lv), M->Nr(M->Nr(rcond), M->Nr(rv), newArray())); result = util::makeVar(arr, series, func); } else if (ltmp && !rtmp) { // We won't know until runtime if we can reuse the temp array(s) since they // might broadcast. auto *lshape = M->getOrRealizeFunc("_shape", {lv->getType()}, {}, FUSION_MODULE); seqassertn(lshape, "shape function func not found for left arg"); auto *leftShape = util::call(lshape, {M->Nr(lv)}); lcond = util::makeVar(*leftShape == *M->Nr(outShape), series, func); auto *arr = M->Nr(M->Nr(lcond), M->Nr(lv), newArray()); result = util::makeVar(arr, series, func); freeRightStatic = rfreeable; } else if (!ltmp && rtmp) { // We won't know until runtime if we can reuse the temp array(s) since they // might broadcast. auto *rshape = M->getOrRealizeFunc("_shape", {rv->getType()}, {}, FUSION_MODULE); seqassertn(rshape, "shape function func not found for right arg"); auto *rightShape = util::call(rshape, {M->Nr(rv)}); rcond = util::makeVar(*rightShape == *M->Nr(outShape), series, func); auto *arr = M->Nr(M->Nr(rcond), M->Nr(rv), newArray()); result = util::makeVar(arr, series, func); freeLeftStatic = lfreeable; } } else { if (ltmp) { result = lv; } else { result = util::makeVar(newArray(), series, func); freeLeftStatic = lfreeable; } } auto opstr = opstring(); if (haveVectorizedLoop()) { // We have a vectorized loop available for this operations. if (rv) { auto *vecloop = M->getOrRealizeFunc( "_apply_vectorized_loop_binary", {lv->getType(), rv->getType(), result->getType()}, {opstr}, FUSION_MODULE); seqassertn(vecloop, "binary vec loop func not found ({})", opstr); series->push_back(util::call(vecloop, {M->Nr(lv), M->Nr(rv), M->Nr(result)})); } else { auto *vecloop = M->getOrRealizeFunc("_apply_vectorized_loop_unary", {lv->getType(), result->getType()}, {opstr}, FUSION_MODULE); seqassertn(vecloop, "unary vec loop func not found ({})", opstr); series->push_back( util::call(vecloop, {M->Nr(lv), M->Nr(result)})); } } else { // Arrays for scalar expression function std::vector arrays = {M->Nr(result)}; std::vector scalarFuncArgNames; std::vector scalarFuncArgTypes; std::unordered_map scalarFuncArgMap; // Scalars passed through 'extra' arg of ndarray._loop() std::vector extra; auto *baseType = type.getIRBaseType(T); scalarFuncArgNames.push_back("out"); scalarFuncArgTypes.push_back(M->getPointerType(baseType)); if (lhs->type.isArray()) { if (result != lv) { scalarFuncArgNames.push_back("in0"); scalarFuncArgTypes.push_back(M->getPointerType(lhs->type.getIRBaseType(T))); arrays.push_back(M->Nr(lv)); } } else { extra.push_back(M->Nr(lv)); } if (rv) { if (rhs->type.isArray()) { if (result != rv) { scalarFuncArgNames.push_back("in1"); scalarFuncArgTypes.push_back(M->getPointerType(rhs->type.getIRBaseType(T))); arrays.push_back(M->Nr(rv)); } } else { extra.push_back(M->Nr(rv)); } } auto *extraTuple = util::makeTuple(extra, M); scalarFuncArgNames.push_back("extra"); scalarFuncArgTypes.push_back(extraTuple->getType()); auto *scalarFuncType = M->getFuncType(M->getNoneType(), scalarFuncArgTypes); auto *scalarFunc = M->Nr("__numpy_fusion_scalar_fn"); scalarFunc->realize(scalarFuncType, scalarFuncArgNames); std::vector scalarFuncArgVars(scalarFunc->arg_begin(), scalarFunc->arg_end()); auto *body = M->Nr(); auto name = "_" + opstr; auto deref = [&](unsigned idx) { return (*M->Nr(scalarFuncArgVars[idx]))[*M->getInt(0)]; }; if (rv) { Value *litem = nullptr; Value *ritem = nullptr; if (lhs->type.isArray() && rhs->type.isArray()) { if (result == lv) { litem = deref(0); ritem = deref(1); } else if (result == rv) { litem = deref(1); ritem = deref(0); } else { litem = deref(1); ritem = deref(2); } } else if (lhs->type.isArray()) { if (result == lv) { litem = deref(0); } else { litem = deref(1); } ritem = util::tupleGet(M->Nr(scalarFuncArgVars.back()), 0); } else if (rhs->type.isArray()) { if (result == rv) { ritem = deref(0); } else { ritem = deref(1); } litem = util::tupleGet(M->Nr(scalarFuncArgVars.back()), 0); } else { seqassertn(false, "both lhs are rhs are scalars"); } auto *commonType = decideTypes(this, lhs->type, rhs->type, T); auto *lcast = M->getOrRealizeFunc("_cast", {litem->getType()}, {commonType}, FUSION_MODULE); seqassertn(lcast, "cast func not found for left arg"); litem = util::call(lcast, {litem}); auto *rcast = M->getOrRealizeFunc("_cast", {ritem->getType()}, {commonType}, FUSION_MODULE); seqassertn(rcast, "cast func not found for left arg"); ritem = util::call(rcast, {ritem}); auto *op = M->getOrRealizeFunc(name, {litem->getType(), ritem->getType()}, {}, FUSION_MODULE); seqassertn(op, "2-op func '{}' not found", name); auto *oitem = util::call(op, {litem, ritem}); auto *ptrsetFunc = M->getOrRealizeFunc( "_ptrset", {scalarFuncArgTypes[0], oitem->getType()}, {}, FUSION_MODULE); seqassertn(ptrsetFunc, "ptrset func not found"); body->push_back( util::call(ptrsetFunc, {M->Nr(scalarFuncArgVars[0]), oitem})); } else { auto *litem = deref(result == lv ? 0 : 1); auto *op = M->getOrRealizeFunc(name, {litem->getType()}, {}, FUSION_MODULE); seqassertn(op, "1-op func '{}' not found", name); auto *oitem = util::call(op, {litem}); auto *ptrsetFunc = M->getOrRealizeFunc( "_ptrset", {scalarFuncArgTypes[0], oitem->getType()}, {}, FUSION_MODULE); seqassertn(ptrsetFunc, "ptrset func not found"); body->push_back( util::call(ptrsetFunc, {M->Nr(scalarFuncArgVars[0]), oitem})); } scalarFunc->setBody(body); auto *arraysTuple = util::makeTuple(arrays); auto *loopFunc = M->getOrRealizeFunc( "_loop_basic", {arraysTuple->getType(), scalarFunc->getType(), extraTuple->getType()}, {}, FUSION_MODULE); seqassertn(loopFunc, "loop_basic func not found"); series->push_back( util::call(loopFunc, {arraysTuple, M->Nr(scalarFunc), extraTuple})); } auto freeArray = [&](Var *arr) { auto *freeFunc = M->getOrRealizeFunc("_free", {arr->getType()}, {}, FUSION_MODULE); seqassertn(freeFunc, "free func not found"); return util::call(freeFunc, {M->Nr(arr)}); }; seqassertn(!(freeLeftStatic && lcond), "unexpected free conditions for left arg"); seqassertn(!(freeRightStatic && rcond), "unexpected free conditions for right arg"); if (lcond && rcond) { series->push_back(M->Nr( M->Nr(lcond), util::series(freeArray(rv)), util::series(freeArray(lv), M->Nr(M->Nr(rcond), M->Nr(), util::series(freeArray(rv)))))); } else { if (freeLeftStatic) { series->push_back(freeArray(lv)); } else if (lcond) { series->push_back(M->Nr(M->Nr(lcond), M->Nr(), util::series(freeArray(lv)))); } if (freeRightStatic) { series->push_back(freeArray(rv)); } else if (rcond) { series->push_back(M->Nr(M->Nr(rcond), M->Nr(), util::series(freeArray(rv)))); } } return result; } BroadcastInfo NumPyExpr::getBroadcastInfo() { int64_t arrDim = -1; Var *varLeaf = nullptr; bool multipleLeafVars = false; int numNonVarLeafArrays = 0; bool definitelyBroadcasts = false; apply([&](NumPyExpr &e) { if (e.isLeaf() && e.type.isArray()) { if (arrDim == -1) { arrDim = e.type.ndim; } else if (arrDim != e.type.ndim) { definitelyBroadcasts = true; } if (auto *v = cast(e.val)) { if (varLeaf) { if (varLeaf != v->getVar()) multipleLeafVars = true; } else { varLeaf = v->getVar(); } } else { ++numNonVarLeafArrays; } } }); bool mightBroadcast = numNonVarLeafArrays > 1 || multipleLeafVars || (numNonVarLeafArrays == 1 && varLeaf); if (definitelyBroadcasts) { return BroadcastInfo::YES; } else if (mightBroadcast) { return BroadcastInfo::MAYBE; } else { return BroadcastInfo::NO; } } Value *NumPyExpr::codegenScalarExpr( CodegenContext &C, const std::unordered_map &args, const std::unordered_map &scalarMap, Var *scalars) { auto *M = C.M; auto &T = C.T; Value *lv = lhs ? lhs->codegenScalarExpr(C, args, scalarMap, scalars) : nullptr; Value *rv = rhs ? rhs->codegenScalarExpr(C, args, scalarMap, scalars) : nullptr; auto name = "_" + opstring(); if (lv && rv) { auto *t = type.getIRBaseType(T); auto *commonType = decideTypes(this, lhs->type, rhs->type, T); auto *cast1 = M->getOrRealizeFunc("_cast", {lv->getType()}, {commonType}, FUSION_MODULE); auto *cast2 = M->getOrRealizeFunc("_cast", {rv->getType()}, {commonType}, FUSION_MODULE); lv = util::call(cast1, {lv}); rv = util::call(cast2, {rv}); auto *f = M->getOrRealizeFunc(name, {lv->getType(), rv->getType()}, {}, FUSION_MODULE); seqassertn(f, "2-op func '{}' not found", name); return util::call(f, {lv, rv}); } else if (lv) { auto *t = type.getIRBaseType(T); auto *f = M->getOrRealizeFunc(name, {lv->getType()}, {}, FUSION_MODULE); seqassertn(f, "1-op func '{}' not found", name); return util::call(f, {lv}); } else { if (type.isArray()) { auto it = args.find(this); seqassertn(it != args.end(), "NumPyExpr not found in args map (codegen expr)"); auto *var = it->second; return (*M->Nr(var))[*M->getInt(0)]; } else { auto it = scalarMap.find(this); seqassertn(it != scalarMap.end(), "NumPyExpr not found in scalar map (codegen expr)"); auto idx = it->second; return util::tupleGet(M->Nr(scalars), idx); } } } } // namespace numpy } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/numpy/forward.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "numpy.h" namespace codon { namespace ir { namespace transform { namespace numpy { namespace { using CFG = analyze::dataflow::CFGraph; using CFBlock = analyze::dataflow::CFBlock; using RD = analyze::dataflow::RDInspector; using SE = analyze::module::SideEffectResult; struct GetVars : public util::Operator { std::unordered_set &vids; explicit GetVars(std::unordered_set &vids) : util::Operator(), vids(vids) {} void preHook(Node *v) override { for (auto *var : v->getUsedVariables()) { if (!isA(var)) vids.insert(var->getId()); } } }; struct OkToForwardPast : public util::Operator { std::unordered_set &vids; const std::unordered_map &parsedValues; SE *se; bool ok; OkToForwardPast(std::unordered_set &vids, const std::unordered_map &parsedValues, SE *se) : util::Operator(), vids(vids), parsedValues(parsedValues), se(se), ok(true) {} void preHook(Node *v) override { if (!ok) { return; } else if (auto *assign = cast(v)) { if (vids.count(assign->getLhs()->getId())) ok = false; } else if (auto *val = cast(v)) { auto it = parsedValues.find(val->getId()); if (it != parsedValues.end()) { it->second->apply([&](NumPyExpr &e) { if (e.isLeaf() && se->hasSideEffect(e.val)) ok = false; }); // Skip children since we are processing them manually above. for (auto *used : val->getUsedValues()) see(used); } else if (se->hasSideEffect(val)) { ok = false; } } } }; struct GetAllUses : public util::Operator { Var *var; std::vector &uses; GetAllUses(Var *var, std::vector &uses) : util::Operator(), var(var), uses(uses) {} void preHook(Node *n) override { if (auto *v = cast(n)) { auto vars = v->getUsedVariables(); if (std::find(vars.begin(), vars.end(), var) != vars.end()) uses.push_back(v); } } }; bool canForwardExpressionAlongPath( Value *source, Value *destination, std::unordered_set &vids, const std::unordered_map &parsedValues, SE *se, const std::vector &path) { if (path.empty()) return true; bool go = false; for (auto *block : path) { for (const auto *value : *block) { // Skip things before 'source' in first block if (!go && block == path.front() && value == source) { go = true; continue; } // Skip things after 'destination' in last block if (go && block == path.back() && value == destination) { go = false; break; } if (!go) continue; OkToForwardPast check(vids, parsedValues, se); const_cast(value)->accept(check); if (!check.ok) return false; } } return true; } bool canForwardExpression(NumPyOptimizationUnit *expr, Value *target, const std::unordered_map &parsedValues, CFG *cfg, SE *se) { std::unordered_set vids; bool pure = true; expr->expr->apply([&](NumPyExpr &e) { if (e.isLeaf()) { if (se->hasSideEffect(e.val)) { pure = false; } else { GetVars gv(vids); e.val->accept(gv); } } }); if (!pure) return false; auto *source = expr->assign; auto *start = cfg->getBlock(source); auto *end = cfg->getBlock(target); seqassertn(start, "start CFG block not found"); seqassertn(end, "end CFG block not found"); bool ok = true; std::function &)> dfs = [&](CFBlock *curr, std::vector &path) { path.push_back(curr); if (curr == end) { if (!canForwardExpressionAlongPath(source, target, vids, parsedValues, se, path)) ok = false; } else { for (auto it = curr->successors_begin(); it != curr->successors_end(); ++it) { if (std::find(path.begin(), path.end(), *it) != path.end()) dfs(*it, path); } } path.pop_back(); }; std::vector path; dfs(start, path); return ok; } bool canForwardVariable(AssignInstr *assign, Value *destination, BodiedFunc *func, RD *rd) { auto *var = assign->getLhs(); // Check 1: Only the given assignment should reach the destination. auto reaching = rd->getReachingDefinitions(var, destination); if (reaching.size() != 1 || reaching[0].assignment->getId() != assign->getId()) return false; // Check 2: There should be no other uses of the variable that the given assignment // reaches. std::vector uses; GetAllUses gu(var, uses); func->accept(gu); for (auto *use : uses) { if (use != destination && use->getId() != assign->getId()) { auto defs = rd->getReachingDefinitions(var, use); for (auto &def : defs) { if (def.assignment->getId() == assign->getId()) return false; } } } return true; } ForwardingDAG buildForwardingDAG(BodiedFunc *func, RD *rd, CFG *cfg, SE *se, std::vector &exprs) { std::unordered_map parsedValues; for (auto &e : exprs) { e.expr->apply([&](NumPyExpr &e) { if (e.val) parsedValues.emplace(e.val->getId(), &e); }); } ForwardingDAG dag; int64_t dstId = 0; for (auto &dst : exprs) { auto *target = dst.expr.get(); auto &forwardingVec = dag[&dst]; std::vector> vars; target->apply([&](NumPyExpr &e) { if (e.isLeaf()) { if (auto *v = cast(e.val)) { vars.emplace_back(v->getVar(), &e); } } }); for (auto &p : vars) { int64_t srcId = 0; for (auto &src : exprs) { if (srcId != dstId && src.assign && src.assign->getLhs() == p.first) { auto checkFwdVar = canForwardVariable(src.assign, p.second->val, func, rd); auto checkFwdExpr = canForwardExpression(&src, p.second->val, parsedValues, cfg, se); if (checkFwdVar && checkFwdExpr) forwardingVec.push_back({&dst, &src, p.first, p.second, dstId, srcId}); } ++srcId; } } ++dstId; } return dag; } struct UnionFind { std::vector parent; std::vector rank; explicit UnionFind(int64_t n) : parent(n), rank(n) { for (auto i = 0; i < n; i++) { parent[i] = i; rank[i] = 0; } } int64_t find(int64_t u) { if (parent[u] != u) parent[u] = find(parent[u]); return parent[u]; } void union_(int64_t u, int64_t v) { auto ru = find(u); auto rv = find(v); if (ru != rv) { if (rank[ru] > rank[rv]) { parent[rv] = ru; } else if (rank[ru] < rank[rv]) { parent[ru] = rv; } else { parent[rv] = ru; ++rank[ru]; } } } }; std::vector getForwardingDAGConnectedComponents(ForwardingDAG &dag, std::vector &exprs) { auto n = exprs.size(); UnionFind uf(n); for (auto i = 0; i < n; i++) { for (auto &fwd : dag[&exprs[i]]) { uf.union_(i, fwd.srcId); } } std::vector> components(n); for (auto i = 0; i < n; i++) { auto root = uf.find(i); components[root].push_back(&exprs[i]); } std::vector result; for (auto &c : components) { if (c.empty()) continue; ForwardingDAG d; for (auto *expr : c) d.emplace(expr, dag[expr]); result.push_back(d); } return result; } bool hasCycleHelper(int64_t v, ForwardingDAG &dag, std::vector &exprs, std::vector &visited, std::vector &recStack) { visited[v] = true; recStack[v] = true; for (auto &neighbor : dag[&exprs[v]]) { if (!visited[neighbor.srcId]) { if (hasCycleHelper(neighbor.srcId, dag, exprs, visited, recStack)) return true; } else if (recStack[neighbor.srcId]) { return true; } } recStack[v] = false; return false; } bool hasCycle(ForwardingDAG &dag, std::vector &exprs) { auto n = exprs.size(); std::vector visited(n, false); std::vector recStack(n, false); for (auto i = 0; i < n; i++) { if (dag.find(&exprs[i]) != dag.end() && !visited[i] && hasCycleHelper(i, dag, exprs, visited, recStack)) return true; } return false; } void doForwardingHelper(ForwardingDAG &dag, NumPyOptimizationUnit *curr, std::unordered_set &done, std::vector &assignsToDelete) { if (done.count(curr)) return; auto forwardings = dag[curr]; for (auto &fwd : forwardings) { doForwardingHelper(dag, fwd.src, done, assignsToDelete); // Note that order of leaves here doesn't matter since they're guaranteed to have no // side effects based on forwarding checks. fwd.dst->leaves.insert(fwd.dst->leaves.end(), fwd.src->leaves.begin(), fwd.src->leaves.end()); fwd.dstLeaf->replace(*fwd.src->expr); assignsToDelete.push_back(fwd.src->assign); } done.insert(curr); } } // namespace std::vector getForwardingDAGs(BodiedFunc *func, RD *rd, CFG *cfg, SE *se, std::vector &exprs) { auto dag = buildForwardingDAG(func, rd, cfg, se, exprs); auto dags = getForwardingDAGConnectedComponents(dag, exprs); dags.erase(std::remove_if(dags.begin(), dags.end(), [&](ForwardingDAG &dag) { return hasCycle(dag, exprs); }), dags.end()); return dags; } NumPyOptimizationUnit *doForwarding(ForwardingDAG &dag, std::vector &assignsToDelete) { seqassertn(!dag.empty(), "empty forwarding DAG encountered"); std::unordered_set done; for (auto &e : dag) { doForwardingHelper(dag, e.first, done, assignsToDelete); } // Find the root std::unordered_set notRoot; for (auto &e : dag) { for (auto &f : e.second) { notRoot.insert(f.src); } } seqassertn(notRoot.size() == dag.size() - 1, "multiple roots found in forwarding DAG"); for (auto &e : dag) { if (notRoot.count(e.first) == 0) return e.first; } seqassertn(false, "could not find root in forwarding DAG"); return nullptr; } } // namespace numpy } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/numpy/indexing.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "indexing.h" #include "codon/cir/analyze/dataflow/reaching.h" #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" #include namespace codon { namespace ir { namespace transform { namespace numpy { namespace { const std::string FUSION_MODULE = "std.numpy.fusion"; struct Term { enum Kind { INT, VAR, LEN } kind; int64_t val; const VarValue *var; Term(Kind kind, int64_t val, const VarValue *var) : kind(kind), val(val), var(var) {} static Term valTerm(int64_t v) { return {Kind::INT, v, nullptr}; } static Term varTerm(VarValue *v) { return {Kind::VAR, 1, v}; } static Term lenTerm(VarValue *v) { return {Kind::LEN, 1, v}; } void negate() { val = -val; } void multiply(int64_t n) { val *= n; } bool combine(const Term &other) { if (kind == other.kind && (kind == Kind::INT || var->getVar()->getId() == other.var->getVar()->getId())) { val += other.val; return true; } return false; } bool zero() const { return val == 0; } std::string str() const { switch (kind) { case Kind::INT: return std::to_string(val); case Kind::VAR: return (val == 1 ? "" : std::to_string(val) + "*") + var->getVar()->getName(); case Kind::LEN: return (val == 1 ? "" : std::to_string(val) + "*") + "len(" + var->getVar()->getName() + ")"; } return "(?)"; } }; std::string t2s(const std::vector &terms) { if (terms.empty()) { return "[]"; } std::string s = "[" + terms[0].str(); for (auto i = 1; i < terms.size(); i++) s += ", " + terms[i].str(); s += "]"; return s; } void simplify(std::vector &terms) { // Presumably the number of terms will be small, // so we use a simple quadratic algorithm. for (auto it1 = terms.begin(); it1 != terms.end(); ++it1) { auto it2 = it1 + 1; while (it2 != terms.end()) { auto &t1 = *it1; auto &t2 = *it2; if (t1.combine(t2)) { it2 = terms.erase(it2); } else { ++it2; } } } terms.erase(std::remove_if(terms.begin(), terms.end(), [](const Term &t) { return t.zero(); }), terms.end()); } bool checkTotal(const std::vector &terms, bool strict) { int64_t total = 0; for (const auto &term : terms) { switch (term.kind) { case Term::Kind::INT: total += term.val; break; case Term::Kind::VAR: if (term.val != 0) return false; break; case Term::Kind::LEN: // len is never negative break; } } return strict ? (total > 0) : (total >= 0); } bool lessCheck(const std::vector &terms1, const std::vector &terms2, bool strict) { // Checking if: t1_0 + t1_1 + ... + t1_N < t2_0 + t2_1 + ... + t2_M // Same as: t2_0 + t2_1 + ... + t2_M - t1_0 - t1_1 - ... - t1_N > 0 std::vector tmp(terms1.begin(), terms1.end()); for (auto &t : tmp) t.negate(); std::vector terms(terms2.begin(), terms2.end()); terms.insert(terms.end(), tmp.begin(), tmp.end()); simplify(terms); return checkTotal(terms, strict); } bool lessThan(const std::vector &terms1, const std::vector &terms2) { return lessCheck(terms1, terms2, /*strict=*/true); } bool lessThanOrEqual(const std::vector &terms1, const std::vector &terms2) { return lessCheck(terms1, terms2, /*strict=*/false); } std::vector replaceLoopVariable(const std::vector &terms, Var *loopVar, const std::vector &replacement) { std::vector ans; for (auto &term : terms) { if (term.kind == Term::Kind::VAR && term.var->getVar()->getId() == loopVar->getId()) { for (auto &rep : replacement) { ans.push_back(rep); ans.back().multiply(term.val); } } else { ans.push_back(term); } } return ans; } bool isArrayType(types::Type *t, bool dim1 = false) { bool result = t && isA(t) && t->getName().rfind( ast::getMangledClass("std.numpy.ndarray", "ndarray") + "[", 0) == 0; if (result && dim1) { auto generics = t->getGenerics(); seqassertn(generics.size() == 2 && generics[0].isType() && generics[1].isStatic(), "unrecognized ndarray generics"); auto ndim = generics[1].getStaticValue(); result &= (ndim == 1); } return result; } bool isLen(Func *f) { return f->getName().rfind(ast::getMangledFunc("std.internal.builtin", "len") + "[", 0) == 0; } bool parse(Value *x, std::vector &terms, bool negate = false) { auto push = [&](const Term &t) { terms.push_back(t); if (negate) terms.back().negate(); }; if (auto *v = cast(x)) { push(Term::valTerm(v->getVal())); return true; } if (auto *v = cast(x)) { push(Term::varTerm(v)); return true; } auto *M = x->getModule(); auto *intType = M->getIntType(); if (auto *v = cast(x)) { if (util::isCallOf(v, Module::ADD_MAGIC_NAME, {intType, intType}, intType, /*method=*/true)) { return parse(v->front(), terms, negate) && parse(v->back(), terms, negate); } if (util::isCallOf(v, Module::SUB_MAGIC_NAME, {intType, intType}, intType, /*method=*/true)) { return parse(v->front(), terms, negate) && parse(v->back(), terms, !negate); } if (v->numArgs() == 1 && isArrayType(v->front()->getType()) && isA(v->front()) && util::isCallOf(v, "size", {v->front()->getType()}, intType, /*method=*/true)) { push(Term::lenTerm(cast(v->front()))); return true; } if (v->numArgs() == 1 && isArrayType(v->front()->getType()) && isA(v->front()) && util::isCallOf(v, "len", {v->front()->getType()}, intType, /*method=*/false) && isLen(util::getFunc(v->getCallee()))) { push(Term::lenTerm(cast(v->front()))); return true; } } return false; } struct IndexInfo { CallInstr *orig; VarValue *arr; Value *idx; Value *item; IndexInfo(CallInstr *orig, VarValue *arr, Value *idx, Value *item) : orig(orig), arr(arr), idx(idx), item(item) {} }; struct FindArrayIndex : public util::Operator { std::vector indexes; FindArrayIndex() : util::Operator(/*childrenFirst=*/true), indexes() {} void handle(CallInstr *v) override { if (v->numArgs() < 1 || !isArrayType(v->front()->getType(), /*dim1=*/true) || !isA(v->front())) return; auto *M = v->getModule(); auto *arrType = v->front()->getType(); auto *arrVar = cast(v->front()); auto *intType = v->getModule()->getIntType(); if (util::isCallOf(v, Module::GETITEM_MAGIC_NAME, {arrType, intType}, /*output=*/nullptr, /*method=*/true)) { indexes.emplace_back(v, arrVar, v->back(), nullptr); } else if (util::isCallOf(v, Module::SETITEM_MAGIC_NAME, {arrType, intType, v->back()->getType()}, /*output=*/nullptr, /*method=*/true)) { indexes.emplace_back(v, arrVar, *(v->begin() + 1), v->back()); } } }; void elideBoundsCheck(IndexInfo &index) { auto *M = index.orig->getModule(); util::CloneVisitor cv(M); if (index.item) { auto *setitem = M->getOrRealizeFunc( "_array1d_set_nocheck", {index.arr->getType(), M->getIntType(), index.item->getType()}, {}, FUSION_MODULE); seqassertn(setitem, "setitem function not found"); index.orig->replaceAll( util::call(setitem, {M->Nr(index.arr->getVar()), cv.clone(index.idx), cv.clone(index.item)})); } else { auto *getitem = M->getOrRealizeFunc("_array1d_get_nocheck", {index.arr->getType(), M->getIntType()}, {}, FUSION_MODULE); seqassertn(getitem, "getitem function not found"); index.orig->replaceAll(util::call( getitem, {M->Nr(index.arr->getVar()), cv.clone(index.idx)})); } } bool isOriginalLoopVar(const Value *loc, ImperativeForFlow *loop, analyze::dataflow::RDInspector *rd) { // The loop variable should have exactly two reaching definitions: // - The initial assignment for the loop // - The update assignment // Both are represented as `SyntheticAssignInstr` in the CFG. auto *loopVar = loop->getVar(); auto defs = rd->getReachingDefinitions(loopVar, loc); if (defs.size() != 2) return false; using SAI = analyze::dataflow::SyntheticAssignInstr; auto *s1 = cast(defs[0].assignment); auto *s2 = cast(defs[1].assignment); if (!s1 || !s2) return false; if (s1->getKind() == SAI::Kind::ADD && s2->getKind() == SAI::Kind::KNOWN) { auto *tmp = s1; s1 = s2; s2 = tmp; } else if (!(s1->getKind() == SAI::Kind::KNOWN && s2->getKind() == SAI::Kind::ADD)) { return false; } auto *loop1 = cast(s1->getSource()); auto *loop2 = cast(s2->getSource()); if (!loop1 || !loop2 || loop1->getId() != loop->getId() || loop2->getId() != loop->getId()) return false; return true; } const VarValue *isAliasOfLoopVar(const VarValue *v, ImperativeForFlow *loop, analyze::dataflow::RDInspector *rd) { auto defs = rd->getReachingDefinitions(v->getVar(), v); auto *loopVar = loop->getVar(); if (defs.size() != 1 || !defs[0].known() || !isA(defs[0].assignee) || cast(defs[0].assignee)->getVar()->getId() != loopVar->getId() || !isOriginalLoopVar(v, loop, rd)) return nullptr; return cast(defs[0].assignee); } bool canElideBoundsCheck(ImperativeForFlow *loop, IndexInfo &index, const std::vector &startTerms, const std::vector &stopTerms, analyze::dataflow::RDInspector *rd) { auto *loopVar = loop->getVar(); std::vector idxTerms; if (!parse(index.idx, idxTerms)) return false; // First, check that all involved variables refer to a consistent // value. We do this by making sure there is just one reaching def // for all VarValues referring to the same Var. std::unordered_map reach; // "[var id] -> [reaching def id]" map auto check = [&](const VarValue *v) { auto id = v->getVar()->getId(); if (id == loopVar->getId()) { return isOriginalLoopVar(v, loop, rd); } else { auto defs = rd->getReachingDefinitions(v->getVar(), v); if (defs.size() != 1) return false; auto rid = defs[0].getId(); auto it = reach.find(id); if (it == reach.end()) { reach.emplace(id, rid); return true; } else { return it->second == rid; } } }; if (!check(index.arr)) return false; for (auto &term : startTerms) { if (term.kind != Term::Kind::INT && !check(term.var)) return false; } for (auto &term : stopTerms) { if (term.kind != Term::Kind::INT && !check(term.var)) return false; } for (auto &term : idxTerms) { if (term.kind != Term::Kind::INT && !check(term.var)) return false; } // Update vars that are aliases of the loop var. This can // happen in e.g. compound assignments which need to create // a temporary copy of the index. for (auto &term : idxTerms) { if (term.kind != Term::Kind::VAR) continue; if (auto *loopVarValue = isAliasOfLoopVar(term.var, loop, rd)) term.var = loopVarValue; } // Next, see if we can prove that indexes are in range. std::vector limit = {Term::lenTerm(index.arr)}; auto terms1 = replaceLoopVariable(idxTerms, loopVar, startTerms); auto terms2 = replaceLoopVariable(idxTerms, loopVar, stopTerms); if (loop->getStep() > 0) { return lessThanOrEqual({Term::valTerm(0)}, terms1) && lessThanOrEqual(terms2, limit); } else { return lessThan(terms1, limit) && lessThanOrEqual({Term::valTerm(-1)}, terms2); } } } // namespace void NumPyBoundsCheckElisionPass::visit(ImperativeForFlow *f) { if (f->getStep() == 0 || f->getVar()->isGlobal()) return; std::vector startTerms; std::vector stopTerms; FindArrayIndex find; f->getBody()->accept(find); if (find.indexes.empty() || !parse(f->getStart(), startTerms) || !parse(f->getEnd(), stopTerms)) return; auto *r = getAnalysisResult(reachingDefKey); auto *c = r->cfgResult; auto it = r->results.find(getParentFunc()->getId()); if (it == r->results.end()) return; auto *rd = it->second.get(); for (auto &index : find.indexes) { if (canElideBoundsCheck(f, index, startTerms, stopTerms, rd)) { elideBoundsCheck(index); } } } const std::string NumPyBoundsCheckElisionPass::KEY = "core-numpy-bounds-check-elision"; } // namespace numpy } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/numpy/indexing.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "codon/cir/transform/pass.h" #include "codon/cir/types/types.h" namespace codon { namespace ir { namespace transform { namespace numpy { /// NumPy bounds check elision pass class NumPyBoundsCheckElisionPass : public OperatorPass { private: /// Key of the reaching definition analysis std::string reachingDefKey; public: static const std::string KEY; /// Constructs a NumPy bounds check elision pass. /// @param reachingDefKey the reaching definition analysis' key NumPyBoundsCheckElisionPass(const std::string &reachingDefKey) : OperatorPass(), reachingDefKey(reachingDefKey) {} std::string getKey() const override { return KEY; } void visit(ImperativeForFlow *f) override; }; } // namespace numpy } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/numpy/numpy.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "numpy.h" #include "codon/cir/analyze/dataflow/reaching.h" #include "codon/cir/analyze/module/global_vars.h" #include "codon/cir/analyze/module/side_effect.h" #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" #include "llvm/Support/CommandLine.h" #include #include #include #include #define XLOG(c, ...) \ do { \ if (Verbose) \ LOG(c, ##__VA_ARGS__); \ } while (false) namespace codon { namespace ir { namespace transform { namespace numpy { namespace { llvm::cl::opt AlwaysFuseCostThreshold( "npfuse-always", llvm::cl::desc("Expression cost below which (<=) to always fuse"), llvm::cl::init(10)); llvm::cl::opt NeverFuseCostThreshold( "npfuse-never", llvm::cl::desc("Expression cost above which (>) to never fuse"), llvm::cl::init(50)); llvm::cl::opt Verbose("npfuse-verbose", llvm::cl::desc("Print information about fused expressions"), llvm::cl::init(false)); bool isArrayType(types::Type *t) { return t && isA(t) && t->getName().rfind(ast::getMangledClass("std.numpy.ndarray", "ndarray") + "[", 0) == 0; } bool isUFuncType(types::Type *t) { return t && (t->getName().rfind( ast::getMangledClass("std.numpy.ufunc", "UnaryUFunc") + "[", 0) == 0 || t->getName().rfind( ast::getMangledClass("std.numpy.ufunc", "BinaryUFunc") + "[", 0) == 0); } bool isNoneType(types::Type *t, NumPyPrimitiveTypes &T) { return t && (t->is(T.none) || t->is(T.optnone)); } } // namespace const std::string FUSION_MODULE = "std.numpy.fusion"; NumPyPrimitiveTypes::NumPyPrimitiveTypes(Module *M) : none(M->getNoneType()), optnone(M->getOptionalType(none)), bool_(M->getBoolType()), i8(M->getIntNType(8, true)), u8(M->getIntNType(8, false)), i16(M->getIntNType(16, true)), u16(M->getIntNType(16, false)), i32(M->getIntNType(32, true)), u32(M->getIntNType(32, false)), i64(M->getIntType()), u64(M->getIntNType(64, false)), f16(M->getFloat16Type()), f32(M->getFloat32Type()), f64(M->getFloatType()), c64(M->getType(ast::getMangledClass("std.internal.types.complex", "complex64"))), c128(M->getType(ast::getMangledClass("std.internal.types.complex", "complex"))) {} NumPyType::NumPyType(Type dtype, int64_t ndim) : dtype(dtype), ndim(ndim) { seqassertn(ndim >= 0, "ndim must be non-negative"); } NumPyType::NumPyType() : NumPyType(NP_TYPE_NONE) {} NumPyType NumPyType::get(types::Type *t, NumPyPrimitiveTypes &T) { if (t->is(T.bool_)) return {NumPyType::NP_TYPE_BOOL}; if (t->is(T.i8)) return {NumPyType::NP_TYPE_I8}; if (t->is(T.u8)) return {NumPyType::NP_TYPE_U8}; if (t->is(T.i16)) return {NumPyType::NP_TYPE_I16}; if (t->is(T.u16)) return {NumPyType::NP_TYPE_U16}; if (t->is(T.i32)) return {NumPyType::NP_TYPE_I32}; if (t->is(T.u32)) return {NumPyType::NP_TYPE_U32}; if (t->is(T.i64)) return {NumPyType::NP_TYPE_I64}; if (t->is(T.u64)) return {NumPyType::NP_TYPE_U64}; if (t->is(T.f16)) return {NumPyType::NP_TYPE_F16}; if (t->is(T.f32)) return {NumPyType::NP_TYPE_F32}; if (t->is(T.f64)) return {NumPyType::NP_TYPE_F64}; if (t->is(T.c64)) return {NumPyType::NP_TYPE_C64}; if (t->is(T.c128)) return {NumPyType::NP_TYPE_C128}; if (isArrayType(t)) { auto generics = t->getGenerics(); seqassertn(generics.size() == 2 && generics[0].isType() && generics[1].isStatic(), "unrecognized ndarray generics"); auto *dtype = generics[0].getTypeValue(); auto ndim = generics[1].getStaticValue(); if (dtype->is(T.bool_)) return {NumPyType::NP_TYPE_ARR_BOOL, ndim}; if (dtype->is(T.i8)) return {NumPyType::NP_TYPE_ARR_I8, ndim}; if (dtype->is(T.u8)) return {NumPyType::NP_TYPE_ARR_U8, ndim}; if (dtype->is(T.i16)) return {NumPyType::NP_TYPE_ARR_I16, ndim}; if (dtype->is(T.u16)) return {NumPyType::NP_TYPE_ARR_U16, ndim}; if (dtype->is(T.i32)) return {NumPyType::NP_TYPE_ARR_I32, ndim}; if (dtype->is(T.u32)) return {NumPyType::NP_TYPE_ARR_U32, ndim}; if (dtype->is(T.i64)) return {NumPyType::NP_TYPE_ARR_I64, ndim}; if (dtype->is(T.u64)) return {NumPyType::NP_TYPE_ARR_U64, ndim}; if (dtype->is(T.f16)) return {NumPyType::NP_TYPE_ARR_F16, ndim}; if (dtype->is(T.f32)) return {NumPyType::NP_TYPE_ARR_F32, ndim}; if (dtype->is(T.f64)) return {NumPyType::NP_TYPE_ARR_F64, ndim}; if (dtype->is(T.c64)) return {NumPyType::NP_TYPE_ARR_C64, ndim}; if (dtype->is(T.c128)) return {NumPyType::NP_TYPE_ARR_C128, ndim}; } return {}; } types::Type *NumPyType::getIRBaseType(NumPyPrimitiveTypes &T) const { switch (dtype) { case NP_TYPE_NONE: seqassertn(false, "unexpected type code (NONE)"); return nullptr; case NP_TYPE_BOOL: return T.bool_; case NP_TYPE_I8: return T.i8; case NP_TYPE_U8: return T.u8; case NP_TYPE_I16: return T.i16; case NP_TYPE_U16: return T.u16; case NP_TYPE_I32: return T.i32; case NP_TYPE_U32: return T.u32; case NP_TYPE_I64: return T.i64; case NP_TYPE_U64: return T.u64; case NP_TYPE_F16: return T.f16; case NP_TYPE_F32: return T.f32; case NP_TYPE_F64: return T.f64; case NP_TYPE_C64: return T.c64; case NP_TYPE_C128: return T.c128; case NP_TYPE_SCALAR_END: seqassertn(false, "unexpected type code (SCALAR_END)"); return nullptr; case NP_TYPE_ARR_BOOL: return T.bool_; case NP_TYPE_ARR_I8: return T.i8; case NP_TYPE_ARR_U8: return T.u8; case NP_TYPE_ARR_I16: return T.i16; case NP_TYPE_ARR_U16: return T.u16; case NP_TYPE_ARR_I32: return T.i32; case NP_TYPE_ARR_U32: return T.u32; case NP_TYPE_ARR_I64: return T.i64; case NP_TYPE_ARR_U64: return T.u64; case NP_TYPE_ARR_F16: return T.f16; case NP_TYPE_ARR_F32: return T.f32; case NP_TYPE_ARR_F64: return T.f64; case NP_TYPE_ARR_C64: return T.c64; case NP_TYPE_ARR_C128: return T.c128; default: seqassertn(false, "unexpected type code (?)"); return nullptr; } } std::ostream &operator<<(std::ostream &os, NumPyType const &type) { static const std::unordered_map typestrings = { {NumPyType::NP_TYPE_NONE, "none"}, {NumPyType::NP_TYPE_BOOL, "bool"}, {NumPyType::NP_TYPE_I8, "i8"}, {NumPyType::NP_TYPE_U8, "u8"}, {NumPyType::NP_TYPE_I16, "i16"}, {NumPyType::NP_TYPE_U16, "u16"}, {NumPyType::NP_TYPE_I32, "i32"}, {NumPyType::NP_TYPE_U32, "u32"}, {NumPyType::NP_TYPE_I64, "i64"}, {NumPyType::NP_TYPE_U64, "u64"}, {NumPyType::NP_TYPE_F16, "f16"}, {NumPyType::NP_TYPE_F32, "f32"}, {NumPyType::NP_TYPE_F64, "f64"}, {NumPyType::NP_TYPE_C64, "c64"}, {NumPyType::NP_TYPE_C128, "c128"}, {NumPyType::NP_TYPE_SCALAR_END, ""}, {NumPyType::NP_TYPE_ARR_BOOL, "bool"}, {NumPyType::NP_TYPE_ARR_I8, "i8"}, {NumPyType::NP_TYPE_ARR_U8, "u8"}, {NumPyType::NP_TYPE_ARR_I16, "i16"}, {NumPyType::NP_TYPE_ARR_U16, "u16"}, {NumPyType::NP_TYPE_ARR_I32, "i32"}, {NumPyType::NP_TYPE_ARR_U32, "u32"}, {NumPyType::NP_TYPE_ARR_I64, "i64"}, {NumPyType::NP_TYPE_ARR_U64, "u64"}, {NumPyType::NP_TYPE_ARR_F16, "f16"}, {NumPyType::NP_TYPE_ARR_F32, "f32"}, {NumPyType::NP_TYPE_ARR_F64, "f64"}, {NumPyType::NP_TYPE_ARR_C64, "c64"}, {NumPyType::NP_TYPE_ARR_C128, "c128"}, }; auto it = typestrings.find(type.dtype); seqassertn(it != typestrings.end(), "type not found"); auto s = it->second; if (type.isArray()) os << "array[" << s << ", " << type.ndim << "]"; else os << s; return os; } std::string NumPyType::str() const { std::stringstream buffer; buffer << *this; return buffer.str(); } CodegenContext::CodegenContext(Module *M, SeriesFlow *series, BodiedFunc *func, NumPyPrimitiveTypes &T) : M(M), series(series), func(func), vars(), T(T) {} std::unique_ptr parse(Value *v, std::vector> &leaves, NumPyPrimitiveTypes &T) { struct NumPyMagicMethod { std::string name; NumPyExpr::Op op; int args; bool right; }; struct NumPyUFunc { std::string name; NumPyExpr::Op op; int args; }; static std::vector magics = { {Module::POS_MAGIC_NAME, NumPyExpr::NP_OP_POS, 1, false}, {Module::NEG_MAGIC_NAME, NumPyExpr::NP_OP_NEG, 1, false}, {Module::INVERT_MAGIC_NAME, NumPyExpr::NP_OP_INVERT, 1, false}, {Module::ABS_MAGIC_NAME, NumPyExpr::NP_OP_ABS, 1, false}, {Module::ADD_MAGIC_NAME, NumPyExpr::NP_OP_ADD, 2, false}, {Module::SUB_MAGIC_NAME, NumPyExpr::NP_OP_SUB, 2, false}, {Module::MUL_MAGIC_NAME, NumPyExpr::NP_OP_MUL, 2, false}, {Module::MATMUL_MAGIC_NAME, NumPyExpr::NP_OP_MATMUL, 2, false}, {Module::TRUE_DIV_MAGIC_NAME, NumPyExpr::NP_OP_TRUE_DIV, 2, false}, {Module::FLOOR_DIV_MAGIC_NAME, NumPyExpr::NP_OP_FLOOR_DIV, 2, false}, {Module::MOD_MAGIC_NAME, NumPyExpr::NP_OP_MOD, 2, false}, {Module::POW_MAGIC_NAME, NumPyExpr::NP_OP_POW, 2, false}, {Module::LSHIFT_MAGIC_NAME, NumPyExpr::NP_OP_LSHIFT, 2, false}, {Module::RSHIFT_MAGIC_NAME, NumPyExpr::NP_OP_RSHIFT, 2, false}, {Module::AND_MAGIC_NAME, NumPyExpr::NP_OP_AND, 2, false}, {Module::OR_MAGIC_NAME, NumPyExpr::NP_OP_OR, 2, false}, {Module::XOR_MAGIC_NAME, NumPyExpr::NP_OP_XOR, 2, false}, {Module::RADD_MAGIC_NAME, NumPyExpr::NP_OP_ADD, 2, true}, {Module::RSUB_MAGIC_NAME, NumPyExpr::NP_OP_SUB, 2, true}, {Module::RMUL_MAGIC_NAME, NumPyExpr::NP_OP_MUL, 2, true}, {Module::RMATMUL_MAGIC_NAME, NumPyExpr::NP_OP_MATMUL, 2, true}, {Module::RTRUE_DIV_MAGIC_NAME, NumPyExpr::NP_OP_TRUE_DIV, 2, true}, {Module::RFLOOR_DIV_MAGIC_NAME, NumPyExpr::NP_OP_FLOOR_DIV, 2, true}, {Module::RMOD_MAGIC_NAME, NumPyExpr::NP_OP_MOD, 2, true}, {Module::RPOW_MAGIC_NAME, NumPyExpr::NP_OP_POW, 2, true}, {Module::RLSHIFT_MAGIC_NAME, NumPyExpr::NP_OP_LSHIFT, 2, true}, {Module::RRSHIFT_MAGIC_NAME, NumPyExpr::NP_OP_RSHIFT, 2, true}, {Module::RAND_MAGIC_NAME, NumPyExpr::NP_OP_AND, 2, true}, {Module::ROR_MAGIC_NAME, NumPyExpr::NP_OP_OR, 2, true}, {Module::RXOR_MAGIC_NAME, NumPyExpr::NP_OP_XOR, 2, true}, {Module::EQ_MAGIC_NAME, NumPyExpr::NP_OP_EQ, 2, false}, {Module::NE_MAGIC_NAME, NumPyExpr::NP_OP_NE, 2, false}, {Module::LT_MAGIC_NAME, NumPyExpr::NP_OP_LT, 2, false}, {Module::LE_MAGIC_NAME, NumPyExpr::NP_OP_LE, 2, false}, {Module::GT_MAGIC_NAME, NumPyExpr::NP_OP_GT, 2, false}, {Module::GE_MAGIC_NAME, NumPyExpr::NP_OP_GE, 2, false}, }; static std::vector ufuncs = { {"positive", NumPyExpr::NP_OP_POS, 1}, {"negative", NumPyExpr::NP_OP_NEG, 1}, {"invert", NumPyExpr::NP_OP_INVERT, 1}, {"abs", NumPyExpr::NP_OP_ABS, 1}, {"absolute", NumPyExpr::NP_OP_ABS, 1}, {"add", NumPyExpr::NP_OP_ADD, 2}, {"subtract", NumPyExpr::NP_OP_SUB, 2}, {"multiply", NumPyExpr::NP_OP_MUL, 2}, {"divide", NumPyExpr::NP_OP_TRUE_DIV, 2}, {"floor_divide", NumPyExpr::NP_OP_FLOOR_DIV, 2}, {"remainder", NumPyExpr::NP_OP_MOD, 2}, {"fmod", NumPyExpr::NP_OP_FMOD, 2}, {"power", NumPyExpr::NP_OP_POW, 2}, {"left_shift", NumPyExpr::NP_OP_LSHIFT, 2}, {"right_shift", NumPyExpr::NP_OP_RSHIFT, 2}, {"bitwise_and", NumPyExpr::NP_OP_AND, 2}, {"bitwise_or", NumPyExpr::NP_OP_OR, 2}, {"bitwise_xor", NumPyExpr::NP_OP_XOR, 2}, {"logical_and", NumPyExpr::NP_OP_LOGICAL_AND, 2}, {"logical_or", NumPyExpr::NP_OP_LOGICAL_OR, 2}, {"logical_xor", NumPyExpr::NP_OP_LOGICAL_XOR, 2}, {"equal", NumPyExpr::NP_OP_EQ, 2}, {"not_equal", NumPyExpr::NP_OP_NE, 2}, {"less", NumPyExpr::NP_OP_LT, 2}, {"less_equal", NumPyExpr::NP_OP_LE, 2}, {"greater", NumPyExpr::NP_OP_GT, 2}, {"greater_equal", NumPyExpr::NP_OP_GE, 2}, {"minimum", NumPyExpr::NP_OP_MIN, 2}, {"maximum", NumPyExpr::NP_OP_MAX, 2}, {"fmin", NumPyExpr::NP_OP_FMIN, 2}, {"fmax", NumPyExpr::NP_OP_FMAX, 2}, {"sin", NumPyExpr::NP_OP_SIN, 1}, {"cos", NumPyExpr::NP_OP_COS, 1}, {"tan", NumPyExpr::NP_OP_TAN, 1}, {"arcsin", NumPyExpr::NP_OP_ARCSIN, 1}, {"arccos", NumPyExpr::NP_OP_ARCCOS, 1}, {"arctan", NumPyExpr::NP_OP_ARCTAN, 1}, {"arctan2", NumPyExpr::NP_OP_ARCTAN2, 2}, {"hypot", NumPyExpr::NP_OP_HYPOT, 2}, {"sinh", NumPyExpr::NP_OP_SINH, 1}, {"cosh", NumPyExpr::NP_OP_COSH, 1}, {"tanh", NumPyExpr::NP_OP_TANH, 1}, {"arcsinh", NumPyExpr::NP_OP_ARCSINH, 1}, {"arccosh", NumPyExpr::NP_OP_ARCCOSH, 1}, {"arctanh", NumPyExpr::NP_OP_ARCTANH, 1}, {"conjugate", NumPyExpr::NP_OP_CONJ, 1}, {"exp", NumPyExpr::NP_OP_EXP, 1}, {"exp2", NumPyExpr::NP_OP_EXP2, 1}, {"log", NumPyExpr::NP_OP_LOG, 1}, {"log2", NumPyExpr::NP_OP_LOG2, 1}, {"log10", NumPyExpr::NP_OP_LOG10, 1}, {"expm1", NumPyExpr::NP_OP_EXPM1, 1}, {"log1p", NumPyExpr::NP_OP_LOG1P, 1}, {"sqrt", NumPyExpr::NP_OP_SQRT, 1}, {"square", NumPyExpr::NP_OP_SQUARE, 1}, {"cbrt", NumPyExpr::NP_OP_CBRT, 1}, {"logaddexp", NumPyExpr::NP_OP_LOGADDEXP, 2}, {"logaddexp2", NumPyExpr::NP_OP_LOGADDEXP2, 2}, {"reciprocal", NumPyExpr::NP_OP_RECIPROCAL, 1}, {"rint", NumPyExpr::NP_OP_RINT, 1}, {"floor", NumPyExpr::NP_OP_FLOOR, 1}, {"ceil", NumPyExpr::NP_OP_CEIL, 1}, {"trunc", NumPyExpr::NP_OP_TRUNC, 1}, {"isnan", NumPyExpr::NP_OP_ISNAN, 1}, {"isinf", NumPyExpr::NP_OP_ISINF, 1}, {"isfinite", NumPyExpr::NP_OP_ISFINITE, 1}, {"sign", NumPyExpr::NP_OP_SIGN, 1}, {"signbit", NumPyExpr::NP_OP_SIGNBIT, 1}, {"copysign", NumPyExpr::NP_OP_COPYSIGN, 2}, {"spacing", NumPyExpr::NP_OP_SPACING, 1}, {"nextafter", NumPyExpr::NP_OP_NEXTAFTER, 2}, {"deg2rad", NumPyExpr::NP_OP_DEG2RAD, 1}, {"radians", NumPyExpr::NP_OP_DEG2RAD, 1}, {"rad2deg", NumPyExpr::NP_OP_RAD2DEG, 1}, {"degrees", NumPyExpr::NP_OP_RAD2DEG, 1}, {"heaviside", NumPyExpr::NP_OP_HEAVISIDE, 2}, }; auto getNumPyExprType = [](types::Type *t, NumPyPrimitiveTypes &T) -> NumPyType { if (t->is(T.bool_)) return {NumPyType::NP_TYPE_BOOL}; if (t->is(T.i8)) return {NumPyType::NP_TYPE_I8}; if (t->is(T.u8)) return {NumPyType::NP_TYPE_U8}; if (t->is(T.i16)) return {NumPyType::NP_TYPE_I16}; if (t->is(T.u16)) return {NumPyType::NP_TYPE_U16}; if (t->is(T.i32)) return {NumPyType::NP_TYPE_I32}; if (t->is(T.u32)) return {NumPyType::NP_TYPE_U32}; if (t->is(T.i64)) return {NumPyType::NP_TYPE_I64}; if (t->is(T.u64)) return {NumPyType::NP_TYPE_U64}; if (t->is(T.f16)) return {NumPyType::NP_TYPE_F16}; if (t->is(T.f32)) return {NumPyType::NP_TYPE_F32}; if (t->is(T.f64)) return {NumPyType::NP_TYPE_F64}; if (t->is(T.c64)) return {NumPyType::NP_TYPE_C64}; if (t->is(T.c128)) return {NumPyType::NP_TYPE_C128}; if (isArrayType(t)) { auto generics = t->getGenerics(); seqassertn(generics.size() == 2 && generics[0].isType() && generics[1].isStatic(), "unrecognized ndarray generics"); auto *dtype = generics[0].getTypeValue(); auto ndim = generics[1].getStaticValue(); if (dtype->is(T.bool_)) return {NumPyType::NP_TYPE_ARR_BOOL, ndim}; if (dtype->is(T.i8)) return {NumPyType::NP_TYPE_ARR_I8, ndim}; if (dtype->is(T.u8)) return {NumPyType::NP_TYPE_ARR_U8, ndim}; if (dtype->is(T.i16)) return {NumPyType::NP_TYPE_ARR_I16, ndim}; if (dtype->is(T.u16)) return {NumPyType::NP_TYPE_ARR_U16, ndim}; if (dtype->is(T.i32)) return {NumPyType::NP_TYPE_ARR_I32, ndim}; if (dtype->is(T.u32)) return {NumPyType::NP_TYPE_ARR_U32, ndim}; if (dtype->is(T.i64)) return {NumPyType::NP_TYPE_ARR_I64, ndim}; if (dtype->is(T.u64)) return {NumPyType::NP_TYPE_ARR_U64, ndim}; if (dtype->is(T.f16)) return {NumPyType::NP_TYPE_ARR_F16, ndim}; if (dtype->is(T.f32)) return {NumPyType::NP_TYPE_ARR_F32, ndim}; if (dtype->is(T.f64)) return {NumPyType::NP_TYPE_ARR_F64, ndim}; if (dtype->is(T.c64)) return {NumPyType::NP_TYPE_ARR_C64, ndim}; if (dtype->is(T.c128)) return {NumPyType::NP_TYPE_ARR_C128, ndim}; } return {}; }; auto type = getNumPyExprType(v->getType(), T); if (!type) return {}; // Don't break up expressions that result in scalars or 0-dim arrays since those // should only be computed once if (type.ndim == 0) { auto res = std::make_unique(type, v); leaves.emplace_back(res.get(), v); return std::move(res); } if (auto *c = cast(v)) { auto *f = util::getFunc(c->getCallee()); // Check for matmul if (f && c->numArgs() == 3 && isNoneType(c->back()->getType(), T) && (f->getName().rfind(ast::getMangledFunc("std.numpy.linalg_sym", "matmul") + "[", 0) == 0 || (f->getName().rfind(ast::getMangledFunc("std.numpy.linalg_sym", "dot") + "[]", 0) == 0 && type.ndim == 2))) { std::vector args(c->begin(), c->end()); auto op = NumPyExpr::NP_OP_MATMUL; auto lhs = parse(args[0], leaves, T); if (!lhs) return {}; auto rhs = parse(args[1], leaves, T); if (!rhs) return {}; return std::make_unique(type, v, op, std::move(lhs), std::move(rhs)); } // Check for builtin abs() if (f && c->numArgs() == 1 && (f->getName().rfind(ast::getMangledFunc("std.internal.builtin", "abs") + "[", 0) == 0)) { auto op = NumPyExpr::NP_OP_ABS; auto lhs = parse(c->front(), leaves, T); if (!lhs) return {}; return std::make_unique(type, v, op, std::move(lhs)); } // Check for transpose if (f && isArrayType(f->getParentType()) && c->numArgs() == 1 && f->getUnmangledName() == "T") { auto op = NumPyExpr::NP_OP_TRANSPOSE; auto lhs = parse(c->front(), leaves, T); if (!lhs) return {}; return std::make_unique(type, v, op, std::move(lhs)); } // Check for ufunc (e.g. "np.exp()") call if (f && f->getUnmangledName() == Module::CALL_MAGIC_NAME && isUFuncType(f->getParentType())) { auto ufuncGenerics = f->getParentType()->getGenerics(); seqassertn(!ufuncGenerics.empty() && ufuncGenerics[0].isStaticStr(), "unrecognized ufunc class generics"); auto ufunc = ufuncGenerics[0].getStaticStringValue(); auto callGenerics = f->getType()->getGenerics(); seqassertn(!callGenerics.empty() && callGenerics[0].isType(), "unrecognized ufunc call generics"); auto *dtype = callGenerics[0].getTypeValue(); if (dtype->is(T.none)) { for (auto &u : ufuncs) { if (u.name == ufunc) { seqassertn(u.args == 1 || u.args == 2, "unexpected number of arguments (ufunc)"); // Argument order: // - ufunc self // - operand 1 // - (if binary) operand 2 // - 'out' // - 'where' std::vector args(c->begin(), c->end()); seqassertn(args.size() == u.args + 3, "unexpected call of {}", u.name); auto *where = args[args.size() - 1]; auto *out = args[args.size() - 2]; if (auto *whereConst = cast(where)) { if (!whereConst->getVal()) break; } else { break; } if (!isNoneType(out->getType(), T)) break; auto op = u.op; auto lhs = parse(args[1], leaves, T); if (!lhs) return {}; if (u.args == 1) return std::make_unique(type, v, op, std::move(lhs)); auto rhs = parse(args[2], leaves, T); if (!rhs) return {}; return std::make_unique(type, v, op, std::move(lhs), std::move(rhs)); } } } } // Check for magic method call if (f && isArrayType(f->getParentType())) { for (auto &m : magics) { if (f->getUnmangledName() == m.name && c->numArgs() == m.args) { seqassertn(m.args == 1 || m.args == 2, "unexpected number of arguments (magic)"); std::vector args(c->begin(), c->end()); auto op = m.op; auto lhs = parse(args[0], leaves, T); if (!lhs) return {}; if (m.args == 1) return std::make_unique(type, v, op, std::move(lhs)); auto rhs = parse(args[1], leaves, T); if (!rhs) return {}; return m.right ? std::make_unique(type, v, op, std::move(rhs), std::move(lhs)) : std::make_unique(type, v, op, std::move(lhs), std::move(rhs)); } } } } // Check for right-hand-side magic method call // Right-hand-side magics (e.g. __radd__) are compiled into FlowInstr: // + // becomes: // { v1 = ; v2 = ; return rhs_class.__radd__(v2, v1) } // So we need to check for this to detect r-magics. if (auto *flow = cast(v)) { auto *series = cast(flow->getFlow()); auto *value = cast(flow->getValue()); auto *f = value ? util::getFunc(value->getCallee()) : nullptr; if (series && f && value->numArgs() == 2) { std::vector assignments(series->begin(), series->end()); auto *arg1 = value->front(); auto *arg2 = value->back(); auto *vv1 = cast(arg1); auto *vv2 = cast(arg2); auto *arg1Var = vv1 ? vv1->getVar() : nullptr; auto *arg2Var = vv2 ? vv2->getVar() : nullptr; for (auto &m : magics) { if (f->getUnmangledName() == m.name && value->numArgs() == m.args && m.right) { auto op = m.op; if (assignments.size() == 0) { // Case 1: Degenerate flow instruction return parse(value, leaves, T); } else if (assignments.size() == 1) { // Case 2: One var -- check if it's either of the r-magic operands auto *a1 = cast(assignments.front()); if (a1 && a1->getLhs() == arg1Var) { auto rhs = parse(a1->getRhs(), leaves, T); if (!rhs) return {}; auto lhs = parse(arg2, leaves, T); if (!lhs) return {}; return std::make_unique(type, v, op, std::move(lhs), std::move(rhs)); } else if (a1 && a1->getLhs() == arg2Var) { auto lhs = parse(a1->getRhs(), leaves, T); if (!lhs) return {}; auto rhs = parse(arg1, leaves, T); if (!rhs) return {}; return std::make_unique(type, v, op, std::move(lhs), std::move(rhs)); } } else if (assignments.size() == 2) { // Case 2: Two vars -- check both permutations auto *a1 = cast(assignments.front()); auto *a2 = cast(assignments.back()); if (a1 && a2 && a1->getLhs() == arg1Var && a2->getLhs() == arg2Var) { auto rhs = parse(a1->getRhs(), leaves, T); if (!rhs) return {}; auto lhs = parse(a2->getRhs(), leaves, T); if (!lhs) return {}; return std::make_unique(type, v, op, std::move(lhs), std::move(rhs)); } else if (a1 && a2 && a2->getLhs() == arg1Var && a1->getLhs() == arg2Var) { auto lhs = parse(a1->getRhs(), leaves, T); if (!lhs) return {}; auto rhs = parse(a2->getRhs(), leaves, T); if (!rhs) return {}; return std::make_unique(type, v, op, std::move(lhs), std::move(rhs)); } } break; } } } } auto res = std::make_unique(type, v); leaves.emplace_back(res.get(), v); return std::move(res); } namespace { Var *optimizeHelper(NumPyOptimizationUnit &unit, NumPyExpr *expr, CodegenContext &C) { auto *M = unit.value->getModule(); auto *series = C.series; // Remove some operations that cannot be done element-wise easily by optimizing them // separately, recursively. expr->apply([&](NumPyExpr &e) { if (!e.type.isArray()) return; if (e.op == NumPyExpr::NP_OP_TRANSPOSE) { auto *lv = optimizeHelper(unit, e.lhs.get(), C); auto *transposeFunc = M->getOrRealizeFunc("_transpose", {lv->getType()}, {}, FUSION_MODULE); seqassertn(transposeFunc, "transpose func not found"); auto *var = util::makeVar(util::call(transposeFunc, {M->Nr(lv)}), C.series, C.func); C.vars[&e] = var; NumPyExpr replacement(e.type, M->Nr(var)); replacement.freeable = e.lhs->freeable; e.replace(replacement); } if (e.op == NumPyExpr::NP_OP_MATMUL) { auto *lv = optimizeHelper(unit, e.lhs.get(), C); auto *rv = optimizeHelper(unit, e.rhs.get(), C); auto *matmulFunc = M->getOrRealizeFunc("_matmul", {lv->getType(), rv->getType()}, {}, FUSION_MODULE); seqassertn(matmulFunc, "matmul func not found"); auto *var = util::makeVar( util::call(matmulFunc, {M->Nr(lv), M->Nr(rv)}), C.series, C.func); C.vars[&e] = var; NumPyExpr replacement(e.type, M->Nr(var)); replacement.freeable = true; e.replace(replacement); } }); // Optimize the given expression bool changed; do { changed = false; expr->apply([&](NumPyExpr &e) { if (e.depth() <= 2) return; auto cost = e.cost(); auto bcinfo = e.getBroadcastInfo(); Var *result = nullptr; if (cost <= AlwaysFuseCostThreshold || (cost <= NeverFuseCostThreshold && bcinfo == BroadcastInfo::NO)) { // Don't care about broadcasting; just fuse. XLOG("-> static fuse:\n{}", e.str()); result = e.codegenFusedEval(C); } else if (cost <= NeverFuseCostThreshold && bcinfo != BroadcastInfo::YES) { // Check at runtime if we're broadcasting and fuse conditionally. XLOG("-> conditional fuse:\n{}", e.str()); auto *broadcasts = e.codegenBroadcasts(C); auto *seqtSeries = M->Nr(); auto *fuseSeries = M->Nr(); auto *branch = M->Nr(broadcasts, seqtSeries, fuseSeries); C.series = seqtSeries; auto *seqtResult = e.codegenSequentialEval(C); C.series = fuseSeries; auto *fuseResult = e.codegenFusedEval(C); seqassertn(seqtResult->getType()->is(fuseResult->getType()), "types are not the same: {} {}", seqtResult->getType()->getName(), fuseResult->getType()->getName()); result = M->Nr(seqtResult->getType(), false); unit.func->push_back(result); seqtSeries->push_back(M->Nr(result, M->Nr(seqtResult))); fuseSeries->push_back(M->Nr(result, M->Nr(fuseResult))); C.series = series; series->push_back(branch); } if (result) { NumPyExpr tmp(e.type, M->Nr(result)); e.replace(tmp); e.freeable = true; C.vars[&e] = result; changed = true; } }); } while (changed); XLOG("-> sequential eval:\n{}", expr->str()); return expr->codegenSequentialEval(C); } } // namespace bool NumPyOptimizationUnit::optimize(NumPyPrimitiveTypes &T) { if (!expr->type.isArray() || expr->depth() <= 2) return false; XLOG("Optimizing expression at {}\n{}", value->getSrcInfo(), expr->str()); auto *M = value->getModule(); auto *series = M->Nr(); CodegenContext C(M, series, func, T); util::CloneVisitor cv(M); for (auto &p : leaves) { auto *var = util::makeVar(cv.clone(p.second), series, func); C.vars.emplace(p.first, var); } auto *result = optimizeHelper(*this, expr.get(), C); auto *replacement = M->Nr(C.series, M->Nr(result)); value->replaceAll(replacement); return true; } struct ExtractArrayExpressions : public util::Operator { BodiedFunc *func; NumPyPrimitiveTypes types; std::vector exprs; std::unordered_set extracted; explicit ExtractArrayExpressions(BodiedFunc *func) : util::Operator(), func(func), types(func->getModule()), exprs(), extracted() {} void extract(Value *v, AssignInstr *assign = nullptr) { if (extracted.count(v->getId())) return; std::vector> leaves; auto expr = parse(v, leaves, types); if (expr) { int64_t numArrayNodes = 0; expr->apply([&](NumPyExpr &e) { if (e.type.isArray()) ++numArrayNodes; extracted.emplace(e.val->getId()); }); if (numArrayNodes > 0 && expr->depth() > 1) { exprs.push_back({v, func, std::move(expr), std::move(leaves), assign}); } } } void preHook(Node *n) override { if (auto *v = cast(n)) { extract(v->getRhs(), v->getLhs()->isGlobal() ? nullptr : v); } else if (auto *v = cast(n)) { extract(v); } } }; const std::string NumPyFusionPass::KEY = "core-numpy-fusion"; void NumPyFusionPass::visit(BodiedFunc *func) { ExtractArrayExpressions extractor(func); func->accept(extractor); if (extractor.exprs.empty()) return; auto *rdres = getAnalysisResult(reachingDefKey); auto it = rdres->results.find(func->getId()); if (it == rdres->results.end()) return; auto *rd = it->second.get(); auto *se = getAnalysisResult(sideEffectsKey); auto *cfg = rdres->cfgResult->graphs.find(func->getId())->second.get(); auto fwd = getForwardingDAGs(func, rd, cfg, se, extractor.exprs); for (auto &dag : fwd) { std::vector assignsToDelete; auto *e = doForwarding(dag, assignsToDelete); if (e->optimize(extractor.types)) { for (auto *a : assignsToDelete) a->replaceAll(func->getModule()->Nr()); } } } } // namespace numpy } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/numpy/numpy.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/analyze/dataflow/reaching.h" #include "codon/cir/analyze/module/global_vars.h" #include "codon/cir/analyze/module/side_effect.h" #include "codon/cir/transform/pass.h" #include "codon/cir/types/types.h" #include #include #include namespace codon { namespace ir { namespace transform { namespace numpy { extern const std::string FUSION_MODULE; /// NumPy operator fusion pass. class NumPyFusionPass : public OperatorPass { private: /// Key of the reaching definition analysis std::string reachingDefKey; /// Key of the side effect analysis std::string sideEffectsKey; public: static const std::string KEY; /// Constructs a NumPy fusion pass. /// @param reachingDefKey the reaching definition analysis' key /// @param sideEffectsKey side effect analysis' key NumPyFusionPass(const std::string &reachingDefKey, const std::string &sideEffectsKey) : OperatorPass(), reachingDefKey(reachingDefKey), sideEffectsKey(sideEffectsKey) { } std::string getKey() const override { return KEY; } void visit(BodiedFunc *f) override; }; struct NumPyPrimitiveTypes { types::Type *none; types::Type *optnone; types::Type *bool_; types::Type *i8; types::Type *u8; types::Type *i16; types::Type *u16; types::Type *i32; types::Type *u32; types::Type *i64; types::Type *u64; types::Type *f16; types::Type *f32; types::Type *f64; types::Type *c64; types::Type *c128; explicit NumPyPrimitiveTypes(Module *M); }; struct NumPyType { enum Type { NP_TYPE_NONE = -1, NP_TYPE_BOOL, NP_TYPE_I8, NP_TYPE_U8, NP_TYPE_I16, NP_TYPE_U16, NP_TYPE_I32, NP_TYPE_U32, NP_TYPE_I64, NP_TYPE_U64, NP_TYPE_F16, NP_TYPE_F32, NP_TYPE_F64, NP_TYPE_C64, NP_TYPE_C128, NP_TYPE_SCALAR_END, // separator value NP_TYPE_ARR_BOOL, NP_TYPE_ARR_I8, NP_TYPE_ARR_U8, NP_TYPE_ARR_I16, NP_TYPE_ARR_U16, NP_TYPE_ARR_I32, NP_TYPE_ARR_U32, NP_TYPE_ARR_I64, NP_TYPE_ARR_U64, NP_TYPE_ARR_F16, NP_TYPE_ARR_F32, NP_TYPE_ARR_F64, NP_TYPE_ARR_C64, NP_TYPE_ARR_C128, } dtype; int64_t ndim; NumPyType(Type dtype, int64_t ndim = 0); NumPyType(); static NumPyType get(types::Type *t, NumPyPrimitiveTypes &T); types::Type *getIRBaseType(NumPyPrimitiveTypes &T) const; operator bool() const { return dtype != NP_TYPE_NONE; } bool isArray() const { return dtype > NP_TYPE_SCALAR_END; } friend std::ostream &operator<<(std::ostream &os, NumPyType const &type); std::string str() const; }; struct NumPyExpr; struct CodegenContext { Module *M; SeriesFlow *series; BodiedFunc *func; std::unordered_map vars; NumPyPrimitiveTypes &T; CodegenContext(Module *M, SeriesFlow *series, BodiedFunc *func, NumPyPrimitiveTypes &T); }; enum BroadcastInfo { UNKNOWN, YES, NO, MAYBE, }; struct NumPyExpr { NumPyType type; Value *val; enum Op { NP_OP_NONE, NP_OP_POS, NP_OP_NEG, NP_OP_INVERT, NP_OP_ABS, NP_OP_TRANSPOSE, NP_OP_ADD, NP_OP_SUB, NP_OP_MUL, NP_OP_MATMUL, NP_OP_TRUE_DIV, NP_OP_FLOOR_DIV, NP_OP_MOD, NP_OP_FMOD, NP_OP_POW, NP_OP_LSHIFT, NP_OP_RSHIFT, NP_OP_AND, NP_OP_OR, NP_OP_XOR, NP_OP_LOGICAL_AND, NP_OP_LOGICAL_OR, NP_OP_LOGICAL_XOR, NP_OP_EQ, NP_OP_NE, NP_OP_LT, NP_OP_LE, NP_OP_GT, NP_OP_GE, NP_OP_MIN, NP_OP_MAX, NP_OP_FMIN, NP_OP_FMAX, NP_OP_SIN, NP_OP_COS, NP_OP_TAN, NP_OP_ARCSIN, NP_OP_ARCCOS, NP_OP_ARCTAN, NP_OP_ARCTAN2, NP_OP_HYPOT, NP_OP_SINH, NP_OP_COSH, NP_OP_TANH, NP_OP_ARCSINH, NP_OP_ARCCOSH, NP_OP_ARCTANH, NP_OP_CONJ, NP_OP_EXP, NP_OP_EXP2, NP_OP_LOG, NP_OP_LOG2, NP_OP_LOG10, NP_OP_EXPM1, NP_OP_LOG1P, NP_OP_SQRT, NP_OP_SQUARE, NP_OP_CBRT, NP_OP_LOGADDEXP, NP_OP_LOGADDEXP2, NP_OP_RECIPROCAL, NP_OP_RINT, NP_OP_FLOOR, NP_OP_CEIL, NP_OP_TRUNC, NP_OP_ISNAN, NP_OP_ISINF, NP_OP_ISFINITE, NP_OP_SIGN, NP_OP_SIGNBIT, NP_OP_COPYSIGN, NP_OP_SPACING, NP_OP_NEXTAFTER, NP_OP_DEG2RAD, NP_OP_RAD2DEG, NP_OP_HEAVISIDE, } op; std::unique_ptr lhs; std::unique_ptr rhs; bool freeable; NumPyExpr(NumPyType type, Value *val) : type(std::move(type)), val(val), op(NP_OP_NONE), lhs(), rhs(), freeable(false) { } NumPyExpr(NumPyType type, Value *val, NumPyExpr::Op op, std::unique_ptr lhs) : type(std::move(type)), val(val), op(op), lhs(std::move(lhs)), rhs(), freeable(false) {} NumPyExpr(NumPyType type, Value *val, NumPyExpr::Op op, std::unique_ptr lhs, std::unique_ptr rhs) : type(std::move(type)), val(val), op(op), lhs(std::move(lhs)), rhs(std::move(rhs)), freeable(false) {} static std::unique_ptr parse(Value *v, std::vector> &leaves, NumPyPrimitiveTypes &T); void replace(NumPyExpr &e); bool haveVectorizedLoop() const; int64_t opcost() const; int64_t cost() const; std::string opstring() const; void dump(std::ostream &os, int level, int &leafId) const; friend std::ostream &operator<<(std::ostream &os, NumPyExpr const &expr); std::string str() const; bool isLeaf() const { return !lhs && !rhs; } int depth() const { return std::max(lhs ? lhs->depth() : 0, rhs ? rhs->depth() : 0) + 1; } int nodes() const { return (lhs ? lhs->nodes() : 0) + (rhs ? rhs->nodes() : 0) + 1; } void apply(std::function f); Value *codegenBroadcasts(CodegenContext &C); Var *codegenFusedEval(CodegenContext &C); Var *codegenSequentialEval(CodegenContext &C); BroadcastInfo getBroadcastInfo(); Value *codegenScalarExpr(CodegenContext &C, const std::unordered_map &args, const std::unordered_map &scalarMap, Var *scalars); }; std::unique_ptr parse(Value *v, std::vector> &leaves, NumPyPrimitiveTypes &T); struct NumPyOptimizationUnit { /// Original IR value being corresponding to expression Value *value; /// Function in which the value exists BodiedFunc *func; /// Root expression std::unique_ptr expr; /// Leaves ordered by execution in original expression std::vector> leaves; /// AssignInstr in which RHS is represented by this expression, or null if none AssignInstr *assign; bool optimize(NumPyPrimitiveTypes &T); }; struct Forwarding { NumPyOptimizationUnit *dst; NumPyOptimizationUnit *src; Var *var; NumPyExpr *dstLeaf; int64_t dstId; int64_t srcId; }; using ForwardingDAG = std::unordered_map>; NumPyOptimizationUnit *doForwarding(ForwardingDAG &dag, std::vector &assignsToDelete); std::vector getForwardingDAGs(BodiedFunc *func, analyze::dataflow::RDInspector *rd, analyze::dataflow::CFGraph *cfg, analyze::module::SideEffectResult *se, std::vector &exprs); } // namespace numpy } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/parallel/openmp.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "openmp.h" #include #include #include #include #include "codon/cir/transform/parallel/schedule.h" #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/outlining.h" namespace codon { namespace ir { namespace transform { namespace parallel { namespace { const std::string ompModule = "std.openmp"; const std::string gpuModule = "std.internal.gpu"; const std::string builtinModule = "std.internal.builtin"; void warn(const std::string &msg, const Value *v) { auto src = v->getSrcInfo(); compilationWarning(msg, src.file, src.line, src.col); } struct OMPTypes { types::Type *i64 = nullptr; types::Type *i32 = nullptr; types::Type *i8ptr = nullptr; types::Type *i32ptr = nullptr; explicit OMPTypes(Module *M) { i64 = M->getIntType(); i32 = M->getIntNType(32, /*sign=*/true); i8ptr = M->getPointerType(M->getByteType()); i32ptr = M->getPointerType(i32); } }; Var *getVarFromOutlinedArg(Value *arg) { if (auto *val = cast(arg)) { return val->getVar(); } else if (auto *val = cast(arg)) { return val->getVar(); } else { seqassertn(false, "unknown outline var"); } return nullptr; } Value *ptrFromFunc(Func *func) { auto *M = func->getModule(); auto *funcType = func->getType(); auto *rawMethod = M->getOrRealizeMethod(funcType, "__raw__", {funcType}); seqassertn(rawMethod, "cannot find function __raw__ method"); return util::call(rawMethod, {M->Nr(func)}); } // we create the locks lazily to avoid them when they're not needed struct ReductionLocks { Var *mainLock = nullptr; // lock used in calls to _reduce_no_wait and _end_reduce_no_wait Var *critLock = nullptr; // lock used in reduction critical sections Var *createLock(Module *M) { auto *lockType = M->getOrRealizeType(ast::getMangledClass(ompModule, "Lock")); seqassertn(lockType, "openmp.Lock type not found"); auto *var = M->Nr(lockType, /*global=*/true); static int counter = 1; var->setName(".omp_lock." + std::to_string(counter++)); // add it to main function so it doesn't get demoted by IR pass auto *series = cast(cast(M->getMainFunc())->getBody()); auto *init = (*lockType)(); seqassertn(init, "could not initialize openmp.Lock"); series->insert(series->begin(), M->Nr(var, init)); return var; } Var *getMainLock(Module *M) { if (!mainLock) mainLock = createLock(M); return mainLock; } Var *getCritLock(Module *M) { if (!critLock) critLock = createLock(M); return critLock; } }; struct Reduction { enum Kind { NONE, ADD, MUL, AND, OR, XOR, MIN, MAX, }; Kind kind = Kind::NONE; Var *shared = nullptr; types::Type *getType() { auto *ptrType = cast(shared->getType()); seqassertn(ptrType, "expected shared var to be of pointer type"); return ptrType->getBase(); } Value *getInitial() { if (!*this) return nullptr; auto *M = shared->getModule(); auto *type = getType(); if (isA(type)) { switch (kind) { case Kind::ADD: return M->getInt(0); case Kind::MUL: return M->getInt(1); case Kind::AND: return M->getInt(~0); case Kind::OR: return M->getInt(0); case Kind::XOR: return M->getInt(0); case Kind::MIN: return M->getInt(std::numeric_limits::max()); case Kind::MAX: return M->getInt(std::numeric_limits::min()); default: return nullptr; } } else if (isA(type)) { switch (kind) { case Kind::ADD: return M->getFloat(0.); case Kind::MUL: return M->getFloat(1.); case Kind::MIN: return M->getFloat(std::numeric_limits::max()); case Kind::MAX: return M->getFloat(std::numeric_limits::min()); default: return nullptr; } } else if (isA(type)) { auto *f32 = M->getOrRealizeType("float32"); float value = 0.0; switch (kind) { case Kind::ADD: value = 0.0; break; case Kind::MUL: value = 1.0; break; case Kind::MIN: value = std::numeric_limits::max(); break; case Kind::MAX: value = std::numeric_limits::min(); break; default: return nullptr; } return (*f32)(*M->getFloat(value)); } auto *init = (*type)(); if (!init || !init->getType()->is(type)) return nullptr; return init; } Value *generateNonAtomicReduction(Value *ptr, Value *arg) { auto *M = ptr->getModule(); Value *lhs = util::ptrLoad(ptr); Value *result = nullptr; switch (kind) { case Kind::ADD: result = *lhs + *arg; break; case Kind::MUL: result = *lhs * *arg; break; case Kind::AND: result = *lhs & *arg; break; case Kind::OR: result = *lhs | *arg; break; case Kind::XOR: result = *lhs ^ *arg; break; case Kind::MIN: case Kind::MAX: { // signature is (tuple of args, key, default) auto name = (kind == Kind::MIN ? "min" : "max"); auto *tup = util::makeTuple({lhs, arg}); auto *none = (*M->getNoneType())(); auto *fn = M->getOrRealizeFunc( name, {tup->getType(), none->getType(), none->getType()}, {}, builtinModule); seqassertn(fn, "{} function not found", name); result = util::call(fn, {tup, none, none}); break; } default: return nullptr; } return util::ptrStore(ptr, result); } Value *generateAtomicReduction(Value *ptr, Value *arg, Var *loc, Var *gtid, ReductionLocks &locks) { auto *M = ptr->getModule(); auto *type = getType(); std::string func = ""; if (isA(type)) { switch (kind) { case Kind::ADD: func = "_atomic_int_add"; break; case Kind::MUL: func = "_atomic_int_mul"; break; case Kind::AND: func = "_atomic_int_and"; break; case Kind::OR: func = "_atomic_int_or"; break; case Kind::XOR: func = "_atomic_int_xor"; break; case Kind::MIN: func = "_atomic_int_min"; break; case Kind::MAX: func = "_atomic_int_max"; break; default: break; } } else if (isA(type)) { switch (kind) { case Kind::ADD: func = "_atomic_float_add"; break; case Kind::MUL: func = "_atomic_float_mul"; break; case Kind::MIN: func = "_atomic_float_min"; break; case Kind::MAX: func = "_atomic_float_max"; break; default: break; } } else if (isA(type)) { switch (kind) { case Kind::ADD: func = "_atomic_float32_add"; break; case Kind::MUL: func = "_atomic_float32_mul"; break; case Kind::MIN: func = "_atomic_float32_min"; break; case Kind::MAX: func = "_atomic_float32_max"; break; default: break; } } if (!func.empty()) { auto *atomicOp = M->getOrRealizeFunc(func, {ptr->getType(), arg->getType()}, {}, ompModule); seqassertn(atomicOp, "atomic op '{}' not found", func); return util::call(atomicOp, {ptr, arg}); } switch (kind) { case Kind::ADD: func = "__atomic_add__"; break; case Kind::MUL: func = "__atomic_mul__"; break; case Kind::AND: func = "__atomic_and__"; break; case Kind::OR: func = "__atomic_or__"; break; case Kind::XOR: func = "__atomic_xor__"; break; case Kind::MIN: func = "__atomic_min__"; break; case Kind::MAX: func = "__atomic_max__"; break; default: break; } if (!func.empty()) { auto *atomicOp = M->getOrRealizeMethod(arg->getType(), func, {ptr->getType(), arg->getType()}); if (atomicOp) return util::call(atomicOp, {ptr, arg}); } seqassertn(loc && gtid, "loc and/or gtid are null"); auto *lck = locks.getCritLock(M); auto *lckPtrType = M->getPointerType(lck->getType()); auto *critBegin = M->getOrRealizeFunc("_critical_begin", {loc->getType(), gtid->getType(), lckPtrType}, {}, ompModule); seqassertn(critBegin, "critical begin function not found"); auto *critEnd = M->getOrRealizeFunc( "_critical_end", {loc->getType(), gtid->getType(), lckPtrType}, {}, ompModule); seqassertn(critEnd, "critical end function not found"); auto *critEnter = util::call(critBegin, {M->Nr(loc), M->Nr(gtid), M->Nr(lck)}); auto *operation = generateNonAtomicReduction(ptr, arg); auto *critExit = util::call(critEnd, {M->Nr(loc), M->Nr(gtid), M->Nr(lck)}); // make sure the unlock is in a finally-block return util::series(critEnter, M->Nr(util::series(operation), util::series(critExit))); } operator bool() const { return kind != Kind::NONE; } }; struct ReductionFunction { std::string name; Reduction::Kind kind; bool method; }; struct ReductionIdentifier : public util::Operator { std::vector shareds; Var *loopVarArg; std::unordered_map reductions; ReductionIdentifier() : util::Operator(), shareds(), loopVarArg(nullptr), reductions() {} ReductionIdentifier(std::vector shareds, Var *loopVarArg) : util::Operator(), shareds(std::move(shareds)), loopVarArg(loopVarArg), reductions() {} bool isShared(Var *shared) { if (loopVarArg && shared->getId() == loopVarArg->getId()) return false; for (auto *v : shareds) { if (shared->getId() == v->getId()) return true; } return false; } bool isSharedDeref(Var *shared, Value *v) { auto *M = v->getModule(); auto *ptrType = cast(shared->getType()); seqassertn(ptrType, "expected shared var to be of pointer type"); auto *type = ptrType->getBase(); if (util::isCallOf(v, Module::GETITEM_MAGIC_NAME, {ptrType, M->getIntType()}, type, /*method=*/true)) { auto *call = cast(v); auto *var = util::getVar(call->front()); return util::isConst(call->back(), 0) && var && var->getId() == shared->getId(); } return false; } static void extractAssociativeOpChain(Value *v, const std::string &op, types::Type *type, std::vector &result) { if (util::isCallOf(v, op, {type, nullptr}, type, /*method=*/true) || util::isCallOf(v, op, {nullptr, type}, type, /*method=*/true)) { auto *call = cast(v); extractAssociativeOpChain(call->front(), op, type, result); extractAssociativeOpChain(call->back(), op, type, result); } else { result.push_back(v); } } Reduction getReductionFromCall(CallInstr *v) { auto *M = v->getModule(); auto *func = util::getFunc(v->getCallee()); if (v->numArgs() != 3 || !func || func->getUnmangledName() != Module::SETITEM_MAGIC_NAME) return {}; std::vector args(v->begin(), v->end()); Value *self = args[0]; Value *idx = args[1]; Value *item = args[2]; Var *shared = util::getVar(self); if (!shared || !isShared(shared) || !util::isConst(idx, 0)) return {}; auto *ptrType = cast(shared->getType()); seqassertn(ptrType, "expected shared var to be of pointer type"); auto *type = ptrType->getBase(); auto *noneType = M->getOptionalType(M->getNoneType()); // double-check the call if (!util::isCallOf(v, Module::SETITEM_MAGIC_NAME, {self->getType(), idx->getType(), item->getType()}, M->getNoneType(), /*method=*/true)) return {}; const std::vector reductionFunctions = { {Module::ADD_MAGIC_NAME, Reduction::Kind::ADD, true}, {Module::MUL_MAGIC_NAME, Reduction::Kind::MUL, true}, {Module::AND_MAGIC_NAME, Reduction::Kind::AND, true}, {Module::OR_MAGIC_NAME, Reduction::Kind::OR, true}, {Module::XOR_MAGIC_NAME, Reduction::Kind::XOR, true}, {"min", Reduction::Kind::MIN, false}, {"max", Reduction::Kind::MAX, false}, }; for (auto &rf : reductionFunctions) { if (rf.method) { if (!(util::isCallOf(item, rf.name, {type, nullptr}, type, /*method=*/true) || util::isCallOf(item, rf.name, {nullptr, type}, type, /*method=*/true))) continue; } else { if (!util::isCallOf(item, rf.name, {M->getTupleType({type, type}), noneType, noneType}, type, /*method=*/false)) continue; } auto *callRHS = cast(item); Value *deref = nullptr; if (rf.method) { std::vector opChain; extractAssociativeOpChain(callRHS, rf.name, type, opChain); if (opChain.size() < 2) continue; for (auto *val : opChain) { if (isSharedDeref(shared, val)) { deref = val; break; } } } else { callRHS = cast(callRHS->front()); // this will be Tuple.__new__ if (!callRHS) continue; for (auto *val : *callRHS) { if (isSharedDeref(shared, val)) { deref = val; break; } } } if (!deref) return {}; Reduction reduction = {rf.kind, shared}; if (!reduction.getInitial()) return {}; return reduction; } return {}; } Reduction getReduction(Var *shared) { auto it = reductions.find(shared->getId()); return (it != reductions.end()) ? it->second : Reduction(); } void handle(CallInstr *v) override { if (auto reduction = getReductionFromCall(v)) { auto it = reductions.find(reduction.shared->getId()); // if we've seen the var before, make sure it's consistent // otherwise mark as invalid via an empty reduction if (it == reductions.end()) { reductions.emplace(reduction.shared->getId(), reduction); } else if (it->second && it->second.kind != reduction.kind) { it->second = {}; } } } }; struct SharedInfo { unsigned memb; // member index in template's `extra` arg Var *local; // the local var we create to store current value Reduction reduction; // the reduction we're performing, or empty if none }; struct LoopTemplateReplacer : public util::Operator { BodiedFunc *parent; CallInstr *replacement; Var *loopVar; LoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar) : util::Operator(), parent(parent), replacement(replacement), loopVar(loopVar) {} }; struct ParallelLoopTemplateReplacer : public LoopTemplateReplacer { ReductionIdentifier *reds; std::vector sharedInfo; ReductionLocks locks; Var *locRef; Var *reductionLocRef; Var *gtid; ParallelLoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar, ReductionIdentifier *reds) : LoopTemplateReplacer(parent, replacement, loopVar), reds(reds), sharedInfo(), locks(), locRef(nullptr), reductionLocRef(nullptr), gtid(nullptr) {} unsigned numReductions() { unsigned num = 0; for (auto &info : sharedInfo) { if (info.reduction) num += 1; } return num; } Value *getReductionTuple() { auto *M = parent->getModule(); std::vector elements; for (auto &info : sharedInfo) { if (info.reduction) elements.push_back(M->Nr(info.local)); } return util::makeTuple(elements, M); } BodiedFunc *makeReductionFunc() { auto *M = parent->getModule(); auto *tupleType = getReductionTuple()->getType(); auto *argType = M->getPointerType(tupleType); auto *funcType = M->getFuncType(M->getNoneType(), {argType, argType}); auto *reducer = M->Nr("__omp_reducer"); reducer->realize(funcType, {"lhs", "rhs"}); auto *lhsVar = reducer->arg_front(); auto *rhsVar = reducer->arg_back(); auto *body = M->Nr(); unsigned next = 0; for (auto &info : sharedInfo) { if (info.reduction) { auto *lhs = util::ptrLoad(M->Nr(lhsVar)); auto *rhs = util::ptrLoad(M->Nr(rhsVar)); auto *lhsElem = util::tupleGet(lhs, next); auto *rhsElem = util::tupleGet(rhs, next); body->push_back( info.reduction.generateNonAtomicReduction(lhsElem, util::ptrLoad(rhsElem))); ++next; } } reducer->setBody(body); return reducer; } void handle(CallInstr *v) override { auto *M = v->getModule(); auto *func = util::getFunc(v->getCallee()); if (!func) return; auto name = func->getUnmangledName(); if (name == "_loop_loc_and_gtid") { seqassertn(v->numArgs() == 3 && std::all_of(v->begin(), v->end(), [](auto x) { return isA(x); }), "unexpected loop loc and gtid stub"); std::vector args(v->begin(), v->end()); locRef = util::getVar(args[0]); reductionLocRef = util::getVar(args[1]); gtid = util::getVar(args[2]); } if (name == "_loop_reductions") { seqassertn(reductionLocRef && gtid, "bad visit order in template"); seqassertn(v->numArgs() == 1 && isA(v->front()), "unexpected shared updates stub"); if (numReductions() == 0) return; auto *M = parent->getModule(); auto *extras = util::getVar(v->front()); auto *reductionTuple = getReductionTuple(); auto *reducer = makeReductionFunc(); auto *lck = locks.getMainLock(M); auto *rawReducer = ptrFromFunc(reducer); auto *lckPtrType = M->getPointerType(lck->getType()); auto *reduceNoWait = M->getOrRealizeFunc( "_reduce_nowait", {reductionLocRef->getType(), gtid->getType(), reductionTuple->getType(), rawReducer->getType(), lckPtrType}, {}, ompModule); seqassertn(reduceNoWait, "reduce nowait function not found"); auto *reduceNoWaitEnd = M->getOrRealizeFunc( "_end_reduce_nowait", {reductionLocRef->getType(), gtid->getType(), lckPtrType}, {}, ompModule); seqassertn(reduceNoWaitEnd, "end reduce nowait function not found"); auto *series = M->Nr(); auto *tupleVal = util::makeVar(reductionTuple, series, parent); auto *reduceCode = util::call(reduceNoWait, {M->Nr(reductionLocRef), M->Nr(gtid), M->Nr(tupleVal), rawReducer, M->Nr(lck)}); auto *codeVar = util::makeVar(reduceCode, series, parent); seqassertn(codeVar->getType()->is(M->getIntType()), "wrong reduce code type"); auto *sectionNonAtomic = M->Nr(); auto *sectionAtomic = M->Nr(); for (auto &info : sharedInfo) { if (info.reduction) { Value *ptr = util::tupleGet(M->Nr(extras), info.memb); Value *arg = M->Nr(info.local); sectionNonAtomic->push_back( info.reduction.generateNonAtomicReduction(ptr, arg)); } } sectionNonAtomic->push_back(util::call( reduceNoWaitEnd, {M->Nr(reductionLocRef), M->Nr(gtid), M->Nr(lck)})); for (auto &info : sharedInfo) { if (info.reduction) { Value *ptr = util::tupleGet(M->Nr(extras), info.memb); Value *arg = M->Nr(info.local); sectionAtomic->push_back( info.reduction.generateAtomicReduction(ptr, arg, locRef, gtid, locks)); } } // make: if code == 1 { sectionNonAtomic } elif code == 2 { sectionAtomic } auto *theSwitch = M->Nr( *M->Nr(codeVar) == *M->getInt(1), sectionNonAtomic, util::series(M->Nr(*M->Nr(codeVar) == *M->getInt(2), sectionAtomic))); series->push_back(theSwitch); v->replaceAll(series); } } }; struct ImperativeLoopTemplateReplacer : public ParallelLoopTemplateReplacer { OMPSched *sched; int64_t step; ImperativeLoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar, ReductionIdentifier *reds, OMPSched *sched, int64_t step) : ParallelLoopTemplateReplacer(parent, replacement, loopVar, reds), sched(sched), step(step) {} void handle(CallInstr *v) override { ParallelLoopTemplateReplacer::handle(v); auto *M = v->getModule(); auto *func = util::getFunc(v->getCallee()); if (!func) return; auto name = func->getUnmangledName(); if (name == "_loop_step") { v->replaceAll(M->getInt(step)); } if (name == "_loop_body_stub") { seqassertn(replacement, "unexpected double replacement"); seqassertn(v->numArgs() == 2 && isA(v->front()) && isA(v->back()), "unexpected loop body stub"); auto *outlinedFunc = util::getFunc(replacement->getCallee()); // the template passes the new loop var and extra args // to the body stub for convenience auto *newLoopVar = util::getVar(v->front()); auto *extras = util::getVar(v->back()); std::vector newArgs; auto outlinedArgs = outlinedFunc->arg_begin(); // arg vars of *outlined func* unsigned next = 0; // next index in "extra" args tuple, passed to template // `arg` is an argument of the original outlined func call for (auto *arg : *replacement) { if (getVarFromOutlinedArg(arg)->getId() != loopVar->getId()) { Value *newArg = nullptr; // shared vars will be stored in a new var if (isA(arg)) { types::Type *base = cast(arg->getType())->getBase(); // get extras again since we'll be inserting the new var before extras local Var *lastArg = parent->arg_back(); // ptr to {chunk, start, stop, extras} Value *val = util::tupleGet(util::ptrLoad(M->Nr(lastArg)), 3); Value *initVal = util::ptrLoad(util::tupleGet(val, next)); Reduction reduction = reds->getReduction(*outlinedArgs); if (reduction) { initVal = reduction.getInitial(); seqassertn(initVal && initVal->getType()->is(base), "unknown reduction init value"); } auto *newVar = util::makeVar(initVal, cast(parent->getBody()), parent, /*prepend=*/true); sharedInfo.push_back({next, newVar, reduction}); newArg = M->Nr(newVar); ++next; } else { newArg = util::tupleGet(M->Nr(extras), next++); } newArgs.push_back(newArg); } else { if (isA(arg)) { newArgs.push_back(M->Nr(newLoopVar)); } else if (isA(arg)) { newArgs.push_back(M->Nr(newLoopVar)); } else { seqassertn(false, "unknown outline var"); } } ++outlinedArgs; } v->replaceAll(util::call(outlinedFunc, newArgs)); replacement = nullptr; } if (name == "_loop_shared_updates") { // for all non-reduction shareds, set the final values // this will be similar to OpenMP's "lastprivate" seqassertn(v->numArgs() == 1 && isA(v->front()), "unexpected shared updates stub"); auto *extras = util::getVar(v->front()); auto *series = M->Nr(); for (auto &info : sharedInfo) { if (info.reduction) continue; auto *finalValue = M->Nr(info.local); auto *val = M->Nr(extras); auto *origPtr = util::tupleGet(val, info.memb); series->push_back(util::ptrStore(origPtr, finalValue)); } v->replaceAll(series); } if (name == "_loop_schedule") { v->replaceAll(M->getInt(sched->code)); } if (name == "_loop_ordered") { v->replaceAll(M->getBool(sched->ordered)); } } }; struct TaskLoopReductionVarReplacer : public util::Operator { std::vector reductionArgs; std::vector> reductionRemap; BodiedFunc *parent; void setupReductionRemap() { auto *M = parent->getModule(); for (auto *var : reductionArgs) { auto *newVar = M->Nr(var->getType(), /*global=*/false); reductionRemap.emplace_back(var, newVar); } } TaskLoopReductionVarReplacer(std::vector reductionArgs, BodiedFunc *parent) : util::Operator(), reductionArgs(std::move(reductionArgs)), reductionRemap(), parent(parent) { setupReductionRemap(); } void preHook(Node *v) override { for (auto &p : reductionRemap) { v->replaceUsedVariable(p.first->getId(), p.second); } } // need to do this as a separate step since otherwise the old variable // in the assignment will be replaced, which we don't want void finalize() { auto *M = parent->getModule(); auto *body = cast(parent->getBody()); auto *gtid = parent->arg_back(); for (auto &p : reductionRemap) { auto *taskRedData = M->getOrRealizeFunc( "_taskred_data", {M->getIntType(), p.first->getType()}, {}, ompModule); seqassertn(taskRedData, "could not find '_taskred_data'"); auto *assign = M->Nr( p.second, util::call(taskRedData, {M->Nr(gtid), M->Nr(p.first)})); body->insert(body->begin(), assign); parent->push_back(p.second); } } }; struct TaskLoopBodyStubReplacer : public util::Operator { CallInstr *replacement; std::vector reduceArgs; TaskLoopBodyStubReplacer(CallInstr *replacement, std::vector reduceArgs) : util::Operator(), replacement(replacement), reduceArgs(std::move(reduceArgs)) {} void handle(CallInstr *v) override { auto *func = util::getFunc(v->getCallee()); if (func && func->getUnmangledName() == "_task_loop_body_stub") { seqassertn(replacement, "unexpected double replacement"); seqassertn(v->numArgs() == 3 && isA(v->front()) && isA(v->back()), "unexpected loop body stub"); // the template passes gtid, privs and shareds to the body stub for convenience std::vector args(v->begin(), v->end()); auto *gtid = args[0]; auto *privatesTuple = args[1]; auto *sharedsTuple = args[2]; unsigned privatesNext = 0; unsigned sharedsNext = 0; std::vector newArgs; bool hasReductions = std::any_of(reduceArgs.begin(), reduceArgs.end(), [](bool b) { return b; }); for (auto *arg : *replacement) { if (isA(arg)) { newArgs.push_back(util::tupleGet(privatesTuple, privatesNext++)); } else if (isA(arg)) { newArgs.push_back(util::tupleGet(sharedsTuple, sharedsNext++)); } else { // make sure we're on the last arg, which should be gtid // in case of reductions seqassertn(hasReductions && arg == replacement->back(), "unknown outline var"); } } auto *outlinedFunc = cast(util::getFunc(replacement->getCallee())); if (hasReductions) { newArgs.push_back(gtid); std::vector reductionArgs; unsigned i = 0; for (auto it = outlinedFunc->arg_begin(); it != outlinedFunc->arg_end(); ++it) { if (reduceArgs[i++]) reductionArgs.push_back(*it); } TaskLoopReductionVarReplacer redrep(reductionArgs, outlinedFunc); outlinedFunc->accept(redrep); redrep.finalize(); } v->replaceAll(util::call(outlinedFunc, newArgs)); replacement = nullptr; } } }; struct TaskLoopRoutineStubReplacer : public ParallelLoopTemplateReplacer { std::vector privates; std::vector shareds; Var *array; // task reduction input array Var *tskgrp; // task group identifier void setupSharedInfo(std::vector &sharedRedux) { unsigned sharedsNext = 0; for (auto *val : shareds) { if (getVarFromOutlinedArg(val)->getId() != loopVar->getId()) { if (auto &reduction = sharedRedux[sharedsNext]) { auto *newVar = util::makeVar(reduction.getInitial(), cast(parent->getBody()), parent, /*prepend=*/true); sharedInfo.push_back({sharedsNext, newVar, reduction}); } } ++sharedsNext; } } TaskLoopRoutineStubReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar, ReductionIdentifier *reds, std::vector privates, std::vector shareds, std::vector sharedRedux) : ParallelLoopTemplateReplacer(parent, replacement, loopVar, reds), privates(std::move(privates)), shareds(std::move(shareds)), array(nullptr), tskgrp(nullptr) { setupSharedInfo(sharedRedux); } BodiedFunc *makeTaskRedInitFunc(Reduction *reduction) { auto *M = parent->getModule(); auto *argType = M->getPointerType(reduction->getType()); auto *funcType = M->getFuncType(M->getNoneType(), {argType, argType}); auto *initializer = M->Nr("__red_init"); initializer->realize(funcType, {"lhs", "rhs"}); auto *lhsVar = initializer->arg_front(); auto *body = M->Nr(); auto *lhsPtr = M->Nr(lhsVar); body->push_back(util::ptrStore(lhsPtr, reduction->getInitial())); initializer->setBody(body); return initializer; } BodiedFunc *makeTaskRedCombFunc(Reduction *reduction) { auto *M = parent->getModule(); auto *argType = M->getPointerType(reduction->getType()); auto *funcType = M->getFuncType(M->getNoneType(), {argType, argType}); auto *reducer = M->Nr("__red_comb"); reducer->realize(funcType, {"lhs", "rhs"}); auto *lhsVar = reducer->arg_front(); auto *rhsVar = reducer->arg_back(); auto *body = M->Nr(); auto *lhsPtr = M->Nr(lhsVar); auto *rhsPtr = M->Nr(rhsVar); body->push_back( reduction->generateNonAtomicReduction(lhsPtr, util::ptrLoad(rhsPtr))); reducer->setBody(body); return reducer; } Value *makeTaskRedInput(Reduction *reduction, Value *shar, Value *orig) { auto *M = shar->getModule(); auto *size = M->Nr(reduction->getType(), TypePropertyInstr::Property::SIZEOF); auto *init = ptrFromFunc(makeTaskRedInitFunc(reduction)); auto *comb = ptrFromFunc(makeTaskRedCombFunc(reduction)); auto *taskRedInputType = M->getOrRealizeType(ast::getMangledClass(ompModule, "TaskReductionInput")); seqassertn(taskRedInputType, "could not find 'TaskReductionInput' type"); auto *result = taskRedInputType->construct({shar, orig, size, init, comb}); seqassertn(result, "bad construction of 'TaskReductionInput' type"); return result; } void handle(VarValue *v) override { auto *M = v->getModule(); auto *func = util::getFunc(v); if (func && func->getUnmangledName() == "_routine_stub") { std::vector reduceArgs; unsigned sharedsNext = 0; unsigned infoNext = 0; for (auto *arg : *replacement) { if (isA(arg)) { reduceArgs.push_back(false); } else if (isA(arg)) { if (infoNext < sharedInfo.size() && sharedInfo[infoNext].memb == sharedsNext && sharedInfo[infoNext].reduction) { reduceArgs.push_back(true); ++infoNext; } else { reduceArgs.push_back(false); } ++sharedsNext; } else { // make sure we're on the last arg, which should be gtid // in case of reductions seqassertn(numReductions() > 0 && arg == replacement->back(), "unknown outline var"); reduceArgs.push_back(false); } } util::CloneVisitor cv(M); auto *newRoutine = cv.forceClone(func); TaskLoopBodyStubReplacer rep(replacement, reduceArgs); newRoutine->accept(rep); v->setVar(newRoutine); } } void handle(CallInstr *v) override { ParallelLoopTemplateReplacer::handle(v); auto *M = v->getModule(); auto *func = util::getFunc(v->getCallee()); if (!func) return; auto name = func->getUnmangledName(); if (name == "_taskred_setup") { seqassertn(reductionLocRef && gtid, "bad visit order in template"); seqassertn(v->numArgs() == 1 && isA(v->front()), "unexpected shared updates stub"); unsigned numRed = numReductions(); if (numRed == 0) return; auto *M = parent->getModule(); auto *extras = util::getVar(v->front()); // add task reduction inputs auto *taskRedInitSeries = M->Nr(); auto *taskRedInputType = M->getOrRealizeType(ast::getMangledClass(ompModule, "TaskReductionInput")); seqassertn(taskRedInputType, "could not find 'TaskReductionInput' type"); auto *irArrayType = M->getOrRealizeType( ast::getMangledClass(ompModule, "TaskReductionInputArray")); seqassertn(irArrayType, "could not find 'TaskReductionInputArray' type"); auto *taskRedInputsArray = util::makeVar( M->Nr(irArrayType, numRed), taskRedInitSeries, parent); array = taskRedInputsArray; auto *taskRedInputsArrayType = taskRedInputsArray->getType(); auto *taskRedSetItem = M->getOrRealizeMethod( taskRedInputsArrayType, Module::SETITEM_MAGIC_NAME, {taskRedInputsArrayType, M->getIntType(), taskRedInputType}); seqassertn(taskRedSetItem, "could not find 'TaskReductionInputArray.__setitem__' method"); int i = 0; for (auto &info : sharedInfo) { if (info.reduction) { Value *shar = M->Nr(info.local); Value *orig = util::tupleGet(M->Nr(extras), info.memb); auto *taskRedInput = makeTaskRedInput(&info.reduction, shar, orig); taskRedInitSeries->push_back(util::call( taskRedSetItem, {M->Nr(array), M->getInt(i++), taskRedInput})); } } auto *arrayPtr = M->Nr(M->Nr(array), "ptr"); auto *taskRedInitFunc = M->getOrRealizeFunc("_taskred_init", {reductionLocRef->getType(), gtid->getType(), M->getIntType(), arrayPtr->getType()}, {}, ompModule); seqassertn(taskRedInitFunc, "task red init function not found"); auto *taskRedInitResult = util::makeVar(util::call(taskRedInitFunc, {M->Nr(reductionLocRef), M->Nr(gtid), M->getInt(numRed), arrayPtr}), taskRedInitSeries, parent); tskgrp = taskRedInitResult; v->replaceAll(taskRedInitSeries); } if (name == "_fix_privates_and_shareds") { std::vector args(v->begin(), v->end()); seqassertn(args.size() == 3, "invalid _fix_privates_and_shareds call found"); unsigned numRed = numReductions(); auto *newLoopVar = args[0]; auto *privatesTuple = args[1]; auto *sharedsTuple = args[2]; unsigned privatesNext = 0; unsigned sharedsNext = 0; unsigned infoNext = 0; bool needNewPrivates = false; bool needNewShareds = false; std::vector newPrivates; std::vector newShareds; for (auto *val : privates) { if (numRed > 0 && val == privates.back()) { // i.e. task group identifier seqassertn(tskgrp, "tskgrp var not set"); newPrivates.push_back(M->Nr(tskgrp)); needNewPrivates = true; } else if (getVarFromOutlinedArg(val)->getId() != loopVar->getId()) { newPrivates.push_back(util::tupleGet(privatesTuple, privatesNext)); } else { newPrivates.push_back(newLoopVar); needNewPrivates = true; } ++privatesNext; } for (auto *val : shareds) { if (getVarFromOutlinedArg(val)->getId() != loopVar->getId()) { if (infoNext < sharedInfo.size() && sharedInfo[infoNext].memb == sharedsNext && sharedInfo[infoNext].reduction) { newShareds.push_back(M->Nr(sharedInfo[infoNext].local)); needNewShareds = true; ++infoNext; } else { newShareds.push_back(util::tupleGet(sharedsTuple, sharedsNext)); } } else { newShareds.push_back(M->Nr(util::getVar(newLoopVar))); needNewShareds = true; } ++sharedsNext; } privatesTuple = needNewPrivates ? util::makeTuple(newPrivates, M) : privatesTuple; sharedsTuple = needNewShareds ? util::makeTuple(newShareds, M) : sharedsTuple; Value *result = util::makeTuple({privatesTuple, sharedsTuple}, M); v->replaceAll(result); } if (name == "_taskred_finish") { seqassertn(reductionLocRef && gtid, "bad visit order in template"); if (numReductions() == 0) return; auto *taskRedFini = M->getOrRealizeFunc( "_taskred_fini", {reductionLocRef->getType(), gtid->getType()}, {}, ompModule); seqassertn(taskRedFini, "taskred finish function not found not found"); v->replaceAll(util::call( taskRedFini, {M->Nr(reductionLocRef), M->Nr(gtid)})); } } }; struct GPULoopBodyStubReplacer : public util::Operator { CallInstr *replacement; Var *loopVar; int64_t step; GPULoopBodyStubReplacer(CallInstr *replacement, Var *loopVar, int64_t step) : util::Operator(), replacement(replacement), loopVar(loopVar), step(step) {} void handle(CallInstr *v) override { auto *M = v->getModule(); auto *func = util::getFunc(v->getCallee()); if (!func) return; auto name = func->getUnmangledName(); if (name == "_gpu_loop_body_stub") { seqassertn(replacement, "unexpected double replacement"); seqassertn(v->numArgs() == 2, "unexpected loop body stub"); // the template passes gtid, privs and shareds to the body stub for convenience auto *idx = v->front(); auto *args = v->back(); unsigned next = 0; std::vector newArgs; for (auto *arg : *replacement) { if (getVarFromOutlinedArg(arg)->getId() == loopVar->getId()) { newArgs.push_back(idx); } else { newArgs.push_back(util::tupleGet(args, next++)); } } auto *outlinedFunc = cast(util::getFunc(replacement->getCallee())); v->replaceAll(util::call(outlinedFunc, newArgs)); replacement = nullptr; } if (name == "_loop_step") { v->replaceAll(M->getInt(step)); } } }; struct GPULoopTemplateReplacer : public LoopTemplateReplacer { int64_t step; GPULoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar, int64_t step) : LoopTemplateReplacer(parent, replacement, loopVar), step(step) {} void handle(CallInstr *v) override { auto *M = v->getModule(); auto *func = util::getFunc(v->getCallee()); if (!func) return; auto name = func->getUnmangledName(); if (name == "_loop_step") { v->replaceAll(M->getInt(step)); } } }; struct OpenMPTransformData { util::OutlineResult outline; std::vector sharedVars; ReductionIdentifier reds; }; template OpenMPTransformData unpar(T *v) { v->setParallel(false); return {{}, {}, {}}; } template OpenMPTransformData setupOpenMPTransform(T *v, BodiedFunc *parent, bool gpu) { if (!v->isParallel()) return unpar(v); auto *M = v->getModule(); auto *body = cast(v->getBody()); if (!parent || !body) return unpar(v); auto outline = util::outlineRegion(parent, body, /*allowOutflows=*/false, /*outlineGlobals=*/true, /*allByValue=*/gpu); if (!outline) return unpar(v); // set up args to pass fork_call Var *loopVar = v->getVar(); std::vector outlineCallArgs(outline.call->begin(), outline.call->end()); // shared argument vars std::vector sharedVars; Var *loopVarArg = nullptr; unsigned i = 0; for (auto it = outline.func->arg_begin(); it != outline.func->arg_end(); ++it) { // pick out loop variable to pass to reduction identifier, which will // ensure we don't reduce over it if (getVarFromOutlinedArg(outlineCallArgs[i])->getId() == loopVar->getId()) loopVarArg = *it; if (outline.argKinds[i] == util::OutlineResult::ArgKind::MODIFIED) sharedVars.push_back(*it); ++i; } ReductionIdentifier reds(sharedVars, loopVarArg); outline.func->accept(reds); return {outline, sharedVars, reds}; } struct ForkCallData { CallInstr *fork = nullptr; CallInstr *pushNumThreads = nullptr; }; ForkCallData createForkCall(Module *M, OMPTypes &types, Value *rawTemplateFunc, const std::vector &forkExtraArgs, transform::parallel::OMPSched *sched) { ForkCallData result; auto *forkExtra = util::makeTuple(forkExtraArgs, M); std::vector forkArgTypes = {types.i8ptr, forkExtra->getType()}; auto *forkFunc = M->getOrRealizeFunc("_fork_call", forkArgTypes, {}, ompModule); seqassertn(forkFunc, "fork call function not found"); result.fork = util::call(forkFunc, {rawTemplateFunc, forkExtra}); if (sched->threads && sched->threads->getType()->is(types.i64)) { auto *pushNumThreadsFunc = M->getOrRealizeFunc("_push_num_threads", {types.i64}, {}, ompModule); seqassertn(pushNumThreadsFunc, "push num threads func not found"); result.pushNumThreads = util::call(pushNumThreadsFunc, {sched->threads}); } return result; } struct CollapseResult { ImperativeForFlow *collapsed = nullptr; SeriesFlow *setup = nullptr; std::string error; operator bool() const { return collapsed != nullptr; } }; struct LoopRange { ImperativeForFlow *loop; Var *start; Var *stop; int64_t step; Var *len; }; CollapseResult collapseLoop(BodiedFunc *parent, ImperativeForFlow *v, int64_t levels) { auto fail = [](const std::string &error) { CollapseResult bad; bad.error = error; return bad; }; auto *M = v->getModule(); CollapseResult res; if (levels < 1) return fail("'collapse' must be at least 1"); std::vector loopNests = {v}; ImperativeForFlow *curr = v; for (auto i = 0; i < levels - 1; i++) { auto *body = cast(curr->getBody()); seqassertn(body, "unexpected loop body"); if (std::distance(body->begin(), body->end()) != 1 || !isA(body->front())) return fail("loop nest not collapsible"); curr = cast(body->front()); loopNests.push_back(curr); } std::vector ranges; auto *setup = M->Nr(); auto *intType = M->getIntType(); auto *lenCalc = M->getOrRealizeFunc("_range_len", {intType, intType, intType}, {}, ompModule); seqassertn(lenCalc, "range length calculation function not found"); for (auto *loop : loopNests) { LoopRange range; range.loop = loop; range.start = util::makeVar(loop->getStart(), setup, parent); range.stop = util::makeVar(loop->getEnd(), setup, parent); range.step = loop->getStep(); range.len = util::makeVar( util::call(lenCalc, {M->Nr(range.start), M->Nr(range.stop), M->getInt(range.step)}), setup, parent); ranges.push_back(range); } auto *numIters = M->getInt(1); for (auto &range : ranges) { numIters = (*numIters) * (*M->Nr(range.len)); } auto *collapsedVar = M->Nr(M->getIntType(), /*global=*/false); parent->push_back(collapsedVar); auto *body = M->Nr(); auto sched = std::make_unique(*v->getSchedule()); sched->collapse = 0; auto *collapsed = M->Nr(M->getInt(0), 1, numIters, body, collapsedVar, std::move(sched)); // reconstruct indices by successive divmods Var *lastDiv = nullptr; for (auto it = ranges.rbegin(); it != ranges.rend(); ++it) { auto *k = lastDiv ? lastDiv : collapsedVar; auto *div = util::makeVar(*M->Nr(k) / *M->Nr(it->len), body, parent); auto *mod = util::makeVar(*M->Nr(k) % *M->Nr(it->len), body, parent); auto *i = *M->Nr(it->start) + *(*M->Nr(mod) * *M->getInt(it->step)); body->push_back(M->Nr(it->loop->getVar(), i)); lastDiv = div; } auto *oldBody = cast(loopNests.back()->getBody()); for (auto *x : *oldBody) { body->push_back(x); } res.collapsed = collapsed; res.setup = setup; return res; } } // namespace const std::string OpenMPPass::KEY = "core-parallel-openmp"; void OpenMPPass::handle(ForFlow *v) { auto data = setupOpenMPTransform(v, cast(getParentFunc()), /*gpu=*/false); if (!v->isParallel()) return; auto &outline = data.outline; auto &sharedVars = data.sharedVars; auto &reds = data.reds; auto *M = v->getModule(); auto *loopVar = v->getVar(); auto *sched = v->getSchedule(); OMPTypes types(M); // separate arguments into 'private' and 'shared' std::vector sharedRedux; // reductions corresponding to shared vars std::vector privates, shareds; unsigned i = 0; for (auto *arg : *outline.call) { if (isA(arg)) { privates.push_back(arg); } else { shareds.push_back(arg); sharedRedux.push_back(reds.getReduction(sharedVars[i++])); } } util::CloneVisitor cv(M); // We need to pass the task group identifier returned from // __kmpc_taskred_modifier_init to the task entry, so append // it to private data (initially as null void pointer). Also // we add an argument to the end of the outlined function for // the gtid. if (reds.reductions.size() > 0) { auto *nullPtr = types.i8ptr->construct({}); privates.push_back(nullPtr); auto *outlinedFuncType = cast(outline.func->getType()); std::vector argTypes(outlinedFuncType->begin(), outlinedFuncType->end()); argTypes.push_back(M->getIntType()); auto *retType = outlinedFuncType->getReturnType(); std::vector oldArgVars(outline.func->arg_begin(), outline.func->arg_end()); std::vector argNames; for (auto *var : oldArgVars) { argNames.push_back(var->getName()); } argNames.push_back("gtid"); auto *newOutlinedFunc = M->Nr("__outlined_new"); newOutlinedFunc->realize(M->getFuncType(retType, argTypes), argNames); std::vector newArgVars(newOutlinedFunc->arg_begin(), newOutlinedFunc->arg_end()); std::unordered_map remaps; for (unsigned i = 0; i < oldArgVars.size(); i++) { remaps.emplace(oldArgVars[i]->getId(), newArgVars[i]); } auto *newBody = cast(cv.clone(outline.func->getBody(), newOutlinedFunc, remaps)); newOutlinedFunc->setBody(newBody); // update outline struct outline.func = newOutlinedFunc; outline.call->setCallee(M->Nr(newOutlinedFunc)); outline.call->insert(outline.call->end(), M->getInt(0)); outline.argKinds.push_back(util::OutlineResult::ArgKind::CONSTANT); } auto *privatesTuple = util::makeTuple(privates, M); auto *sharedsTuple = util::makeTuple(shareds, M); // template call std::vector templateFuncArgs = { types.i32ptr, types.i32ptr, M->getPointerType( M->getTupleType({v->getIter()->getType(), privatesTuple->getType(), sharedsTuple->getType()}))}; auto *templateFunc = M->getOrRealizeFunc("_task_loop_outline_template", templateFuncArgs, {}, ompModule); seqassertn(templateFunc, "task loop outline template not found"); templateFunc = cv.forceClone(templateFunc); TaskLoopRoutineStubReplacer rep(cast(templateFunc), outline.call, loopVar, &reds, privates, shareds, sharedRedux); templateFunc->accept(rep); auto *rawTemplateFunc = ptrFromFunc(templateFunc); std::vector forkExtraArgs = {v->getIter(), privatesTuple, sharedsTuple}; // fork call auto forkData = createForkCall(M, types, rawTemplateFunc, forkExtraArgs, sched); if (forkData.pushNumThreads) insertBefore(forkData.pushNumThreads); v->replaceAll(forkData.fork); } void OpenMPPass::handle(ImperativeForFlow *v) { auto *parent = cast(getParentFunc()); if (v->isParallel() && v->getSchedule()->collapse != 0) { auto levels = v->getSchedule()->collapse; auto collapse = collapseLoop(parent, v, levels); if (collapse) { v->replaceAll(collapse.collapsed); v = collapse.collapsed; insertBefore(collapse.setup); } else if (!collapse.error.empty()) { warn("could not collapse loop: " + collapse.error, v); } } auto data = setupOpenMPTransform(v, parent, (v->isParallel() && v->getSchedule()->gpu)); if (!v->isParallel()) return; auto &outline = data.outline; auto &sharedVars = data.sharedVars; auto &reds = data.reds; auto *M = v->getModule(); auto *loopVar = v->getVar(); auto *sched = v->getSchedule(); OMPTypes types(M); // we disable shared vars for GPU loops seqassertn(!(sched->gpu && !sharedVars.empty()), "GPU-parallel loop had shared vars"); // gather extra arguments std::vector extraArgs; std::vector extraArgTypes; for (auto *arg : *outline.call) { if (getVarFromOutlinedArg(arg)->getId() != loopVar->getId()) { extraArgs.push_back(arg); extraArgTypes.push_back(arg->getType()); } } // template call std::string templateFuncName; if (sched->gpu) { templateFuncName = "_gpu_loop_outline_template"; } else if (sched->dynamic) { templateFuncName = "_dynamic_loop_outline_template"; } else if (sched->chunk) { templateFuncName = "_static_chunked_loop_outline_template"; } else { templateFuncName = "_static_loop_outline_template"; } if (sched->gpu) { std::unordered_set kernels; const std::string gpuAttr = ast::getMangledFunc(gpuModule, "kernel"); for (auto *var : *M) { if (auto *func = cast(var)) { if (util::hasAttribute(func, gpuAttr)) { kernels.insert(func->getId()); } } } std::vector templateFuncArgs = {types.i64, types.i64, M->getTupleType(extraArgTypes)}; static int64_t instance = 0; auto *templateFunc = M->getOrRealizeFunc(templateFuncName, templateFuncArgs, {instance++}, gpuModule); if (!templateFunc) { warn("loop not compilable for GPU; ignoring", v); v->setParallel(false); return; } BodiedFunc *kernel = nullptr; for (auto *var : *M) { if (auto *func = cast(var)) { if (util::hasAttribute(func, gpuAttr) && kernels.count(func->getId()) == 0) { seqassertn(!kernel, "multiple new kernels found after instantiation"); kernel = func; } } } seqassertn(kernel, "no new kernel found"); GPULoopBodyStubReplacer brep(outline.call, loopVar, v->getStep()); kernel->accept(brep); util::CloneVisitor cv(M); templateFunc = cast(cv.forceClone(templateFunc)); GPULoopTemplateReplacer rep(cast(templateFunc), outline.call, loopVar, v->getStep()); templateFunc->accept(rep); v->replaceAll(util::call( templateFunc, {v->getStart(), v->getEnd(), util::makeTuple(extraArgs, M)})); } else { std::vector templateFuncArgs = { types.i32ptr, types.i32ptr, M->getPointerType(M->getTupleType( {types.i64, types.i64, types.i64, M->getTupleType(extraArgTypes)}))}; auto *templateFunc = M->getOrRealizeFunc(templateFuncName, templateFuncArgs, {}, ompModule); seqassertn(templateFunc, "imperative loop outline template not found"); util::CloneVisitor cv(M); templateFunc = cast(cv.forceClone(templateFunc)); ImperativeLoopTemplateReplacer rep(cast(templateFunc), outline.call, loopVar, &reds, sched, v->getStep()); templateFunc->accept(rep); auto *rawTemplateFunc = ptrFromFunc(templateFunc); auto *chunk = (sched->chunk && sched->chunk->getType()->is(types.i64)) ? sched->chunk : M->getInt(1); std::vector forkExtraArgs = {chunk, v->getStart(), v->getEnd()}; for (auto *arg : extraArgs) { forkExtraArgs.push_back(arg); } // fork call auto forkData = createForkCall(M, types, rawTemplateFunc, forkExtraArgs, sched); if (forkData.pushNumThreads) insertBefore(forkData.pushNumThreads); v->replaceAll(forkData.fork); } } } // namespace parallel } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/parallel/openmp.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace parallel { class OpenMPPass : public OperatorPass { public: /// Constructs an OpenMP pass. OpenMPPass() : OperatorPass(/*childrenFirst=*/true) {} static const std::string KEY; std::string getKey() const override { return KEY; } void handle(ForFlow *) override; void handle(ImperativeForFlow *) override; }; } // namespace parallel } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/parallel/schedule.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "schedule.h" #include "codon/cir/cir.h" #include "codon/cir/util/irtools.h" #include #include namespace codon { namespace ir { namespace transform { namespace parallel { namespace { int getScheduleCode(const std::string &schedule = "static", bool chunked = false, bool ordered = false, bool monotonic = false) { // codes from "enum sched_type" at // https://github.com/llvm/llvm-project/blob/main/openmp/runtime/src/kmp.h int modifier = monotonic ? (1 << 29) : (1 << 30); if (schedule == "static") { if (chunked) { if (ordered) return 65; else return 33; } else { if (ordered) return 66; else return 34; } } else if (schedule == "dynamic") { return (ordered ? 67 : 35) | modifier; } else if (schedule == "guided") { return (ordered ? 68 : 36) | modifier; } else if (schedule == "runtime") { return (ordered ? 69 : 37) | modifier; } else if (schedule == "auto") { return (ordered ? 70 : 38) | modifier; } return getScheduleCode(); // default } Value *nullIfNeg(Value *v) { if (v && util::isConst(v) && util::getConst(v) <= 0) return nullptr; return v; } } // namespace OMPSched::OMPSched(int code, bool dynamic, Value *threads, Value *chunk, bool ordered, int64_t collapse, bool gpu) : code(code), dynamic(dynamic), threads(nullIfNeg(threads)), chunk(nullIfNeg(chunk)), ordered(ordered), collapse(collapse), gpu(gpu) { if (code < 0) this->code = getScheduleCode(); } OMPSched::OMPSched(const std::string &schedule, Value *threads, Value *chunk, bool ordered, int64_t collapse, bool gpu) : OMPSched(getScheduleCode(schedule, nullIfNeg(chunk) != nullptr, ordered), (schedule != "static") || ordered, threads, chunk, ordered, collapse, gpu) {} std::vector OMPSched::getUsedValues() const { std::vector ret; if (threads) ret.push_back(threads); if (chunk) ret.push_back(chunk); return ret; } int OMPSched::replaceUsedValue(id_t id, Value *newValue) { auto count = 0; if (threads && threads->getId() == id) { threads = newValue; ++count; } if (chunk && chunk->getId() == id) { chunk = newValue; ++count; } return count; } } // namespace parallel } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/parallel/schedule.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/value.h" namespace codon { namespace ir { class Value; namespace transform { namespace parallel { struct OMPSched { int code; bool dynamic; Value *threads; Value *chunk; bool ordered; int64_t collapse; bool gpu; explicit OMPSched(int code = -1, bool dynamic = false, Value *threads = nullptr, Value *chunk = nullptr, bool ordered = false, int64_t collapse = 0, bool gpu = false); explicit OMPSched(const std::string &code, Value *threads = nullptr, Value *chunk = nullptr, bool ordered = false, int64_t collapse = 0, bool gpu = false); OMPSched(const OMPSched &s) : code(s.code), dynamic(s.dynamic), threads(s.threads), chunk(s.chunk), ordered(s.ordered), collapse(s.collapse), gpu(s.gpu) {} std::vector getUsedValues() const; int replaceUsedValue(id_t id, Value *newValue); }; } // namespace parallel } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/pass.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "pass.h" #include "codon/cir/transform/manager.h" namespace codon { namespace ir { namespace transform { analyze::Result *Pass::doGetAnalysis(const std::string &key) { return manager ? manager->getAnalysisResult(key) : nullptr; } void PassGroup::run(Module *module) { for (auto &p : passes) p->run(module); } void PassGroup::setManager(PassManager *mng) { Pass::setManager(mng); for (auto &p : passes) p->setManager(mng); } } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/pass.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/module.h" #include "codon/cir/util/operator.h" namespace codon { namespace ir { namespace analyze { struct Result; } namespace transform { class PassManager; /// General pass base class. class Pass { private: PassManager *manager = nullptr; public: virtual ~Pass() noexcept = default; /// @return a unique key for this pass virtual std::string getKey() const = 0; /// Execute the pass. /// @param module the module virtual void run(Module *module) = 0; /// Determine if pass should repeat. /// @param num how many times this pass has already run /// @return true if pass should repeat virtual bool shouldRepeat(int num) const { return false; } /// Sets the manager. /// @param mng the new manager virtual void setManager(PassManager *mng) { manager = mng; } /// Returns the result of a given analysis. /// @param key the analysis key /// @return the analysis result template AnalysisType *getAnalysisResult(const std::string &key) { return static_cast(doGetAnalysis(key)); } private: analyze::Result *doGetAnalysis(const std::string &key); }; class PassGroup : public Pass { private: int repeat; std::vector> passes; public: explicit PassGroup(int repeat = 0, std::vector> passes = {}) : Pass(), repeat(repeat), passes(std::move(passes)) {} virtual ~PassGroup() noexcept = default; void push_back(std::unique_ptr p) { passes.push_back(std::move(p)); } /// @return default number of times pass should repeat int getRepeat() const { return repeat; } /// Sets the default number of times pass should repeat. /// @param r number of repeats void setRepeat(int r) { repeat = r; } bool shouldRepeat(int num) const override { return num < repeat; } void run(Module *module) override; void setManager(PassManager *mng) override; }; /// Pass that runs a single Operator. class OperatorPass : public Pass, public util::Operator { public: /// Constructs an operator pass. /// @param childrenFirst true if children should be iterated first explicit OperatorPass(bool childrenFirst = false) : util::Operator(childrenFirst) {} void run(Module *module) override { reset(); process(module); } }; } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/pythonic/dict.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "dict.h" #include #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/matching.h" namespace codon { namespace ir { namespace transform { namespace pythonic { namespace { /// get or __getitem__ call metadata struct GetCall { /// the function, nullptr if not a get call Func *func = nullptr; /// the dictionary, must not be a call Value *dict = nullptr; /// the key, must not be a call Value *key = nullptr; /// the default value, may be null Const *dflt = nullptr; }; /// Identify the call and return its metadata. /// @param call the call /// @return the metadata GetCall analyzeGet(CallInstr *call) { // extract the function auto *func = util::getFunc(call->getCallee()); if (!func) return {}; auto unmangled = func->getUnmangledName(); // canonical get/__getitem__ calls have at least two arguments auto it = call->begin(); auto dist = std::distance(it, call->end()); if (dist < 2) return {}; // extract the dictionary and keys auto *dict = *it++; auto *k = *it++; // dictionary and key must not be calls if (isA(dict) || isA(k)) return {}; // get calls have a default if (unmangled == "get" && std::distance(it, call->end()) == 1) { auto *dflt = cast(*it); return {func, dict, k, dflt}; } else if (unmangled == "__getitem__" && std::distance(it, call->end()) == 0) { return {func, dict, k, nullptr}; } // call is not correct return {}; } } // namespace const std::string DictArithmeticOptimization::KEY = "core-pythonic-dict-arithmetic-opt"; void DictArithmeticOptimization::handle(CallInstr *v) { auto *M = v->getModule(); // get and check the exterior function (should be a __setitem__ with 3 args) auto *setFunc = util::getFunc(v->getCallee()); if (setFunc && setFunc->getUnmangledName() == "__setitem__" && std::distance(v->begin(), v->end()) == 3) { auto it = v->begin(); // extract all the arguments to the function // the dictionary and key must not be calls, and the value must // be a call auto *dictValue = *it++; auto *keyValue = *it++; if (isA(dictValue) || isA(keyValue)) return; auto *opCall = cast(*it++); // the call must take exactly two arguments if (!dictValue || !opCall || std::distance(opCall->begin(), opCall->end()) != 2) return; // grab the function, which needs to be an int or float call for now auto *opFunc = util::getFunc(opCall->getCallee()); auto *getCall = cast(opCall->front()); if (!opFunc || !getCall) return; auto *intType = M->getIntType(); auto *floatType = M->getFloatType(); auto *parentType = opFunc->getParentType(); if (!parentType || !(parentType->is(intType) || parentType->is(floatType))) return; // check the first argument auto getAnalysis = analyzeGet(getCall); if (!getAnalysis.func) return; // second argument can be any non-null value auto *secondValue = opCall->back(); // verify that we are dealing with the same dictionary and key if (util::match(dictValue, getAnalysis.dict, false, true) && util::match(keyValue, getAnalysis.key, false, true)) { util::CloneVisitor cv(M); Func *replacementFunc; // call non-throwing version if we have a default if (getAnalysis.dflt) { replacementFunc = M->getOrRealizeMethod( dictValue->getType(), "__dict_do_op__", {dictValue->getType(), keyValue->getType(), secondValue->getType(), getAnalysis.dflt->getType(), opFunc->getType()}); } else { replacementFunc = M->getOrRealizeMethod(dictValue->getType(), "__dict_do_op_throws__", {dictValue->getType(), keyValue->getType(), secondValue->getType(), opFunc->getType()}); } if (replacementFunc) { std::vector args = {cv.clone(dictValue), cv.clone(keyValue), cv.clone(secondValue)}; if (getAnalysis.dflt) args.push_back(cv.clone(getAnalysis.dflt)); // sanity check to make sure function is inlined if (args.size() != std::distance(replacementFunc->arg_begin(), replacementFunc->arg_end())) args.push_back(M->N(v, opFunc)); v->replaceAll(util::call(replacementFunc, args)); } } } } } // namespace pythonic } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/pythonic/dict.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace pythonic { /// Pass to optimize calls of form d[x] = func(d[x], any). /// This will work on any dictionary-like object that implements _do_op and /// _do_op_throws as well as getters. class DictArithmeticOptimization : public OperatorPass { public: static const std::string KEY; std::string getKey() const override { return KEY; } void handle(CallInstr *v) override; }; } // namespace pythonic } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/pythonic/generator.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "generator.h" #include #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/matching.h" namespace codon { namespace ir { namespace transform { namespace pythonic { namespace { bool isSum(Func *f) { return f && f->getName().rfind(ast::getMangledFunc("std.internal.builtin", "sum"), 0) == 0; } bool isAny(Func *f) { return f && f->getName().rfind(ast::getMangledFunc("std.internal.builtin", "any"), 0) == 0; } bool isAll(Func *f) { return f && f->getName().rfind(ast::getMangledFunc("std.internal.builtin", "all"), 0) == 0; } // Replaces yields with updates to the accumulator variable. struct GeneratorSumTransformer : public util::Operator { Var *accumulator; bool valid; explicit GeneratorSumTransformer(Var *accumulator) : util::Operator(), accumulator(accumulator), valid(true) {} void handle(YieldInstr *v) override { auto *M = v->getModule(); auto *val = v->getValue(); if (!val) { valid = false; return; } Value *rhs = val; if (val->getType()->is(M->getBoolType())) { rhs = M->Nr(rhs, M->getInt(1), M->getInt(0)); } Value *add = *M->Nr(accumulator) + *rhs; if (!add || !add->getType()->is(accumulator->getType())) { valid = false; return; } auto *assign = M->Nr(accumulator, add); v->replaceAll(assign); } void handle(ReturnInstr *v) override { auto *M = v->getModule(); auto *newReturn = M->Nr(M->Nr(accumulator)); see(newReturn); if (v->getValue()) { v->replaceAll(util::series(v->getValue(), newReturn)); } else { v->replaceAll(newReturn); } } void handle(YieldInInstr *v) override { valid = false; } }; // Replaces yields with conditional returns of the any/all answer. struct GeneratorAnyAllTransformer : public util::Operator { bool any; // true=any, false=all bool valid; explicit GeneratorAnyAllTransformer(bool any) : util::Operator(), any(any), valid(true) {} void handle(YieldInstr *v) override { auto *M = v->getModule(); auto *val = v->getValue(); auto *valBool = val ? (*M->getBoolType())(*val) : nullptr; if (!valBool) { valid = false; return; } else if (!any) { valBool = M->Nr(valBool, M->getBool(false), M->getBool(true)); } auto *newReturn = M->Nr(M->getBool(any)); see(newReturn); auto *rep = M->Nr(valBool, util::series(newReturn)); v->replaceAll(rep); } void handle(ReturnInstr *v) override { if (saw(v)) return; auto *M = v->getModule(); auto *newReturn = M->Nr(M->getBool(!any)); see(newReturn); if (v->getValue()) { v->replaceAll(util::series(v->getValue(), newReturn)); } else { v->replaceAll(newReturn); } } void handle(YieldInInstr *v) override { valid = false; } }; Func *genToSum(BodiedFunc *gen, types::Type *startType, types::Type *outType) { if (!gen || !gen->isGenerator()) return nullptr; auto *M = gen->getModule(); auto *genType = cast(gen->getType()); if (!genType) return nullptr; auto *fn = M->Nr("__sum_wrapper"); std::vector argTypes(genType->begin(), genType->end()); argTypes.push_back(startType); std::vector names; for (auto it = gen->arg_begin(); it != gen->arg_end(); ++it) { names.push_back((*it)->getName()); } names.push_back("start"); auto *fnType = M->getFuncType(outType, argTypes); fn->realize(fnType, names); std::unordered_map argRemap; for (auto it1 = gen->arg_begin(), it2 = fn->arg_begin(); it1 != gen->arg_end() && it2 != fn->arg_end(); ++it1, ++it2) { argRemap.emplace((*it1)->getId(), *it2); } util::CloneVisitor cv(M); auto *body = cast(cv.clone(gen->getBody(), fn, argRemap)); fn->setBody(body); Value *init = M->Nr(fn->arg_back()); if (startType->is(M->getIntType()) && outType->is(M->getFloatType())) init = (*M->getFloatType())(*init); if (!init || !init->getType()->is(outType)) { M->remove(fn); return nullptr; } auto *accumulator = util::makeVar(init, body, fn, /*prepend=*/true); GeneratorSumTransformer xgen(accumulator); fn->accept(xgen); body->push_back(M->Nr(M->Nr(accumulator))); if (!xgen.valid) { M->remove(fn); return nullptr; } return fn; } Func *genToAnyAll(BodiedFunc *gen, bool any) { if (!gen || !gen->isGenerator()) return nullptr; auto *M = gen->getModule(); auto *fn = M->Nr(any ? "__any_wrapper" : "__all_wrapper"); auto *genType = cast(gen->getType()); std::vector argTypes(genType->begin(), genType->end()); std::vector names; for (auto it = gen->arg_begin(); it != gen->arg_end(); ++it) { names.push_back((*it)->getName()); } auto *fnType = M->getFuncType(M->getBoolType(), argTypes); fn->realize(fnType, names); std::unordered_map argRemap; for (auto it1 = gen->arg_begin(), it2 = fn->arg_begin(); it1 != gen->arg_end() && it2 != fn->arg_end(); ++it1, ++it2) { argRemap.emplace((*it1)->getId(), *it2); } util::CloneVisitor cv(M); auto *body = cast(cv.clone(gen->getBody(), fn, argRemap)); fn->setBody(body); GeneratorAnyAllTransformer xgen(any); fn->accept(xgen); body->push_back(M->Nr(M->getBool(!any))); if (!xgen.valid) { M->remove(fn); return nullptr; } return fn; } } // namespace const std::string GeneratorArgumentOptimization::KEY = "core-pythonic-generator-argument-opt"; void GeneratorArgumentOptimization::handle(CallInstr *v) { auto *M = v->getModule(); auto *func = util::getFunc(v->getCallee()); if (isSum(func) && v->numArgs() == 2) { auto *call = cast(v->front()); if (!call) return; auto *gen = util::getFunc(call->getCallee()); auto *start = v->back(); if (auto *fn = genToSum(cast(gen), start->getType(), v->getType())) { std::vector args(call->begin(), call->end()); args.push_back(start); v->replaceAll(util::call(fn, args)); } } else { bool any = isAny(func), all = isAll(func); if (!(any || all) || v->numArgs() != 1 || !v->getType()->is(M->getBoolType())) return; auto *call = cast(v->front()); if (!call) return; auto *gen = util::getFunc(call->getCallee()); if (auto *fn = genToAnyAll(cast(gen), any)) { std::vector args(call->begin(), call->end()); v->replaceAll(util::call(fn, args)); } } } } // namespace pythonic } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/pythonic/generator.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace pythonic { /// Pass to optimize passing a generator to some built-in functions /// like sum(), any() or all(), which will be converted to regular /// for-loops. class GeneratorArgumentOptimization : public OperatorPass { public: static const std::string KEY; std::string getKey() const override { return KEY; } void handle(CallInstr *v) override; }; } // namespace pythonic } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/pythonic/io.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "io.h" #include #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" namespace codon { namespace ir { namespace transform { namespace pythonic { namespace { void optimizePrint(CallInstr *v) { auto *M = v->getModule(); auto *inner = cast(v->front()); if (!inner) return; auto *innerFunc = util::getFunc(inner->getCallee()); if (!innerFunc || innerFunc->getUnmangledName() != "__new__" || std::distance(inner->begin(), inner->end()) != 1) return; auto *cat = cast(inner->front()); if (!cat) return; auto *catFunc = util::getFunc(cat->getCallee()); if (!catFunc || catFunc->getUnmangledName() != "cat") return; auto *realCat = M->getOrRealizeMethod(M->getStringType(), "cat", {cat->front()->getType()}); if (realCat->getId() != catFunc->getId()) return; util::CloneVisitor cv(M); std::vector args; std::vector types; for (auto *printArg : *v) { args.push_back(cv.clone(printArg)); types.push_back(printArg->getType()); } args[0] = cv.clone(cat->front()); types[0] = args[0]->getType(); args[1] = M->getString(""); auto *replacement = M->getOrRealizeFunc("print", types, {}, "std.internal.builtin"); if (!replacement) return; v->replaceAll(util::call(replacement, args)); } void optimizeWrite(CallInstr *v) { auto *M = v->getModule(); auto it = v->begin(); auto *file = *it++; auto *cat = cast(*it++); if (!cat) return; auto *catFunc = util::getFunc(cat->getCallee()); if (!catFunc || catFunc->getUnmangledName() != "cat") return; auto *realCat = M->getOrRealizeMethod(M->getStringType(), "cat", {cat->front()->getType()}); if (realCat->getId() != catFunc->getId()) return; util::CloneVisitor cv(M); auto *iter = cv.clone(cat->front())->iter(); if (!iter) return; std::vector args = {cv.clone(file), iter}; auto *replacement = M->getOrRealizeMethod(file->getType(), "__file_write_gen__", {args[0]->getType(), args[1]->getType()}); if (!replacement) return; v->replaceAll(util::call(replacement, args)); } } // namespace const std::string IOCatOptimization::KEY = "core-pythonic-io-cat-opt"; void IOCatOptimization::handle(CallInstr *v) { if (util::getStdlibFunc(v->getCallee(), "print")) { optimizePrint(v); } else if (auto *f = cast(util::getFunc(v->getCallee()))) { if (f->getUnmangledName() == "write") optimizeWrite(v); } } } // namespace pythonic } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/pythonic/io.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace pythonic { /// Pass to optimize print str.cat(...) or file.write(str.cat(...)). class IOCatOptimization : public OperatorPass { public: static const std::string KEY; std::string getKey() const override { return KEY; } void handle(CallInstr *v) override; }; } // namespace pythonic } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/pythonic/list.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "list.h" #include #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" namespace codon { namespace ir { namespace transform { namespace pythonic { namespace { const std::string LIST = ast::getMangledClass("std.internal.types.array", "List"); const std::string SLICE = ast::getMangledClass("std.internal.types.slice", "Slice") + "[int,int,int]"; bool isList(Value *v) { return v->getType()->getName().rfind(LIST + "[", 0) == 0; } bool isSlice(Value *v) { return v->getType()->getName() == SLICE; } // The following "handlers" account for the possible sub-expressions we might // see when optimizing list1 + list2 + ... listN. Currently, we optimize: // - Slices: x[a:b:c] (avoid constructing the temporary sliced list) // - Literals: [a, b, c] (just append elements directly) // - Default: (append by iterating over the list) // It is easy to handle new sub-expression types by adding new handlers. // There are three stages in the optimized code: // - Setup: assign all the relevant expressions to variables, making // sure they're evaluated in the same order as before // - Count: figure out the total length of the resulting list // - Create: initialize a new list with the appropriate capacity and // append all the elements // The handlers have virtual functions to generate IR for each of these steps. struct ElementHandler { std::vector vars; ElementHandler() : vars() {} virtual ~ElementHandler() {} virtual void setup(SeriesFlow *block, BodiedFunc *parent) = 0; virtual Value *length(Module *M) = 0; virtual Value *append(Value *result) = 0; void doSetup(const std::vector &values, SeriesFlow *block, BodiedFunc *parent) { for (auto *v : values) { vars.push_back(util::makeVar(v, block, parent)); } } static std::unique_ptr get(Value *v, types::Type *ty); }; struct DefaultHandler : public ElementHandler { Value *element; DefaultHandler(Value *element) : ElementHandler(), element(element) {} void setup(SeriesFlow *block, BodiedFunc *parent) override { doSetup({element}, block, parent); } Value *length(Module *M) override { auto *e = M->Nr(vars[0]); auto *ty = element->getType(); auto *fn = M->getOrRealizeMethod(ty, "_list_add_opt_default_len", {ty}); seqassertn(fn, "could not find default list length helper"); return util::call(fn, {e}); } Value *append(Value *result) override { auto *M = result->getModule(); auto *e = M->Nr(vars[0]); auto *ty = result->getType(); auto *fn = M->getOrRealizeMethod(ty, "_list_add_opt_default_append", {ty, ty}); seqassertn(fn, "could not find default list append helper"); return util::call(fn, {result, e}); } static std::unique_ptr get(Value *v, types::Type *ty) { if (!v->getType()->is(ty)) return {}; return std::make_unique(v); } }; struct SliceHandler : public ElementHandler { Value *element; Value *slice; SliceHandler(Value *element, Value *slice) : ElementHandler(), element(element), slice(slice) {} void setup(SeriesFlow *block, BodiedFunc *parent) override { doSetup({element, slice}, block, parent); } Value *length(Module *M) override { auto *e = M->Nr(vars[0]); auto *s = M->Nr(vars[1]); auto *ty = element->getType(); auto *fn = M->getOrRealizeMethod(ty, "_list_add_opt_slice_len", {ty, slice->getType()}); seqassertn(fn, "could not find slice list length helper"); return util::call(fn, {e, s}); } Value *append(Value *result) override { auto *M = result->getModule(); auto *e = M->Nr(vars[0]); auto *s = M->Nr(vars[1]); auto *ty = result->getType(); auto *fn = M->getOrRealizeMethod(ty, "_list_add_opt_slice_append", {ty, ty, slice->getType()}); seqassertn(fn, "could not find slice list append helper"); return util::call(fn, {result, e, s}); } static std::unique_ptr get(Value *v, types::Type *ty) { if (!v->getType()->is(ty)) return {}; if (auto *c = cast(v)) { auto *func = util::getFunc(c->getCallee()); if (func && func->getUnmangledName() == Module::GETITEM_MAGIC_NAME && std::distance(c->begin(), c->end()) == 2 && isList(c->front()) && isSlice(c->back())) { return std::make_unique(c->front(), c->back()); } } return {}; } }; struct LiteralHandler : public ElementHandler { std::vector elements; LiteralHandler(std::vector elements) : ElementHandler(), elements(std::move(elements)) {} void setup(SeriesFlow *block, BodiedFunc *parent) override { doSetup(elements, block, parent); } Value *length(Module *M) override { return M->getInt(elements.size()); } Value *append(Value *result) override { auto *M = result->getModule(); auto *ty = result->getType(); auto *block = M->Nr(); if (vars.empty()) return block; auto *fn = M->getOrRealizeMethod(ty, "_list_add_opt_literal_append", {ty, elements[0]->getType()}); seqassertn(fn, "could not find literal list append helper"); for (auto *var : vars) { block->push_back(util::call(fn, {result, M->Nr(var)})); } return block; } static std::unique_ptr get(Value *v, types::Type *ty) { if (!v->getType()->is(ty)) return {}; if (auto *attr = v->getAttribute()) { std::vector elements; for (auto &element : attr->elements) { if (element.star) return {}; elements.push_back(element.value); } return std::make_unique(std::move(elements)); } return {}; } }; std::unique_ptr ElementHandler::get(Value *v, types::Type *ty) { if (auto h = SliceHandler::get(v, ty)) return std::move(h); if (auto h = LiteralHandler::get(v, ty)) return std::move(h); return DefaultHandler::get(v, ty); } struct InspectionResult { bool valid = true; std::vector args; }; void inspect(Value *v, InspectionResult &r) { // check if add first then go from there if (isList(v)) { if (auto *c = cast(v)) { auto *func = util::getFunc(c->getCallee()); if (func && func->getUnmangledName() == Module::ADD_MAGIC_NAME && c->numArgs() == 2 && isList(c->front()) && isList(c->back())) { inspect(c->front(), r); inspect(c->back(), r); return; } } r.args.push_back(v); } else { r.valid = false; } } Value *optimize(BodiedFunc *parent, InspectionResult &r) { if (!r.valid || r.args.size() <= 1) return nullptr; auto *M = parent->getModule(); auto *ty = r.args[0]->getType(); util::CloneVisitor cv(M); std::vector> handlers; for (auto *v : r.args) { handlers.push_back(ElementHandler::get(cv.clone(v), ty)); } auto *opt = M->Nr(); auto *len = util::makeVar(M->getInt(0), opt, parent); for (auto &h : handlers) { h->setup(opt, parent); } for (auto &h : handlers) { opt->push_back(M->Nr(len, *M->Nr(len) + *h->length(M))); } auto *fn = M->getOrRealizeMethod(ty, "_list_add_opt_opt_new", {M->getIntType()}); seqassertn(fn, "could not find list new helper"); auto *result = util::makeVar(util::call(fn, {M->Nr(len)}), opt, parent); for (auto &h : handlers) { opt->push_back(h->append(M->Nr(result))); } return M->Nr(opt, M->Nr(result)); } } // namespace const std::string ListAdditionOptimization::KEY = "core-pythonic-list-addition-opt"; void ListAdditionOptimization::handle(CallInstr *v) { auto *M = v->getModule(); auto *f = util::getFunc(v->getCallee()); if (!f || f->getUnmangledName() != Module::ADD_MAGIC_NAME) return; InspectionResult r; inspect(v, r); auto *parent = cast(getParentFunc()); if (auto *opt = optimize(parent, r)) v->replaceAll(opt); } } // namespace pythonic } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/pythonic/list.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace pythonic { /// Pass to optimize list1 + list2 + ... /// Also handles list slices and list literals efficiently. class ListAdditionOptimization : public OperatorPass { public: static const std::string KEY; std::string getKey() const override { return KEY; } void handle(CallInstr *v) override; }; } // namespace pythonic } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/pythonic/str.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "str.h" #include #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" namespace codon { namespace ir { namespace transform { namespace pythonic { namespace { struct InspectionResult { bool valid = true; std::vector args; }; bool isString(Value *v) { auto *M = v->getModule(); return v->getType()->is(M->getStringType()); } void inspect(Value *v, InspectionResult &r) { // check if add first then go from there if (isString(v)) { if (auto *c = cast(v)) { auto *func = util::getFunc(c->getCallee()); if (func && func->getUnmangledName() == Module::ADD_MAGIC_NAME && c->numArgs() == 2 && isString(c->front()) && isString(c->back())) { inspect(c->front(), r); inspect(c->back(), r); return; } } r.args.push_back(v); } else { r.valid = false; } } } // namespace const std::string StrAdditionOptimization::KEY = "core-pythonic-str-addition-opt"; void StrAdditionOptimization::handle(CallInstr *v) { auto *M = v->getModule(); auto *f = util::getFunc(v->getCallee()); if (!f || f->getUnmangledName() != Module::ADD_MAGIC_NAME) return; InspectionResult r; inspect(v, r); if (r.valid && r.args.size() > 2) { std::vector args; util::CloneVisitor cv(M); for (auto *arg : r.args) { args.push_back(cv.clone(arg)); } auto *arg = util::makeTuple(args, M); args = {arg}; auto *replacementFunc = M->getOrRealizeMethod(M->getStringType(), "cat", {arg->getType()}); seqassertn(replacementFunc, "could not find cat function [{}]", v->getSrcInfo()); v->replaceAll(util::call(replacementFunc, args)); } } } // namespace pythonic } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/pythonic/str.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" namespace codon { namespace ir { namespace transform { namespace pythonic { /// Pass to optimize str1 + str2 + ... class StrAdditionOptimization : public OperatorPass { public: static const std::string KEY; std::string getKey() const override { return KEY; } void handle(CallInstr *v) override; }; } // namespace pythonic } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/transform/rewrite.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/transform/pass.h" #include "codon/cir/util/visitor.h" namespace codon { namespace ir { namespace transform { /// Base for rewrite rules. class RewriteRule : public util::Visitor { private: Value *result = nullptr; protected: void defaultVisit(Node *) override {} void setResult(Value *r) { result = r; } void resetResult() { setResult(nullptr); } Value *getResult() const { return result; } public: virtual ~RewriteRule() noexcept = default; /// Apply the rule. /// @param v the value to rewrite /// @return nullptr if no rewrite, the replacement otherwise Value *apply(Value *v) { v->accept(*this); auto *replacement = getResult(); resetResult(); return replacement; } }; /// A collection of rewrite rules. class Rewriter { private: std::unordered_map> rules; int numReplacements = 0; public: /// Adds a given rewrite rule with the given key. /// @param key the rule's key /// @param rule the rewrite rule void registerRule(const std::string &key, std::unique_ptr rule) { rules.emplace(std::make_pair(key, std::move(rule))); } /// Applies all rewrite rules to the given node, and replaces the given /// node with the result of the rewrites. /// @param v the node to rewrite void rewrite(Value *v) { Value *result = v; for (auto &r : rules) { if (auto *rep = r.second->apply(result)) { ++numReplacements; result = rep; } } if (v != result) v->replaceAll(result); } /// @return the number of replacements int getNumReplacements() const { return numReplacements; } /// Sets the replacement count to zero. void reset() { numReplacements = 0; } }; } // namespace transform } // namespace ir } // namespace codon ================================================ FILE: codon/cir/types/types.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "types.h" #include #include #include #include "codon/cir/module.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/iterators.h" #include "codon/cir/util/visitor.h" #include "codon/cir/value.h" #include "codon/parser/cache.h" #include namespace codon { namespace ir { namespace types { namespace { std::vector extractTypes(const std::vector &gens) { std::vector ret; for (auto &g : gens) ret.push_back(g.type); return ret; } } // namespace const char Type::NodeId = 0; std::vector Type::doGetGenerics() const { if (!astType) return {}; std::vector ret; for (auto &g : astType->getClass()->generics) { if (auto ai = g.type->getIntStatic()) { ret.emplace_back(ai->value); } else if (auto ai = g.type->getBoolStatic()) { ret.emplace_back(int(ai->value)); } else if (auto as = g.type->getStrStatic()) { ret.emplace_back(as->value); } else if (auto ac = g.type->getClass()) { ret.emplace_back( getModule()->getCache()->realizeType(ac, extractTypes(ac->generics))); } else { seqassertn(false, "IR only supports int, bool or str statics [{}]", g.type->getSrcInfo()); } } return ret; } Value *Type::doConstruct(std::vector args) { auto *module = getModule(); std::vector argTypes; for (auto *a : args) argTypes.push_back(a->getType()); auto *fn = module->getOrRealizeMethod(this, Module::NEW_MAGIC_NAME, argTypes); if (!fn) return nullptr; return module->Nr(module->Nr(fn), args); } const char PrimitiveType::NodeId = 0; const char IntType::NodeId = 0; const char FloatType::NodeId = 0; const char Float32Type::NodeId = 0; const char Float16Type::NodeId = 0; const char BFloat16Type::NodeId = 0; const char Float128Type::NodeId = 0; const char BoolType::NodeId = 0; const char ByteType::NodeId = 0; const char VoidType::NodeId = 0; const char MemberedType::NodeId = 0; const char RecordType::NodeId = 0; RecordType::RecordType(std::string name, std::vector fieldTypes, std::vector fieldNames) : AcceptorExtend(std::move(name)) { for (auto i = 0; i < fieldTypes.size(); ++i) { fields.emplace_back(fieldNames[i], fieldTypes[i]); } } RecordType::RecordType(std::string name, std::vector mTypes) : AcceptorExtend(std::move(name)) { for (int i = 0; i < mTypes.size(); ++i) { fields.emplace_back(std::to_string(i + 1), mTypes[i]); } } std::vector RecordType::doGetUsedTypes() const { std::vector ret; for (auto &f : fields) ret.push_back(const_cast(f.getType())); return ret; } Type *RecordType::getMemberType(const std::string &n) const { auto it = std::find_if(fields.begin(), fields.end(), [n](auto &x) { return x.getName() == n; }); return (it != fields.end()) ? it->getType() : nullptr; } int RecordType::getMemberIndex(const std::string &n) const { auto it = std::find_if(fields.begin(), fields.end(), [n](auto &x) { return x.getName() == n; }); int index = std::distance(fields.begin(), it); return (index < fields.size()) ? index : -1; } void RecordType::realize(std::vector mTypes, std::vector mNames) { fields.clear(); for (auto i = 0; i < mTypes.size(); ++i) { fields.emplace_back(mNames[i], mTypes[i]); } } const char RefType::NodeId = 0; bool RefType::doIsContentAtomic() const { auto *contents = getContents(); return !std::any_of(contents->begin(), contents->end(), [](auto &field) { return field.getName().rfind(".__vtable__", 0) != 0 && !field.getType()->isAtomic(); }); } Value *RefType::doConstruct(std::vector args) { auto *module = getModule(); auto *argsTuple = util::makeTuple(args, module); auto *constructFn = module->getOrRealizeFunc("construct_ref", {argsTuple->getType()}, {this}, "std.internal.gc"); if (!constructFn) return nullptr; std::vector callArgs = {argsTuple}; return module->Nr(module->Nr(constructFn), callArgs); } const char FuncType::NodeId = 0; std::vector FuncType::doGetGenerics() const { auto t = getAstType(); if (!t) return {}; auto astType = t->getFunc(); if (!astType) return {}; std::vector ret; for (auto &g : astType->funcGenerics) { if (auto ai = g.type->getIntStatic()) { ret.emplace_back(ai->value); } else if (auto ai = g.type->getBoolStatic()) { ret.emplace_back(int(ai->value)); } else if (auto as = g.type->getStrStatic()) { ret.emplace_back(as->value); } else if (auto ac = g.type->getClass()) { ret.emplace_back( getModule()->getCache()->realizeType(ac, extractTypes(ac->generics))); } else { seqassertn(false, "IR only supports int, bool or str statics [{}]", g.type->getSrcInfo()); } } return ret; } std::vector FuncType::doGetUsedTypes() const { auto ret = argTypes; ret.push_back(rType); return ret; } const char DerivedType::NodeId = 0; const char PointerType::NodeId = 0; std::string PointerType::getInstanceName(Type *base) { return fmt::format(FMT_STRING("Pointer[{}]"), base->referenceString()); } const char OptionalType::NodeId = 0; std::string OptionalType::getInstanceName(Type *base) { return fmt::format(FMT_STRING("Optional[{}]"), base->referenceString()); } const char GeneratorType::NodeId = 0; std::string GeneratorType::getInstanceName(Type *base) { return fmt::format(FMT_STRING("Generator[{}]"), base->referenceString()); } const char IntNType::NodeId = 0; std::string IntNType::getInstanceName(unsigned int len, bool sign) { return fmt::format(FMT_STRING("{}Int{}"), sign ? "" : "U", len); } const char VectorType::NodeId = 0; std::string VectorType::getInstanceName(unsigned int count, PrimitiveType *base) { return fmt::format(FMT_STRING("Vector[{}, {}]"), count, base->referenceString()); } const char UnionType::NodeId = 0; std::string UnionType::getInstanceName(const std::vector &types) { std::vector names; for (auto *type : types) { names.push_back(type->referenceString()); } return fmt::format(FMT_STRING("Union[{}]"), fmt::join(names.begin(), names.end(), ", ")); } } // namespace types } // namespace ir } // namespace codon ================================================ FILE: codon/cir/types/types.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include "codon/cir/base.h" #include "codon/cir/util/packs.h" #include "codon/cir/util/visitor.h" #include "codon/parser/ast.h" #include #include namespace codon { namespace ir { class Value; namespace types { class Type; class Generic { private: union { int64_t staticValue; char *staticStringValue; types::Type *typeValue; } value; enum { STATIC, STATIC_STR, TYPE } tag; public: Generic(int64_t staticValue) : value(), tag(STATIC) { value.staticValue = staticValue; } Generic(const std::string &staticValue) : value(), tag(STATIC_STR) { value.staticStringValue = new char[staticValue.size() + 1]; strncpy(value.staticStringValue, staticValue.data(), staticValue.size()); value.staticStringValue[staticValue.size()] = 0; } Generic(types::Type *typeValue) : value(), tag(TYPE) { value.typeValue = typeValue; } Generic(const types::Generic &) = default; ~Generic() { // if (tag == STATIC_STR) // delete[] value.staticStringValue; } /// @return true if the generic is a type bool isType() const { return tag == TYPE; } /// @return true if the generic is static bool isStatic() const { return tag == STATIC; } /// @return true if the generic is static bool isStaticStr() const { return tag == STATIC_STR; } /// @return the static value int64_t getStaticValue() const { return value.staticValue; } /// @return the static string value std::string getStaticStringValue() const { return value.staticStringValue; } /// @return the type value types::Type *getTypeValue() const { return value.typeValue; } }; /// Type from which other CIR types derive. Generally types are immutable. class Type : public ReplaceableNodeBase { private: ast::types::TypePtr astType; public: static const char NodeId; using ReplaceableNodeBase::ReplaceableNodeBase; virtual ~Type() noexcept = default; std::vector getUsedTypes() const final { return getActual()->doGetUsedTypes(); } int replaceUsedType(const std::string &name, Type *newType) final { seqassertn(false, "types not replaceable"); return -1; } using Node::replaceUsedType; /// @param other another type /// @return true if this type is equal to the argument type bool is(types::Type *other) const { return getName() == other->getName(); } /// A type is "atomic" iff it contains no pointers to dynamically /// allocated memory. Atomic types do not need to be scanned during /// garbage collection. /// @return true if the type is atomic bool isAtomic() const { return getActual()->doIsAtomic(); } /// Checks if the contents (i.e. within an allocated block of memory) /// of a type are atomic. Currently only meaningful for reference types. /// @return true if the type's content is atomic bool isContentAtomic() const { return getActual()->doIsContentAtomic(); } /// @return the ast type ast::types::TypePtr getAstType() const { return getActual()->astType; } /// Sets the ast type. Should not generally be used. /// @param t the new type void setAstType(ast::types::TypePtr t) { getActual()->astType = std::move(t); } /// @return the generics used in the type std::vector getGenerics() const { return getActual()->doGetGenerics(); } /// Constructs an instance of the type given the supplied args. /// @param args the arguments /// @return the new value Value *construct(std::vector args) { return getActual()->doConstruct(std::move(args)); } template Value *operator()(Args &&...args) { std::vector dst; util::stripPack(dst, std::forward(args)...); return construct(dst); } private: virtual std::vector doGetGenerics() const; virtual std::vector doGetUsedTypes() const { return {}; } virtual bool doIsAtomic() const = 0; virtual bool doIsContentAtomic() const { return true; } virtual Value *doConstruct(std::vector args); }; /// Type from which primitive atomic types derive. class PrimitiveType : public AcceptorExtend { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; private: bool doIsAtomic() const final { return true; } }; /// Int type (64-bit signed integer) class IntType : public AcceptorExtend { public: static const char NodeId; /// Constructs an int type. IntType() : AcceptorExtend("int") {} }; /// Float type (64-bit double) class FloatType : public AcceptorExtend { public: static const char NodeId; /// Constructs a float type. FloatType() : AcceptorExtend("float") {} }; /// Float32 type (32-bit float) class Float32Type : public AcceptorExtend { public: static const char NodeId; /// Constructs a float32 type. Float32Type() : AcceptorExtend("float32") {} }; /// Float16 type (16-bit float) class Float16Type : public AcceptorExtend { public: static const char NodeId; /// Constructs a float16 type. Float16Type() : AcceptorExtend("float16") {} }; /// BFloat16 type (16-bit brain float) class BFloat16Type : public AcceptorExtend { public: static const char NodeId; /// Constructs a bfloat16 type. BFloat16Type() : AcceptorExtend("bfloat16") {} }; /// Float128 type (128-bit float) class Float128Type : public AcceptorExtend { public: static const char NodeId; /// Constructs a float128 type. Float128Type() : AcceptorExtend("float128") {} }; /// Bool type (8-bit unsigned integer; either 0 or 1) class BoolType : public AcceptorExtend { public: static const char NodeId; /// Constructs a bool type. BoolType() : AcceptorExtend("bool") {} }; /// Byte type (8-bit unsigned integer) class ByteType : public AcceptorExtend { public: static const char NodeId; /// Constructs a byte type. ByteType() : AcceptorExtend("byte") {} }; /// Void type class VoidType : public AcceptorExtend { public: static const char NodeId; /// Constructs a void type. VoidType() : AcceptorExtend("void") {} }; /// Type from which membered types derive. class MemberedType : public AcceptorExtend { public: static const char NodeId; /// Object that represents a field in a membered type. class Field { private: /// the field's name std::string name; /// the field's type Type *type; public: /// Constructs a field. /// @param name the field's name /// @param type the field's type Field(std::string name, Type *type) : name(std::move(name)), type(type) {} /// @return the field's name const std::string &getName() const { return name; } /// @return the field type Type *getType() const { return type; } }; using const_iterator = std::vector::const_iterator; using const_reference = std::vector::const_reference; /// Constructs a membered type. /// @param name the type's name explicit MemberedType(std::string name) : AcceptorExtend(std::move(name)) {} /// Gets a field type by name. /// @param name the field's name /// @return the type if it exists virtual Type *getMemberType(const std::string &name) const = 0; /// Gets the index of a field by name. /// @param name the field's name /// @return 0-based field index, or -1 if not found virtual int getMemberIndex(const std::string &name) const = 0; /// @return iterator to the first field virtual const_iterator begin() const = 0; /// @return iterator beyond the last field virtual const_iterator end() const = 0; /// @return a reference to the first field virtual const_reference front() const = 0; /// @return a reference to the last field virtual const_reference back() const = 0; /// Changes the body of the membered type. /// @param mTypes the new body /// @param mNames the new names virtual void realize(std::vector mTypes, std::vector mNames) = 0; }; /// Membered type equivalent to C structs/C++ PODs class RecordType : public AcceptorExtend { private: std::vector fields; public: static const char NodeId; /// Constructs a record type. /// @param name the type's name /// @param fieldTypes the member types /// @param fieldNames the member names RecordType(std::string name, std::vector fieldTypes, std::vector fieldNames); /// Constructs a record type. The field's names are "1", "2"... /// @param name the type's name /// @param mTypes a vector of member types RecordType(std::string name, std::vector mTypes); /// Constructs an empty record type. /// @param name the name explicit RecordType(std::string name) : AcceptorExtend(std::move(name)) {} Type *getMemberType(const std::string &n) const override; int getMemberIndex(const std::string &n) const override; const_iterator begin() const override { return fields.begin(); } const_iterator end() const override { return fields.end(); } const_reference front() const override { return fields.front(); } const_reference back() const override { return fields.back(); } void realize(std::vector mTypes, std::vector mNames) override; private: std::vector doGetUsedTypes() const override; bool doIsAtomic() const override { return !std::any_of(fields.begin(), fields.end(), [](auto &field) { return !field.getType()->isAtomic(); }); } }; /// Membered type that is passed by reference. Similar to Python classes. class RefType : public AcceptorExtend { private: /// the internal contents of the type Type *contents; /// true if type is polymorphic and needs RTTI bool polymorphic; public: static const char NodeId; /// Constructs a reference type. /// @param name the type's name /// @param contents the type's contents /// @param polymorphic true if type is polymorphic RefType(std::string name, RecordType *contents, bool polymorphic = false) : AcceptorExtend(std::move(name)), contents(contents), polymorphic(polymorphic) {} /// @return true if the type is polymorphic and needs RTTI bool isPolymorphic() const { return polymorphic; } /// Sets whether the type is polymorphic. Should not generally be used. /// @param p true if polymorphic void setPolymorphic(bool p = true) { polymorphic = p; } Type *getMemberType(const std::string &n) const override { return getContents()->getMemberType(n); } int getMemberIndex(const std::string &n) const override { return getContents()->getMemberIndex(n); } const_iterator begin() const override { return getContents()->begin(); } const_iterator end() const override { return getContents()->end(); } const_reference front() const override { return getContents()->front(); } const_reference back() const override { return getContents()->back(); } /// @return the reference type's contents RecordType *getContents() const { return cast(contents); } /// Sets the reference type's contents. Should not generally be used. /// @param t the new contents void setContents(RecordType *t) { contents = t; } void realize(std::vector mTypes, std::vector mNames) override { getContents()->realize(std::move(mTypes), std::move(mNames)); } private: std::vector doGetUsedTypes() const override { return {contents}; } bool doIsAtomic() const override { return false; } bool doIsContentAtomic() const override; Value *doConstruct(std::vector args) override; }; /// Type associated with a CIR function. class FuncType : public AcceptorExtend { public: using const_iterator = std::vector::const_iterator; using const_reference = std::vector::const_reference; private: /// return type Type *rType; /// argument types std::vector argTypes; /// whether the function is variadic (e.g. "printf" in C) bool variadic; public: static const char NodeId; /// Constructs a function type. /// @param rType the function's return type /// @param argTypes the function's arg types FuncType(std::string name, Type *rType, std::vector argTypes, bool variadic = false) : AcceptorExtend(std::move(name)), rType(rType), argTypes(std::move(argTypes)), variadic(variadic) {} /// @return the function's return type Type *getReturnType() const { return rType; } /// @return true if the function is variadic bool isVariadic() const { return variadic; } /// @return iterator to the first argument const_iterator begin() const { return argTypes.begin(); } /// @return iterator beyond the last argument const_iterator end() const { return argTypes.end(); } /// @return a reference to the first argument const_reference front() const { return argTypes.front(); } /// @return a reference to the last argument const_reference back() const { return argTypes.back(); } private: std::vector doGetGenerics() const override; std::vector doGetUsedTypes() const override; bool doIsAtomic() const override { return false; } }; /// Base for simple derived types. class DerivedType : public AcceptorExtend { private: /// the base type Type *base; public: static const char NodeId; /// Constructs a derived type. /// @param name the type's name /// @param base the type's base explicit DerivedType(std::string name, Type *base) : AcceptorExtend(std::move(name)), base(base) {} /// @return the type's base Type *getBase() const { return base; } private: bool doIsAtomic() const override { return base->isAtomic(); } std::vector doGetUsedTypes() const override { return {base}; } }; /// Type of a pointer to another CIR type class PointerType : public AcceptorExtend { public: static const char NodeId; /// Constructs a pointer type. /// @param base the type's base explicit PointerType(Type *base) : AcceptorExtend(getInstanceName(base), base) {} static std::string getInstanceName(Type *base); private: bool doIsAtomic() const override { return false; } }; /// Type of an optional containing another CIR type class OptionalType : public AcceptorExtend { public: static const char NodeId; /// Constructs an optional type. /// @param base the type's base explicit OptionalType(Type *base) : AcceptorExtend(getInstanceName(base), base) {} static std::string getInstanceName(Type *base); private: bool doIsAtomic() const override { return getBase()->isAtomic(); } }; /// Type of a generator yielding another CIR type class GeneratorType : public AcceptorExtend { public: static const char NodeId; /// Constructs a generator type. /// @param base the type's base explicit GeneratorType(Type *base) : AcceptorExtend(getInstanceName(base), base) {} static std::string getInstanceName(Type *base); private: bool doIsAtomic() const override { return false; } }; /// Type of a variably sized integer class IntNType : public AcceptorExtend { private: /// length of the integer unsigned len; /// whether the variable is signed bool sign; public: static const char NodeId; static const unsigned MAX_LEN = 2048; /// Constructs a variably sized integer type. /// @param len the length of the integer /// @param sign true if signed, false otherwise IntNType(unsigned len, bool sign) : AcceptorExtend(getInstanceName(len, sign)), len(len), sign(sign) {} /// @return the length of the integer unsigned getLen() const { return len; } /// @return true if signed bool isSigned() const { return sign; } /// @return the name of the opposite signed corresponding type std::string oppositeSignName() const { return getInstanceName(len, !sign); } static std::string getInstanceName(unsigned len, bool sign); }; /// Type of a vector of primitives class VectorType : public AcceptorExtend { private: /// number of elements unsigned count; /// base type PrimitiveType *base; public: static const char NodeId; /// Constructs a vector type. /// @param count the number of elements /// @param base the base type VectorType(unsigned count, PrimitiveType *base) : AcceptorExtend(getInstanceName(count, base)), count(count), base(base) {} /// @return the count of the vector unsigned getCount() const { return count; } /// @return the base type of the vector PrimitiveType *getBase() const { return base; } static std::string getInstanceName(unsigned count, PrimitiveType *base); }; class UnionType : public AcceptorExtend { private: /// alternative types std::vector types; public: static const char NodeId; using const_iterator = std::vector::const_iterator; using const_reference = std::vector::const_reference; /// Constructs a UnionType. /// @param types the alternative types (must be sorted by caller) explicit UnionType(std::vector types) : AcceptorExtend(), types(std::move(types)) {} const_iterator begin() const { return types.begin(); } const_iterator end() const { return types.end(); } const_reference front() const { return types.front(); } const_reference back() const { return types.back(); } static std::string getInstanceName(const std::vector &types); private: std::vector doGetUsedTypes() const override { return types; } bool doIsAtomic() const override { return !std::any_of(types.begin(), types.end(), [](auto *type) { return !type->isAtomic(); }); } }; } // namespace types } // namespace ir } // namespace codon template <> struct fmt::formatter : fmt::ostream_formatter {}; ================================================ FILE: codon/cir/util/cloning.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "cloning.h" #include "codon/cir/util/operator.h" namespace codon { namespace ir { namespace util { namespace { struct GatherLocals : public util::Operator { std::vector locals; void preHook(Node *node) override { for (auto *v : node->getUsedVariables()) { if (!v->isGlobal()) locals.push_back(v); } } }; } // namespace Value *CloneVisitor::clone(const Value *other, BodiedFunc *cloneTo, const std::unordered_map &remaps) { if (!other) return nullptr; if (cloneTo) { auto *M = cloneTo->getModule(); GatherLocals gl; const_cast(other)->accept(gl); for (auto *v : gl.locals) { auto it = remaps.find(v->getId()); if (it != remaps.end()) { forceRemap(v, it->second); } else { auto *clonedVar = M->N(v, v->getType(), v->isGlobal(), v->isExternal(), v->isThreadLocal(), v->getName()); cloneTo->push_back(clonedVar); forceRemap(v, clonedVar); } } } else { auto *M = other->getModule(); for (const auto &e : remaps) { forceRemap(M->getVar(e.first), e.second); } } auto id = other->getId(); if (ctx.find(id) == ctx.end()) { other->accept(*this); ctx[id] = result; for (auto it = other->attributes_begin(); it != other->attributes_end(); ++it) { const auto *attr = other->getAttribute(*it); if (attr->needsClone()) { ctx[id]->setAttribute(attr->clone(*this), *it); } } } return cast(ctx[id]); } Var *CloneVisitor::clone(const Var *other) { if (!other) return nullptr; auto id = other->getId(); if (ctx.find(id) != ctx.end()) return cast(ctx[id]); return const_cast(other); } void CloneVisitor::visit(const Var *v) { result = module->N(v, v->getType(), v->isGlobal(), v->isExternal(), v->isThreadLocal(), v->getName()); } void CloneVisitor::visit(const BodiedFunc *v) { auto *res = Nt(v); std::vector argNames; for (auto it = v->arg_begin(); it != v->arg_end(); ++it) argNames.push_back((*it)->getName()); for (const auto *var : *v) { auto *newVar = forceClone(var); res->push_back(newVar); } res->setUnmangledName(v->getUnmangledName()); res->setGenerator(v->isGenerator()); res->setAsync(v->isAsync()); res->realize(cast(v->getType()), argNames); auto argIt1 = v->arg_begin(); auto argIt2 = res->arg_begin(); while (argIt1 != v->arg_end()) { forceRemap(*argIt1, *argIt2); ++argIt1; ++argIt2; } // body might reference this! forceRemap(v, res); if (v->getBody()) res->setBody(clone(v->getBody())); res->setJIT(v->isJIT()); result = res; } void CloneVisitor::visit(const ExternalFunc *v) { auto *res = Nt(v); std::vector argNames; for (auto it = v->arg_begin(); it != v->arg_end(); ++it) argNames.push_back((*it)->getName()); res->setUnmangledName(v->getUnmangledName()); res->setGenerator(v->isGenerator()); res->setAsync(v->isAsync()); res->realize(cast(v->getType()), argNames); auto argIt1 = v->arg_begin(); auto argIt2 = res->arg_begin(); while (argIt1 != v->arg_end()) { forceRemap(*argIt1, *argIt2); ++argIt1; ++argIt2; } result = res; } void CloneVisitor::visit(const InternalFunc *v) { auto *res = Nt(v); std::vector argNames; for (auto it = v->arg_begin(); it != v->arg_end(); ++it) argNames.push_back((*it)->getName()); res->setUnmangledName(v->getUnmangledName()); res->setGenerator(v->isGenerator()); res->setAsync(v->isAsync()); res->realize(cast(v->getType()), argNames); auto argIt1 = v->arg_begin(); auto argIt2 = res->arg_begin(); while (argIt1 != v->arg_end()) { forceRemap(*argIt1, *argIt2); ++argIt1; ++argIt2; } res->setParentType(v->getParentType()); result = res; } void CloneVisitor::visit(const LLVMFunc *v) { auto *res = Nt(v); std::vector argNames; for (auto it = v->arg_begin(); it != v->arg_end(); ++it) argNames.push_back((*it)->getName()); res->setUnmangledName(v->getUnmangledName()); res->setGenerator(v->isGenerator()); res->setAsync(v->isAsync()); res->realize(cast(v->getType()), argNames); auto argIt1 = v->arg_begin(); auto argIt2 = res->arg_begin(); while (argIt1 != v->arg_end()) { forceRemap(*argIt1, *argIt2); ++argIt1; ++argIt2; } res->setLLVMBody(v->getLLVMBody()); res->setLLVMDeclarations(v->getLLVMDeclarations()); res->setLLVMLiterals( std::vector(v->literal_begin(), v->literal_end())); result = res; } void CloneVisitor::visit(const VarValue *v) { result = Nt(v, clone(v->getVar())); } void CloneVisitor::visit(const PointerValue *v) { result = Nt(v, clone(v->getVar()), v->getFields()); } void CloneVisitor::visit(const SeriesFlow *v) { auto *res = Nt(v); for (auto *c : *v) res->push_back(clone(c)); result = res; } void CloneVisitor::visit(const IfFlow *v) { result = Nt(v, clone(v->getCond()), clone(v->getTrueBranch()), clone(v->getFalseBranch())); } void CloneVisitor::visit(const WhileFlow *v) { auto *loop = Nt(v, nullptr, nullptr); forceRemap(v, loop); loop->setCond(clone(v->getCond())); loop->setBody(clone(v->getBody())); result = loop; } void CloneVisitor::visit(const ForFlow *v) { auto *loop = Nt(v, nullptr, nullptr, nullptr, std::unique_ptr(), false); forceRemap(v, loop); loop->setIter(clone(v->getIter())); loop->setBody(clone(v->getBody())); loop->setVar(clone(v->getVar())); if (auto *sched = v->getSchedule()) { auto schedCloned = std::make_unique(*sched); for (auto *val : sched->getUsedValues()) { schedCloned->replaceUsedValue(val->getId(), clone(val)); } loop->setSchedule(std::move(schedCloned)); } loop->setAsync(v->isAsync()); result = loop; } void CloneVisitor::visit(const ImperativeForFlow *v) { auto *loop = Nt(v, nullptr, v->getStep(), nullptr, nullptr, nullptr, std::unique_ptr()); forceRemap(v, loop); loop->setStart(clone(v->getStart())); loop->setBody(clone(v->getBody())); loop->setVar(clone(v->getVar())); loop->setEnd(clone(v->getEnd())); if (auto *sched = v->getSchedule()) { auto schedCloned = std::make_unique(*sched); for (auto *val : sched->getUsedValues()) { schedCloned->replaceUsedValue(val->getId(), clone(val)); } loop->setSchedule(std::move(schedCloned)); } result = loop; } void CloneVisitor::visit(const TryCatchFlow *v) { auto *res = Nt(v, clone(v->getBody()), clone(v->getFinally()), clone(v->getElse())); for (auto &c : *v) { res->emplace_back(clone(c.getHandler()), c.getType(), clone(c.getVar())); } result = res; } void CloneVisitor::visit(const PipelineFlow *v) { std::vector cloned; for (const auto &s : *v) { cloned.push_back(clone(s)); } result = Nt(v, std::move(cloned)); } void CloneVisitor::visit(const dsl::CustomFlow *v) { result = v->doClone(*this); } void CloneVisitor::visit(const IntConst *v) { result = Nt(v, v->getVal(), v->getType()); } void CloneVisitor::visit(const FloatConst *v) { result = Nt(v, v->getVal(), v->getType()); } void CloneVisitor::visit(const BoolConst *v) { result = Nt(v, v->getVal(), v->getType()); } void CloneVisitor::visit(const StringConst *v) { result = Nt(v, v->getVal(), v->getType()); } void CloneVisitor::visit(const dsl::CustomConst *v) { result = v->doClone(*this); } void CloneVisitor::visit(const AssignInstr *v) { result = Nt(v, clone(v->getLhs()), clone(v->getRhs())); } void CloneVisitor::visit(const ExtractInstr *v) { result = Nt(v, clone(v->getVal()), v->getField()); } void CloneVisitor::visit(const InsertInstr *v) { result = Nt(v, clone(v->getLhs()), v->getField(), clone(v->getRhs())); } void CloneVisitor::visit(const CallInstr *v) { std::vector args; for (const auto *a : *v) args.push_back(clone(a)); result = Nt(v, clone(v->getCallee()), std::move(args)); } void CloneVisitor::visit(const StackAllocInstr *v) { result = Nt(v, v->getArrayType(), v->getCount()); } void CloneVisitor::visit(const TypePropertyInstr *v) { result = Nt(v, v->getInspectType(), v->getProperty()); } void CloneVisitor::visit(const YieldInInstr *v) { result = Nt(v, v->getType(), v->isSuspending()); } void CloneVisitor::visit(const TernaryInstr *v) { result = Nt(v, clone(v->getCond()), clone(v->getTrueValue()), clone(v->getFalseValue())); } void CloneVisitor::visit(const BreakInstr *v) { result = Nt(v, cloneLoop ? clone(v->getLoop()) : v->getLoop()); } void CloneVisitor::visit(const ContinueInstr *v) { result = Nt(v, cloneLoop ? clone(v->getLoop()) : v->getLoop()); } void CloneVisitor::visit(const ReturnInstr *v) { result = Nt(v, clone(v->getValue())); } void CloneVisitor::visit(const YieldInstr *v) { result = Nt(v, clone(v->getValue()), v->isFinal()); } void CloneVisitor::visit(const AwaitInstr *v) { result = Nt(v, clone(v->getValue()), v->getType(), v->isGenerator()); } void CloneVisitor::visit(const ThrowInstr *v) { result = Nt(v, clone(v->getValue())); } void CloneVisitor::visit(const FlowInstr *v) { result = Nt(v, clone(v->getFlow()), clone(v->getValue())); } void CloneVisitor::visit(const dsl::CustomInstr *v) { result = v->doClone(*this); } } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/cloning.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/cir/cir.h" #include "codon/cir/util/visitor.h" namespace codon { namespace ir { namespace util { class CloneVisitor : public ConstVisitor { private: /// the clone context std::unordered_map ctx; /// the result Node *result; /// the module Module *module; /// true if break/continue loops should be cloned bool cloneLoop; public: /// Constructs a clone visitor. /// @param module the module /// @param cloneLoop true if break/continue loops should be cloned explicit CloneVisitor(Module *module, bool cloneLoop = true) : ctx(), result(nullptr), module(module), cloneLoop(cloneLoop) {} virtual ~CloneVisitor() noexcept = default; void visit(const Var *v) override; void visit(const BodiedFunc *v) override; void visit(const ExternalFunc *v) override; void visit(const InternalFunc *v) override; void visit(const LLVMFunc *v) override; void visit(const VarValue *v) override; void visit(const PointerValue *v) override; void visit(const SeriesFlow *v) override; void visit(const IfFlow *v) override; void visit(const WhileFlow *v) override; void visit(const ForFlow *v) override; void visit(const ImperativeForFlow *v) override; void visit(const TryCatchFlow *v) override; void visit(const PipelineFlow *v) override; void visit(const dsl::CustomFlow *v) override; void visit(const IntConst *v) override; void visit(const FloatConst *v) override; void visit(const BoolConst *v) override; void visit(const StringConst *v) override; void visit(const dsl::CustomConst *v) override; void visit(const AssignInstr *v) override; void visit(const ExtractInstr *v) override; void visit(const InsertInstr *v) override; void visit(const CallInstr *v) override; void visit(const StackAllocInstr *v) override; void visit(const TypePropertyInstr *v) override; void visit(const YieldInInstr *v) override; void visit(const TernaryInstr *v) override; void visit(const BreakInstr *v) override; void visit(const ContinueInstr *v) override; void visit(const ReturnInstr *v) override; void visit(const YieldInstr *v) override; void visit(const AwaitInstr *v) override; void visit(const ThrowInstr *v) override; void visit(const FlowInstr *v) override; void visit(const dsl::CustomInstr *v) override; /// Clones a value, returning the previous value if other has already been cloned. /// @param other the original /// @param cloneTo the function to clone locals to, or null if none /// @param remaps variable re-mappings /// @return the clone Value *clone(const Value *other, BodiedFunc *cloneTo = nullptr, const std::unordered_map &remaps = {}); /// Returns the original unless the variable has been force cloned. /// @param other the original /// @return the original or the previous clone Var *clone(const Var *other); /// Clones a flow, returning the previous value if other has already been cloned. /// @param other the original /// @return the clone Flow *clone(const Flow *other) { return cast(clone(static_cast(other))); } /// Forces a clone. No difference for values but ensures that variables are actually /// cloned. /// @param other the original /// @return the clone template NodeType *forceClone(const NodeType *other) { if (!other) return nullptr; auto id = other->getId(); if (ctx.find(id) == ctx.end()) { other->accept(*this); ctx[id] = result; for (auto it = other->attributes_begin(); it != other->attributes_end(); ++it) { const auto *attr = other->getAttribute(*it); if (attr->needsClone()) { ctx[id]->setAttribute(attr->forceClone(*this), *it); } } } return cast(ctx[id]); } /// Remaps a clone. /// @param original the original /// @param newVal the clone template void forceRemap(const NodeType *original, const NodeType *newVal) { ctx[original->getId()] = const_cast(newVal); } PipelineFlow::Stage clone(const PipelineFlow::Stage &other) { std::vector args; for (const auto *a : other) args.push_back(clone(a)); return {clone(other.getCallee()), std::move(args), other.isGenerator(), other.isParallel()}; } private: template NodeType *Nt(const NodeType *source, Args... args) { return module->N(source, std::forward(args)..., source->getName()); } }; } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/context.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include namespace codon { namespace ir { namespace util { /// Base for CIR visitor contexts. template class CIRContext { private: std::vector frames; public: /// Emplaces a frame onto the stack. /// @param args a parameter pack of the arguments template void emplaceFrame(Args... args) { frames.emplace_back(args...); } /// Replaces a frame. /// @param newFrame the new frame void replaceFrame(Frame newFrame) { frames.pop_back(); frames.push_back(newFrame); } /// @return all frames std::vector &getFrames() { return frames; } /// @return the current frame Frame &getFrame() { return frames.back(); } /// Pops a frame. void popFrame() { return frames.pop_back(); } }; } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/format.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include #include #include "codon/cir/util/format.h" #include "codon/cir/util/visitor.h" namespace codon { namespace ir { namespace util { struct NodeFormatter { const types::Type *type = nullptr; const Value *value = nullptr; const Var *var = nullptr; bool canShowFull = false; std::unordered_set &seenNodes; std::unordered_set &seenTypes; NodeFormatter(const types::Type *type, std::unordered_set &seenNodes, std::unordered_set &seenTypes) : type(type), seenNodes(seenNodes), seenTypes(seenTypes) {} NodeFormatter(const Value *value, std::unordered_set &seenNodes, std::unordered_set &seenTypes) : value(value), seenNodes(seenNodes), seenTypes(seenTypes) {} NodeFormatter(const Var *var, std::unordered_set &seenNodes, std::unordered_set &seenTypes) : var(var), seenNodes(seenNodes), seenTypes(seenTypes) {} friend std::ostream &operator<<(std::ostream &os, const NodeFormatter &n); }; namespace { std::string escapeString(const std::string &str) { std::stringstream escaped; for (char c : str) { switch (c) { case '\a': escaped << "\\a"; break; case '\b': escaped << "\\b"; break; case '\f': escaped << "\\f"; break; case '\n': escaped << "\\n"; break; case '\r': escaped << "\\r"; break; case '\t': escaped << "\\t"; break; case '\v': escaped << "\\v"; break; case '\\': escaped << "\\\\"; break; case '\'': escaped << "\\'"; break; case '\"': escaped << "\\\""; break; case '\?': escaped << "\\\?"; break; default: escaped << c; } } return escaped.str(); } class FormatVisitor : util::ConstVisitor { private: std::ostream &os; std::unordered_set &seenNodes; std::unordered_set &seenTypes; public: FormatVisitor(std::ostream &os, std::unordered_set &seenNodes, std::unordered_set &seenTypes) : os(os), seenNodes(seenNodes), seenTypes(seenTypes) {} virtual ~FormatVisitor() noexcept = default; void visit(const Module *v) override { auto types = makeFormatters(v->types_begin(), v->types_end(), true); auto vars = makeFormatters(v->begin(), v->end(), true); fmt::print(os, FMT_STRING("(module\n(argv {})\n(types {})\n(vars {})\n{})"), makeFormatter(v->getArgVar(), true), fmt::join(types.begin(), types.end(), "\n"), fmt::join(vars.begin(), vars.end(), "\n"), makeFormatter(v->getMainFunc(), true)); } void defaultVisit(const Node *) override { os << "(unknown_node)"; } void visit(const Var *v) override { fmt::print( os, FMT_STRING("(var '\"{}\" {} (global {}) (external {}) (thread-local {}))"), v->referenceString(), makeFormatter(v->getType()), v->isGlobal(), v->isExternal(), v->isThreadLocal()); } void visit(const BodiedFunc *v) override { auto args = makeFormatters(v->arg_begin(), v->arg_end(), true); auto symbols = makeFormatters(v->begin(), v->end(), true); fmt::print(os, FMT_STRING("(bodied_func '\"{}\" {}\n(args {})\n(vars {})\n{})"), v->referenceString(), makeFormatter(v->getType()), fmt::join(args.begin(), args.end(), " "), fmt::join(symbols.begin(), symbols.end(), " "), makeFormatter(v->getBody())); } void visit(const ExternalFunc *v) override { fmt::print(os, FMT_STRING("(external_func '\"{}\" {})"), v->referenceString(), makeFormatter(v->getType())); } void visit(const InternalFunc *v) override { fmt::print(os, FMT_STRING("(internal_func '\"{}\" {})"), v->referenceString(), makeFormatter(v->getType())); } void visit(const LLVMFunc *v) override { std::vector literals; for (auto it = v->literal_begin(); it != v->literal_end(); ++it) { const auto &l = *it; if (l.isStatic()) { literals.push_back(fmt::format(FMT_STRING("(static {})"), l.getStaticValue())); } else if (l.isStaticStr()) { literals.push_back( fmt::format(FMT_STRING("(static \"{}\")"), l.getStaticStringValue())); } else { literals.push_back( fmt::format(FMT_STRING("(type {})"), makeFormatter(l.getTypeValue()))); } } fmt::print(os, FMT_STRING("(llvm_func '\"{}\" {}\n(decls \"{}\")\n" "\"{}\"\n(literals {}))"), v->referenceString(), makeFormatter(v->getType()), escapeString(v->getLLVMDeclarations()), escapeString(v->getLLVMBody()), fmt::join(literals.begin(), literals.end(), "\n")); } void visit(const VarValue *v) override { fmt::print(os, FMT_STRING("'\"{}\""), v->getVar()->referenceString()); } void visit(const PointerValue *v) override { fmt::print(os, FMT_STRING("(ptr '\"{}\" \"{}\")"), v->getVar()->referenceString(), fmt::join(v->getFields().begin(), v->getFields().end(), ".")); } void visit(const SeriesFlow *v) override { auto series = makeFormatters(v->begin(), v->end()); fmt::print(os, FMT_STRING("(series\n{}\n)"), fmt::join(series.begin(), series.end(), "\n")); } void visit(const IfFlow *v) override { fmt::print(os, FMT_STRING("(if {}\n{}\n{}\n)"), makeFormatter(v->getCond()), makeFormatter(v->getTrueBranch()), makeFormatter(v->getFalseBranch())); } void visit(const WhileFlow *v) override { fmt::print(os, FMT_STRING("(while {}\n{}\n)"), makeFormatter(v->getCond()), makeFormatter(v->getBody())); } void visit(const ForFlow *v) override { fmt::print(os, FMT_STRING("({}{}for {}\n{}\n{}\n)"), v->isParallel() ? "par_" : "", v->isAsync() ? "async_" : "", makeFormatter(v->getIter()), makeFormatter(v->getVar()), makeFormatter(v->getBody())); } void visit(const ImperativeForFlow *v) override { fmt::print(os, FMT_STRING("({}imp_for {}\n{}\n{}\n{}\n{}\n)"), v->isParallel() ? "par_" : "", makeFormatter(v->getStart()), v->getStep(), makeFormatter(v->getEnd()), makeFormatter(v->getVar()), makeFormatter(v->getBody())); } void visit(const TryCatchFlow *v) override { std::vector catches; for (auto &c : *v) { catches.push_back( fmt::format(FMT_STRING("(catch {} {}\n{}\n)"), makeFormatter(c.getType()), makeFormatter(c.getVar()), makeFormatter(c.getHandler()))); } fmt::print(os, FMT_STRING("(try {}\n{}\n(else\n{}\n)\n(finally\n{})\n)"), makeFormatter(v->getBody()), fmt::join(catches.begin(), catches.end(), "\n"), makeFormatter(v->getElse()), makeFormatter(v->getFinally())); } void visit(const PipelineFlow *v) override { std::vector stages; for (const auto &s : *v) { auto args = makeFormatters(s.begin(), s.end()); stages.push_back(fmt::format( FMT_STRING("(stage {} {}\n(generator {})\n(parallel {}))"), makeFormatter(s.getCallee()), fmt::join(args.begin(), args.end(), "\n"), s.isGenerator(), s.isParallel())); } fmt::print(os, FMT_STRING("(pipeline {})"), fmt::join(stages.begin(), stages.end(), "\n")); } void visit(const dsl::CustomFlow *v) override { v->doFormat(os); } void visit(const IntConst *v) override { fmt::print(os, FMT_STRING("{}"), v->getVal()); } void visit(const FloatConst *v) override { fmt::print(os, FMT_STRING("{}"), v->getVal()); } void visit(const BoolConst *v) override { fmt::print(os, FMT_STRING("{}"), v->getVal()); } void visit(const StringConst *v) override { fmt::print(os, FMT_STRING("\"{}\""), escapeString(v->getVal())); } void visit(const dsl::CustomConst *v) override { v->doFormat(os); } void visit(const AssignInstr *v) override { fmt::print(os, FMT_STRING("(assign {} {})"), makeFormatter(v->getLhs()), makeFormatter(v->getRhs())); } void visit(const ExtractInstr *v) override { fmt::print(os, FMT_STRING("(extract {} \"{}\")"), makeFormatter(v->getVal()), v->getField()); } void visit(const InsertInstr *v) override { fmt::print(os, FMT_STRING("(insert {} \"{}\" {})"), makeFormatter(v->getLhs()), v->getField(), makeFormatter(v->getRhs())); } void visit(const CallInstr *v) override { auto args = makeFormatters(v->begin(), v->end()); fmt::print(os, FMT_STRING("(call {}\n{}\n)"), makeFormatter(v->getCallee()), fmt::join(args.begin(), args.end(), "\n")); } void visit(const StackAllocInstr *v) override { fmt::print(os, FMT_STRING("(stack_alloc {} {})"), makeFormatter(v->getArrayType()), v->getCount()); } void visit(const TypePropertyInstr *v) override { std::string property; if (v->getProperty() == TypePropertyInstr::Property::IS_ATOMIC) { property = "atomic"; } else if (v->getProperty() == TypePropertyInstr::Property::SIZEOF) { property = "sizeof"; } else { property = "unknown"; } fmt::print(os, FMT_STRING("(property {} {})"), property, makeFormatter(v->getInspectType())); } void visit(const YieldInInstr *v) override { fmt::print(os, FMT_STRING("(yield_in {})"), makeFormatter(v->getType())); } void visit(const TernaryInstr *v) override { fmt::print(os, FMT_STRING("(select {}\n{}\n{}\n)"), makeFormatter(v->getCond()), makeFormatter(v->getTrueValue()), makeFormatter(v->getFalseValue())); } void visit(const BreakInstr *v) override { os << "(break " << (v->getLoop() ? v->getLoop()->getId() : -1) << ')'; } void visit(const ContinueInstr *v) override { os << "(continue " << (v->getLoop() ? v->getLoop()->getId() : -1) << ')'; } void visit(const ReturnInstr *v) override { fmt::print(os, FMT_STRING("(return {})"), makeFormatter(v->getValue())); } void visit(const YieldInstr *v) override { fmt::print(os, FMT_STRING("(yield {})"), makeFormatter(v->getValue())); } void visit(const AwaitInstr *v) override { fmt::print(os, FMT_STRING("(await {} {} {})"), makeFormatter(v->getType()), makeFormatter(v->getValue()), v->isGenerator()); } void visit(const ThrowInstr *v) override { fmt::print(os, FMT_STRING("(throw {})"), makeFormatter(v->getValue())); } void visit(const FlowInstr *v) override { fmt::print(os, FMT_STRING("(flow {} {})"), makeFormatter(v->getFlow()), makeFormatter(v->getValue())); } void visit(const dsl::CustomInstr *v) override { v->doFormat(os); } void visit(const types::IntType *v) override { fmt::print(os, FMT_STRING("(int '\"{}\")"), v->referenceString()); } void visit(const types::FloatType *v) override { fmt::print(os, FMT_STRING("(float '\"{}\")"), v->referenceString()); } void visit(const types::Float32Type *v) override { fmt::print(os, FMT_STRING("(float32 '\"{}\")"), v->referenceString()); } void visit(const types::Float16Type *v) override { fmt::print(os, FMT_STRING("(float16 '\"{}\")"), v->referenceString()); } void visit(const types::BFloat16Type *v) override { fmt::print(os, FMT_STRING("(bfloat16 '\"{}\")"), v->referenceString()); } void visit(const types::Float128Type *v) override { fmt::print(os, FMT_STRING("(float128 '\"{}\")"), v->referenceString()); } void visit(const types::BoolType *v) override { fmt::print(os, FMT_STRING("(bool '\"{}\")"), v->referenceString()); } void visit(const types::ByteType *v) override { fmt::print(os, FMT_STRING("(byte '\"{}\")"), v->referenceString()); } void visit(const types::VoidType *v) override { fmt::print(os, FMT_STRING("(void '\"{}\")"), v->referenceString()); } void visit(const types::RecordType *v) override { std::vector fields; std::vector formatters; for (const auto &m : *v) { fields.push_back(fmt::format(FMT_STRING("(\"{}\" {})"), m.getName(), makeFormatter(m.getType()))); } fmt::print(os, FMT_STRING("(record '\"{}\" {})"), v->referenceString(), fmt::join(fields.begin(), fields.end(), " ")); } void visit(const types::RefType *v) override { fmt::print(os, FMT_STRING("(ref '\"{}\" {})"), v->referenceString(), makeFormatter(v->getContents())); } void visit(const types::FuncType *v) override { auto args = makeFormatters(v->begin(), v->end()); fmt::print(os, FMT_STRING("(func '\"{}\" {}{} {})"), v->referenceString(), fmt::join(args.begin(), args.end(), " "), (v->isVariadic() ? " ..." : ""), makeFormatter(v->getReturnType())); } void visit(const types::OptionalType *v) override { fmt::print(os, FMT_STRING("(optional '\"{}\" {})"), v->referenceString(), makeFormatter(v->getBase())); } void visit(const types::PointerType *v) override { fmt::print(os, FMT_STRING("(pointer '\"{}\" {})"), v->referenceString(), makeFormatter(v->getBase())); } void visit(const types::GeneratorType *v) override { fmt::print(os, FMT_STRING("(generator '\"{}\" {})"), v->referenceString(), makeFormatter(v->getBase())); } void visit(const types::IntNType *v) override { fmt::print(os, FMT_STRING("(intn '\"{}\" {} (signed {}))"), v->referenceString(), v->getLen(), v->isSigned()); } void visit(const types::VectorType *v) override { fmt::print(os, FMT_STRING("(vector '\"{}\" {} (count {}))"), v->referenceString(), makeFormatter(v->getBase()), v->getCount()); } void visit(const types::UnionType *v) override { auto types = makeFormatters(v->begin(), v->end()); fmt::print(os, FMT_STRING("(union '\"{}\" {})"), v->referenceString(), fmt::join(types.begin(), types.end(), " ")); } void visit(const dsl::types::CustomType *v) override { v->doFormat(os); } void format(const Node *n) { if (n) n->accept(*this); else os << "(null)"; } void format(const types::Type *t, bool canShowFull = false) { if (t) { if (seenTypes.find(t->getName()) != seenTypes.end() || !canShowFull) fmt::print(os, FMT_STRING("(type '\"{}\")"), t->referenceString()); else { seenTypes.insert(t->getName()); t->accept(*this); } } else os << "(null)"; } void format(const Value *t) { if (t) { if (seenNodes.find(t->getId()) != seenNodes.end()) fmt::print(os, FMT_STRING("(value '\"{}\")"), t->referenceString()); else { seenNodes.insert(t->getId()); t->accept(*this); } } else os << "(null)"; } void format(const Var *t, bool canShowFull = false) { if (t) { if (seenNodes.find(t->getId()) != seenNodes.end() || !canShowFull) fmt::print(os, FMT_STRING("(var '\"{}\")"), t->referenceString()); else { seenNodes.insert(t->getId()); t->accept(*this); } } else os << "(null)"; } private: NodeFormatter makeFormatter(const types::Type *node, bool canShowFull = false) { auto ret = NodeFormatter(node, seenNodes, seenTypes); ret.canShowFull = canShowFull; return ret; } NodeFormatter makeFormatter(const Value *node) { return NodeFormatter(node, seenNodes, seenTypes); } NodeFormatter makeFormatter(const Var *node, bool canShowFull = false) { auto ret = NodeFormatter(node, seenNodes, seenTypes); ret.canShowFull = canShowFull; return ret; } template std::vector makeFormatters(It begin, It end) { std::vector ret; while (begin != end) { ret.push_back(makeFormatter(*begin)); ++begin; } return ret; } template std::vector makeFormatters(It begin, It end, bool canShowFull) { std::vector ret; while (begin != end) { ret.push_back(makeFormatter(*begin, canShowFull)); ++begin; } return ret; } }; } // namespace std::ostream &operator<<(std::ostream &os, const NodeFormatter &n) { FormatVisitor fv(os, n.seenNodes, n.seenTypes); if (n.type) fv.format(n.type, n.canShowFull); else if (n.value) fv.format(n.value); else fv.format(n.var, n.canShowFull); return os; } std::string format(const Node *node) { std::stringstream ss; format(ss, node); return ss.str(); } std::ostream &format(std::ostream &os, const Node *node) { std::unordered_set seenNodes; std::unordered_set seenTypes; FormatVisitor fv(os, seenNodes, seenTypes); fv.format(node); return os; } } // namespace util } // namespace ir } // namespace codon template <> struct fmt::formatter : ostream_formatter {}; ================================================ FILE: codon/cir/util/format.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/cir/cir.h" namespace codon { namespace ir { namespace util { /// Formats an IR node. /// @param node the node /// @return the formatted node std::string format(const Node *node); /// Formats an IR node to an IO stream. /// @param os the output stream /// @param node the node /// @return the resulting output stream std::ostream &format(std::ostream &os, const Node *node); } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/inlining.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "inlining.h" #include #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/operator.h" namespace codon { namespace ir { namespace util { namespace { class ReturnVerifier : public util::Operator { public: bool needLoop = false; void handle(ReturnInstr *v) { if (needLoop) { return; } auto it = parent_begin(); if (it == parent_end()) { needLoop = true; return; } SeriesFlow *prev = nullptr; while (it != parent_end()) { Value *v = cast(*it++); auto *cur = cast(v); if (!cur || (prev && prev->back()->getId() != cur->getId())) { needLoop = true; return; } prev = cur; } needLoop = prev->back()->getId() != v->getId(); } }; class ReturnReplacer : public util::Operator { private: Value *implicitLoop; Var *var; bool aggressive; util::CloneVisitor &cv; public: ReturnReplacer(Value *implicitLoop, Var *var, bool aggressive, util::CloneVisitor &cv) : implicitLoop(implicitLoop), var(var), aggressive(aggressive), cv(cv) {} void handle(ReturnInstr *v) { auto *M = v->getModule(); auto *rep = M->N(v); if (var) { rep->push_back(M->N(v, var, cv.clone(v->getValue()))); } if (aggressive) rep->push_back(M->N(v, implicitLoop)); v->replaceAll(rep); } }; } // namespace InlineResult inlineFunction(Func *func, std::vector args, bool aggressive, codon::SrcInfo info) { auto *bodied = cast(func); if (!bodied) return {nullptr, {}}; auto *fType = cast(bodied->getType()); if (!fType || args.size() != std::distance(bodied->arg_begin(), bodied->arg_end())) return {nullptr, {}}; auto *M = bodied->getModule(); util::CloneVisitor cv(M); auto *newFlow = M->N(info, bodied->getName() + "_inlined"); std::vector newVars; auto arg_it = bodied->arg_begin(); for (auto i = 0; i < args.size(); ++i) { newVars.push_back(cv.forceClone(*arg_it++)); newFlow->push_back(M->N(info, newVars.back(), cv.clone(args[i]))); } for (auto *v : *bodied) { newVars.push_back(cv.forceClone(v)); } Var *retVal = nullptr; if (!fType->getReturnType()->is(M->getVoidType()) && !fType->getReturnType()->is(M->getNoneType())) { retVal = M->N(info, fType->getReturnType()); newVars.push_back(retVal); } Flow *clonedBody = cv.clone(bodied->getBody()); ReturnVerifier rv; rv.process(clonedBody); if (!aggressive && rv.needLoop) return {nullptr, {}}; WhileFlow *implicit = nullptr; if (rv.needLoop) { auto *loopBody = M->N(info); implicit = M->N(info, M->getBool(true), loopBody); loopBody->push_back(clonedBody); if (!retVal) loopBody->push_back(M->N(info, implicit)); } ReturnReplacer rr(implicit, retVal, rv.needLoop, cv); rr.process(clonedBody); newFlow->push_back(implicit ? implicit : clonedBody); if (retVal) { return {M->N(info, newFlow, M->N(info, retVal)), std::move(newVars)}; } return {newFlow, std::move(newVars)}; } InlineResult inlineCall(CallInstr *v, bool aggressive) { return inlineFunction(util::getFunc(v->getCallee()), std::vector(v->begin(), v->end()), aggressive, v->getSrcInfo()); } } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/inlining.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/cir.h" namespace codon { namespace ir { namespace util { /// Result of an inlining operation. struct InlineResult { /// the result, either a SeriesFlow or FlowInstr Value *result; /// variables added by the inlining std::vector newVars; operator bool() const { return bool(result); } }; /// Inline the given function with the supplied arguments. /// @param func the function /// @param args the arguments /// @param callInfo the call information /// @param aggressive true if should inline complex functions /// @return the inlined result, nullptr if unsuccessful InlineResult inlineFunction(Func *func, std::vector args, bool aggressive = false, codon::SrcInfo callInfo = {}); /// Inline the given call. /// @param v the instruction /// @param aggressive true if should inline complex functions /// @return the inlined result, nullptr if unsuccessful InlineResult inlineCall(CallInstr *v, bool aggressive = false); } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/irtools.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "irtools.h" #include namespace codon { namespace ir { namespace util { bool hasAttribute(const Func *func, const std::string &attribute) { if (auto *attr = func->getAttribute()) { return attr->has(attribute); } return false; } bool isStdlibFunc(const Func *func, const std::string &submodule) { if (auto *attr = func->getAttribute()) { std::string module = attr->get(".module"); return module.rfind("std::" + submodule, 0) == 0; } return false; } CallInstr *call(Func *func, const std::vector &args) { auto *M = func->getModule(); return M->Nr(M->Nr(func), args); } bool isCallOf(const Value *value, const std::string &name, const std::vector &inputs, types::Type *output, bool method) { if (auto *call = cast(value)) { auto *fn = getFunc(call->getCallee()); if (!fn || fn->getUnmangledName() != name || call->numArgs() != inputs.size()) return false; unsigned i = 0; for (auto *arg : *call) { if (inputs[i] && !arg->getType()->is(inputs[i])) return false; ++i; } if (output && !value->getType()->is(output)) return false; if (method) { if (inputs.empty() || !fn->getParentType()) return false; if (inputs[0] && !fn->getParentType()->is(inputs[0])) return false; } return true; } return false; } bool isCallOf(const Value *value, const std::string &name, int numArgs, types::Type *output, bool method) { if (auto *call = cast(value)) { auto *fn = getFunc(call->getCallee()); if (!fn || fn->getUnmangledName() != name || (numArgs >= 0 && call->numArgs() != numArgs)) return false; if (output && !value->getType()->is(output)) return false; if (method && (!fn->getParentType() || call->numArgs() == 0 || !call->front()->getType()->is(fn->getParentType()))) return false; return true; } return false; } bool isMagicMethodCall(const Value *value) { if (auto *call = cast(value)) { auto *fn = getFunc(call->getCallee()); if (!fn || !fn->getParentType() || call->numArgs() == 0 || !call->front()->getType()->is(fn->getParentType())) return false; auto name = fn->getUnmangledName(); auto size = name.size(); if (size < 5 || !(name[0] == '_' && name[1] == '_' && name[size - 1] == '_' && name[size - 2] == '_')) return false; return true; } return false; } Value *makeTuple(const std::vector &args, Module *M) { if (!M) { seqassertn(!args.empty(), "unknown module for empty tuple construction"); M = args[0]->getModule(); } std::vector types; for (auto *arg : args) { types.push_back(arg->getType()); } auto *tupleType = M->getTupleType(types); auto *newFunc = M->getOrRealizeMethod(tupleType, "__new__", types); seqassertn(newFunc, "could not realize {} new function", *tupleType); return M->Nr(M->Nr(newFunc), args); } Var *makeVar(Value *x, SeriesFlow *flow, BodiedFunc *parent, bool prepend) { const bool global = (parent == nullptr); auto *M = x->getModule(); auto *v = M->Nr(x->getType(), global); if (global) { static int counter = 1; v->setName(".anon_global." + std::to_string(counter++)); } auto *a = M->Nr(v, x); if (prepend) { flow->insert(flow->begin(), a); } else { flow->push_back(a); } if (!global) { parent->push_back(v); } return v; } Value *alloc(types::Type *type, Value *count) { auto *M = type->getModule(); auto *ptrType = M->getPointerType(type); return (*ptrType)(*count); } Value *alloc(types::Type *type, int64_t count) { auto *M = type->getModule(); return alloc(type, M->getInt(count)); } Var *getVar(Value *x) { if (auto *v = cast(x)) { if (auto *var = cast(v->getVar())) { if (!isA(var)) { return var; } } } return nullptr; } const Var *getVar(const Value *x) { if (auto *v = cast(x)) { if (auto *var = cast(v->getVar())) { if (!isA(var)) { return var; } } } return nullptr; } Func *getFunc(Value *x) { if (auto *v = cast(x)) { if (auto *func = cast(v->getVar())) { return func; } } return nullptr; } const Func *getFunc(const Value *x) { if (auto *v = cast(x)) { if (auto *func = cast(v->getVar())) { return func; } } return nullptr; } Value *ptrLoad(Value *ptr) { auto *M = ptr->getModule(); auto *deref = (*ptr)[*M->getInt(0)]; seqassertn(deref, "pointer getitem not found [{}]", ptr->getSrcInfo()); return deref; } Value *ptrStore(Value *ptr, Value *val) { auto *M = ptr->getModule(); auto *setitem = M->getOrRealizeMethod(ptr->getType(), Module::SETITEM_MAGIC_NAME, {ptr->getType(), M->getIntType(), val->getType()}); seqassertn(setitem, "pointer setitem not found [{}]", ptr->getSrcInfo()); return call(setitem, {ptr, M->getInt(0), val}); } Value *tupleGet(Value *tuple, unsigned index) { auto *M = tuple->getModule(); return M->Nr(tuple, "item" + std::to_string(index + 1)); } Value *tupleStore(Value *tuple, unsigned index, Value *val) { auto *M = tuple->getModule(); auto *type = cast(tuple->getType()); seqassertn(type, "argument is not a tuple [{}]", tuple->getSrcInfo()); std::vector newElements; for (unsigned i = 0; i < std::distance(type->begin(), type->end()); i++) { newElements.push_back(i == index ? val : tupleGet(tuple, i)); } return makeTuple(newElements, M); } BodiedFunc *getStdlibFunc(Value *x, const std::string &name, const std::string &submodule) { if (auto *f = getFunc(x)) { if (auto *g = cast(f)) { if (isStdlibFunc(g, submodule) && g->getUnmangledName() == name) { return g; } } } return nullptr; } const BodiedFunc *getStdlibFunc(const Value *x, const std::string &name, const std::string &submodule) { if (auto *f = getFunc(x)) { if (auto *g = cast(f)) { if (isStdlibFunc(g, submodule) && g->getUnmangledName() == name) { return g; } } } return nullptr; } types::Type *getReturnType(const Func *func) { return cast(func->getType())->getReturnType(); } void setReturnType(Func *func, types::Type *rType) { auto *M = func->getModule(); auto *t = cast(func->getType()); seqassertn(t, "{} is not a function type [{}]", *func->getType(), func->getSrcInfo()); std::vector argTypes(t->begin(), t->end()); func->setType(M->getFuncType(rType, argTypes)); } } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/irtools.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/cir.h" namespace codon { namespace ir { namespace util { /// Checks whether a function has a given attribute. /// @param func the function /// @param attribute the attribute name /// @return true if the function has the given attribute bool hasAttribute(const Func *func, const std::string &attribute); /// Checks whether a function comes from the standard library, and /// optionally a specific module therein. /// @param func the function /// @param submodule module name (e.g. "std::bio"), or empty if /// no module check is required /// @return true if the function is from the standard library in /// the given module bool isStdlibFunc(const Func *func, const std::string &submodule = ""); /// Calls a function. /// @param func the function /// @param args vector of call arguments /// @return call instruction with the given function and arguments CallInstr *call(Func *func, const std::vector &args); /// Checks if a value represents a call of a particular function. /// @param value the value to check /// @param name the function's (unmangled) name /// @param inputs vector of input types /// @param output output type, null for no check /// @param method true to ensure this call is a method call /// @return true if value is a call matching all parameters above bool isCallOf(const Value *value, const std::string &name, const std::vector &inputs, types::Type *output = nullptr, bool method = false); /// Checks if a value represents a call of a particular function. /// @param value the value to check /// @param name the function's (unmangled) name /// @param numArgs argument count, negative for no check /// @param output output type, null for no check /// @param method true to ensure this call is a method call /// @return true if value is a call matching all parameters above bool isCallOf(const Value *value, const std::string &name, int numArgs = -1, types::Type *output = nullptr, bool method = false); /// Checks if a value represents a call to a magic method. /// Magic method names start and end in "__" (two underscores). /// @param value the value to check /// @return true if value is a magic method call bool isMagicMethodCall(const Value *value); /// Constructs a new tuple. /// @param args vector of tuple contents /// @param M the module; inferred from elements if null /// @return value represents a tuple with the given contents Value *makeTuple(const std::vector &args, Module *M = nullptr); /// Constructs and assigns a new variable. /// @param x the value to assign to the new variable /// @param flow series flow in which to assign the new variable /// @param parent function to add the new variable to, or null for global variable /// @param prepend true to insert assignment at start of block /// @return value containing the new variable Var *makeVar(Value *x, SeriesFlow *flow, BodiedFunc *parent, bool prepend = false); /// Dynamically allocates memory for the given type with the given /// number of elements. /// @param type the type /// @param count integer value representing the number of elements /// @return value representing a pointer to the allocated memory Value *alloc(types::Type *type, Value *count); /// Dynamically allocates memory for the given type with the given /// number of elements. /// @param type the type /// @param count the number of elements /// @return value representing a pointer to the allocated memory Value *alloc(types::Type *type, int64_t count); /// Builds a new series flow with the given contents. Returns /// null if no contents are provided. /// @param args contents of the series flow /// @return new series flow template SeriesFlow *series(Args... args) { std::vector vals = {args...}; if (vals.empty()) return nullptr; auto *series = vals[0]->getModule()->Nr(); for (auto *val : vals) { series->push_back(val); } return series; } /// Checks whether the given value is a constant of the given /// type. Note that standard "int" corresponds to the C type /// "int64_t", which should be used here. /// @param x the value to check /// @return true if the value is constant template bool isConst(const Value *x) { return isA>(x); } /// Checks whether the given value is a constant of the given /// type, and that is has a particular value. Note that standard /// "int" corresponds to the C type "int64_t", which should be used here. /// @param x the value to check /// @param value constant value to compare to /// @return true if the value is constant with the given value template bool isConst(const Value *x, const T &value) { if (auto *c = cast>(x)) { return c->getVal() == value; } return false; } /// Returns the constant represented by a given value. Raises an assertion /// error if the given value is not constant. Note that standard /// "int" corresponds to the C type "int64_t", which should be used here. /// @param x the (constant) value /// @return the constant represented by the given value template T getConst(const Value *x) { auto *c = cast>(x); seqassertn(c, "{} is not a constant [{}]", *x, x->getSrcInfo()); return c->getVal(); } /// Gets a variable from a value. /// @param x the value /// @return the variable represented by the given value, or null if none Var *getVar(Value *x); /// Gets a variable from a value. /// @param x the value /// @return the variable represented by the given value, or null if none const Var *getVar(const Value *x); /// Gets a function from a value. /// @param x the value /// @return the function represented by the given value, or null if none Func *getFunc(Value *x); /// Gets a function from a value. /// @param x the value /// @return the function represented by the given value, or null if none const Func *getFunc(const Value *x); /// Loads value from a pointer. /// @param ptr the pointer /// @return the value pointed to by the argument Value *ptrLoad(Value *ptr); /// Stores a value into a pointer. /// @param ptr the pointer /// @param val the value to store /// @return "__setitem__" call representing the store Value *ptrStore(Value *ptr, Value *val); /// Gets value from a tuple at the given index. /// @param tuple the tuple /// @param index the 0-based index /// @return tuple element at the given index Value *tupleGet(Value *tuple, unsigned index); /// Stores value in a tuple at the given index. Since tuples are immutable, /// a new instance is returned with the appropriate element replaced. /// @param tuple the tuple /// @param index the 0-based index /// @param val the value to store /// @return new tuple instance with the given value inserted Value *tupleStore(Value *tuple, unsigned index, Value *val); /// Gets a bodied standard library function from a value. /// @param x the value /// @param name name of the function /// @param submodule optional module to check /// @return the standard library function (with the given name, from the given /// submodule) represented by the given value, or null if none BodiedFunc *getStdlibFunc(Value *x, const std::string &name, const std::string &submodule = ""); /// Gets a bodied standard library function from a value. /// @param x the value /// @param name name of the function /// @param submodule optional module to check /// @return the standard library function (with the given name, from the given /// submodule) represented by the given value, or null if none const BodiedFunc *getStdlibFunc(const Value *x, const std::string &name, const std::string &submodule = ""); /// Gets the return type of a function. /// @param func the function /// @return the return type of the given function types::Type *getReturnType(const Func *func); /// Sets the return type of a function. Argument types remain unchanged. /// @param func the function /// @param rType the new return type void setReturnType(Func *func, types::Type *rType); } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/iterators.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include namespace codon { namespace ir { namespace util { /// Iterator wrapper that applies a function to the iterator. template struct function_iterator_adaptor { It internal; DereferenceFunc d; MemberFunc m; using iterator_category = std::input_iterator_tag; using value_type = typename std::remove_reference::type; using reference = void; using pointer = void; using difference_type = typename std::iterator_traits::difference_type; /// Constructs an adaptor. /// @param internal the internal iterator /// @param d the dereference function /// @param m the member access function function_iterator_adaptor(It internal, DereferenceFunc &&d, MemberFunc &&m) : internal(std::move(internal)), d(std::move(d)), m(std::move(m)) {} decltype(auto) operator*() { return d(*internal); } decltype(auto) operator->() { return m(*internal); } function_iterator_adaptor &operator++() { internal++; return *this; } function_iterator_adaptor operator++(int) { function_iterator_adaptor copy(*this); internal++; return copy; } template bool operator==(const function_iterator_adaptor &other) const { return other.internal == internal; } template bool operator!=(const function_iterator_adaptor &other) const { return other.internal != internal; } }; /// Creates an adaptor that dereferences values. /// @param it the internal iterator /// @return the adaptor template auto dereference_adaptor(It it) { auto f = [](const auto &v) -> auto & { return *v; }; auto m = [](const auto &v) -> auto { return v.get(); }; return function_iterator_adaptor(it, std::move(f), std::move(m)); } /// Creates an adaptor that gets the address of its values. /// @param it the internal iterator /// @return the adaptor template auto raw_ptr_adaptor(It it) { auto f = [](auto &v) -> auto * { return v.get(); }; auto m = [](auto &v) -> auto * { return v.get(); }; return function_iterator_adaptor(it, std::move(f), std::move(m)); } /// Creates an adaptor that gets the const address of its values. /// @param it the internal iterator /// @return the adaptor template auto const_raw_ptr_adaptor(It it) { auto f = [](auto &v) -> const auto * { return v.get(); }; auto m = [](auto &v) -> const auto * { return v.get(); }; return function_iterator_adaptor(it, std::move(f), std::move(m)); } /// Creates an adaptor that gets the keys of its values. /// @param it the internal iterator /// @return the adaptor template auto map_key_adaptor(It it) { auto f = [](auto &v) -> auto & { return v.first; }; auto m = [](auto &v) -> auto & { return v.first; }; return function_iterator_adaptor(it, std::move(f), std::move(m)); } /// Creates an adaptor that gets the const keys of its values. /// @param it the internal iterator /// @return the adaptor template auto const_map_key_adaptor(It it) { auto f = [](auto &v) -> const auto & { return v.first; }; auto m = [](auto &v) -> const auto & { return v.first; }; return function_iterator_adaptor(it, std::move(f), std::move(m)); } } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/matching.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "matching.h" #include #include "codon/cir/cir.h" #include "codon/cir/util/visitor.h" #define VISIT(x) \ void visit(const x *v) override { \ if (matchAny || dynamic_cast(v)) { \ result = true; \ matchAny = true; \ } else if (!nodeId) { \ nodeId = &x::NodeId; \ other = v; \ } else if (nodeId != &x::NodeId || \ (!checkName && v->getName() != other->getName())) \ result = false; \ else \ handle(v, static_cast(other)); \ } namespace codon { namespace ir { namespace util { namespace { class MatchVisitor : public util::ConstVisitor { private: bool matchAny = false; bool checkName; const char *nodeId = nullptr; bool result = false; const Node *other = nullptr; bool varIdMatch; public: explicit MatchVisitor(bool checkName = false, bool varIdMatch = false) : checkName(checkName), varIdMatch(varIdMatch) {} VISIT(Var); void handle(const Var *x, const Var *y) { result = compareVars(x, y); } VISIT(Func); void handle(const Func *x, const Func *y) {} VISIT(BodiedFunc); void handle(const BodiedFunc *x, const BodiedFunc *y) { result = compareFuncs(x, y) && std::equal(x->begin(), x->end(), y->begin(), y->end(), [this](auto *x, auto *y) { return process(x, y); }) && process(x->getBody(), y->getBody()) && x->isJIT() == y->isJIT(); } VISIT(ExternalFunc); void handle(const ExternalFunc *x, const ExternalFunc *y) { result = x->getUnmangledName() == y->getUnmangledName() && compareFuncs(x, y); } VISIT(InternalFunc); void handle(const InternalFunc *x, const InternalFunc *y) { result = x->getParentType() == y->getParentType() && compareFuncs(x, y); } VISIT(LLVMFunc); void handle(const LLVMFunc *x, const LLVMFunc *y) { result = std::equal(x->literal_begin(), x->literal_end(), y->literal_begin(), y->literal_end(), [this](auto &x, auto &y) { if (x.isStatic() && y.isStatic()) return x.getStaticValue() == y.getStaticValue(); else if (x.isStaticStr() && y.isStaticStr()) return x.getStaticStringValue() == y.getStaticStringValue(); else if (x.isType() && y.isType()) return process(x.getTypeValue(), y.getTypeValue()); return false; }) && x->getLLVMDeclarations() == y->getLLVMDeclarations() && x->getLLVMBody() == y->getLLVMBody() && compareFuncs(x, y); } VISIT(Value); void handle(const Value *x, const Value *y) {} VISIT(VarValue); void handle(const VarValue *x, const VarValue *y) { result = compareVars(x->getVar(), y->getVar()); } VISIT(PointerValue); void handle(const PointerValue *x, const PointerValue *y) { result = compareVars(x->getVar(), y->getVar()) && x->getFields() == y->getFields(); } VISIT(Flow); void handle(const Flow *x, const Flow *y) {} VISIT(SeriesFlow); void handle(const SeriesFlow *x, const SeriesFlow *y) { result = std::equal(x->begin(), x->end(), y->begin(), y->end(), [this](auto *x, auto *y) { return process(x, y); }); } VISIT(IfFlow); void handle(const IfFlow *x, const IfFlow *y) { result = process(x->getCond(), y->getCond()) && process(x->getTrueBranch(), y->getTrueBranch()) && process(x->getFalseBranch(), y->getFalseBranch()); } VISIT(WhileFlow); void handle(const WhileFlow *x, const WhileFlow *y) { result = process(x->getCond(), y->getCond()) && process(x->getBody(), y->getBody()); } VISIT(ForFlow); void handle(const ForFlow *x, const ForFlow *y) { result = (x->isAsync() == y->isAsync()) && process(x->getIter(), y->getIter()) && process(x->getBody(), y->getBody()) && process(x->getVar(), y->getVar()); } VISIT(ImperativeForFlow); void handle(const ImperativeForFlow *x, const ImperativeForFlow *y) { result = process(x->getVar(), y->getVar()) && process(x->getBody(), y->getBody()) && process(x->getStart(), y->getStart()) && x->getStep() == y->getStep() && process(x->getEnd(), y->getEnd()); } VISIT(TryCatchFlow); void handle(const TryCatchFlow *x, const TryCatchFlow *y) { result = result && process(x->getElse(), y->getElse()) && process(x->getFinally(), y->getFinally()) && process(x->getBody(), y->getBody()) && std::equal(x->begin(), x->end(), y->begin(), y->end(), [this](auto &x, auto &y) { return process(x.getHandler(), y.getHandler()) && process(x.getType(), y.getType()) && process(x.getVar(), y.getVar()); }); } VISIT(PipelineFlow); void handle(const PipelineFlow *x, const PipelineFlow *y) { result = std::equal( x->begin(), x->end(), y->begin(), y->end(), [this](auto &x, auto &y) { return process(x.getCallee(), y.getCallee()) && std::equal(x.begin(), x.end(), y.begin(), y.end(), [this](auto *x, auto *y) { return process(x, y); }) && x.isGenerator() == y.isGenerator() && x.isParallel() == y.isParallel(); }); } VISIT(dsl::CustomFlow); void handle(const dsl::CustomFlow *x, const dsl::CustomFlow *y) { result = x->match(y); } VISIT(IntConst); void handle(const IntConst *x, const IntConst *y) { result = process(x->getType(), y->getType()) && x->getVal() == y->getVal(); } VISIT(FloatConst); void handle(const FloatConst *x, const FloatConst *y) { result = process(x->getType(), y->getType()) && x->getVal() == y->getVal(); } VISIT(BoolConst); void handle(const BoolConst *x, const BoolConst *y) { result = process(x->getType(), y->getType()) && x->getVal() == y->getVal(); } VISIT(StringConst); void handle(const StringConst *x, const StringConst *y) { result = process(x->getType(), y->getType()) && x->getVal() == y->getVal(); } VISIT(dsl::CustomConst); void handle(const dsl::CustomConst *x, const dsl::CustomConst *y) { result = x->match(y); } VISIT(AssignInstr); void handle(const AssignInstr *x, const AssignInstr *y) { result = process(x->getLhs(), y->getLhs()) && process(x->getRhs(), y->getRhs()); } VISIT(ExtractInstr); void handle(const ExtractInstr *x, const ExtractInstr *y) { result = process(x->getVal(), y->getVal()) && x->getField() == y->getField(); } VISIT(InsertInstr); void handle(const InsertInstr *x, const InsertInstr *y) { result = process(x->getLhs(), y->getLhs()) && x->getField() == y->getField() && process(x->getRhs(), y->getRhs()); } VISIT(CallInstr); void handle(const CallInstr *x, const CallInstr *y) { result = process(x->getCallee(), y->getCallee()) && std::equal(x->begin(), x->end(), y->begin(), y->end(), [this](auto *x, auto *y) { return process(x, y); }); } VISIT(StackAllocInstr); void handle(const StackAllocInstr *x, const StackAllocInstr *y) { result = x->getCount() == y->getCount() && process(x->getType(), y->getType()); } VISIT(TypePropertyInstr); void handle(const TypePropertyInstr *x, const TypePropertyInstr *y) { result = x->getProperty() == y->getProperty() && process(x->getInspectType(), y->getInspectType()); } VISIT(YieldInInstr); void handle(const YieldInInstr *x, const YieldInInstr *y) { result = process(x->getType(), y->getType()); } VISIT(TernaryInstr); void handle(const TernaryInstr *x, const TernaryInstr *y) { result = process(x->getCond(), y->getCond()) && process(x->getTrueValue(), y->getTrueValue()) && process(x->getFalseValue(), y->getFalseValue()); } VISIT(BreakInstr); void handle(const BreakInstr *x, const BreakInstr *y) { result = process(x->getLoop(), y->getLoop()); } VISIT(ContinueInstr); void handle(const ContinueInstr *x, const ContinueInstr *y) { result = process(x->getLoop(), y->getLoop()); } VISIT(ReturnInstr); void handle(const ReturnInstr *x, const ReturnInstr *y) { result = process(x->getValue(), y->getValue()); } VISIT(YieldInstr); void handle(const YieldInstr *x, const YieldInstr *y) { result = process(x->getValue(), y->getValue()); } VISIT(AwaitInstr); void handle(const AwaitInstr *x, const AwaitInstr *y) { result = process(x->getType(), y->getType()) && process(x->getValue(), y->getValue()) && (x->isGenerator() == y->isGenerator()); } VISIT(ThrowInstr); void handle(const ThrowInstr *x, const ThrowInstr *y) { result = process(x->getValue(), y->getValue()); } VISIT(FlowInstr); void handle(const FlowInstr *x, const FlowInstr *y) { result = process(x->getFlow(), y->getFlow()) && process(x->getValue(), y->getValue()); } VISIT(dsl::CustomInstr); void handle(const dsl::CustomInstr *x, const dsl::CustomInstr *y) { result = x->match(y); } bool process(const Node *x, const Node *y) const { if (!x && !y) return true; else if ((!x && y) || (x && !y)) return false; auto *tx = cast(x); auto *ty = cast(y); if (tx || ty) return tx && ty && tx->is(const_cast(ty)); MatchVisitor v(checkName); x->accept(v); y->accept(v); return v.result; } private: bool compareVars(const Var *x, const Var *y) const { return process(x->getType(), y->getType()) && (!varIdMatch || x->getId() == y->getId()); } bool compareFuncs(const Func *x, const Func *y) const { if (!compareVars(x, y)) return false; if (!std::equal(x->arg_begin(), x->arg_end(), y->arg_begin(), y->arg_end(), [this](auto *x, auto *y) { return process(x, y); })) return false; return true; } }; } // namespace const char AnyValue::NodeId = 0; const char AnyFlow::NodeId = 0; const char AnyVar::NodeId = 0; const char AnyFunc::NodeId = 0; bool match(Node *a, Node *b, bool checkNames, bool varIdMatch) { return MatchVisitor(checkNames).process(a, b); } } // namespace util } // namespace ir } // namespace codon #undef VISIT ================================================ FILE: codon/cir/util/matching.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/cir.h" namespace codon { namespace ir { namespace util { /// Base class for IR nodes that match anything. class Any {}; /// Any value. class AnyValue : public AcceptorExtend, public Any { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; private: types::Type *doGetType() const override { return getModule()->getVoidType(); } }; /// Any flow. class AnyFlow : public AcceptorExtend, public Any { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; }; /// Any variable. class AnyVar : public AcceptorExtend, public Any { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; }; /// Any function. class AnyFunc : public AcceptorExtend, public Any { public: static const char NodeId; using AcceptorExtend::AcceptorExtend; AnyFunc() : AcceptorExtend() { setUnmangledName("any"); } }; /// Checks if IR nodes match. /// @param a the first IR node /// @param b the second IR node /// @param checkNames whether or not to check the node names /// @param varIdMatch whether or not variable ids must match /// @return true if the nodes are equal bool match(Node *a, Node *b, bool checkNames = false, bool varIdMatch = false); } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/operator.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/cir/cir.h" #include "codon/cir/util/visitor.h" #define LAMBDA_VISIT(x) \ virtual void handle(codon::ir::x *v) {} \ void visit(codon::ir::x *v) override { \ if (childrenFirst) \ processChildren(v); \ preHook(v); \ handle(v); \ postHook(v); \ if (!childrenFirst) \ processChildren(v); \ } namespace codon { namespace ir { namespace util { /// Pass that visits all values in a module. class Operator : public Visitor { private: /// IDs of previously visited nodes std::unordered_set seen; /// stack of IR nodes being visited std::vector nodeStack; /// stack of iterators std::vector itStack; /// true if should visit children first bool childrenFirst; protected: void defaultVisit(Node *) override {} public: /// Constructs an operator. /// @param childrenFirst true if children should be visited first explicit Operator(bool childrenFirst = false) : childrenFirst(childrenFirst) {} virtual ~Operator() noexcept = default; /// This function is applied to all nodes before handling the node /// itself. It provides a way to write one function that gets /// applied to every visited node. /// @param node the node virtual void preHook(Node *node) {} /// This function is applied to all nodes after handling the node /// itself. It provides a way to write one function that gets /// applied to every visited node. /// @param node the node virtual void postHook(Node *node) {} void visit(Module *m) override { nodeStack.push_back(m); nodeStack.push_back(m->getMainFunc()); process(m->getMainFunc()); nodeStack.pop_back(); for (auto *s : *m) { nodeStack.push_back(s); process(s); nodeStack.pop_back(); } nodeStack.pop_back(); } void visit(BodiedFunc *f) override { if (f->getBody()) { seen.insert(f->getBody()->getId()); process(f->getBody()); } } LAMBDA_VISIT(VarValue); LAMBDA_VISIT(PointerValue); void visit(codon::ir::SeriesFlow *v) override { if (childrenFirst) processSeriesFlowChildren(v); preHook(v); handle(v); postHook(v); if (!childrenFirst) processSeriesFlowChildren(v); } virtual void handle(codon::ir::SeriesFlow *v) {} LAMBDA_VISIT(IfFlow); LAMBDA_VISIT(WhileFlow); LAMBDA_VISIT(ForFlow); LAMBDA_VISIT(ImperativeForFlow); LAMBDA_VISIT(TryCatchFlow); LAMBDA_VISIT(PipelineFlow); LAMBDA_VISIT(dsl::CustomFlow); LAMBDA_VISIT(TemplatedConst); LAMBDA_VISIT(TemplatedConst); LAMBDA_VISIT(TemplatedConst); LAMBDA_VISIT(TemplatedConst); LAMBDA_VISIT(dsl::CustomConst); LAMBDA_VISIT(Instr); LAMBDA_VISIT(AssignInstr); LAMBDA_VISIT(ExtractInstr); LAMBDA_VISIT(InsertInstr); LAMBDA_VISIT(CallInstr); LAMBDA_VISIT(StackAllocInstr); LAMBDA_VISIT(TypePropertyInstr); LAMBDA_VISIT(YieldInInstr); LAMBDA_VISIT(TernaryInstr); LAMBDA_VISIT(BreakInstr); LAMBDA_VISIT(ContinueInstr); LAMBDA_VISIT(ReturnInstr); LAMBDA_VISIT(YieldInstr); LAMBDA_VISIT(AwaitInstr); LAMBDA_VISIT(ThrowInstr); LAMBDA_VISIT(FlowInstr); LAMBDA_VISIT(dsl::CustomInstr); template void process(Node *v) { v->accept(*this); } /// Return the parent of the current node. /// @param level the number of levels up from the current node template Desired *getParent(int level = 0) { return cast(nodeStack[nodeStack.size() - level - 1]); } /// @return current depth in the tree int depth() const { return nodeStack.size(); } /// @tparam Desired the desired type /// @return the last encountered example of the desired type template Desired *findLast() { for (auto it = nodeStack.rbegin(); it != nodeStack.rend(); ++it) { if (auto *v = cast(*it)) return v; } return nullptr; } /// @return the last encountered function Func *getParentFunc() { return findLast(); } /// @return an iterator to the first parent auto parent_begin() const { return nodeStack.begin(); } /// @return an iterator beyond the last parent auto parent_end() const { return nodeStack.end(); } /// @param v the value /// @return whether we have visited ("seen") the given value bool saw(const Value *v) const { return seen.find(v->getId()) != seen.end(); } /// Avoid visiting the given value in the future. /// @param v the value void see(const Value *v) { seen.insert(v->getId()); } /// Inserts the new value before the current position in the last seen SeriesFlow. /// @param v the new value auto insertBefore(Value *v) { return findLast()->insert(itStack.back(), v); } /// Inserts the new value after the current position in the last seen SeriesFlow. /// @param v the new value, which is marked seen auto insertAfter(Value *v) { auto newPos = itStack.back(); ++newPos; see(v); return findLast()->insert(newPos, v); } /// Resets the operator. void reset() { seen.clear(); nodeStack.clear(); itStack.clear(); } private: void processChildren(Value *v) { nodeStack.push_back(v); for (auto *c : v->getUsedValues()) { if (saw(c)) continue; see(c); process(c); } nodeStack.pop_back(); } void processSeriesFlowChildren(codon::ir::SeriesFlow *v) { nodeStack.push_back(v); for (auto it = v->begin(); it != v->end(); ++it) { itStack.push_back(it); process(*it); itStack.pop_back(); } nodeStack.pop_back(); } }; } // namespace util } // namespace ir } // namespace codon #undef LAMBDA_VISIT ================================================ FILE: codon/cir/util/outlining.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "outlining.h" #include #include #include #include "codon/cir/util/cloning.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/operator.h" namespace codon { namespace ir { namespace util { namespace { struct OutlineReplacer : public Operator { std::unordered_set &modVars; std::vector> &remap; std::vector &outFlows; CloneVisitor cv; OutlineReplacer(Module *M, std::unordered_set &modVars, std::vector> &remap, std::vector &outFlows) : Operator(), modVars(modVars), remap(remap), outFlows(outFlows), cv(M, false) {} // Replace all used vars based on remapping. void postHook(Node *node) override { for (auto &pair : remap) { node->replaceUsedVariable(std::get<0>(pair), std::get<1>(pair)); } } Var *mappedVar(Var *v) { for (auto &pair : remap) { if (std::get<0>(pair)->getId() == v->getId()) return std::get<1>(pair); } return nullptr; } // A return in the outlined func, or a break/continue that references a // non-outlined loop, will return a status code that tells the call site // what action to perform. template void replaceOutFlowWithReturn(InstrType *v) { auto *M = v->getModule(); for (unsigned i = 0; i < outFlows.size(); i++) { if (outFlows[i]->getId() == v->getId()) { auto *copy = cv.clone(v); v->replaceAll(M->template Nr(M->getInt(i + 1))); outFlows[i] = copy; break; } } } void handle(ReturnInstr *v) override { replaceOutFlowWithReturn(v); } void handle(BreakInstr *v) override { replaceOutFlowWithReturn(v); } void handle(ContinueInstr *v) override { replaceOutFlowWithReturn(v); } // If passed by pointer (i.e. a "mod var"), change variable reference to // a pointer dereference. void handle(VarValue *v) override { auto *M = v->getModule(); if (modVars.count(v->getVar()->getId()) > 0) { // var -> pointer dereference auto *deref = util::ptrLoad(M->Nr(mappedVar(v->getVar()))); saw(deref); v->replaceAll(deref); } } // If passed by pointer (i.e. a "mod var"), change pointer value to just // be the var itself. void handle(PointerValue *v) override { auto *M = v->getModule(); if (modVars.count(v->getVar()->getId()) > 0) { // pointer -> var auto *ref = M->Nr(mappedVar(v->getVar())); saw(ref); v->replaceAll(ref); } } // If passed by pointer (i.e. a "mod var"), change assignment to store // in the pointer. void handle(AssignInstr *v) override { auto *M = v->getModule(); if (modVars.count(v->getLhs()->getId()) > 0) { // store in pointer Var *newVar = mappedVar(v->getLhs()); auto *setitem = util::ptrStore(M->Nr(newVar), v->getRhs()); saw(setitem); v->replaceAll(setitem); } } }; struct Outliner : public Operator { BodiedFunc *parent; SeriesFlow *flowRegion; decltype(flowRegion->begin()) begin, end; bool outlineGlobals; // whether to outline globals that are modified bool allByValue; // outline all vars by value (can change semantics) bool inRegion; // are we in the outlined region? bool invalid; // if we can't outline for whatever reason std::unordered_set inVars; // vars used inside region std::unordered_set outVars; // vars used outside region std::unordered_set modifiedInVars; // vars modified (assigned or address'd) in region std::unordered_set globalsToOutline; // modified global vars to outline std::unordered_set inLoops; // loops contained in region std::vector outFlows; // control flows that need to be handled externally (e.g. return) Outliner(BodiedFunc *parent, SeriesFlow *flowRegion, decltype(flowRegion->begin()) begin, decltype(flowRegion->begin()) end, bool outlineGlobals, bool allByValue) : Operator(), parent(parent), flowRegion(flowRegion), begin(begin), end(end), outlineGlobals(outlineGlobals), allByValue(allByValue), inRegion(false), invalid(false), inVars(), outVars(), modifiedInVars(), globalsToOutline(), inLoops(), outFlows() {} bool isEnclosingLoopInRegion(id_t loopId = -1) { int d = depth(); for (int i = 0; i < d; i++) { Flow *v = getParent(i); if (!v) v = getParent(i); if (!v) v = getParent(i); if (v && (loopId == -1 || loopId == v->getId())) return inLoops.count(v->getId()) > 0; } return false; } void handle(WhileFlow *v) override { if (inRegion) inLoops.insert(v->getId()); } void handle(ForFlow *v) override { if (inRegion) inLoops.insert(v->getId()); } void handle(ImperativeForFlow *v) override { if (inRegion) inLoops.insert(v->getId()); } void handle(ReturnInstr *v) override { if (inRegion) outFlows.push_back(v); } void handle(BreakInstr *v) override { auto *loop = v->getLoop(); if (inRegion && !isEnclosingLoopInRegion(loop ? loop->getId() : -1)) outFlows.push_back(v); } void handle(ContinueInstr *v) override { auto *loop = v->getLoop(); if (inRegion && !isEnclosingLoopInRegion(loop ? loop->getId() : -1)) outFlows.push_back(v); } void handle(YieldInstr *v) override { if (inRegion) invalid = true; } void handle(AwaitInstr *v) override { if (inRegion) invalid = true; } void handle(YieldInInstr *v) override { if (inRegion) invalid = true; } void handle(StackAllocInstr *v) override { if (inRegion) invalid = true; } void handle(AssignInstr *v) override { if (inRegion) { auto *var = v->getLhs(); modifiedInVars.insert(var->getId()); if (outlineGlobals && var->isGlobal() && !var->isThreadLocal()) globalsToOutline.insert(var->getId()); } } void handle(PointerValue *v) override { if (inRegion) { auto *var = v->getVar(); modifiedInVars.insert(var->getId()); if (outlineGlobals && var->isGlobal() && !var->isThreadLocal()) globalsToOutline.insert(var->getId()); } } void visit(SeriesFlow *v) override { if (v->getId() != flowRegion->getId()) return Operator::visit(v); auto it = flowRegion->begin(); for (; it != begin; ++it) { (*it)->accept(*this); } inRegion = true; for (; it != end; ++it) { (*it)->accept(*this); } inRegion = false; for (; it != flowRegion->end(); ++it) { (*it)->accept(*this); } } void visit(BodiedFunc *v) override { for (auto it = v->arg_begin(); it != v->arg_end(); ++it) { outVars.insert((*it)->getId()); } Operator::visit(v); } void preHook(Node *node) override { auto vars = node->getUsedVariables(); auto &set = (inRegion ? inVars : outVars); for (auto *var : vars) { if (!var->isGlobal()) set.insert(var->getId()); else if (inRegion && allByValue && !isA(var)) globalsToOutline.insert(var->getId()); } } // private = used in region AND NOT used outside region std::unordered_set getPrivateVars() { std::unordered_set privateVars; for (auto id : inVars) { if (outVars.count(id) == 0) privateVars.insert(id); } return privateVars; } // shared = used in region AND used outside region std::unordered_set getSharedVars() { std::unordered_set sharedVars; for (auto id : inVars) { if (outVars.count(id) > 0) sharedVars.insert(id); } return sharedVars; } // mod = shared AND modified in region std::unordered_set getModVars() { if (allByValue) return {}; std::unordered_set modVars, shared = getSharedVars(); for (auto id : modifiedInVars) { if (globalsToOutline.count(id) > 0 || shared.count(id) > 0) modVars.insert(id); } return modVars; } OutlineResult outline(bool allowOutflows = true) { if (invalid) return {}; auto *M = flowRegion->getModule(); std::vector> remap; // mapping of old vars to new func vars std::vector argTypes; // arg types of new func std::vector argNames; // arg names of new func std::vector argKinds; // arg information given back to user // Figure out arguments and outlined function type: // - Private variables can be made local to the new function // - Shared variables will be passed as arguments // - Modified+shared variables will be passed as pointers unsigned idx = 0; auto shared = getSharedVars(); shared.insert(globalsToOutline.begin(), globalsToOutline.end()); auto mod = getModVars(); for (auto id : shared) { Var *var = M->getVar(id); seqassertn(var, "unknown var id [{}]", var->getSrcInfo()); remap.emplace_back(var, nullptr); const bool isMod = (mod.count(id) > 0); types::Type *type = isMod ? M->getPointerType(var->getType()) : var->getType(); argTypes.push_back(type); argNames.push_back(var->getName()); argKinds.push_back(isMod ? OutlineResult::ArgKind::MODIFIED : OutlineResult::ArgKind::CONSTANT); } // Check if we need to handle control flow externally. // If so, function will return an int code indicating control. const bool callIndicatesControl = !outFlows.empty(); if (callIndicatesControl && !allowOutflows) return {}; auto *funcType = M->getFuncType( callIndicatesControl ? M->getIntType() : M->getNoneType(), argTypes); auto *outlinedFunc = M->Nr("__outlined"); outlinedFunc->realize(funcType, argNames); // Insert function arguments in variable remappings. idx = 0; for (auto it = outlinedFunc->arg_begin(); it != outlinedFunc->arg_end(); ++it) { remap[idx] = {std::get<0>(remap[idx]), *it}; ++idx; } // Make private vars locals of the new function. for (auto id : getPrivateVars()) { Var *var = M->getVar(id); seqassertn(var, "unknown var id [{}]", var->getSrcInfo()); Var *newVar = M->N(var->getSrcInfo(), var->getType(), /*global=*/false, /*external=*/false, /*tls=*/false, var->getName()); remap.emplace_back(var, newVar); outlinedFunc->push_back(newVar); } // Delete outlined region from parent function and insert into outlined function. auto *body = M->N((*begin)->getSrcInfo()); auto it = begin; while (it != end) { body->push_back(*it); it = flowRegion->erase(it); } outlinedFunc->setBody(body); // Replace vars and externally-handled flows. OutlineReplacer outRep(M, mod, remap, outFlows); body->accept(outRep); // Determine arguments for call to outlined function. std::vector args; for (unsigned i = 0; i < shared.size(); i++) { Var *var = std::get<0>(remap[i]); Value *arg = (mod.count(var->getId()) > 0) ? static_cast(M->Nr(var)) : M->Nr(var); args.push_back(arg); } auto *outlinedCall = call(outlinedFunc, args); // Check if we need external control-flow handling. if (callIndicatesControl) { auto *codeVar = M->Nr(M->getIntType()); // result of outlined func call parent->push_back(codeVar); it = flowRegion->insert(it, M->Nr(codeVar, outlinedCall)); // Check each return code of the function. 0 means normal return; do nothing. for (unsigned i = 0; i < outFlows.size(); i++) { // Generate "if (result == code) { action }". auto *codeVal = M->getInt(i + 1); // 1-based by convention auto *codeCheck = (*codeVal == *M->Nr(codeVar)); auto *codeBody = series(outFlows[i]); auto *codeIf = M->Nr(codeCheck, codeBody); ++it; it = flowRegion->insert(it, codeIf); } } else { it = flowRegion->insert(it, outlinedCall); } return {outlinedFunc, outlinedCall, argKinds, static_cast(outFlows.size())}; } }; } // namespace OutlineResult outlineRegion(BodiedFunc *parent, SeriesFlow *series, decltype(series->begin()) begin, decltype(series->end()) end, bool allowOutflows, bool outlineGlobals, bool allByValue) { if (begin == end) return {}; Outliner outliner(parent, series, begin, end, outlineGlobals, allByValue); parent->accept(outliner); return outliner.outline(allowOutflows); } OutlineResult outlineRegion(BodiedFunc *parent, SeriesFlow *series, bool allowOutflows, bool outlineGlobals, bool allByValue) { return outlineRegion(parent, series, series->begin(), series->end(), allowOutflows, outlineGlobals, allByValue); } } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/outlining.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/cir.h" namespace codon { namespace ir { namespace util { /// The result of an outlining operation. struct OutlineResult { /// Information about an argument of an outlined function. enum ArgKind { CONSTANT, ///< Argument is not modified by outlined function MODIFIED, ///< Argument is modified and passed by pointer }; /// The outlined function BodiedFunc *func = nullptr; /// The call to the outlined function CallInstr *call = nullptr; /// Information about each argument of the outlined function. /// "CONSTANT" arguments are passed by value; "MODIFIED" /// arguments are passed by pointer and written to by the /// outlined function. The size of this vector is the same /// as the number of arguments of the outlined function; each /// entry corresponds to one of those arguments. std::vector argKinds; /// Number of externally-handled control flows. /// For example, an outlined function that contains a "break" /// of a non-outlined loop will return an integer code that /// tells the callee to perform this break. A series of /// if-statements are added to the call site to check the /// returned code and perform the correct action. This value /// is the number of if-statements generated. If it is zero, /// the function returns void and no such checks are done. int numOutFlows = 0; operator bool() const { return bool(func); } }; /// Outlines a region of IR delineated by begin and end iterators /// on a particular series flow. The outlined code will be replaced /// by a call to the outlined function, and possibly extra logic if /// control flow needs to be handled. /// @param parent the function containing the series flow /// @param series the series flow on which outlining will happen /// @param begin start of outlining /// @param end end of outlining (non-inclusive like standard iterators) /// @param allowOutflows allow outlining regions with "out-flows" /// @param outlineGlobals outline globals as arguments to outlined function /// @param allByValue pass all outlined vars by value (can change semantics) /// @return the result of outlining OutlineResult outlineRegion(BodiedFunc *parent, SeriesFlow *series, decltype(series->begin()) begin, decltype(series->end()) end, bool allowOutflows = true, bool outlineGlobals = false, bool allByValue = false); /// Outlines a series flow from its parent function. The outlined code /// will be replaced by a call to the outlined function, and possibly /// extra logic if control flow needs to be handled. /// @param parent the function containing the series flow /// @param series the series flow on which outlining will happen /// @param allowOutflows allow outlining regions with "out-flows" /// @param outlineGlobals outline globals as arguments to outlined function /// @param allByValue pass all outlined vars by value (can change semantics) /// @return the result of outlining OutlineResult outlineRegion(BodiedFunc *parent, SeriesFlow *series, bool allowOutflows = true, bool outlineGlobals = false, bool allByValue = false); } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/packs.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include namespace codon { namespace ir { namespace util { /// Utility function to strip parameter packs. /// @param dst the destination vector /// @param first the value template void stripPack(std::vector &dst, Desired &first) { dst.push_back(&first); } /// Utility function to strip parameter packs. /// @param dst the destination vector template void stripPack(std::vector &dst) {} /// Utility function to strip parameter packs. /// @param dst the destination vector /// @param first the value /// @param args the argument pack template void stripPack(std::vector &dst, Desired &first, Args &&...args) { dst.push_back(&first); stripPack(dst, std::forward(args)...); } } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/side_effect.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "side_effect.h" #include "codon/parser/common.h" namespace codon { namespace ir { namespace util { const std::string NON_PURE_ATTR = ast::getMangledFunc("std.internal.attributes", "nonpure"); const std::string PURE_ATTR = ast::getMangledFunc("std.internal.core", "pure"); const std::string NO_SIDE_EFFECT_ATTR = ast::getMangledFunc("std.internal.attributes", "no_side_effect"); const std::string NO_CAPTURE_ATTR = ast::getMangledFunc("std.internal.attributes", "nocapture"); const std::string DERIVES_ATTR = ast::getMangledFunc("std.internal.core", "derives"); const std::string SELF_CAPTURES_ATTR = ast::getMangledFunc("std.internal.attributes", "self_captures"); } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/side_effect.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include namespace codon { namespace ir { namespace util { /// Function side effect status. "Pure" functions by definition give the same /// output for the same inputs and have no side effects. "No side effect" /// functions have no side effects, but can give different outputs for the /// same input (e.g. time() is one such function). "No capture" functions do /// not capture any of their arguments; note that capturing an argument is /// considered a side effect. Therefore, we have pure < no_side_effect < /// no_capture < unknown, where "<" denotes subset. The enum values are also /// ordered in this way, which is relied on by the implementation. enum SideEffectStatus { PURE = 0, NO_SIDE_EFFECT, NO_CAPTURE, UNKNOWN, }; extern const std::string NON_PURE_ATTR; extern const std::string PURE_ATTR; extern const std::string NO_SIDE_EFFECT_ATTR; extern const std::string NO_CAPTURE_ATTR; extern const std::string DERIVES_ATTR; extern const std::string SELF_CAPTURES_ATTR; } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/visitor.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "visitor.h" #include "codon/cir/cir.h" namespace codon { namespace ir { namespace util { void Visitor::visit(Module *x) { defaultVisit(x); } void Visitor::visit(Var *x) { defaultVisit(x); } void Visitor::visit(Func *x) { defaultVisit(x); } void Visitor::visit(BodiedFunc *x) { defaultVisit(x); } void Visitor::visit(ExternalFunc *x) { defaultVisit(x); } void Visitor::visit(InternalFunc *x) { defaultVisit(x); } void Visitor::visit(LLVMFunc *x) { defaultVisit(x); } void Visitor::visit(Value *x) { defaultVisit(x); } void Visitor::visit(VarValue *x) { defaultVisit(x); } void Visitor::visit(PointerValue *x) { defaultVisit(x); } void Visitor::visit(Flow *x) { defaultVisit(x); } void Visitor::visit(SeriesFlow *x) { defaultVisit(x); } void Visitor::visit(IfFlow *x) { defaultVisit(x); } void Visitor::visit(WhileFlow *x) { defaultVisit(x); } void Visitor::visit(ForFlow *x) { defaultVisit(x); } void Visitor::visit(ImperativeForFlow *x) { defaultVisit(x); } void Visitor::visit(TryCatchFlow *x) { defaultVisit(x); } void Visitor::visit(PipelineFlow *x) { defaultVisit(x); } void Visitor::visit(dsl::CustomFlow *x) { defaultVisit(x); } void Visitor::visit(Const *x) { defaultVisit(x); } void Visitor::visit(TemplatedConst *x) { defaultVisit(x); } void Visitor::visit(TemplatedConst *x) { defaultVisit(x); } void Visitor::visit(TemplatedConst *x) { defaultVisit(x); } void Visitor::visit(TemplatedConst *x) { defaultVisit(x); } void Visitor::visit(dsl::CustomConst *x) { defaultVisit(x); } void Visitor::visit(Instr *x) { defaultVisit(x); } void Visitor::visit(AssignInstr *x) { defaultVisit(x); } void Visitor::visit(ExtractInstr *x) { defaultVisit(x); } void Visitor::visit(InsertInstr *x) { defaultVisit(x); } void Visitor::visit(CallInstr *x) { defaultVisit(x); } void Visitor::visit(StackAllocInstr *x) { defaultVisit(x); } void Visitor::visit(YieldInInstr *x) { defaultVisit(x); } void Visitor::visit(TernaryInstr *x) { defaultVisit(x); } void Visitor::visit(BreakInstr *x) { defaultVisit(x); } void Visitor::visit(ContinueInstr *x) { defaultVisit(x); } void Visitor::visit(ReturnInstr *x) { defaultVisit(x); } void Visitor::visit(TypePropertyInstr *x) { defaultVisit(x); } void Visitor::visit(YieldInstr *x) { defaultVisit(x); } void Visitor::visit(AwaitInstr *x) { defaultVisit(x); } void Visitor::visit(ThrowInstr *x) { defaultVisit(x); } void Visitor::visit(FlowInstr *x) { defaultVisit(x); } void Visitor::visit(dsl::CustomInstr *x) { defaultVisit(x); } void Visitor::visit(types::Type *x) { defaultVisit(x); } void Visitor::visit(types::PrimitiveType *x) { defaultVisit(x); } void Visitor::visit(types::IntType *x) { defaultVisit(x); } void Visitor::visit(types::FloatType *x) { defaultVisit(x); } void Visitor::visit(types::Float32Type *x) { defaultVisit(x); } void Visitor::visit(types::Float16Type *x) { defaultVisit(x); } void Visitor::visit(types::BFloat16Type *x) { defaultVisit(x); } void Visitor::visit(types::Float128Type *x) { defaultVisit(x); } void Visitor::visit(types::BoolType *x) { defaultVisit(x); } void Visitor::visit(types::ByteType *x) { defaultVisit(x); } void Visitor::visit(types::VoidType *x) { defaultVisit(x); } void Visitor::visit(types::RecordType *x) { defaultVisit(x); } void Visitor::visit(types::RefType *x) { defaultVisit(x); } void Visitor::visit(types::FuncType *x) { defaultVisit(x); } void Visitor::visit(types::OptionalType *x) { defaultVisit(x); } void Visitor::visit(types::PointerType *x) { defaultVisit(x); } void Visitor::visit(types::GeneratorType *x) { defaultVisit(x); } void Visitor::visit(types::IntNType *x) { defaultVisit(x); } void Visitor::visit(types::VectorType *x) { defaultVisit(x); } void Visitor::visit(types::UnionType *x) { defaultVisit(x); } void Visitor::visit(dsl::types::CustomType *x) { defaultVisit(x); } void ConstVisitor::visit(const Module *x) { defaultVisit(x); } void ConstVisitor::visit(const Var *x) { defaultVisit(x); } void ConstVisitor::visit(const Func *x) { defaultVisit(x); } void ConstVisitor::visit(const BodiedFunc *x) { defaultVisit(x); } void ConstVisitor::visit(const ExternalFunc *x) { defaultVisit(x); } void ConstVisitor::visit(const InternalFunc *x) { defaultVisit(x); } void ConstVisitor::visit(const LLVMFunc *x) { defaultVisit(x); } void ConstVisitor::visit(const Value *x) { defaultVisit(x); } void ConstVisitor::visit(const VarValue *x) { defaultVisit(x); } void ConstVisitor::visit(const PointerValue *x) { defaultVisit(x); } void ConstVisitor::visit(const Flow *x) { defaultVisit(x); } void ConstVisitor::visit(const SeriesFlow *x) { defaultVisit(x); } void ConstVisitor::visit(const IfFlow *x) { defaultVisit(x); } void ConstVisitor::visit(const WhileFlow *x) { defaultVisit(x); } void ConstVisitor::visit(const ForFlow *x) { defaultVisit(x); } void ConstVisitor::visit(const ImperativeForFlow *x) { defaultVisit(x); } void ConstVisitor::visit(const TryCatchFlow *x) { defaultVisit(x); } void ConstVisitor::visit(const PipelineFlow *x) { defaultVisit(x); } void ConstVisitor::visit(const dsl::CustomFlow *x) { defaultVisit(x); } void ConstVisitor::visit(const Const *x) { defaultVisit(x); } void ConstVisitor::visit(const TemplatedConst *x) { defaultVisit(x); } void ConstVisitor::visit(const TemplatedConst *x) { defaultVisit(x); } void ConstVisitor::visit(const TemplatedConst *x) { defaultVisit(x); } void ConstVisitor::visit(const TemplatedConst *x) { defaultVisit(x); } void ConstVisitor::visit(const dsl::CustomConst *x) { defaultVisit(x); } void ConstVisitor::visit(const Instr *x) { defaultVisit(x); } void ConstVisitor::visit(const AssignInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const ExtractInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const InsertInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const CallInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const StackAllocInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const YieldInInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const TernaryInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const BreakInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const ContinueInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const ReturnInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const TypePropertyInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const YieldInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const AwaitInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const ThrowInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const FlowInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const dsl::CustomInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const types::Type *x) { defaultVisit(x); } void ConstVisitor::visit(const types::PrimitiveType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::IntType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::FloatType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::Float32Type *x) { defaultVisit(x); } void ConstVisitor::visit(const types::Float16Type *x) { defaultVisit(x); } void ConstVisitor::visit(const types::BFloat16Type *x) { defaultVisit(x); } void ConstVisitor::visit(const types::Float128Type *x) { defaultVisit(x); } void ConstVisitor::visit(const types::BoolType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::ByteType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::VoidType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::RecordType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::RefType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::FuncType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::OptionalType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::PointerType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::GeneratorType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::IntNType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::VectorType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::UnionType *x) { defaultVisit(x); } void ConstVisitor::visit(const dsl::types::CustomType *x) { defaultVisit(x); } } // namespace util } // namespace ir } // namespace codon ================================================ FILE: codon/cir/util/visitor.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #define VISIT(x) virtual void visit(codon::ir::x *) #define CONST_VISIT(x) virtual void visit(const codon::ir::x *) namespace codon { namespace ir { class Node; namespace types { class Type; class PrimitiveType; class IntType; class FloatType; class Float32Type; class Float16Type; class BFloat16Type; class Float128Type; class BoolType; class ByteType; class VoidType; class RecordType; class RefType; class FuncType; class OptionalType; class PointerType; class GeneratorType; class IntNType; class VectorType; class UnionType; } // namespace types namespace dsl { namespace types { class CustomType; } class CustomConst; class CustomFlow; class CustomInstr; } // namespace dsl class Module; class Var; class Func; class BodiedFunc; class ExternalFunc; class InternalFunc; class LLVMFunc; class Value; class VarValue; class PointerValue; class Flow; class SeriesFlow; class IfFlow; class WhileFlow; class ForFlow; class ImperativeForFlow; class TryCatchFlow; class PipelineFlow; class Const; template class TemplatedConst; class Instr; class AssignInstr; class ExtractInstr; class InsertInstr; class CallInstr; class StackAllocInstr; class TypePropertyInstr; class YieldInInstr; class TernaryInstr; class BreakInstr; class ContinueInstr; class ReturnInstr; class YieldInstr; class AwaitInstr; class ThrowInstr; class FlowInstr; namespace util { /// Base for CIR visitors class Visitor { protected: virtual void defaultVisit(codon::ir::Node *) { throw std::runtime_error("cannot visit node"); } public: virtual ~Visitor() noexcept = default; VISIT(Module); VISIT(Var); VISIT(Func); VISIT(BodiedFunc); VISIT(ExternalFunc); VISIT(InternalFunc); VISIT(LLVMFunc); VISIT(Value); VISIT(VarValue); VISIT(PointerValue); VISIT(Flow); VISIT(SeriesFlow); VISIT(IfFlow); VISIT(WhileFlow); VISIT(ForFlow); VISIT(ImperativeForFlow); VISIT(TryCatchFlow); VISIT(PipelineFlow); VISIT(dsl::CustomFlow); VISIT(Const); VISIT(TemplatedConst); VISIT(TemplatedConst); VISIT(TemplatedConst); VISIT(TemplatedConst); VISIT(dsl::CustomConst); VISIT(Instr); VISIT(AssignInstr); VISIT(ExtractInstr); VISIT(InsertInstr); VISIT(CallInstr); VISIT(StackAllocInstr); VISIT(TypePropertyInstr); VISIT(YieldInInstr); VISIT(TernaryInstr); VISIT(BreakInstr); VISIT(ContinueInstr); VISIT(ReturnInstr); VISIT(YieldInstr); VISIT(AwaitInstr); VISIT(ThrowInstr); VISIT(FlowInstr); VISIT(dsl::CustomInstr); VISIT(types::Type); VISIT(types::PrimitiveType); VISIT(types::IntType); VISIT(types::FloatType); VISIT(types::Float32Type); VISIT(types::Float16Type); VISIT(types::BFloat16Type); VISIT(types::Float128Type); VISIT(types::BoolType); VISIT(types::ByteType); VISIT(types::VoidType); VISIT(types::RecordType); VISIT(types::RefType); VISIT(types::FuncType); VISIT(types::OptionalType); VISIT(types::PointerType); VISIT(types::GeneratorType); VISIT(types::IntNType); VISIT(types::VectorType); VISIT(types::UnionType); VISIT(dsl::types::CustomType); }; class ConstVisitor { protected: virtual void defaultVisit(const codon::ir::Node *) { throw std::runtime_error("cannot visit const node"); } public: virtual ~ConstVisitor() noexcept = default; CONST_VISIT(Module); CONST_VISIT(Var); CONST_VISIT(Func); CONST_VISIT(BodiedFunc); CONST_VISIT(ExternalFunc); CONST_VISIT(InternalFunc); CONST_VISIT(LLVMFunc); CONST_VISIT(Value); CONST_VISIT(VarValue); CONST_VISIT(PointerValue); CONST_VISIT(Flow); CONST_VISIT(SeriesFlow); CONST_VISIT(IfFlow); CONST_VISIT(WhileFlow); CONST_VISIT(ForFlow); CONST_VISIT(ImperativeForFlow); CONST_VISIT(TryCatchFlow); CONST_VISIT(PipelineFlow); CONST_VISIT(dsl::CustomFlow); CONST_VISIT(Const); CONST_VISIT(TemplatedConst); CONST_VISIT(TemplatedConst); CONST_VISIT(TemplatedConst); CONST_VISIT(TemplatedConst); CONST_VISIT(dsl::CustomConst); CONST_VISIT(Instr); CONST_VISIT(AssignInstr); CONST_VISIT(ExtractInstr); CONST_VISIT(InsertInstr); CONST_VISIT(CallInstr); CONST_VISIT(StackAllocInstr); CONST_VISIT(TypePropertyInstr); CONST_VISIT(YieldInInstr); CONST_VISIT(TernaryInstr); CONST_VISIT(BreakInstr); CONST_VISIT(ContinueInstr); CONST_VISIT(ReturnInstr); CONST_VISIT(YieldInstr); CONST_VISIT(AwaitInstr); CONST_VISIT(ThrowInstr); CONST_VISIT(FlowInstr); CONST_VISIT(dsl::CustomInstr); CONST_VISIT(types::Type); CONST_VISIT(types::PrimitiveType); CONST_VISIT(types::IntType); CONST_VISIT(types::FloatType); CONST_VISIT(types::Float32Type); CONST_VISIT(types::Float16Type); CONST_VISIT(types::BFloat16Type); CONST_VISIT(types::Float128Type); CONST_VISIT(types::BoolType); CONST_VISIT(types::ByteType); CONST_VISIT(types::VoidType); CONST_VISIT(types::RecordType); CONST_VISIT(types::RefType); CONST_VISIT(types::FuncType); CONST_VISIT(types::OptionalType); CONST_VISIT(types::PointerType); CONST_VISIT(types::GeneratorType); CONST_VISIT(types::IntNType); CONST_VISIT(types::VectorType); CONST_VISIT(types::UnionType); CONST_VISIT(dsl::types::CustomType); }; } // namespace util } // namespace ir } // namespace codon #undef VISIT #undef CONST_VISIT ================================================ FILE: codon/cir/value.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "value.h" #include "codon/cir/instr.h" #include "codon/cir/module.h" namespace codon { namespace ir { const char Value::NodeId = 0; Value *Value::operator==(Value &other) { return doBinaryOp(Module::EQ_MAGIC_NAME, other); } Value *Value::operator!=(Value &other) { return doBinaryOp(Module::NE_MAGIC_NAME, other); } Value *Value::operator<(Value &other) { return doBinaryOp(Module::LT_MAGIC_NAME, other); } Value *Value::operator>(Value &other) { return doBinaryOp(Module::GT_MAGIC_NAME, other); } Value *Value::operator<=(Value &other) { return doBinaryOp(Module::LE_MAGIC_NAME, other); } Value *Value::operator>=(Value &other) { return doBinaryOp(Module::GE_MAGIC_NAME, other); } Value *Value::operator+() { return doUnaryOp(Module::POS_MAGIC_NAME); } Value *Value::operator-() { return doUnaryOp(Module::NEG_MAGIC_NAME); } Value *Value::operator~() { return doUnaryOp(Module::INVERT_MAGIC_NAME); } Value *Value::operator+(Value &other) { return doBinaryOp(Module::ADD_MAGIC_NAME, other); } Value *Value::operator-(Value &other) { return doBinaryOp(Module::SUB_MAGIC_NAME, other); } Value *Value::operator*(Value &other) { return doBinaryOp(Module::MUL_MAGIC_NAME, other); } Value *Value::matMul(Value &other) { return doBinaryOp(Module::MATMUL_MAGIC_NAME, other); } Value *Value::trueDiv(Value &other) { return doBinaryOp(Module::TRUE_DIV_MAGIC_NAME, other); } Value *Value::operator/(Value &other) { return doBinaryOp(Module::FLOOR_DIV_MAGIC_NAME, other); } Value *Value::operator%(Value &other) { return doBinaryOp(Module::MOD_MAGIC_NAME, other); } Value *Value::pow(Value &other) { return doBinaryOp(Module::POW_MAGIC_NAME, other); } Value *Value::operator<<(Value &other) { return doBinaryOp(Module::LSHIFT_MAGIC_NAME, other); } Value *Value::operator>>(Value &other) { return doBinaryOp(Module::RSHIFT_MAGIC_NAME, other); } Value *Value::operator&(Value &other) { return doBinaryOp(Module::AND_MAGIC_NAME, other); } Value *Value::operator|(Value &other) { return doBinaryOp(Module::OR_MAGIC_NAME, other); } Value *Value::operator^(Value &other) { return doBinaryOp(Module::XOR_MAGIC_NAME, other); } Value *Value::operator||(Value &other) { auto *module = getModule(); return module->Nr(toBool(), module->getBool(true), other.toBool()); } Value *Value::operator&&(Value &other) { auto *module = getModule(); return module->Nr(toBool(), other.toBool(), module->getBool(false)); } Value *Value::operator[](Value &other) { return doBinaryOp(Module::GETITEM_MAGIC_NAME, other); } Value *Value::toInt() { return doUnaryOp(Module::INT_MAGIC_NAME); } Value *Value::toFloat() { return doUnaryOp(Module::FLOAT_MAGIC_NAME); } Value *Value::toBool() { return doUnaryOp(Module::BOOL_MAGIC_NAME); } Value *Value::toStr() { return doUnaryOp(Module::REPR_MAGIC_NAME); } Value *Value::len() { return doUnaryOp(Module::LEN_MAGIC_NAME); } Value *Value::iter() { return doUnaryOp(Module::ITER_MAGIC_NAME); } Value *Value::doUnaryOp(const std::string &name) { auto *module = getModule(); auto *fn = module->getOrRealizeMethod(getType(), name, std::vector{getType()}); if (!fn) return nullptr; auto *fnVal = module->Nr(fn); return (*fnVal)(*this); } Value *Value::doBinaryOp(const std::string &name, Value &other) { auto *module = getModule(); auto *fn = module->getOrRealizeMethod( getType(), name, std::vector{getType(), other.getType()}); if (!fn) return nullptr; auto *fnVal = module->Nr(fn); return (*fnVal)(*this, other); } Value *Value::doCall(const std::vector &args) { auto *module = getModule(); return module->Nr(this, args); } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/value.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/base.h" #include "codon/cir/types/types.h" #include "codon/cir/util/packs.h" namespace codon { namespace ir { class Func; class Value : public ReplaceableNodeBase, public IdMixin { public: static const char NodeId; /// Constructs a value. /// @param the value's name explicit Value(std::string name = "") : ReplaceableNodeBase(std::move(name)) {} virtual ~Value() noexcept = default; std::string referenceString() const final { return fmt::format(FMT_STRING("{}.{}"), getName(), getId()); } std::vector getUsedValues() final { return getActual()->doGetUsedValues(); } std::vector getUsedValues() const final { auto ret = getActual()->doGetUsedValues(); return std::vector(ret.begin(), ret.end()); } int replaceUsedValue(id_t id, Value *newValue) final { return getActual()->doReplaceUsedValue(id, newValue); } using Node::replaceUsedValue; std::vector getUsedTypes() const final { return getActual()->doGetUsedTypes(); } int replaceUsedType(const std::string &name, types::Type *newType) final { return getActual()->doReplaceUsedType(name, newType); } using Node::replaceUsedType; std::vector getUsedVariables() final { return getActual()->doGetUsedVariables(); } std::vector getUsedVariables() const final { auto ret = getActual()->doGetUsedVariables(); return std::vector(ret.begin(), ret.end()); } int replaceUsedVariable(id_t id, Var *newVar) final { return getActual()->doReplaceUsedVariable(id, newVar); } using Node::replaceUsedVariable; /// @return the value's type types::Type *getType() const { return getActual()->doGetType(); } id_t getId() const override { return getActual()->id; } Value *operator==(Value &other); Value *operator!=(Value &other); Value *operator<(Value &other); Value *operator>(Value &other); Value *operator<=(Value &other); Value *operator>=(Value &other); Value *operator+(); Value *operator-(); Value *operator~(); Value *operator+(Value &other); Value *operator-(Value &other); Value *operator*(Value &other); Value *matMul(Value &other); Value *trueDiv(Value &other); Value *operator/(Value &other); Value *operator%(Value &other); Value *pow(Value &other); Value *operator<<(Value &other); Value *operator>>(Value &other); Value *operator&(Value &other); Value *operator|(Value &other); Value *operator^(Value &other); Value *operator||(Value &other); Value *operator&&(Value &other); template Value *operator()(Args &&...args) { std::vector dst; util::stripPack(dst, std::forward(args)...); return doCall(dst); } Value *operator[](Value &other); Value *toInt(); Value *toFloat(); Value *toBool(); Value *toStr(); Value *len(); Value *iter(); private: Value *doUnaryOp(const std::string &name); Value *doBinaryOp(const std::string &name, Value &other); Value *doCall(const std::vector &args); virtual types::Type *doGetType() const = 0; virtual std::vector doGetUsedValues() const { return {}; } virtual int doReplaceUsedValue(id_t id, Value *newValue) { return 0; } virtual std::vector doGetUsedTypes() const { return {}; } virtual int doReplaceUsedType(const std::string &name, types::Type *newType) { return 0; } virtual std::vector doGetUsedVariables() const { return {}; } virtual int doReplaceUsedVariable(id_t id, Var *newVar) { return 0; } }; } // namespace ir } // namespace codon template struct fmt::formatter< T, std::enable_if_t::value, char>> : fmt::ostream_formatter {}; ================================================ FILE: codon/cir/var.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "var.h" #include "codon/cir/module.h" namespace codon { namespace ir { const char Var::NodeId = 0; Var::Var(types::Type *type, bool global, bool external, bool tls, std::string name) : ReplaceableNodeBase(std::move(name)), type(type), global(global), external(external), tls(tls) {} int Var::doReplaceUsedType(const std::string &name, types::Type *newType) { if (type->getName() == name) { type = newType; return 1; } return 0; } const char VarValue::NodeId = 0; int VarValue::doReplaceUsedVariable(id_t id, Var *newVar) { if (val->getId() == id) { val = newVar; return 1; } return 0; } const char PointerValue::NodeId = 0; types::Type *PointerValue::doGetType() const { return getModule()->getPointerType(val->getType()); } int PointerValue::doReplaceUsedVariable(id_t id, Var *newVar) { if (val->getId() == id) { val = newVar; return 1; } return 0; } } // namespace ir } // namespace codon ================================================ FILE: codon/cir/var.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/cir/types/types.h" #include "codon/cir/value.h" #include "codon/util/common.h" #include #include namespace codon { namespace ir { class Func; class Var; /// CIR object representing a variable. class Var : public ReplaceableNodeBase, public IdMixin { private: /// the variable's type types::Type *type; /// true if the variable is global bool global; /// true if the variable is external bool external; /// true if the variable is thread-local bool tls; public: static const char NodeId; /// Constructs a variable. /// @param type the variable's type /// @param global true if the variable is global /// @param external true if the variable is external /// @param tls true if the variable is thread-local /// @param name the variable's name explicit Var(types::Type *type, bool global = false, bool external = false, bool tls = false, std::string name = ""); virtual ~Var() noexcept = default; std::vector getUsedValues() final { return getActual()->doGetUsedValues(); } std::vector getUsedValues() const final { auto ret = getActual()->doGetUsedValues(); return std::vector(ret.begin(), ret.end()); } int replaceUsedValue(id_t id, Value *newValue) final { return doReplaceUsedValue(id, newValue); } using Node::replaceUsedValue; std::vector getUsedTypes() const final { return getActual()->doGetUsedTypes(); } int replaceUsedType(const std::string &name, types::Type *newType) final { return getActual()->doReplaceUsedType(name, newType); } using Node::replaceUsedType; std::vector getUsedVariables() final { return doGetUsedVariables(); } std::vector getUsedVariables() const final { auto ret = doGetUsedVariables(); return std::vector(ret.begin(), ret.end()); } int replaceUsedVariable(id_t id, Var *newVar) final { return getActual()->doReplaceUsedVariable(id, newVar); } using Node::replaceUsedVariable; /// @return the type types::Type *getType() const { return getActual()->type; } /// Sets the type. /// @param t the new type void setType(types::Type *t) { getActual()->type = t; } /// @return true if the variable is global bool isGlobal() const { return getActual()->global; } /// Sets the global flag. /// @param v the new value void setGlobal(bool v = true) { getActual()->global = v; } /// @return true if the variable is external bool isExternal() const { return getActual()->external; } /// Sets the external flag. /// @param v the new value void setExternal(bool v = true) { getActual()->external = v; } /// @return true if the variable is thread-local bool isThreadLocal() const { return getActual()->tls; } /// Sets the thread-local flag. /// @param v the new value void setThreadLocal(bool v = true) { getActual()->tls = v; } std::string referenceString() const final { return fmt::format(FMT_STRING("{}.{}"), getName(), getId()); } id_t getId() const override { return getActual()->id; } protected: virtual std::vector doGetUsedValues() const { return {}; } virtual int doReplaceUsedValue(id_t id, Value *newValue) { return 0; } virtual std::vector doGetUsedTypes() const { return {type}; } virtual int doReplaceUsedType(const std::string &name, types::Type *newType); virtual std::vector doGetUsedVariables() const { return {}; } virtual int doReplaceUsedVariable(id_t id, Var *newVar) { return 0; } }; /// Value that contains an unowned variable reference. class VarValue : public AcceptorExtend { private: /// the referenced var Var *val; public: static const char NodeId; /// Constructs a variable value. /// @param val the referenced value /// @param name the name explicit VarValue(Var *val, std::string name = "") : AcceptorExtend(std::move(name)), val(val) {} /// @return the variable Var *getVar() { return val; } /// @return the variable const Var *getVar() const { return val; } /// Sets the variable. /// @param v the new variable void setVar(Var *v) { val = v; } private: types::Type *doGetType() const override { return val->getType(); } std::vector doGetUsedVariables() const override { return {val}; } int doReplaceUsedVariable(id_t id, Var *newVar) override; }; /// Value that represents a pointer. class PointerValue : public AcceptorExtend { private: /// the referenced var Var *val; /// sequence of fields indicating where pointer should point std::vector fields; public: static const char NodeId; /// Constructs a pointer value. /// @param val the referenced value /// @param fields the sequence of fields, or empty to get var pointer /// @param name the name explicit PointerValue(Var *val, std::vector fields, std::string name = "") : AcceptorExtend(std::move(name)), val(val), fields(std::move(fields)) {} /// Constructs a pointer value. /// @param val the referenced value /// @param name the name explicit PointerValue(Var *val, std::string name = "") : PointerValue(val, {}, name) {} /// @return the variable Var *getVar() { return val; } /// @return the variable const Var *getVar() const { return val; } /// Sets the variable. /// @param v the new variable void setVar(Var *v) { val = v; } /// @return the sequence of fields const std::vector &getFields() const { return fields; } /// Sets the sequence of fields /// @param f the new fields void setFields(std::vector f) { fields = std::move(f); } private: types::Type *doGetType() const override; std::vector doGetUsedVariables() const override { return {val}; } int doReplaceUsedVariable(id_t id, Var *newVar) override; }; } // namespace ir } // namespace codon template struct fmt::formatter::value, char>> : fmt::ostream_formatter {}; ================================================ FILE: codon/compiler/compiler.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "compiler.h" #include "codon/compiler/error.h" #include "codon/parser/cache.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/doc/doc.h" #include "codon/parser/visitors/format/format.h" #include "codon/parser/visitors/translate/translate.h" #include "codon/parser/visitors/typecheck/typecheck.h" extern double totalPeg; namespace codon { namespace { ir::transform::PassManager::Init getPassManagerInit(Compiler::Mode mode, bool isTest) { using ir::transform::PassManager; switch (mode) { case Compiler::Mode::DEBUG: return isTest ? PassManager::Init::RELEASE : PassManager::Init::DEBUG; case Compiler::Mode::RELEASE: return PassManager::Init::RELEASE; case Compiler::Mode::JIT: return PassManager::Init::JIT; default: return PassManager::Init::EMPTY; } } } // namespace Compiler::Compiler(const std::string &argv0, Compiler::Mode mode, const std::vector &disabledPasses, bool isTest, bool pyNumerics, bool pyExtension, const std::shared_ptr &fs) : argv0(argv0), debug(mode == Mode::DEBUG), pyNumerics(pyNumerics), pyExtension(pyExtension), input(), plm(std::make_unique(argv0)), cache(std::make_unique(argv0, fs)), module(std::make_unique()), pm(std::make_unique( getPassManagerInit(mode, isTest), disabledPasses, pyNumerics, pyExtension)), llvisitor(std::make_unique()) { cache->module = module.get(); cache->pythonExt = pyExtension; cache->pythonCompat = pyNumerics; cache->compiler = this; module->setCache(cache.get()); llvisitor->setDebug(debug); llvisitor->setPluginManager(plm.get()); } llvm::Error Compiler::load(const std::string &plugin) { auto result = plm->load(plugin); if (auto err = result.takeError()) return err; auto *p = *result; if (!p->info.stdlibPath.empty()) { cache->fs->add_search_path(p->info.stdlibPath); } for (auto &kw : p->dsl->getExprKeywords()) { cache->customExprStmts[kw.keyword] = kw.callback; } for (auto &kw : p->dsl->getBlockKeywords()) { cache->customBlockStmts[kw.keyword] = {kw.hasExpr, kw.callback}; } p->dsl->addIRPasses(pm.get(), debug); loadedPlugins.insert(plugin); return llvm::Error::success(); } /// Checks if a plugin is already loaded. bool Compiler::isPluginLoaded(const std::string &path) const { return loadedPlugins.find(path) != loadedPlugins.end(); } llvm::Error Compiler::parse(bool isCode, const std::string &file, const std::string &code, int startLine, int testFlags, const std::unordered_map &defines) { input = file; std::string abspath = (file != "-") ? std::string(cache->fs->canonical(file)) : file; try { auto nodeOrErr = isCode ? ast::parseCode(cache.get(), abspath, code, startLine) : ast::parseFile(cache.get(), abspath); if (!nodeOrErr) throw exc::ParserException(nodeOrErr.takeError()); auto codeStmt = *nodeOrErr; cache->fs->set_module0(file); Timer t2("typecheck"); t2.logged = true; auto typechecked = ast::TypecheckVisitor::apply( cache.get(), codeStmt, abspath, defines, getEarlyDefines(), (testFlags > 1)); LOG_TIME("[T] parse = {:.1f}", totalPeg); LOG_TIME("[T] typecheck = {:.1f}", t2.elapsed() - totalPeg); if (codon::getLogger().flags & codon::Logger::FLAG_USER) { auto fo = fopen("_dump_typecheck.sexp", "w"); fmt::print(fo, "{}\n", typechecked->toString(0)); for (auto &f : cache->functions) for (auto &r : f.second.realizations) { fmt::print(fo, "{}\n", r.second->ast->toString(0)); } fclose(fo); fo = fopen("_dump_typecheck.htm", "w"); auto s = ast::FormatVisitor::apply(typechecked, cache.get(), true); fmt::print(fo, "{}\n", s); fclose(fo); } Timer t4("translate"); ast::TranslateVisitor::apply(cache.get(), std::move(typechecked)); t4.log(); } catch (const exc::ParserException &exc) { return llvm::make_error(exc.getErrors()); } module->setSrcInfo({abspath, 0, 0, 0}); if (codon::getLogger().flags & codon::Logger::FLAG_USER) { auto fo = fopen("_dump_ir.sexp", "w"); fmt::print(fo, "{}\n", *module); fclose(fo); } return llvm::Error::success(); } llvm::Error Compiler::parseFile(const std::string &file, int testFlags, const std::unordered_map &defines) { return parse(/*isCode=*/false, file, /*code=*/"", /*startLine=*/0, testFlags, defines); } llvm::Error Compiler::parseCode(const std::string &file, const std::string &code, int startLine, int testFlags, const std::unordered_map &defines) { return parse(/*isCode=*/true, file, code, startLine, testFlags, defines); } llvm::Error Compiler::compile() { pm->run(module.get()); if (codon::getLogger().flags & codon::Logger::FLAG_USER) { auto fo = fopen("_dump_ir_opt.sexp", "w"); fmt::print(fo, "{}\n", *module); fclose(fo); } llvisitor->visit(module.get()); if (codon::getLogger().flags & codon::Logger::FLAG_USER) { auto fo = fopen("_dump_llvm.ll", "w"); std::string str; llvm::raw_string_ostream os(str); os << *(llvisitor->getModule()); os.flush(); fmt::print(fo, "{}\n", str); fclose(fo); } return llvm::Error::success(); } llvm::Expected Compiler::docgen(const std::vector &files) { try { auto j = ast::DocVisitor::apply(argv0, files); return j->toString(); } catch (exc::ParserException &exc) { return llvm::make_error(exc.getErrors()); } } std::unordered_map Compiler::getEarlyDefines() { std::unordered_map earlyDefines; earlyDefines.emplace("__debug__", debug ? "1" : "0"); earlyDefines.emplace("__py_numerics__", pyNumerics ? "1" : "0"); earlyDefines.emplace("__py_extension__", pyExtension ? "1" : "0"); earlyDefines.emplace("__apple__", #if __APPLE__ "1" #else "0" #endif ); return earlyDefines; } } // namespace codon ================================================ FILE: codon/compiler/compiler.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/cir/llvm/llvisitor.h" #include "codon/cir/module.h" #include "codon/cir/transform/manager.h" #include "codon/compiler/error.h" #include "codon/dsl/plugins.h" #include "codon/parser/cache.h" namespace codon { class Compiler { public: enum Mode { DEBUG, RELEASE, JIT, }; private: std::string argv0; bool debug; bool pyNumerics; bool pyExtension; std::string input; std::unique_ptr plm; std::unique_ptr cache; std::unique_ptr module; std::unique_ptr pm; std::unique_ptr llvisitor; std::unordered_set loadedPlugins; llvm::Error parse(bool isCode, const std::string &file, const std::string &code, int startLine, int testFlags, const std::unordered_map &defines); public: Compiler(const std::string &argv0, Mode mode, const std::vector &disabledPasses = {}, bool isTest = false, bool pyNumerics = false, bool pyExtension = false, const std::shared_ptr &fs = nullptr); explicit Compiler(const std::string &argv0, bool debug = false, const std::vector &disabledPasses = {}, bool isTest = false, bool pyNumerics = false, bool pyExtension = false, const std::shared_ptr &fs = nullptr) : Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest, pyNumerics, pyExtension, fs) {} std::string getArgv0() const { return argv0; } std::string getInput() const { return input; } PluginManager *getPluginManager() const { return plm.get(); } ast::Cache *getCache() const { return cache.get(); } ir::Module *getModule() const { return module.get(); } ir::transform::PassManager *getPassManager() const { return pm.get(); } ir::LLVMVisitor *getLLVMVisitor() const { return llvisitor.get(); } bool isPluginLoaded(const std::string &) const; llvm::Error load(const std::string &plugin); llvm::Error parseFile(const std::string &file, int testFlags = 0, const std::unordered_map &defines = {}); llvm::Error parseCode(const std::string &file, const std::string &code, int startLine = 0, int testFlags = 0, const std::unordered_map &defines = {}); llvm::Error compile(); llvm::Expected docgen(const std::vector &files); std::unordered_map getEarlyDefines(); }; } // namespace codon ================================================ FILE: codon/compiler/debug_listener.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "debug_listener.h" #include #include #include #include "codon/runtime/lib.h" namespace codon { namespace { std::string makeBacktrace(const std::vector &backtrace, std::function(uintptr_t)> backtraceCallback) { std::ostringstream buf; buf << "\033[1mBacktrace:\033[0m\n"; for (auto pc : backtrace) { auto line = backtraceCallback(pc); if (!line) break; if (!line->empty()) buf << " " << *line << "\n"; } return buf.str(); } } // namespace void DebugListener::notifyObjectLoaded(ObjectKey key, const llvm::object::ObjectFile &obj, const llvm::RuntimeDyld::LoadedObjectInfo &L) { uintptr_t start = 0, stop = 0; for (const auto &sec : obj.sections()) { if (sec.isText()) { start = L.getSectionLoadAddress(sec); stop = start + sec.getSize(); break; } } auto buf = llvm::MemoryBuffer::getMemBufferCopy(obj.getData(), obj.getFileName()); auto newObj = llvm::cantFail( llvm::object::ObjectFile::createObjectFile(buf->getMemBufferRef())); objects.emplace_back(key, std::move(newObj), std::move(buf), start, stop); } void DebugListener::notifyFreeingObject(ObjectKey key) { objects.erase( std::remove_if(objects.begin(), objects.end(), [key](const ObjectInfo &o) { return key == o.getKey(); }), objects.end()); } llvm::Expected DebugListener::symbolize(uintptr_t pc) { for (const auto &o : objects) { if (o.contains(pc)) { llvm::symbolize::LLVMSymbolizer sym; return sym.symbolizeCode( o.getObject(), {pc - o.getStart(), llvm::object::SectionedAddress::UndefSection}); } } return llvm::DILineInfo(); } llvm::Expected DebugListener::getPrettyBacktrace(uintptr_t pc) { auto invalid = [](const std::string &name) { return name == ""; }; auto src = symbolize(pc); if (auto err = src.takeError()) return std::move(err); if (invalid(src->FunctionName) || invalid(src->FileName)) return ""; return runtime::makeBacktraceFrameString(pc, src->FunctionName, src->FileName, src->Line, src->Column); } std::string DebugListener::getPrettyBacktrace(const std::vector &backtrace) { return makeBacktrace(backtrace, [&](uintptr_t pc) { return getPrettyBacktrace(pc); }); } void DebugPlugin::notifyMaterializing(llvm::orc::MaterializationResponsibility &mr, llvm::jitlink::LinkGraph &graph, llvm::jitlink::JITLinkContext &ctx, llvm::MemoryBufferRef inputObject) { auto newBuf = llvm::MemoryBuffer::getMemBufferCopy(inputObject.getBuffer(), graph.getName()); auto newObj = llvm::cantFail( llvm::object::ObjectFile::createObjectFile(newBuf->getMemBufferRef())); { std::lock_guard lock(pluginMutex); assert(pendingObjs.count(&mr) == 0); pendingObjs[&mr] = std::unique_ptr( new JITObjectInfo{std::move(newBuf), std::move(newObj), {}}); } } llvm::Error DebugPlugin::notifyEmitted(llvm::orc::MaterializationResponsibility &mr) { { std::lock_guard lock(pluginMutex); auto it = pendingObjs.find(&mr); if (it == pendingObjs.end()) return llvm::Error::success(); auto newInfo = pendingObjs[&mr].get(); auto getLoadAddress = [newInfo](const llvm::StringRef &name) -> uint64_t { auto result = newInfo->sectionLoadAddresses.find(name); if (result == newInfo->sectionLoadAddresses.end()) return 0; return result->second; }; // register(*newInfo->Object, getLoadAddress, nullptr) } llvm::cantFail(mr.withResourceKeyDo([&](llvm::orc::ResourceKey key) { std::lock_guard lock(pluginMutex); registeredObjs[key].push_back(std::move(pendingObjs[&mr])); pendingObjs.erase(&mr); })); return llvm::Error::success(); } llvm::Error DebugPlugin::notifyFailed(llvm::orc::MaterializationResponsibility &mr) { std::lock_guard lock(pluginMutex); pendingObjs.erase(&mr); return llvm::Error::success(); } llvm::Error DebugPlugin::notifyRemovingResources(llvm::orc::JITDylib &jd, llvm::orc::ResourceKey key) { std::lock_guard lock(pluginMutex); registeredObjs.erase(key); return llvm::Error::success(); } void DebugPlugin::notifyTransferringResources(llvm::orc::JITDylib &jd, llvm::orc::ResourceKey dstKey, llvm::orc::ResourceKey srcKey) { std::lock_guard lock(pluginMutex); auto it = registeredObjs.find(srcKey); if (it != registeredObjs.end()) { for (std::unique_ptr &info : it->second) registeredObjs[dstKey].push_back(std::move(info)); registeredObjs.erase(it); } } void DebugPlugin::modifyPassConfig(llvm::orc::MaterializationResponsibility &mr, llvm::jitlink::LinkGraph &graph, llvm::jitlink::PassConfiguration &config) { std::lock_guard lock(pluginMutex); auto it = pendingObjs.find(&mr); if (it == pendingObjs.end()) return; JITObjectInfo &info = *it->second; config.PostAllocationPasses.push_back( [&info, this](llvm::jitlink::LinkGraph &graph) -> llvm::Error { std::lock_guard lock(pluginMutex); for (const llvm::jitlink::Section &sec : graph.sections()) { #if defined(__APPLE__) && defined(__MACH__) size_t secPos = sec.getName().find(','); if (secPos >= 16 || (sec.getName().size() - (secPos + 1) > 16)) continue; auto secName = sec.getName().substr(secPos + 1); #else auto secName = sec.getName(); #endif info.sectionLoadAddresses[secName] = llvm::jitlink::SectionRange(sec).getStart().getValue(); } return llvm::Error::success(); }); } llvm::Expected DebugPlugin::symbolize(uintptr_t pc) { for (const auto &entry : registeredObjs) { for (const auto &info : entry.second) { const auto *o = info->object.get(); for (const auto &sec : o->sections()) { if (sec.isText()) { uintptr_t start = info->sectionLoadAddresses.lookup(llvm::cantFail(sec.getName())); uintptr_t stop = start + sec.getSize(); if (start <= pc && pc < stop) { llvm::symbolize::LLVMSymbolizer sym; return sym.symbolizeCode( *o, {pc - start, llvm::object::SectionedAddress::UndefSection}); } } } } } return llvm::DILineInfo(); } llvm::Expected DebugPlugin::getPrettyBacktrace(uintptr_t pc) { auto invalid = [](const std::string &name) { return name == ""; }; auto src = symbolize(pc); if (auto err = src.takeError()) return std::move(err); if (invalid(src->FunctionName) || invalid(src->FileName)) return ""; return runtime::makeBacktraceFrameString(pc, src->FunctionName, src->FileName, src->Line, src->Column); } std::string DebugPlugin::getPrettyBacktrace(const std::vector &backtrace) { return makeBacktrace(backtrace, [&](uintptr_t pc) { return getPrettyBacktrace(pc); }); } } // namespace codon ================================================ FILE: codon/compiler/debug_listener.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/cir/llvm/llvm.h" namespace codon { /// Debug info tracker for MCJIT. class DebugListener : public llvm::JITEventListener { public: class ObjectInfo { private: ObjectKey key; std::unique_ptr object; std::unique_ptr buffer; uintptr_t start; uintptr_t stop; public: ObjectInfo(ObjectKey key, std::unique_ptr object, std::unique_ptr buffer, uintptr_t start, uintptr_t stop) : key(key), object(std::move(object)), buffer(std::move(buffer)), start(start), stop(stop) {} ObjectKey getKey() const { return key; } const llvm::object::ObjectFile &getObject() const { return *object; } uintptr_t getStart() const { return start; } uintptr_t getStop() const { return stop; } bool contains(uintptr_t pc) const { return start <= pc && pc < stop; } }; private: std::vector objects; void notifyObjectLoaded(ObjectKey key, const llvm::object::ObjectFile &obj, const llvm::RuntimeDyld::LoadedObjectInfo &L) override; void notifyFreeingObject(ObjectKey key) override; public: DebugListener() : llvm::JITEventListener(), objects() {} llvm::Expected symbolize(uintptr_t pc); llvm::Expected getPrettyBacktrace(uintptr_t pc); std::string getPrettyBacktrace(const std::vector &backtrace); }; /// Debug info tracker for JITLink. Adapted from Julia's implementation: /// https://github.com/JuliaLang/julia/blob/master/src/jitlayers.cpp class DebugPlugin : public llvm::orc::ObjectLinkingLayer::Plugin { struct JITObjectInfo { std::unique_ptr backingBuffer; std::unique_ptr object; llvm::StringMap sectionLoadAddresses; }; std::mutex pluginMutex; std::map> pendingObjs; std::map>> registeredObjs; public: void notifyMaterializing(llvm::orc::MaterializationResponsibility &mr, llvm::jitlink::LinkGraph &graph, llvm::jitlink::JITLinkContext &ctx, llvm::MemoryBufferRef inputObject) override; llvm::Error notifyEmitted(llvm::orc::MaterializationResponsibility &mr) override; llvm::Error notifyFailed(llvm::orc::MaterializationResponsibility &mr) override; llvm::Error notifyRemovingResources(llvm::orc::JITDylib &jd, llvm::orc::ResourceKey key) override; void notifyTransferringResources(llvm::orc::JITDylib &jd, llvm::orc::ResourceKey dstKey, llvm::orc::ResourceKey srcKey) override; void modifyPassConfig(llvm::orc::MaterializationResponsibility &mr, llvm::jitlink::LinkGraph &, llvm::jitlink::PassConfiguration &config) override; llvm::Expected symbolize(uintptr_t pc); llvm::Expected getPrettyBacktrace(uintptr_t pc); std::string getPrettyBacktrace(const std::vector &backtrace); }; } // namespace codon ================================================ FILE: codon/compiler/engine.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "engine.h" #include "codon/cir/llvm/optimize.h" #include "codon/compiler/memory_manager.h" namespace codon { namespace jit { Engine::Engine() : jit(), debug(nullptr) { auto eb = llvm::EngineBuilder(); eb.setMArch(llvm::codegen::getMArch()); eb.setMCPU(llvm::codegen::getCPUStr()); eb.setMAttrs(llvm::codegen::getFeatureList()); auto target = eb.selectTarget(); auto layout = target->createDataLayout(); auto epc = llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create( std::make_shared())); llvm::orc::LLJITBuilder builder; builder.setDataLayout(layout); builder.setObjectLinkingLayerCreator( [&](llvm::orc::ExecutionSession &es, const llvm::Triple &triple) -> llvm::Expected> { auto L = std::make_unique( es, llvm::cantFail(BoehmGCJITLinkMemoryManager::Create())); if (auto regOrErr = llvm::orc::createJITLoaderGDBRegistrar(es)) { L->addPlugin(std::make_unique( es, std::move(*regOrErr))); } auto dbPlugin = std::make_unique(); this->debug = dbPlugin.get(); L->addPlugin(std::move(dbPlugin)); L->setAutoClaimResponsibilityForObjectSymbols(true); return L; }); builder.setJITTargetMachineBuilder( llvm::orc::JITTargetMachineBuilder(target->getTargetTriple())); jit = llvm::cantFail(builder.create()); jit->getMainJITDylib().addGenerator( llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( layout.getGlobalPrefix()))); jit->getIRTransformLayer().setTransform( [&](llvm::orc::ThreadSafeModule module, const llvm::orc::MaterializationResponsibility &R) { module.withModuleDo([](llvm::Module &module) { ir::optimize(&module, /*debug=*/false, /*jit=*/true); }); return std::move(module); }); } llvm::Error Engine::addModule(llvm::orc::ThreadSafeModule module, llvm::orc::ResourceTrackerSP rt) { if (!rt) rt = jit->getMainJITDylib().getDefaultResourceTracker(); return jit->addIRModule(rt, std::move(module)); } llvm::Expected Engine::lookup(llvm::StringRef name) { return jit->lookup(name); } } // namespace jit } // namespace codon ================================================ FILE: codon/compiler/engine.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include "codon/cir/llvm/llvm.h" #include "codon/compiler/debug_listener.h" namespace codon { namespace jit { class Engine { private: std::unique_ptr jit; DebugPlugin *debug; public: Engine(); const llvm::DataLayout &getDataLayout() const { return jit->getDataLayout(); } llvm::orc::JITDylib &getMainJITDylib() { return jit->getMainJITDylib(); } DebugPlugin *getDebugListener() const { return debug; } llvm::Error addModule(llvm::orc::ThreadSafeModule module, llvm::orc::ResourceTrackerSP rt = nullptr); llvm::Expected lookup(llvm::StringRef name); }; } // namespace jit } // namespace codon ================================================ FILE: codon/compiler/error.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "error.h" namespace codon { SrcInfo::SrcInfo(std::string file, int line, int col, int len) : file(std::move(file)), line(line), col(col), len(len), id(0) { if (this->file.empty() && line != 0) line++; static int nextId = 0; id = nextId++; }; SrcInfo::SrcInfo() : SrcInfo("", 0, 0, 0) {} bool SrcInfo::operator==(const SrcInfo &src) const { return id == src.id; } bool SrcInfo::operator<(const SrcInfo &src) const { return std::tie(file, line, col) < std::tie(src.file, src.line, src.col); } bool SrcInfo::operator<=(const SrcInfo &src) const { return std::tie(file, line, col) <= std::tie(src.file, src.line, src.col); } std::string ErrorMessage::toString() const { std::string s; if (!getFile().empty()) { s += getFile(); if (getLine() != 0) { s += fmt::format(":{}", getLine()); if (getColumn() != 0) s += fmt::format(":{}", getColumn()); } s += ": "; } s += getMessage(); return s; } namespace error { char ParserErrorInfo::ID = 0; char RuntimeErrorInfo::ID = 0; char PluginErrorInfo::ID = 0; char IOErrorInfo::ID = 0; void E(llvm::Error &&error) { throw exc::ParserException(std::move(error)); } } // namespace error namespace exc { ParserException::ParserException(llvm::Error &&e) noexcept : std::runtime_error("") { llvm::handleAllErrors(std::move(e), [this](const error::ParserErrorInfo &e) { errors = e.getErrors(); }); } } // namespace exc } // namespace codon ================================================ FILE: codon/compiler/error.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include "codon/parser/ast/error.h" #include "llvm/Support/Error.h" #include namespace codon { namespace error { enum Error { CALL_NAME_ORDER, CALL_NAME_STAR, CALL_ELLIPSIS, IMPORT_IDENTIFIER, IMPORT_FN, IMPORT_STAR, FN_LLVM, FN_LAST_KWARG, FN_MULTIPLE_ARGS, FN_DEFAULT_STARARG, FN_ARG_TWICE, FN_DEFAULT, FN_C_DEFAULT, FN_C_TYPE, FN_SINGLE_DECORATOR, CLASS_EXTENSION, CLASS_MISSING_TYPE, CLASS_ARG_TWICE, CLASS_BAD_DECORATOR, CLASS_MULTIPLE_DECORATORS, CLASS_SINGLE_DECORATOR, CLASS_CONFLICT_DECORATOR, CLASS_NONSTATIC_DECORATOR, CLASS_BAD_DECORATOR_ARG, ID_NOT_FOUND, ID_CANNOT_CAPTURE, ID_INVALID_BIND, UNION_TOO_BIG, COMPILER_NO_FILE, COMPILER_NO_STDLIB, ID_NONLOCAL, IMPORT_NO_MODULE, IMPORT_NO_NAME, DEL_NOT_ALLOWED, DEL_INVALID, ASSIGN_INVALID, ASSIGN_LOCAL_REFERENCE, ASSIGN_MULTI_STAR, INT_RANGE, FLOAT_RANGE, STR_FSTRING_BALANCE_EXTRA, STR_FSTRING_BALANCE_MISSING, CALL_NO_TYPE, CALL_TUPLE_COMPREHENSION, CALL_NAMEDTUPLE, CALL_PARTIAL, EXPECTED_TOPLEVEL, CLASS_ID_NOT_FOUND, CLASS_INVALID_BIND, CLASS_NO_INHERIT, CLASS_NO_EXTEND, CLASS_TUPLE_INHERIT, CLASS_BAD_MRO, CLASS_BAD_ATTR, MATCH_MULTI_ELLIPSIS, FN_OUTSIDE_ERROR, FN_GLOBAL_ASSIGNED, FN_GLOBAL_NOT_FOUND, FN_BAD_LLVM, FN_REALIZE_BUILTIN, EXPECTED_LOOP, LOOP_DECORATOR, BAD_STATIC_TYPE, EXPECTED_TYPE, UNEXPECTED_TYPE, DOT_NO_ATTR, DOT_NO_ATTR_ARGS, FN_NO_ATTR_ARGS, EXPECTED_STATIC, EXPECTED_STATIC_SPECIFIED, ASSIGN_UNEXPECTED_STATIC, ASSIGN_UNEXPECTED_FROZEN, CALL_BAD_UNPACK, CALL_BAD_ITER, CALL_BAD_KWUNPACK, CALL_REPEATED_NAME, CALL_RECURSIVE_DEFAULT, CALL_SUPERF, CALL_SUPER_PARENT, CALL_PTR_VAR, EXPECTED_TUPLE, CALL_REALIZED_FN, CALL_ARGS_MANY, CALL_ARGS_INVALID, CALL_ARGS_MISSING, GENERICS_MISMATCH, EXPECTED_GENERATOR, STATIC_RANGE_BOUNDS, TUPLE_RANGE_BOUNDS, STATIC_DIV_ZERO, SLICE_STEP_ZERO, OP_NO_MAGIC, INST_CALLABLE_STATIC, CATCH_EXCEPTION_TYPE, TYPE_CANNOT_REALIZE_ATTR, TYPE_UNIFY, TYPE_FAILED, MAX_REALIZATION, CUSTOM, __END__ }; class ParserErrorInfo : public llvm::ErrorInfo { ParserErrors errors; public: static char ID; explicit ParserErrorInfo(const ErrorMessage &msg) : errors(msg) {} explicit ParserErrorInfo(const std::vector &msgs) : errors(msgs) {} explicit ParserErrorInfo(const ParserErrors &errors) : errors(errors) {} template ParserErrorInfo(error::Error e, const codon::SrcInfo &o = codon::SrcInfo(), const TA &...args) { auto msg = Emsg(e, args...); errors = ParserErrors(ErrorMessage(msg, o, (int)e)); } const ParserErrors &getErrors() const { return errors; } ParserErrors &getErrors() { return errors; } void log(llvm::raw_ostream &out) const override { for (const auto &trace : errors) { for (const auto &msg : trace.getMessages()) { auto t = msg.toString(); out << t << "\n"; } } } std::error_code convertToErrorCode() const override { return llvm::inconvertibleErrorCode(); } }; class RuntimeErrorInfo : public llvm::ErrorInfo { private: std::string output; std::string type; ErrorMessage message; std::vector backtrace; public: RuntimeErrorInfo(const std::string &output, const std::string &type, const std::string &msg, const std::string &file = "", int line = 0, int col = 0, std::vector backtrace = {}) : output(output), type(type), message(msg, file, line, col), backtrace(std::move(backtrace)) {} std::string getOutput() const { return output; } std::string getType() const { return type; } std::string getMessage() const { return message.getMessage(); } std::string getFile() const { return message.getFile(); } int getLine() const { return message.getLine(); } int getColumn() const { return message.getColumn(); } std::vector getBacktrace() const { return backtrace; } void log(llvm::raw_ostream &out) const override { out << type << ": "; message.log(out); } std::error_code convertToErrorCode() const override { return llvm::inconvertibleErrorCode(); } static char ID; }; class PluginErrorInfo : public llvm::ErrorInfo { private: std::string message; public: explicit PluginErrorInfo(const std::string &message) : message(message) {} std::string getMessage() const { return message; } void log(llvm::raw_ostream &out) const override { out << message; } std::error_code convertToErrorCode() const override { return llvm::inconvertibleErrorCode(); } static char ID; }; class IOErrorInfo : public llvm::ErrorInfo { private: std::string message; public: explicit IOErrorInfo(const std::string &message) : message(message) {} std::string getMessage() const { return message; } void log(llvm::raw_ostream &out) const override { out << message; } std::error_code convertToErrorCode() const override { return llvm::inconvertibleErrorCode(); } static char ID; }; template std::string Eformat(const TA &...args) { return ""; } template std::string Eformat(const char *fmt, const TA &...args) { return fmt::format(fmt::runtime(fmt), args...); } template std::string Emsg(Error e, const TA &...args) { switch (e) { /// Validations case Error::CALL_NAME_ORDER: return fmt::format("positional argument follows keyword argument"); case Error::CALL_NAME_STAR: return fmt::format("cannot use starred expression here"); case Error::CALL_ELLIPSIS: return fmt::format("multiple ellipsis expressions"); case Error::IMPORT_IDENTIFIER: return fmt::format("expected identifier"); case Error::IMPORT_FN: return fmt::format( "function signatures only allowed when importing C or Python functions"); case Error::IMPORT_STAR: return fmt::format("import * only allowed at module level"); case Error::FN_LLVM: return fmt::format("return types required for LLVM and C functions"); case Error::FN_LAST_KWARG: return fmt::format("kwargs must be the last argument"); case Error::FN_MULTIPLE_ARGS: return fmt::format("multiple star arguments provided"); case Error::FN_DEFAULT_STARARG: return fmt::format("star arguments cannot have default values"); case Error::FN_ARG_TWICE: return fmt::format(fmt::runtime("duplicate argument '{}' in function definition"), args...); case Error::FN_DEFAULT: return fmt::format( fmt::runtime("non-default argument '{}' follows default argument"), args...); case Error::FN_C_DEFAULT: return fmt::format( fmt::runtime( "argument '{}' within C function definition cannot have default value"), args...); case Error::FN_C_TYPE: return fmt::format( fmt::runtime( "argument '{}' within C function definition requires type annotation"), args...); case Error::FN_SINGLE_DECORATOR: return fmt::format( fmt::runtime("cannot combine '@{}' with other attributes or decorators"), args...); case Error::CLASS_EXTENSION: return fmt::format("class extensions cannot define data attributes and generics or " "inherit other classes"); case Error::CLASS_MISSING_TYPE: return fmt::format(fmt::runtime("type required for data attribute '{}'"), args...); case Error::CLASS_ARG_TWICE: return fmt::format( fmt::runtime("duplicate data attribute '{}' in class definition"), args...); case Error::CLASS_BAD_DECORATOR: return fmt::format("unsupported class decorator"); case Error::CLASS_MULTIPLE_DECORATORS: return fmt::format(fmt::runtime("duplicate decorator '@{}' in class definition"), args...); case Error::CLASS_SINGLE_DECORATOR: return fmt::format( fmt::runtime("cannot combine '@{}' with other attributes or decorators"), args...); case Error::CLASS_CONFLICT_DECORATOR: return fmt::format(fmt::runtime("cannot combine '@{}' with '@{}'"), args...); case Error::CLASS_NONSTATIC_DECORATOR: return fmt::format("class decorator arguments must be compile-time static values"); case Error::CLASS_BAD_DECORATOR_ARG: return fmt::format("class decorator got unexpected argument"); /// Simplification case Error::ID_NOT_FOUND: return fmt::format(fmt::runtime("name '{}' is not defined"), args...); case Error::ID_CANNOT_CAPTURE: return fmt::format(fmt::runtime("name '{}' cannot be captured"), args...); case Error::ID_NONLOCAL: return fmt::format(fmt::runtime("no binding for nonlocal '{}' found"), args...); case Error::ID_INVALID_BIND: return fmt::format(fmt::runtime("cannot bind '{}' to global or nonlocal name"), args...); case Error::IMPORT_NO_MODULE: return fmt::format(fmt::runtime("no module named '{}'"), args...); case Error::IMPORT_NO_NAME: return fmt::format(fmt::runtime("cannot import name '{}' from '{}'"), args...); case Error::DEL_NOT_ALLOWED: return fmt::format(fmt::runtime("name '{}' cannot be deleted"), args...); case Error::DEL_INVALID: return fmt::format(fmt::runtime("cannot delete given expression"), args...); case Error::ASSIGN_INVALID: return fmt::format("cannot assign to given expression"); case Error::ASSIGN_LOCAL_REFERENCE: return fmt::format( fmt::runtime("local variable '{}' referenced before assignment at {}"), args...); case Error::ASSIGN_MULTI_STAR: return fmt::format("multiple starred expressions in assignment"); case Error::INT_RANGE: return fmt::format(fmt::runtime("integer '{}' cannot fit into 64-bit integer"), args...); case Error::FLOAT_RANGE: return fmt::format(fmt::runtime("float '{}' cannot fit into 64-bit float"), args...); case Error::STR_FSTRING_BALANCE_EXTRA: return fmt::format("expecting '}}' in f-string"); case Error::STR_FSTRING_BALANCE_MISSING: return fmt::format("single '}}' is not allowed in f-string"); case Error::CALL_NO_TYPE: return fmt::format(fmt::runtime("cannot use calls in type signatures"), args...); case Error::CALL_TUPLE_COMPREHENSION: return fmt::format( fmt::runtime( "tuple constructor does not accept nested or conditioned comprehensions"), args...); case Error::CALL_NAMEDTUPLE: return fmt::format(fmt::runtime("namedtuple() takes 2 static arguments"), args...); case Error::CALL_PARTIAL: return fmt::format(fmt::runtime("partial() takes 1 or more arguments"), args...); case Error::EXPECTED_TOPLEVEL: return fmt::format(fmt::runtime("{} must be a top-level statement"), args...); case Error::CLASS_ID_NOT_FOUND: // Note that type aliases are not valid class names return fmt::format(fmt::runtime("class name '{}' is not defined"), args...); case Error::CLASS_INVALID_BIND: return fmt::format(fmt::runtime("cannot bind '{}' to class or function"), args...); case Error::CLASS_NO_INHERIT: return fmt::format(fmt::runtime("{} classes cannot inherit {} classes"), args...); case Error::CLASS_NO_EXTEND: return fmt::format(fmt::runtime("'{}' cannot be extended"), args...); case Error::CLASS_TUPLE_INHERIT: return fmt::format("reference classes cannot inherit tuple classes"); case Error::CLASS_BAD_MRO: return fmt::format("inconsistent class hierarchy"); case Error::CLASS_BAD_ATTR: return fmt::format("unexpected expression in class definition"); case Error::MATCH_MULTI_ELLIPSIS: return fmt::format("multiple ellipses in a pattern"); case Error::FN_OUTSIDE_ERROR: return fmt::format(fmt::runtime("'{}' outside function"), args...); case Error::FN_GLOBAL_ASSIGNED: return fmt::format( fmt::runtime("name '{}' is assigned to before global declaration"), args...); case Error::FN_GLOBAL_NOT_FOUND: return fmt::format(fmt::runtime("no binding for {} '{}' found"), args...); case Error::FN_BAD_LLVM: return fmt::format("invalid LLVM code"); case Error::FN_REALIZE_BUILTIN: return fmt::format("builtin, exported and external functions cannot be generic"); case Error::EXPECTED_LOOP: return fmt::format(fmt::runtime("'{}' outside loop"), args...); case Error::LOOP_DECORATOR: return fmt::format("invalid loop decorator"); case Error::BAD_STATIC_TYPE: return fmt::format("expected 'int', 'bool' or 'str'"); case Error::EXPECTED_TYPE: return fmt::format(fmt::runtime("expected {} expression"), args...); case Error::UNEXPECTED_TYPE: return fmt::format(fmt::runtime("unexpected {} expression"), args...); /// Typechecking case Error::UNION_TOO_BIG: return fmt::format( fmt::runtime( "union exceeded its maximum capacity (contains more than {} types)"), args...); case Error::DOT_NO_ATTR: return fmt::format(fmt::runtime("'{}' object has no attribute '{}'"), args...); case Error::DOT_NO_ATTR_ARGS: return fmt::format(fmt::runtime("'{}' object has no method '{}' with arguments {}"), args...); case Error::FN_NO_ATTR_ARGS: return fmt::format(fmt::runtime("no function '{}' with arguments {}"), args...); case Error::EXPECTED_STATIC: return fmt::format("expected static expression"); case Error::EXPECTED_STATIC_SPECIFIED: return fmt::format(fmt::runtime("expected static {} expression"), args...); case Error::ASSIGN_UNEXPECTED_STATIC: return fmt::format("cannot modify static expressions"); case Error::ASSIGN_UNEXPECTED_FROZEN: return fmt::format("cannot modify tuple attributes"); case Error::CALL_BAD_UNPACK: return fmt::format(fmt::runtime("argument after * must be a tuple, not '{}'"), args...); case Error::CALL_BAD_ITER: return fmt::format(fmt::runtime("iterable must be a tuple, not '{}'"), args...); case Error::CALL_BAD_KWUNPACK: return fmt::format( fmt::runtime("argument after ** must be a named tuple, not '{}'"), args...); case Error::CALL_REPEATED_NAME: return fmt::format(fmt::runtime("keyword argument repeated: {}"), args...); case Error::CALL_RECURSIVE_DEFAULT: return fmt::format(fmt::runtime("argument '{}' has recursive default value"), args...); case Error::CALL_SUPERF: return fmt::format("no superf methods found"); case Error::CALL_SUPER_PARENT: return fmt::format("no super methods found"); case Error::CALL_PTR_VAR: return fmt::format("__ptr__() only takes identifiers or tuple fields as arguments"); case Error::EXPECTED_TUPLE: return fmt::format("expected tuple type"); case Error::CALL_REALIZED_FN: return fmt::format("static.realized() only takes functions as a first argument"); case Error::CALL_ARGS_MANY: return fmt::format(fmt::runtime("{}() takes {} arguments ({} given)"), args...); case Error::CALL_ARGS_INVALID: return fmt::format(fmt::runtime("'{}' is an invalid keyword argument for {}()"), args...); case Error::CALL_ARGS_MISSING: return fmt::format( fmt::runtime("{}() missing 1 required positional argument: '{}'"), args...); case Error::GENERICS_MISMATCH: return fmt::format(fmt::runtime("{} takes {} generics ({} given)"), args...); case Error::EXPECTED_GENERATOR: return fmt::format("expected iterable expression"); case Error::STATIC_RANGE_BOUNDS: return fmt::format( fmt::runtime("static.range too large (expected 0..{}, got instead {})"), args...); case Error::TUPLE_RANGE_BOUNDS: return fmt::format( fmt::runtime("tuple index out of range (expected 0..{}, got instead {})"), args...); case Error::STATIC_DIV_ZERO: return fmt::format("static division by zero"); case Error::SLICE_STEP_ZERO: return fmt::format("slice step cannot be zero"); case Error::OP_NO_MAGIC: return fmt::format( fmt::runtime("unsupported operand type(s) for {}: '{}' and '{}'"), args...); case Error::INST_CALLABLE_STATIC: return fmt::format("CallableTrait cannot take static types"); case Error::CATCH_EXCEPTION_TYPE: return fmt::format(fmt::runtime("'{}' does not inherit from BaseException"), args...); case Error::TYPE_CANNOT_REALIZE_ATTR: return fmt::format( fmt::runtime("type of attribute '{}' of object '{}' cannot be inferred"), args...); case Error::TYPE_UNIFY: return fmt::format(fmt::runtime("'{}' does not match expected type '{}'"), args...); case Error::TYPE_FAILED: return fmt::format( fmt::runtime( "cannot infer the complete type of an expression (inferred only '{}')"), args...); case Error::COMPILER_NO_FILE: return fmt::format(fmt::runtime("cannot open file '{}' for parsing"), args...); case Error::COMPILER_NO_STDLIB: return fmt::format("cannot locate standard library"); case Error::MAX_REALIZATION: return fmt::format( fmt::runtime( "maximum realization depth reached during the realization of '{}'"), args...); case Error::CUSTOM: return Eformat(args...); default: assert(false); } } template void E(Error e, const codon::SrcInfo &o = codon::SrcInfo(), const TA &...args) { auto msg = Emsg(e, args...); auto err = ParserErrors(ErrorMessage(msg, o, (int)e)); throw exc::ParserException(err); } void E(llvm::Error &&error); } // namespace error } // namespace codon ================================================ FILE: codon/compiler/jit.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "jit.h" #include #include "codon/parser/common.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/doc/doc.h" #include "codon/parser/visitors/format/format.h" #include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/translate/translate.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon { namespace jit { namespace { typedef int MainFunc(int, char **); typedef void InputFunc(); typedef void *PyWrapperFunc(void *); const std::string JIT_FILENAME = ""; } // namespace JIT::JIT(const std::string &argv0, const std::string &mode, const std::string &stdlibRoot) : compiler(std::make_unique(argv0, Compiler::Mode::JIT, /*disabledPasses=*/std::vector{}, /*isTest=*/false, /*pyNumerics=*/false, /*pyExtension=*/false)), engine(std::make_unique()), pydata(std::make_unique()), mode(mode), forgetful(false) { if (!stdlibRoot.empty()) compiler->getCache()->fs->add_search_path(stdlibRoot); compiler->getLLVMVisitor()->setJIT(true); } void collectExecutableStmts(ast::Stmt *s, ast::SuiteStmt *final) { if (cast(s) || cast(s) || cast(s)) return; if (auto ss = ast::cast(s)) { for (auto &si : *ss) collectExecutableStmts(si, final); } else if (s) { final->addStmt(ast::clean_clone(s)); } } llvm::Error JIT::init(bool forgetful) { if (forgetful) { this->forgetful = true; auto fs = std::make_shared(compiler->getArgv0(), "", /*allowExternal=*/false); compiler->getCache()->fs = fs; } auto *cache = compiler->getCache(); auto *module = compiler->getModule(); auto *pm = compiler->getPassManager(); auto *llvisitor = compiler->getLLVMVisitor(); cache->isJit = true; auto typechecked = ast::TypecheckVisitor::apply( cache, cache->N(), JIT_FILENAME, {}, compiler->getEarlyDefines()); cache->isJit = false; // we still need main(), so pause isJit first time during translation ast::TranslateVisitor::apply(cache, std::move(typechecked)); cache->isJit = true; module->setSrcInfo({JIT_FILENAME, 0, 0, 0}); pm->run(module); module->accept(*llvisitor); auto pair = llvisitor->takeModule(module); if (auto err = engine->addModule({std::move(pair.first), std::move(pair.second)})) return err; auto func = engine->lookup("main"); if (auto err = func.takeError()) return err; auto *main = func->toPtr(); (*main)(0, nullptr); return llvm::Error::success(); } llvm::Error JIT::compile(const ir::Func *input, llvm::orc::ResourceTrackerSP rt) { auto *module = compiler->getModule(); auto *pm = compiler->getPassManager(); auto *llvisitor = compiler->getLLVMVisitor(); Timer t1("jit/ir"); pm->run(module); t1.log(); Timer t2("jit/llvm"); auto pair = llvisitor->takeModule(module); t2.log(); Timer t3("jit/engine"); if (auto err = engine->addModule({std::move(pair.first), std::move(pair.second)}, rt)) return std::move(err); t3.log(); return llvm::Error::success(); } JITState::JITState(ast::Cache *cache, bool forgetful) : cache(cache), forgetful(forgetful), bCache(*cache), mainCtx(*(cache->imports[MAIN_IMPORT].ctx)), stdlibCtx(*(cache->imports[STDLIB_IMPORT].ctx)), typeCtx(*(cache->typeCtx)), translateCtx(*(cache->codegenCtx)) {} void JITState::undo() { if (!forgetful) undoUnusedIR(); *cache = bCache; *(cache->imports[MAIN_IMPORT].ctx) = mainCtx; *(cache->imports[STDLIB_IMPORT].ctx) = stdlibCtx; *(cache->typeCtx) = typeCtx; *(cache->codegenCtx) = translateCtx; if (forgetful) cleanUpRealizations(); } void JITState::undoUnusedIR() { // Clean-up unused IR nodes made before Typechecker raised an error for (auto &f : cache->functions) { for (auto &r : f.second.realizations) { if (!(in(bCache.functions, f.first) && in(bCache.functions[f.first].realizations, r.first)) && r.second->ir) { cache->module->remove(r.second->ir); } } } } void JITState::cleanUpRealizations() { // Clean-up IR nodes after single JIT input // Nothing should be done here with a proper arena support. } llvm::Expected JIT::compile(const std::string &code, const std::string &file, int line) { auto *cache = compiler->getCache(); auto preamble = cache->N(); JITState state(cache, forgetful); try { auto nodeOrErr = ast::parseCode(cache, file.empty() ? JIT_FILENAME : file, code, /*startLine=*/line); if (!nodeOrErr) throw exc::ParserException(nodeOrErr.takeError()); auto *node = *nodeOrErr; ast::Stmt **e = &node; while (auto se = ast::cast(*e)) { if (se->empty()) break; e = &se->back(); } if (e) if (auto ex = ast::cast(*e)) { *e = cache->N(cache->N( cache->N("_jit_display"), clone(ex->getExpr()), cache->N(mode))); } auto sctx = cache->imports[MAIN_IMPORT].ctx; if (auto err = ast::ScopingVisitor::apply(sctx->cache, node, &sctx->globalShadows)) throw exc::ParserException(std::move(err)); auto tv = ast::TypecheckVisitor::apply(sctx, node, JIT_FILENAME); auto typechecked = cache->N(); for (auto &s : *preamble) typechecked->addStmt(s); typechecked->addStmt(node); // TODO: unroll on errors... // add newly realized functions std::vector v; std::vector frs; v.push_back(typechecked); for (auto &p : cache->pendingRealizations) { v.push_back(cache->functions[p.first].ast); frs.push_back(&cache->functions[p.first].realizations[p.second]->ir); } auto func = ast::TranslateVisitor::apply(cache, cache->N(v)); cache->jitCell++; return func; } catch (const exc::ParserException &exc) { state.undo(); return llvm::make_error(exc.getErrors()); } } llvm::Expected JIT::address(const ir::Func *input, llvm::orc::ResourceTrackerSP rt) { if (auto err = compile(input, rt)) return std::move(err); const std::string name = ir::LLVMVisitor::getNameForFunction(input); auto func = engine->lookup(name); if (auto err = func.takeError()) return std::move(err); return (void *)func->getValue(); } llvm::Expected JIT::run(const ir::Func *input, llvm::orc::ResourceTrackerSP rt) { auto result = address(input, rt); if (auto err = result.takeError()) return std::move(err); auto *repl = (InputFunc *)result.get(); try { (*repl)(); } catch (const runtime::JITError &e) { return handleJITError(e); } return runtime::getCapturedOutput(); } llvm::Expected JIT::execute(const std::string &code, const std::string &file, int line, bool debug, llvm::orc::ResourceTrackerSP rt) { if (debug) fmt::print(stderr, "[codon::jit::execute] code:\n{}-----\n", code); std::unique_ptr state = nullptr; if (forgetful) state = std::make_unique(compiler->getCache(), forgetful); auto result = compile(code, file, line); if (auto err = result.takeError()) return std::move(err); if (auto err = compile(result.get(), rt)) return std::move(err); auto r = run(result.get()); if (state) state->undo(); return r; } llvm::Error JIT::handleJITError(const runtime::JITError &e) { std::vector backtrace; for (auto pc : e.getBacktrace()) { auto line = engine->getDebugListener()->getPrettyBacktrace(pc); if (line && !line->empty()) backtrace.push_back(*line); } return llvm::make_error(e.getOutput(), e.getType(), e.what(), e.getFile(), e.getLine(), e.getCol(), backtrace); } namespace { std::string buildKey(const std::string &name, const std::vector &types) { std::stringstream key; key << name; for (const auto &t : types) { key << "|" << t; } return key.str(); } std::string buildPythonWrapper(const std::string &name, const std::string &wrapname, const std::vector &types, const std::string &pyModule, const std::vector &pyVars) { std::stringstream wrap; wrap << "@export\n"; wrap << "def " << wrapname << "(args: cobj) -> cobj:\n"; for (unsigned i = 0; i < types.size(); i++) { wrap << " " << "a" << i << " = " << types[i] << ".__from_py__(PyTuple_GetItem(args, " << i << "))\n"; } for (unsigned i = 0; i < pyVars.size(); i++) { wrap << " " << "py" << i << " = pyobj._get_module(\"" << pyModule << "\")._getattr(\"" << pyVars[i] << "\")\n"; } wrap << " return " << name << "("; for (unsigned i = 0; i < types.size(); i++) { if (i > 0) wrap << ", "; wrap << "a" << i; } for (unsigned i = 0; i < pyVars.size(); i++) { if (i > 0 || types.size() > 0) wrap << ", "; wrap << "py" << i; } wrap << ").__to_py__()\n"; return wrap.str(); } } // namespace JIT::PythonData::PythonData() : cobj(nullptr), cache() {} ir::types::Type *JIT::PythonData::getCObjType(ir::Module *M) { if (cobj) return cobj; cobj = M->getPointerType(M->getByteType()); return cobj; } JIT::JITResult JIT::executeSafe(const std::string &code, const std::string &file, int line, bool debug) { auto result = execute(code, file, line, debug); if (auto err = result.takeError()) { auto errorInfo = llvm::toString(std::move(err)); return JITResult::error(errorInfo); } return JITResult::success(); } JIT::JITResult JIT::executePython(const std::string &name, const std::vector &types, const std::string &pyModule, const std::vector &pyVars, void *arg, bool debug) { auto key = buildKey(name, types); auto &cache = pydata->cache; auto it = cache.find(key); PyWrapperFunc *wrap; if (it != cache.end()) { auto *wrapper = it->second; const std::string name = ir::LLVMVisitor::getNameForFunction(wrapper); auto func = llvm::cantFail(engine->lookup(name)); wrap = func.toPtr(); } else { static int idx = 0; auto wrapname = "__codon_wrapped__" + name + "_" + std::to_string(idx++); auto wrapper = buildPythonWrapper(name, wrapname, types, pyModule, pyVars); if (debug) fmt::print(stderr, "[codon::jit::executePython] wrapper:\n{}-----\n", wrapper); if (auto err = compile(wrapper).takeError()) { auto errorInfo = llvm::toString(std::move(err)); return JITResult::error(errorInfo); } auto *M = compiler->getModule(); auto *func = M->getOrRealizeFunc(wrapname, {pydata->getCObjType(M)}); seqassertn(func, "could not access wrapper func '{}'", wrapname); cache.emplace(key, func); auto result = address(func); if (auto err = result.takeError()) { auto errorInfo = llvm::toString(std::move(err)); return JITResult::error(errorInfo); } wrap = (PyWrapperFunc *)result.get(); } try { auto *ans = (*wrap)(arg); return JITResult::success(ans); } catch (const runtime::JITError &e) { auto err = handleJITError(e); auto errorInfo = llvm::toString(std::move(err)); return JITResult::error(errorInfo); } } } // namespace jit } // namespace codon void *jit_init(char *name) { auto jit = new codon::jit::JIT(std::string(name)); llvm::cantFail(jit->init()); return jit; } void jit_exit(void *jit) { delete ((codon::jit::JIT *)jit); } CJITResult jit_execute_python(void *jit, char *name, char **types, size_t types_size, char *pyModule, char **py_vars, size_t py_vars_size, void *arg, uint8_t debug) { std::vector cppTypes; cppTypes.reserve(types_size); for (size_t i = 0; i < types_size; i++) cppTypes.emplace_back(types[i]); std::vector cppPyVars; cppPyVars.reserve(py_vars_size); for (size_t i = 0; i < py_vars_size; i++) cppPyVars.emplace_back(py_vars[i]); auto t = ((codon::jit::JIT *)jit) ->executePython(std::string(name), cppTypes, std::string(pyModule), cppPyVars, arg, bool(debug)); void *result = t.result; char *message = t.message.empty() ? nullptr : strndup(t.message.c_str(), t.message.size()); return {result, message}; } CJITResult jit_execute_safe(void *jit, char *code, char *file, int32_t line, uint8_t debug) { auto t = ((codon::jit::JIT *)jit) ->executeSafe(std::string(code), std::string(file), line, bool(debug)); void *result = t.result; char *message = t.message.empty() ? nullptr : strndup(t.message.c_str(), t.message.size()); return {result, message}; } char *get_jit_library() { auto t = codon::ast::library_path(); return strndup(t.c_str(), t.size()); } ================================================ FILE: codon/compiler/jit.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/cir/llvm/llvisitor.h" #include "codon/cir/transform/manager.h" #include "codon/cir/var.h" #include "codon/compiler/compiler.h" #include "codon/compiler/engine.h" #include "codon/compiler/error.h" #include "codon/parser/cache.h" #include "codon/parser/visitors/translate/translate.h" #include "codon/parser/visitors/typecheck/typecheck.h" #include "codon/runtime/lib.h" #include "codon/compiler/jit_extern.h" namespace codon { namespace jit { class JITState { ast::Cache *cache; bool forgetful; ast::Cache bCache; ast::TypeContext mainCtx; ast::TypeContext stdlibCtx; ast::TypeContext typeCtx; ast::TranslateContext translateCtx; public: explicit JITState(ast::Cache *cache, bool forgetful = false); void undo(); void undoUnusedIR(); void cleanUpRealizations(); }; class JIT { public: struct PythonData { ir::types::Type *cobj; std::unordered_map cache; PythonData(); ir::types::Type *getCObjType(ir::Module *M); }; struct JITResult { void *result; std::string message; operator bool() const { return message.empty(); } static JITResult success(void *result = nullptr) { return {result, ""}; } static JITResult error(const std::string &message) { return {nullptr, message}; } }; private: std::unique_ptr compiler; std::unique_ptr engine; std::unique_ptr pydata; std::string mode; bool forgetful = false; public: explicit JIT(const std::string &argv0, const std::string &mode = "", const std::string &stdlibRoot = ""); Compiler *getCompiler() const { return compiler.get(); } Engine *getEngine() const { return engine.get(); } // General llvm::Error init(bool forgetful = false); llvm::Error compile(const ir::Func *input, llvm::orc::ResourceTrackerSP rt = nullptr); llvm::Expected compile(const std::string &code, const std::string &file = "", int line = 0); llvm::Expected address(const ir::Func *input, llvm::orc::ResourceTrackerSP rt = nullptr); llvm::Expected run(const ir::Func *input, llvm::orc::ResourceTrackerSP rt = nullptr); llvm::Expected execute(const std::string &code, const std::string &file = "", int line = 0, bool debug = false, llvm::orc::ResourceTrackerSP rt = nullptr); // Python llvm::Expected runPythonWrapper(const ir::Func *wrapper, void *arg); llvm::Expected getWrapperFunc(const std::string &name, const std::vector &types); JITResult executePython(const std::string &name, const std::vector &types, const std::string &pyModule, const std::vector &pyVars, void *arg, bool debug); JITResult executeSafe(const std::string &code, const std::string &file, int line, bool debug); // Errors llvm::Error handleJITError(const runtime::JITError &e); }; } // namespace jit } // namespace codon ================================================ FILE: codon/compiler/jit_extern.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #ifdef __cplusplus extern "C" { #endif struct CJITResult { void *result; char *error; }; void *jit_init(char *name); void jit_exit(void *jit); struct CJITResult jit_execute_python(void *jit, char *name, char **types, size_t types_size, char *pyModule, char **py_vars, size_t py_vars_size, void *arg, uint8_t debug); struct CJITResult jit_execute_safe(void *jit, char *code, char *file, int32_t line, uint8_t debug); char *get_jit_library(); #ifdef __cplusplus } #endif ================================================ FILE: codon/compiler/memory_manager.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "memory_manager.h" #include "codon/runtime/lib.h" namespace codon { class BoehmGCJITLinkMemoryManager::IPInFlightAlloc : public llvm::jitlink::JITLinkMemoryManager::InFlightAlloc { public: IPInFlightAlloc(BoehmGCJITLinkMemoryManager &MemMgr, llvm::jitlink::LinkGraph &G, llvm::jitlink::BasicLayout BL, llvm::sys::MemoryBlock StandardSegments, llvm::sys::MemoryBlock FinalizationSegments) : MemMgr(MemMgr), G(G), BL(std::move(BL)), StandardSegments(std::move(StandardSegments)), FinalizationSegments(std::move(FinalizationSegments)) {} void finalize(OnFinalizedFunction OnFinalized) override { // Apply memory protections to all segments. if (auto Err = applyProtections()) { OnFinalized(std::move(Err)); return; } // Run finalization actions. auto DeallocActions = runFinalizeActions(G.allocActions()); if (!DeallocActions) { OnFinalized(DeallocActions.takeError()); return; } // Release the finalize segments slab. if (auto EC = llvm::sys::Memory::releaseMappedMemory(FinalizationSegments)) { OnFinalized(llvm::errorCodeToError(EC)); return; } // Continue with finalized allocation. OnFinalized(MemMgr.createFinalizedAlloc(std::move(StandardSegments), std::move(*DeallocActions))); } void abandon(OnAbandonedFunction OnAbandoned) override { llvm::Error Err = llvm::Error::success(); if (auto EC = llvm::sys::Memory::releaseMappedMemory(FinalizationSegments)) Err = llvm::joinErrors(std::move(Err), llvm::errorCodeToError(EC)); if (auto EC = llvm::sys::Memory::releaseMappedMemory(StandardSegments)) Err = llvm::joinErrors(std::move(Err), llvm::errorCodeToError(EC)); OnAbandoned(std::move(Err)); } private: llvm::Error applyProtections() { for (auto &KV : BL.segments()) { const auto &AG = KV.first; auto &Seg = KV.second; auto Prot = toSysMemoryProtectionFlags(AG.getMemProt()); uint64_t SegSize = llvm::alignTo(Seg.ContentSize + Seg.ZeroFillSize, MemMgr.PageSize); llvm::sys::MemoryBlock MB(Seg.WorkingMem, SegSize); if (auto EC = llvm::sys::Memory::protectMappedMemory(MB, Prot)) return llvm::errorCodeToError(EC); if (Prot & llvm::sys::Memory::MF_EXEC) llvm::sys::Memory::InvalidateInstructionCache(MB.base(), MB.allocatedSize()); } return llvm::Error::success(); } BoehmGCJITLinkMemoryManager &MemMgr; llvm::jitlink::LinkGraph &G; llvm::jitlink::BasicLayout BL; llvm::sys::MemoryBlock StandardSegments; llvm::sys::MemoryBlock FinalizationSegments; }; llvm::Expected> BoehmGCJITLinkMemoryManager::Create() { if (auto PageSize = llvm::sys::Process::getPageSize()) { if (!llvm::isPowerOf2_64((uint64_t)*PageSize)) return llvm::make_error("Page size is not a power of 2", llvm::inconvertibleErrorCode()); return std::make_unique(*PageSize); } else { return PageSize.takeError(); } } void BoehmGCJITLinkMemoryManager::allocate(const llvm::jitlink::JITLinkDylib *JD, llvm::jitlink::LinkGraph &G, OnAllocatedFunction OnAllocated) { llvm::jitlink::BasicLayout BL(G); /// Scan the request and calculate the group and total sizes. /// Check that segment size is no larger than a page. auto SegsSizes = BL.getContiguousPageBasedLayoutSizes(PageSize); if (!SegsSizes) { OnAllocated(SegsSizes.takeError()); return; } /// Check that the total size requested (including zero fill) is not larger /// than a size_t. if (SegsSizes->total() > std::numeric_limits::max()) { OnAllocated(llvm::make_error( "Total requested size " + llvm::formatv("{0:x}", SegsSizes->total()) + " for graph " + G.getName() + " exceeds address space")); return; } // Allocate one slab for the whole thing (to make sure everything is // in-range), then partition into standard and finalization blocks. // // FIXME: Make two separate allocations in the future to reduce // fragmentation: finalization segments will usually be a single page, and // standard segments are likely to be more than one page. Where multiple // allocations are in-flight at once (likely) the current approach will leave // a lot of single-page holes. llvm::sys::MemoryBlock Slab; llvm::sys::MemoryBlock StandardSegsMem; llvm::sys::MemoryBlock FinalizeSegsMem; { const llvm::sys::Memory::ProtectionFlags ReadWrite = static_cast(llvm::sys::Memory::MF_READ | llvm::sys::Memory::MF_WRITE); std::error_code EC; Slab = llvm::sys::Memory::allocateMappedMemory(SegsSizes->total(), nullptr, ReadWrite, EC); if (EC) { OnAllocated(llvm::errorCodeToError(EC)); return; } // Zero-fill the whole slab up-front. memset(Slab.base(), 0, Slab.allocatedSize()); StandardSegsMem = {Slab.base(), static_cast(SegsSizes->StandardSegs)}; FinalizeSegsMem = {(void *)((char *)Slab.base() + SegsSizes->StandardSegs), static_cast(SegsSizes->FinalizeSegs)}; } auto NextStandardSegAddr = llvm::orc::ExecutorAddr::fromPtr(StandardSegsMem.base()); auto NextFinalizeSegAddr = llvm::orc::ExecutorAddr::fromPtr(FinalizeSegsMem.base()); // Build ProtMap, assign addresses. for (auto &KV : BL.segments()) { auto &AG = KV.first; auto &Seg = KV.second; auto &SegAddr = (AG.getMemLifetime() == llvm::orc::MemLifetime::Standard) ? NextStandardSegAddr : NextFinalizeSegAddr; Seg.WorkingMem = SegAddr.toPtr(); Seg.Addr = SegAddr; SegAddr += llvm::alignTo(Seg.ContentSize + Seg.ZeroFillSize, PageSize); if (static_cast(AG.getMemProt()) & static_cast(llvm::orc::MemProt::Write)) { seq_gc_add_roots((void *)Seg.Addr.getValue(), (void *)SegAddr.getValue()); } } if (auto Err = BL.apply()) { OnAllocated(std::move(Err)); return; } OnAllocated(std::make_unique( *this, G, std::move(BL), std::move(StandardSegsMem), std::move(FinalizeSegsMem))); } void BoehmGCJITLinkMemoryManager::deallocate(std::vector Allocs, OnDeallocatedFunction OnDeallocated) { std::vector StandardSegmentsList; std::vector> DeallocActionsList; { std::lock_guard Lock(FinalizedAllocsMutex); for (auto &Alloc : Allocs) { auto *FA = Alloc.release().toPtr(); StandardSegmentsList.push_back(std::move(FA->StandardSegments)); DeallocActionsList.push_back(std::move(FA->DeallocActions)); FA->~FinalizedAllocInfo(); FinalizedAllocInfos.Deallocate(FA); } } llvm::Error DeallocErr = llvm::Error::success(); while (!DeallocActionsList.empty()) { auto &DeallocActions = DeallocActionsList.back(); auto &StandardSegments = StandardSegmentsList.back(); /// Run any deallocate calls. while (!DeallocActions.empty()) { if (auto Err = DeallocActions.back().runWithSPSRetErrorMerged()) DeallocErr = llvm::joinErrors(std::move(DeallocErr), std::move(Err)); DeallocActions.pop_back(); } /// Release the standard segments slab. if (auto EC = llvm::sys::Memory::releaseMappedMemory(StandardSegments)) DeallocErr = llvm::joinErrors(std::move(DeallocErr), llvm::errorCodeToError(EC)); DeallocActionsList.pop_back(); StandardSegmentsList.pop_back(); } OnDeallocated(std::move(DeallocErr)); } llvm::jitlink::JITLinkMemoryManager::FinalizedAlloc BoehmGCJITLinkMemoryManager::createFinalizedAlloc( llvm::sys::MemoryBlock StandardSegments, std::vector DeallocActions) { std::lock_guard Lock(FinalizedAllocsMutex); auto *FA = FinalizedAllocInfos.Allocate(); new (FA) FinalizedAllocInfo({std::move(StandardSegments), std::move(DeallocActions)}); return FinalizedAlloc(llvm::orc::ExecutorAddr::fromPtr(FA)); } } // namespace codon ================================================ FILE: codon/compiler/memory_manager.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include "codon/cir/llvm/llvm.h" namespace codon { /// Basically a copy of LLVM's jitlink::InProcessMemoryManager that registers /// relevant allocated sections with the GC. TODO: Avoid copying this entire /// class if/when there's an API to perform the registration externally. class BoehmGCJITLinkMemoryManager : public llvm::jitlink::JITLinkMemoryManager { public: class IPInFlightAlloc; /// Attempts to auto-detect the host page size. static llvm::Expected> Create(); /// Create an instance using the given page size. BoehmGCJITLinkMemoryManager(uint64_t PageSize) : PageSize(PageSize) {} void allocate(const llvm::jitlink::JITLinkDylib *JD, llvm::jitlink::LinkGraph &G, OnAllocatedFunction OnAllocated) override; // Use overloads from base class. using llvm::jitlink::JITLinkMemoryManager::allocate; void deallocate(std::vector Alloc, OnDeallocatedFunction OnDeallocated) override; // Use overloads from base class. using llvm::jitlink::JITLinkMemoryManager::deallocate; private: // FIXME: Use an in-place array instead of a vector for DeallocActions. // There shouldn't need to be a heap alloc for this. struct FinalizedAllocInfo { llvm::sys::MemoryBlock StandardSegments; std::vector DeallocActions; }; FinalizedAlloc createFinalizedAlloc( llvm::sys::MemoryBlock StandardSegments, std::vector DeallocActions); uint64_t PageSize; std::mutex FinalizedAllocsMutex; llvm::RecyclingAllocator FinalizedAllocInfos; }; } // namespace codon ================================================ FILE: codon/config/.gitignore ================================================ * */ !.gitignore ================================================ FILE: codon/dsl/dsl.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/cir/cir.h" #include "codon/cir/transform/manager.h" #include "codon/cir/transform/pass.h" #include "codon/parser/cache.h" #include "llvm/Passes/PassBuilder.h" #include #include #include namespace codon { /// Base class for DSL plugins. Plugins will return an instance of /// a child of this class, which defines various characteristics of /// the DSL, like keywords and IR passes. class DSL { public: /// General information about this plugin. struct Info { /// Extension name std::string name; /// Extension description std::string description; /// Extension version std::string version; /// Extension URL std::string url; /// Supported Codon versions (semver range) std::string supported; /// Plugin stdlib path std::string stdlibPath; /// Plugin dynamic library path std::string dylibPath; /// Linker arguments (to replace "-l dylibPath" if present) std::vector linkArgs; }; using KeywordCallback = std::function; struct ExprKeyword { std::string keyword; KeywordCallback callback; }; struct BlockKeyword { std::string keyword; KeywordCallback callback; bool hasExpr; }; virtual ~DSL() noexcept = default; /// Registers this DSL's IR passes with the given pass manager. /// @param pm the pass manager to add the passes to /// @param debug true if compiling in debug mode virtual void addIRPasses(ir::transform::PassManager *pm, bool debug) {} /// Registers this DSL's LLVM passes with the given pass builder. /// @param pb the pass builder to add the passes to /// @param debug true if compiling in debug mode virtual void addLLVMPasses(llvm::PassBuilder *pb, bool debug) {} /// Returns a vector of "expression keywords", defined as keywords of /// the form "keyword ". /// @return this DSL's expression keywords virtual std::vector getExprKeywords() { return {}; } /// Returns a vector of "block keywords", defined as keywords of the /// form "keyword : ". /// @return this DSL's block keywords virtual std::vector getBlockKeywords() { return {}; } }; } // namespace codon ================================================ FILE: codon/dsl/plugins.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "plugins.h" #include #include #include #include "codon/parser/common.h" #include "codon/util/common.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" namespace codon { namespace { llvm::Expected pluginError(const std::string &msg) { return llvm::make_error(msg); } typedef std::unique_ptr LoadFunc(); } // namespace llvm::Expected PluginManager::load(const std::string &path) { #if __APPLE__ const std::string libExt = "dylib"; #else const std::string libExt = "so"; #endif const std::string config = "plugin.toml"; llvm::SmallString<128> tomlPath(path); llvm::sys::path::append(tomlPath, config); if (!llvm::sys::fs::exists(tomlPath)) { // try default install path std::string s = ast::Filesystem::executable_path(argv0.c_str()); tomlPath = llvm::SmallString<128>(llvm::sys::path::parent_path(s)); llvm::sys::path::append(tomlPath, "../lib/codon/plugins", path, config); } toml::parse_result tml; try { tml = toml::parse_file(tomlPath.str()); } catch (const toml::parse_error &e) { return pluginError( fmt::format("[toml::parse_file(\"{}\")] {}", tomlPath.str(), e.what())); } auto about = tml["about"]; auto library = tml["library"]; std::string cppLib = library["cpp"].value_or(""); std::string dylibPath; if (!cppLib.empty()) { llvm::SmallString<128> p = llvm::sys::path::parent_path(tomlPath); llvm::sys::path::append(p, cppLib + "." + libExt); dylibPath = p.str(); } auto link = library["link"]; std::vector linkArgs; if (auto arr = link.as_array()) { arr->for_each([&linkArgs](auto &&el) { std::string l = el.value_or(""); if (!l.empty()) linkArgs.push_back(l); }); } else { std::string l = link.value_or(""); if (!l.empty()) linkArgs.push_back(l); } for (auto &l : linkArgs) l = fmt::format(fmt::runtime(l), fmt::arg("root", llvm::sys::path::parent_path(tomlPath))); std::string codonLib = library["codon"].value_or(""); std::string stdlibPath; if (!codonLib.empty()) { llvm::SmallString<128> p = llvm::sys::path::parent_path(tomlPath); llvm::sys::path::append(p, codonLib); stdlibPath = p.str(); } DSL::Info info = {about["name"].value_or(""), about["description"].value_or(""), about["version"].value_or(""), about["url"].value_or(""), about["supported"].value_or(""), stdlibPath, dylibPath, linkArgs}; bool versionOk = false; try { versionOk = semver::range::satisfies( semver::version(CODON_VERSION_MAJOR, CODON_VERSION_MINOR, CODON_VERSION_PATCH), info.supported); } catch (const std::invalid_argument &e) { return pluginError(fmt::format("[semver::range::satisfies(..., \"{}\")] {}", info.supported, e.what())); } if (!versionOk) return pluginError(fmt::format("unsupported version {} (supported: {})", CODON_VERSION, info.supported)); if (!dylibPath.empty()) { std::string libLoadErrorMsg; auto handle = llvm::sys::DynamicLibrary::getPermanentLibrary(dylibPath.c_str(), &libLoadErrorMsg); if (!handle.isValid()) return pluginError(fmt::format( "[llvm::sys::DynamicLibrary::getPermanentLibrary(\"{}\", ...)] {}", dylibPath, libLoadErrorMsg)); auto *entry = (LoadFunc *)handle.getAddressOfSymbol("load"); if (!entry) return pluginError( fmt::format("could not find 'load' in plugin shared library: {}", dylibPath)); auto dsl = (*entry)(); plugins.push_back(std::make_unique(std::move(dsl), info, handle)); } else { plugins.push_back(std::make_unique(std::make_unique(), info, llvm::sys::DynamicLibrary())); } return plugins.back().get(); } } // namespace codon ================================================ FILE: codon/dsl/plugins.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/cir/util/iterators.h" #include "codon/compiler/error.h" #include "codon/dsl/dsl.h" #include "llvm/Support/DynamicLibrary.h" namespace codon { /// Plugin metadata struct Plugin { /// the associated DSL std::unique_ptr dsl; /// plugin information DSL::Info info; /// library handle llvm::sys::DynamicLibrary handle; Plugin(std::unique_ptr dsl, DSL::Info info, llvm::sys::DynamicLibrary handle) : dsl(std::move(dsl)), info(std::move(info)), handle(std::move(handle)) {} }; /// Manager for loading, applying and unloading plugins. class PluginManager { private: /// Codon executable location std::string argv0; /// vector of loaded plugins std::vector> plugins; public: /// Constructs a plugin manager PluginManager(const std::string &argv0) : argv0(argv0) {} /// @return iterator to the first plugin auto begin() { return ir::util::raw_ptr_adaptor(plugins.begin()); } /// @return iterator beyond the last plugin auto end() { return ir::util::raw_ptr_adaptor(plugins.end()); } /// @return const iterator to the first plugin auto begin() const { return ir::util::const_raw_ptr_adaptor(plugins.begin()); } /// @return const iterator beyond the last plugin auto end() const { return ir::util::const_raw_ptr_adaptor(plugins.end()); } /// Loads the plugin at the given load path. /// @param path path to plugin directory containing "plugin.toml" file /// @return plugin pointer if successful, plugin error otherwise llvm::Expected load(const std::string &path); }; } // namespace codon ================================================ FILE: codon/parser/ast/attr.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "attr.h" namespace codon::ast {} // namespace codon::ast ================================================ FILE: codon/parser/ast/attr.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once namespace codon::ast { constexpr int INDENT_SIZE = 2; struct Attr { enum { Module = 201, ParentClass, Bindings, LLVM, Python, Atomic, Property, StaticMethod, Attribute, C, Internal, HiddenFromUser, ForceRealize, AllowPassThrough, ParentCallExpr, TupleCall, Validated, AutoGenerated, CVarArg, Method, Capture, HasSelf, IsGenerator, Extend, Tuple, Dataclass, ClassDeduce, ClassNoTuple, Test, Overload, Export, Inline, NoArgReorder, FunctionAttributes, NoExtend, ClassMagic, ExprSequenceItem, ExprStarSequenceItem, ExprList, ExprSet, ExprDict, ExprPartial, ExprDominated, ExprStarArgument, ExprKwStarArgument, ExprOrderedCall, ExprExternVar, ExprDominatedUndefCheck, ExprDominatedUsed, ExprTime, ExprDoNotRealize, ExprNoSpecial, TryPyVar }; }; } // namespace codon::ast ================================================ FILE: codon/parser/ast/error.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "llvm/Support/Error.h" #include #include #include #include /** * WARNING: do not include anything else in this file, especially format.h * peglib.h uses this file. However, it is not compatible with format.h * (and possibly some other includes). Their inclusion will result in a succesful * compilation but extremely weird behaviour and hard-to-debug crashes (it seems that * some parts of peglib conflict with format.h in a weird way---further investigation * needed). */ namespace codon { struct SrcInfo { std::string file; int line; int col; int len; int id; /// used to differentiate different instances SrcInfo(); SrcInfo(std::string file, int line, int col, int len); bool operator==(const SrcInfo &src) const; bool operator<(const SrcInfo &src) const; bool operator<=(const SrcInfo &src) const; }; class ErrorMessage { private: std::string msg; SrcInfo loc; int errorCode = -1; public: explicit ErrorMessage(std::string msg, SrcInfo loc = SrcInfo(), int errorCode = -1) : msg(std::move(msg)), loc(std::move(loc)), errorCode(-1) {} explicit ErrorMessage(std::string msg, const std::string &file = "", int line = 0, int col = 0, int len = 0, int errorCode = -1) : msg(std::move(msg)), loc(file, line, col, len), errorCode(-1) {} std::string getMessage() const { return msg; } std::string getFile() const { return loc.file; } int getLine() const { return loc.line; } int getColumn() const { return loc.col; } int getLength() const { return loc.len; } int getErrorCode() const { return errorCode; } SrcInfo getSrcInfo() const { return loc; } void setSrcInfo(const SrcInfo &s) { loc = s; } bool operator==(const ErrorMessage &t) const { return msg == t.msg && loc == t.loc; } std::string toString() const; void log(llvm::raw_ostream &out) const { out << toString(); } }; struct ParserErrors { struct Backtrace { std::vector trace; const std::vector &getMessages() const { return trace; } auto begin() const { return trace.begin(); } auto front() const { return trace.front(); } auto front() { return trace.front(); } auto end() const { return trace.end(); } auto back() { return trace.back(); } auto back() const { return trace.back(); } auto size() const { return trace.size(); } void addMessage(const std::string &msg, const SrcInfo &info = SrcInfo()) { trace.emplace_back(msg, info); } bool operator==(const Backtrace &t) const { return trace == t.trace; } }; std::vector errors; ParserErrors() = default; explicit ParserErrors(const ErrorMessage &msg) : errors{Backtrace{{msg}}} {} ParserErrors(const std::string &msg, const SrcInfo &info) : ParserErrors(ErrorMessage{msg, info}) {} explicit ParserErrors(const std::string &msg) : ParserErrors(ErrorMessage{msg, SrcInfo{}}) {} ParserErrors(const ParserErrors &e) = default; explicit ParserErrors(const std::vector &m) : ParserErrors() { for (auto &msg : m) errors.push_back(Backtrace{{msg}}); } auto begin() { return errors.begin(); } auto end() { return errors.end(); } auto begin() const { return errors.begin(); } auto end() const { return errors.end(); } auto empty() const { return errors.empty(); } auto size() const { return errors.size(); } auto &back() { return errors.back(); } const auto &back() const { return errors.back(); } void append(const ParserErrors &e) { for (auto &trace : e) addError(trace); } Backtrace getLast() { assert(!empty() && "empty error trace"); return errors.back(); } /// Add an error message to the current backtrace void addError(const Backtrace &trace) { if (!errors.empty() && errors.back() == trace) return; errors.push_back({trace}); } void addError(const std::vector &trace) { addError(Backtrace{trace}); } std::string getMessage() const { if (empty()) return ""; return errors.front().trace.front().getMessage(); } }; } // namespace codon namespace codon::exc { /** * Parser error exception. * Used for parsing, transformation and type-checking errors. */ class ParserException : public std::runtime_error { /// These vectors (stacks) store an error stack-trace. ParserErrors errors; public: ParserException() noexcept : std::runtime_error("") {} explicit ParserException(const ParserErrors &errors) noexcept : std::runtime_error(errors.getMessage()), errors(errors) {} explicit ParserException(llvm::Error &&e) noexcept; const ParserErrors &getErrors() const { return errors; } ParserErrors &getErrors() { return errors; } }; } // namespace codon::exc ================================================ FILE: codon/parser/ast/expr.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "expr.h" #include #include #include #include #include #include #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/match.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/visitor.h" #define FASTFLOAT_ALLOWS_LEADING_PLUS #define FASTFLOAT_SKIP_WHITE_SPACE #include "fast_float/fast_float.h" #define ACCEPT_IMPL(T, X) \ ASTNode *T::clone(bool c) const { return cache->N(*this, c); } \ void T::accept(X &visitor) { visitor.visit(this); } \ const char T::NodeId = 0; using namespace codon::error; using namespace codon::matcher; namespace codon::ast { Expr::Expr() : AcceptorExtend(), type(nullptr), done(false), origExpr(nullptr) {} Expr::Expr(const Expr &expr, bool clean) : Expr(expr) { if (clean) { type = nullptr; done = false; } } types::ClassType *Expr::getClassType() const { return type ? type->getClass() : nullptr; } std::string Expr::wrapType(const std::string &sexpr) const { auto is = sexpr; if (done) is.insert(findStar(is), "*"); return "(" + is + (type && !done ? fmt::format(" #:type \"{}\"", type->debugString(2)) : "") + ")"; } Param::Param(std::string name, Expr *type, Expr *defaultValue, int status) : name(std::move(name)), type(type), defaultValue(defaultValue) { if (status == 0 && (match(getType(), MOr(M(TYPE_TYPE), M(TRAIT_TYPE), M(M(TRAIT_TYPE), M_))) || getStaticGeneric(getType()))) { this->status = Generic; } else { this->status = (status == 0 ? Value : (status == 1 ? Generic : HiddenGeneric)); } } Param::Param(const SrcInfo &info, std::string name, Expr *type, Expr *defaultValue, int status) : Param(std::move(name), type, defaultValue, status) { setSrcInfo(info); } std::string Param::toString(int indent) const { return fmt::format("({}{}{}{})", name, type ? " #:type " + type->toString(indent) : "", defaultValue ? " #:default " + defaultValue->toString(indent) : "", !isValue() ? " #:generic" : ""); } Param Param::clone(bool clean) const { return Param(name, ast::clone(type, clean), ast::clone(defaultValue, clean), status); } std::pair Param::getNameWithStars() const { int stars = 0; for (; stars < name.size() && name[stars] == '*'; stars++) ; auto n = name.substr(stars); return {stars, n}; } NoneExpr::NoneExpr() : AcceptorExtend() {} NoneExpr::NoneExpr(const NoneExpr &expr, bool clean) : AcceptorExtend(expr, clean) {} std::string NoneExpr::toString(int) const { return wrapType("none"); } BoolExpr::BoolExpr(bool value) : AcceptorExtend(), value(value) {} BoolExpr::BoolExpr(const BoolExpr &expr, bool clean) : AcceptorExtend(expr, clean), value(expr.value) {} bool BoolExpr::getValue() const { return value; } std::string BoolExpr::toString(int) const { return wrapType(fmt::format("bool {}", static_cast(value))); } IntExpr::IntExpr(int64_t intValue) : AcceptorExtend(), value(std::to_string(intValue)), intValue(intValue) {} IntExpr::IntExpr(const std::string &value, std::string suffix) : AcceptorExtend(), value(), suffix(std::move(suffix)) { for (auto c : value) if (c != '_') this->value += c; try { if (startswith(this->value, "0b") || startswith(this->value, "0B")) intValue = std::stoull(this->value.substr(2), nullptr, 2); else intValue = std::stoull(this->value, nullptr, 0); } catch (std::out_of_range &) { } } IntExpr::IntExpr(const IntExpr &expr, bool clean) : AcceptorExtend(expr, clean), value(expr.value), suffix(expr.suffix), intValue(expr.intValue) {} std::pair IntExpr::getRawData() const { return {value, suffix}; } bool IntExpr::hasStoredValue() const { return intValue.has_value(); } int64_t IntExpr::getValue() const { seqassertn(hasStoredValue(), "value not set"); return intValue.value(); } std::string IntExpr::toString(int) const { return wrapType( fmt::format("int {}{}", value, suffix.empty() ? "" : fmt::format(" #:suffix \"{}\"", suffix))); } FloatExpr::FloatExpr(double floatValue) : AcceptorExtend(), value(fmt::format("{:g}", floatValue)), floatValue(floatValue) { } FloatExpr::FloatExpr(const std::string &value, std::string suffix) : AcceptorExtend(), value(), suffix(std::move(suffix)) { this->value.reserve(value.size()); std::ranges::copy_if(value.begin(), value.end(), std::back_inserter(this->value), [](char c) { return c != '_'; }); double result; auto r = fast_float::from_chars(this->value.data(), this->value.data() + this->value.size(), result); if (r.ec == std::errc() || r.ec == std::errc::result_out_of_range) floatValue = result; } FloatExpr::FloatExpr(const FloatExpr &expr, bool clean) : AcceptorExtend(expr, clean), value(expr.value), suffix(expr.suffix), floatValue(expr.floatValue) {} std::pair FloatExpr::getRawData() const { return {value, suffix}; } bool FloatExpr::hasStoredValue() const { return floatValue.has_value(); } double FloatExpr::getValue() const { seqassertn(hasStoredValue(), "value not set"); return floatValue.value(); } std::string FloatExpr::toString(int) const { return wrapType( fmt::format("float {}{}", value, suffix.empty() ? "" : fmt::format(" #:suffix \"{}\"", suffix))); } StringExpr::StringExpr(std::vector strings) : AcceptorExtend(), strings(std::move(strings)) {} StringExpr::StringExpr(std::string value, std::string prefix) : StringExpr(std::vector{ StringExpr::String{std::move(value), std::move(prefix)}}) {} StringExpr::StringExpr(const StringExpr &expr, bool clean) : AcceptorExtend(expr, clean), strings(expr.strings) { for (auto &s : strings) s.expr = ast::clone(s.expr); } std::string StringExpr::toString(int) const { std::vector s; for (auto &vp : strings) s.push_back(fmt::format( "\"{}\"{}", escape(vp.value), vp.prefix.empty() ? "" : fmt::format(" #:prefix \"{}\"", vp.prefix))); return wrapType(fmt::format("string ({})", join(s))); } std::string StringExpr::getValue() const { seqassert(isSimple(), "invalid StringExpr"); return strings[0].value; } bool StringExpr::isSimple() const { return strings.size() == 1 && strings[0].prefix.empty(); } IdExpr::IdExpr(std::string value) : AcceptorExtend(), value(std::move(value)) {} IdExpr::IdExpr(const IdExpr &expr, bool clean) : AcceptorExtend(expr, clean), value(expr.value) {} std::string IdExpr::toString(int) const { return !getType() ? fmt::format("'{}", value) : wrapType(fmt::format("'{}", value)); } StarExpr::StarExpr(Expr *expr) : AcceptorExtend(), expr(expr) {} StarExpr::StarExpr(const StarExpr &expr, bool clean) : AcceptorExtend(expr, clean), expr(ast::clone(expr.expr, clean)) {} std::string StarExpr::toString(int indent) const { return wrapType(fmt::format("star {}", expr->toString(indent))); } KeywordStarExpr::KeywordStarExpr(Expr *expr) : AcceptorExtend(), expr(expr) {} KeywordStarExpr::KeywordStarExpr(const KeywordStarExpr &expr, bool clean) : AcceptorExtend(expr, clean), expr(ast::clone(expr.expr, clean)) {} std::string KeywordStarExpr::toString(int indent) const { return wrapType(fmt::format("kwstar {}", expr->toString(indent))); } TupleExpr::TupleExpr(std::vector items) : AcceptorExtend(), Items(std::move(items)) {} TupleExpr::TupleExpr(const TupleExpr &expr, bool clean) : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)) {} std::string TupleExpr::toString(int) const { return wrapType(fmt::format("tuple {}", combine(items))); } ListExpr::ListExpr(std::vector items) : AcceptorExtend(), Items(std::move(items)) {} ListExpr::ListExpr(const ListExpr &expr, bool clean) : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)) {} std::string ListExpr::toString(int) const { return wrapType(!items.empty() ? fmt::format("list {}", combine(items)) : "list"); } SetExpr::SetExpr(std::vector items) : AcceptorExtend(), Items(std::move(items)) {} SetExpr::SetExpr(const SetExpr &expr, bool clean) : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)) {} std::string SetExpr::toString(int) const { return wrapType(!items.empty() ? fmt::format("set {}", combine(items)) : "set"); } DictExpr::DictExpr(std::vector items) : AcceptorExtend(), Items(std::move(items)) { for (auto *i : *this) { auto t = cast(i); seqassertn(t && t->size() == 2, "dictionary items are invalid"); } } DictExpr::DictExpr(const DictExpr &expr, bool clean) : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)) {} std::string DictExpr::toString(int) const { return wrapType(!items.empty() ? fmt::format("dict {}", combine(items)) : "set"); } GeneratorExpr::GeneratorExpr(Cache *cache, GeneratorExpr::GeneratorKind kind, Expr *expr, std::vector loops) : AcceptorExtend(), kind(kind), loops() { this->cache = cache; seqassert(!loops.empty() && cast(loops[0]), "bad generator constructor"); loops.push_back(cache->N(cache->N(expr))); formCompleteStmt(loops); } GeneratorExpr::GeneratorExpr(Cache *cache, Expr *key, Expr *expr, std::vector loops) : AcceptorExtend(), kind(GeneratorExpr::DictGenerator), loops() { this->cache = cache; seqassert(!loops.empty() && cast(loops[0]), "bad generator constructor"); Expr *t = cache->N(std::vector{key, expr}); loops.push_back(cache->N(cache->N(t))); formCompleteStmt(loops); } GeneratorExpr::GeneratorExpr(const GeneratorExpr &expr, bool clean) : AcceptorExtend(expr, clean), kind(expr.kind), loops(ast::clone(expr.loops, clean)) {} std::string GeneratorExpr::toString(int indent) const { auto pad = indent >= 0 ? ("\n" + std::string(indent + 2 * INDENT_SIZE, ' ')) : " "; std::string prefix; if (kind == GeneratorKind::ListGenerator) prefix = "list-"; if (kind == GeneratorKind::SetGenerator) prefix = "set-"; if (kind == GeneratorKind::DictGenerator) prefix = "dict-"; auto l = loops->toString(indent >= 0 ? indent + 2 * INDENT_SIZE : -1); return wrapType(fmt::format("{}gen {}", prefix, l)); } Expr *GeneratorExpr::getFinalExpr() { auto s = *(getFinalStmt()); if (cast(s)) return cast(s)->getExpr(); return nullptr; } int GeneratorExpr::loopCount() const { int cnt = 0; for (Stmt *i = loops;;) { if (auto sf = cast(i)) { i = sf->getSuite(); cnt++; } else if (auto si = cast(i)) { i = si->getIf(); cnt++; } else if (auto ss = cast(i)) { if (ss->empty()) break; i = ss->back(); } else break; } return cnt; } void GeneratorExpr::setFinalExpr(Expr *expr) { *(getFinalStmt()) = cache->N(expr); } void GeneratorExpr::setFinalStmt(Stmt *stmt) { *(getFinalStmt()) = stmt; } Stmt *GeneratorExpr::getFinalSuite() const { return loops; } Stmt **GeneratorExpr::getFinalStmt() { for (Stmt **i = &loops;;) { if (auto sf = cast(*i)) i = reinterpret_cast(&sf->suite); else if (auto si = cast(*i)) i = reinterpret_cast(&si->ifSuite); else if (auto ss = cast(*i)) { if (ss->empty()) return i; i = &(ss->back()); } else return i; } seqassert(false, "bad generator"); return nullptr; } void GeneratorExpr::formCompleteStmt(const std::vector &loops) { Stmt *final = nullptr; for (size_t i = loops.size(); i-- > 0;) { if (auto si = cast(loops[i])) si->ifSuite = SuiteStmt::wrap(final); else if (auto sf = cast(loops[i])) sf->suite = SuiteStmt::wrap(final); final = loops[i]; } this->loops = loops[0]; } IfExpr::IfExpr(Expr *cond, Expr *ifexpr, Expr *elsexpr) : AcceptorExtend(), cond(cond), ifexpr(ifexpr), elsexpr(elsexpr) {} IfExpr::IfExpr(const IfExpr &expr, bool clean) : AcceptorExtend(expr, clean), cond(ast::clone(expr.cond, clean)), ifexpr(ast::clone(expr.ifexpr, clean)), elsexpr(ast::clone(expr.elsexpr, clean)) { } std::string IfExpr::toString(int indent) const { return wrapType(fmt::format("if-expr {} {} {}", cond->toString(indent), ifexpr->toString(indent), elsexpr->toString(indent))); } UnaryExpr::UnaryExpr(std::string op, Expr *expr) : AcceptorExtend(), op(std::move(op)), expr(expr) {} UnaryExpr::UnaryExpr(const UnaryExpr &expr, bool clean) : AcceptorExtend(expr, clean), op(expr.op), expr(ast::clone(expr.expr, clean)) {} std::string UnaryExpr::toString(int indent) const { return wrapType(fmt::format("unary \"{}\" {}", op, expr->toString(indent))); } BinaryExpr::BinaryExpr(Expr *lexpr, std::string op, Expr *rexpr, bool inPlace) : AcceptorExtend(), op(std::move(op)), lexpr(lexpr), rexpr(rexpr), inPlace(inPlace) {} BinaryExpr::BinaryExpr(const BinaryExpr &expr, bool clean) : AcceptorExtend(expr, clean), op(expr.op), lexpr(ast::clone(expr.lexpr, clean)), rexpr(ast::clone(expr.rexpr, clean)), inPlace(expr.inPlace) {} std::string BinaryExpr::toString(int indent) const { return wrapType(fmt::format("binary \"{}\" {} {}{}", op, lexpr->toString(indent), rexpr->toString(indent), inPlace ? " #:in-place" : "")); } ChainBinaryExpr::ChainBinaryExpr(std::vector> exprs) : AcceptorExtend(), exprs(std::move(exprs)) {} ChainBinaryExpr::ChainBinaryExpr(const ChainBinaryExpr &expr, bool clean) : AcceptorExtend(expr, clean) { for (auto &e : expr.exprs) exprs.emplace_back(e.first, ast::clone(e.second, clean)); } std::string ChainBinaryExpr::toString(int indent) const { std::vector s; for (auto &i : exprs) s.push_back(fmt::format("({} \"{}\")", i.first, i.second->toString(indent))); return wrapType(fmt::format("chain {}", join(s, " "))); } Pipe Pipe::clone(bool clean) const { return {op, ast::clone(expr, clean)}; } PipeExpr::PipeExpr(std::vector items) : AcceptorExtend(), Items(std::move(items)) { for (auto &i : *this) { if (auto call = cast(i.expr)) { for (auto &a : *call) if (auto el = cast(a.value)) el->mode = EllipsisExpr::PIPE; } } } PipeExpr::PipeExpr(const PipeExpr &expr, bool clean) : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)), inTypes(expr.inTypes) {} std::string PipeExpr::toString(int indent) const { std::vector s; for (auto &i : items) s.push_back(fmt::format("({} \"{}\")", i.expr->toString(indent), i.op)); return wrapType(fmt::format("pipe {}", join(s, " "))); } IndexExpr::IndexExpr(Expr *expr, Expr *index) : AcceptorExtend(), expr(expr), index(index) {} IndexExpr::IndexExpr(const IndexExpr &expr, bool clean) : AcceptorExtend(expr, clean), expr(ast::clone(expr.expr, clean)), index(ast::clone(expr.index, clean)) {} std::string IndexExpr::toString(int indent) const { return wrapType( fmt::format("index {} {}", expr->toString(indent), index->toString(indent))); } CallArg CallArg::clone(bool clean) const { return CallArg{name, ast::clone(value, clean)}; } CallArg::CallArg(const SrcInfo &info, std::string name, Expr *value) : name(std::move(name)), value(value) { setSrcInfo(info); } CallArg::CallArg(std::string name, Expr *value) : name(std::move(name)), value(value) { if (value) setSrcInfo(value->getSrcInfo()); } CallArg::CallArg(Expr *value) : CallArg("", value) {} CallExpr::CallExpr(const CallExpr &expr, bool clean) : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)), expr(ast::clone(expr.expr, clean)), ordered(expr.ordered), partial(expr.partial) { } CallExpr::CallExpr(Expr *expr, std::vector args) : AcceptorExtend(), Items(std::move(args)), expr(expr), ordered(false), partial(false) {} CallExpr::CallExpr(Expr *expr, const std::vector &args) : AcceptorExtend(), Items({}), expr(expr), ordered(false), partial(false) { for (auto a : args) if (a) items.emplace_back("", a); } std::string CallExpr::toString(int indent) const { std::vector s; auto pad = indent >= 0 ? ("\n" + std::string(indent + 2 * INDENT_SIZE, ' ')) : " "; for (auto &i : *this) { if (!i.name.empty()) s.emplace_back(pad + fmt::format("#:name '{}", i.name)); s.emplace_back(pad + i.value->toString(indent >= 0 ? indent + 2 * INDENT_SIZE : -1)); } return wrapType(fmt::format("call{} {}{}", partial ? "-partial" : "", expr->toString(indent), join(s, ""))); } DotExpr::DotExpr(Expr *expr, std::string member) : AcceptorExtend(), expr(expr), member(std::move(member)) {} DotExpr::DotExpr(const DotExpr &expr, bool clean) : AcceptorExtend(expr, clean), expr(ast::clone(expr.expr, clean)), member(expr.member) {} std::string DotExpr::toString(int indent) const { return wrapType(fmt::format("dot {} '{}", expr->toString(indent), member)); } SliceExpr::SliceExpr(Expr *start, Expr *stop, Expr *step) : AcceptorExtend(), start(start), stop(stop), step(step) {} SliceExpr::SliceExpr(const SliceExpr &expr, bool clean) : AcceptorExtend(expr, clean), start(ast::clone(expr.start, clean)), stop(ast::clone(expr.stop, clean)), step(ast::clone(expr.step, clean)) {} std::string SliceExpr::toString(int indent) const { return wrapType(fmt::format( "slice{}{}{}", start ? fmt::format(" #:start {}", start->toString(indent)) : "", stop ? fmt::format(" #:end {}", stop->toString(indent)) : "", step ? fmt::format(" #:step {}", step->toString(indent)) : "")); } EllipsisExpr::EllipsisExpr(EllipsisType mode) : AcceptorExtend(), mode(mode) {} EllipsisExpr::EllipsisExpr(const EllipsisExpr &expr, bool clean) : AcceptorExtend(expr, clean), mode(expr.mode) {} std::string EllipsisExpr::toString(int) const { return wrapType(fmt::format( "ellipsis{}", mode == PIPE ? " #:pipe" : (mode == PARTIAL ? " #:partial" : ""))); } LambdaExpr::LambdaExpr(std::vector vars, Expr *expr) : AcceptorExtend(), Items(std::move(vars)), expr(expr) {} LambdaExpr::LambdaExpr(const LambdaExpr &expr, bool clean) : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)), expr(ast::clone(expr.expr, clean)) {} std::string LambdaExpr::toString(int indent) const { std::vector as; for (auto &a : items) as.push_back(a.toString(indent)); return wrapType(fmt::format("lambda ({}) {}", join(as, " "), expr->toString(indent))); } YieldExpr::YieldExpr() : AcceptorExtend() {} YieldExpr::YieldExpr(const YieldExpr &expr, bool clean) : AcceptorExtend(expr, clean) {} std::string YieldExpr::toString(int) const { return "yield-expr"; } AwaitExpr::AwaitExpr(Expr *expr) : AcceptorExtend(), expr(expr), transformed(false) {} AwaitExpr::AwaitExpr(const AwaitExpr &expr, bool clean) : AcceptorExtend(expr, clean), expr(ast::clone(expr.expr, clean)), transformed(expr.transformed) {} std::string AwaitExpr::toString(int indent) const { return wrapType(fmt::format("await {}", expr->toString(indent))); } AssignExpr::AssignExpr(Expr *var, Expr *expr) : AcceptorExtend(), var(var), expr(expr) {} AssignExpr::AssignExpr(const AssignExpr &expr, bool clean) : AcceptorExtend(expr, clean), var(ast::clone(expr.var, clean)), expr(ast::clone(expr.expr, clean)) {} std::string AssignExpr::toString(int indent) const { return wrapType( fmt::format("assign-expr '{} {}", var->toString(indent), expr->toString(indent))); } RangeExpr::RangeExpr(Expr *start, Expr *stop) : AcceptorExtend(), start(start), stop(stop) {} RangeExpr::RangeExpr(const RangeExpr &expr, bool clean) : AcceptorExtend(expr, clean), start(ast::clone(expr.start, clean)), stop(ast::clone(expr.stop, clean)) {} std::string RangeExpr::toString(int indent) const { return wrapType( fmt::format("range {} {}", start->toString(indent), stop->toString(indent))); } StmtExpr::StmtExpr(std::vector stmts, Expr *expr) : AcceptorExtend(), Items(std::move(stmts)), expr(expr) {} StmtExpr::StmtExpr(Stmt *stmt, Expr *expr) : AcceptorExtend(), Items({}), expr(expr) { items.push_back(stmt); } StmtExpr::StmtExpr(Stmt *stmt, Stmt *stmt2, Expr *expr) : AcceptorExtend(), Items({}), expr(expr) { items.push_back(stmt); items.push_back(stmt2); } StmtExpr::StmtExpr(const StmtExpr &expr, bool clean) : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)), expr(ast::clone(expr.expr, clean)) {} std::string StmtExpr::toString(int indent) const { auto pad = indent >= 0 ? ("\n" + std::string(indent + 2 * INDENT_SIZE, ' ')) : " "; std::vector s; s.reserve(items.size()); for (auto &i : items) s.emplace_back(pad + i->toString(indent >= 0 ? indent + 2 * INDENT_SIZE : -1)); return wrapType( fmt::format("stmt-expr {} ({})", expr->toString(indent), join(s, ""))); } InstantiateExpr::InstantiateExpr(Expr *expr, std::vector typeParams) : AcceptorExtend(), Items(std::move(typeParams)), expr(expr) {} InstantiateExpr::InstantiateExpr(Expr *expr, Expr *typeParam) : AcceptorExtend(), Items({typeParam}), expr(expr) {} InstantiateExpr::InstantiateExpr(const InstantiateExpr &expr, bool clean) : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)), expr(ast::clone(expr.expr, clean)) {} std::string InstantiateExpr::toString(int indent) const { return wrapType( fmt::format("instantiate {} {}", expr->toString(indent), combine(items))); } bool isId(Expr *e, const std::string &s) { auto ie = cast(e); return ie && ie->getValue() == s; } types::LiteralKind getStaticGeneric(Expr *e) { IdExpr *ie = nullptr; if (match(e, M(M(MOr("Static", "Literal")), MVar(ie)))) { return types::Type::literalFromString(ie->getValue()); } return types::LiteralKind::Runtime; } const char ASTNode::NodeId = 0; const char Expr::NodeId = 0; ACCEPT_IMPL(NoneExpr, ASTVisitor); ACCEPT_IMPL(BoolExpr, ASTVisitor); ACCEPT_IMPL(IntExpr, ASTVisitor); ACCEPT_IMPL(FloatExpr, ASTVisitor); ACCEPT_IMPL(StringExpr, ASTVisitor); ACCEPT_IMPL(IdExpr, ASTVisitor); ACCEPT_IMPL(StarExpr, ASTVisitor); ACCEPT_IMPL(KeywordStarExpr, ASTVisitor); ACCEPT_IMPL(TupleExpr, ASTVisitor); ACCEPT_IMPL(ListExpr, ASTVisitor); ACCEPT_IMPL(SetExpr, ASTVisitor); ACCEPT_IMPL(DictExpr, ASTVisitor); ACCEPT_IMPL(GeneratorExpr, ASTVisitor); ACCEPT_IMPL(IfExpr, ASTVisitor); ACCEPT_IMPL(UnaryExpr, ASTVisitor); ACCEPT_IMPL(BinaryExpr, ASTVisitor); ACCEPT_IMPL(ChainBinaryExpr, ASTVisitor); ACCEPT_IMPL(PipeExpr, ASTVisitor); ACCEPT_IMPL(IndexExpr, ASTVisitor); ACCEPT_IMPL(CallExpr, ASTVisitor); ACCEPT_IMPL(DotExpr, ASTVisitor); ACCEPT_IMPL(SliceExpr, ASTVisitor); ACCEPT_IMPL(EllipsisExpr, ASTVisitor); ACCEPT_IMPL(LambdaExpr, ASTVisitor); ACCEPT_IMPL(YieldExpr, ASTVisitor); ACCEPT_IMPL(AwaitExpr, ASTVisitor); ACCEPT_IMPL(AssignExpr, ASTVisitor); ACCEPT_IMPL(RangeExpr, ASTVisitor); ACCEPT_IMPL(StmtExpr, ASTVisitor); ACCEPT_IMPL(InstantiateExpr, ASTVisitor); } // namespace codon::ast namespace tser { void operator<<(codon::ast::Expr *t, BinaryArchive &a) { using S = codon::PolymorphicSerializer; a.save(t != nullptr); if (t) { void *typ = const_cast(t->dynamicNodeId()); auto key = S::_serializers[typ]; a.save(key); S::save(key, t, a); } } void operator>>(codon::ast::Expr *&t, BinaryArchive &a) { using S = codon::PolymorphicSerializer; bool empty = a.load(); if (!empty) { const auto key = a.load(); S::load(key, t, a); } else { t = nullptr; } } } // namespace tser ================================================ FILE: codon/parser/ast/expr.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include "codon/parser/ast/node.h" #include "codon/parser/ast/types.h" #include "codon/parser/common.h" #include "codon/util/serialize.h" namespace codon::ast { #define ACCEPT(CLASS, VISITOR, ...) \ static const char NodeId; \ using AcceptorExtend::clone; \ using AcceptorExtend::accept; \ ASTNode *clone(bool c) const override; \ void accept(VISITOR &visitor) override; \ std::string toString(int) const override; \ friend class TypecheckVisitor; \ template friend struct CallbackASTVisitor; \ friend struct ReplacingCallbackASTVisitor; \ inline decltype(auto) match_members() const { return std::tie(__VA_ARGS__); } \ SERIALIZE(CLASS, BASE(Expr), ##__VA_ARGS__) // Forward declarations struct Stmt; /** * A Seq AST expression. * Each AST expression is intended to be instantiated as a shared_ptr. */ struct Expr : public AcceptorExtend { using base_type = Expr; Expr(); Expr(const Expr &) = default; Expr(const Expr &, bool); /// Get a node type. /// @return Type pointer or a nullptr if a type is not set. types::Type *getType() const { return type.get(); } void setType(const types::TypePtr &t) { type = t; } types::ClassType *getClassType() const; bool isDone() const { return done; } void setDone() { done = true; } Expr *getOrigExpr() const { return origExpr; } void setOrigExpr(Expr *orig) { origExpr = orig; } static const char NodeId; SERIALIZE(Expr, BASE(ASTNode), /*type,*/ done, origExpr); Expr *operator<<(types::Type *t); protected: /// Add a type to S-expression string. std::string wrapType(const std::string &sexpr) const; private: /// Type of the expression. nullptr by default. types::TypePtr type; /// Flag that indicates if all types in an expression are inferred (i.e. if a /// type-checking procedure was successful). bool done; /// Original (pre-transformation) expression Expr *origExpr; }; /// Function signature parameter helper node (name: type = defaultValue). struct Param : public codon::SrcObject { std::string name; Expr *type; Expr *defaultValue; enum { Value, Generic, HiddenGeneric } status; // 1 for normal generic, 2 for hidden generic explicit Param(std::string name = "", Expr *type = nullptr, Expr *defaultValue = nullptr, int generic = 0); explicit Param(const SrcInfo &info, std::string name = "", Expr *type = nullptr, Expr *defaultValue = nullptr, int generic = 0); std::string getName() const { return name; } Expr *getType() const { return type; } Expr *getDefault() const { return defaultValue; } bool isValue() const { return status == Value; } bool isGeneric() const { return status == Generic; } bool isHiddenGeneric() const { return status == HiddenGeneric; } std::pair getNameWithStars() const; SERIALIZE(Param, name, type, defaultValue); Param clone(bool) const; std::string toString(int) const; }; /// None expression. /// @li None struct NoneExpr : public AcceptorExtend { NoneExpr(); NoneExpr(const NoneExpr &, bool); ACCEPT(NoneExpr, ASTVisitor); }; /// Bool expression (value). /// @li True struct BoolExpr : public AcceptorExtend { explicit BoolExpr(bool value = false); BoolExpr(const BoolExpr &, bool); bool getValue() const; ACCEPT(BoolExpr, ASTVisitor, value); private: bool value; }; /// Int expression (value.suffix). /// @li 12 /// @li 13u /// @li 000_010b struct IntExpr : public AcceptorExtend { explicit IntExpr(int64_t intValue = 0); explicit IntExpr(const std::string &value, std::string suffix = ""); IntExpr(const IntExpr &, bool); bool hasStoredValue() const; int64_t getValue() const; std::pair getRawData() const; ACCEPT(IntExpr, ASTVisitor, value, suffix, intValue); private: /// Expression value is stored as a string that is parsed during typechecking. std::string value; /// Number suffix (e.g. "u" for "123u"). std::string suffix; /// Parsed value and sign for "normal" 64-bit integers. std::optional intValue; }; /// Float expression (value.suffix). /// @li 12.1 /// @li 13.15z /// @li e-12 struct FloatExpr : public AcceptorExtend { explicit FloatExpr(double floatValue = 0.0); explicit FloatExpr(const std::string &value, std::string suffix = ""); FloatExpr(const FloatExpr &, bool); bool hasStoredValue() const; double getValue() const; std::pair getRawData() const; ACCEPT(FloatExpr, ASTVisitor, value, suffix, floatValue); private: /// Expression value is stored as a string that is parsed during typechecking. std::string value; /// Number suffix (e.g. "u" for "123u"). std::string suffix; /// Parsed value for 64-bit floats. std::optional floatValue; }; /// String expression (prefix"value"). /// @li s'ACGT' /// @li "fff" struct StringExpr : public AcceptorExtend { struct FormatSpec { std::string text; std::string conversion; std::string spec; SERIALIZE(FormatSpec, text, conversion, spec); }; // Vector of {value, prefix} strings. struct String : public SrcObject { std::string value; std::string prefix; Expr *expr; FormatSpec format; explicit String(std::string v, std::string p = "", Expr *e = nullptr) : value(std::move(v)), prefix(std::move(p)), expr(e), format() {} SERIALIZE(String, value, prefix, expr, format); }; explicit StringExpr(std::string value = "", std::string prefix = ""); explicit StringExpr(std::vector strings); StringExpr(const StringExpr &, bool); std::string getValue() const; bool isSimple() const; ACCEPT(StringExpr, ASTVisitor, strings); private: std::vector strings; auto begin() { return strings.begin(); } auto end() { return strings.end(); } friend class ScopingVisitor; }; /// Identifier expression (value). struct IdExpr : public AcceptorExtend { explicit IdExpr(std::string value = ""); IdExpr(const IdExpr &, bool); std::string getValue() const { return value; } ACCEPT(IdExpr, ASTVisitor, value); private: std::string value; void setValue(const std::string &s) { value = s; } friend class ScopingVisitor; }; /// Star (unpacking) expression (*what). /// @li *args struct StarExpr : public AcceptorExtend { explicit StarExpr(Expr *what = nullptr); StarExpr(const StarExpr &, bool); Expr *getExpr() const { return expr; } ACCEPT(StarExpr, ASTVisitor, expr); private: Expr *expr; }; /// KeywordStar (unpacking) expression (**what). /// @li **kwargs struct KeywordStarExpr : public AcceptorExtend { explicit KeywordStarExpr(Expr *what = nullptr); KeywordStarExpr(const KeywordStarExpr &, bool); Expr *getExpr() const { return expr; } ACCEPT(KeywordStarExpr, ASTVisitor, expr); private: Expr *expr; }; /// Tuple expression ((items...)). /// @li (1, a) struct TupleExpr : public AcceptorExtend, Items { explicit TupleExpr(std::vector items = {}); TupleExpr(const TupleExpr &, bool); ACCEPT(TupleExpr, ASTVisitor, items); }; /// List expression ([items...]). /// @li [1, 2] struct ListExpr : public AcceptorExtend, Items { explicit ListExpr(std::vector items = {}); ListExpr(const ListExpr &, bool); ACCEPT(ListExpr, ASTVisitor, items); }; /// Set expression ({items...}). /// @li {1, 2} struct SetExpr : public AcceptorExtend, Items { explicit SetExpr(std::vector items = {}); SetExpr(const SetExpr &, bool); ACCEPT(SetExpr, ASTVisitor, items); }; /// Dictionary expression ({(key: value)...}). /// Each (key, value) pair is stored as a TupleExpr. /// @li {'s': 1, 't': 2} struct DictExpr : public AcceptorExtend, Items { explicit DictExpr(std::vector items = {}); DictExpr(const DictExpr &, bool); ACCEPT(DictExpr, ASTVisitor, items); }; /// Generator or comprehension expression [(expr (loops...))]. /// @li [i for i in j] /// @li (f + 1 for j in k if j for f in j) struct GeneratorExpr : public AcceptorExtend { /// Generator kind: normal generator, list comprehension, set comprehension. enum GeneratorKind { Generator, ListGenerator, SetGenerator, TupleGenerator, DictGenerator }; GeneratorExpr() : kind(Generator), loops(nullptr) {} GeneratorExpr(Cache *cache, GeneratorKind kind, Expr *expr, std::vector loops); GeneratorExpr(Cache *cache, Expr *key, Expr *expr, std::vector loops); GeneratorExpr(const GeneratorExpr &, bool); int loopCount() const; Stmt *getFinalSuite() const; Expr *getFinalExpr(); ACCEPT(GeneratorExpr, ASTVisitor, kind, loops); private: GeneratorKind kind; Stmt *loops; Stmt **getFinalStmt(); void setFinalExpr(Expr *); void setFinalStmt(Stmt *); void formCompleteStmt(const std::vector &); friend class TranslateVisitor; }; /// Conditional expression [cond if ifexpr else elsexpr]. /// @li 1 if a else 2 struct IfExpr : public AcceptorExtend { explicit IfExpr(Expr *cond = nullptr, Expr *ifexpr = nullptr, Expr *elsexpr = nullptr); IfExpr(const IfExpr &, bool); Expr *getCond() const { return cond; } Expr *getIf() const { return ifexpr; } Expr *getElse() const { return elsexpr; } ACCEPT(IfExpr, ASTVisitor, cond, ifexpr, elsexpr); private: Expr *cond, *ifexpr, *elsexpr; }; /// Unary expression [op expr]. /// @li -56 struct UnaryExpr : public AcceptorExtend { explicit UnaryExpr(std::string op = "", Expr *expr = nullptr); UnaryExpr(const UnaryExpr &, bool); std::string getOp() const { return op; } Expr *getExpr() const { return expr; } ACCEPT(UnaryExpr, ASTVisitor, op, expr); private: std::string op; Expr *expr; }; /// Binary expression [lexpr op rexpr]. /// @li 1 + 2 /// @li 3 or 4 struct BinaryExpr : public AcceptorExtend { explicit BinaryExpr(Expr *lexpr = nullptr, std::string op = "", Expr *rexpr = nullptr, bool inPlace = false); BinaryExpr(const BinaryExpr &, bool); std::string getOp() const { return op; } void setOp(const std::string &o) { op = o; } Expr *getLhs() const { return lexpr; } Expr *getRhs() const { return rexpr; } bool isInPlace() const { return inPlace; } ACCEPT(BinaryExpr, ASTVisitor, op, lexpr, rexpr); private: std::string op; Expr *lexpr, *rexpr; /// True if an expression modifies lhs in-place (e.g. a += b). bool inPlace; }; /// Chained binary expression. /// @li 1 <= x <= 2 struct ChainBinaryExpr : public AcceptorExtend { explicit ChainBinaryExpr(std::vector> exprs = {}); ChainBinaryExpr(const ChainBinaryExpr &, bool); ACCEPT(ChainBinaryExpr, ASTVisitor, exprs); private: std::vector> exprs; }; struct Pipe { std::string op; Expr *expr; SERIALIZE(Pipe, op, expr); Pipe clone(bool) const; }; /// Pipe expression [(op expr)...]. /// op is either "" (only the first item), "|>" or "||>". /// @li a |> b ||> c struct PipeExpr : public AcceptorExtend, Items { explicit PipeExpr(std::vector items = {}); PipeExpr(const PipeExpr &, bool); ACCEPT(PipeExpr, ASTVisitor, items); private: /// Output type of a "prefix" pipe ending at the index position. /// Example: for a |> b |> c, inTypes[1] is typeof(a |> b). std::vector inTypes; }; /// Index expression (expr[index]). /// @li a[5] struct IndexExpr : public AcceptorExtend { explicit IndexExpr(Expr *expr = nullptr, Expr *index = nullptr); IndexExpr(const IndexExpr &, bool); Expr *getExpr() const { return expr; } Expr *getIndex() const { return index; } ACCEPT(IndexExpr, ASTVisitor, expr, index); private: Expr *expr, *index; }; struct CallArg : public codon::SrcObject { std::string name; Expr *value; explicit CallArg(std::string name = "", Expr *value = nullptr); CallArg(const SrcInfo &info, std::string name, Expr *value); CallArg(Expr *value); std::string getName() const { return name; } Expr *getExpr() const { return value; } operator Expr *() const { return value; } SERIALIZE(CallArg, name, value); CallArg clone(bool) const; }; /// Call expression (expr((name=value)...)). /// @li a(1, b=2) struct CallExpr : public AcceptorExtend, Items { /// Each argument can have a name (e.g. foo(1, b=5)) explicit CallExpr(Expr *expr = nullptr, std::vector args = {}); /// Convenience constructors CallExpr(Expr *expr, const std::vector &args); template CallExpr(Expr *expr, Expr *arg, Ts... args) : CallExpr(expr, std::vector{arg, args...}) {} CallExpr(const CallExpr &, bool); Expr *getExpr() const { return expr; } bool isOrdered() const { return ordered; } bool isPartial() const { return partial; } ACCEPT(CallExpr, ASTVisitor, expr, items, ordered, partial); private: Expr *expr; /// True if type-checker has processed and re-ordered args. bool ordered; /// True if the call is partial bool partial = false; }; /// Dot (access) expression (expr.member). /// @li a.b struct DotExpr : public AcceptorExtend { DotExpr() : expr(nullptr), member() {} DotExpr(Expr *expr, std::string member); DotExpr(const DotExpr &, bool); Expr *getExpr() const { return expr; } std::string getMember() const { return member; } ACCEPT(DotExpr, ASTVisitor, expr, member); private: Expr *expr; std::string member; }; /// Slice expression (st:stop:step). /// @li 1:10:3 /// @li s::-1 /// @li ::: struct SliceExpr : public AcceptorExtend { explicit SliceExpr(Expr *start = nullptr, Expr *stop = nullptr, Expr *step = nullptr); SliceExpr(const SliceExpr &, bool); Expr *getStart() const { return start; } Expr *getStop() const { return stop; } Expr *getStep() const { return step; } ACCEPT(SliceExpr, ASTVisitor, start, stop, step); private: /// Any of these can be nullptr to account for partial slices. Expr *start, *stop, *step; }; /// Ellipsis expression. /// @li ... struct EllipsisExpr : public AcceptorExtend { /// True if this is a target partial argument within a PipeExpr. /// If true, this node will be handled differently during the type-checking stage. enum EllipsisType { PIPE, PARTIAL, STANDALONE }; explicit EllipsisExpr(EllipsisType mode = STANDALONE); EllipsisExpr(const EllipsisExpr &, bool); EllipsisType getMode() const { return mode; } bool isStandalone() const { return mode == STANDALONE; } bool isPipe() const { return mode == PIPE; } bool isPartial() const { return mode == PARTIAL; } ACCEPT(EllipsisExpr, ASTVisitor, mode); private: EllipsisType mode; friend struct PipeExpr; }; /// Lambda expression (lambda (vars)...: expr). /// @li lambda a, b: a + b struct LambdaExpr : public AcceptorExtend, Items { explicit LambdaExpr(std::vector vars = {}, Expr *expr = nullptr); LambdaExpr(const LambdaExpr &, bool); Expr *getExpr() const { return expr; } ACCEPT(LambdaExpr, ASTVisitor, expr, items); private: Expr *expr; }; /// Yield (send to generator) expression. /// @li (yield) struct YieldExpr : public AcceptorExtend { YieldExpr(); YieldExpr(const YieldExpr &, bool); ACCEPT(YieldExpr, ASTVisitor); }; /// Await expression (await expr). /// @li await a struct AwaitExpr : public AcceptorExtend { explicit AwaitExpr(Expr *expr); AwaitExpr(const AwaitExpr &, bool); Expr *getExpr() const { return expr; } ACCEPT(AwaitExpr, ASTVisitor, expr); private: Expr *expr; // True if a statement was transformed during type-checking stage // (to avoid setting up __await__ multiple times). bool transformed; }; /// Assignment (walrus) expression (var := expr). /// @li a := 5 + 3 struct AssignExpr : public AcceptorExtend { explicit AssignExpr(Expr *var = nullptr, Expr *expr = nullptr); AssignExpr(const AssignExpr &, bool); Expr *getVar() const { return var; } Expr *getExpr() const { return expr; } ACCEPT(AssignExpr, ASTVisitor, var, expr); private: Expr *var, *expr; }; /// Range expression (start ... end). /// Used only in match-case statements. /// @li 1 ... 2 struct RangeExpr : public AcceptorExtend { explicit RangeExpr(Expr *start = nullptr, Expr *stop = nullptr); RangeExpr(const RangeExpr &, bool); Expr *getStart() const { return start; } Expr *getStop() const { return stop; } ACCEPT(RangeExpr, ASTVisitor, start, stop); private: Expr *start, *stop; }; /// The following nodes are created during typechecking. /// Statement expression (stmts...; expr). /// Statements are evaluated only if the expression is evaluated /// (to support short-circuiting). /// @li (a = 1; b = 2; a + b) struct StmtExpr : public AcceptorExtend, Items { explicit StmtExpr(Stmt *stmt = nullptr, Expr *expr = nullptr); StmtExpr(std::vector stmts, Expr *expr); StmtExpr(Stmt *stmt, Stmt *stmt2, Expr *expr); StmtExpr(const StmtExpr &, bool); Expr *getExpr() const { return expr; } ACCEPT(StmtExpr, ASTVisitor, expr, items); private: Expr *expr; }; /// Static tuple indexing expression (expr[index]). /// @li (1, 2, 3)[2] struct InstantiateExpr : public AcceptorExtend, Items { explicit InstantiateExpr(Expr *expr = nullptr, std::vector typeParams = {}); /// Convenience constructor for a single type parameter. InstantiateExpr(Expr *expr, Expr *typeParam); InstantiateExpr(const InstantiateExpr &, bool); Expr *getExpr() const { return expr; } ACCEPT(InstantiateExpr, ASTVisitor, expr, items); private: Expr *expr; }; #undef ACCEPT bool isId(Expr *e, const std::string &s); types::LiteralKind getStaticGeneric(Expr *e); } // namespace codon::ast template <> struct fmt::formatter : fmt::formatter { template auto format(const codon::ast::CallArg &p, FormatContext &ctx) const -> decltype(ctx.out()) { return fmt::format_to(ctx.out(), "({}{})", p.name.empty() ? "" : fmt::format("{} = ", p.name), p.value ? p.value->toString(0) : "-"); } }; template <> struct fmt::formatter : fmt::formatter { template auto format(const codon::ast::Param &p, FormatContext &ctx) const -> decltype(ctx.out()) { return fmt::format_to(ctx.out(), "{}", p.toString(0)); } }; namespace tser { void operator<<(codon::ast::Expr *t, BinaryArchive &a); void operator>>(codon::ast::Expr *&t, BinaryArchive &a); } // namespace tser ================================================ FILE: codon/parser/ast/node.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include "codon/cir/base.h" namespace codon::ast { using ir::cast; // Forward declarations struct Cache; struct ASTVisitor; struct ASTNode : public ir::Node { static const char NodeId; using ir::Node::Node; /// See LLVM documentation. static const void *nodeId() { return &NodeId; } const void *dynamicNodeId() const override { return &NodeId; } /// See LLVM documentation. virtual bool isConvertible(const void *other) const override { return other == nodeId() || ir::Node::isConvertible(other); } Cache *cache = nullptr; ASTNode() = default; ASTNode(const ASTNode &) = default; virtual ~ASTNode() = default; /// Convert a node to an S-expression. virtual std::string toString(int) const = 0; virtual std::string toString() const { return toString(-1); } /// Deep copy a node. virtual ASTNode *clone(bool clean) const = 0; ASTNode *clone() const { return clone(false); } using ir::Node::accept; /// Accept an AST visitor. virtual void accept(ASTVisitor &visitor) {} /// Allow pretty-printing to C++ streams. friend std::ostream &operator<<(std::ostream &out, const ASTNode &expr) { return out << expr.toString(); } void setAttribute(int key, std::unique_ptr value) { attributes[key] = std::move(value); } void setAttribute(int key, const std::string &value) { attributes[key] = std::make_unique(value); } void setAttribute(int key, int64_t value) { attributes[key] = std::make_unique(value); } void setAttribute(int key) { attributes[key] = std::make_unique(); } inline decltype(auto) members() { int a = 0; return std::tie(a); } }; template void E(error::Error e, ASTNode *o, const TA &...args) { E(e, o->getSrcInfo(), args...); } template void E(error::Error e, const ASTNode &o, const TA &...args) { E(e, o.getSrcInfo(), args...); } template class AcceptorExtend : public Parent { public: using Parent::Parent; /// See LLVM documentation. static const void *nodeId() { return &Derived::NodeId; } const void *dynamicNodeId() const override { return &Derived::NodeId; } /// See LLVM documentation. virtual bool isConvertible(const void *other) const override { return other == nodeId() || Parent::isConvertible(other); } }; template struct Items { explicit Items(std::vector items) : items(std::move(items)) {} const T &operator[](size_t i) const { return items[i]; } T &operator[](size_t i) { return items[i]; } auto begin() { return items.begin(); } auto end() { return items.end(); } auto begin() const { return items.begin(); } auto end() const { return items.end(); } auto size() const { return items.size(); } bool empty() const { return items.empty(); } const T &front() const { return items.front(); } const T &back() const { return items.back(); } T &front() { return items.front(); } T &back() { return items.back(); } protected: std::vector items; }; } // namespace codon::ast template struct fmt::formatter, char>> : fmt::ostream_formatter {}; ================================================ FILE: codon/parser/ast/stmt.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "stmt.h" #include #include #include #include #include "codon/parser/cache.h" #include "codon/parser/match.h" #include "codon/parser/visitors/visitor.h" #define ACCEPT_IMPL(T, X) \ ASTNode *T::clone(bool clean) const { return cache->N(*this, clean); } \ void T::accept(X &visitor) { visitor.visit(this); } \ const char T::NodeId = 0; using namespace codon::error; using namespace codon::matcher; namespace codon::ast { Stmt::Stmt() : AcceptorExtend(), done(false) {} Stmt::Stmt(const Stmt &stmt) : AcceptorExtend(stmt), done(stmt.done) {} Stmt::Stmt(const codon::SrcInfo &s) : AcceptorExtend(), done(false) { setSrcInfo(s); } Stmt::Stmt(const Stmt &stmt, bool clean) : AcceptorExtend(stmt), done(stmt.done) { if (clean) done = false; } std::string Stmt::wrapStmt(const std::string &s) const { return s; } SuiteStmt::SuiteStmt(std::vector stmts) : AcceptorExtend(), Items(std::move(stmts)) {} SuiteStmt::SuiteStmt(const SuiteStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)) {} std::string SuiteStmt::toString(int indent) const { if (indent == -1) return ""; std::string pad = indent >= 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::string s; for (int i = 0; i < size(); i++) if (items[i]) { auto is = items[i]->toString(indent >= 0 ? indent + INDENT_SIZE : -1); if (items[i]->isDone()) is.insert(findStar(is), "*"); s += (i ? pad : "") + is; } return wrapStmt(fmt::format("({}suite{})", (isDone() ? "*" : ""), (s.empty() ? s : " " + pad + s))); } void SuiteStmt::flatten() { std::vector ns; for (auto &s : items) { if (!s) continue; if (!cast(s)) { ns.push_back(s); } else { for (auto *ss : *cast(s)) ns.push_back(ss); } } items = ns; } void SuiteStmt::addStmt(Stmt *s) { if (s) { items.push_back(s); done = false; } } SuiteStmt *SuiteStmt::wrap(Stmt *s) { if (s && !cast(s)) return s->cache->NS(s, s); return static_cast(s); } BreakStmt::BreakStmt(const BreakStmt &stmt, bool clean) : AcceptorExtend(stmt, clean) {} std::string BreakStmt::toString(int indent) const { return wrapStmt("(break)"); } ContinueStmt::ContinueStmt(const ContinueStmt &stmt, bool clean) : AcceptorExtend(stmt, clean) {} std::string ContinueStmt::toString(int indent) const { return wrapStmt("(continue)"); } ExprStmt::ExprStmt(Expr *expr) : AcceptorExtend(), expr(expr) {} ExprStmt::ExprStmt(const ExprStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} std::string ExprStmt::toString(int indent) const { return wrapStmt(fmt::format("(expr {})", expr->toString(indent))); } AssignStmt::AssignStmt(Expr *lhs, Expr *rhs, Expr *type, UpdateMode update) : AcceptorExtend(), lhs(lhs), rhs(rhs), type(type), update(update) {} AssignStmt::AssignStmt(const AssignStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), lhs(ast::clone(stmt.lhs, clean)), rhs(ast::clone(stmt.rhs, clean)), type(ast::clone(stmt.type, clean)), update(stmt.update) {} std::string AssignStmt::toString(int indent) const { return wrapStmt( fmt::format("({} {}{}{})", update != Assign ? "update" : "assign", lhs->toString(indent), rhs ? " " + rhs->toString(indent) : "", type ? fmt::format(" #:type {}", type->toString(indent)) : "")); } DelStmt::DelStmt(Expr *expr) : AcceptorExtend(), expr(expr) {} DelStmt::DelStmt(const DelStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} std::string DelStmt::toString(int indent) const { return wrapStmt(fmt::format("(del {})", expr->toString(indent))); } PrintStmt::PrintStmt(std::vector items, bool noNewline) : AcceptorExtend(), Items(std::move(items)), noNewline(noNewline) {} PrintStmt::PrintStmt(const PrintStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), noNewline(stmt.noNewline) {} std::string PrintStmt::toString(int indent) const { return wrapStmt( fmt::format("(print {}{})", noNewline ? "#:inline " : "", combine(items))); } ReturnStmt::ReturnStmt(Expr *expr) : AcceptorExtend(), expr(expr) {} ReturnStmt::ReturnStmt(const ReturnStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} std::string ReturnStmt::toString(int indent) const { return wrapStmt(expr ? fmt::format("(return {})", expr->toString(indent)) : "(return)"); } YieldStmt::YieldStmt(Expr *expr) : AcceptorExtend(), expr(expr) {} YieldStmt::YieldStmt(const YieldStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} std::string YieldStmt::toString(int indent) const { return wrapStmt(expr ? fmt::format("(yield {})", expr->toString(indent)) : "(yield)"); } AssertStmt::AssertStmt(Expr *expr, Expr *message) : AcceptorExtend(), expr(expr), message(message) {} AssertStmt::AssertStmt(const AssertStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)), message(ast::clone(stmt.message, clean)) {} std::string AssertStmt::toString(int indent) const { return wrapStmt(fmt::format("(assert {}{})", expr->toString(indent), message ? message->toString(indent) : "")); } WhileStmt::WhileStmt(Expr *cond, Stmt *suite, Stmt *elseSuite) : AcceptorExtend(), cond(cond), suite(SuiteStmt::wrap(suite)), elseSuite(SuiteStmt::wrap(elseSuite)) {} WhileStmt::WhileStmt(const WhileStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), cond(ast::clone(stmt.cond, clean)), suite(ast::clone(stmt.suite, clean)), elseSuite(ast::clone(stmt.elseSuite, clean)) {} std::string WhileStmt::toString(int indent) const { if (indent == -1) return wrapStmt(fmt::format("(while {})", cond->toString(indent))); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; if (elseSuite && elseSuite->firstInBlock()) { return wrapStmt( fmt::format("(while-else {}{}{}{}{})", cond->toString(indent), pad, suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); } else { return wrapStmt( fmt::format("(while {}{}{})", cond->toString(indent), pad, suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); } } ForStmt::ForStmt(Expr *var, Expr *iter, Stmt *suite, Stmt *elseSuite, Expr *decorator, std::vector ompArgs, bool async) : AcceptorExtend(), var(var), iter(iter), suite(SuiteStmt::wrap(suite)), elseSuite(SuiteStmt::wrap(elseSuite)), decorator(decorator), ompArgs(std::move(ompArgs)), async(async), wrapped(false), flat(false) {} ForStmt::ForStmt(const ForStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), var(ast::clone(stmt.var, clean)), iter(ast::clone(stmt.iter, clean)), suite(ast::clone(stmt.suite, clean)), elseSuite(ast::clone(stmt.elseSuite, clean)), decorator(ast::clone(stmt.decorator, clean)), ompArgs(ast::clone(stmt.ompArgs, clean)), async(stmt.async), wrapped(stmt.wrapped), flat(stmt.flat) {} std::string ForStmt::toString(int indent) const { auto vs = var->toString(indent); if (indent == -1) return wrapStmt(fmt::format("(for {} {})", vs, iter->toString(indent))); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::string attr; if (decorator) attr += " " + decorator->toString(indent); if (!attr.empty()) attr = " #:attr" + attr; if (elseSuite && elseSuite->firstInBlock()) { return wrapStmt( fmt::format("(for-else {} {}{}{}{}{}{})", vs, iter->toString(indent), attr, pad, suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); } else { return wrapStmt( fmt::format("(for {} {}{}{}{})", vs, iter->toString(indent), attr, pad, suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); } } IfStmt::IfStmt(Expr *cond, Stmt *ifSuite, Stmt *elseSuite) : AcceptorExtend(), cond(cond), ifSuite(SuiteStmt::wrap(ifSuite)), elseSuite(SuiteStmt::wrap(elseSuite)) {} IfStmt::IfStmt(const IfStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), cond(ast::clone(stmt.cond, clean)), ifSuite(ast::clone(stmt.ifSuite, clean)), elseSuite(ast::clone(stmt.elseSuite, clean)) {} std::string IfStmt::toString(int indent) const { if (indent == -1) return wrapStmt(fmt::format("(if {})", cond->toString(indent))); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; return wrapStmt(fmt::format( "(if {}{}{}{})", cond->toString(indent), pad, ifSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), elseSuite ? pad + elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : "")); } MatchCase::MatchCase(Expr *pattern, Expr *guard, Stmt *suite) : pattern(pattern), guard(guard), suite(SuiteStmt::wrap(suite)) {} MatchCase MatchCase::clone(bool clean) const { return {ast::clone(pattern, clean), ast::clone(guard, clean), ast::clone(suite, clean)}; } MatchStmt::MatchStmt(Expr *expr, std::vector cases) : AcceptorExtend(), Items(std::move(cases)), expr(expr) {} MatchStmt::MatchStmt(const MatchStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), expr(ast::clone(stmt.expr, clean)) {} std::string MatchStmt::toString(int indent) const { if (indent == -1) return wrapStmt(fmt::format("(match {})", expr->toString(indent))); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::string padExtra = indent > 0 ? std::string(INDENT_SIZE, ' ') : ""; std::vector s; for (auto &c : items) s.push_back(fmt::format( "(case {}{}{}{})", c.pattern->toString(indent), c.guard ? " #:guard " + c.guard->toString(indent) : "", pad + padExtra, c.suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1 * 2))); return wrapStmt( fmt::format("(match {}{}{})", expr->toString(indent), pad, join(s, pad))); } ImportStmt::ImportStmt(Expr *from, Expr *what, std::vector args, Expr *ret, std::string as, size_t dots, bool isFunction) : AcceptorExtend(), from(from), what(what), as(std::move(as)), dots(dots), args(std::move(args)), ret(ret), isFunction(isFunction) {} ImportStmt::ImportStmt(const ImportStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), from(ast::clone(stmt.from, clean)), what(ast::clone(stmt.what, clean)), as(stmt.as), dots(stmt.dots), args(ast::clone(stmt.args, clean)), ret(ast::clone(stmt.ret, clean)), isFunction(stmt.isFunction) {} std::string ImportStmt::toString(int indent) const { std::vector va; for (auto &a : args) va.push_back(a.toString(indent)); return wrapStmt( fmt::format("(import {}{}{}{}{}{})", from ? from->toString(indent) : "", as.empty() ? "" : fmt::format(" #:as '{}", as), what ? fmt::format(" #:what {}", what->toString(indent)) : "", dots ? fmt::format(" #:dots {}", dots) : "", va.empty() ? "" : fmt::format(" #:args ({})", join(va)), ret ? fmt::format(" #:ret {}", ret->toString(indent)) : "")); } ExceptStmt::ExceptStmt(const std::string &var, Expr *exc, Stmt *suite) : var(var), exc(exc), suite(SuiteStmt::wrap(suite)) {} ExceptStmt::ExceptStmt(const ExceptStmt &stmt, bool clean) : AcceptorExtend(stmt), var(stmt.var), exc(ast::clone(stmt.exc, clean)), suite(ast::clone(stmt.suite, clean)) {} std::string ExceptStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::string padExtra = indent > 0 ? std::string(INDENT_SIZE, ' ') : ""; return wrapStmt(fmt::format( "(catch {}{}{}{})", !var.empty() ? fmt::format("#:var '{}", var) : "", exc ? fmt::format(" #:exc {}", exc->toString(indent)) : "", pad + padExtra, suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1 * 2))); } TryStmt::TryStmt(Stmt *suite, std::vector excepts, Stmt *elseSuite, Stmt *finally) : AcceptorExtend(), Items(std::move(excepts)), suite(SuiteStmt::wrap(suite)), elseSuite(SuiteStmt::wrap(elseSuite)), finally(SuiteStmt::wrap(finally)) {} TryStmt::TryStmt(const TryStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), suite(ast::clone(stmt.suite, clean)), elseSuite(ast::clone(stmt.elseSuite, clean)), finally(ast::clone(stmt.finally, clean)) {} std::string TryStmt::toString(int indent) const { if (indent == -1) return wrapStmt(fmt::format("(try)")); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::vector s; for (auto &i : items) s.push_back(i->toString(indent)); return wrapStmt(fmt::format( "(try{}{}{}{}{})", pad, suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, join(s, pad), elseSuite ? fmt::format("{}(else {})", pad, elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)) : "", finally ? fmt::format("{}(finally {})", pad, finally->toString(indent >= 0 ? indent + INDENT_SIZE : -1)) : "")); } ThrowStmt::ThrowStmt(Expr *expr, Expr *from, bool transformed) : AcceptorExtend(), expr(expr), from(from), transformed(transformed) {} ThrowStmt::ThrowStmt(const ThrowStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)), from(ast::clone(stmt.from, clean)), transformed(stmt.transformed) {} std::string ThrowStmt::toString(int indent) const { return wrapStmt( fmt::format("(throw{}{})", expr ? " " + expr->toString(indent) : "", from ? fmt::format(" :from {}", from->toString(indent)) : "")); } GlobalStmt::GlobalStmt(std::string var, bool nonLocal) : AcceptorExtend(), var(std::move(var)), nonLocal(nonLocal) {} GlobalStmt::GlobalStmt(const GlobalStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), var(stmt.var), nonLocal(stmt.nonLocal) {} std::string GlobalStmt::toString(int indent) const { return wrapStmt(fmt::format("({} '{})", nonLocal ? "nonlocal" : "global", var)); } FunctionStmt::FunctionStmt(std::string name, Expr *ret, std::vector args, Stmt *suite, std::vector decorators, bool async) : AcceptorExtend(), Items(std::move(args)), name(std::move(name)), ret(ret), suite(SuiteStmt::wrap(suite)), decorators(std::move(decorators)), async(async) {} FunctionStmt::FunctionStmt(const FunctionStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), name(stmt.name), ret(ast::clone(stmt.ret, clean)), suite(ast::clone(stmt.suite, clean)), decorators(ast::clone(stmt.decorators, clean)), async(stmt.async) {} std::string FunctionStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::vector as; for (auto &a : items) as.push_back(a.toString(indent)); std::vector dec; for (auto &a : decorators) if (a) dec.push_back(fmt::format("(dec {})", a->toString(indent))); if (indent == -1) return wrapStmt(fmt::format("(fn '{} ({}){})", name, join(as, " "), ret ? " #:ret " + ret->toString(indent) : "")); return wrapStmt(fmt::format( "(fn '{} ({}){}{}{}{})", name, join(as, " "), ret ? " #:ret " + ret->toString(indent) : "", dec.empty() ? "" : fmt::format(" (dec {})", join(dec, " ")), pad, suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : "(suite)")); } std::string FunctionStmt::getSignature() { if (signature.empty()) { std::vector s; for (auto &a : items) s.push_back(a.type ? a.type->toString() : "-"); signature = join(s, ":"); } return signature; } size_t FunctionStmt::getStarArgs() const { size_t i = 0; while (i < items.size()) { if (startswith(items[i].name, "*") && !startswith(items[i].name, "**")) break; i++; } return i; } size_t FunctionStmt::getKwStarArgs() const { size_t i = 0; while (i < items.size()) { if (startswith(items[i].name, "**")) break; i++; } return i; } std::string FunctionStmt::getDocstr() const { if (auto s = suite->firstInBlock()) { if (auto e = cast(s)) { if (auto ss = cast(e->getExpr())) return ss->getValue(); } } return ""; } bool FunctionStmt::hasFunctionAttribute(const std::string &attr) const { if (auto f = getAttribute(Attr::FunctionAttributes)) { return in(f->attributes, attr) != nullptr; } return false; } // Search expression tree for a identifier class IdSearchVisitor : public CallbackASTVisitor { std::string what; bool result; public: IdSearchVisitor(std::string what) : what(std::move(what)), result(false) {} bool transform(Expr *expr) override { if (result) return result; IdSearchVisitor v(what); if (expr) expr->accept(v); return result = v.result; } bool transform(Stmt *stmt) override { if (result) return result; IdSearchVisitor v(what); if (stmt) stmt->accept(v); return result = v.result; } void visit(IdExpr *expr) override { if (expr->getValue() == what) result = true; } }; /// Check if a function can be called with the given arguments. /// See @c reorderNamedArgs for details. std::unordered_set FunctionStmt::getNonInferrableGenerics() const { std::unordered_set nonInferrableGenerics; for (const auto &a : items) { if (a.status == Param::Generic && !a.defaultValue) { bool inferrable = false; for (const auto &b : items) if (b.type && IdSearchVisitor(a.name).transform(b.type)) { inferrable = true; break; } if (ret && IdSearchVisitor(a.name).transform(ret)) inferrable = true; if (!inferrable) nonInferrableGenerics.insert(a.name); } } return nonInferrableGenerics; } ClassStmt::ClassStmt(std::string name, std::vector args, Stmt *suite, std::vector decorators, const std::vector &baseClasses, std::vector staticBaseClasses) : AcceptorExtend(), Items(std::move(args)), name(std::move(name)), suite(SuiteStmt::wrap(suite)), decorators(std::move(decorators)), staticBaseClasses(std::move(staticBaseClasses)) { for (auto &b : baseClasses) { Expr *e = nullptr; if (match(b, M(M("Static"), MVar(e)))) { this->staticBaseClasses.push_back(e); } else { this->baseClasses.push_back(b); } } } ClassStmt::ClassStmt(const ClassStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), name(stmt.name), suite(ast::clone(stmt.suite, clean)), decorators(ast::clone(stmt.decorators, clean)), baseClasses(ast::clone(stmt.baseClasses, clean)), staticBaseClasses(ast::clone(stmt.staticBaseClasses, clean)) {} std::string ClassStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::vector bases; for (auto &b : baseClasses) bases.push_back(b->toString(indent)); for (auto &b : staticBaseClasses) bases.push_back(fmt::format("(static {})", b->toString(indent))); std::string as; for (int i = 0; i < items.size(); i++) as += (i ? pad : "") + items[i].toString(indent); std::vector attr; for (auto &a : decorators) attr.push_back(fmt::format("(dec {})", a->toString(indent))); if (indent == -1) return wrapStmt(fmt::format("(class '{} ({}))", name, as)); return wrapStmt(fmt::format( "(class '{}{}{}{}{}{})", name, bases.empty() ? "" : fmt::format(" (bases {})", join(bases, " ")), attr.empty() ? "" : fmt::format(" (attr {})", join(attr, " ")), as.empty() ? as : pad + as, pad, suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : "(suite)")); } bool ClassStmt::isRecord() const { return hasAttribute(Attr::Tuple); } bool ClassStmt::isClassVar(const Param &p) { if (!p.type) return true; if (auto i = cast(p.type)) return isId(i->getExpr(), "ClassVar"); return false; } std::string ClassStmt::getDocstr() const { if (auto s = suite->firstInBlock()) { if (auto e = cast(s)) { if (auto ss = cast(e->getExpr())) return ss->getValue(); } } return ""; } YieldFromStmt::YieldFromStmt(Expr *expr) : AcceptorExtend(), expr(std::move(expr)) {} YieldFromStmt::YieldFromStmt(const YieldFromStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} std::string YieldFromStmt::toString(int indent) const { return wrapStmt(fmt::format("(yield-from {})", expr->toString(indent))); } WithStmt::WithStmt(std::vector items, std::vector vars, Stmt *suite, bool isAsync) : AcceptorExtend(), Items(std::move(items)), vars(std::move(vars)), suite(SuiteStmt::wrap(suite)), async(isAsync) { seqassert(this->items.size() == this->vars.size(), "vector size mismatch"); } WithStmt::WithStmt(std::vector> itemVarPairs, Stmt *suite, bool isAsync) : AcceptorExtend(), Items({}), suite(SuiteStmt::wrap(suite)), async(isAsync) { for (auto [i, j] : itemVarPairs) { items.push_back(i); if (auto je = cast(j)) { vars.push_back(je->getValue()); } else { vars.emplace_back(); } } } WithStmt::WithStmt(const WithStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), vars(stmt.vars), suite(ast::clone(stmt.suite, clean)), async(stmt.async) {} std::string WithStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::vector as; as.reserve(items.size()); for (int i = 0; i < items.size(); i++) { as.push_back(!vars[i].empty() ? fmt::format("({} #:var '{})", items[i]->toString(indent), vars[i]) : items[i]->toString(indent)); } if (indent == -1) return wrapStmt(fmt::format("(with ({}))", join(as, " "))); return wrapStmt( fmt::format("(with ({}){}{})", join(as, " "), pad, suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); } CustomStmt::CustomStmt(std::string keyword, Expr *expr, Stmt *suite) : AcceptorExtend(), keyword(std::move(keyword)), expr(expr), suite(SuiteStmt::wrap(suite)) {} CustomStmt::CustomStmt(const CustomStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), keyword(stmt.keyword), expr(ast::clone(stmt.expr, clean)), suite(ast::clone(stmt.suite, clean)) {} std::string CustomStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; return wrapStmt(fmt::format( "(custom-{} {}{}{})", keyword, expr ? fmt::format(" #:expr {}", expr->toString(indent)) : "", pad, suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : "")); } DirectiveStmt::DirectiveStmt(std::string key, std::string value) : AcceptorExtend(), key(std::move(key)), value(std::move(value)) {} DirectiveStmt::DirectiveStmt(const DirectiveStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), key(stmt.key), value(stmt.value) {} std::string DirectiveStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; return wrapStmt(fmt::format("(directive {} '{}')", key, value)); } AssignMemberStmt::AssignMemberStmt(Expr *lhs, std::string member, Expr *rhs, Expr *type) : AcceptorExtend(), lhs(lhs), member(std::move(member)), rhs(rhs), type(type) {} AssignMemberStmt::AssignMemberStmt(const AssignMemberStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), lhs(ast::clone(stmt.lhs, clean)), member(stmt.member), rhs(ast::clone(stmt.rhs, clean)), type(ast::clone(stmt.type, clean)) {} std::string AssignMemberStmt::toString(int indent) const { return wrapStmt(fmt::format("(assign-member {} {} {})", lhs->toString(indent), member, rhs->toString(indent))); } CommentStmt::CommentStmt(std::string comment) : AcceptorExtend(), comment(std::move(comment)) {} CommentStmt::CommentStmt(const CommentStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), comment(stmt.comment) {} std::string CommentStmt::toString(int indent) const { return wrapStmt(fmt::format("(comment \"{}\")", comment)); } const char Stmt::NodeId = 0; ACCEPT_IMPL(SuiteStmt, ASTVisitor); ACCEPT_IMPL(BreakStmt, ASTVisitor); ACCEPT_IMPL(ContinueStmt, ASTVisitor); ACCEPT_IMPL(ExprStmt, ASTVisitor); ACCEPT_IMPL(AssignStmt, ASTVisitor); ACCEPT_IMPL(DelStmt, ASTVisitor); ACCEPT_IMPL(PrintStmt, ASTVisitor); ACCEPT_IMPL(ReturnStmt, ASTVisitor); ACCEPT_IMPL(YieldStmt, ASTVisitor); ACCEPT_IMPL(AssertStmt, ASTVisitor); ACCEPT_IMPL(WhileStmt, ASTVisitor); ACCEPT_IMPL(ForStmt, ASTVisitor); ACCEPT_IMPL(IfStmt, ASTVisitor); ACCEPT_IMPL(MatchStmt, ASTVisitor); ACCEPT_IMPL(ImportStmt, ASTVisitor); ACCEPT_IMPL(ExceptStmt, ASTVisitor); ACCEPT_IMPL(TryStmt, ASTVisitor); ACCEPT_IMPL(ThrowStmt, ASTVisitor); ACCEPT_IMPL(GlobalStmt, ASTVisitor); ACCEPT_IMPL(FunctionStmt, ASTVisitor); ACCEPT_IMPL(ClassStmt, ASTVisitor); ACCEPT_IMPL(YieldFromStmt, ASTVisitor); ACCEPT_IMPL(WithStmt, ASTVisitor); ACCEPT_IMPL(CustomStmt, ASTVisitor); ACCEPT_IMPL(DirectiveStmt, ASTVisitor); ACCEPT_IMPL(AssignMemberStmt, ASTVisitor); ACCEPT_IMPL(CommentStmt, ASTVisitor); } // namespace codon::ast namespace tser { void operator<<(codon::ast::Stmt *t, BinaryArchive &a) { using S = codon::PolymorphicSerializer; a.save(t != nullptr); if (t) { auto typ = t->dynamicNodeId(); auto key = S::_serializers[const_cast(typ)]; a.save(key); S::save(key, t, a); } } void operator>>(codon::ast::Stmt *&t, BinaryArchive &a) { using S = codon::PolymorphicSerializer; bool empty = a.load(); if (!empty) { std::string key = a.load(); S::load(key, t, a); } else { t = nullptr; } } } // namespace tser ================================================ FILE: codon/parser/ast/stmt.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/parser/ast/expr.h" #include "codon/parser/common.h" #include "codon/util/serialize.h" namespace codon::ast { #define ACCEPT(CLASS, VISITOR, ...) \ static const char NodeId; \ using AcceptorExtend::clone; \ using AcceptorExtend::accept; \ ASTNode *clone(bool c) const override; \ void accept(VISITOR &visitor) override; \ std::string toString(int) const override; \ friend class TypecheckVisitor; \ template friend struct CallbackASTVisitor; \ friend struct ReplacingCallbackASTVisitor; \ inline decltype(auto) match_members() const { return std::tie(__VA_ARGS__); } \ SERIALIZE(CLASS, BASE(Stmt), ##__VA_ARGS__) // Forward declarations struct ASTVisitor; /** * A Seq AST statement. * Each AST statement is intended to be instantiated as a shared_ptr. */ struct Stmt : public AcceptorExtend { using base_type = Stmt; Stmt(); Stmt(const Stmt &s); Stmt(const Stmt &, bool); explicit Stmt(const codon::SrcInfo &s); bool isDone() const { return done; } void setDone() { done = true; } /// @return the first statement in a suite; if a statement is not a suite, returns the /// statement itself virtual Stmt *firstInBlock() { return this; } static const char NodeId; SERIALIZE(Stmt, BASE(ASTNode), done); virtual std::string wrapStmt(const std::string &) const; protected: /// Flag that indicates if all types in a statement are inferred (i.e. if a /// type-checking procedure was successful). bool done; }; /// Suite (block of statements) statement (stmt...). /// @li a = 5; foo(1) struct SuiteStmt : public AcceptorExtend, Items { explicit SuiteStmt(std::vector stmts = {}); /// Convenience constructor template explicit SuiteStmt(Stmt *stmt, Ts... stmts) : Items({stmt, stmts...}) {} SuiteStmt(const SuiteStmt &, bool); Stmt *firstInBlock() override { return items.empty() ? nullptr : items[0]->firstInBlock(); } void flatten(); void addStmt(Stmt *s); static SuiteStmt *wrap(Stmt *); ACCEPT(SuiteStmt, ASTVisitor, items); }; /// Break statement. /// @li break struct BreakStmt : public AcceptorExtend { BreakStmt() = default; BreakStmt(const BreakStmt &, bool); ACCEPT(BreakStmt, ASTVisitor); }; /// Continue statement. /// @li continue struct ContinueStmt : public AcceptorExtend { ContinueStmt() = default; ContinueStmt(const ContinueStmt &, bool); ACCEPT(ContinueStmt, ASTVisitor); }; /// Expression statement (expr). /// @li 3 + foo() struct ExprStmt : public AcceptorExtend { explicit ExprStmt(Expr *expr = nullptr); ExprStmt(const ExprStmt &, bool); Expr *getExpr() const { return expr; } ACCEPT(ExprStmt, ASTVisitor, expr); private: Expr *expr; }; /// Assignment statement (lhs: type = rhs). /// @li a = 5 /// @li a: Optional[int] = 5 /// @li a, b, c = 5, *z struct AssignStmt : public AcceptorExtend { enum UpdateMode { Assign, Update, UpdateAtomic, ThreadLocalAssign }; AssignStmt() : lhs(nullptr), rhs(nullptr), type(nullptr), update(UpdateMode::Assign) {} AssignStmt(Expr *lhs, Expr *rhs, Expr *type = nullptr, UpdateMode update = UpdateMode::Assign); AssignStmt(const AssignStmt &, bool); Expr *getLhs() const { return lhs; } Expr *getRhs() const { return rhs; } Expr *getTypeExpr() const { return type; } void setLhs(Expr *expr) { lhs = expr; } void setRhs(Expr *expr) { rhs = expr; } bool isAssignment() const { return update == Assign; } bool isUpdate() const { return update == Update; } bool isAtomicUpdate() const { return update == UpdateAtomic; } bool isThreadLocal() { return update == ThreadLocalAssign; } void setUpdate() { update = Update; } void setAtomicUpdate() { update = UpdateAtomic; } void setThreadLocal() { update = ThreadLocalAssign; } ACCEPT(AssignStmt, ASTVisitor, lhs, rhs, type, update); private: Expr *lhs, *rhs, *type; UpdateMode update; }; /// Deletion statement (del expr). /// @li del a /// @li del a[5] struct DelStmt : public AcceptorExtend { explicit DelStmt(Expr *expr = nullptr); DelStmt(const DelStmt &, bool); Expr *getExpr() const { return expr; } ACCEPT(DelStmt, ASTVisitor, expr); private: Expr *expr; }; /// Print statement (print expr). /// @li print a, b struct PrintStmt : public AcceptorExtend, Items { explicit PrintStmt(std::vector items = {}, bool noNewline = false); PrintStmt(const PrintStmt &, bool); bool hasNewline() const { return !noNewline; } ACCEPT(PrintStmt, ASTVisitor, items, noNewline); private: /// True if there is a dangling comma after print: print a, bool noNewline; }; /// Return statement (return expr). /// @li return /// @li return a struct ReturnStmt : public AcceptorExtend { explicit ReturnStmt(Expr *expr = nullptr); ReturnStmt(const ReturnStmt &, bool); Expr *getExpr() const { return expr; } ACCEPT(ReturnStmt, ASTVisitor, expr); private: /// nullptr if this is an empty return/yield statements. Expr *expr; }; /// Yield statement (yield expr). /// @li yield /// @li yield a struct YieldStmt : public AcceptorExtend { explicit YieldStmt(Expr *expr = nullptr); YieldStmt(const YieldStmt &, bool); Expr *getExpr() const { return expr; } ACCEPT(YieldStmt, ASTVisitor, expr); private: /// nullptr if this is an empty return/yield statements. Expr *expr; }; /// Assert statement (assert expr). /// @li assert a /// @li assert a, "Message" struct AssertStmt : public AcceptorExtend { explicit AssertStmt(Expr *expr = nullptr, Expr *message = nullptr); AssertStmt(const AssertStmt &, bool); Expr *getExpr() const { return expr; } Expr *getMessage() const { return message; } ACCEPT(AssertStmt, ASTVisitor, expr, message); private: Expr *expr; /// nullptr if there is no message. Expr *message; }; /// While loop statement (while cond: suite; else: elseSuite). /// @li while True: print /// @li while True: break /// else: print struct WhileStmt : public AcceptorExtend { WhileStmt() : cond(nullptr), suite(nullptr), elseSuite(nullptr), gotoVar() {} WhileStmt(Expr *cond, Stmt *suite, Stmt *elseSuite = nullptr); WhileStmt(const WhileStmt &, bool); Expr *getCond() const { return cond; } SuiteStmt *getSuite() const { return suite; } SuiteStmt *getElse() const { return elseSuite; } ACCEPT(WhileStmt, ASTVisitor, cond, suite, elseSuite, gotoVar); private: Expr *cond; SuiteStmt *suite; /// nullptr if there is no else suite. SuiteStmt *elseSuite; /// Set if a while loop is used to emulate goto statement /// (as `while gotoVar: ...`). std::string gotoVar; }; /// For loop statement (for var in iter: suite; else elseSuite). /// @li for a, b in c: print /// @li for i in j: break /// else: print struct ForStmt : public AcceptorExtend { ForStmt() : var(nullptr), iter(nullptr), suite(nullptr), elseSuite(nullptr), decorator(nullptr), ompArgs(), async(false), wrapped(false), flat(false) {} ForStmt(Expr *var, Expr *iter, Stmt *suite, Stmt *elseSuite = nullptr, Expr *decorator = nullptr, std::vector ompArgs = {}, bool async = false); ForStmt(const ForStmt &, bool); Expr *getVar() const { return var; } Expr *getIter() const { return iter; } SuiteStmt *getSuite() const { return suite; } SuiteStmt *getElse() const { return elseSuite; } Expr *getDecorator() const { return decorator; } void setDecorator(Expr *e) { decorator = e; } bool isAsync() const { return async; } void setAsync() { async = true; } bool isWrapped() const { return wrapped; } bool isFlat() const { return flat; } ACCEPT(ForStmt, ASTVisitor, var, iter, suite, elseSuite, decorator, ompArgs, async, wrapped, flat); private: Expr *var; Expr *iter; SuiteStmt *suite; SuiteStmt *elseSuite; Expr *decorator; std::vector ompArgs; bool async; /// Indicates if iter was wrapped with __iter__() call. bool wrapped; /// True if there are no break/continue within the loop bool flat; friend struct GeneratorExpr; friend class ScopingVisitor; }; /// If block statement (if cond: suite; (elif cond: suite)...). /// @li if a: foo() /// @li if a: foo() /// elif b: bar() /// @li if a: foo() /// elif b: bar() /// else: baz() struct IfStmt : public AcceptorExtend { IfStmt(Expr *cond = nullptr, Stmt *ifSuite = nullptr, Stmt *elseSuite = nullptr); IfStmt(const IfStmt &, bool); Expr *getCond() const { return cond; } SuiteStmt *getIf() const { return ifSuite; } SuiteStmt *getElse() const { return elseSuite; } ACCEPT(IfStmt, ASTVisitor, cond, ifSuite, elseSuite); private: Expr *cond; /// elseSuite can be nullptr (if no else is found). SuiteStmt *ifSuite, *elseSuite; friend struct GeneratorExpr; }; struct MatchCase { MatchCase(Expr *pattern = nullptr, Expr *guard = nullptr, Stmt *suite = nullptr); Expr *getPattern() const { return pattern; } Expr *getGuard() const { return guard; } SuiteStmt *getSuite() const { return suite; } MatchCase clone(bool) const; SERIALIZE(MatchCase, pattern, guard, suite); private: Expr *pattern; Expr *guard; SuiteStmt *suite; friend struct MatchStmt; friend class TypecheckVisitor; template friend struct CallbackASTVisitor; friend struct ReplacingCallbackASTVisitor; }; /// Match statement (match what: (case pattern: case)...). /// @li match a: /// case 1: print /// case _: pass struct MatchStmt : public AcceptorExtend, Items { MatchStmt(Expr *what = nullptr, std::vector cases = {}); MatchStmt(const MatchStmt &, bool); Expr *getExpr() const { return expr; } ACCEPT(MatchStmt, ASTVisitor, items, expr); private: Expr *expr; }; /// Import statement. /// This node describes various kinds of import statements: /// - from from import what (as as) /// - import what (as as) /// - from c import what(args...) (-> ret) (as as) /// - from .(dots...)from import what (as as) /// @li import a /// @li from b import a /// @li from ...b import a as ai /// @li from c import foo(int) -> int as bar /// @li from python.numpy import array /// @li from python import numpy.array(int) -> int as na struct ImportStmt : public AcceptorExtend { ImportStmt(Expr *from = nullptr, Expr *what = nullptr, std::vector args = {}, Expr *ret = nullptr, std::string as = "", size_t dots = 0, bool isFunction = true); ImportStmt(const ImportStmt &, bool); Expr *getFrom() const { return from; } Expr *getWhat() const { return what; } std::string getAs() const { return as; } size_t getDots() const { return dots; } Expr *getReturnType() const { return ret; } const std::vector &getArgs() const { return args; } bool isCVar() const { return !isFunction; } ACCEPT(ImportStmt, ASTVisitor, from, what, as, dots, args, ret, isFunction); private: Expr *from, *what; std::string as; /// Number of dots in a relative import (e.g. dots is 3 for "from ...foo"). size_t dots; /// Function argument types for C imports. std::vector args; /// Function return type for C imports. Expr *ret; /// Set if this is a function C import (not variable import) bool isFunction; }; struct ExceptStmt : public AcceptorExtend { ExceptStmt(const std::string &var = "", Expr *exc = nullptr, Stmt *suite = nullptr); ExceptStmt(const ExceptStmt &, bool); std::string getVar() const { return var; } Expr *getException() const { return exc; } SuiteStmt *getSuite() const { return suite; } ACCEPT(ExceptStmt, ASTVisitor, var, exc, suite); private: /// empty string if an except is unnamed. std::string var; /// nullptr if there is no explicit exception type. Expr *exc; SuiteStmt *suite; friend class ScopingVisitor; }; /// Try-except statement (try: suite; (except var (as exc): suite)...; finally: /// finally). /// @li: try: a /// except e: pass /// except e as Exc: pass /// except: pass /// finally: print struct TryStmt : public AcceptorExtend, Items { TryStmt(Stmt *suite = nullptr, std::vector catches = {}, Stmt *elseSuite = nullptr, Stmt *finally = nullptr); TryStmt(const TryStmt &, bool); SuiteStmt *getSuite() const { return suite; } SuiteStmt *getElse() const { return elseSuite; } SuiteStmt *getFinally() const { return finally; } ACCEPT(TryStmt, ASTVisitor, items, suite, elseSuite, finally); private: SuiteStmt *suite; /// nullptr if there is no else block. SuiteStmt *elseSuite; /// nullptr if there is no finally block. SuiteStmt *finally; }; /// Throw statement (raise expr). /// @li: raise a struct ThrowStmt : public AcceptorExtend { explicit ThrowStmt(Expr *expr = nullptr, Expr *from = nullptr, bool transformed = false); ThrowStmt(const ThrowStmt &, bool); Expr *getExpr() const { return expr; } Expr *getFrom() const { return from; } bool isTransformed() const { return transformed; } ACCEPT(ThrowStmt, ASTVisitor, expr, from, transformed); private: Expr *expr; Expr *from; // True if a statement was transformed during type-checking stage // (to avoid setting up ExcHeader multiple times). bool transformed; }; /// Global variable statement (global var). /// @li: global a struct GlobalStmt : public AcceptorExtend { explicit GlobalStmt(std::string var = "", bool nonLocal = false); GlobalStmt(const GlobalStmt &, bool); std::string getVar() const { return var; } bool isNonLocal() const { return nonLocal; } ACCEPT(GlobalStmt, ASTVisitor, var, nonLocal); private: std::string var; bool nonLocal; }; /// Function statement (@(attributes...) def name[funcs...](args...) -> ret: suite). /// @li: @decorator /// def foo[T=int, U: int](a, b: int = 0) -> list[T]: pass struct FunctionStmt : public AcceptorExtend, Items { FunctionStmt(std::string name = "", Expr *ret = nullptr, std::vector args = {}, Stmt *suite = nullptr, std::vector decorators = {}, bool async = false); FunctionStmt(const FunctionStmt &, bool); std::string getName() const { return name; } void setName(const std::string &n) { name = n; } Expr *getReturn() const { return ret; } SuiteStmt *getSuite() const { return suite; } void setSuite(SuiteStmt *s) { suite = s; } const std::vector &getDecorators() const { return decorators; } void setDecorators(const std::vector &d) { decorators = d; } bool isAsync() const { return async; } void setAsync() { async = true; } void addParam(const Param &p) { items.push_back(p); } /// @return a function signature that consists of generics and arguments in a /// S-expression form. /// @li (T U (int 0)) std::string getSignature(); size_t getStarArgs() const; size_t getKwStarArgs() const; std::string getDocstr() const; std::unordered_set getNonInferrableGenerics() const; bool hasFunctionAttribute(const std::string &attr) const; ACCEPT(FunctionStmt, ASTVisitor, name, items, ret, suite, decorators, async); private: std::string name; Expr *ret; /// nullptr if return type is not specified. SuiteStmt *suite; std::vector decorators; bool async; std::string signature; friend struct Cache; }; /// Class statement (@(attributes...) class name[generics...]: args... ; suite). /// @li: @type /// class F[T]: /// m: T /// def __new__() -> F[T]: ... struct ClassStmt : public AcceptorExtend, Items { ClassStmt(std::string name = "", std::vector args = {}, Stmt *suite = nullptr, std::vector decorators = {}, const std::vector &baseClasses = {}, std::vector staticBaseClasses = {}); ClassStmt(const ClassStmt &, bool); std::string getName() const { return name; } SuiteStmt *getSuite() const { return suite; } const std::vector &getDecorators() const { return decorators; } void setDecorators(const std::vector &d) { decorators = d; } const std::vector &getBaseClasses() const { return baseClasses; } const std::vector &getStaticBaseClasses() const { return staticBaseClasses; } /// @return true if a class is a tuple-like record (e.g. has a "@tuple" attribute) bool isRecord() const; std::string getDocstr() const; static bool isClassVar(const Param &p); ACCEPT(ClassStmt, ASTVisitor, name, suite, items, decorators, baseClasses, staticBaseClasses); private: std::string name; SuiteStmt *suite; std::vector decorators; std::vector baseClasses; std::vector staticBaseClasses; }; /// Yield-from statement (yield from expr). /// @li: yield from it struct YieldFromStmt : public AcceptorExtend { explicit YieldFromStmt(Expr *expr = nullptr); YieldFromStmt(const YieldFromStmt &, bool); Expr *getExpr() const { return expr; } ACCEPT(YieldFromStmt, ASTVisitor, expr); private: Expr *expr; }; /// With statement (with (item as var)...: suite). /// @li: with foo(), bar() as b: pass struct WithStmt : public AcceptorExtend, Items { WithStmt(std::vector items = {}, std::vector vars = {}, Stmt *suite = nullptr, bool isAsync = false); WithStmt(std::vector> items, Stmt *suite, bool isAsync); WithStmt(const WithStmt &, bool); const std::vector &getVars() const { return vars; } SuiteStmt *getSuite() const { return suite; } bool isAsync() const { return async; } void setAsync() { async = true; } ACCEPT(WithStmt, ASTVisitor, items, vars, suite); private: /// empty string if a corresponding item is unnamed std::vector vars; SuiteStmt *suite; bool async; }; /// Custom block statement (foo: ...). /// @li: pt_tree: pass struct CustomStmt : public AcceptorExtend { CustomStmt(std::string keyword = "", Expr *expr = nullptr, Stmt *suite = nullptr); CustomStmt(const CustomStmt &, bool); std::string getKeyword() const { return keyword; } Expr *getExpr() const { return expr; } SuiteStmt *getSuite() const { return suite; } ACCEPT(CustomStmt, ASTVisitor, keyword, expr, suite); private: std::string keyword; Expr *expr; SuiteStmt *suite; }; struct DirectiveStmt : public AcceptorExtend { DirectiveStmt(std::string key = "", std::string value = ""); DirectiveStmt(const DirectiveStmt &, bool); std::string getKey() const { return key; } std::string getValue() const { return value; } ACCEPT(DirectiveStmt, ASTVisitor, key, value); private: std::string key, value; }; /// The following nodes are created during typechecking. /// Member assignment statement (lhs.member = rhs). /// @li: a.x = b struct AssignMemberStmt : public AcceptorExtend { AssignMemberStmt(Expr *lhs = nullptr, std::string member = "", Expr *rhs = nullptr, Expr *type = nullptr); AssignMemberStmt(const AssignMemberStmt &, bool); Expr *getLhs() const { return lhs; } std::string getMember() const { return member; } Expr *getRhs() const { return rhs; } Expr *getTypeExpr() const { return type; } ACCEPT(AssignMemberStmt, ASTVisitor, lhs, member, rhs, type); private: Expr *lhs; std::string member; Expr *rhs; Expr *type; }; /// Comment statement (# comment). /// Currently used only for pretty-printing. struct CommentStmt : public AcceptorExtend { explicit CommentStmt(std::string comment = ""); CommentStmt(const CommentStmt &, bool); std::string getComment() const { return comment; } ACCEPT(CommentStmt, ASTVisitor, comment); private: std::string comment; }; #undef ACCEPT } // namespace codon::ast namespace tser { void operator<<(codon::ast::Stmt *t, BinaryArchive &a); void operator>>(codon::ast::Stmt *&t, BinaryArchive &a); } // namespace tser ================================================ FILE: codon/parser/ast/types/class.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include "codon/parser/ast/types/class.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon::ast::types { std::string ClassType::Generic::debugString(char mode) const { if (!staticKind && type->getStatic() && mode != 2) return type->getStatic()->getNonStaticType()->debugString(mode); return type->debugString(mode); } std::string ClassType::Generic::realizedName() const { if (!staticKind && type->getStatic()) return type->getStatic()->getNonStaticType()->realizedName(); return type->realizedName(); } ClassType::Generic ClassType::Generic::generalize(int atLevel) const { TypePtr t = nullptr; if (!staticKind && type && type->getStatic()) t = type->getStatic()->getNonStaticType()->generalize(atLevel); else if (type) t = type->generalize(atLevel); return {name, t, id, staticKind}; } ClassType::Generic ClassType::Generic::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const { TypePtr t = nullptr; if (!staticKind && type && type->getStatic()) t = type->getStatic()->getNonStaticType()->instantiate(atLevel, unboundCount, cache); else if (type) t = type->instantiate(atLevel, unboundCount, cache); return {name, t, id, staticKind}; } ClassType::ClassType(Cache *cache, std::string name, std::vector generics, std::vector hiddenGenerics) : Type(cache), name(std::move(name)), generics(std::move(generics)), hiddenGenerics(std::move(hiddenGenerics)) {} ClassType::ClassType(const ClassType *base) : Type(*base), name(base->name), generics(base->generics), hiddenGenerics(base->hiddenGenerics), isTuple(base->isTuple) {} int ClassType::unify(Type *typ, Unification *us) { if (auto tc = typ->getClass()) { if (name == "int" && tc->name == "Int") return tc->unify(this, us); if (tc->name == "int" && name == "Int") { auto t64 = std::make_shared(cache, 64); return generics[0].type->unify(t64.get(), us); } if (name == "unrealized_type" && tc->name == name) { // instantiate + unify! std::unordered_map genericCache; auto l = generics[0].type->instantiate(0, &(cache->unboundCount), &genericCache); genericCache.clear(); auto r = tc->generics[0].type->instantiate(0, &(cache->unboundCount), &genericCache); return l->unify(r.get(), us); } int s1 = 3, s = 0; if (name == "__NTuple__" && tc->name == name) { auto n1 = generics[0].getType()->getIntStatic(); auto n2 = tc->generics[0].getType()->getIntStatic(); if (n1 && n2) { auto t1 = generics[1].getType()->getClass(); auto t2 = tc->generics[1].getType()->getClass(); seqassert(t1 && t2, "bad ntuples"); if (n1->value * t1->generics.size() != n2->value * t2->generics.size()) return -1; for (size_t i = 0; i < t1->generics.size() * n1->value; i++) { if ((s = t1->generics[i % t1->generics.size()].getType()->unify( t2->generics[i % t2->generics.size()].getType(), us)) == -1) return -1; s1 += s; } return s1; } } else if (tc->name == "__NTuple__") { return tc->unify(this, us); } else if (name == "__NTuple__" && tc->name == TYPE_TUPLE) { auto n1 = generics[0].getType()->getIntStatic(); if (!n1) { auto n = tc->generics.size(); auto tn = std::make_shared(cache, n); // If we are unifying NT[N, T] and T[X, X, ...], we assume that N is number of // X's if (generics[0].type->unify(tn.get(), us) == -1) return -1; auto tv = TypecheckVisitor(cache->typeCtx); TypePtr tt; if (n) { tt = tv.instantiateType(tv.generateTuple(1), {tc->generics[0].getType()}); for (size_t i = 1; i < tc->generics.size(); i++) { if ((s = tt->getClass()->generics[0].getType()->unify( tc->generics[i].getType(), us)) == -1) return -1; s1 += s; } } else { tt = tv.instantiateType(tv.generateTuple(1)); // tt = tv.instantiateType(tv.generateTuple(0)); } if (generics[1].type->unify(tt.get(), us) == -1) return -1; } else { auto t1 = generics[1].getType()->getClass(); seqassert(t1, "bad ntuples"); if (n1->value * t1->generics.size() != tc->generics.size()) return -1; for (size_t i = 0; i < t1->generics.size() * n1->value; i++) { if ((s = t1->generics[i % t1->generics.size()].getType()->unify( tc->generics[i].getType(), us)) == -1) return -1; s1 += s; } } return s1; } // Check names. if (name != tc->name) { return -1; } // Check generics. if (generics.size() != tc->generics.size()) return -1; for (int i = 0; i < generics.size(); i++) { if ((s = generics[i].type->unify(tc->generics[i].type.get(), us)) == -1) { return -1; } s1 += s; } for (int i = 0; i < hiddenGenerics.size(); i++) { if ((s = hiddenGenerics[i].type->unify(tc->hiddenGenerics[i].type.get(), us)) == -1) { return -1; } s1 += s; } return s1; } else if (auto tl = typ->getLink()) { return tl->unify(this, us); } else { return -1; } } TypePtr ClassType::generalize(int atLevel) const { std::vector g, hg; for (auto &t : generics) g.push_back(t.generalize(atLevel)); for (auto &t : hiddenGenerics) hg.push_back(t.generalize(atLevel)); auto c = std::make_shared(cache, name, g, hg); c->isTuple = isTuple; c->setSrcInfo(getSrcInfo()); return c; } TypePtr ClassType::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const { std::vector g, hg; for (auto &t : generics) g.push_back(t.instantiate(atLevel, unboundCount, cache)); for (auto &t : hiddenGenerics) hg.push_back(t.instantiate(atLevel, unboundCount, cache)); auto c = std::make_shared(this->cache, name, g, hg); c->isTuple = isTuple; c->setSrcInfo(getSrcInfo()); return c; } bool ClassType::hasUnbounds(bool includeGenerics) const { if (name == "unrealized_type") return false; auto pred = [includeGenerics](const auto &t) { return t.type && t.type->hasUnbounds(includeGenerics); }; return std::ranges::any_of(generics.begin(), generics.end(), pred) || std::ranges::any_of(hiddenGenerics.begin(), hiddenGenerics.end(), pred); } std::vector ClassType::getUnbounds(bool includeGenerics) const { std::vector u; if (name == "unrealized_type") return u; for (auto &t : generics) if (t.type) { auto tu = t.type->getUnbounds(includeGenerics); u.insert(u.begin(), tu.begin(), tu.end()); } for (auto &t : hiddenGenerics) if (t.type) { auto tu = t.type->getUnbounds(includeGenerics); u.insert(u.begin(), tu.begin(), tu.end()); } return u; } bool ClassType::canRealize() const { if (name == "type") { if (!hasUnbounds(false)) return true; // always true! } if (name == "unrealized_type") return generics[0].type->getClass() != nullptr; auto pred = [](auto &t) { return !t.type || t.type->canRealize(); }; return std::ranges::all_of(generics.begin(), generics.end(), pred) && std::ranges::all_of(hiddenGenerics.begin(), hiddenGenerics.end(), pred); } bool ClassType::isInstantiated() const { if (name == "unrealized_type") return generics[0].type->getClass() != nullptr; auto pred = [](auto &t) { return !t.type || t.type->isInstantiated(); }; return std::ranges::all_of(generics.begin(), generics.end(), pred) && std::ranges::all_of(hiddenGenerics.begin(), hiddenGenerics.end(), pred); } std::string ClassType::debugString(char mode) const { if (name == "NamedTuple") { if (auto ids = generics[0].type->getIntStatic()) { auto id = ids->value; seqassert(id >= 0 && id < cache->generatedTupleNames.size(), "bad id: {}", id); const auto &names = cache->generatedTupleNames[id]; auto ts = generics[1].getType()->getClass(); if (names.empty()) return name; std::vector as; for (size_t i = 0; i < names.size(); i++) as.push_back(names[i] + "=" + ts->generics[i].debugString(mode)); return name + "[" + join(as, ",") + "]"; } else { return name + "[" + generics[0].type->debugString(mode) + "]"; } } else if (name == "Partial" && generics[3].type->getClass() && mode != 2) { // Name: function[full_args](instantiated_args...) std::vector as; auto known = getPartialMask(); auto func = getPartialFunc(); std::vector args; if (auto ta = generics[1].type->getClass()) for (const auto &i : ta->generics) args.push_back(i.debugString(mode)); size_t ai = 0, gi = 0; for (size_t i = 0; i < known.size(); i++) { if ((*func->ast)[i].isValue()) { as.emplace_back(ai < args.size() ? (known[i] == ClassType::PartialFlag::Included ? args[ai] : ("..." + (mode == 0 ? "" : args[ai]))) : "..."); if (known[i] == ClassType::PartialFlag::Included) ai++; } else { auto s = func->funcGenerics[gi].debugString(mode); as.emplace_back((known[i] == ClassType::PartialFlag::Included ? s : ("..." + (mode == 0 ? "" : s)))); gi++; } } if (!args.empty()) { if (args.back() != "Tuple") // unused *args (by default always 0 in mask) as.push_back(args.back()); } auto ks = generics[2].type->debugString(mode); if (ks.size() > 10) { // if **kwargs is used ks = ks.substr(11, ks.size() - 12); // chop off NamedTuple[...] as.push_back(ks); } auto fnname = func->ast->getName(); if (mode == 0) { fnname = cache->rev(func->ast->getName()); } else { fnname = func->ast->getName(); } return fnname + "(" + join(as, ",") + ")"; } std::vector gs; for (auto &a : generics) if (!a.name.empty()) gs.push_back(a.debugString(mode)); if ((mode == 2) && !hiddenGenerics.empty()) { for (auto &a : hiddenGenerics) if (!a.name.empty()) gs.push_back("-" + a.debugString(mode)); } // Special formatting for Functions and Tuples auto n = mode == 0 ? cache->rev(name) : name; return n + (gs.empty() ? "" : ("[" + join(gs, ",") + "]")); } std::string ClassType::realizedName() const { if (!_rn.empty()) return _rn; std::string s; if (name == "Partial") { s = debugString(1); } else { std::vector gs; if (name == "Union" && generics[0].type->getClass()) { std::set gss; for (auto &a : generics[0].type->getClass()->generics) gss.insert(a.realizedName()); gs = {join(gss, " | ")}; } else { for (auto &a : generics) if (!a.name.empty()) gs.push_back(a.realizedName()); } s = join(gs, ","); s = name + (s.empty() ? "" : ("[" + s + "]")); } if (canRealize()) const_cast(this)->_rn = s; return s; } FuncType *ClassType::getPartialFunc() const { seqassert(name == "Partial", "not a partial"); auto n = generics[3].type->getClass()->generics[0].type; seqassert(n->getFunc(), "not a partial func"); return n->getFunc(); } std::string ClassType::getPartialMask() const { seqassert(name == "Partial", "not a partial"); auto n = generics[0].type->getStrStatic()->value; return n; } bool ClassType::isPartialEmpty() const { auto a = generics[1].type->getClass(); auto ka = generics[2].type->getClass(); return a->generics.size() == 1 && a->generics[0].type->getClass()->generics.empty() && ka->generics[1].type->getClass()->generics.empty(); } } // namespace codon::ast::types ================================================ FILE: codon/parser/ast/types/class.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include "codon/parser/ast/types/type.h" namespace codon::ast::types { struct FuncType; /** * A generic class reference type. All Seq types inherit from this class. */ struct ClassType : public Type { /** * A generic type declaration. * Each generic is defined by its unique ID. */ struct Generic { // Generic name. std::string name; // Unique generic ID. int id; // Pointer to realized type (or generic LinkType). TypePtr type; // Set if this is a static generic LiteralKind staticKind; Generic(std::string name, TypePtr type, int id, LiteralKind staticKind) : name(std::move(name)), id(id), type(std::move(type)), staticKind(staticKind) { } types::Type *getType() const { return type.get(); } Generic generalize(int atLevel) const; Generic instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const; std::string debugString(char mode) const; std::string realizedName() const; }; /// Canonical type name. std::string name; /// List of generics, if present. std::vector generics; std::vector hiddenGenerics; bool isTuple = false; std::string _rn; explicit ClassType(Cache *cache, std::string name, std::vector generics = {}, std::vector hiddenGenerics = {}); explicit ClassType(const ClassType *base); public: int unify(Type *typ, Unification *undo) override; TypePtr generalize(int atLevel) const override; TypePtr instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const override; public: bool hasUnbounds(bool) const override; std::vector getUnbounds(bool) const override; bool canRealize() const override; bool isInstantiated() const override; std::string debugString(char mode) const override; std::string realizedName() const override; ClassType *getClass() override { return this; } ClassType *getPartial() override { return name == "Partial" ? getClass() : nullptr; } bool isRecord() const { return isTuple; } size_t size() const { return generics.size(); } Type *operator[](size_t i) const { return generics[i].getType(); } public: enum PartialFlag { Missing = '0', Included, Default }; FuncType *getPartialFunc() const; std::string getPartialMask() const; bool isPartialEmpty() const; }; } // namespace codon::ast::types template <> struct fmt::formatter : fmt::formatter { template auto format(const codon::ast::types::ClassType::Generic &p, FormatContext &ctx) const -> decltype(ctx.out()) { return fmt::format_to(ctx.out(), "({}{})", p.name.empty() ? "" : fmt::format("{} = ", p.name), p.type); } }; ================================================ FILE: codon/parser/ast/types/function.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include "codon/parser/ast/stmt.h" #include "codon/parser/ast/types/function.h" #include "codon/parser/cache.h" namespace codon::ast::types { FuncType::FuncType(const ClassType *baseType, FunctionStmt *ast, std::vector funcGenerics, TypePtr funcParent) : ClassType(baseType), ast(ast), funcGenerics(std::move(funcGenerics)), funcParent(std::move(funcParent)) {} int FuncType::unify(Type *typ, Unification *us) { if (this == typ) return 0; int s1 = 2, s = 0; if (auto t = typ->getFunc()) { // Check if names and parents match. if (ast->getName() != t->ast->getName() || (static_cast(funcParent) ^ static_cast(t->funcParent))) return -1; if (funcParent && (s = funcParent->unify(t->funcParent.get(), us)) == -1) { return -1; } s1 += s; // Check if function generics match. seqassert(funcGenerics.size() == t->funcGenerics.size(), "generic size mismatch for {}", ast->getName()); for (int i = 0; i < funcGenerics.size(); i++) { if ((s = funcGenerics[i].type->unify(t->funcGenerics[i].type.get(), us)) == -1) return -1; s1 += s; } } s = this->ClassType::unify(typ, us); return s == -1 ? s : s1 + s; } TypePtr FuncType::generalize(int atLevel) const { std::vector fg; for (auto &t : funcGenerics) fg.push_back(t.generalize(atLevel)); auto p = funcParent ? funcParent->generalize(atLevel) : nullptr; auto r = std::static_pointer_cast(this->ClassType::generalize(atLevel)); auto t = std::make_shared(r->getClass(), ast, fg, p); return t; } TypePtr FuncType::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const { std::vector fg; for (auto &t : funcGenerics) { fg.push_back(t.instantiate(atLevel, unboundCount, cache)); if (cache && fg.back().type) { if (auto c = in(*cache, t.id)) *c = fg.back().type; } } auto p = funcParent ? funcParent->instantiate(atLevel, unboundCount, cache) : nullptr; auto r = std::static_pointer_cast( this->ClassType::instantiate(atLevel, unboundCount, cache)); auto t = std::make_shared(r->getClass(), ast, fg, p); return t; } bool FuncType::hasUnbounds(bool includeGenerics) const { for (auto &t : funcGenerics) if (t.type && t.type->hasUnbounds(includeGenerics)) return true; if (funcParent && funcParent->hasUnbounds(includeGenerics)) return true; for (const auto &a : *this) if (a.getType()->hasUnbounds(includeGenerics)) return true; return getRetType()->hasUnbounds(includeGenerics); } std::vector FuncType::getUnbounds(bool includeGenerics) const { std::vector u; for (auto &t : funcGenerics) if (t.type) { auto tu = t.type->getUnbounds(includeGenerics); u.insert(u.begin(), tu.begin(), tu.end()); } if (funcParent) { auto tu = funcParent->getUnbounds(includeGenerics); u.insert(u.begin(), tu.begin(), tu.end()); } // Important: return type unbounds are not important, so skip them. for (const auto &a : *this) { auto tu = a.getType()->getUnbounds(includeGenerics); u.insert(u.begin(), tu.begin(), tu.end()); } return u; } bool FuncType::canRealize() const { bool allowPassThrough = ast->hasAttribute(Attr::AllowPassThrough); // Important: return type does not have to be realized. for (int ai = 0; ai < size(); ai++) if (!(*this)[ai]->getFunc() && !(*this)[ai]->canRealize()) { if (!allowPassThrough) return false; for (auto &u : (*this)[ai]->getUnbounds(true)) if (u->getLink()->kind == LinkType::Generic || !u->getLink()->passThrough) return false; } bool generics = std::ranges::all_of(funcGenerics.begin(), funcGenerics.end(), [](auto &a) { return !a.type || a.type->canRealize(); }); if (generics && funcParent && !funcParent->canRealize()) { if (!allowPassThrough) return false; for (auto &u : funcParent->getUnbounds(true)) { if (u->getLink()->kind == LinkType::Generic || !u->getLink()->passThrough) return false; } } return generics; } bool FuncType::isInstantiated() const { TypePtr removed = nullptr; auto retType = getRetType(); if (retType->getFunc() && retType->getFunc()->funcParent.get() == this) { removed = retType->getFunc()->funcParent; retType->getFunc()->funcParent = nullptr; } auto res = std::ranges::all_of( funcGenerics.begin(), funcGenerics.end(), [](auto &a) { return !a.type || a.type->isInstantiated(); }) && (!funcParent || funcParent->isInstantiated()) && this->ClassType::isInstantiated(); if (removed) retType->getFunc()->funcParent = removed; return res; } std::string FuncType::debugString(char mode) const { std::vector gs; for (auto &a : funcGenerics) if (!a.name.empty()) gs.push_back(mode < 2 ? a.type->debugString(mode) : (cache->rev(a.name) + "=" + a.type->debugString(mode))); std::string s = join(gs, ","); std::vector as; // Important: return type does not have to be realized. if (mode == 2) as.push_back("RET=" + getRetType()->debugString(mode)); if (mode < 2 || !ast) { for (const auto &a : *this) { as.push_back(a.debugString(mode)); } } else { for (size_t i = 0, si = 0; i < ast->size(); i++) { if ((*ast)[i].isGeneric()) continue; as.push_back(((*ast)[i].getName() + "=" + (*this)[si++]->debugString(mode))); } } std::string a = join(as, ","); s = s.empty() ? a : join(std::vector{s, a}, ";"); seqassert(ast, "ast must not be null"); auto fnname = ast->getName(); if (mode == 0) { fnname = cache->rev(ast->getName()); } if (mode == 2 && funcParent) s += ";" + funcParent->debugString(mode); return fnname + (s.empty() ? "" : ("[" + s + "]")); } std::string FuncType::realizedName() const { std::vector gs; for (auto &a : funcGenerics) if (!a.name.empty()) gs.push_back(a.realizedName()); std::string s = join(gs, ","); std::vector as; // Important: return type does not have to be realized. for (const auto &a : *this) as.push_back(a.getType()->getFunc() ? a.getType()->getFunc()->realizedName() : a.realizedName()); std::string a = join(as, ","); s = s.empty() ? a : join(std::vector{a, s}, ","); return (funcParent ? funcParent->realizedName() + ":" : "") + ast->getName() + (s.empty() ? "" : ("[" + s + "]")); } Type *FuncType::getRetType() const { return generics[1].type.get(); } std::string FuncType::getFuncName() const { return ast->getName(); } Type *FuncType::operator[](size_t i) const { return generics[0].type->getClass()->generics[i].getType(); } std::vector::iterator FuncType::begin() const { return generics[0].type->getClass()->generics.begin(); } std::vector::iterator FuncType::end() const { return generics[0].type->getClass()->generics.end(); } size_t FuncType::size() const { return generics[0].type->getClass()->generics.size(); } bool FuncType::empty() const { return generics[0].type->getClass()->generics.empty(); } } // namespace codon::ast::types ================================================ FILE: codon/parser/ast/types/function.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/parser/ast/types/class.h" #include "codon/parser/ast/types/type.h" namespace codon::ast { struct FunctionStmt; } namespace codon::ast::types { /** * A generic type that represents a Codon function instantiation. * It inherits ClassType that realizes Function[...]. */ struct FuncType : public ClassType { /// Canonical AST node. FunctionStmt *ast; /// Function generics (e.g. T in def foo[T](...)). std::vector funcGenerics; /// Enclosing class or a function. TypePtr funcParent; public: FuncType( const ClassType *baseType, FunctionStmt *ast, std::vector funcGenerics = std::vector(), TypePtr funcParent = nullptr); public: int unify(Type *typ, Unification *undo) override; TypePtr generalize(int atLevel) const override; TypePtr instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const override; public: bool hasUnbounds(bool) const override; std::vector getUnbounds(bool) const override; bool canRealize() const override; bool isInstantiated() const override; std::string debugString(char mode) const override; std::string realizedName() const override; FuncType *getFunc() override { return this; } Type *getRetType() const; Type *getParentType() const { return funcParent.get(); } std::string getFuncName() const; Type *operator[](size_t i) const; std::vector::iterator begin() const; std::vector::iterator end() const; size_t size() const; bool empty() const; }; } // namespace codon::ast::types ================================================ FILE: codon/parser/ast/types/link.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include "codon/parser/ast/types/link.h" #include "codon/parser/visitors/format/format.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon::ast::types { LinkType::LinkType(Cache *cache, Kind kind, int id, int level, TypePtr type, LiteralKind staticKind, std::shared_ptr trait, TypePtr defaultType, std::string genericName, bool passThrough) : Type(cache), kind(kind), id(id), level(level), type(std::move(type)), staticKind(staticKind), trait(std::move(trait)), genericName(std::move(genericName)), defaultType(std::move(defaultType)), passThrough(passThrough) { seqassert((this->type && kind == Link) || (!this->type && kind == Generic) || (!this->type && kind == Unbound), "inconsistent link state"); } LinkType::LinkType(TypePtr type) : Type(type), kind(Link), id(0), level(0), type(std::move(type)), staticKind(Runtime), trait(nullptr), defaultType(nullptr), passThrough(false) { seqassert(this->type, "link to nullptr"); } int LinkType::unify(Type *typ, Unification *undo) { if (kind == Link) { // Case 1: Just follow the link return type->unify(typ, undo); } else { // Case 3: Unbound unification if (getStaticKind() != typ->getStaticKind()) { if (!getStaticKind()) { // other one is; move this to non-static equivalent if (undo) { undo->statics.push_back(shared_from_this()); staticKind = typ->getStaticKind(); } } else { return -1; } } if (auto t = typ->getLink()) { if (t->kind == Link) return t->type->unify(this, undo); if (kind != t->kind) return -1; // Identical unbound types get a score of 1 if (id == t->id) return 1; // Generics must have matching IDs unless we are doing non-destructive unification if (kind == Generic) return undo ? -1 : 1; // Always merge a newer type into the older type (e.g. keep the types with // lower IDs around). if (id < t->id) return t->unify(this, undo); } else if (kind == Generic) { return -1; } // Generics must be handled by now; only unbounds can be unified! seqassertn(kind == Unbound, "not an unbound"); // Ensure that we do not have recursive unification! (e.g. unify ?1 with list[?1]) if (occurs(typ, undo)) return -1; // Handle traits if (trait && trait->unify(typ, undo) == -1) return -1; // ⚠️ Unification: destructive part. seqassert(!type, "type has been already unified or is in inconsistent state"); if (undo) { LOG_TYPECHECK("[unify] {} := {}", id, typ->debugString(2)); // Link current type to typ and ensure that this modification is recorded in undo. undo->linked.push_back(shared_from_this()); this->kind = Link; seqassert(!typ->getLink() || typ->getLink()->kind != Unbound || typ->getLink()->id <= id, "type unification is not consistent"); this->type = typ->follow()->shared_from_this(); if (auto t = type->getLink(); t && trait && t->kind == Unbound && !t->trait) { undo->traits.push_back(t->shared_from_this()); t->trait = trait; } } return 0; } } TypePtr LinkType::generalize(int atLevel) const { // We need to preserve the pointers as something else might be pointing // to this unbound, hence the const_casts because shared_from_ptr // needs it. if (kind == Generic) { return const_cast(this)->shared_from_this(); } else if (kind == Unbound) { if (level >= atLevel) { return std::make_shared( cache, Generic, id, 0, nullptr, staticKind, trait ? std::static_pointer_cast(trait->generalize(atLevel)) : nullptr, defaultType ? defaultType->generalize(atLevel) : nullptr, genericName, passThrough); } else { return const_cast(this)->shared_from_this(); } } else { seqassert(type, "link is null"); return type->generalize(atLevel); } } TypePtr LinkType::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const { if (kind == Generic) { if (TypePtr *res = nullptr; cache && ((res = in(*cache, id)))) return *res; auto t = std::make_shared( this->cache, Unbound, unboundCount ? (*unboundCount)++ : id, atLevel, nullptr, staticKind, trait ? std::static_pointer_cast( trait->instantiate(atLevel, unboundCount, cache)) : nullptr, defaultType ? defaultType->instantiate(atLevel, unboundCount, cache) : nullptr, genericName, passThrough); if (cache) (*cache)[id] = t; return t; } else if (kind == Unbound) { // We need to preserve the pointers as something else might be pointing // to this unbound, hence the const_casts because shared_from_ptr // needs it. return const_cast(this)->shared_from_this(); } else { seqassert(type, "link is null"); return type->instantiate(atLevel, unboundCount, cache); } } Type *LinkType::follow() { if (kind == Link) return type->follow(); else return this; } std::vector LinkType::getUnbounds(bool includeGenerics) const { if (kind == Unbound) return {(Type *)this}; else if (kind == Link) return type->getUnbounds(includeGenerics); else if (includeGenerics) return {(Type *)this}; return {}; } bool LinkType::hasUnbounds(bool includeGenerics) const { if (kind == Unbound) return true; if (includeGenerics && kind == Generic) return true; if (kind == Link) return type->hasUnbounds(includeGenerics); return false; } bool LinkType::canRealize() const { if (kind != Link) return false; else return type->canRealize(); } bool LinkType::isInstantiated() const { return kind == Link && type->isInstantiated(); } std::string LinkType::debugString(char mode) const { if (kind == Unbound || kind == Generic) { if (mode == 2) { return (genericName.empty() ? "" : genericName + ":") + (kind == Unbound ? "?" : "#") + fmt::format("{}", id) + (trait ? ":" + trait->debugString(mode) : "") + (staticKind ? fmt::format(":S{}", static_cast(staticKind)) : ""); } else if (trait) { return trait->debugString(mode); } return (genericName.empty() ? (mode ? "?" : "") : genericName); } return type->debugString(mode); } std::string LinkType::realizedName() const { if (kind == Unbound) return ("#" + genericName); if (kind == Generic) return ("#" + genericName); seqassert(kind == Link, "unexpected generic link"); return type->realizedName(); } LinkType *LinkType::getLink() { return this; } FuncType *LinkType::getFunc() { return kind == Link ? type->getFunc() : nullptr; } ClassType *LinkType::getPartial() { return kind == Link ? type->getPartial() : nullptr; } ClassType *LinkType::getClass() { return kind == Link ? type->getClass() : nullptr; } StaticType *LinkType::getStatic() { return kind == Link ? type->getStatic() : nullptr; } IntStaticType *LinkType::getIntStatic() { return kind == Link ? type->getIntStatic() : nullptr; } StrStaticType *LinkType::getStrStatic() { return kind == Link ? type->getStrStatic() : nullptr; } BoolStaticType *LinkType::getBoolStatic() { return kind == Link ? type->getBoolStatic() : nullptr; } UnionType *LinkType::getUnion() { return kind == Link ? type->getUnion() : nullptr; } LinkType *LinkType::getUnbound() { if (kind == Unbound) return this; if (kind == Link) return type->getUnbound(); return nullptr; } bool LinkType::occurs(Type *typ, Type::Unification *undo) { if (auto tl = typ->getLink()) { if (tl->kind == Unbound) { if (tl->id == id) return true; if (tl->trait && occurs(tl->trait.get(), undo)) return true; if (undo && tl->level > level) { undo->leveled.emplace_back(tl->shared_from_this(), tl->level); tl->level = level; } return false; } else if (tl->kind == Link) { return occurs(tl->type.get(), undo); } else { return false; } } else if (typ->getStatic()) { return false; } if (const auto tc = typ->getClass()) { return std::ranges::any_of( tc->generics.begin(), tc->generics.end(), [&](const auto &g) { return g.type && occurs(g.type.get(), undo); }); } else { return false; } } } // namespace codon::ast::types ================================================ FILE: codon/parser/ast/types/link.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/parser/ast/types/traits.h" #include "codon/parser/ast/types/type.h" namespace codon::ast::types { struct LinkType : public Type { /// Enumeration describing the current state. enum Kind { Unbound, Generic, Link } kind; /// The unique identifier of an unbound or generic type. int id; /// The type-checking level of an unbound type. int level; /// The type to which LinkType points to. nullptr if unknown (unbound or generic). TypePtr type; /// >0 if a type is a static type (e.g. N in Int[N: int]); 0 otherwise. LiteralKind staticKind; /// Optional trait that unbound type requires prior to unification. std::shared_ptr trait; /// The generic name of a generic type, if applicable. Used for pretty-printing. std::string genericName; /// Type that will be used if an unbound is not resolved. TypePtr defaultType; /// Set if this type can be used unrealized as function argument during function /// realization. bool passThrough; public: LinkType(Cache *cache, Kind kind, int id, int level = 0, TypePtr type = nullptr, LiteralKind staticKind = LiteralKind::Runtime, std::shared_ptr trait = nullptr, TypePtr defaultType = nullptr, std::string genericName = "", bool passThrough = false); /// Convenience constructor for linked types. explicit LinkType(TypePtr type); public: int unify(Type *typ, Unification *undo) override; TypePtr generalize(int atLevel) const override; TypePtr instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const override; public: Type *follow() override; bool hasUnbounds(bool) const override; std::vector getUnbounds(bool) const override; bool canRealize() const override; bool isInstantiated() const override; std::string debugString(char mode) const override; std::string realizedName() const override; LinkType *getLink() override; FuncType *getFunc() override; ClassType *getPartial() override; ClassType *getClass() override; StaticType *getStatic() override; IntStaticType *getIntStatic() override; StrStaticType *getStrStatic() override; BoolStaticType *getBoolStatic() override; UnionType *getUnion() override; LinkType *getUnbound() override; private: /// Checks if a current (unbound) type occurs within a given type. /// Needed to prevent a recursive unification (e.g. ?1 with list[?1]). bool occurs(Type *typ, Type::Unification *undo); }; } // namespace codon::ast::types ================================================ FILE: codon/parser/ast/types/static.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include "codon/parser/ast.h" #include "codon/parser/ast/types/static.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon::ast::types { StaticType::StaticType(Cache *cache, const std::string &typeName) : ClassType(cache, typeName) {} bool StaticType::canRealize() const { return true; } bool StaticType::isInstantiated() const { return true; } std::string StaticType::realizedName() const { return debugString(0); } Type *StaticType::getNonStaticType() const { return cache->findClass(name); } /*****************************************************************/ IntStaticType::IntStaticType(Cache *cache, int64_t i) : StaticType(cache, "int"), value(i) {} int IntStaticType::unify(Type *typ, Unification *us) { if (auto t = typ->getIntStatic()) { return value == t->value ? 1 : -1; } else if (auto c = typ->getClass()) { return ClassType::unify(c, us); } else if (auto tl = typ->getLink()) { return tl->unify(this, us); } else { return -1; } } TypePtr IntStaticType::generalize(int atLevel) const { auto c = std::make_shared(cache, value); c->setSrcInfo(getSrcInfo()); return c; } TypePtr IntStaticType::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const { auto c = std::make_shared(this->cache, value); c->setSrcInfo(getSrcInfo()); return c; } std::string IntStaticType::debugString(char mode) const { return mode < 2 ? fmt::format("{}", value) : fmt::format("Literal[{}]", value); } Expr *IntStaticType::getStaticExpr() const { return cache->N(value); } /*****************************************************************/ StrStaticType::StrStaticType(Cache *cache, std::string s) : StaticType(cache, "str"), value(std::move(s)) {} int StrStaticType::unify(Type *typ, Unification *us) { if (auto t = typ->getStrStatic()) { return value == t->value ? 1 : -1; } else if (auto c = typ->getClass()) { return ClassType::unify(c, us); } else if (auto tl = typ->getLink()) { return tl->unify(this, us); } else { return -1; } } TypePtr StrStaticType::generalize(int atLevel) const { auto c = std::make_shared(cache, value); c->setSrcInfo(getSrcInfo()); return c; } TypePtr StrStaticType::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const { auto c = std::make_shared(this->cache, value); c->setSrcInfo(getSrcInfo()); return c; } std::string StrStaticType::debugString(char mode) const { return mode < 2 ? fmt::format("'{}'", escape(value)) : fmt::format("Literal['{}']", escape(value)); } Expr *StrStaticType::getStaticExpr() const { return cache->N(value); } /*****************************************************************/ BoolStaticType::BoolStaticType(Cache *cache, bool b) : StaticType(cache, "bool"), value(b) {} int BoolStaticType::unify(Type *typ, Unification *us) { if (auto t = typ->getBoolStatic()) { return value == t->value ? 1 : -1; } else if (auto c = typ->getClass()) { return ClassType::unify(c, us); } else if (auto tl = typ->getLink()) { return tl->unify(this, us); } else { return -1; } } TypePtr BoolStaticType::generalize(int atLevel) const { auto c = std::make_shared(cache, value); c->setSrcInfo(getSrcInfo()); return c; } TypePtr BoolStaticType::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const { auto c = std::make_shared(this->cache, value); c->setSrcInfo(getSrcInfo()); return c; } std::string BoolStaticType::debugString(char mode) const { return mode < 2 ? (value ? "1" : "0") : fmt::format("Literal[{}]", value ? "True" : "False"); } Expr *BoolStaticType::getStaticExpr() const { return cache->N(value); } } // namespace codon::ast::types ================================================ FILE: codon/parser/ast/types/static.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include "codon/parser/ast/types/class.h" namespace codon::ast::types { struct StaticType : public ClassType { explicit StaticType(Cache *, const std::string &); public: bool canRealize() const override; bool isInstantiated() const override; std::string realizedName() const override; virtual Expr *getStaticExpr() const = 0; virtual LiteralKind getStaticKind() = 0; virtual Type *getNonStaticType() const; StaticType *getStatic() override { return this; } }; struct IntStaticType : public StaticType { int64_t value; explicit IntStaticType(Cache *cache, int64_t); int unify(Type *typ, Unification *undo) override; TypePtr generalize(int atLevel) const override; TypePtr instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const override; std::string debugString(char mode) const override; Expr *getStaticExpr() const override; LiteralKind getStaticKind() override { return LiteralKind::Int; } IntStaticType *getIntStatic() override { return this; } }; struct StrStaticType : public StaticType { std::string value; explicit StrStaticType(Cache *cache, std::string); int unify(Type *typ, Unification *undo) override; TypePtr generalize(int atLevel) const override; TypePtr instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const override; std::string debugString(char mode) const override; Expr *getStaticExpr() const override; LiteralKind getStaticKind() override { return LiteralKind::String; } StrStaticType *getStrStatic() override { return this; } }; struct BoolStaticType : public StaticType { bool value; explicit BoolStaticType(Cache *cache, bool); int unify(Type *typ, Unification *undo) override; TypePtr generalize(int atLevel) const override; TypePtr instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const override; std::string debugString(char mode) const override; Expr *getStaticExpr() const override; LiteralKind getStaticKind() override { return LiteralKind::Bool; } BoolStaticType *getBoolStatic() override { return this; } }; using StaticTypePtr = std::shared_ptr; using IntStaticTypePtr = std::shared_ptr; using StrStaticTypePtr = std::shared_ptr; } // namespace codon::ast::types ================================================ FILE: codon/parser/ast/types/traits.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon::ast::types { Trait::Trait(const std::shared_ptr &type) : Type(type) {} Trait::Trait(Cache *cache) : Type(cache) {} bool Trait::canRealize() const { return false; } bool Trait::isInstantiated() const { return false; } std::string Trait::realizedName() const { return ""; } CallableTrait::CallableTrait(Cache *cache, std::vector args) : Trait(cache), args(std::move(args)) {} int CallableTrait::unify(Type *typ, Unification *us) { /// TODO: one day merge with the CallExpr's logic... if (auto tr = typ->getClass()) { TypePtr ft = nullptr; if (typ->is("TypeWrap")) { TypecheckVisitor tv(cache->typeCtx); ft = tv.instantiateType( tv.findMethod(typ->getClass(), "__call_no_self__").front(), typ->getClass()); tr = ft->getClass(); } if (tr->name == "NoneType") return 1; if (tr->name != "Function" && !tr->getPartial()) return -1; if (!tr->isRecord()) return -1; if (args.empty()) return 1; std::string known; TypePtr func = nullptr; // trFun can point to it auto trFun = tr; if (auto pt = tr->getPartial()) { int ic = 0; std::unordered_map c; func = pt->getPartialFunc()->instantiate(0, &ic, &c); trFun = func->getClass(); known = pt->getPartialMask(); auto knownArgTypes = pt->generics[1].type->getClass(); for (size_t i = 0, j = 0, k = 0; i < known.size(); i++) if ((*func->getFunc()->ast)[i].isGeneric()) { j++; } else if (known[i] == ClassType::PartialFlag::Included) { if ((*func->getFunc())[i - j]->unify(knownArgTypes->generics[k].type.get(), us) == -1) return -1; k++; } } else { known = std::string(tr->generics[0].type->getClass()->generics.size(), ClassType::PartialFlag::Missing); } auto inArgs = args[0]->getClass(); auto trInArgs = trFun->generics[0].type->getClass(); auto trAst = trFun->getFunc() ? trFun->getFunc()->ast : nullptr; size_t star = 0, kwStar = trInArgs->generics.size(); size_t total = 0; if (trAst) { star = trAst->getStarArgs(); kwStar = trAst->getKwStarArgs(); for (size_t fi = 0; fi < trAst->size(); fi++) { if (fi < star && !(*trAst)[fi].isValue()) star--; if (fi < kwStar && !(*trAst)[fi].isValue()) kwStar--; } if (kwStar < trAst->size() && star >= trInArgs->generics.size()) star -= 1; size_t preStar = 0; for (size_t fi = 0; fi < trAst->size(); fi++) { if (fi != kwStar && known[fi] != ClassType::PartialFlag::Included && (*trAst)[fi].isValue() && !startswith((*trAst)[fi].getName(), "$")) { total++; if (fi < star) preStar++; } } if (preStar < total) { if (inArgs->generics.size() < preStar) return -1; } else if (inArgs->generics.size() != total) { return -1; } } else { total = star = trInArgs->generics.size(); if (inArgs->generics.size() != total) return -1; } size_t i = 0; for (size_t fi = 0; i < inArgs->generics.size() && fi < star; fi++) { if (known[fi] != ClassType::PartialFlag::Included && trAst && (*trAst)[fi].isValue() && !startswith((*trAst)[fi].getName(), "$")) { if (inArgs->generics[i++].type->unify(trInArgs->generics[fi].type.get(), us) == -1) return -1; } } auto tv = TypecheckVisitor(cache->typeCtx); if (auto pf = trFun->getFunc()) { // Make sure to set types of *args/**kwargs so that the function that // is being unified with CallableTrait[] can be realized if (star < trInArgs->generics.size() - (kwStar < trInArgs->generics.size())) { std::vector starArgTypes; if (auto tp = tr->getPartial()) { auto ts = tp->generics[1].type->getClass(); seqassert(ts && !ts->generics.empty() && ts->generics[ts->generics.size() - 1].type->getClass(), "bad partial *args/**kwargs"); for (auto &tt : ts->generics[ts->generics.size() - 1].type->getClass()->generics) starArgTypes.push_back(tt.getType()); } for (; i < inArgs->generics.size(); i++) starArgTypes.push_back(inArgs->generics[i].getType()); if ((*pf->ast)[star].getType()) { // if we have *args: type, use those types auto starTyp = tv.extractType(tv.transform(clone((*pf->ast)[star].getType()))); for (auto &t : starArgTypes) t = starTyp; } auto tn = tv.instantiateType(tv.generateTuple(starArgTypes.size()), starArgTypes); if (tn->unify(trInArgs->generics[star].type.get(), us) == -1) return -1; } if (kwStar < trInArgs->generics.size()) { auto tt = tv.generateTuple(0); size_t id = 0; if (auto tp = tr->getPartial()) { auto ts = tp->generics[2].type->getClass(); seqassert(ts && ts->is("NamedTuple"), "bad partial *args/**kwargs"); id = ts->generics[0].type->getIntStatic()->value; tt = ts->generics[1].getType()->getClass(); } auto tid = std::make_shared(cache, id); auto kt = tv.instantiateType(tv.getStdLibType("NamedTuple"), {tid.get(), tt}); if (kt->unify(trInArgs->generics[kwStar].type.get(), us) == -1) return -1; } if (us && pf->canRealize()) { // Realize if possible to allow deduction of return type auto rf = tv.realize(pf); pf->unify(rf, us); } if (args[1]->unify(pf->getRetType(), us) == -1) return -1; } return 1; } else if (auto tl = typ->getLink()) { if (tl->kind == LinkType::Link) return unify(tl->type.get(), us); if (tl->kind == LinkType::Unbound) { if (tl->trait) { auto tt = dynamic_cast(tl->trait.get()); if (!tt || tt->args.size() != args.size()) return -1; for (int i = 0; i < args.size(); i++) if (args[i]->unify(tt->args[i].get(), us) == -1) return -1; } return 1; } } return -1; } TypePtr CallableTrait::generalize(int atLevel) const { auto g = args; for (auto &t : g) t = t ? t->generalize(atLevel) : nullptr; auto c = std::make_shared(cache, g); c->setSrcInfo(getSrcInfo()); return c; } TypePtr CallableTrait::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const { auto g = args; for (auto &t : g) t = t ? t->instantiate(atLevel, unboundCount, cache) : nullptr; auto c = std::make_shared(this->cache, g); c->setSrcInfo(getSrcInfo()); return c; } std::string CallableTrait::debugString(char mode) const { auto s = args[0]->debugString(mode); return fmt::format("CallableTrait[{},{}]", startswith(s, "Tuple") ? s.substr(5) : s, args[1]->debugString(mode)); } TypeTrait::TypeTrait(TypePtr typ) : Trait(typ), type(std::move(typ)) {} int TypeTrait::unify(Type *typ, Unification *us) { if (typ->getClass()) { // does not make sense otherwise and results in infinite cycles return typ->unify(type.get(), us); } if (typ->getUnbound()) return 0; return -1; } TypePtr TypeTrait::generalize(int atLevel) const { auto c = std::make_shared(type->generalize(atLevel)); c->setSrcInfo(getSrcInfo()); return c; } TypePtr TypeTrait::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const { auto c = std::make_shared(type->instantiate(atLevel, unboundCount, cache)); c->setSrcInfo(getSrcInfo()); return c; } std::string TypeTrait::debugString(char mode) const { return fmt::format("Trait[{}]", type->getClass() ? type->getClass()->name : "-"); } } // namespace codon::ast::types ================================================ FILE: codon/parser/ast/types/traits.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/parser/ast/types/type.h" namespace codon::ast::types { struct Trait : public Type { bool canRealize() const override; bool isInstantiated() const override; std::string realizedName() const override; protected: explicit Trait(const std::shared_ptr &); explicit Trait(Cache *); }; struct CallableTrait : public Trait { std::vector args; // tuple with arg types, ret type public: explicit CallableTrait(Cache *cache, std::vector args); int unify(Type *typ, Unification *undo) override; TypePtr generalize(int atLevel) const override; TypePtr instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const override; std::string debugString(char mode) const override; }; struct TypeTrait : public Trait { TypePtr type; public: explicit TypeTrait(TypePtr type); int unify(Type *typ, Unification *undo) override; TypePtr generalize(int atLevel) const override; TypePtr instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const override; std::string debugString(char mode) const override; }; } // namespace codon::ast::types ================================================ FILE: codon/parser/ast/types/type.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include "codon/parser/ast/types/type.h" #include "codon/parser/visitors/format/format.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon::ast::types { /// Undo a destructive unification. void Type::Unification::undo() { for (size_t i = linked.size(); i-- > 0;) { linked[i]->getLink()->kind = LinkType::Unbound; linked[i]->getLink()->type = nullptr; } for (size_t i = leveled.size(); i-- > 0;) { seqassertn(leveled[i].first->getLink()->kind == LinkType::Unbound, "not unbound [{}]", leveled[i].first->getSrcInfo()); leveled[i].first->getLink()->level = leveled[i].second; } for (auto &t : traits) t->getLink()->trait = nullptr; for (auto &t : statics) t->getLink()->staticKind = LiteralKind::Runtime; } Type::Type(const std::shared_ptr &typ) : cache(typ->cache) { setSrcInfo(typ->getSrcInfo()); } Type::Type(Cache *cache, const SrcInfo &info) : cache(cache) { setSrcInfo(info); } Type *Type::follow() { return this; } bool Type::hasUnbounds(bool) const { return false; } std::vector Type::getUnbounds(bool) const { return {}; } std::string Type::toString() const { return debugString(2); } std::string Type::prettyString() const { return debugString(0); } bool Type::is(const std::string &s) { return getClass() && getClass()->name == s; } LiteralKind Type::getStaticKind() { if (auto s = getStatic()) return s->getStaticKind(); if (auto l = follow()->getLink()) return l->staticKind; return LiteralKind::Runtime; } LiteralKind Type::literalFromString(const std::string &s) { if (s == "int") return LiteralKind::Int; if (s == "str") return LiteralKind::String; if (s == "bool") return LiteralKind::Bool; return LiteralKind::Runtime; } std::string Type::stringFromLiteral(LiteralKind k) { if (k == LiteralKind::Int) return "int"; if (k == LiteralKind::String) return "str"; if (k == LiteralKind::Bool) return "bool"; return ""; } Type *Type::operator<<(Type *t) { seqassert(t, "rhs is nullptr"); types::Type::Unification undo; if (unify(t, &undo) >= 0) { return this; } else { undo.undo(); return nullptr; } } } // namespace codon::ast::types ================================================ FILE: codon/parser/ast/types/type.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/parser/common.h" namespace codon::ast { struct Cache; struct Expr; struct TypeContext; } // namespace codon::ast namespace codon::ast::types { /// Forward declarations struct ClassType; struct FuncType; struct LinkType; struct StaticType; struct IntStaticType; struct StrStaticType; struct BoolStaticType; struct UnionType; enum LiteralKind { Runtime, Int, String, Bool }; /** * An abstract type class that describes methods needed for the type inference. * (Hindley-Milner's Algorithm W inference; see * https://github.com/tomprimozic/type-systems). * * Type instances are mutable and each type is intended to be instantiated and * manipulated as a shared_ptr. */ struct Type : public codon::SrcObject, public std::enable_shared_from_this { /// A structure that keeps the list of unification steps that can be undone later. /// Needed because the unify() is destructive. struct Unification { /// List of unbound types that have been changed. std::vector> linked; /// List of unbound types whose level has been changed. std::vector, int>> leveled; /// List of assigned traits. std::vector> traits; /// List of unbound types whose static status has been changed. std::vector> statics; public: /// Undo the unification step. void undo(); }; public: /// Unifies a given type with the current type. /// @param typ A given type. /// @param undo A reference to Unification structure to track the unification steps /// and allow later undoing of the unification procedure. /// @return Unification score: -1 for failure, anything >= 0 for success. /// Higher score translates to a "better" unification. /// ⚠️ Destructive operation if undo is not null! /// (both the current and a given type are modified). virtual int unify(Type *typ, Unification *undo) = 0; /// Generalize all unbound types whose level is below the provided level. /// This method replaces all unbound types with a generic types (e.g. ?1 -> T1). /// Note that the generalized type keeps the unbound type's ID. virtual std::shared_ptr generalize(int atLevel) const = 0; /// Instantiate all generic types. Inverse of generalize(): it replaces all /// generic types with new unbound types (e.g. T1 -> ?1234). /// Note that the instantiated type has a distinct and unique ID. /// @param atLevel Level of the instantiation. /// @param unboundCount A reference of the unbound counter to ensure that no two /// unbound types share the same ID. /// @param cache A reference to a lookup table to ensure that all instances of a /// generic point to the same unbound type (e.g. dict[T, list[T]] should /// be instantiated as dict[?1, list[?1]]). virtual std::shared_ptr instantiate(int atLevel, int *unboundCount, std::unordered_map> *cache) const = 0; public: /// Get the final type (follow through all LinkType links). /// For example, for (a->b->c->d) it returns d. virtual Type *follow(); /// Check if type has unbound/generic types. virtual bool hasUnbounds(bool includeGenerics) const; /// Obtain the list of internal unbound types. virtual std::vector getUnbounds(bool includeGenerics) const; /// True if a type is realizable. virtual bool canRealize() const = 0; /// True if a type is completely instantiated (has no unbounds or generics). virtual bool isInstantiated() const = 0; /// Debug print facility. std::string toString() const; /// Pretty-print facility. std::string prettyString() const; /// Pretty-print facility. mode is [0: pretty, 1: llvm, 2: debug] virtual std::string debugString(char mode) const = 0; /// Print the realization string. /// Similar to toString, but does not print the data unnecessary for realization /// (e.g. the function return type). virtual std::string realizedName() const = 0; LiteralKind getStaticKind(); /// Convenience virtual functions to avoid unnecessary casts. virtual FuncType *getFunc() { return nullptr; } virtual ClassType *getPartial() { return nullptr; } virtual ClassType *getClass() { return nullptr; } virtual LinkType *getLink() { return nullptr; } virtual LinkType *getUnbound() { return nullptr; } virtual StaticType *getStatic() { return nullptr; } virtual IntStaticType *getIntStatic() { return nullptr; } virtual StrStaticType *getStrStatic() { return nullptr; } virtual BoolStaticType *getBoolStatic() { return nullptr; } virtual UnionType *getUnion() { return nullptr; } virtual bool is(const std::string &s); Type *operator<<(Type *t); static LiteralKind literalFromString(const std::string &s); static std::string stringFromLiteral(LiteralKind k); protected: Cache *cache; explicit Type(const std::shared_ptr &); explicit Type(Cache *, const SrcInfo & = SrcInfo()); }; using TypePtr = std::shared_ptr; } // namespace codon::ast::types template struct fmt::formatter< T, std::enable_if_t, char>> : fmt::formatter { char presentation = 'b'; constexpr auto parse(const format_parse_context &ctx) -> decltype(ctx.begin()) { auto it = ctx.begin(); if (const auto end = ctx.end(); it != end && (*it == 'a' || *it == 'b' || *it == 'c')) presentation = *it++; return it; } template auto format(const T &p, FormatContext &ctx) const -> decltype(ctx.out()) { if (presentation == 'a') return fmt::format_to(ctx.out(), "{}", p.debugString(0)); else if (presentation == 'b') return fmt::format_to(ctx.out(), "{}", p.debugString(1)); else return fmt::format_to(ctx.out(), "{}", p.debugString(2)); } }; ================================================ FILE: codon/parser/ast/types/union.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/visitors/typecheck/typecheck.h" #include namespace codon::ast::types { UnionType::UnionType(Cache *cache) : ClassType(cache, "Union") { isTuple = true; for (size_t i = 0; i < MAX_UNION; i++) pendingTypes.emplace_back( std::make_shared(cache, LinkType::Generic, i, 0, nullptr)); } UnionType::UnionType(Cache *cache, const std::vector &generics, const std::vector &pendingTypes) : ClassType(cache, "Union", generics), pendingTypes(pendingTypes) { isTuple = true; } int UnionType::unify(Type *typ, Unification *us) { if (typ->getUnion()) { auto tr = typ->getUnion(); if (!isSealed() && !tr->isSealed()) { for (size_t i = 0; i < pendingTypes.size(); i++) if (pendingTypes[i]->unify(tr->pendingTypes[i].get(), us) == -1) return -1; return ClassType::unify(typ, us); } else if (!isSealed()) { return tr->unify(this, us); } else if (!tr->isSealed()) { if (tr->pendingTypes[0]->getLink() && tr->pendingTypes[0]->getLink()->kind == LinkType::Unbound) return ClassType::unify(tr, us); return -1; } // Do not hard-unify if we have unbounds if (!canRealize() || !tr->canRealize()) return 0; auto u1 = getRealizationTypes(); auto u2 = tr->getRealizationTypes(); if (u1.size() != u2.size()) return -1; int s1 = 2, s = 0; for (size_t i = 0; i < u1.size(); i++) { if ((s = u1[i]->unify(u2[i], us)) == -1) return -1; s1 += s; } return s1; } else if (auto tl = typ->getLink()) { return tl->unify(this, us); } return -1; } TypePtr UnionType::generalize(int atLevel) const { auto r = ClassType::generalize(atLevel); auto p = pendingTypes; for (auto &t : p) t = t->generalize(atLevel); auto t = std::make_shared(cache, r->getClass()->generics, p); t->setSrcInfo(getSrcInfo()); return t; } TypePtr UnionType::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const { auto r = ClassType::instantiate(atLevel, unboundCount, cache); auto p = pendingTypes; for (auto &t : p) t = t->instantiate(atLevel, unboundCount, cache); auto t = std::make_shared(this->cache, r->getClass()->generics, p); t->setSrcInfo(getSrcInfo()); return t; } std::string UnionType::debugString(char mode) const { if (mode == 2) return this->ClassType::debugString(mode); if (!generics[0].type->getClass()) return this->ClassType::debugString(mode); std::set gss; for (auto &a : generics[0].type->getClass()->generics) gss.insert(a.debugString(mode)); std::string s = join(gss, " | "); return (name + (s.empty() ? "" : ("[" + s + "]"))); } bool UnionType::canRealize() const { return isSealed() && ClassType::canRealize(); } std::string UnionType::realizedName() const { seqassert(canRealize(), "cannot realize {}", debugString(2)); return ClassType::realizedName(); } bool UnionType::addType(Type *typ) { seqassert(!isSealed(), "union already sealed"); if (this == typ) return true; if (auto tu = typ->getUnion()) { if (tu->isSealed()) { for (auto &t : tu->generics[0].type->getClass()->generics) if (!addType(t.type.get())) return false; } else { for (auto &t : tu->pendingTypes) { if (t->getLink() && t->getLink()->kind == LinkType::Unbound) break; else if (!addType(t.get())) return false; } } return true; } else { // Find first pending generic to which we can attach this! Unification us; for (auto &t : pendingTypes) if (auto l = t->getLink()) { if (l->kind == LinkType::Unbound) { t->unify(typ, &us); return true; } } return false; } } bool UnionType::isSealed() const { return generics[0].type->getClass() != nullptr; } void UnionType::seal() { seqassert(!isSealed(), "union already sealed"); auto tv = TypecheckVisitor(cache->typeCtx); size_t i; for (i = 0; i < pendingTypes.size(); i++) if (pendingTypes[i]->getLink() && pendingTypes[i]->getLink()->kind == LinkType::Unbound) break; std::vector typeSet; typeSet.reserve(i); for (size_t j = 0; j < i; j++) typeSet.emplace_back(pendingTypes[j].get()); auto t = tv.instantiateType(tv.generateTuple(typeSet.size()), typeSet); Unification us; generics[0].type->unify(t.get(), &us); } std::vector UnionType::getRealizationTypes() const { seqassert(canRealize(), "cannot realize {}", debugString(2)); std::map unionTypes; for (auto &u : generics[0].type->getClass()->generics) unionTypes[u.type->realizedName()] = u.type.get(); std::vector r; r.reserve(unionTypes.size()); for (auto &t : unionTypes | std::views::values) r.emplace_back(t); return r; } } // namespace codon::ast::types ================================================ FILE: codon/parser/ast/types/union.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/parser/ast/types/class.h" namespace codon::ast::types { struct UnionType : public ClassType { static constexpr int MAX_UNION = 256; std::vector pendingTypes; explicit UnionType(Cache *cache); UnionType(Cache *, const std::vector &, const std::vector &); public: int unify(Type *typ, Unification *undo) override; TypePtr generalize(int atLevel) const override; TypePtr instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) const override; public: bool canRealize() const override; std::string debugString(char mode) const override; std::string realizedName() const override; bool isSealed() const; UnionType *getUnion() override { return this; } bool addType(Type *); void seal(); std::vector getRealizationTypes() const; }; } // namespace codon::ast::types ================================================ FILE: codon/parser/ast/types.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include "codon/parser/ast/types/class.h" #include "codon/parser/ast/types/function.h" #include "codon/parser/ast/types/link.h" #include "codon/parser/ast/types/static.h" #include "codon/parser/ast/types/traits.h" #include "codon/parser/ast/types/type.h" #include "codon/parser/ast/types/union.h" ================================================ FILE: codon/parser/ast.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include "codon/cir/attribute.h" #include "codon/cir/base.h" #include "codon/parser/ast/attr.h" #include "codon/parser/ast/error.h" #include "codon/parser/ast/expr.h" #include "codon/parser/ast/node.h" #include "codon/parser/ast/stmt.h" #include "codon/parser/ast/types.h" ================================================ FILE: codon/parser/cache.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "cache.h" #include #include #include #include #include "codon/cir/pyextension.h" #include "codon/cir/util/irtools.h" #include "codon/parser/common.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/translate/translate.h" #include "codon/parser/visitors/typecheck/ctx.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon::ast { const std::string VAR_ARGV = getMangledVar("", "__argv__"); const std::string FN_OPTIONAL_UNWRAP = getMangledFunc("std.internal.types.optional", "unwrap"); Cache::Cache(std::string argv0, const std::shared_ptr &fs) : fs(fs) { if (!this->fs) { this->fs = std::make_shared(argv0); } this->_nodes = new std::vector>(); typeCtx = std::make_shared(this, ".root"); } std::string Cache::getTemporaryVar(const std::string &prefix, char sigil) { auto n = fmt::format("{}{}_{}", sigil ? fmt::format("{}_", sigil) : "", prefix, ++varCount); return n; } std::string Cache::rev(const std::string &s) { auto i = reverseIdentifierLookup.find(s); if (i != reverseIdentifierLookup.end()) return i->second; seqassertn(false, "'{}' has no non-canonical name", s); return ""; } SrcInfo Cache::generateSrcInfo() { return {FILE_GENERATED, generatedSrcInfoCount, generatedSrcInfoCount++, 0}; } std::string Cache::getContent(const SrcInfo &info) { auto i = imports.find(info.file); if (i == imports.end()) return ""; int line = info.line - 1; if (line < 0 || line >= i->second.content.size()) return ""; auto s = i->second.content[line]; int col = info.col - 1; if (col < 0 || col >= s.size()) return ""; int len = info.len; return s.substr(col, len); } Cache::Class *Cache::getClass(const types::ClassType *type) { auto name = type->name; return in(classes, name); } std::string Cache::getMethod(types::ClassType *typ, const std::string &member) { if (auto cls = getClass(typ)) { if (auto t = in(cls->methods, member)) return *t; } seqassertn(false, "cannot find '{}' in '{}'", member, typ->name); return ""; } types::ClassType *Cache::findClass(const std::string &name) const { auto f = typeCtx->find(name); if (f && f->isType()) return f->getType()->getClass()->generics[0].getType()->getClass(); return nullptr; } types::FuncType *Cache::findFunction(const std::string &name) const { auto f = typeCtx->find(name); if (f && f->type && f->isFunc()) return f->type->getFunc(); f = typeCtx->find(name + ":0"); if (f && f->type && f->isFunc()) return f->type->getFunc(); return nullptr; } types::FuncType *Cache::findMethod(types::ClassType *typ, const std::string &member, const std::vector &args) const { auto f = TypecheckVisitor(typeCtx).findBestMethod(typ, member, args); return f; } ir::types::Type *Cache::realizeType(types::ClassType *type, const std::vector &generics) { auto tv = TypecheckVisitor(typeCtx); if (auto rtv = tv.realize(tv.instantiateType(type, castVectorPtr(generics)))) { return classes[rtv->getClass()->name] .realizations[rtv->getClass()->realizedName()] ->ir; } return nullptr; } ir::Func *Cache::realizeFunction(types::FuncType *type, const std::vector &args, const std::vector &generics, types::ClassType *parentClass) { auto tv = TypecheckVisitor(typeCtx); auto t = tv.instantiateType(type, parentClass); if (args.size() != t->size() + 1) return nullptr; types::Type::Unification undo; if (t->getRetType()->unify(args[0].get(), &undo) < 0) { undo.undo(); return nullptr; } for (int gi = 1; gi < args.size(); gi++) { undo = types::Type::Unification(); if ((*t)[gi - 1]->unify(args[gi].get(), &undo) < 0) { undo.undo(); return nullptr; } } if (!generics.empty()) { if (generics.size() != t->funcGenerics.size()) return nullptr; for (int gi = 0; gi < generics.size(); gi++) { undo = types::Type::Unification(); if (t->funcGenerics[gi].type->unify(generics[gi].get(), &undo) < 0) { undo.undo(); return nullptr; } } } ir::Func *f = nullptr; if (auto rtv = tv.realize(t.get())) { auto pr = pendingRealizations; // copy it as it might be modified for (const auto &key : pr | std::views::keys) TranslateVisitor(codegenCtx).translateStmts(clone(functions[key].ast)); f = functions[rtv->getFunc()->ast->getName()].realizations[rtv->realizedName()]->ir; } return f; } ir::types::Type *Cache::makeTuple(const std::vector &types) { auto tv = TypecheckVisitor(typeCtx); auto t = tv.instantiateType(tv.generateTuple(types.size()), castVectorPtr(types)); return realizeType(t->getClass(), types); } ir::types::Type *Cache::makeFunction(const std::vector &types) { auto tv = TypecheckVisitor(typeCtx); seqassertn(!types.empty(), "types must have at least one argument"); std::vector tt; for (size_t i = 1; i < types.size(); i++) tt.emplace_back(types[i].get()); const auto &ret = types[0]; auto argType = tv.instantiateType(tv.generateTuple(types.size() - 1), tt); auto ft = realizeType(tv.getStdLibType("Function")->getClass(), {argType, ret}); return ft; } ir::types::Type *Cache::makeUnion(const std::vector &types) { auto tv = TypecheckVisitor(typeCtx); auto argType = tv.instantiateType(tv.generateTuple(types.size()), castVectorPtr(types)); return realizeType(tv.getStdLibType("Union")->getClass(), {argType}); } size_t Cache::getRealizationId(types::ClassType *type) { auto cv = TypecheckVisitor(typeCtx).getClassRealization(type); return cv->id; } std::vector Cache::getBaseRealizationIds(types::ClassType *type) { auto r = TypecheckVisitor(typeCtx).getClassRealization(type); std::vector baseIds; for (const auto &t : r->bases) { baseIds.push_back(getRealizationId(t.get())); } return baseIds; } std::vector Cache::getChildRealizationIds(types::ClassType *type) { auto cv = TypecheckVisitor(typeCtx).getClassRealization(type); auto parentId = cv->id; std::vector childIds; for (const auto &[_, c] : classes) { for (const auto &[_, r] : c.realizations) { for (const auto &t : r->bases) { if (getRealizationId(t.get()) == parentId) { childIds.push_back(r->id); break; } } } } return childIds; } std::vector Cache::parseCode(const std::string &code) { auto nodeOrErr = ast::parseCode(this, "", code, /*startLine=*/0); if (!nodeOrErr) throw exc::ParserException(nodeOrErr.takeError()); auto sctx = imports[MAIN_IMPORT].ctx; auto node = ast::TypecheckVisitor::apply(sctx, *nodeOrErr); auto old = codegenCtx->series; codegenCtx->series.clear(); ast::TranslateVisitor(codegenCtx).initializeGlobals(); ast::TranslateVisitor(codegenCtx).translateStmts(node); std::swap(old, codegenCtx->series); return old; } std::vector> Cache::mergeC3(std::vector> &seqs) { // Reference: https://www.python.org/download/releases/2.3/mro/ std::vector> result; for (size_t i = 0;; i++) { bool found = false; std::shared_ptr cand = nullptr; for (auto &seq : seqs) { if (seq.empty()) continue; found = true; bool nothead = false; for (auto &s : seqs) if (!s.empty()) { bool in = false; for (size_t j = 1; j < s.size(); j++) { if ((in |= (seq[0]->is(s[j]->getClass()->name)))) break; } if (in) { nothead = true; break; } } if (!nothead) { cand = std::dynamic_pointer_cast(seq[0]); break; } } if (!found) return result; if (!cand) return {}; result.push_back(cand); for (auto &s : seqs) if (!s.empty() && cand->is(s[0]->getClass()->name)) { s.erase(s.begin()); } } return result; } /** * Generate Python bindings for Cython-like access. */ void Cache::populatePythonModule() { using namespace ast; const std::string CYTHON_ITER = "_PyWrap.IterWrap"; if (!pythonExt) return; if (!pyModule) pyModule = std::make_shared(); LOG_USER("[py] ====== module generation ======="); auto tv = TypecheckVisitor(typeCtx); auto clss = classes; // needs copy as below fns can mutate this for (const auto &cn : clss | std::views::keys) { auto py = tv.cythonizeClass(cn); if (!py.name.empty()) pyModule->types.push_back(py); } // Handle __iternext__ wrappers for (const auto &cn : classes[CYTHON_ITER].realizations | std::views::keys) { auto py = tv.cythonizeIterator(cn); pyModule->types.push_back(py); } auto fns = functions; // needs copy as below fns can mutate this for (const auto &fn : fns | std::views::keys) { auto py = tv.cythonizeFunction(fn); if (!py.name.empty()) pyModule->functions.push_back(py); } // Handle pending realizations! auto pr = pendingRealizations; // copy it as it might be modified for (const auto &key : pr | std::views::keys) TranslateVisitor(codegenCtx).translateStmts(clone(functions[key].ast)); } } // namespace codon::ast ================================================ FILE: codon/parser/cache.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include "codon/cir/cir.h" #include "codon/cir/pyextension.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/ctx.h" #define FILE_GENERATED "" #define MODULE_MAIN "__main__" #define STDLIB_INTERNAL_MODULE "internal" #define MAIN_IMPORT "" #define STDLIB_IMPORT ":stdlib:" #define TYPE_CALLABLE "Callable" #define TYPE_FUNCTION "Function" #define TYPE_OPTIONAL "Optional" #define TYPE_TUPLE "Tuple" #define TYPE_TYPE "type" #define TRAIT_TYPE "TypeTrait" #define TRAIT_CALLABLE "CallableTrait" #define FN_DISPATCH_SUFFIX ":dispatch" #define FN_SETTER_SUFFIX ":set_" #define VAR_CLASS_TOPLEVEL ":toplevel" #define VAR_USED_SUFFIX ":used" #define MAX_ERRORS 5 #define MAX_TUPLE 2048 #define MAX_INT_WIDTH 10000 #define MAX_REALIZATION_DEPTH 200 #define MAX_STATIC_ITER 1024 namespace codon { class Compiler; } namespace codon::ast { extern const std::string VAR_ARGV; extern const std::string FN_OPTIONAL_UNWRAP; /// Forward declarations struct TypeContext; struct TranslateContext; /** * Cache encapsulation that holds data structures shared across various transformation * stages (AST transformation, type checking etc.). The subsequent stages (e.g. type * checking) assumes that previous stages populated this structure correctly. * Implemented to avoid a bunch of global objects. */ struct Cache { /// Filesystem object used for accessing files. std::shared_ptr fs; /// Stores a count for each identifier (name) seen in the code. /// Used to generate unique identifier for each name in the code (e.g. Foo -> Foo.2). std::unordered_map identifierCount; /// Maps a unique identifier back to the original name in the code /// (e.g. Foo.2 -> Foo). std::unordered_map reverseIdentifierLookup; /// Number of code-generated source code positions. Used to generate the next unique /// source-code position information. int generatedSrcInfoCount = 0; /// Number of unbound variables so far. Used to generate the next unique unbound /// identifier. int unboundCount = 256; /// Number of auto-generated variables so far. Used to generate the next unique /// variable name in getTemporaryVar() below. int varCount = 0; /// Scope counter. Each conditional block gets a new scope ID. int blockCount = 1; /// Holds module import data. struct Module { /// Relative module name (e.g., `foo.bar`) std::string name; /// Absolute filename of an import. std::string filename; /// Import typechecking context. std::shared_ptr ctx; /// Unique import variable for checking already loaded imports. std::string importVar; /// File content (line:col indexable) std::vector content; /// Set if loaded at toplevel bool loadedAtToplevel = true; void update(const std::string &name, const std::string &filename, const std::shared_ptr &ctx) { this->name = name; this->filename = filename; this->ctx = ctx; } }; /// Compiler Compiler *compiler = nullptr; /// IR module. ir::Module *module = nullptr; /// Table of imported files that maps an absolute filename to an Import structure. /// By convention, the key of the Codon's standard library is ":stdlib:", /// and the main module is "". std::unordered_map imports; /// Set of unique (canonical) global identifiers for marking such variables as global /// in code-generation step and in JIT. std::map globals; /// Stores class data for each class (type) in the source code. struct Class { /// Module information std::string module; /// Generic (unrealized) class template AST. ClassStmt *ast = nullptr; /// Non-simplified AST. Used for base class instantiation. ClassStmt *originalAst = nullptr; /// Class method lookup table. Each non-canonical name points /// to a root function name of a corresponding method. std::unordered_map methods; /// A class field (member). struct ClassField { /// Field name. std::string name; /// A corresponding generic field type. types::TypePtr type; /// Base class name (if available) std::string baseClass; Expr *typeExpr; ClassField(const std::string &name, const types::TypePtr &type, const std::string &baseClass, Expr *typeExpr = nullptr) : name(name), type(type), baseClass(baseClass), typeExpr(typeExpr) {} types::Type *getType() const { return type.get(); } }; /// A list of class' ClassField instances. List is needed (instead of map) because /// the order of the fields matters. std::vector fields; /// Dictionary of class variables: a name maps to a canonical name. std::unordered_map classVars; /// A class realization. struct ClassRealization { /// Realized class type. std::shared_ptr type; /// A list of field names and realization's realized field types. std::vector> fields; /// IR type pointer. codon::ir::types::Type *ir = nullptr; // Bases (in MRO order) std::vector> bases; /// Realization vtable (for each base class). /// Maps {base, function signature} to {thunk realization, thunk ID}. /// Base can be the realization itself. /// Order is important so map is used instead of unordered_map. std::map, std::shared_ptr> vtable; /// Realization ID size_t id = 0; types::ClassType *getType() const { return type.get(); } }; /// Realization lookup table that maps a realized class name to the corresponding /// ClassRealization instance. std::unordered_map> realizations; /// Set if a class is polymorphic and has RTTI. bool rtti = false; /// List of virtual method names std::unordered_set virtuals; /// MRO std::vector> mro; /// List of statically inherited classes. std::vector staticParentClasses; int jitCell = 0; bool hasRTTI() const { return rtti; } }; /// Class lookup table that maps a canonical class identifier to the corresponding /// Class instance. std::unordered_map classes; size_t classRealizationCnt = 0; Class *getClass(const types::ClassType *); std::map, size_t> thunkIds; struct Function { /// Module information std::string module; std::string rootName; /// Generic (unrealized) function template AST. FunctionStmt *ast; /// Unrealized function type. std::shared_ptr type; /// Non-simplified AST. FunctionStmt *origAst = nullptr; bool isToplevel = false; /// A function realization. struct FunctionRealization { /// Realized function type. std::shared_ptr type; /// Realized function AST (stored here for later realization in code generations /// stage). FunctionStmt *ast; /// IR function pointer. ir::Func *ir; /// Resolved captures std::vector captures; types::FuncType *getType() const { return type.get(); } }; /// Realization lookup table that maps a realized function name to the corresponding /// FunctionRealization instance. std::unordered_map> realizations = {}; std::set captures = {}; types::FuncType *getType() const { return type.get(); } }; /// Function lookup table that maps a canonical function identifier to the /// corresponding Function instance. std::unordered_map functions; /// Maps a "root" name of each function to the list of names of the function /// overloads (canonical names). std::unordered_map> overloads; /// Pointer to the later contexts needed for IR API access. std::shared_ptr typeCtx = nullptr; std::shared_ptr codegenCtx = nullptr; /// Set of function realizations that are to be translated to IR. std::set> pendingRealizations; /// Custom operators std::unordered_map>> customBlockStmts; std::unordered_map> customExprStmts; /// Set if the Codon is running in JIT mode. bool isJit = false; int jitCell = 0; std::unordered_set generatedTuples; std::unordered_map generatedKwTuples; std::vector> generatedTupleNames = {{}}; ParserErrors errors; /// Set if Codon operates in Python compatibility mode (e.g., with Python numerics) bool pythonCompat = false; /// Set if Codon operates in Python extension mode bool pythonExt = false; public: explicit Cache(std::string argv0 = "", const std::shared_ptr &fs = nullptr); /// Return a uniquely named temporary variable of a format /// "{sigil}_{prefix}{counter}". A sigil should be a non-lexable symbol. std::string getTemporaryVar(const std::string &prefix = "", char sigil = '%'); /// Get the non-canonical version of a canonical name. std::string rev(const std::string &s); /// Generate a unique SrcInfo for internally generated AST nodes. SrcInfo generateSrcInfo(); /// Get file contents at the given location. std::string getContent(const SrcInfo &info); /// Realization API. /// Find a class with a given canonical name and return a matching types::Type pointer /// or a nullptr if a class is not found. /// Returns an _uninstantiated_ type. types::ClassType *findClass(const std::string &name) const; /// Find a function with a given canonical name and return a matching types::Type /// pointer or a nullptr if a function is not found. /// Returns an _uninstantiated_ type. types::FuncType *findFunction(const std::string &name) const; /// Find the canonical name of a class method. std::string getMethod(types::ClassType *typ, const std::string &member); /// Find the class method in a given class type that best matches the given arguments. /// Returns an _uninstantiated_ type. types::FuncType *findMethod(types::ClassType *typ, const std::string &member, const std::vector &args) const; /// Given a class type and the matching generic vector, instantiate the type and /// realize it. ir::types::Type *realizeType(types::ClassType *type, const std::vector &generics = {}); /// Given a function type and function arguments, instantiate the type and /// realize it. The first argument is the function return type. /// You can also pass function generics if a function has one (e.g. T in def /// foo[T](...)). If a generic is used as an argument, it will be auto-deduced. Pass /// only if a generic cannot be deduced from the provided args. ir::Func *realizeFunction(types::FuncType *type, const std::vector &args, const std::vector &generics = {}, types::ClassType *parentClass = nullptr); ir::types::Type *makeTuple(const std::vector &types); ir::types::Type *makeFunction(const std::vector &types); ir::types::Type *makeUnion(const std::vector &types); size_t getRealizationId(types::ClassType *type); std::vector getBaseRealizationIds(types::ClassType *type); std::vector getChildRealizationIds(types::ClassType *type); std::vector parseCode(const std::string &code); static std::vector> mergeC3(std::vector> &); std::shared_ptr pyModule = nullptr; void populatePythonModule(); private: std::vector> *_nodes; public: /// Convenience method that constructs a node with the visitor's source location. template Tn *N(Ts &&...args) { _nodes->emplace_back(std::make_unique(std::forward(args)...)); Tn *t = static_cast(_nodes->back().get()); t->cache = this; return t; } template Tn *NS(const ASTNode *srcInfo, Ts &&...args) { _nodes->emplace_back(std::make_unique(std::forward(args)...)); Tn *t = static_cast(_nodes->back().get()); t->cache = this; t->setSrcInfo(srcInfo->getSrcInfo()); return t; } std::unordered_map _timings; struct CTimer { Cache *c; Timer t; std::string name; CTimer(Cache *c, std::string n) : c(c), t(Timer("")), name(std::move(n)) {} double elapsed() const { return t.elapsed(); } ~CTimer() { c->_timings[name] += t.elapsed(); t.logged = true; } }; template std::vector castVectorPtr(std::vector> v) { std::vector r; r.reserve(v.size()); for (const auto &i : v) r.emplace_back(i.get()); return r; } }; } // namespace codon::ast ================================================ FILE: codon/parser/common.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "common.h" #include #include #include #include #include #include "codon/compiler/compiler.h" #include "codon/parser/cache.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" #include #include CMRC_DECLARE(codon); namespace codon::ast { IFilesystem::path_t IFilesystem::canonical(const path_t &path) const { return std::filesystem::weakly_canonical(path); } void IFilesystem::add_search_path(const std::string &p) { auto path = path_t(p); if (exists(path)) { search_paths.emplace_back(canonical(path)); } } std::vector IFilesystem::get_stdlib_paths() const { return search_paths; } ImportFile IFilesystem::get_root(const path_t &sp) const { bool isStdLib = false; std::string s = sp; std::string root; for (auto &p : get_stdlib_paths()) if (startswith(s, p)) { root = p; isStdLib = true; break; } auto module0 = get_module0().parent_path(); if (!isStdLib && !module0.empty() && startswith(s, module0)) root = module0; std::string ext = ".codon"; if (!((root.empty() || startswith(s, root)) && endswith(s, ext))) ext = ".py"; seqassertn((root.empty() || startswith(s, root)) && endswith(s, ext), "bad path substitution: {}, {}", s, root); auto module = s.substr(root.size() + 1, s.size() - root.size() - ext.size() - 1); std::ranges::replace(module, '/', '.'); return ImportFile{(!isStdLib && root == module0) ? ImportFile::PACKAGE : ImportFile::STDLIB, s, module}; } Filesystem::Filesystem(const std::string &argv0, const std::string &module0) : argv0(argv0), module0(module0) { if (auto p = getenv("CODON_PATH")) { add_search_path(p); } if (!argv0.empty()) { auto root = executable_path(argv0.c_str()).parent_path(); for (auto loci : {"../lib/codon/stdlib", "../stdlib", "stdlib"}) { add_search_path(root / loci); } for (auto loci : {"../lib/codon/plugins", "../plugins"}) { add_search_path(root / loci); } } } std::vector Filesystem::read_lines(const path_t &path) const { std::vector lines; if (path == "-") { for (std::string line; getline(std::cin, line);) { lines.push_back(line); } } else { std::ifstream fin(path); if (!fin) E(error::Error::COMPILER_NO_FILE, SrcInfo(), path); for (std::string line; getline(fin, line);) { lines.push_back(line); } fin.close(); } return lines; } void Filesystem::set_module0(const std::string &s) { module0 = canonical(path_t(s)); } IFilesystem::path_t Filesystem::get_module0() const { return module0.empty() ? IFilesystem::path_t() : canonical(module0); } IFilesystem::path_t Filesystem::executable_path(const char *argv0) { auto exc = llvm::sys::fs::getMainExecutable( argv0, reinterpret_cast(executable_path)); return path_t(exc); } bool Filesystem::exists(const IFilesystem::path_t &path) const { return std::filesystem::exists(path); } ResourceFilesystem::ResourceFilesystem(const std::string &argv0, const std::string &module0, bool allowExternal) : Filesystem(argv0, module0), allowExternal(allowExternal) { search_paths = {"/stdlib"}; } std::vector ResourceFilesystem::read_lines(const path_t &path) const { auto fs = cmrc::codon::get_filesystem(); if (!fs.exists(path) && allowExternal) return Filesystem::read_lines(path); std::vector lines; if (path == "-") { E(error::Error::COMPILER_NO_FILE, SrcInfo(), ""); } else { try { auto fd = fs.open(path); auto contents = std::string(fd.begin(), fd.end()); lines = split(contents, '\n'); } catch (std::system_error &) { E(error::Error::COMPILER_NO_FILE, SrcInfo(), path); } } return lines; } bool ResourceFilesystem::exists(const path_t &path) const { auto fs = cmrc::codon::get_filesystem(); if (fs.exists(path)) return true; if (allowExternal) return Filesystem::exists(path); return false; } /// String and collection utilities std::vector split(const std::string &s, char delim) { std::vector items; std::string item; std::istringstream iss(s); while (std::getline(iss, item, delim)) items.push_back(item); return items; } // clang-format off std::string escape(const std::string &str) { std::string r; r.reserve(str.size()); for (unsigned char c : str) { switch (c) { case '\a': r += "\\a"; break; case '\b': r += "\\b"; break; case '\f': r += "\\f"; break; case '\n': r += "\\n"; break; case '\r': r += "\\r"; break; case '\t': r += "\\t"; break; case '\v': r += "\\v"; break; case '\'': r += "\\'"; break; case '\\': r += "\\\\"; break; default: if (c < 32 || c >= 127) r += fmt::format("\\x{:x}", c); else r += c; } } return r; } std::string unescape(const std::string &str) { std::string r; r.reserve(str.size()); for (int i = 0; i < str.size(); i++) { if (str[i] == '\\' && i + 1 < str.size()) switch(str[i + 1]) { case 'a': r += '\a'; i++; break; case 'b': r += '\b'; i++; break; case 'f': r += '\f'; i++; break; case 'n': r += '\n'; i++; break; case 'r': r += '\r'; i++; break; case 't': r += '\t'; i++; break; case 'v': r += '\v'; i++; break; case '"': r += '\"'; i++; break; case '\'': r += '\''; i++; break; case '\\': r += '\\'; i++; break; case 'x': { if (i + 3 > str.size()) throw std::invalid_argument("invalid \\x code"); size_t pos = 0; auto code = std::stoi(str.substr(i + 2, 2), &pos, 16); r += static_cast(code); i += pos + 1; break; } default: if (str[i + 1] >= '0' && str[i + 1] <= '7') { size_t pos = 0; auto code = std::stoi(str.substr(i + 1, 3), &pos, 8); r += static_cast(code); i += pos; } else { r += str[i]; } } else r += str[i]; } return r; } // clang-format on std::string escapeFStringBraces(const std::string &str, int start, int len) { std::string t; t.reserve(len); for (int i = start; i < start + len; i++) if (str[i] == '{') t += "{{"; else if (str[i] == '}') t += "}}"; else t += str[i]; return t; } int findStar(const std::string &s) { int i = 0; for (; i < s.size(); i++) { if (s[i] == '(') return i + 1; if (!isspace(s[i])) return i; } return i; } bool in(const std::string &m, char item) { auto f = m.find(item); return f != std::string::npos; } bool in(const std::string &m, const std::string &item) { auto f = m.find(item); return f != std::string::npos; } size_t startswith(const std::string &str, const std::string &prefix) { if (prefix.empty()) return true; return (str.size() >= prefix.size() && str.substr(0, prefix.size()) == prefix) ? prefix.size() : 0; } size_t endswith(const std::string &str, const std::string &suffix) { if (suffix.empty()) return true; return (str.size() >= suffix.size() && str.substr(str.size() - suffix.size()) == suffix) ? suffix.size() : 0; } void ltrim(std::string &str) { str.erase(str.begin(), std::ranges::find_if( str, [](unsigned char ch) { return !std::isspace(ch); })); } void rtrim(std::string &str) { /// https://stackoverflow.com/questions/216823/whats-the-best-way-to-trim-stdstring str.erase(std::find_if(str.rbegin(), str.rend(), [](unsigned char ch) { return !std::isspace(ch); }) .base(), str.end()); } bool isdigit(const std::string &str) { return std::ranges::all_of(str, ::isdigit); } // Adapted from https://github.com/gpakosz/whereami/blob/master/src/whereami.c (MIT) #ifdef __APPLE__ #include #include #endif std::string library_path() { std::string result; #ifdef __APPLE__ char buffer[PATH_MAX]; for (;;) { Dl_info info; if (dladdr(__builtin_extract_return_addr(__builtin_return_address(0)), &info)) { char *resolved = realpath(info.dli_fname, buffer); if (!resolved) break; result = std::string(resolved); } break; } #else for (int r = 0; r < 5; r++) { FILE *maps = fopen("/proc/self/maps", "r"); if (!maps) break; for (;;) { char buffer[PATH_MAX < 1024 ? 1024 : PATH_MAX]; uint64_t low, high; char perms[5]; uint64_t offset; uint32_t major, minor; char path[PATH_MAX]; uint32_t inode; if (!fgets(buffer, sizeof(buffer), maps)) break; if (sscanf(buffer, "%" PRIx64 "-%" PRIx64 " %s %" PRIx64 " %x:%x %u %s\n", &low, &high, perms, &offset, &major, &minor, &inode, path) == 8) { uint64_t addr = (uintptr_t)(__builtin_extract_return_addr(__builtin_return_address(0))); if (low <= addr && addr <= high) { char *resolved = realpath(path, buffer); if (resolved) result = std::string(resolved); break; } } } fclose(maps); if (!result.empty()) break; } #endif return result; } std::string Filesystem::get_absolute_path(const std::string &path) { char *c = realpath(path.c_str(), nullptr); if (!c) return path; std::string result(c); free(c); return result; } std::shared_ptr getImportFile(Cache *cache, const std::string &what, const std::string &relativeTo, bool forceStdlib) { auto fs = cache->fs.get(); std::vector paths; auto parentRelativeTo = IFilesystem::path_t(relativeTo).parent_path(); if (what != "") { if (!forceStdlib) { auto path = parentRelativeTo / what; path.replace_extension("codon"); if (fs->exists(path)) paths.emplace_back(fs->canonical(path)); path = parentRelativeTo / what / "__init__.codon"; if (fs->exists(path)) paths.emplace_back(fs->canonical(path)); path = parentRelativeTo / what; path.replace_extension("py"); if (fs->exists(path)) paths.emplace_back(fs->canonical(path)); path = parentRelativeTo / what / "__init__.py"; if (fs->exists(path)) paths.emplace_back(fs->canonical(path)); } } auto checkPlugin = [&paths, &fs, &cache](const std::filesystem::path &path, const std::string &what) { if (fs->exists(path / what / "plugin.toml") && fs->exists(path / what / "stdlib" / what / "__init__.codon")) { bool failed = false; if (cache->compiler && !cache->compiler->isPluginLoaded(path / what)) { LOG_REALIZE("Loading plugin {}", path / what); llvm::handleAllErrors(cache->compiler->load(path / what), [&failed](const codon::error::PluginErrorInfo &e) { codon::compilationError(e.getMessage(), /*file=*/"", /*line=*/0, /*col=*/0, /*len=*/0, /*errorCode=*/-1, /*terminate=*/false); failed = true; }); } if (!failed) paths.emplace_back( fs->canonical(path / what / "stdlib" / what / "__init__.codon")); } }; if (paths.empty()) { // Load a plugin maybe checkPlugin(parentRelativeTo, what); } for (auto &p : fs->get_stdlib_paths()) { auto path = p / what; path.replace_extension("codon"); if (fs->exists(path)) paths.emplace_back(fs->canonical(path)); path = p / what / "__init__.codon"; if (fs->exists(path)) paths.emplace_back(fs->canonical(path)); // Load a plugin maybe checkPlugin(p, what); } if (paths.empty()) return nullptr; return std::make_shared(fs->get_root(paths[0])); } std::string getMangledClass(const std::string &module, const std::string &cls, size_t id) { if (module == "std.internal.core") return cls; std::string num; if (!in(cls, '.')) num = "." + std::to_string(id); return (module.empty() ? "" : (module + ".")) + cls + num; } std::string getMangledFunc(const std::string &module, const std::string &fn, size_t overload, size_t id) { if (module == "std.internal.core") return fn + ":" + std::to_string(overload); std::string num; if (!in(fn, '.')) num = "." + std::to_string(id); return (module.empty() ? "" : (module + ".")) + fn + num + ":" + std::to_string(overload); } std::string getMangledMethod(const std::string &module, const std::string &cls, const std::string &method, size_t overload, size_t id) { if (module == "std.internal.core") return cls + "." + method + ":" + std::to_string(overload); std::string num; if (!in(cls, '.')) num = "." + std::to_string(id); return (module.empty() ? "" : (module + ".")) + cls + num + "." + method + ":" + std::to_string(overload); } std::string getMangledVar(const std::string &module, const std::string &var, size_t id) { std::string num; if (!in(var, '.')) num = "." + std::to_string(id); return (module.empty() ? "" : (module + ".")) + var + num; } } // namespace codon::ast ================================================ FILE: codon/parser/common.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include #include #include #include #include "codon/util/common.h" namespace codon { namespace ir { struct Attribute; } namespace ast { struct Cache; /// String and collection utilities /// Split a delimiter-separated string into a vector of strings. std::vector split(const std::string &str, char delim); /// Escape a C string (replace \n with \\n etc.). std::string escape(const std::string &str); /// Unescape a C string (replace \\n with \n etc.). std::string unescape(const std::string &str); /// Escape an F-string braces (replace { and } with {{ and }}). std::string escapeFStringBraces(const std::string &str, int start, int len); int findStar(const std::string &s); /// True if a string str starts with a prefix. size_t startswith(const std::string &str, const std::string &prefix); /// True if a string str ends with a suffix. size_t endswith(const std::string &str, const std::string &suffix); /// Trims whitespace at the beginning of the string. void ltrim(std::string &str); /// Trims whitespace at the end of the string. void rtrim(std::string &str); /// True if a string only contains digits. bool isdigit(const std::string &str); /// Combine items separated by a delimiter into a string. /// Combine items separated by a delimiter into a string. template std::string join(const T &items, const std::string &delim = " ") { std::string s; bool first = true; for (const auto &i : items) { if (!first) s += delim; s += i; first = false; } return s; } template std::string join(const T &items, const std::string &delim, size_t start, size_t end = (1ull << 31)) { std::string s; if (end > items.size()) end = items.size(); for (int i = start; i < end; i++) s += (i > start ? delim : "") + items[i]; return s; } /// Combine items separated by a delimiter into a string. template std::string combine(const std::vector &items, const std::string &delim = " ", const int indent = -1) { std::string s; for (int i = 0; i < items.size(); i++) if (items[i]) s += (i ? delim : "") + items[i]->toString(indent); return s; } template std::string combine2(const std::vector &items, const std::string &delim = ",", int start = 0, int end = -1) { std::string s; if (end == -1) end = items.size(); for (int i = start; i < end; i++) s += (i ? delim : "") + fmt::format("{}", items[i]); return s; } /// @return True if an item is found in a vector vec. template const T *in(const std::vector &vec, const U &item, size_t start = 0) { auto f = std::find(vec.begin() + start, vec.end(), item); return f != vec.end() ? &(*f) : nullptr; } /// @return True if an item is found in a set s. template const T *in(const std::set &s, const U &item) { auto f = s.find(item); return f != s.end() ? &(*f) : nullptr; } /// @return True if an item is found in an unordered_set s. template const T *in(const std::unordered_set &s, const U &item) { auto f = s.find(item); return f != s.end() ? &(*f) : nullptr; } /// @return True if an item is found in a map m. template const V *in(const std::map &m, const U &item) { auto f = m.find(item); return f != m.end() ? &(f->second) : nullptr; } /// @return True if an item is found in an unordered_map m. template const V *in(const std::unordered_map &m, const U &item) { auto f = m.find(item); return f != m.end() ? &(f->second) : nullptr; } /// @return True if an item is found in an unordered_map m. template V *in(std::unordered_map &m, const U &item) { auto f = m.find(item); return f != m.end() ? &(f->second) : nullptr; } /// @return True if an item is found in an string m. bool in(const std::string &m, const std::string &item); bool in(const std::string &m, char item); /// AST utilities template T clone(const T &t, bool clean = false) { return t.clone(clean); } template std::remove_const_t *clone(T *t, bool clean = false) { return t ? static_cast *>(t->clone(clean)) : nullptr; } template std::remove_const_t *clean_clone(T *t) { return clone(t, true); } /// Clones a vector of cloneable pointer objects. template std::vector> clone(const std::vector &t, bool clean = false) { std::vector> v; for (auto &i : t) v.push_back(clone(i, clean)); return v; } /// Path utilities /// Detect an absolute path of the current libcodonc. /// @return Absolute executable path or argv0 if one cannot be found. std::string library_path(); struct ImportFile { enum Status { STDLIB, PACKAGE }; Status status; /// Absolute path of an import. std::string path; /// Module name (e.g. foo.bar.baz). std::string module; }; class IFilesystem { public: using path_t = std::filesystem::path; protected: std::vector search_paths; public: virtual ~IFilesystem() {}; virtual std::vector read_lines(const path_t &path) const = 0; virtual bool exists(const path_t &path) const = 0; virtual std::vector get_stdlib_paths() const; virtual path_t canonical(const path_t &path) const; ImportFile get_root(const path_t &s) const; virtual path_t get_module0() const { return ""; } virtual void set_module0(const std::string &) {} virtual void add_search_path(const std::string &p); }; class Filesystem : public IFilesystem { public: using IFilesystem::path_t; private: path_t argv0, module0; std::vector extraPaths; public: Filesystem(const std::string &argv0, const std::string &module0 = ""); std::vector read_lines(const path_t &path) const override; bool exists(const path_t &path) const override; path_t get_module0() const override; void set_module0(const std::string &s) override; public: /// Detect an absolute path of the current executable (whose argv0 is known). /// @return Absolute executable path or argv0 if one cannot be found. static path_t executable_path(const char *argv0); /// @return The absolute canonical path of a given path. static std::string get_absolute_path(const std::string &path); }; class ResourceFilesystem : public Filesystem { bool allowExternal; public: ResourceFilesystem(const std::string &argv0, const std::string &module0 = "", bool allowExternal = true); std::vector read_lines(const path_t &path) const override; bool exists(const path_t &path) const override; }; /// Find an import file what given an executable path (argv0) either in the standard /// library or relative to a file relativeTo. Set forceStdlib for searching only the /// standard library. std::shared_ptr getImportFile(Cache *cache, const std::string &what, const std::string &relativeTo, bool forceStdlib = false); template class SetInScope { T *t; T origVal; public: SetInScope(T *t, const T &val) : t(t), origVal(*t) { *t = val; } ~SetInScope() { *t = origVal; } }; std::string getMangledClass(const std::string &module, const std::string &cls, size_t id = 0); std::string getMangledFunc(const std::string &module, const std::string &fn, size_t overload = 0, size_t id = 0); std::string getMangledMethod(const std::string &module, const std::string &cls, const std::string &method, size_t overload = 0, size_t id = 0); std::string getMangledVar(const std::string &module, const std::string &var, size_t id = 0); } // namespace ast } // namespace codon ================================================ FILE: codon/parser/ctx.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include #include #include #include "codon/parser/ast.h" namespace codon::ast { /** * A variable table (transformation context). * Base class that holds a list of existing identifiers and their block hierarchy. * @tparam T Variable type. */ template class Context : public std::enable_shared_from_this> { public: using Item = std::shared_ptr; protected: using Map = std::unordered_map>; /// Maps a identifier to a stack of objects that share the same identifier. /// Each object is represented by a nesting level and a pointer to that object. /// Top of the stack is the current block; the bottom is the outer-most block. /// Stack is represented as std::deque to allow iteration and access to the outer-most /// block. Map map; /// Stack of blocks and their corresponding identifiers. Top of the stack is the /// current block. std::deque> stack; private: /// Set of current context flags. std::unordered_set flags; /// The absolute path of the current module. std::string filename; /// SrcInfo stack used for obtaining source information of the current expression. std::vector nodeStack; public: explicit Context(std::string filename) : filename(std::move(filename)) { /// Add a top-level block to the stack. stack.push_front(std::list()); } virtual ~Context() = default; /// Add an object to the top of the stack. virtual void add(const std::string &name, const Item &var) { seqassertn(!name.empty(), "adding an empty identifier"); map[name].push_front(var); stack.front().push_back(name); } /// Remove the top-most object with a given identifier. void remove(const std::string &name) { removeFromMap(name); for (auto &s : stack) { auto i = std::ranges::find(s, name); if (i != s.end()) { s.erase(i); return; } } } /// Return a top-most object with a given identifier or nullptr if it does not exist. virtual Item find(const std::string &name) const { auto it = map.find(name); return it != map.end() ? it->second.front() : nullptr; } /// Return all objects that share a common identifier or nullptr if it does not exist. virtual std::list *find_all(const std::string &name) { auto it = map.find(name); return it != map.end() ? &(it->second) : nullptr; } /// Add a new block (i.e. adds a stack level). virtual void addBlock() { stack.push_front(std::list()); } /// Remove the top-most block and all variables it holds. virtual void popBlock() { for (auto &name : stack.front()) removeFromMap(name); stack.pop_front(); } void removeFromTopStack(const std::string &name) { auto it = std::ranges::find(stack.front(), name); if (it != stack.front().end()) stack.front().erase(it); } /// The absolute path of a current module. std::string getFilename() const { return filename; } /// Sets the absolute path of a current module. void setFilename(std::string file) { filename = std::move(file); } /// Convenience functions to allow range-based for loops over a context. typename Map::iterator begin() { return map.begin(); } typename Map::iterator end() { return map.end(); } /// Pretty-prints the current context state. virtual void dump() {} protected: /// Remove an identifier from the map only. virtual void removeFromMap(const std::string &name) { auto i = map.find(name); if (i == map.end()) return; seqassertn(i->second.size(), "identifier {} not found in the map", name); i->second.pop_front(); if (!i->second.size()) map.erase(name); } public: /* SrcInfo helpers */ void pushNode(ASTNode *n) { nodeStack.emplace_back(n); } void popNode() { nodeStack.pop_back(); } ASTNode *getLastNode() const { return nodeStack.back(); } ASTNode *getParentNode() const { assert(nodeStack.size() > 1); return nodeStack[nodeStack.size() - 2]; } SrcInfo getSrcInfo() const { return nodeStack.back()->getSrcInfo(); } size_t getStackSize() const { return stack.size(); } }; } // namespace codon::ast ================================================ FILE: codon/parser/match.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "codon/parser/match.h" namespace codon::matcher { match_zero_or_more_t MAny() { return match_zero_or_more_t(); } match_startswith_t MStarts(std::string s) { return match_startswith_t{std::move(s)}; } match_endswith_t MEnds(std::string s) { return match_endswith_t{std::move(s)}; } match_contains_t MContains(std::string s) { return match_contains_t{std::move(s)}; } template <> bool match(const char *c, const char *d) { return std::string(c) == std::string(d); } template <> bool match(const char *c, std::string d) { return std::string(c) == d; } template <> bool match(std::string c, const char *d) { return std::string(d) == c; } template <> bool match(double &a, double b) { return std::abs(a - b) < __FLT_EPSILON__; } template <> bool match(std::string s, match_startswith_t m) { return m.s.size() <= s.size() && s.substr(0, m.s.size()) == m.s; } template <> bool match(std::string s, match_endswith_t m) { return m.s.size() <= s.size() && s.substr(s.size() - m.s.size(), m.s.size()) == m.s; } template <> bool match(std::string s, match_contains_t m) { return s.find(m.s) != std::string::npos; } template <> bool match(const char *s, match_startswith_t m) { return match(std::string(s), m); } template <> bool match(const char *s, match_endswith_t m) { return match(std::string(s), m); } template <> bool match(const char *s, match_contains_t m) { return match(std::string(s), m); } } // namespace codon::matcher ================================================ FILE: codon/parser/match.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/cir/base.h" namespace codon::matcher { template struct match_t { std::tuple args; std::function fn; match_t(MA... args, std::function fn) : args(std::tuple(args...)), fn(fn) {} }; template struct match_or_t { std::tuple args; match_or_t(MA... args) : args(std::tuple(args...)) {} }; struct match_ignore_t {}; struct match_zero_or_more_t {}; struct match_startswith_t { std::string s; }; struct match_endswith_t { std::string s; }; struct match_contains_t { std::string s; }; template match_t M(TA... args) { return match_t(args..., nullptr); } template match_t MCall(TA... args, std::function fn) { return match_t(args..., fn); } template match_t MVar(TA... args, T &tp) { return match_t(args..., [&tp](T &t) { tp = t; }); } template match_t MVar(TA... args, T *&tp) { return match_t(args..., [&tp](T &t) { tp = &t; }); } template match_or_t MOr(TA... args) { return match_or_t(args...); } match_zero_or_more_t MAny(); match_startswith_t MStarts(std::string s); match_endswith_t MEnds(std::string s); match_contains_t MContains(std::string s); ////////////////////////////////////////////////////////////////////////////// template bool match(T t, M m) { if constexpr (std::is_same_v) return t == m; return false; } template bool match(T &t, match_ignore_t) { return true; } template bool match(T &t, match_zero_or_more_t) { return true; } template <> bool match(const char *c, const char *d); template <> bool match(const char *c, std::string d); template <> bool match(std::string c, const char *d); template <> bool match(double &a, double b); template <> bool match(std::string s, match_startswith_t m); template <> bool match(std::string s, match_endswith_t m); template <> bool match(std::string s, match_contains_t m); template <> bool match(const char *s, match_startswith_t m); template <> bool match(const char *s, match_endswith_t m); template <> bool match(const char *s, match_contains_t m); template bool match_help(T &t, TM m) { if constexpr (i == std::tuple_size_v) { return i == std::tuple_size_v; } else if constexpr (i < std::tuple_size_v) { if constexpr (std::is_same_v(m.args))>, match_zero_or_more_t>) { return true; } return match(std::get(t.match_members()), std::get(m.args)) && match_help(t, m); } else { return false; } } template bool match_or_help(T &t, match_or_t m) { if constexpr (i >= 0 && i < std::tuple_size_v) { return match(t, std::get(m.args)) || match_or_help(t, m); } else { return false; } } template bool match(TM &t, match_or_t m) { return match_or_help<0, TM, TA...>(t, m); } template bool match(TM *t, match_or_t m) { return match_or_help<0, TM *, TA...>(t, m); } template bool match(T &t, match_t m) { if constexpr (std::is_pointer_v) { TM *tm = ir::cast(t); if (!tm) return false; if constexpr (sizeof...(TA) == 0) { if (m.fn) m.fn(*tm); return true; } else { auto r = match_help<0>(*tm, m); if (r && m.fn) m.fn(*tm); return r; } } else { if constexpr (!std::is_same_v) return false; if constexpr (sizeof...(TA) == 0) { if (m.fn) m.fn(t); return true; } else { auto r = match_help<0>(t, m); if (r && m.fn) m.fn(t); return r; } } } template bool match(T *t, match_t m) { return match(t, m); } } // namespace codon::matcher #define M_ matcher::match_ignore_t() ================================================ FILE: codon/parser/peg/grammar.peg ================================================ # Copyright (C) 2022 Exaloop Inc. # Codon PEG grammar # Adopted from Python 3's PEG grammar (https://docs.python.org/3/reference/grammar.html) # TODO: nice docstrs PREAMBLE { #include "codon/parser/peg/rules.h" #include using namespace std; using namespace codon::ast; #define V0 VS[0] #define V1 VS[1] #define V2 VS[2] #define ac std::any_cast #define ac_expr std::any_cast #define ac_stmt std::any_cast #define SemVals peg::SemanticValues #define aste(T, s, ...) setSI(CTX.cache->N(__VA_ARGS__), s) #define asts(T, s, ...) setSI(CTX.cache->N(__VA_ARGS__), s) template T *setSI(ASTNode *n, const codon::SrcInfo &s) { n->setSrcInfo(s); return (T*)n; } /// @return vector c transformed by the function f. template auto vmap(const std::vector &c, F &&f) { using VT = std::decay_t(f)(*std::begin(c)))>; std::vector ret; std::transform(std::begin(c), std::end(c), std::inserter(ret, std::end(ret)), f); return ret; } template auto vmap(const peg::SemanticValues &c, F &&f) { return vmap(static_cast&>(c), f); } Expr *chain(const codon::ast::ParseContext &CTX, peg::SemanticValues &VS, const codon::SrcInfo &LOC) { Expr *b = ac_expr(V0); for (int i = 1; i < VS.size(); i++) b = aste(Binary, LOC, b, VS.token_to_string(i - 1), ac_expr(VS[i])); return b; } Expr *wrap_tuple(const codon::ast::ParseContext &CTX, peg::SemanticValues &VS, const codon::SrcInfo &LOC) { if (VS.size() == 1 && VS.tokens.empty()) return ac_expr(V0); return aste(Tuple, LOC, VS.transform()); } } program <- (statements (_ EOL)* / (_ EOL)*) !. { if (VS.empty()) return asts(Suite, LOC); return ac_stmt(V0); } fstring <- fstring_prefix _ fstring_tail? _ !. { StringExpr::FormatSpec fs {"", "", ""}; auto [e, t] = ac>(V0); fs.text = t; if (VS.size() > 1) { auto [c, s] = ac>(V1); fs.conversion = c; fs.spec = s; } return make_pair(e, fs); } fstring_prefix <- star_expressions _ '='? { auto text = VS.sv(); return make_pair(ac_expr(V0), string(!VS.sv().empty() && VS.sv().back() == '=' ? VS.sv() : "")); } fstring_tail <- / fstring_conversion fstring_spec? { return make_pair(ac(V0), VS.size() > 1 ? ac(V1) : ""); } / fstring_spec { return make_pair(std::string(), ac(V0)); } fstring_conversion <- "!" ("s" / "r" / "a") { return string(VS.sv().substr(1)); } fstring_spec <- ':' format_spec { return ac(V0); } # Macros list(c, e) <- e (_ c _ e)* tlist(c, e) <- e (_ c _ e)* (_ )? statements <- ((_ EOL)* statement)+ { auto s = asts(Suite, LOC, VS.transform()); cast(s)->flatten(); return s; } statement <- SAMEDENT compound_stmt / SAMEDENT simple_stmt simple_stmt <- tlist(';', small_stmt) _ EOL { auto s = asts(Suite, LOC, VS.transform()); cast(s)->flatten(); return s; } small_stmt <- / directive / assignment / 'pass' &(SPACE / ';' / EOL) { return any(asts(Suite, LOC)); } / 'break' &(SPACE / ';' / EOL) { return any(asts(Break, LOC)); } / 'continue' &(SPACE / ';' / EOL) { return any(asts(Continue, LOC)); } / global_stmt / nonlocal_stmt / yield_stmt &(SPACE / ';' / EOL) / assert_stmt / del_stmt / return_stmt &(SPACE / ';' / EOL) / raise_stmt &(SPACE / ';' / EOL) / print_stmt / import_stmt / expressions &(_ ';' / _ EOL) { return any(asts(Expr, LOC, ac_expr(V0))); } / custom_small_stmt assignment <- / id _ ':' _ expression (_ '=' _ star_expressions)? { return asts(Assign, LOC, ac_expr(V0), VS.size() > 2 ? ac_expr(V2) : nullptr, ac_expr(V1) ); } / primary _ ':' _ expression (_ '=' _ star_expressions)? { return asts(Assign, LOC, ac_expr(V0), VS.size() > 2 ? ac_expr(V2) : nullptr, ac_expr(V1) ); } / (star_targets _ (!'==' '=') _)+ star_expressions !(_ '=') { vector stmts; for (int i = int(VS.size()) - 2; i >= 0; i--) { auto a = asts(Assign, LOC, ac_expr(VS[i]), ac_expr(VS[i + 1])); stmts.push_back(a); } return asts(Suite, LOC, std::move(stmts)); } / star_expression _ augassign '=' ^ _ star_expressions { auto a = asts(Assign, LOC, ac_expr(V0), aste(Binary, LOC, clone(ac_expr(V0)), ac(V1), ac_expr(V2), true)); return a; } augassign <- < '+' / '-' / '**' / '*' / '@' / '//' / '/' / '%' / '&' / '|' / '^' / '<<' / '>>' > { return VS.token_to_string(); } global_stmt <- 'global' SPACE tlist(',', NAME) { return asts(Suite, LOC, vmap(VS, [&](const any &i) { return asts(Global, LOC, ac(i), false); }) ); } nonlocal_stmt <- 'nonlocal' SPACE tlist(',', NAME) { return asts(Suite, LOC, vmap(VS, [&](const any &i) { return asts(Global, LOC, ac(i), true); }) ); } yield_stmt <- / 'yield' SPACE 'from' SPACE expression { return asts(YieldFrom, LOC, ac_expr(V0)); } / 'yield' (SPACE expressions)? { return asts(Yield, LOC, !VS.empty() ? ac_expr(V0) : nullptr); } assert_stmt <- 'assert' SPACE expression (_ ',' _ expression)? { return asts(Assert, LOC, ac_expr(V0), VS.size() > 1 ? ac_expr(V1) : nullptr); } # TODO: do targets as in Python del_stmt <- 'del' SPACE tlist(',', expression) { return asts(Suite, LOC, vmap(VS, [&](const any &i) { return asts(Del, LOC, ac_expr(i)); }) ); } return_stmt <- 'return' (SPACE expressions)? { return asts(Return, LOC, !VS.empty() ? ac_expr(V0) : nullptr); } raise_stmt <- / 'raise' SPACE expression (SPACE 'from' SPACE expression)? { return asts(Throw, LOC, ac_expr(V0), VS.size() > 1 ? ac_expr(V1) : nullptr); } / 'raise' { return asts(Throw, LOC, nullptr); } print_stmt <- / 'print' SPACE star_expression (_ ',' _ star_expression)* (_ <','>)? { return asts(Print, LOC, VS.transform(), !VS.tokens.empty()); } / 'print' _ &EOL { return asts(Print, LOC, vector{}, false); } import_stmt <- import_name / import_from import_name <- 'import' SPACE list(',', as_name) { return asts(Suite, LOC, vmap(VS.transform>(), [&](const pair &i) { return asts(Import, LOC, i.first, nullptr, vector{}, nullptr, i.second); }) ); } as_name <- dot_name (SPACE 'as' SPACE NAME)? { return pair(ac_expr(V0), VS.size() > 1 ? ac(V1) : ""); } import_from <- / 'from' SPACE (_ <'.'>)* (_ dot_name)? SPACE 'import' SPACE '*' { return asts(Import, LOC, VS.size() == 1 ? ac_expr(V0) : nullptr, aste(Id, LOC, "*"), vector{}, nullptr, "", int(VS.tokens.size()) ); } / 'from' SPACE (_ <'.'>)* (_ dot_name)? SPACE 'import' SPACE (from_as_parens / from_as_items) { auto f = VS.size() == 2 ? ac_expr(V0) : nullptr; return asts(Suite, LOC, vmap( ac(VS.size() == 2 ? V1 : V0), [&](const any &i) { auto p = ac>(i); auto t = ac, Expr*, bool>>(p.first); return asts(Import, LOC, f, get<0>(t), std::move(get<1>(t)), get<2>(t), p.second, int(VS.tokens.size()), get<3>(t) ); } ) ); } from_as_parens <- '(' _ tlist(',', from_as) _ ')' { return VS; } from_as_items <- list(',', from_as) { return VS; } from_as <- from_id (SPACE 'as' SPACE NAME)? { return pair(V0, VS.size() > 1 ? ac(V1) : ""); } from_id <- / dot_name _ ':' _ expression { return tuple(ac_expr(V0), vector(), ac_expr(V1), false); } / dot_name _ from_params (_ '->' _ expression)? { return tuple( ac_expr(V0), ac(V1).transform(), VS.size() > 2 ? ac_expr(V2) : aste(Id, LOC, "NoneType"), true ); } / dot_name { return tuple(ac_expr(V0), vector{}, (Expr*)nullptr, true); } dot_name <- id (_ '.' _ NAME)* { if (VS.size() == 1) return ac_expr(V0); auto dot = aste(Dot, LOC, ac_expr(V0), ac(V1)); for (int i = 2; i < VS.size(); i++) dot = aste(Dot, LOC, dot, ac(VS[i])); return dot; } from_params <- '(' _ tlist(',', from_param)? _ ')' { return VS; } from_param <- expression { return Param(LOC, "", ac_expr(V0), nullptr); } #TODO expand import logic / param { return ac(V0); } suite <- (simple_stmt / (_ EOL)+ &INDENT statements (_ EOL)* &DEDENT) { return ac_stmt(V0); } compound_stmt <- / function / if_stmt / class / with_stmt_async / for / try_stmt / while_stmt / match_stmt / custom_stmt if_stmt <- ('if' SPACE named_expression _ ':' _ suite) (SAMEDENT 'elif' SPACE named_expression _ ':' _ suite)* (SAMEDENT 'else' _ ':' _ suite)? { Stmt *lastElse = VS.size() % 2 == 0 ? nullptr : SuiteStmt::wrap(ac_stmt(VS.back())); for (size_t i = VS.size() - bool(lastElse); i-- > 0; ) { lastElse = asts(If, LOC, ac_expr(VS[i - 1]), SuiteStmt::wrap(ac_stmt(VS[i])), SuiteStmt::wrap(lastElse)); i--; } return lastElse; } while_stmt <- ('while' SPACE named_expression _ ':' _ suite) (SAMEDENT 'else' (SPACE 'not' SPACE 'break')* _ ':' _ suite)? { return asts(While, LOC, ac_expr(V0), ac_stmt(V1), VS.size() > 2 ? ac_stmt(V2) : nullptr ); } for <- decorator? for_stmt_async { if (VS.size() > 1) { auto s = (ForStmt*)(ac_stmt(V1)); s->setDecorator(ac_expr(V0)); return (Stmt*)s; } return ac_stmt(V0); } for_stmt_async <- / 'async' SPACE for_stmt { auto s = (ForStmt*)(ac_stmt(V0)); s->setAsync(); return make_any(s); } / for_stmt { return V0; } for_stmt <- ('for' SPACE star_targets) (SPACE 'in' SPACE star_expressions _ ':' _ suite) (SAMEDENT 'else' (SPACE 'not' SPACE 'break')* _ ':' _ suite)? { return asts(For, LOC, ac_expr(V0), ac_expr(V1), ac_stmt(V2), VS.size() > 3 ? ac_stmt(VS[3]) : nullptr ); } with_stmt_async <- / 'async' SPACE with_stmt { auto s = (WithStmt*)(ac_stmt(V0)); s->setAsync(); return make_any(s); } / with_stmt { return V0; } with_stmt <- 'with' SPACE (with_parens_item / with_item) _ ':' _ suite { return asts(With, LOC, ac(V0).transform>(), ac_stmt(V1), false ); } with_parens_item <- '(' _ tlist(',', as_item) _ ')' { return VS; } with_item <- list(',', as_item) { return VS; } as_item <- / expression SPACE 'as' SPACE id &(_ (',' / ')' / ':')) { return pair(ac_expr(V0), ac_expr(V1)); } / expression { return pair(ac_expr(V0), (Expr*)nullptr); } # TODO: else block? try_stmt <- / ('try' _ ':' _ suite) excepts else_finally? { std::pair ef {nullptr, nullptr}; if (VS.size() > 2) ef = ac>(V2); return asts(Try, LOC, ac_stmt(V0), ac(V1).transform(), ef.first, ef.second ); } / ('try' _ ':' _ suite) (SAMEDENT 'finally' _ ':' _ suite)? { return asts(Try, LOC, ac_stmt(V0), vector{}, nullptr, VS.size() > 1 ? ac_stmt(V1) : nullptr ); } else_finally <- / SAMEDENT 'else' _ ':' _ suite SAMEDENT 'finally' _ ':' _ suite { return std::pair(ac_stmt(V0), ac_stmt(V1)); } / SAMEDENT 'else' _ ':' _ suite { return std::pair(ac_stmt(V0), nullptr); } / SAMEDENT 'finally' _ ':' _ suite { return std::pair(nullptr, ac_stmt(V0)); } excepts <- (SAMEDENT except_block)+ { return VS; } except_block <- / 'except' SPACE expression (SPACE 'as' SPACE NAME)? _ ':' _ suite { if (VS.size() == 3) return setSI(CTX.cache->N(ac(V1), ac_expr(V0), ac_stmt(V2)), LOC); else return setSI(CTX.cache->N("", ac_expr(V0), ac_stmt(V1)), LOC); } / 'except' _ ':' _ suite { return setSI(CTX.cache->N("", nullptr, ac_stmt(V0)), LOC); } function <- / extern_decorators function_def_async (_ EOL)+ &INDENT extern (_ EOL)* &DEDENT { auto fn = (FunctionStmt*)(ac_stmt(V1)); fn->setDecorators(ac>(V0)); fn->setSuite(SuiteStmt::wrap(asts(Expr, LOC, aste(String, LOC, ac(V2))))); return (Stmt*)fn; } / decorators? function_def_async _ suite { auto fn = (FunctionStmt*)(ac_stmt(VS.size() > 2 ? V1 : V0)); if (VS.size() > 2) fn->setDecorators(ac>(V0)); fn->setSuite(SuiteStmt::wrap(ac_stmt(VS.size() > 2 ? V2 : V1))); return (Stmt*)fn; } extern <- (empty_line* EXTERNDENT (!EOL .)* EOL empty_line*)+ { return string(VS.sv()); } ~empty_line <- [ \t]* EOL function_def_async <- / 'async' SPACE function_def { auto fn = (FunctionStmt*)(ac_stmt(V0)); fn->setAsync(); return make_any(fn); } / function_def { return V0; } function_def <- / 'def' SPACE NAME _ generics _ params (_ '->' _ expression)? _ ':' { auto params = ac(V2).transform(); for (auto &p: ac>(V1)) params.push_back(p); return asts(Function, LOC, ac(V0), VS.size() == 4 ? ac_expr(VS[3]) : nullptr, params, nullptr ); } / 'def' SPACE NAME _ params (_ '->' _ expression)? _ ':' { return asts(Function, LOC, ac(V0), VS.size() == 3 ? ac_expr(VS[2]) : nullptr, ac(V1).transform(), nullptr ); } params <- '(' _ tlist(',', param)? _ ')' { return VS; } param <- / param_name _ ':' _ expression (_ '=' _ expression)? { return Param(LOC, ac(V0), ac_expr(V1), VS.size() > 2 ? ac_expr(V2) : nullptr); } / param_name (_ '=' _ expression)? { return Param(LOC, ac(V0), nullptr, VS.size() > 1 ? ac_expr(V1) : nullptr); } param_name <- <'**' / '*'>? _ NAME { return (!VS.tokens.empty() ? VS.token_to_string() : "") + ac(V0); } generics <- '[' _ tlist(',', param) _ ']' { vector params; for (auto &p: VS) { auto v = ac(p); v.status = Param::Generic; if (!v.type) v.type = aste(Id, LOC, "type"); params.push_back(v); } return params; } decorators <- decorator+ { return VS.transform(); } decorator <- ('@' _ !(('llvm' / 'python') _ EOL) named_expression _ EOL SAMEDENT) { return ac_expr(V0); } extern_decorators <- / decorators? ('@' _ <'llvm'/'python'> _ EOL SAMEDENT) decorators? { vector vs{aste(Id, LOC, VS.token_to_string())}; for (auto &v: VS) { auto nv = ac>(v); vs.insert(vs.end(), nv.begin(), nv.end()); } return vs; } class <- decorators? class_def { if (VS.size() == 2) { auto fn = ac_stmt(V1); cast(fn)->setDecorators(ac>(V0)); return fn; } return ac_stmt(V0); } base_class_args <- '(' _ tlist(',', expression)? _ ')' { return VS.transform(); } class_args <- / generics _ base_class_args { return make_pair(ac>(V0), ac>(V1)); } / generics { return make_pair(ac>(V0), vector{}); } / base_class_args { return make_pair(vector{}, ac>(V0)); } class_def <- 'class' SPACE NAME _ class_args? _ ':' _ suite { vector generics; vector baseClasses; if (VS.size() == 3) std::tie(generics, baseClasses) = ac, vector>>(V1); vector args; auto suite = (SuiteStmt*)(asts(Suite, LOC)); auto s = cast(ac_stmt(VS.size() == 3 ? V2 : V1)); seqassertn(s, "not a suite"); for (auto *i: *s) { if (auto a = cast(i)) if (auto ei = cast(a->getLhs())) { args.push_back(Param(a->getSrcInfo(), ei->getValue(), a->getTypeExpr(), a->getRhs())); continue; } suite->addStmt(i); } suite->flatten(); for (auto &p: generics) args.push_back(p); return asts(Class, LOC, ac(V0), std::move(args), suite, vector{}, baseClasses ); } match_stmt <- 'match' SPACE expression _ ':' (_ EOL)+ &INDENT (SAMEDENT case)+ (_ EOL)* &DEDENT { return asts(Match, LOC, ac_expr(V0), VS.transform(1)); } case <- / 'case' SPACE expression SPACE 'if' SPACE pipe _ ':' _ suite { return MatchCase{ac_expr(V0), ac_expr(V1), ac_stmt(V2)}; } / 'case' SPACE expression _ ':' _ suite { return MatchCase{ac_expr(V0), nullptr, ac_stmt(V1)}; } custom_stmt <- / NAME SPACE expression _ ':' _ suite { return asts(Custom, LOC, ac(V0), ac_expr(V1), ac_stmt(V2)); } / NAME _ ':' _ suite { return asts(Custom, LOC, ac(V0), nullptr, ac_stmt(V2)); } custom_stmt__PREDICATE { auto kwd = ac(V0); return CTX.hasCustomStmtKeyword(kwd, VS.choice() == 0); // ignore it } custom_small_stmt <- NAME SPACE expressions { return any(asts(Custom, LOC, ac(V0), ac_expr(V1), nullptr)); } custom_small_stmt__PREDICATE { auto kwd = ac(V0); return CTX.hasCustomExprStmt(kwd); // ignore it } directive <- '##' _ 'codon:' _ NAME _ '=' _ (INT / NAME) { return asts(Directive, LOC, ac(V0), ac(V1)); } ######################################################################################## # (2) Expressions ######################################################################################## expressions <- tlist(',', expression) { return wrap_tuple(CTX, VS, LOC); } expression <- / lambdef { return ac_expr(V0); } / disjunction SPACE? 'if' SPACE? disjunction SPACE? 'else' SPACE? expression { return aste(If, LOC, ac_expr(V1), ac_expr(V0), ac_expr(V2)); } / pipe { return ac_expr(V0); } lambdef <- / 'lambda' SPACE lparams _ ':' _ expression { return aste(Lambda, LOC, ac(V0).transform(), ac_expr(V1) ); } / 'lambda' _ ':' _ expression { return aste(Lambda, LOC, vector{}, ac_expr(V0)); } lparams <- tlist(',', lparam)? { return VS; } lparam <- param_name (_ '=' _ expression)? { return Param(LOC, ac(V0), nullptr, VS.size() > 1 ? ac_expr(V1) : nullptr); } pipe <- / disjunction (_ <'|>' / '||>'> _ disjunction)+ { vector v; for (int i = 0; i < VS.size(); i++) v.push_back(Pipe{i ? VS.token_to_string(i - 1) : "", ac_expr(VS[i])}); return aste(Pipe, LOC, std::move(v)); } / disjunction { return ac_expr(V0); } disjunction <- / conjunction (SPACE? 'or' SPACE? conjunction)+ { auto b = aste(Binary, LOC, ac_expr(V0), "||", ac_expr(V1)); for (int i = 2; i < VS.size(); i++) b = aste(Binary, LOC, b, "||", ac_expr(VS[i])); return b; } / conjunction { return ac_expr(V0); } conjunction <- / inversion (SPACE? 'and' SPACE? inversion)+ { auto b = aste(Binary, LOC, ac_expr(V0), "&&", ac_expr(V1)); for (int i = 2; i < VS.size(); i++) b = aste(Binary, LOC, b, "&&", ac_expr(VS[i])); return b; } / inversion { return ac_expr(V0); } inversion <- / 'not' SPACE inversion { return aste(Unary, LOC, "!", ac_expr(V0)); } / comparison { return ac_expr(V0); } comparison <- bitwise_or compare_op_bitwise_or* { if (VS.size() == 1) { return ac_expr(V0); } else if (VS.size() == 2) { auto p = ac>(V1); return aste(Binary, LOC, ac_expr(V0), p.first, p.second); } else { vector> v{pair(string(), ac_expr(V0))}; auto vp = VS.transform>(1); v.insert(v.end(), vp.begin(), vp.end()); return aste(ChainBinary, LOC, std::move(v)); } } compare_op_bitwise_or <- / SPACE 'not' SPACE 'in' SPACE bitwise_or { return pair(string("not in"), ac_expr(V0)); } / SPACE 'is' SPACE 'not' SPACE bitwise_or { return pair(string("is not"), ac_expr(V0)); } / SPACE <'in' / 'is'> SPACE bitwise_or { return pair(VS.token_to_string(), ac_expr(V0)); } / _ <'==' / '!=' / '<=' / '<' / '>=' / '>'> _ bitwise_or { return pair(VS.token_to_string(), ac_expr(V0)); } bitwise_or <- bitwise_xor (_ <'|'> _ bitwise_xor)* { return chain(CTX, VS, LOC); } bitwise_xor <- bitwise_and (_ <'^'> _ bitwise_and)* { return chain(CTX, VS, LOC); } bitwise_and <- shift_expr (_ <'&'> _ shift_expr )* { return chain(CTX, VS, LOC); } shift_expr <- sum (_ <'<<' / '>>'> _ sum )* { return chain(CTX, VS, LOC); } sum <- term (_ <'+' / '-'> _ term)* { return chain(CTX, VS, LOC); } term <- factor (_ <'*' / '//' / '/' / '%' / '@'> _ factor)* { return chain(CTX, VS, LOC); } factor <- / <'+' / '-' / '~'> _ factor { return aste(Unary, LOC, VS.token_to_string(), ac_expr(V0)); } / power { return ac_expr(V0); } power <- / await_primary _ <'**'> _ factor { return aste(Binary, LOC, ac_expr(V0), "**", ac_expr(V1)); } / await_primary { return ac_expr(V0); } await_primary <- / 'await' SPACE primary { return aste(Await, LOC, ac_expr(V0)); } / primary { return ac_expr(V0); } primary <- atom (_ primary_tail)* { auto e = ac(V0); for (int i = 1; i < VS.size(); i++) { auto p = ac>(VS[i]); if (p.first == 0) { e = aste(Dot, LOC, e, ac(p.second)); } else if (p.first == 1) { e = aste(Call, LOC, e, ac_expr(p.second)); } else if (p.first == 2) { e = aste(Call, LOC, e, ac>(p.second)); } else { e = aste(Index, LOC, e, ac_expr(p.second)); } } return e; } primary_tail <- / '.' _ NAME { return pair(0, V0); } / genexp { return pair(1, V0); } / arguments { return pair(2, VS.size() ? V0 : any(vector{})); } / slices { return pair(3, V0); } slices <- '[' _ tlist(',', slice) _ ']' { return wrap_tuple(CTX, VS, LOC); } slice <- / slice_part _ ':' _ slice_part (_ ':' _ slice_part)? { return aste(Slice, LOC, ac_expr(V0), ac_expr(V1), VS.size() > 2 ? ac_expr(V2) : nullptr ); } / expression { return ac_expr(V0); } slice_part <- expression? { return VS.size() ? V0 : make_any(nullptr); } atom <- / STRING (SPACE STRING)* { auto e = aste(String, LOC, VS.transform()); return e; } / id { return ac_expr(V0); } / 'True' { return aste(Bool, LOC, true); } / 'False' { return aste(Bool, LOC, false);} / 'None' { return aste(None, LOC); } / INT _ '...' _ INT { return aste(Range, LOC, aste(Int, LOC, ac(V0)), aste(Int, LOC, ac(V1)) ); } / FLOAT NAME? { return aste(Float, LOC, ac(V0), VS.size() > 1 ? ac(V1) : ""); } / INT NAME? { return aste(Int, LOC, ac(V0), VS.size() > 1 ? ac(V1) : ""); } / parentheses { return ac_expr(V0); } / '...' { return aste(Ellipsis, LOC); } parentheses <- ( tuple / yield / named / genexp / listexpr / listcomp / dict / set / dictcomp / setcomp ) tuple <- / '(' _ ')' { return aste(Tuple, LOC, VS.transform()); } / '(' _ tlist(',', star_named_expression) _ ')' { return wrap_tuple(CTX, VS, LOC); } yield <- '(' _ 'yield' _ ')' { return aste(Yield, LOC); } named <- '(' _ named_expression _ ')' genexp <- '(' _ named_expression SPACE for_if_clauses _ ')' { return aste(Generator, LOC, CTX.cache, GeneratorExpr::Generator, ac_expr(V0), ac>(V1) ); } listexpr <- '[' _ tlist(',', star_named_expression)? _ ']' { return aste(List, LOC, VS.transform()); } listcomp <- '[' _ named_expression SPACE for_if_clauses _ ']' { return aste(Generator, LOC, CTX.cache, GeneratorExpr::ListGenerator, ac_expr(V0), ac>(V1) ); } set <- '{' _ tlist(',', star_named_expression) _ '}' { return aste(Set, LOC, VS.transform()); } setcomp <- '{' _ named_expression SPACE for_if_clauses _ '}' { return aste(Generator, LOC, CTX.cache, GeneratorExpr::SetGenerator, ac_expr(V0), ac>(V1) ); } dict <- '{' _ tlist(',', double_starred_kvpair)? _ '}' { return aste(Dict, LOC, VS.transform()); } dictcomp <- '{' _ kvpair SPACE for_if_clauses _ '}' { auto p = ac(V0); return aste(Generator, LOC, CTX.cache, (*cast(p))[0], (*cast(p))[1], ac>(V1) ); } double_starred_kvpair <- / '**' _ bitwise_or { return aste(KeywordStar, LOC, ac_expr(V0)); } / kvpair { return ac(V0); } kvpair <- expression _ ':' _ expression { return aste(Tuple, LOC, std::vector{ac_expr(V0), ac_expr(V1)}); } for_if_clauses <- for_if_clause_async (SPACE for_if_clause_async)* { std::vector v = ac>(V0); auto tail = VS.transform>(1); for (auto &t: tail) v.insert(v.end(), t.begin(), t.end()); return v; } for_if_clause_async <- / 'async' SPACE for_if_clause { auto s = ac>(V0); if (!s.empty() && cast(s.front())) cast(s.front())->setAsync(); return make_any>(s); } / for_if_clause { return V0; } for_if_clause <- / 'for' SPACE star_targets SPACE 'in' SPACE disjunction (SPACE? 'if' SPACE? disjunction)* { std::vector v{asts(For, LOC, ac_expr(V0), ac_expr(V1), nullptr)}; auto tail = VS.transform(2); for (auto &t: tail) v.push_back(asts(If, LOC, t, nullptr)); return v; } star_targets <- tlist(',', star_target) { return wrap_tuple(CTX, VS, LOC); } star_target <- / '*' _ !'*' star_target { return aste(Star, LOC, ac_expr(V0)); } / star_parens { return ac_expr(V0); } / primary { return ac_expr(V0); } star_parens <- / '(' _ tlist(',', star_target) _ ')' { return wrap_tuple(CTX, VS, LOC); } / '[' _ tlist(',', star_target) _ ']' { return wrap_tuple(CTX, VS, LOC); } star_expressions <- tlist(',', star_expression) { return wrap_tuple(CTX, VS, LOC); } star_expression <- / '*' _ bitwise_or { return aste(Star, LOC, ac_expr(V0)); } / expression { return ac_expr(V0); } star_named_expression <- / '*' _ bitwise_or { return aste(Star, LOC, ac_expr(V0)); } / named_expression { return ac_expr(V0); } named_expression <- / NAME _ ':=' _ ^ expression { return aste(Assign, LOC, aste(Id, LOC, ac(V0)), ac_expr(V1)); } / expression !(_ ':=') { return ac_expr(V0); } arguments <- '(' _ tlist(',', args)? _ ')' { vector result; for (auto &v: VS) for (auto &i: ac>(v)) result.push_back(i); if (EllipsisExpr *e = nullptr; !result.empty() && result.back().getName().empty() && ((e = cast(result.back().getExpr())))) { auto en = CTX.cache->N(EllipsisExpr::PARTIAL); en->setSrcInfo(e->getSrcInfo()); result.back() = CallArg{"", en}; } return result; } args <- (simple_args (_ ',' _ kwargs)? / kwargs) { auto args = ac>(V0); if (VS.size() > 1) { auto v = ac>(V1); args.insert(args.end(), v.begin(), v.end()); } return args; } simple_args <- list(',', (starred_expression / named_expression !(_ '='))) { return vmap(VS, [](auto &i) { return CallArg(ac_expr(i)); }); } starred_expression <- '*' _ expression { return aste(Star, LOC, ac_expr(V0)); } kwargs <- / list(',', kwarg_or_starred) _ ',' _ list(',', kwarg_or_double_starred) { return VS.transform(); } / list(',', kwarg_or_starred) { return VS.transform(); } / list(',', kwarg_or_double_starred) { return VS.transform(); } kwarg_or_starred <- / NAME _ '=' _ expression { return CallArg(LOC, ac(V0), ac_expr(V1)); } / starred_expression { return CallArg(ac_expr(V0)); } kwarg_or_double_starred <- / NAME _ '=' _ expression { return CallArg(LOC, ac(V0), ac_expr(V1)); } / '**' _ expression { return CallArg(aste(KeywordStar, LOC, ac_expr(V0))); } id <- NAME { return aste(Id, LOC, ac(V0)); } INT <- (BININT / HEXINT / DECINT) { return string(VS.sv()); } BININT <- <'0' [bB] [0-1] ('_'* [0-1])*> HEXINT <- <'0' [xX] [0-9a-fA-F] ('_'? [0-9a-fA-F])*> DECINT <- <[0-9] ('_'? [0-9])*> FLOAT <- (EXPFLOAT / PTFLOAT) { return string(VS.sv()); } PTFLOAT <- DECINT? '.' DECINT / DECINT '.' EXPFLOAT <- (PTFLOAT / DECINT) [eE] <'+' / '-'>? DECINT NAME <- / keyword [a-zA-Z_0-9]+ { return string(VS.sv()); } / !keyword <[a-zA-Z_] [a-zA-Z_0-9]*> { return VS.token_to_string(); } STRING <- { auto p = StringExpr::String( ac(VS.size() > 1 ? V1 : V0), VS.size() > 1 ? ac(V0) : "" ); if (p.prefix != "r" && p.prefix != "R") { p.value = unescape(p.value); } else { p.prefix = ""; } return p; } STRING__PREDICATE { auto p = StringExpr::String( ac(VS.size() > 1 ? V1 : V0), VS.size() > 1 ? ac(V0) : "" ); if (p.prefix != "r" && p.prefix != "R") try { p.value = unescape(p.value); } catch (std::invalid_argument &e) { MSG = "invalid code in a string"; return false; } catch (std::out_of_range &) { MSG = "invalid code in a string"; return false; } return true; } STR <- < '"""' (!'"""' CHAR)* '"""' / '\'\'\'' (!'\'\'\'' CHAR)* '\'\'\'' / '"' (!('"' / EOL) CHAR)* '"' / '\'' (!('\'' / EOL) CHAR)* '\'' > { string s; s.reserve(VS.size()); for (auto &v: VS) s.append(ac(v)); return s; } CHAR <- ('\\' . / .) { return string(VS.sv()); } ~COMMENT <- !directive <'#' (!EOL .)*> ~INDENT__NOPACKRAT <- <[ \t]*> { CTX.indent.push(VS.sv().size()); } INDENT__PREDICATE { if (!(CTX.indent.empty() && VS.sv().size()) && !(!CTX.indent.empty() && VS.sv().size() > CTX.indent.top())) { MSG = "unexpected indentation"; return false; } return true; } ~SAMEDENT__NOPACKRAT <- <[ \t]*> {} SAMEDENT__PREDICATE { return !(!CTX.indent.size() && VS.sv().size()) && !(CTX.indent.size() && VS.sv().size() != CTX.indent.top()); } ~DEDENT__NOPACKRAT <- <[ \t]*> { CTX.indent.pop(); } DEDENT__PREDICATE { if (!(CTX.indent.size() && VS.sv().size() < CTX.indent.top())) { MSG = "unexpected dedent"; return false; } return true; } ~EXTERNDENT__NOPACKRAT <- <[ \t]*> {} EXTERNDENT__PREDICATE { return !(!CTX.indent.size() && VS.sv().size()) && !(CTX.indent.size() && VS.sv().size() < CTX.indent.top()); } ~EOL <- <[\r][\n] / [\r\n]> ~SPACE <- ([ \t]+ / COMMENT / NLP EOL) SPACE? ~_ <- SPACE? ~keyword <- < 'False' / 'else' / 'import' / 'pass' / 'None' / 'break' / 'except' / 'in' / 'raise' / 'True' / 'class' / 'finally' / 'is' / 'return' / 'and' / 'continue' / 'for' / 'as' / 'lambda' / 'try' / 'def' / 'from' / 'while' / 'assert' / 'del' / 'global' / 'not' / 'with' / 'elif' / 'if' / 'or' / 'yield' / 'async' / 'await' > # https://docs.python.org/3/library/string.html#formatspec format_spec <- ([^{}] [<>=^] / [<>=^])? [+- ]? 'z'? '#'? '0'? [0-9]* [_,]* ('.' [0-9]+)? [bcdeEfFgGnosxX%]? { return string(VS.sv()); } ================================================ FILE: codon/parser/peg/openmp.peg ================================================ # Copyright (C) 2022 Exaloop Inc. # OpenMP PEG grammar PREAMBLE { #include "codon/parser/peg/rules.h" #include using namespace std; using namespace codon::ast; #define V0 VS[0] #define V1 VS[1] #define ac std::any_cast template T *setSI(ASTNode *n, const codon::SrcInfo &s) { n->setSrcInfo(s); return (T*)n; } #define ast(T, s, ...) setSI(CTX.cache->N(__VA_ARGS__), s) } pragma <- "omp"? _ "parallel"? _ (clause _)* { std::vector v; for (auto &i: VS) { auto vi = ac>(i); v.insert(v.end(), vi.begin(), vi.end()); } return v; } clause <- / "schedule" _ "(" _ schedule_kind (_ "," _ int)? _ ")" { // CTX; std::vector v{CallArg{"schedule", ast(String, LOC, ac(V0))}}; if (VS.size() > 1) v.push_back(CallArg{"chunk_size", ast(Int, LOC, ac(V1))}); return v; } / "num_threads" _ "(" _ int _ ")" { return std::vector{CallArg{"num_threads", ast(Int, LOC, ac(V0))}}; } / "ordered" { return std::vector{CallArg{"ordered", ast(Bool, LOC, true)}}; } / "collapse" { return std::vector{CallArg{"collapse", ast(Int, LOC, ac(V0))}}; } / "gpu" { return std::vector{CallArg{"gpu", ast(Bool, LOC, true)}}; } schedule_kind <- ("static" / "dynamic" / "guided" / "auto" / "runtime") { return VS.token_to_string(); } int <- [1-9] [0-9]* { return stoi(VS.token_to_string()); } # ident <- [a-zA-Z_] [a-zA-Z_0-9]* { # return ast(VS.token_to_string()); # } ~SPACE <- [ \t]+ ~_ <- SPACE* ================================================ FILE: codon/parser/peg/peg.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "peg.h" #include #include #include #include #include #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/peg/rules.h" #include "codon/parser/visitors/format/format.h" #include double totalPeg = 0.0; namespace codon::ast { static std::shared_ptr grammar(nullptr); static std::shared_ptr ompGrammar(nullptr); std::shared_ptr initParser() { auto g = std::make_shared(); init_codon_rules(*g); init_codon_actions(*g); ~(*g)["NLP"] <= peg::usr([](const char *s, size_t n, peg::SemanticValues &, std::any &dt) { auto e = (n >= 1 && s[0] == '\\' ? 1 : -1); if (std::any_cast(dt).parens && e == -1) e = 0; return e; }); for (auto &val : *g | std::views::values) { auto v = peg::LinkReferences(*g, val.params); val.accept(v); } (*g)["program"].enablePackratParsing = true; (*g)["fstring"].enablePackratParsing = true; for (auto &rule : std::vector{ "arguments", "slices", "genexp", "parentheses", "star_parens", "generics", "with_parens_item", "params", "from_as_parens", "from_params"}) { (*g)[rule].enter = [](const peg::Context &, const char *, size_t, std::any &dt) { std::any_cast(dt).parens++; }; (*g)[rule.c_str()].leave = [](const peg::Context &, const char *, size_t, size_t, std::any &, std::any &dt) { std::any_cast(dt).parens--; }; } return g; } template llvm::Expected parseCode(Cache *cache, const std::string &file, const std::string &code, int line_offset, int col_offset, const std::string &rule) { Timer t(""); t.logged = true; // Initialize if (!grammar) grammar = initParser(); std::vector errors; auto log = [&](size_t line, size_t col, const std::string &msg, const std::string &) { size_t ed = msg.size(); if (startswith(msg, "syntax error, unexpected")) { auto i = msg.find(", expecting"); if (i != std::string::npos) ed = i; } errors.emplace_back(msg.substr(0, ed), file, line, col); }; T result; auto ctx = std::make_any(cache, 0, line_offset, col_offset); auto r = (*grammar)[rule].parse_and_get_value(code.c_str(), code.size(), ctx, result, file.c_str(), log); auto ret = r.ret && r.len == code.size(); if (!ret) r.error_info.output_log(log, code.c_str(), code.size()); totalPeg += t.elapsed(); if (!errors.empty()) return llvm::make_error(errors); return result; } llvm::Expected parseCode(Cache *cache, const std::string &file, const std::string &code, int line_offset) { return parseCode(cache, file, code + "\n", line_offset, 0, "program"); } llvm::Expected> parseExpr(Cache *cache, const std::string &code, const codon::SrcInfo &offset) { auto newCode = code; ltrim(newCode); rtrim(newCode); return parseCode>( cache, offset.file, newCode, offset.line, offset.col, "fstring"); } llvm::Expected parseFile(Cache *cache, const std::string &file) { auto lines = cache->fs->read_lines(file); cache->imports[file].content = lines; std::string code = join(lines, "\n"); auto result = parseCode(cache, file, code); // /* For debugging purposes: */ LOG("peg/{} := {}", file, result); return result; } std::shared_ptr initOpenMPParser() { auto g = std::make_shared(); init_omp_rules(*g); init_omp_actions(*g); for (auto &val : *g | std::views::values) { auto v = peg::LinkReferences(*g, val.params); val.accept(v); } (*g)["pragma"].enablePackratParsing = true; return g; } llvm::Expected> parseOpenMP(Cache *cache, const std::string &code, const codon::SrcInfo &loc) { if (!ompGrammar) ompGrammar = initOpenMPParser(); std::vector errors; auto log = [&](size_t line, size_t col, const std::string &msg, const std::string &) { errors.emplace_back(fmt::format("openmp: {}", msg), loc.file, loc.line, loc.col); }; std::vector result; auto ctx = std::make_any(cache, 0, 0, 0); auto r = (*ompGrammar)["pragma"].parse_and_get_value(code.c_str(), code.size(), ctx, result, "", log); auto ret = r.ret && r.len == code.size(); if (!ret) r.error_info.output_log(log, code.c_str(), code.size()); if (!errors.empty()) return llvm::make_error(errors); return result; } } // namespace codon::ast ================================================ FILE: codon/parser/peg/peg.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include "codon/parser/ast.h" #include "codon/util/common.h" namespace codon::ast { /// Parse a Seq code block with the appropriate file and position offsets. llvm::Expected parseCode(Cache *cache, const std::string &file, const std::string &code, int line_offset = 0); /// Parse a Seq code expression. /// @return pair of Expr * and a format specification /// (empty if not available). llvm::Expected> parseExpr(Cache *cache, const std::string &code, const codon::SrcInfo &offset); /// Parse a Seq file. llvm::Expected parseFile(Cache *cache, const std::string &file); /// Parse a OpenMP clause. llvm::Expected> parseOpenMP(Cache *cache, const std::string &code, const codon::SrcInfo &loc); } // namespace codon::ast ================================================ FILE: codon/parser/peg/rules.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include "codon/parser/cache.h" #include "codon/parser/common.h" namespace codon::ast { struct ParseContext { Cache *cache; std::stack indent; int parens; int line_offset, col_offset; ParseContext(Cache *cache, int parens = 0, int line_offset = 0, int col_offset = 0) : cache(cache), parens(parens), line_offset(line_offset), col_offset(col_offset) { } bool hasCustomStmtKeyword(const std::string &kwd, bool hasExpr) const { auto i = cache->customBlockStmts.find(kwd); if (i != cache->customBlockStmts.end()) return i->second.first == hasExpr; return false; } bool hasCustomExprStmt(const std::string &kwd) const { return in(cache->customExprStmts, kwd); } }; } // namespace codon::ast void init_codon_rules(peg::Grammar &); void init_codon_actions(peg::Grammar &); void init_omp_rules(peg::Grammar &); void init_omp_actions(peg::Grammar &); ================================================ FILE: codon/parser/visitors/doc/doc.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "doc.h" #include #include #include #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/match.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/format/format.h" #include "codon/parser/visitors/scoping/scoping.h" namespace codon::ast { using namespace error; using namespace matcher; // clang-format off std::string json_escape(const std::string &str) { std::string r; r.reserve(str.size()); for (unsigned char c : str) { switch (c) { case '\b': r += "\\b"; break; case '\f': r += "\\f"; break; case '\n': r += "\\n"; break; case '\r': r += "\\r"; break; case '\t': r += "\\t"; break; case '\\': r += "\\\\"; break; case '"': r += "\\\""; break; default: r += c; } } return r; } // clang-format on json::json() : list(false) {} json::json(const std::string &s) : list(false) { values[s] = nullptr; } json::json(const std::string &s, const std::string &v) : list(false) { values[s] = std::make_shared(v); } json::json(const std::vector> &vs) : list(true) { for (int i = 0; i < vs.size(); i++) values[std::to_string(i)] = vs[i]; } json::json(const std::vector &vs) : list(true) { for (int i = 0; i < vs.size(); i++) values[std::to_string(i)] = std::make_shared(vs[i]); } json::json(const std::unordered_map &vs) : list(false) { for (auto &v : vs) values[v.first] = std::make_shared(v.second); } std::string json::toString() { std::vector s; if (values.empty()) { return "{}"; } else if (values.size() == 1 && !values.begin()->second) { return fmt::format("\"{}\"", json_escape(values.begin()->first)); } else if (list) { for (int i = 0; i < values.size(); i++) s.push_back(values[std::to_string(i)]->toString()); return fmt::format("[ {} ]", join(s, ", ")); } else { for (auto &v : values) s.push_back( fmt::format("\"{}\": {}", json_escape(v.first), v.second->toString())); return fmt::format("{{ {} }}", join(s, ", ")); } } std::shared_ptr json::get(const std::string &s) { auto i = values.find(s); seqassertn(i != values.end(), "cannot find {}", s); return i->second; } std::shared_ptr json::set(const std::string &s, const std::string &value) { return values[s] = std::make_shared(value); } std::shared_ptr json::set(const std::string &s, const std::shared_ptr &value) { return values[s] = value; } std::shared_ptr DocVisitor::apply(const std::string &argv0, const std::vector &files) { auto shared = std::make_shared(); shared->argv0 = argv0; auto cache = std::make_unique(argv0); shared->cache = cache.get(); shared->modules[""] = std::make_shared(shared); shared->j = std::make_shared(); auto stdlib = getImportFile(cache.get(), STDLIB_INTERNAL_MODULE, "", true); auto astOrErr = ast::parseFile(shared->cache, stdlib->path); if (!astOrErr) throw exc::ParserException(astOrErr.takeError()); auto coreOrErr = ast::parseCode(shared->cache, stdlib->path, "from internal.core import *"); if (!coreOrErr) throw exc::ParserException(coreOrErr.takeError()); shared->modules[""]->setFilename(stdlib->path); shared->modules[""]->add("__py_numerics__", std::make_shared(shared->itemID++)); shared->modules[""]->add("__py_extension__", std::make_shared(shared->itemID++)); shared->modules[""]->add("__debug__", std::make_shared(shared->itemID++)); shared->modules[""]->add("__apple__", std::make_shared(shared->itemID++)); auto j = std::make_shared(std::unordered_map{ {"name", "type"}, {"kind", "class"}, {"type", "type"}}); j->set("generics", std::make_shared(std::vector{"T"})); shared->modules[""]->add("type", std::make_shared(shared->itemID)); shared->j->set(std::to_string(shared->itemID++), j); j = std::make_shared(std::unordered_map{ {"name", "Literal"}, {"kind", "class"}, {"type", "type"}}); j->set("generics", std::make_shared(std::vector{"T"})); shared->modules[""]->add("Literal", std::make_shared(shared->itemID)); shared->j->set(std::to_string(shared->itemID++), j); DocVisitor(shared->modules[""]).transformModule(*coreOrErr); DocVisitor(shared->modules[""]).transformModule(*astOrErr); auto ctx = std::make_shared(shared); for (auto &f : files) { auto path = std::string(cache->fs->canonical(f)); ctx->setFilename(path); // LOG("-> parsing {}", path); auto fAstOrErr = ast::parseFile(shared->cache, path); if (!fAstOrErr) throw exc::ParserException(fAstOrErr.takeError()); DocVisitor(ctx).transformModule(*fAstOrErr); } shared->cache = nullptr; return shared->j; } std::shared_ptr DocContext::find(const std::string &s) const { auto i = Context::find(s); if (!i && this != shared->modules[""].get()) return shared->modules[""]->find(s); return i; } std::string getDocstr(Stmt *s) { if (auto se = cast(s)) if (auto e = cast(se->getExpr())) return e->getValue(); return ""; } std::vector DocVisitor::flatten(Stmt *stmt, std::string *docstr, bool deep) { std::vector stmts; if (auto s = cast(stmt)) { for (int i = 0; i < (deep ? s->size() : 1); i++) { for (auto &x : flatten((*s)[i], i ? nullptr : docstr, deep)) stmts.push_back(std::move(x)); } } else { if (docstr) *docstr = getDocstr(stmt); stmts.push_back(std::move(stmt)); } return stmts; } std::shared_ptr DocVisitor::transform(Expr *expr) { if (!expr) return std::make_shared(); DocVisitor v(ctx); v.setSrcInfo(expr->getSrcInfo()); v.resultExpr = std::make_shared(); expr->accept(v); return v.resultExpr; } std::string DocVisitor::transform(Stmt *stmt) { if (!stmt) return ""; DocVisitor v(ctx); v.setSrcInfo(stmt->getSrcInfo()); stmt->accept(v); return v.resultStmt; } void DocVisitor::transformModule(Stmt *stmt) { if (auto err = ScopingVisitor::apply(ctx->shared->cache, stmt)) throw exc::ParserException(std::move(err)); std::vector children; std::string docstr; auto flat = flatten(std::move(stmt), &docstr); for (int i = 0; i < flat.size(); i++) { auto &s = flat[i]; auto id = transform(s); if (id.empty()) continue; if (i < (flat.size() - 1) && cast(s)) { auto ds = getDocstr(flat[i + 1]); if (!ds.empty()) ctx->shared->j->get(id)->set("doc", ds); } children.push_back(id); } auto id = std::to_string(ctx->shared->itemID++); auto ja = ctx->shared->j->set( id, std::make_shared(std::unordered_map{ {"kind", "module"}, {"path", ctx->getFilename()}})); ja->set("children", std::make_shared(children)); if (!docstr.empty()) ja->set("doc", docstr); } void DocVisitor::visit(IntExpr *expr) { auto [value, _] = expr->getRawData(); resultExpr = std::make_shared(value); } void DocVisitor::visit(IdExpr *expr) { auto i = ctx->find(expr->getValue()); if (!i) E(Error::CUSTOM, expr->getSrcInfo(), "unknown identifier {}", expr->getValue()); resultExpr = std::make_shared(*i ? std::to_string(*i) : expr->getValue()); } void DocVisitor::visit(IndexExpr *expr) { auto tr = [&](Expr *e) { if (match(e, MOr(M(), M(), M()))) return std::make_shared(FormatVisitor::apply(e)); else return transform(e); }; std::vector> v; v.push_back(transform(expr->getExpr())); if (auto tp = cast(expr->getIndex())) { if (auto l = cast((*tp)[0])) { for (auto *e : *l) v.push_back(tr(e)); v.push_back(tr((*tp)[1])); } else for (auto *e : *tp) { v.push_back(tr(e)); } } else { v.push_back(tr(expr->getIndex())); } resultExpr = std::make_shared(v); } bool isValidName(const std::string &s) { if (s.empty()) return false; if (s.size() > 4 && s.substr(0, 2) == "__" && s.substr(s.size() - 2) == "__") return true; return s[0] != '_'; } void DocVisitor::visit(FunctionStmt *stmt) { int id = ctx->shared->itemID++; ctx->add(stmt->getName(), std::make_shared(id)); auto j = std::make_shared(std::unordered_map{ {"kind", "function"}, {"name", stmt->getName()}}); j->set("pos", jsonify(stmt->getSrcInfo())); std::vector> args; std::vector generics; for (auto &a : *stmt) if (!a.isValue()) { ctx->add(a.name, std::make_shared(0)); generics.push_back(a.name); a.status = Param::Generic; } for (auto &a : *stmt) { auto jj = std::make_shared(); jj->set("name", a.name); if (a.type) { auto tt = transform(a.type); if (tt->values.empty()) { LOG("{}: warning: cannot resolve argument {}", a.type->getSrcInfo(), FormatVisitor::apply(a.type)); jj->set("type", FormatVisitor::apply(a.type)); } else { jj->set("type", tt); } } if (a.defaultValue) { jj->set("default", FormatVisitor::apply(a.defaultValue)); } args.push_back(jj); } j->set("generics", std::make_shared(generics)); bool isLLVM = false; std::vector attrs; for (const auto &d : stmt->getDecorators()) if (auto e = cast(d)) { attrs.push_back(e->getValue()); isLLVM |= (e->getValue() == "llvm"); } if (stmt->hasAttribute(Attr::Property)) attrs.push_back("property"); if (stmt->hasAttribute(Attr::Attribute)) attrs.push_back("__attribute__"); if (stmt->hasAttribute(Attr::Python)) attrs.push_back("python"); if (stmt->hasAttribute(Attr::LLVM)) attrs.push_back("llvm"); if (stmt->hasAttribute(Attr::Internal)) attrs.push_back("__internal__"); if (stmt->hasAttribute(Attr::HiddenFromUser)) attrs.push_back("__hidden__"); if (stmt->hasAttribute(Attr::Atomic)) attrs.push_back("atomic"); if (stmt->hasAttribute(Attr::StaticMethod)) attrs.push_back("staticmethod"); if (stmt->hasAttribute(Attr::C)) attrs.push_back("C"); if (!attrs.empty()) j->set("attrs", std::make_shared(attrs)); if (stmt->getReturn()) j->set("return", transform(stmt->getReturn())); j->set("args", std::make_shared(args)); std::string docstr; flatten(stmt->getSuite(), &docstr); for (auto &g : generics) ctx->remove(g); if (!docstr.empty() && !isLLVM) j->set("doc", docstr); ctx->shared->j->set(std::to_string(id), j); resultStmt = std::to_string(id); } void DocVisitor::visit(ClassStmt *stmt) { std::vector generics; bool isRecord = stmt->isRecord(); auto j = std::make_shared(std::unordered_map{ {"name", stmt->getName()}, {"kind", "class"}, {"type", isRecord ? "type" : "class"}}); int id = ctx->shared->itemID++; bool isExtend = false; for (auto &d : stmt->getDecorators()) if (auto e = cast(d)) isExtend |= (e->getValue() == "extend"); if (isExtend) { j->set("type", "extension"); auto i = ctx->find(stmt->getName()); j->set("parent", std::to_string(*i)); generics = ctx->shared->generics[*i]; } else { ctx->add(stmt->getName(), std::make_shared(id)); } std::vector> args; for (const auto &a : *stmt) if (!a.isValue()) { generics.push_back(a.name); } ctx->shared->generics[id] = generics; for (auto &g : generics) ctx->add(g, std::make_shared(0)); for (const auto &a : *stmt) { auto ja = std::make_shared(); ja->set("name", a.name); if (a.type) { auto tt = transform(a.type); if (tt->values.empty()) { LOG("{}: warning: cannot resolve argument {}", a.type->getSrcInfo(), FormatVisitor::apply(a.type)); ja->set("type", FormatVisitor::apply(a.type)); } else { ja->set("type", tt); } } if (a.defaultValue) ja->set("default", FormatVisitor::apply(a.defaultValue)); args.push_back(ja); } j->set("generics", std::make_shared(generics)); j->set("args", std::make_shared(args)); j->set("pos", jsonify(stmt->getSrcInfo())); std::string docstr; std::vector members; for (auto &f : flatten(stmt->getSuite(), &docstr)) { if (auto ff = cast(f)) { auto i = transform(f); if (i != "") members.push_back(i); if (isValidName(ff->getName())) ctx->remove(ff->getName()); } } for (auto &g : generics) ctx->remove(g); j->set("members", std::make_shared(members)); if (!docstr.empty()) j->set("doc", docstr); ctx->shared->j->set(std::to_string(id), j); resultStmt = std::to_string(id); } std::shared_ptr DocVisitor::jsonify(const codon::SrcInfo &s) { return std::make_shared( std::vector{std::to_string(s.line), std::to_string(s.len)}); } void DocVisitor::visit(ImportStmt *stmt) { if (match(stmt->getFrom(), M(MOr("C", "python")))) { int id = ctx->shared->itemID++; std::string name, lib; if (auto i = cast(stmt->getWhat())) name = i->getValue(); else if (auto d = cast(stmt->getWhat())) name = d->getMember(), lib = FormatVisitor::apply(d->getExpr()); else seqassert(false, "invalid C import statement"); ctx->add(name, std::make_shared(id)); name = stmt->getAs().empty() ? name : stmt->getAs(); auto dict = std::unordered_map{{"name", name}, {"kind", "function"}}; if (stmt->getFrom() && cast(stmt->getFrom())) dict["extern"] = cast(stmt->getFrom())->getValue(); auto j = std::make_shared(dict); j->set("pos", jsonify(stmt->getSrcInfo())); std::vector> args; if (stmt->getReturnType()) j->set("return", transform(stmt->getReturnType())); for (const auto &a : stmt->getArgs()) { auto ja = std::make_shared(); ja->set("name", a.name); ja->set("type", transform(a.type)); args.push_back(ja); } j->set("dylib", lib); j->set("args", std::make_shared(args)); ctx->shared->j->set(std::to_string(id), j); resultStmt = std::to_string(id); return; } std::vector dirs; // Path components Expr *e = stmt->getFrom(); while (cast(e)) { while (auto d = cast(e)) { dirs.push_back(d->getMember()); e = d->getExpr(); } if (!cast(e) || !stmt->getArgs().empty() || stmt->getReturnType() || (stmt->getWhat() && !cast(stmt->getWhat()))) E(Error::CUSTOM, stmt->getSrcInfo(), "invalid import statement"); } if (auto ee = cast(e)) { if (!stmt->getArgs().empty() || stmt->getReturnType() || (stmt->getWhat() && !cast(stmt->getWhat()))) E(Error::CUSTOM, stmt->getSrcInfo(), "invalid import statement"); // We have an empty stmt->from in "from .. import". if (!ee->getValue().empty()) dirs.push_back(ee->getValue()); } // Handle dots (e.g. .. in from ..m import x). for (size_t i = 1; i < stmt->getDots(); i++) dirs.emplace_back(".."); std::string path; for (int i = static_cast(dirs.size()) - 1; i >= 0; i--) path += dirs[i] + (i ? "/" : ""); // Fetch the import! auto file = getImportFile(ctx->shared->cache, path, ctx->getFilename()); if (!file) E(Error::CUSTOM, stmt->getSrcInfo(), "cannot locate import '{}'", path); auto ictx = ctx; auto it = ctx->shared->modules.find(file->path); if (it == ctx->shared->modules.end()) { ctx->shared->modules[file->path] = ictx = std::make_shared(ctx->shared); ictx->setFilename(file->path); // LOG("=> parsing {}", file->path); auto tmpOrErr = parseFile(ctx->shared->cache, file->path); if (!tmpOrErr) throw exc::ParserException(tmpOrErr.takeError()); DocVisitor(ictx).transformModule(*tmpOrErr); } else { ictx = it->second; } if (!stmt->getWhat()) { // TODO: implement this corner case for (auto &i : dirs) if (!ctx->find(i)) ctx->add(i, std::make_shared(ctx->shared->itemID++)); } else if (isId(stmt->getWhat(), "*")) { for (auto &i : *ictx) ctx->add(i.first, i.second.front()); } else { auto i = cast(stmt->getWhat()); if (auto c = ictx->find(i->getValue())) ctx->add(stmt->getAs().empty() ? i->getValue() : stmt->getAs(), c); else E(Error::CUSTOM, stmt->getSrcInfo(), "symbol '{}' not found in {}", i->getValue(), file->path); } } void DocVisitor::visit(AssignStmt *stmt) { auto e = cast(stmt->getLhs()); if (!e) return; int id = ctx->shared->itemID++; ctx->add(e->getValue(), std::make_shared(id)); auto j = std::make_shared(std::unordered_map{ {"name", e->getValue()}, {"kind", "variable"}}); if (stmt->getTypeExpr()) j->set("type", transform(stmt->getTypeExpr())); if (stmt->getRhs()) j->set("value", FormatVisitor::apply(stmt->getRhs())); j->set("pos", jsonify(stmt->getSrcInfo())); ctx->shared->j->set(std::to_string(id), j); resultStmt = std::to_string(id); } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/doc/doc.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/ctx.h" #include "codon/parser/visitors/visitor.h" namespace codon::ast { struct json { // values={str -> null} -> string value // values={i -> json} -> list (if list=true) // values={...} -> dictionary std::unordered_map> values; bool list; json(); json(const std::string &s); json(const std::string &s, const std::string &v); json(const std::vector> &vs); json(const std::vector &vs); json(const std::unordered_map &vs); std::string toString(); std::shared_ptr get(const std::string &s); std::shared_ptr set(const std::string &s, const std::string &value); std::shared_ptr set(const std::string &s, const std::shared_ptr &value); }; struct DocContext; struct DocShared { int itemID = 1; std::shared_ptr j; std::unordered_map> modules; std::string argv0; Cache *cache = nullptr; std::unordered_map> generics; DocShared() {} }; struct DocContext : public Context { std::shared_ptr shared; explicit DocContext(std::shared_ptr shared) : Context(""), shared(std::move(shared)) {} std::shared_ptr find(const std::string &s) const override; }; struct DocVisitor : public CallbackASTVisitor, std::string> { std::shared_ptr ctx; std::shared_ptr resultExpr; std::string resultStmt; public: explicit DocVisitor(std::shared_ptr ctx) : ctx(std::move(ctx)) {} static std::shared_ptr apply(const std::string &argv0, const std::vector &files); std::shared_ptr transform(Expr *e) override; std::string transform(Stmt *e) override; void transformModule(Stmt *stmt); static std::shared_ptr jsonify(const codon::SrcInfo &s); static std::vector flatten(Stmt *stmt, std::string *docstr = nullptr, bool deep = true); public: void visit(IntExpr *) override; void visit(IdExpr *) override; void visit(IndexExpr *) override; void visit(FunctionStmt *) override; void visit(ClassStmt *) override; void visit(AssignStmt *) override; void visit(ImportStmt *) override; }; } // namespace codon::ast ================================================ FILE: codon/parser/visitors/format/format.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include "codon/parser/visitors/format/format.h" namespace codon { namespace ast { std::string FormatVisitor::anchor_root(const std::string &s) { return fmt::format("{}", s, s); } std::string FormatVisitor::anchor(const std::string &s) { return fmt::format("{}", s, s); } FormatVisitor::FormatVisitor(bool html, Cache *cache) : renderType(false), renderHTML(html), indent(0), cache(cache) { if (renderHTML) { header = "\n"; header += "
\n"; footer = "\n
"; nl = "
"; typeStart = ""; typeEnd = ""; nodeStart = ""; nodeEnd = ""; stmtStart = ""; stmtEnd = ""; exprStart = ""; exprEnd = ""; commentStart = ""; commentEnd = ""; literalStart = ""; literalEnd = ""; keywordStart = ""; keywordEnd = ""; space = " "; renderType = true; } else { space = " "; } } std::string FormatVisitor::transform(Expr *expr) { FormatVisitor v(renderHTML, cache); if (expr) expr->accept(v); return v.result; } std::string FormatVisitor::transform(Stmt *stmt) { return transform(stmt, 0); } std::string FormatVisitor::transform(Stmt *stmt, int indent) { FormatVisitor v(renderHTML, cache); v.indent = this->indent + indent; if (stmt) stmt->accept(v); if (v.result.empty()) return ""; return fmt::format("{}{}{}{}{}", stmtStart, cast(stmt) ? "" : pad(indent), v.result, stmtEnd, newline()); } std::string FormatVisitor::pad(int indent) const { std::string s; for (int i = 0; i < (this->indent + indent) * 2; i++) s += space; return s; } std::string FormatVisitor::newline() const { return nl + "\n"; } std::string FormatVisitor::keyword(const std::string &s) const { return fmt::format("{}{}{}", keywordStart, s, keywordEnd); } std::string FormatVisitor::literal(const std::string &s) const { return fmt::format("{}{}{}", literalStart, s, literalEnd); } /*************************************************************************************/ void FormatVisitor::visit(NoneExpr *expr) { result = renderExpr(expr, "None"); } void FormatVisitor::visit(BoolExpr *expr) { result = renderExpr(expr, "{}", literal(expr->getValue() ? "True" : "False")); } void FormatVisitor::visit(IntExpr *expr) { auto [value, suffix] = expr->getRawData(); result = renderExpr(expr, "{}{}", literal(value), suffix); } void FormatVisitor::visit(FloatExpr *expr) { auto [value, suffix] = expr->getRawData(); result = renderExpr(expr, "{}{}", literal(value), suffix); } void FormatVisitor::visit(StringExpr *expr) { result = renderExpr(expr, "{}", literal(fmt::format("\"{}\"", escape(expr->getValue())))); } void FormatVisitor::visit(IdExpr *expr) { result = renderExpr(expr, "{}", expr->getType() && expr->getType()->getFunc() ? anchor(expr->getValue()) : expr->getValue()); } void FormatVisitor::visit(StarExpr *expr) { result = renderExpr(expr, "*{}", transform(expr->getExpr())); } void FormatVisitor::visit(KeywordStarExpr *expr) { result = renderExpr(expr, "**{}", transform(expr->getExpr())); } void FormatVisitor::visit(TupleExpr *expr) { result = renderExpr(expr, "({})", transformItems(*expr)); } void FormatVisitor::visit(ListExpr *expr) { result = renderExpr(expr, "[{}]", transformItems(*expr)); } void FormatVisitor::visit(InstantiateExpr *expr) { result = renderExpr(expr, "{}⟦{}⟧", transform(expr->getExpr()), transformItems(*expr)); } void FormatVisitor::visit(SetExpr *expr) { result = renderExpr(expr, "{{{}}}", transformItems(*expr)); } void FormatVisitor::visit(DictExpr *expr) { std::vector s; for (auto *i : *expr) { auto t = cast(i); s.push_back(fmt::format("{}: {}", transform((*t)[0]), transform((*t)[1]))); } result = renderExpr(expr, "{{{}}}", join(s, ", ")); } void FormatVisitor::visit(GeneratorExpr *expr) { // seqassert(false, "not implemented"); result = "GENERATOR_IMPL"; } void FormatVisitor::visit(IfExpr *expr) { result = renderExpr(expr, "({} {} {} {} {})", transform(expr->getIf()), keyword("if"), transform(expr->getCond()), keyword("else"), transform(expr->getElse())); } void FormatVisitor::visit(UnaryExpr *expr) { result = renderExpr(expr, "{}{}", expr->getOp(), transform(expr->getExpr())); } void FormatVisitor::visit(BinaryExpr *expr) { result = renderExpr(expr, "({} {} {})", transform(expr->getLhs()), expr->getOp(), transform(expr->getRhs())); } void FormatVisitor::visit(PipeExpr *expr) { std::vector items; for (const auto &l : *expr) { if (!items.size()) items.push_back(transform(l.expr)); else items.push_back(l.op + " " + transform(l.expr)); } result = renderExpr(expr, "({})", join(items, " ")); } void FormatVisitor::visit(IndexExpr *expr) { result = renderExpr(expr, "{}[{}]", transform(expr->getExpr()), transform(expr->getIndex())); } void FormatVisitor::visit(CallExpr *expr) { std::vector args; for (auto &i : *expr) { if (i.name == "") args.push_back(transform(i.value)); else args.push_back(fmt::format("{}={}", i.name, transform(i.value))); } result = renderExpr(expr, "{}({})", transform(expr->getExpr()), join(args, ", ")); } void FormatVisitor::visit(DotExpr *expr) { result = renderExpr(expr, "{}.{}", transform(expr->getExpr()), expr->getMember()); } void FormatVisitor::visit(SliceExpr *expr) { std::string s; if (expr->getStart()) s += transform(expr->getStart()); s += ":"; if (expr->getStop()) s += transform(expr->getStop()); s += ":"; if (expr->getStep()) s += transform(expr->getStep()); result = renderExpr(expr, "{}", s); } void FormatVisitor::visit(EllipsisExpr *expr) { result = renderExpr(expr, "..."); } void FormatVisitor::visit(LambdaExpr *expr) { std::vector s; for (const auto &v : *expr) s.emplace_back(v.getName()); result = renderExpr(expr, "{} {}: {}", keyword("lambda"), join(s, ", "), transform(expr->getExpr())); } void FormatVisitor::visit(YieldExpr *expr) { result = renderExpr(expr, "{}", "(" + keyword("yield") + ")"); } void FormatVisitor::visit(AwaitExpr *expr) { result = fmt::format("{} {}", keyword("await"), transform(expr->getExpr())); } void FormatVisitor::visit(StmtExpr *expr) { std::string s; for (auto *i : *expr) s += fmt::format("{}{}", pad(2), transform(i, 2)); result = renderExpr(expr, "《{}{}{}{}{}》", newline(), s, newline(), pad(2), transform(expr->getExpr())); } void FormatVisitor::visit(AssignExpr *expr) { result = renderExpr(expr, "({} := {})", transform(expr->getVar()), transform(expr->getExpr())); } void FormatVisitor::visit(SuiteStmt *stmt) { for (auto *s : *stmt) result += transform(s); } void FormatVisitor::visit(BreakStmt *stmt) { result = keyword("break"); } void FormatVisitor::visit(ContinueStmt *stmt) { result = keyword("continue"); } void FormatVisitor::visit(ExprStmt *stmt) { result = transform(stmt->getExpr()); } void FormatVisitor::visit(AssignStmt *stmt) { if (stmt->getTypeExpr()) { result = fmt::format("{}: {} = {}", transform(stmt->getLhs()), transform(stmt->getTypeExpr()), transform(stmt->getRhs())); } else { result = fmt::format("{} = {}", transform(stmt->getLhs()), transform(stmt->getRhs())); } } void FormatVisitor::visit(AssignMemberStmt *stmt) { result = fmt::format("{}.{} = {}", transform(stmt->getLhs()), stmt->getMember(), transform(stmt->getRhs())); } void FormatVisitor::visit(DelStmt *stmt) { result = fmt::format("{} {}", keyword("del"), transform(stmt->getExpr())); } void FormatVisitor::visit(PrintStmt *stmt) { result = fmt::format("{} {}", keyword("print"), transformItems(*stmt)); } void FormatVisitor::visit(ReturnStmt *stmt) { result = fmt::format("{}{}", keyword("return"), stmt->getExpr() ? " " + transform(stmt->getExpr()) : ""); } void FormatVisitor::visit(YieldStmt *stmt) { result = fmt::format("{}{}", keyword("yield"), stmt->getExpr() ? " " + transform(stmt->getExpr()) : ""); } void FormatVisitor::visit(AssertStmt *stmt) { result = fmt::format("{} {}", keyword("assert"), transform(stmt->getExpr())); } void FormatVisitor::visit(WhileStmt *stmt) { result = fmt::format("{} {}:{}{}", keyword("while"), transform(stmt->getCond()), newline(), transform(stmt->getSuite(), 1)); } void FormatVisitor::visit(ForStmt *stmt) { result = fmt::format("{} {} {} {}:{}{}", keyword("for"), transform(stmt->getVar()), keyword("in"), transform(stmt->getIter()), newline(), transform(stmt->getSuite(), 1)); } void FormatVisitor::visit(IfStmt *stmt) { result = fmt::format("{} {}:{}{}{}", keyword("if"), transform(stmt->getCond()), newline(), transform(stmt->getIf(), 1), stmt->getElse() ? fmt::format("{}:{}{}", keyword("else"), newline(), transform(stmt->getElse(), 1)) : ""); } void FormatVisitor::visit(MatchStmt *stmt) { std::string s; for (const auto &c : *stmt) s += fmt::format( "{}{}{}{}:{}{}", pad(1), keyword("case"), transform(c.getPattern()), c.getGuard() ? " " + (keyword("case") + " " + transform(c.getGuard())) : "", newline(), transform(c.getSuite(), 2)); result = fmt::format("{} {}:{}{}", keyword("match"), transform(stmt->getExpr()), newline(), s); } void FormatVisitor::visit(ImportStmt *stmt) { auto as = stmt->getAs().empty() ? "" : fmt::format(" {} {} ", keyword("as"), stmt->getAs()); if (!stmt->getWhat()) result += fmt::format("{} {}{}", keyword("import"), transform(stmt->getFrom()), as); else result += fmt::format("{} {} {} {}{}", keyword("from"), transform(stmt->getFrom()), keyword("import"), transform(stmt->getWhat()), as); } void FormatVisitor::visit(TryStmt *stmt) { std::vector catches; for (auto *c : *stmt) { catches.push_back(fmt::format( "{} {}{}:{}{}", keyword("except"), transform(c->getException()), c->getVar() == "" ? "" : fmt::format("{} {}", keyword("as"), c->getVar()), newline(), transform(c->getSuite(), 1))); } result = fmt::format("{}:{}{}{}{}", keyword("try"), newline(), transform(stmt->getSuite(), 1), join(catches, ""), stmt->getFinally() ? fmt::format("{}:{}{}", keyword("finally"), newline(), transform(stmt->getFinally(), 1)) : ""); } void FormatVisitor::visit(GlobalStmt *stmt) { result = fmt::format("{} {}", keyword("global"), stmt->getVar()); } void FormatVisitor::visit(ThrowStmt *stmt) { result = fmt::format("{} {}{}", keyword("raise"), transform(stmt->getExpr()), stmt->getFrom() ? fmt::format(" {} {}", keyword("from"), transform(stmt->getFrom())) : ""); } void FormatVisitor::visit(FunctionStmt *fstmt) { if (cache) { if (in(cache->functions, fstmt->getName())) { if (!cache->functions[fstmt->getName()].realizations.empty()) { result += fmt::format("
# {}", fmt::format("{} {}", keyword("def"), fstmt->getName())); for (auto &val : cache->functions[fstmt->getName()].realizations | std::views::values) { auto fa = val->ast; auto ft = val->type; std::vector attrs; for (const auto &a : fa->getDecorators()) attrs.push_back(fmt::format("@{}", transform(a))); if (auto a = fa->getAttribute(Attr::Module)) if (!a->value.empty()) attrs.push_back(fmt::format("@module:{}", a->value)); if (auto a = fa->getAttribute(Attr::ParentClass)) if (!a->value.empty()) attrs.push_back(fmt::format("@parent:{}", a->value)); std::vector args; for (size_t i = 0, j = 0; i < fa->size(); i++) { auto &a = (*fa)[i]; if (a.isValue()) { args.push_back(fmt::format( "{}: {}{}", a.getName(), anchor((*ft)[j++]->realizedName()), a.getDefault() ? fmt::format("={}", transform(a.getDefault())) : "")); } } auto body = transform(fa->getSuite(), 1); auto name = fmt::format("{}", anchor_root(fa->getName())); result += fmt::format( "{}{}{}{} {}({}){}:{}{}", newline(), pad(), attrs.size() ? join(attrs, newline() + pad()) + newline() + pad() : "", keyword("def"), anchor_root(name), join(args, ", "), fmt::format(" -> {}", anchor(ft->getRetType()->realizedName())), newline(), body.empty() ? fmt::format("{}", keyword("pass")) : body); } result += "
"; } return; } } } void FormatVisitor::visit(ClassStmt *stmt) { if (cache) { if (auto cls = in(cache->classes, stmt->getName())) { if (!cls->realizations.empty()) { result = fmt::format( "
# {}", fmt::format("{} {} {}", keyword("class"), stmt->getName(), stmt->hasAttribute(Attr::Extend) ? " +@extend" : "")); for (auto &real : cls->realizations) { std::vector args; auto l = real.second->type->is(TYPE_TUPLE) ? real.second->type->generics.size() : real.second->fields.size(); for (size_t i = 0; i < l; i++) { const auto &[n, t] = real.second->fields[i]; auto name = fmt::format("{}{}: {}{}", exprStart, n, anchor(t->realizedName()), exprEnd); args.push_back(name); } result += fmt::format("{}{}{}{} {}", newline(), pad(), (stmt->hasAttribute(Attr::Tuple) ? fmt::format("@tuple{}{}", newline(), pad()) : ""), keyword("class"), anchor_root(real.first)); if (!args.empty()) result += fmt::format(":{}{}{}", newline(), pad(indent + 1), join(args, newline() + pad(indent + 1))); } result += "
"; } } } // if (stmt->suite) // result += transform(stmt->suite); } void FormatVisitor::visit(YieldFromStmt *stmt) { result = fmt::format("{} {}", keyword("yield from"), transform(stmt->getExpr())); } void FormatVisitor::visit(WithStmt *stmt) {} void FormatVisitor::visit(CommentStmt *stmt) { result = fmt::format("{}# {}{}", commentStart, stmt->getComment(), commentEnd); } void FormatVisitor::visit(DirectiveStmt *stmt) {} } // namespace ast } // namespace codon ================================================ FILE: codon/parser/visitors/format/format.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" #include "codon/parser/visitors/visitor.h" namespace codon { namespace ast { class FormatVisitor : public CallbackASTVisitor { std::string result; std::string space; bool renderType, renderHTML; int indent; std::string header, footer, nl; std::string typeStart, typeEnd; std::string nodeStart, nodeEnd; std::string stmtStart, stmtEnd; std::string exprStart, exprEnd; std::string commentStart, commentEnd; std::string keywordStart, keywordEnd; std::string literalStart, literalEnd; Cache *cache; private: template std::string renderExpr(T &&t, const char *fmt, Ts &&...args) { std::string s = t->getType() ? fmt::format("{}{}{}", typeStart, anchor(t->getType()->realizedName()), typeEnd) : ""; return fmt::format("{}{}{}{}{}{}", exprStart, nodeStart, fmt::format(fmt::runtime(fmt), args...), nodeEnd, s, exprEnd); } template std::string renderComment(const char *fmt, Ts &&...args) { return fmt::format("{}{}{}", commentStart, fmt::format(fmt::runtime(fmt), args...), commentEnd); } std::string pad(int indent = 0) const; std::string newline() const; std::string keyword(const std::string &s) const; std::string literal(const std::string &s) const; static std::string anchor_root(const std::string &s); static std::string anchor(const std::string &s); public: FormatVisitor(bool html, Cache *cache = nullptr); std::string transform(Expr *e) override; std::string transform(Stmt *stmt) override; std::string transform(Stmt *stmt, int indent); template static std::string apply(const T &stmt, Cache *cache = nullptr, bool html = false, bool init = false) { auto t = FormatVisitor(html, cache); return fmt::format("{}{}{}", t.header, t.transform(stmt), t.footer); } void defaultVisit(Expr *e) override { seqassertn(false, "cannot format {}", *e); } void defaultVisit(Stmt *e) override { seqassertn(false, "cannot format {}", *e); } public: void visit(NoneExpr *) override; void visit(BoolExpr *) override; void visit(IntExpr *) override; void visit(FloatExpr *) override; void visit(StringExpr *) override; void visit(IdExpr *) override; void visit(StarExpr *) override; void visit(KeywordStarExpr *) override; void visit(TupleExpr *) override; void visit(ListExpr *) override; void visit(SetExpr *) override; void visit(DictExpr *) override; void visit(GeneratorExpr *) override; void visit(InstantiateExpr *expr) override; void visit(IfExpr *) override; void visit(UnaryExpr *) override; void visit(BinaryExpr *) override; void visit(PipeExpr *) override; void visit(IndexExpr *) override; void visit(CallExpr *) override; void visit(DotExpr *) override; void visit(SliceExpr *) override; void visit(EllipsisExpr *) override; void visit(LambdaExpr *) override; void visit(YieldExpr *) override; void visit(AwaitExpr *) override; void visit(StmtExpr *expr) override; void visit(AssignExpr *expr) override; void visit(SuiteStmt *) override; void visit(BreakStmt *) override; void visit(ContinueStmt *) override; void visit(ExprStmt *) override; void visit(AssignStmt *) override; void visit(AssignMemberStmt *) override; void visit(DelStmt *) override; void visit(PrintStmt *) override; void visit(ReturnStmt *) override; void visit(YieldStmt *) override; void visit(AssertStmt *) override; void visit(WhileStmt *) override; void visit(ForStmt *) override; void visit(IfStmt *) override; void visit(MatchStmt *) override; void visit(ImportStmt *) override; void visit(TryStmt *) override; void visit(GlobalStmt *) override; void visit(ThrowStmt *) override; void visit(FunctionStmt *) override; void visit(ClassStmt *) override; void visit(YieldFromStmt *) override; void visit(WithStmt *) override; void visit(CommentStmt *) override; void visit(DirectiveStmt *) override; public: friend std::ostream &operator<<(std::ostream &out, const FormatVisitor &c) { return out << c.result; } using CallbackASTVisitor::transform; template std::string transformItems(const T &ts) { std::vector r; for (auto &e : ts) r.push_back(transform(e)); return fmt::format("{}", join(r, ", ")); } }; } // namespace ast } // namespace codon template <> struct fmt::formatter : fmt::ostream_formatter {}; ================================================ FILE: codon/parser/visitors/scoping/scoping.cpp ================================================ // Copyright (C) 2022-2023 Exaloop Inc. #include #include #include #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/match.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/scoping/scoping.h" #define CHECK(x) \ { \ if (!(x)) \ return; \ } #define STOP_ERROR(...) \ do { \ addError(__VA_ARGS__); \ return; \ } while (0) using namespace codon::error; using namespace codon::matcher; namespace codon::ast { llvm::Error ScopingVisitor::apply(Cache *cache, Stmt *s, std::unordered_map *globalShadows) { auto c = std::make_shared(); c->cache = cache; c->functionScope = nullptr; ScopingVisitor v; v.ctx = c; ConditionalBlock cb(c.get(), s, 0); if (!v.transform(s)) return llvm::make_error(v.errors); if (v.hasErrors()) return llvm::make_error(v.errors); v.processChildCaptures(); /// Count number of shadowed names to know which names change or not later on if (globalShadows) { for (auto &[u, v] : c->map) { size_t i = 0; for (auto &ii : v) if (!ii.ignore) i++; if (i > 1) (*globalShadows)[u] = i; } } // LOG("-> {}", s->toString(2)); return llvm::Error::success(); } bool ScopingVisitor::transform(Expr *expr) { if (!canContinue()) return false; ScopingVisitor v(*this); if (expr) { v.setSrcInfo(expr->getSrcInfo()); expr->accept(v); if (v.hasErrors()) errors.append(v.errors); if (!canContinue()) return false; } return true; } bool ScopingVisitor::transform(Stmt *stmt) { if (!canContinue()) return false; ScopingVisitor v(*this); if (stmt) { v.setSrcInfo(stmt->getSrcInfo()); stmt->setAttribute(Attr::ExprTime, ++ctx->time); stmt->accept(v); if (v.hasErrors()) errors.append(v.errors); if (!canContinue()) return false; } return true; } bool ScopingVisitor::transformScope(Expr *e) { if (e) { ConditionalBlock c(ctx.get(), nullptr); return transform(e); } return true; } bool ScopingVisitor::transformScope(Stmt *s) { if (s) { ConditionalBlock c(ctx.get(), s); return transform(s); } return true; } bool ScopingVisitor::transformAdding(Expr *e, ASTNode *root) { if (cast(e)) { return transform(e); } else if (auto de = cast(e)) { if (!transform(e)) return false; if (!ctx->classDeduce.first.empty() && match(de->getExpr(), M(ctx->classDeduce.first))) ctx->classDeduce.second.insert(de->getMember()); return true; } else if (cast(e) || cast(e) || cast(e)) { SetInScope s1(&(ctx->adding), true); SetInScope s2(&(ctx->root), root); return transform(e); } else { seqassert(e, "bad call to transformAdding"); addError(Error::ASSIGN_INVALID, e); return false; } } void ScopingVisitor::visit(IdExpr *expr) { if (ctx->adding) ctx->root = expr; if (ctx->adding && ctx->tempScope) ctx->renames.back()[expr->getValue()] = ctx->cache->getTemporaryVar(expr->getValue()); for (size_t i = ctx->renames.size(); i-- > 0;) if (auto v = in(ctx->renames[i], expr->getValue())) { expr->setValue(*v); break; } if (visitName(expr->getValue(), ctx->adding, ctx->root, expr->getSrcInfo())) expr->setAttribute(Attr::ExprDominatedUndefCheck); } void ScopingVisitor::visit(DotExpr *expr) { SetInScope s(&(ctx->adding), false); // to handle a.x, y = b CallbackASTVisitor::visit(expr); } void ScopingVisitor::visit(IndexExpr *expr) { SetInScope s(&(ctx->adding), false); // to handle a[x], y = b CallbackASTVisitor::visit(expr); } void ScopingVisitor::visit(StringExpr *expr) { std::vector exprs; for (auto &p : *expr) { if (p.prefix == "f" || p.prefix == "F") { /// Transform an F-string auto fstr = unpackFString(p.value); if (!canContinue()) return; for (auto pf : fstr) { if (pf.prefix.empty() && !exprs.empty() && exprs.back().prefix.empty()) { exprs.back().value += pf.value; } else { exprs.emplace_back(pf); } } } else if (!p.prefix.empty()) { exprs.emplace_back(p); } else if (!exprs.empty() && exprs.back().prefix.empty()) { exprs.back().value += p.value; } else { exprs.emplace_back(p); } } expr->strings = exprs; } /// Split a Python-like f-string into a list: /// `f"foo {x+1} bar"` -> `("foo ", str(x+1), " bar") /// Supports "{x=}" specifier (that prints the raw expression as well): /// `f"{x+1=}"` -> `("x+1=", str(x+1))` std::vector ScopingVisitor::unpackFString(const std::string &value) { // Strings to be concatenated std::vector items; int braceCount = 0, braceStart = 0; for (int i = 0; i < value.size(); i++) { if (value[i] == '{') { if (braceStart < i) items.emplace_back(value.substr(braceStart, i - braceStart)); if (!braceCount) braceStart = i + 1; braceCount++; } else if (value[i] == '}') { braceCount--; if (!braceCount) { std::string code = value.substr(braceStart, i - braceStart); auto offset = getSrcInfo(); offset.col += i; items.emplace_back(code, "#f"); items.back().setSrcInfo(offset); auto val = parseExpr(ctx->cache, code, offset); if (!val) { addError(val.takeError()); } else { items.back().expr = val->first; if (!transform(items.back().expr)) return items; items.back().format = val->second; } } braceStart = i + 1; } } if (braceCount > 0) addError(Error::STR_FSTRING_BALANCE_EXTRA, getSrcInfo()); else if (braceCount < 0) addError(Error::STR_FSTRING_BALANCE_MISSING, getSrcInfo()); if (braceStart != value.size()) items.emplace_back(value.substr(braceStart, value.size() - braceStart)); return items; } void ScopingVisitor::visit(GeneratorExpr *expr) { SetInScope s(&(ctx->tempScope), true); ctx->renames.emplace_back(); CHECK(transform(expr->getFinalSuite())); ctx->renames.pop_back(); } void ScopingVisitor::visit(IfExpr *expr) { CHECK(transform(expr->getCond())); CHECK(transformScope(expr->getIf())); CHECK(transformScope(expr->getElse())); } void ScopingVisitor::visit(BinaryExpr *expr) { CHECK(transform(expr->getLhs())); if (expr->getOp() == "&&" || expr->getOp() == "||") { CHECK(transformScope(expr->getRhs())); } else { CHECK(transform(expr->getRhs())); } } void ScopingVisitor::visit(AssignExpr *expr) { seqassert(cast(expr->getVar()), "only simple assignment expression are supported"); SetInScope s(&(ctx->tempScope), false); CHECK(transform(expr->getExpr())); CHECK(transformAdding(expr->getVar(), expr)); } void ScopingVisitor::visit(LambdaExpr *expr) { auto c = std::make_shared(); c->cache = ctx->cache; FunctionStmt f("lambda", nullptr, {}, nullptr); c->functionScope = &f; c->renames = ctx->renames; ScopingVisitor v; c->scope.emplace_back(0, nullptr); v.ctx = c; for (const auto &a : *expr) { auto [_, n] = a.getNameWithStars(); v.visitName(n, true, expr, a.getSrcInfo()); if (a.defaultValue) CHECK(transform(a.defaultValue)); } c->scope.pop_back(); SuiteStmt s; c->scope.emplace_back(0, &s); v.transform(expr->getExpr()); v.processChildCaptures(); c->scope.pop_back(); if (v.hasErrors()) errors.append(v.errors); if (!canContinue()) return; auto b = std::make_unique(); b->captures = c->captures; for (const auto &n : c->captures) ctx->childCaptures.insert(n); for (auto &[u, v] : c->map) b->bindings[u] = {u, v.size()}; expr->setAttribute(Attr::Bindings, std::move(b)); } void ScopingVisitor::visit(AssignStmt *stmt) { CHECK(transform(stmt->getRhs())); CHECK(transform(stmt->getTypeExpr())); CHECK(transformAdding(stmt->getLhs(), stmt)); } void ScopingVisitor::visit(IfStmt *stmt) { CHECK(transform(stmt->getCond())); CHECK(transformScope(stmt->getIf())); CHECK(transformScope(stmt->getElse())); } void ScopingVisitor::visit(MatchStmt *stmt) { CHECK(transform(stmt->getExpr())); for (auto &m : *stmt) { CHECK(transform(m.getPattern())); CHECK(transform(m.getGuard())); CHECK(transformScope(m.getSuite())); } } void ScopingVisitor::visit(WhileStmt *stmt) { std::unordered_set seen; { ConditionalBlock c(ctx.get(), stmt->getSuite()); ctx->scope.back().seenVars = std::make_unique>(); CHECK(transform(stmt->getCond())); CHECK(transform(stmt->getSuite())); seen = *(ctx->scope.back().seenVars); } for (auto &var : seen) { findDominatingBinding(var); } CHECK(transformScope(stmt->getElse())); } void ScopingVisitor::visit(ForStmt *stmt) { CHECK(transform(stmt->getIter())); CHECK(transform(stmt->getDecorator())); for (auto &a : stmt->ompArgs) CHECK(transform(a.value)); std::unordered_set seen, seenDef; { ConditionalBlock c(ctx.get(), stmt->getSuite()); ctx->scope.back().seenVars = std::make_unique>(); CHECK(transformAdding(stmt->getVar(), stmt)); seenDef = *(ctx->scope.back().seenVars); ctx->scope.back().seenVars = std::make_unique>(); CHECK(transform(stmt->getSuite())); seen = *(ctx->scope.back().seenVars); } for (auto &var : seen) if (!in(seenDef, var)) findDominatingBinding(var); CHECK(transformScope(stmt->getElse())); } void ScopingVisitor::visit(ImportStmt *stmt) { // Validate if (stmt->getFrom()) { Expr *e = stmt->getFrom(); while (auto d = cast(e)) e = d->getExpr(); if (!isId(stmt->getFrom(), "C") && !isId(stmt->getFrom(), "python")) { if (!cast(e)) STOP_ERROR(Error::IMPORT_IDENTIFIER, e); if (!stmt->getArgs().empty()) STOP_ERROR(Error::IMPORT_FN, stmt->getArgs().front().getSrcInfo()); if (stmt->getReturnType()) STOP_ERROR(Error::IMPORT_FN, stmt->getReturnType()); if (stmt->getWhat() && !cast(stmt->getWhat())) STOP_ERROR(Error::IMPORT_IDENTIFIER, stmt->getWhat()); } if (stmt->isCVar() && !stmt->getArgs().empty()) STOP_ERROR(Error::IMPORT_FN, stmt->getArgs().front().getSrcInfo()); } if (ctx->functionScope && stmt->getWhat() && isId(stmt->getWhat(), "*")) STOP_ERROR(error::Error::IMPORT_STAR, stmt); // dylib C imports if (stmt->getFrom() && isId(stmt->getFrom(), "C") && cast(stmt->getWhat())) CHECK(transform(cast(stmt->getWhat())->getExpr())); if (stmt->getAs().empty()) { if (stmt->getWhat()) { if (!match(stmt->getWhat(), M("*"))) CHECK(transformAdding(stmt->getWhat(), stmt)); } else { CHECK(transformAdding(stmt->getFrom(), stmt)); } } else { visitName(stmt->getAs(), true, stmt, stmt->getSrcInfo()); } for (const auto &a : stmt->getArgs()) { CHECK(transform(a.type)); CHECK(transform(a.defaultValue)); } CHECK(transform(stmt->getReturnType())); } void ScopingVisitor::visit(TryStmt *stmt) { CHECK(transformScope(stmt->getSuite())); for (auto *a : *stmt) { CHECK(transform(a->getException())); ConditionalBlock c(ctx.get(), a->getSuite()); if (!a->getVar().empty()) { auto newName = ctx->cache->getTemporaryVar(a->getVar()); ctx->renames.push_back({{a->getVar(), newName}}); a->var = newName; visitName(a->getVar(), true, a, a->getException()->getSrcInfo()); } CHECK(transform(a->getSuite())); if (!a->getVar().empty()) ctx->renames.pop_back(); } CHECK(transformScope(stmt->getElse())); CHECK(transform(stmt->getFinally())); } void ScopingVisitor::visit(DelStmt *stmt) { CHECK(transform(stmt->getExpr())); } /// Process `global` statements. void ScopingVisitor::visit(GlobalStmt *stmt) { if (!ctx->functionScope) STOP_ERROR(Error::FN_OUTSIDE_ERROR, stmt, stmt->isNonLocal() ? "nonlocal" : "global"); if (in(ctx->map, stmt->getVar()) || in(ctx->captures, stmt->getVar())) STOP_ERROR(Error::FN_GLOBAL_ASSIGNED, stmt, stmt->getVar()); visitName(stmt->getVar(), true, stmt, stmt->getSrcInfo()); // No shadowing od global/nonlocal allowed findDominatingBinding(stmt->getVar(), /* alowShadow= */ false); ctx->captures[stmt->getVar()] = stmt->isNonLocal() ? BindingsAttribute::CaptureType::Nonlocal : BindingsAttribute::CaptureType::Global; } void ScopingVisitor::visit(YieldStmt *stmt) { if (ctx->functionScope) ctx->functionScope->setAttribute(Attr::IsGenerator); CHECK(transform(stmt->getExpr())); } void ScopingVisitor::visit(YieldExpr *expr) { if (ctx->functionScope) ctx->functionScope->setAttribute(Attr::IsGenerator); } void ScopingVisitor::visit(FunctionStmt *stmt) { // Validate std::vector newDecorators; for (auto &d : stmt->getDecorators()) { if (isId(d, "__attribute__")) { stmt->setAttribute(Attr::Attribute); } else if (isId(d, "llvm")) { stmt->setAttribute(Attr::LLVM); } else if (isId(d, "python")) { if (stmt->getDecorators().size() != 1) STOP_ERROR(Error::FN_SINGLE_DECORATOR, stmt->getDecorators()[1], "python"); stmt->setAttribute(Attr::Python); } else if (isId(d, "__internal__")) { stmt->setAttribute(Attr::Internal); } else if (isId(d, "__hidden__")) { stmt->setAttribute(Attr::HiddenFromUser); } else if (isId(d, "atomic")) { stmt->setAttribute(Attr::Atomic); } else if (isId(d, "property")) { stmt->setAttribute(Attr::Property); } else if (isId(d, "staticmethod")) { stmt->setAttribute(Attr::StaticMethod); } else if (isId(d, "__force__")) { stmt->setAttribute(Attr::ForceRealize); } else if (isId(d, "C")) { stmt->setAttribute(Attr::C); } else { newDecorators.emplace_back(d); } } if (stmt->hasAttribute(Attr::C)) { for (auto &a : *stmt) { if (a.getName().size() > 1 && a.getName()[0] == '*' && a.getName()[1] != '*') stmt->setAttribute(Attr::CVarArg); } } if (!stmt->empty() && !stmt->front().getType() && stmt->front().getName() == "self") { stmt->setAttribute(Attr::HasSelf); } stmt->setDecorators(newDecorators); if (!stmt->getReturn() && (stmt->hasAttribute(Attr::LLVM) || stmt->hasAttribute(Attr::C))) STOP_ERROR(Error::FN_LLVM, getSrcInfo()); // Set attributes std::unordered_set seenArgs; bool defaultsStarted = false, hasStarArg = false, hasKwArg = false; for (size_t ia = 0; ia < stmt->size(); ia++) { auto &a = (*stmt)[ia]; auto [stars, n] = a.getNameWithStars(); if (stars == 2) { if (hasKwArg) STOP_ERROR(Error::FN_MULTIPLE_ARGS, a.getSrcInfo()); if (a.getDefault()) STOP_ERROR(Error::FN_DEFAULT_STARARG, a.getDefault()); if (ia != stmt->size() - 1) STOP_ERROR(Error::FN_LAST_KWARG, a.getSrcInfo()); hasKwArg = true; } else if (stars == 1) { if (hasStarArg) STOP_ERROR(Error::FN_MULTIPLE_ARGS, a.getSrcInfo()); if (a.getDefault()) STOP_ERROR(Error::FN_DEFAULT_STARARG, a.getDefault()); hasStarArg = true; } if (in(seenArgs, n)) STOP_ERROR(Error::FN_ARG_TWICE, a.getSrcInfo(), n); seenArgs.insert(n); if (!a.getDefault() && defaultsStarted && !stars && a.isValue()) STOP_ERROR(Error::FN_DEFAULT, a.getSrcInfo(), n); defaultsStarted |= static_cast(a.getDefault()); if (stmt->hasAttribute(Attr::C)) { if (a.getDefault()) STOP_ERROR(Error::FN_C_DEFAULT, a.getDefault(), n); if (stars != 1 && !a.getType()) STOP_ERROR(Error::FN_C_TYPE, a.getSrcInfo(), n); } } bool isOverload = false; for (auto &d : stmt->getDecorators()) if (isId(d, "overload")) { isOverload = true; } if (!isOverload) visitName(stmt->getName(), true, stmt, stmt->getSrcInfo()); auto c = std::make_shared(); c->cache = ctx->cache; c->functionScope = stmt; if (ctx->inClass && !stmt->empty()) c->classDeduce = {stmt->front().getName(), {}}; c->renames = ctx->renames; ScopingVisitor v; c->scope.emplace_back(0); v.ctx = c; v.visitName(stmt->getName(), true, stmt, stmt->getSrcInfo()); for (const auto &a : *stmt) { auto [_, n] = a.getNameWithStars(); v.visitName(n, true, stmt, a.getSrcInfo()); if (a.defaultValue) CHECK(transform(a.defaultValue)); } c->scope.pop_back(); c->scope.emplace_back(0, stmt->getSuite()); v.transform(stmt->getSuite()); v.processChildCaptures(); c->scope.pop_back(); if (v.hasErrors()) errors.append(v.errors); if (!canContinue()) return; auto b = std::make_unique(); b->captures = c->captures; for (const auto &n : c->captures) { ctx->childCaptures.insert(n); } // Recursive capture if used if (c->map[stmt->getName()].size() == 1 && in(c->firstSeen, stmt->getName())) b->captures[stmt->getName()] = BindingsAttribute::Read; for (auto &[u, v] : c->map) { b->bindings[u] = {u, v.size()}; auto cp = in(c->childCaptures, u); if (!cp) cp = in(c->captures, u); if (cp && *cp == BindingsAttribute::Nonlocal) { b->bindings[u].isNonlocal = true; ctx->childCaptures[u] = BindingsAttribute::Nonlocal; } } stmt->setAttribute(Attr::Bindings, std::move(b)); if (!c->classDeduce.second.empty()) { auto s = std::make_unique(); for (const auto &n : c->classDeduce.second) s->values.push_back(n); stmt->setAttribute(Attr::ClassDeduce, std::move(s)); } } void ScopingVisitor::visit(WithStmt *stmt) { ConditionalBlock c(ctx.get(), stmt->getSuite()); for (size_t i = 0; i < stmt->size(); i++) { CHECK(transform((*stmt)[i])); if (!stmt->getVars()[i].empty()) visitName(stmt->getVars()[i], true, stmt, stmt->getSrcInfo()); } CHECK(transform(stmt->getSuite())); } void ScopingVisitor::visit(ClassStmt *stmt) { // @tuple(init=, repr=, eq=, order=, hash=, pickle=, container=, python=, add=, // internal=...) // @dataclass(...) // @extend std::map tupleMagics = { {"new", true}, {"repr", false}, {"hash", false}, {"eq", false}, {"ne", false}, {"lt", false}, {"le", false}, {"gt", false}, {"ge", false}, {"pickle", true}, {"unpickle", true}, {"to_py", false}, {"from_py", false}, {"iter", false}, {"getitem", false}, {"len", false}, {"to_gpu", false}, {"from_gpu", false}, {"from_gpu_new", false}, {"tuplesize", true}}; for (auto &d : stmt->getDecorators()) { if (isId(d, "__notuple__")) { stmt->setAttribute(Attr::ClassNoTuple); } else if (isId(d, "__noextend__")) { stmt->setAttribute(Attr::NoExtend); } else if (isId(d, "dataclass")) { stmt->setAttribute(Attr::Dataclass); } else if (auto c = cast(d)) { if (isId(c->getExpr(), "tuple")) { stmt->setAttribute(Attr::Tuple); for (auto &val : tupleMagics | std::views::values) val = true; } else if (!isId(c->getExpr(), "dataclass")) { STOP_ERROR(Error::CLASS_BAD_DECORATOR, c->getExpr()); } else if (stmt->hasAttribute(Attr::Tuple)) { STOP_ERROR(Error::CLASS_CONFLICT_DECORATOR, c, "dataclass", "tuple"); } for (const auto &a : *c) { auto b = cast(a); if (!b) STOP_ERROR(Error::CLASS_NONSTATIC_DECORATOR, a.getSrcInfo()); char val = static_cast(b->getValue()); if (a.getName() == "init") { tupleMagics["new"] = val; } else if (a.getName() == "repr") { tupleMagics["repr"] = val; } else if (a.getName() == "eq") { tupleMagics["eq"] = tupleMagics["ne"] = val; } else if (a.getName() == "order") { tupleMagics["lt"] = tupleMagics["le"] = tupleMagics["gt"] = tupleMagics["ge"] = val; } else if (a.getName() == "hash") { tupleMagics["hash"] = val; } else if (a.getName() == "pickle") { tupleMagics["pickle"] = tupleMagics["unpickle"] = val; } else if (a.getName() == "python") { tupleMagics["to_py"] = tupleMagics["from_py"] = val; } else if (a.getName() == "gpu") { tupleMagics["to_gpu"] = tupleMagics["from_gpu"] = tupleMagics["from_gpu_new"] = val; } else if (a.getName() == "container") { tupleMagics["iter"] = tupleMagics["getitem"] = val; } else { STOP_ERROR(Error::CLASS_BAD_DECORATOR_ARG, a.getSrcInfo()); } } } else if (isId(d, "tuple")) { if (stmt->hasAttribute(Attr::Tuple)) STOP_ERROR(Error::CLASS_MULTIPLE_DECORATORS, d, "tuple"); stmt->setAttribute(Attr::Tuple); for (auto &val : tupleMagics | std::views::values) { val = true; } } else if (isId(d, "extend")) { stmt->setAttribute(Attr::Extend); if (stmt->getDecorators().size() != 1) { STOP_ERROR( Error::CLASS_SINGLE_DECORATOR, stmt->getDecorators()[stmt->getDecorators().front() == d]->getSrcInfo(), "extend"); } } else if (isId(d, "__internal__")) { stmt->setAttribute(Attr::Internal); } else { STOP_ERROR(Error::CLASS_BAD_DECORATOR, d); } } if (!stmt->hasAttribute(Attr::Tuple) && !stmt->hasAttribute(Attr::Internal) && !stmt->hasAttribute(Attr::Dataclass) && stmt->getStaticBaseClasses().empty() && stmt->size() == 0) { stmt->setAttribute(Attr::ClassDeduce); } if (!stmt->hasAttribute(Attr::Tuple)) { tupleMagics["init"] = tupleMagics["new"]; tupleMagics["new"] = tupleMagics["raw"] = true; tupleMagics["len"] = false; tupleMagics["repr_default"] = true; } tupleMagics["dict"] = true; // Internal classes do not get any auto-generated members. std::vector magics; if (!stmt->hasAttribute(Attr::Internal)) { for (auto &m : tupleMagics) if (m.second) { if (m.first == "new") magics.insert(magics.begin(), m.first); else magics.push_back(m.first); } } stmt->setAttribute(Attr::ClassMagic, std::make_unique(magics)); std::unordered_set seen; if (stmt->hasAttribute(Attr::Extend) && !stmt->empty()) STOP_ERROR(Error::CLASS_EXTENSION, stmt->front().getSrcInfo()); if (stmt->hasAttribute(Attr::Extend) && !(stmt->getBaseClasses().empty() && stmt->getStaticBaseClasses().empty())) { STOP_ERROR(Error::CLASS_EXTENSION, stmt->getBaseClasses().empty() ? stmt->getStaticBaseClasses().front() : stmt->getBaseClasses().front()); } for (auto &a : *stmt) { if (!a.getType() && !a.getDefault()) STOP_ERROR(Error::CLASS_MISSING_TYPE, a.getSrcInfo(), a.getName()); if (in(seen, a.getName())) STOP_ERROR(Error::CLASS_ARG_TWICE, a.getSrcInfo(), a.getName()); seen.insert(a.getName()); } if (stmt->hasAttribute(Attr::Extend)) visitName(stmt->getName()); else visitName(stmt->getName(), true, stmt, stmt->getSrcInfo()); auto c = std::make_shared(); c->cache = ctx->cache; c->renames = ctx->renames; ScopingVisitor v; c->scope.emplace_back(0); c->inClass = true; v.ctx = c; for (const auto &a : *stmt) { v.transform(a.type); v.transform(a.defaultValue); } v.transform(stmt->getSuite()); c->scope.pop_back(); if (v.hasErrors()) errors.append(v.errors); if (!canContinue()) return; for (auto &d : stmt->getBaseClasses()) CHECK(transform(d)); for (auto &d : stmt->getStaticBaseClasses()) CHECK(transform(d)); } void ScopingVisitor::processChildCaptures() { for (auto &n : ctx->childCaptures) { if (auto i = in(ctx->map, n.first)) { if (i->back().binding && cast(i->back().binding)) continue; } if (!findDominatingBinding(n.first)) { ctx->captures.insert(n); // propagate! } } } void ScopingVisitor::switchToUpdate(ASTNode *binding, const std::string &name, bool gotUsedVar) { if (binding && binding->hasAttribute(Attr::Bindings)) { binding->getAttribute(Attr::Bindings)->bindings.erase(name); } if (binding) { if (!gotUsedVar && binding->hasAttribute(Attr::ExprDominatedUsed)) binding->eraseAttribute(Attr::ExprDominatedUsed); binding->setAttribute(gotUsedVar ? Attr::ExprDominatedUsed : Attr::ExprDominated); } if (cast(binding)) STOP_ERROR(error::Error::ID_INVALID_BIND, binding, name); else if (cast(binding)) STOP_ERROR(error::Error::ID_INVALID_BIND, binding, name); } bool ScopingVisitor::visitName(const std::string &name, bool adding, ASTNode *root, const SrcInfo &src) { if (adding && ctx->inClass) return false; if (adding) { if (auto p = in(ctx->captures, name)) { if (*p == BindingsAttribute::CaptureType::Read) { addError(error::Error::ASSIGN_LOCAL_REFERENCE, ctx->firstSeen[name], name, src); return false; } else if (root) { // global, nonlocal switchToUpdate(root, name, false); } } else { if (auto i = in(ctx->childCaptures, name)) { if (*i != BindingsAttribute::CaptureType::Global && ctx->functionScope) { auto newScope = std::vector{ctx->scope[0].id}; seqassert(ctx->scope.front().suite, "invalid suite"); if (!ctx->scope.front().suite->hasAttribute(Attr::Bindings)) ctx->scope.front().suite->setAttribute( Attr::Bindings, std::make_unique()); ctx->scope.front() .suite->getAttribute(Attr::Bindings) ->bindings[name] = {name, 0, *i == BindingsAttribute::CaptureType::Nonlocal}; auto newItem = ScopingVisitor::Context::Item(src, newScope, nullptr); ctx->map[name].push_back(newItem); } } ctx->map[name].emplace_front(src, ctx->getScope(), root); } } else { if (!in(ctx->firstSeen, name)) ctx->firstSeen[name] = src; if (!in(ctx->map, name)) { ctx->captures[name] = BindingsAttribute::CaptureType::Read; } } if (auto val = findDominatingBinding(name)) { // Track loop variables to dominate them later. Example: // x = 1 // while True: // if x > 10: break // x = x + 1 # x must be dominated after the loop to ensure that it gets updated auto scope = ctx->getScope(); for (size_t li = ctx->scope.size(); li-- > 0;) { if (ctx->scope[li].seenVars) { bool inside = val->scope.size() >= scope.size() && val->scope[scope.size() - 1] == scope.back(); if (!inside) ctx->scope[li].seenVars->insert(name); else break; } scope.pop_back(); } // Variable binding check for variables that are defined within conditional blocks if (!val->accessChecked.empty()) { bool checked = false; for (size_t ai = val->accessChecked.size(); ai-- > 0;) { auto &a = val->accessChecked[ai]; if (a.size() <= ctx->scope.size() && a[a.size() - 1] == ctx->scope[a.size() - 1].id) { checked = true; break; } } if (!checked) { seqassert(!adding, "bad branch"); if (!(val->binding && val->binding->hasAttribute(Attr::Bindings))) { // If the expression is not conditional, we can just do the check once val->accessChecked.push_back(ctx->getScope()); } return true; } } } return false; } /// Get an item from the context. Perform domination analysis for accessing items /// defined in the conditional blocks (i.e., Python scoping). ScopingVisitor::Context::Item * ScopingVisitor::findDominatingBinding(const std::string &name, bool allowShadow) { auto it = in(ctx->map, name); if (!it || it->empty()) return nullptr; auto lastGood = it->begin(); while (lastGood != it->end() && lastGood->ignore) ++lastGood; int commonScope = static_cast(ctx->scope.size()); // Iterate through all bindings with the given name and find the closest binding that // dominates the current scope. for (auto i = it->begin(); i != it->end(); ++i) { if (i->ignore) continue; bool completeDomination = i->scope.size() <= ctx->scope.size() && i->scope.back() == ctx->scope[i->scope.size() - 1].id; if (completeDomination) { commonScope = i->scope.size(); lastGood = i; break; } else { seqassert(i->scope[0] == 0, "bad scoping"); seqassert(ctx->scope[0].id == 0, "bad scoping"); // Find the longest block prefix between the binding and the current common scope. commonScope = std::min(commonScope, static_cast(i->scope.size())); while (commonScope > 0 && i->scope[commonScope - 1] != ctx->scope[commonScope - 1].id) commonScope--; // if (commonScope < int(ctx->scope.size()) && commonScope != p) // break; lastGood = i; } } seqassert(lastGood != it->end(), "corrupted scoping ({})", name); if (!allowShadow) { // go to the end lastGood = it->end(); --lastGood; int p = std::min(commonScope, static_cast(lastGood->scope.size())); while (p >= 0 && lastGood->scope[p - 1] != ctx->scope[p - 1].id) p--; commonScope = p; } bool gotUsedVar = false; if (lastGood->scope.size() != commonScope) { // The current scope is potentially reachable by multiple bindings that are // not dominated by a common binding. Create such binding in the scope that // dominates (covers) all of them. auto scope = ctx->getScope(); auto newScope = std::vector(scope.begin(), scope.begin() + commonScope); // Make sure to prepend a binding declaration: `var` and `var__used__ = False` // to the dominating scope. for (size_t si = commonScope; si-- > 0; si--) { if (!ctx->scope[si].suite) continue; if (!ctx->scope[si].suite->hasAttribute(Attr::Bindings)) ctx->scope[si].suite->setAttribute(Attr::Bindings, std::make_unique()); ctx->scope[si] .suite->getAttribute(Attr::Bindings) ->bindings[name] = {name, 1}; auto newItem = ScopingVisitor::Context::Item( getSrcInfo(), newScope, ctx->scope[si].suite, {lastGood->scope}); lastGood = it->insert(++lastGood, newItem); gotUsedVar = true; break; } } else if (lastGood->binding && lastGood->binding->hasAttribute(Attr::Bindings)) { gotUsedVar = lastGood->binding->getAttribute(Attr::Bindings) ->bindings[name] .count > 0; } // Remove all bindings after the dominant binding. for (auto i = it->begin(); i != it->end(); ++i) { if (i == lastGood) break; switchToUpdate(i->binding, name, gotUsedVar); i->scope = lastGood->scope; i->ignore = true; } if (!gotUsedVar && lastGood->binding && lastGood->binding->hasAttribute(Attr::Bindings)) lastGood->binding->getAttribute(Attr::Bindings) ->bindings[name] = {name, 0}; return &(*lastGood); } auto format_as(BindingsAttribute::CaptureType c) { return c == BindingsAttribute::Read ? "RD" : (c == BindingsAttribute::Global ? "GL" : "NL"); } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/scoping/scoping.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include #include "codon/parser/ast.h" #include "codon/parser/visitors/typecheck/ctx.h" #include "codon/parser/visitors/visitor.h" namespace codon::ast { struct BindingsAttribute : public ir::Attribute { static const int AttributeID = 190; enum CaptureType { Read, Global, Nonlocal }; std::unordered_map captures; struct Binding { std::string name; size_t count; bool isNonlocal = false; }; std::unordered_map bindings; std::unordered_map localRenames; std::unique_ptr clone() const override { auto p = std::make_unique(); p->captures = captures; p->bindings = bindings; p->localRenames = localRenames; return p; } private: std::ostream &doFormat(std::ostream &os) const override { return os << "Bindings"; } }; auto format_as(BindingsAttribute::CaptureType c); class ScopingVisitor : public CallbackASTVisitor { struct Context { /// A pointer to the shared cache. Cache *cache; /// Holds the information about current scope. /// A scope is defined as a stack of conditional blocks /// (i.e., blocks that might not get executed during the runtime). /// Used mainly to support Python's variable scoping rules. struct ScopeBlock { int id; // Associated SuiteStmt Stmt *suite; /// List of variables "seen" before their assignment within a loop. /// Used to dominate variables that are updated within a loop. std::unique_ptr> seenVars = nullptr; ScopeBlock(int id, Stmt *s = nullptr) : id(id), suite(s), seenVars(nullptr) {} }; /// Current hierarchy of conditional blocks. std::vector scope; std::vector getScope() const { std::vector result; result.reserve(scope.size()); for (const auto &b : scope) result.emplace_back(b.id); return result; } struct Item : public codon::SrcObject { std::vector scope; ASTNode *binding = nullptr; bool ignore = false; /// List of scopes where the identifier is accessible /// without __used__ check std::vector> accessChecked; Item(const codon::SrcInfo &src, std::vector scope, ASTNode *binding = nullptr, std::vector> accessChecked = {}) : scope(std::move(scope)), binding(std::move(binding)), ignore(false), accessChecked(std::move(accessChecked)) { setSrcInfo(src); } }; std::unordered_map> map; std::unordered_map captures; std::unordered_map childCaptures; // for functions! std::map firstSeen; std::pair> classDeduce; bool adding = false; ASTNode *root = nullptr; FunctionStmt *functionScope = nullptr; bool inClass = false; // bool isConditional = false; std::vector> renames = {{}}; bool tempScope = false; // Time to track positions of assignments and references to them. int64_t time = 0; }; std::shared_ptr ctx = nullptr; struct ConditionalBlock { Context *ctx; ConditionalBlock(Context *ctx, Stmt *s, int id = -1) : ctx(ctx) { if (s) seqassertn(cast(s), "not a suite"); ctx->scope.emplace_back(id == -1 ? ctx->cache->blockCount++ : id, s); } ~ConditionalBlock() { seqassertn(!ctx->scope.empty() && (ctx->scope.back().id == 0 || ctx->scope.size() > 1), "empty scope"); ctx->scope.pop_back(); } }; public: ParserErrors errors; bool hasErrors() const { return !errors.empty(); } bool canContinue() const { return errors.size() <= MAX_ERRORS; } template void addError(error::Error e, const SrcInfo &o, const TA &...args) { auto msg = ErrorMessage(error::Emsg(e, args...), o.file, o.line, o.col, o.len, static_cast(e)); errors.addError({msg}); } template void addError(error::Error e, ASTNode *o, const TA &...args) { this->addError(e, o->getSrcInfo(), args...); } void addError(llvm::Error &&e) { llvm::handleAllErrors(std::move(e), [this](const error::ParserErrorInfo &e) { this->errors.append(e.getErrors()); }); } static llvm::Error apply(Cache *, Stmt *s, std::unordered_map * = nullptr); bool transform(Expr *expr) override; bool transform(Stmt *stmt) override; // Can error! bool visitName(const std::string &name, bool = false, ASTNode * = nullptr, const SrcInfo & = SrcInfo()); bool transformAdding(Expr *e, ASTNode *); bool transformScope(Expr *); bool transformScope(Stmt *); void visit(StringExpr *) override; void visit(IdExpr *) override; void visit(DotExpr *) override; void visit(IndexExpr *) override; void visit(GeneratorExpr *) override; void visit(IfExpr *) override; void visit(BinaryExpr *) override; void visit(LambdaExpr *) override; void visit(YieldExpr *) override; void visit(AssignExpr *) override; void visit(AssignStmt *) override; void visit(DelStmt *) override; void visit(YieldStmt *) override; void visit(WhileStmt *) override; void visit(ForStmt *) override; void visit(IfStmt *) override; void visit(MatchStmt *) override; void visit(ImportStmt *) override; void visit(TryStmt *) override; void visit(GlobalStmt *) override; void visit(FunctionStmt *) override; void visit(ClassStmt *) override; void visit(WithStmt *) override; Context::Item *findDominatingBinding(const std::string &, bool = true); void processChildCaptures(); void switchToUpdate(ASTNode *binding, const std::string &, bool); std::vector unpackFString(const std::string &value); template Tn *N(Ts &&...args) { Tn *t = ctx->cache->N(std::forward(args)...); t->setSrcInfo(getSrcInfo()); return t; } }; } // namespace codon::ast ================================================ FILE: codon/parser/visitors/translate/translate.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "translate.h" #include #include #include #include #include "codon/cir/transform/parallel/schedule.h" #include "codon/cir/util/cloning.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/visitors/translate/translate_ctx.h" #include "codon/parser/visitors/typecheck/typecheck.h" using codon::ir::cast; using codon::ir::transform::parallel::OMPSched; namespace codon::ast { TranslateVisitor::TranslateVisitor(std::shared_ptr ctx) : ctx(std::move(ctx)), result(nullptr) {} ir::Func *TranslateVisitor::apply(Cache *cache, Stmt *stmts) { ir::BodiedFunc *main = nullptr; if (cache->isJit) { auto fnName = fmt::format("_jit_{}", cache->jitCell); main = cache->module->Nr(fnName); main->setSrcInfo({"", 0, 0, 0}); main->setGlobal(); auto irType = cache->module->unsafeGetFuncType( fnName, cache->classes["NoneType"].realizations["NoneType"]->ir, {}, false); main->realize(irType, {}); main->setJIT(); } else { main = cast(cache->module->getMainFunc()); auto path = cache->fs->get_module0(); main->setSrcInfo({path, 0, 0, 0}); } auto block = cache->module->Nr("body"); main->setBody(block); if (!cache->codegenCtx) cache->codegenCtx = std::make_shared(cache); cache->codegenCtx->bases = {main}; cache->codegenCtx->series = {block}; cache->populatePythonModule(); TranslateVisitor(cache->codegenCtx).translateStmts(stmts); return main; } void TranslateVisitor::translateStmts(Stmt *stmts) const { initializeGlobals(); TranslateVisitor(ctx->cache->codegenCtx).transform(stmts); for (auto &f : ctx->cache->functions | std::views::values) TranslateVisitor(ctx->cache->codegenCtx).transform(f.ast); } void TranslateVisitor::initializeGlobals() const { for (auto &[name, ir] : ctx->cache->globals) if (!ir) { ir::types::Type *vt = nullptr; if (auto t = ctx->cache->typeCtx->forceFind(name)->getType()) { if (!t->isInstantiated() || (t->is(TYPE_TYPE)) || t->getFunc()) continue; vt = getType(t); } ir = name == VAR_ARGV ? ctx->cache->codegenCtx->getModule()->getArgVar() : ctx->cache->codegenCtx->getModule()->N( SrcInfo(), vt, true, false, false, name); ctx->cache->codegenCtx->add(TranslateItem::Var, name, ir); } } /************************************************************************************/ ir::Value *TranslateVisitor::transform(Expr *expr) { TranslateVisitor v(ctx); v.setSrcInfo(expr->getSrcInfo()); types::ClassType *p = nullptr; if (expr->hasAttribute(Attr::ExprList) || expr->hasAttribute(Attr::ExprSet) || expr->hasAttribute(Attr::ExprDict) || expr->hasAttribute(Attr::ExprPartial)) { ctx->seqItems.emplace_back(); } if (expr->hasAttribute(Attr::ExprPartial)) { p = expr->getType()->getPartial(); } expr->accept(v); ir::Value *ir = v.result; if (expr->hasAttribute(Attr::ExprList) || expr->hasAttribute(Attr::ExprSet)) { std::vector le; for (auto &pl : ctx->seqItems.back()) { seqassert(pl.first == Attr::ExprSequenceItem || pl.first == Attr::ExprStarSequenceItem, "invalid list/set element"); le.push_back( ir::LiteralElement{pl.second, pl.first == Attr::ExprStarSequenceItem}); } if (expr->hasAttribute(Attr::ExprList)) ir->setAttribute(std::make_unique(le)); else ir->setAttribute(std::make_unique(le)); ctx->seqItems.pop_back(); } if (expr->hasAttribute(Attr::ExprDict)) { std::vector dla; for (int pi = 0; pi < ctx->seqItems.back().size(); pi++) { auto &pl = ctx->seqItems.back()[pi]; if (pl.first == Attr::ExprStarSequenceItem) { dla.push_back({pl.second, nullptr}); } else { seqassert(pl.first == Attr::ExprSequenceItem && pi + 1 < ctx->seqItems.back().size() && ctx->seqItems.back()[pi + 1].first == Attr::ExprSequenceItem, "invalid dict element"); dla.push_back({pl.second, ctx->seqItems.back()[pi + 1].second}); pi++; } } ir->setAttribute(std::make_unique(dla)); ctx->seqItems.pop_back(); } if (expr->hasAttribute(Attr::ExprPartial)) { std::vector vals; seqassert(p, "invalid partial element"); int j = 0; auto known = p->getPartialMask(); auto func = p->getPartialFunc(); for (int i = 0; i < known.size(); i++) { if (known[i] == types::ClassType::PartialFlag::Included && (*func->ast)[i].isValue()) { seqassert(j < ctx->seqItems.back().size() && ctx->seqItems.back()[j].first == Attr::ExprSequenceItem, "invalid partial element"); vals.push_back(ctx->seqItems.back()[j++].second); } else if ((*func->ast)[i].isValue()) { vals.push_back({nullptr}); } } ir->setAttribute( std::make_unique(func->ast->getName(), vals)); ctx->seqItems.pop_back(); } if (expr->hasAttribute(Attr::ExprSequenceItem)) { ctx->seqItems.back().emplace_back(Attr::ExprSequenceItem, ir); } if (expr->hasAttribute(Attr::ExprStarSequenceItem)) { ctx->seqItems.back().emplace_back(Attr::ExprStarSequenceItem, ir); } return ir; } void TranslateVisitor::defaultVisit(Expr *n) { seqassert(false, "invalid node {}", n->toString()); } void TranslateVisitor::visit(NoneExpr *expr) { auto f = expr->getType()->realizedName() + ":" + getMangledMethod("std.internal.core", TYPE_OPTIONAL, "__new__"); auto val = ctx->find(f); seqassert(val, "cannot find '{}'", f); result = make(expr, make(expr, val->getFunc()), std::vector{}); } void TranslateVisitor::visit(BoolExpr *expr) { result = make(expr, expr->getValue(), getType(expr->getType())); } void TranslateVisitor::visit(IntExpr *expr) { result = make(expr, expr->getValue(), getType(expr->getType())); } void TranslateVisitor::visit(FloatExpr *expr) { result = make(expr, expr->getValue(), getType(expr->getType())); } void TranslateVisitor::visit(StringExpr *expr) { result = make(expr, expr->getValue(), getType(expr->getType())); } void TranslateVisitor::visit(IdExpr *expr) { auto val = ctx->find(expr->getValue()); if (!val) { // ctx->find(expr->getValue()); seqassert(val, "cannot find '{}'", expr->getValue()); } if (expr->getValue() == getMangledVar("", "__vtable_size__")) { // LOG("[] __vtable_size__={}", ctx->cache->classRealizationCnt + 2); result = make(expr, ctx->cache->classRealizationCnt + 2, getType(expr->getType())); } else if (auto *v = val->getVar()) { result = make(expr, v); } else if (auto *f = val->getFunc()) { result = make(expr, f); } else { // Just use NoneType which is {} (same as type) auto ntval = ctx->find(getMangledMethod("std.internal.core", "NoneType", "__new__")); seqassert(ntval, "cannot find '{}'", "NoneType.__new__"); result = make(expr, make(expr, ntval->getFunc()), std::vector{}); } } void TranslateVisitor::visit(IfExpr *expr) { auto cond = transform(expr->getCond()); auto ifexpr = transform(expr->getIf()); auto elsexpr = transform(expr->getElse()); result = make(expr, cond, ifexpr, elsexpr); } // Search expression tree for an identifier class IdVisitor : public CallbackASTVisitor { public: std::unordered_set ids; bool transform(Expr *expr) override { IdVisitor v; if (expr) expr->accept(v); ids.insert(v.ids.begin(), v.ids.end()); return true; } bool transform(Stmt *stmt) override { IdVisitor v; if (stmt) stmt->accept(v); ids.insert(v.ids.begin(), v.ids.end()); return true; } void visit(IdExpr *expr) override { ids.insert(expr->getValue()); } }; void TranslateVisitor::visit(GeneratorExpr *expr) { auto name = ctx->cache->imports[MAIN_IMPORT].ctx->generateCanonicalName("_generator"); ir::Func *fn = ctx->cache->module->Nr(name); fn->setGlobal(); fn->setGenerator(); std::vector names; std::vector types; std::vector items; IdVisitor v; expr->accept(v); for (auto &i : v.ids) { auto val = ctx->find(i); if (val && !val->getFunc() && !val->getType() && !val->getVar()->isGlobal()) { types.push_back(val->getVar()->getType()); names.push_back(i); items.emplace_back(make(expr, val->getVar())); } } auto irType = ctx->cache->module->unsafeGetFuncType( name, ctx->forceFind(expr->getType()->realizedName())->getType(), types, false); fn->realize(irType, names); ctx->addBlock(); for (auto &n : names) ctx->add(TranslateItem::Var, n, fn->getArgVar(n)); auto body = make(expr, "body"); ctx->bases.push_back(cast(fn)); ctx->addSeries(body); expr->setFinalStmt(ctx->cache->N(expr->getFinalExpr())); auto e = expr->getFinalSuite(); transform(e); ctx->popSeries(); ctx->bases.pop_back(); cast(fn)->setBody(body); ctx->popBlock(); result = make(expr, make(expr, fn), std::move(items)); } void TranslateVisitor::visit(CallExpr *expr) { auto ei = cast(expr->getExpr()); if (ei && ei->getValue() == getMangledFunc("std.internal.core", "__ptr__")) { auto head = expr->begin()->getExpr(); ir::FlowInstr *pre = cast(transform(head)); while (auto sexp = cast(head)) head = sexp->getExpr(); std::vector members; while (auto id = cast(head)) { members.emplace_back(id->getMember()); head = id->getExpr(); } std::ranges::reverse(members); auto id = cast(head); seqassert(id, "expected IdExpr, got {}", *((*expr)[0].value)); auto key = id->getValue(); auto val = ctx->find(key); seqassert(val && val->getVar(), "{} is not a variable", key); auto pv = make(expr, val->getVar(), members); if (pre) { pre->setValue(pv); result = pre; } else { result = pv; } return; } else if (ei && ei->getValue() == getMangledMethod("std.internal.core", "__array__", "__new__")) { auto fnt = expr->getExpr()->getType()->getFunc(); auto sz = fnt->funcGenerics[0].type->getIntStatic()->value; auto typ = fnt->funcParent->getClass()->generics[0].getType(); auto *arrayType = ctx->getModule()->unsafeGetArrayType(getType(typ)); arrayType->setAstType(expr->getType()->shared_from_this()); result = make(expr, arrayType, sz); return; } else if (ei && startswith(ei->getValue(), getMangledMethod("std.internal.core", "Generator", "_yield_in_no_suspend"))) { result = make(expr, getType(expr->getType()), false); return; } auto ft = expr->getExpr()->getType()->getFunc(); seqassert(ft, "not calling function"); auto callee = transform(expr->getExpr()); bool isVariadic = ft->ast->hasAttribute(Attr::CVarArg); std::vector items; size_t i = 0; for (auto &a : *expr) { seqassert(!cast(a.value), "ellipsis not elided"); if (i + 1 == expr->size() && isVariadic) { auto call = cast(a.value); seqassert(call, "expected *args tuple: '{}'", call->toString(0)); for (auto &arg : *call) items.emplace_back(transform(arg.value)); } else { items.emplace_back(transform(a.value)); } i++; } result = make(expr, callee, std::move(items)); } void TranslateVisitor::visit(DotExpr *expr) { if (expr->getMember() == "__atomic__" || expr->getMember() == "__elemsize__" || expr->getMember() == "__contents_atomic__") { auto ei = cast(expr->getExpr()); seqassert(ei, "expected IdExpr, got {}", *(expr->getExpr())); auto t = TypecheckVisitor(ctx->cache->typeCtx).extractType(ei->getType()); auto type = ctx->find(t->realizedName())->getType(); seqassert(type, "{} is not a type", ei->getValue()); result = make( expr, type, expr->getMember() == "__atomic__" ? ir::TypePropertyInstr::Property::IS_ATOMIC : (expr->getMember() == "__contents_atomic__" ? ir::TypePropertyInstr::Property::IS_CONTENT_ATOMIC : ir::TypePropertyInstr::Property::SIZEOF)); } else { result = make(expr, transform(expr->getExpr()), expr->getMember()); } } void TranslateVisitor::visit(YieldExpr *expr) { result = make(expr, getType(expr->getType())); } void TranslateVisitor::visit(PipeExpr *expr) { auto isGen = [](const ir::Value *v) -> bool { auto *type = v->getType(); if (ir::isA(type)) return true; else if (auto *fn = cast(type)) { return ir::isA(fn->getReturnType()); } return false; }; std::vector stages; auto *firstStage = transform((*expr)[0].expr); auto firstIsGen = isGen(firstStage); stages.emplace_back(firstStage, std::vector(), firstIsGen, false); // Pipeline without generators (just function call sugar) auto simplePipeline = !firstIsGen; for (auto i = 1; i < expr->size(); i++) { auto call = cast((*expr)[i].expr); seqassert(call, "{} is not a call", *((*expr)[i].expr)); auto fn = transform(call->getExpr()); if (i + 1 != expr->size()) simplePipeline &= !isGen(fn); std::vector args; args.reserve(call->size()); for (auto &a : *call) args.emplace_back(cast(a.value) ? nullptr : transform(a.value)); stages.emplace_back(fn, args, isGen(fn), false); } if (simplePipeline) { // Transform a |> b |> c to c(b(a)) ir::util::CloneVisitor cv(ctx->getModule()); result = cv.clone(stages[0].getCallee()); for (auto i = 1; i < stages.size(); ++i) { std::vector newArgs; for (auto arg : stages[i]) newArgs.push_back(arg ? cv.clone(arg) : result); result = make(expr, cv.clone(stages[i].getCallee()), newArgs); } } else { for (int i = 0; i < expr->size(); i++) if ((*expr)[i].op == "||>") stages[i].setParallel(); // This is a statement in IR. ctx->getSeries()->push_back(make(expr, stages)); } } void TranslateVisitor::visit(AwaitExpr *expr) { result = make(expr, transform(expr->getExpr()), getType(expr->getType()), expr->getExpr()->getType()->is( getMangledClass("std.internal.core", "Generator"))); } void TranslateVisitor::visit(StmtExpr *expr) { auto *bodySeries = make(expr, "body"); ctx->addSeries(bodySeries); for (auto &s : *expr) transform(s); ctx->popSeries(); result = make(expr, bodySeries, transform(expr->getExpr())); } /************************************************************************************/ ir::Value *TranslateVisitor::transform(Stmt *stmt) { TranslateVisitor v(ctx); v.setSrcInfo(stmt->getSrcInfo()); stmt->accept(v); if (v.result) ctx->getSeries()->push_back(v.result); return v.result; } void TranslateVisitor::defaultVisit(Stmt *n) { seqassert(false, "invalid node {}", n->toString()); } void TranslateVisitor::visit(SuiteStmt *stmt) { for (auto *s : *stmt) transform(s); } void TranslateVisitor::visit(BreakStmt *stmt) { result = make(stmt); } void TranslateVisitor::visit(ContinueStmt *stmt) { result = make(stmt); } void TranslateVisitor::visit(ExprStmt *stmt) { IdExpr *ei = nullptr; auto ce = cast(stmt->getExpr()); if (ce && ((ei = cast(ce->getExpr()))) && ei->getValue() == getMangledMethod("std.internal.core", "Generator", "_yield_final")) { result = make(stmt, transform((*ce)[0].value), true); ctx->getBase()->setGenerator(); } else { result = transform(stmt->getExpr()); } } void TranslateVisitor::visit(AssignStmt *stmt) { if (stmt->getLhs() && cast(stmt->getLhs()) && cast(stmt->getLhs())->getValue() == VAR_ARGV) return; auto lei = cast(stmt->getLhs()); seqassert(lei, "expected IdExpr, got {}", *stmt); auto var = lei->getValue(); auto isGlobal = in(ctx->cache->globals, var); ir::Var *v = nullptr; if (stmt->isUpdate()) { auto val = ctx->find(lei->getValue()); seqassert(val && val->getVar(), "{} is not a variable", lei->getValue()); v = val->getVar(); if (!v->getType()) { v->setSrcInfo(stmt->getSrcInfo()); v->setType(getType(stmt->getRhs()->getType())); } result = make(stmt, v, transform(stmt->getRhs())); return; } if (!stmt->getLhs()->getType()->isInstantiated() || (stmt->getLhs()->getType()->is(TYPE_TYPE)) || stmt->getLhs()->getType()->getFunc()) { if (!cast(stmt->getRhs())) { // Side effect result = transform(stmt->getRhs()); } return; // type aliases/fn aliases etc } if (isGlobal) { seqassert(ctx->find(var) && ctx->find(var)->getVar(), "cannot find global '{}'", var); v = ctx->find(var)->getVar(); v->setSrcInfo(stmt->getSrcInfo()); v->setType(getType((stmt->getRhs() ? stmt->getRhs() : stmt->getLhs())->getType())); } else { v = make( stmt, getType((stmt->getRhs() ? stmt->getRhs() : stmt->getLhs())->getType()), false, false, false, var); ctx->getBase()->push_back(v); ctx->add(TranslateItem::Var, var, v); } // Check if it is thread-local if (stmt->isThreadLocal()) { v->setThreadLocal(); v->setGlobal(); } // Check if it is a C variable if (stmt->getLhs()->hasAttribute(Attr::ExprExternVar)) { v->setExternal(); v->setName(ctx->cache->rev(var)); v->setGlobal(); return; } if (stmt->getRhs()) { result = make(stmt, v, transform(stmt->getRhs())); } } void TranslateVisitor::visit(AssignMemberStmt *stmt) { result = make(stmt, transform(stmt->getLhs()), stmt->getMember(), transform(stmt->getRhs())); } void TranslateVisitor::visit(ReturnStmt *stmt) { result = make(stmt, stmt->getExpr() ? transform(stmt->getExpr()) : nullptr); } void TranslateVisitor::visit(YieldStmt *stmt) { result = make(stmt, stmt->getExpr() ? transform(stmt->getExpr()) : nullptr); ctx->getBase()->setGenerator(); } void TranslateVisitor::visit(WhileStmt *stmt) { auto loop = make(stmt, transform(stmt->getCond()), make(stmt, "body")); ctx->addSeries(cast(loop->getBody())); transform(stmt->getSuite()); ctx->popSeries(); result = loop; } void TranslateVisitor::visit(ForStmt *stmt) { std::unique_ptr os = nullptr; if (stmt->getDecorator()) { auto c = cast(stmt->getDecorator()); seqassert(c, "for par is not a call: {}", *(stmt->getDecorator())); auto fc = c->getExpr()->getType()->getFunc(); seqassert(fc && fc->ast->getName() == getMangledFunc("std.openmp", "for_par"), "for par is not a function"); auto schedule = fc->funcGenerics[0].type->getStrStatic()->value; bool ordered = fc->funcGenerics[1].type->getBoolStatic()->value; auto threads = transform((*c)[0].value); auto chunk = transform((*c)[1].value); auto collapse = fc->funcGenerics[2].type->getIntStatic()->value; bool gpu = fc->funcGenerics[3].type->getBoolStatic()->value; os = std::make_unique(schedule, threads, chunk, ordered, collapse, gpu); } seqassert(cast(stmt->getVar()), "expected IdExpr, got {}", *(stmt->getVar())); auto varName = cast(stmt->getVar())->getValue(); ir::Var *var = nullptr; if (!ctx->find(varName) || !stmt->getVar()->hasAttribute(Attr::ExprDominated)) { var = make(stmt, getType(stmt->getVar()->getType()), false, false, false, varName); } else { var = ctx->find(varName)->getVar(); } ctx->getBase()->push_back(var); auto bodySeries = make(stmt, "body"); auto loop = make(stmt, transform(stmt->getIter()), bodySeries, var); if (stmt->isAsync()) loop->setAsync(); if (os) loop->setSchedule(std::move(os)); ctx->add(TranslateItem::Var, varName, var); ctx->addSeries(cast(loop->getBody())); transform(stmt->getSuite()); ctx->popSeries(); result = loop; } void TranslateVisitor::visit(IfStmt *stmt) { auto cond = transform(stmt->getCond()); auto trueSeries = make(stmt, "ifstmt_true"); ctx->addSeries(trueSeries); transform(stmt->getIf()); ctx->popSeries(); ir::SeriesFlow *falseSeries = nullptr; if (stmt->getElse()) { falseSeries = make(stmt, "ifstmt_false"); ctx->addSeries(falseSeries); transform(stmt->getElse()); ctx->popSeries(); } result = make(stmt, cond, trueSeries, falseSeries); } void TranslateVisitor::visit(TryStmt *stmt) { auto *bodySeries = make(stmt, "body"); ctx->addSeries(bodySeries); transform(stmt->getSuite()); ctx->popSeries(); ir::SeriesFlow *finallySeries = make(stmt, "finally"); if (stmt->getFinally()) { ctx->addSeries(finallySeries); transform(stmt->getFinally()); ctx->popSeries(); } ir::SeriesFlow *elseSeries = nullptr; if (stmt->getElse()) { elseSeries = make(stmt, "else"); ctx->addSeries(elseSeries); transform(stmt->getElse()); ctx->popSeries(); } auto *tc = make(stmt, bodySeries, finallySeries, elseSeries); for (auto *c : *stmt) { auto *catchBody = make(stmt, "catch"); auto *excType = c->getException() ? getType(TypecheckVisitor(ctx->cache->typeCtx) .extractType(c->getException()->getType())) : nullptr; ir::Var *catchVar = nullptr; if (!c->getVar().empty()) { if (!ctx->find(c->getVar()) || !c->hasAttribute(Attr::ExprDominated)) { catchVar = make(stmt, excType, false, false, false, c->getVar()); } else { catchVar = ctx->find(c->getVar())->getVar(); } ctx->add(TranslateItem::Var, c->getVar(), catchVar); ctx->getBase()->push_back(catchVar); } ctx->addSeries(catchBody); transform(c->getSuite()); ctx->popSeries(); tc->push_back(ir::TryCatchFlow::Catch(catchBody, excType, catchVar)); } result = tc; } void TranslateVisitor::visit(ThrowStmt *stmt) { result = make(stmt, stmt->getExpr() ? transform(stmt->getExpr()) : nullptr); } void TranslateVisitor::visit(FunctionStmt *stmt) { // Process all realizations. transformFunctionRealizations(stmt->getName(), stmt->hasAttribute(Attr::LLVM)); } void TranslateVisitor::visit(ClassStmt *stmt) { // Nothing to see here, as all type handles are already generated. // Methods will be handled by FunctionStmt visitor. } /************************************************************************************/ codon::ir::types::Type *TranslateVisitor::getType(types::Type *t) const { seqassert(t && t->getClass(), "not a class: {}", t ? t->debugString(2) : "-"); std::string name = t->getClass()->ClassType::realizedName(); auto i = ctx->find(name); seqassert(i, "type {} not realized: {}", t->debugString(2), name); seqassert(i->getType(), "type {} not IR-realized: {}", t->debugString(2), name); return i->getType(); } void TranslateVisitor::transformFunctionRealizations(const std::string &name, bool isLLVM) { for (auto &real : ctx->cache->functions[name].realizations) { if (!in(ctx->cache->pendingRealizations, make_pair(name, real.first))) continue; ctx->cache->pendingRealizations.erase(make_pair(name, real.first)); LOG_TYPECHECK("[translate] generating fn {}", real.first); real.second->ir->setSrcInfo(getSrcInfo()); const auto &ast = real.second->ast; seqassert(ast, "AST not set for {}", real.first); if (!isLLVM) transformFunction(real.second->type.get(), ast, real.second->ir); else transformLLVMFunction(real.second->type.get(), ast, real.second->ir); } } void TranslateVisitor::transformFunction(const types::FuncType *type, FunctionStmt *ast, ir::Func *func) { std::vector names; std::vector indices; for (int i = 0, j = 0; i < ast->size(); i++) if ((*ast)[i].isValue()) { if (!(*type)[j]->getFunc()) { names.push_back(ctx->cache->rev((*ast)[i].name)); indices.push_back(i); } j++; } if (ast->hasAttribute(Attr::CVarArg)) { names.pop_back(); indices.pop_back(); } // TODO: refactor IR attribute API std::unordered_map attr; if (ast->hasAttribute(Attr::FunctionAttributes)) attr = ast->getAttribute(Attr::FunctionAttributes)->attributes; attr[".module"] = ast->getAttribute(Attr::Module)->value; func->setAttribute(std::make_unique(attr)); for (int i = 0; i < names.size(); i++) func->getArgVar(names[i])->setSrcInfo((*ast)[indices[i]].getSrcInfo()); // func->setUnmangledName(ctx->cache->reverseIdentifierLookup[type->ast->name]); if (!ast->hasAttribute(Attr::C) && !ast->hasAttribute(Attr::Internal)) { ctx->addBlock(); for (auto i = 0; i < names.size(); i++) ctx->add(TranslateItem::Var, (*ast)[indices[i]].name, func->getArgVar(names[i])); auto body = make(ast, "body"); ctx->bases.push_back(cast(func)); ctx->addSeries(body); transform(ast->getSuite()); ctx->popSeries(); ctx->bases.pop_back(); cast(func)->setBody(body); ctx->popBlock(); } if (ast->isAsync()) func->setAsync(); } void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt *ast, ir::Func *func) const { std::vector names; std::vector indices; for (int i = 0, j = 1; i < ast->size(); i++) if ((*ast)[i].isValue()) { names.push_back(ctx->cache->reverseIdentifierLookup[(*ast)[i].name]); indices.push_back(i); j++; } auto f = cast(func); seqassert(f, "not a function"); std::unordered_map attr; if (ast->hasAttribute(Attr::FunctionAttributes)) attr = ast->getAttribute(Attr::FunctionAttributes)->attributes; attr[".module"] = ast->getAttribute(Attr::Module)->value; func->setAttribute(std::make_unique(attr)); for (int i = 0; i < names.size(); i++) func->getArgVar(names[i])->setSrcInfo((*ast)[indices[i]].getSrcInfo()); seqassert( ast->getSuite()->firstInBlock() && cast(ast->getSuite()->firstInBlock()) && cast(cast(ast->getSuite()->firstInBlock())->getExpr()), "LLVM function does not begin with a string"); std::istringstream sin( cast(cast(ast->getSuite()->firstInBlock())->getExpr()) ->getValue()); std::vector literals; auto ss = cast(ast->getSuite()); for (int i = 1; i < ss->size(); i++) { if (auto sti = cast((*ss)[i])->getExpr()->getType()->getIntStatic()) { literals.emplace_back(sti->value); } else if (auto sts = cast((*ss)[i])->getExpr()->getType()->getStrStatic()) { literals.emplace_back(sts->value); } else { seqassert(cast((*ss)[i])->getExpr()->getType(), "invalid LLVM type argument: {}", (*ss)[i]->toString(0)); literals.emplace_back( getType(TypecheckVisitor(ctx->cache->typeCtx) .extractType(cast((*ss)[i])->getExpr()->getType()))); } } bool isDeclare = true; std::string declare; std::vector lines; for (std::string l; getline(sin, l);) { std::string lp = l; ltrim(lp); rtrim(lp); // Extract declares and constants. if (isDeclare && !startswith(lp, "declare ") && !startswith(lp, "@")) { bool isConst = lp.find("private constant") != std::string::npos; if (!isConst) { isDeclare = false; if (!lp.empty() && lp.back() != ':') lines.emplace_back("entry:"); } } if (isDeclare) declare += lp + "\n"; else lines.emplace_back(l); } f->setLLVMBody(join(lines, "\n")); f->setLLVMDeclarations(declare); f->setLLVMLiterals(literals); // func->setUnmangledName(ctx->cache->reverseIdentifierLookup[type->ast->name]); } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/translate/translate.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include "codon/cir/cir.h" #include "codon/parser/ast.h" #include "codon/parser/visitors/translate/translate_ctx.h" #include "codon/parser/visitors/visitor.h" namespace codon::ast { class TranslateVisitor : public CallbackASTVisitor { std::shared_ptr ctx; ir::Value *result; public: explicit TranslateVisitor(std::shared_ptr ctx); static codon::ir::Func *apply(Cache *cache, Stmt *stmts); void translateStmts(Stmt *stmts) const; ir::Value *transform(Expr *expr) override; ir::Value *transform(Stmt *stmt) override; void initializeGlobals() const; private: void defaultVisit(Expr *expr) override; void defaultVisit(Stmt *expr) override; public: void visit(NoneExpr *) override; void visit(BoolExpr *) override; void visit(IntExpr *) override; void visit(FloatExpr *) override; void visit(StringExpr *) override; void visit(IdExpr *) override; void visit(IfExpr *) override; void visit(GeneratorExpr *) override; void visit(CallExpr *) override; void visit(DotExpr *) override; void visit(YieldExpr *) override; void visit(StmtExpr *) override; void visit(PipeExpr *) override; void visit(AwaitExpr *) override; void visit(SuiteStmt *) override; void visit(BreakStmt *) override; void visit(ContinueStmt *) override; void visit(ExprStmt *) override; void visit(AssignStmt *) override; void visit(AssignMemberStmt *) override; void visit(ReturnStmt *) override; void visit(YieldStmt *) override; void visit(WhileStmt *) override; void visit(ForStmt *) override; void visit(IfStmt *) override; void visit(TryStmt *) override; void visit(ThrowStmt *) override; void visit(FunctionStmt *) override; void visit(ClassStmt *) override; void visit(CommentStmt *) override {} void visit(DirectiveStmt *) override {} private: ir::types::Type *getType(types::Type *t) const; void transformFunctionRealizations(const std::string &name, bool isLLVM); void transformFunction(const types::FuncType *type, FunctionStmt *ast, ir::Func *func); void transformLLVMFunction(types::FuncType *type, FunctionStmt *ast, ir::Func *func) const; template ValueType *make(Args &&...args) { auto *ret = ctx->getModule()->N(std::forward(args)...); return ret; } }; } // namespace codon::ast ================================================ FILE: codon/parser/visitors/translate/translate_ctx.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "translate_ctx.h" #include #include #include "codon/parser/common.h" #include "codon/parser/ctx.h" #include "codon/parser/visitors/translate/translate.h" #include "codon/parser/visitors/typecheck/ctx.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon::ast { TranslateContext::TranslateContext(Cache *cache) : Context(""), cache(cache) {} std::shared_ptr TranslateContext::find(const std::string &name) const { if (auto t = Context::find(name)) return t; std::shared_ptr ret = nullptr; auto tt = cache->typeCtx->find(name); if (tt && tt->isType() && tt->type->canRealize()) { auto t = tt->getType(); if (name != t->realizedName()) // type prefix t = TypecheckVisitor(cache->typeCtx).extractType(t); auto n = t->getClass()->name; if (!in(cache->classes, n) || !in(cache->classes[n].realizations, name)) return nullptr; ret = std::make_shared(TranslateItem::Type, bases[0]); ret->handle.type = cache->classes[n].realizations[name]->ir; } else if (tt && tt->type->getFunc() && tt->type->canRealize()) { ret = std::make_shared(TranslateItem::Func, bases[0]); seqassertn( in(cache->functions, tt->type->getFunc()->ast->getName()) && in(cache->functions[tt->type->getFunc()->ast->getName()].realizations, name), "cannot find function realization {}", name); ret->handle.func = cache->functions[tt->type->getFunc()->ast->getName()].realizations[name]->ir; } return ret; } std::shared_ptr TranslateContext::forceFind(const std::string &name) const { auto i = find(name); seqassertn(i, "cannot find '{}'", name); return i; } std::shared_ptr TranslateContext::add(TranslateItem::Kind kind, const std::string &name, void *type) { auto it = std::make_shared(kind, getBase()); if (kind == TranslateItem::Var) it->handle.var = static_cast(type); else if (kind == TranslateItem::Func) it->handle.func = static_cast(type); else it->handle.type = static_cast(type); add(name, it); return it; } void TranslateContext::addSeries(codon::ir::SeriesFlow *s) { series.push_back(s); } void TranslateContext::popSeries() { series.pop_back(); } codon::ir::Module *TranslateContext::getModule() const { return dynamic_cast(bases[0]->getModule()); } codon::ir::BodiedFunc *TranslateContext::getBase() const { return bases.back(); } codon::ir::SeriesFlow *TranslateContext::getSeries() const { return series.back(); } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/translate/translate_ctx.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include "codon/cir/cir.h" #include "codon/cir/types/types.h" #include "codon/parser/cache.h" #include "codon/parser/ctx.h" namespace codon::ast { /** * IR context object description. * This represents an identifier that can be either a function, a class (type), or a * variable. */ struct TranslateItem { enum Kind { Func, Type, Var } kind; /// IR handle. union { codon::ir::Var *var; codon::ir::Func *func; codon::ir::types::Type *type; } handle; /// Base function pointer. codon::ir::BodiedFunc *base; TranslateItem(Kind k, codon::ir::BodiedFunc *base) : kind(k), handle{nullptr}, base(base) {} const codon::ir::BodiedFunc *getBase() const { return base; } codon::ir::Func *getFunc() const { return kind == Func ? handle.func : nullptr; } codon::ir::types::Type *getType() const { return kind == Type ? handle.type : nullptr; } codon::ir::Var *getVar() const { return kind == Var ? handle.var : nullptr; } }; /** * A variable table (context) for the IR translation stage. */ struct TranslateContext : public Context { /// A pointer to the shared cache. Cache *cache; /// Stack of function bases. std::vector bases; /// Stack of IR series (blocks). std::vector series; /// Stack of sequence items for attribute initialization. std::vector>> seqItems; public: TranslateContext(Cache *cache); using Context::add; /// Convenience method for adding an object to the context. std::shared_ptr add(TranslateItem::Kind kind, const std::string &name, void *type); std::shared_ptr find(const std::string &name) const override; std::shared_ptr forceFind(const std::string &name) const; /// Convenience method for adding a series. void addSeries(codon::ir::SeriesFlow *s); void popSeries(); public: codon::ir::Module *getModule() const; codon::ir::BodiedFunc *getBase() const; codon::ir::SeriesFlow *getSeries() const; }; } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/access.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" #include "codon/parser/match.h" #include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; using namespace codon::matcher; namespace codon::ast { using namespace types; /// Typecheck identifiers. /// If an identifier is a static variable, evaluate it and replace it with its value /// (e.g., change `N` to `IntExpr(16)`). /// If the identifier of a generic is fully qualified, use its qualified name /// (e.g., replace `Ptr` with `Ptr[byte]`). void TypecheckVisitor::visit(IdExpr *expr) { auto val = ctx->find(expr->getValue(), getTime()); if (!val) { E(Error::ID_NOT_FOUND, expr, expr->getValue()); } // If this is an overload, use the dispatch function if (isUnbound(expr) && hasOverloads(val->getName())) { val = ctx->forceFind(getDispatch(val->getName())->getFuncName()); } // If we are accessing an outside variable, capture it or raise an error checkCapture(val); // Replace the variable with its canonical name expr->value = val->getName(); // Set up type unify(expr->getType(), instantiateType(val->getType())); if (auto f = expr->getType()->getFunc()) { expr->value = f->getFuncName(); // resolve overloads } // Realize a type or a function if possible and replace the identifier with // a qualified identifier or a static expression (e.g., `foo` -> `foo[int]`) if (expr->getType()->canRealize()) { if (auto s = expr->getType()->getStatic()) { resultExpr = transform(s->getStaticExpr()); return; } if (!val->isVar()) { if (!(expr->hasAttribute(Attr::ExprDoNotRealize) && expr->getType()->getFunc())) { if (auto r = realize(expr->getType())) { expr->value = r->realizedName(); expr->setDone(); } } } else { realize(expr->getType()); expr->setDone(); } } // If this identifier needs __used__ checks (see @c ScopeVisitor), add them if (expr->hasAttribute(Attr::ExprDominatedUndefCheck)) { auto controlVar = fmt::format("{}{}", getUnmangledName(val->canonicalName), VAR_USED_SUFFIX); if (ctx->find(controlVar, getTime())) { auto checkStmt = N(N( N(N("__internal__"), "undef"), N(controlVar), N(getUnmangledName(val->canonicalName)))); expr->eraseAttribute(Attr::ExprDominatedUndefCheck); resultExpr = transform(N(checkStmt, expr)); } } } /// Transform a dot expression. Select the best method overload if possible. /// @example /// `obj.__class__` -> `type(obj)` /// `cls.__name__` -> `"class"` (same for functions) /// `obj.method` -> `cls.method(obj, ...)` or /// `cls.method(obj)` if method has `@property` attribute /// `obj.member` -> see @c getClassMember /// @return nullptr if no transformation was made void TypecheckVisitor::visit(DotExpr *expr) { // Check if this is being called from CallExpr (e.g., foo.bar(...)) // and mark it as such (to inline useless partial expression) auto *parentCall = cast(ctx->getParentNode()); if (parentCall && !parentCall->hasAttribute(Attr::ParentCallExpr)) parentCall = nullptr; // Flatten imports: // `a.b.c` -> canonical name of `c` in `a.b` if `a.b` is an import // `a.B.c` -> canonical name of `c` in class `a.B` // `python.foo` -> internal.python._get_identifier("foo") std::vector chain; Expr *head = expr; for (; cast(head); head = cast(head)->getExpr()) chain.push_back(cast(head)->getMember()); Expr *final = expr; if (auto id = cast(head)) { // Case: a.bar.baz chain.push_back(id->getValue()); std::ranges::reverse(chain); auto [pos, val] = getImport(chain); if (!val) { // Python capture seqassert(ctx->getBase()->pyCaptures, "unexpected py capture"); ctx->getBase()->pyCaptures->insert(chain[0]); final = N(N("__pyenv__"), N(chain[0])); } else if (val->getModule() == "std.python" && ctx->getModule() != val->getModule()) { // Import from python (e.g., pyobj.foo) final = transform(N( N(N(N("internal"), "python"), "_get_identifier"), N(chain[pos++]))); } else if (val->getModule() == ctx->getModule() && pos == 1) { final = transform(N(chain[0]), true); } else { final = N(val->canonicalName); } while (pos < chain.size()) final = N(final, chain[pos++]); } if (auto dot = cast(final)) { expr->expr = dot->getExpr(); expr->member = dot->getMember(); } else { resultExpr = transform(final); return; } // Special case: obj.__class__ if (expr->getMember() == "__class__") { /// TODO: prevent cls.__class__ and type(cls) resultExpr = transform(N(N(TYPE_TYPE), expr->getExpr())); return; } expr->expr = transform(expr->getExpr()); bool hasSide = hasSideEffect(expr->getExpr()); auto wrapSide = [&](Expr *e) -> Expr * { if (hasSide) { return transform(N(N(expr->getExpr()), e)); } return transform(e); }; // Special case: fn.__name__ // Should go before cls.__name__ to allow printing generic functions // NOTE: NO SIDE EFFECTS! (because of generic function support!) if (extractType(expr->getExpr())->getFunc() && expr->getMember() == "__name__") { resultExpr = N(extractType(expr->getExpr())->prettyString()); return; } if (expr->getExpr()->getType()->getPartial() && expr->getMember() == "__name__") { resultExpr = wrapSide( N(expr->getExpr()->getType()->getPartial()->prettyString())); return; } // Special case: fn.__llvm_name__ or obj.__llvm_name__ if (expr->getMember() == "__llvm_name__") { if (realize(expr->getExpr()->getType())) resultExpr = wrapSide(N(expr->getExpr()->getType()->realizedName())); return; } // Special case: cls.__name__ if (isTypeExpr(expr->getExpr()) && expr->getMember() == "__name__") { if (realize(expr->getExpr()->getType())) resultExpr = wrapSide(N(extractType(expr->getExpr())->prettyString())); return; } // Special case: cls.__mro__ if (isTypeExpr(expr->getExpr()) && expr->getMember() == "__mro__") { if (realize(expr->getExpr()->getType())) { auto t = extractType(expr->getExpr())->getClass(); auto bases = getRTTISuperTypes(t); std::vector items; for (size_t i = 1; i < bases.size(); i++) items.push_back(N(bases[i]->realizedName())); resultExpr = wrapSide(N(items)); } return; } if (isTypeExpr(expr->getExpr()) && expr->getMember() == "__repr__") { resultExpr = transform(N( N(getMangledFunc("std.internal.types.type", "__type_repr__")), expr->getExpr(), N(EllipsisExpr::PARTIAL))); return; } // Special case: expr.__is_static__ if (expr->getMember() == "__is_static__") { if (expr->getExpr()->isDone()) resultExpr = wrapSide( N(static_cast(expr->getExpr()->getType()->getStatic()))); return; } // Special case: cls.__id__ if (isTypeExpr(expr->getExpr()) && expr->getMember() == "__id__") { if (auto c = realize(extractType(expr->getExpr()))) resultExpr = wrapSide(N(getClassRealization(c)->id)); return; } // Ensure that the type is known (otherwise wait until it becomes known) auto typ = extractClassType(expr->getExpr()); if (!typ) return; // Check if this is a method or member access while (true) { auto methods = findMethod(typ, expr->getMember()); if (methods.empty()) resultExpr = getClassMember(expr); // If the expression changed during the @c getClassMember (e.g., optional unwrap), // keep going further to be able to select the appropriate method or member auto oldExpr = expr->getExpr(); if (!resultExpr && expr->getExpr() != oldExpr) { typ = extractClassType(expr->getExpr()); if (!typ) return; // delay typechecking continue; } if (!methods.empty()) { // If a method is ambiguous use dispatch auto bestMethod = methods.size() > 1 ? getDispatch(getRootName(methods.front())) : methods.front(); Expr *e = N(bestMethod->getFuncName()); e->setType(instantiateType(bestMethod, typ)); if (isTypeExpr(expr->getExpr())) { // Static access: `cls.method` unify(expr->getType(), e->getType()); resultExpr = wrapSide(e); return; } auto vt = expr->getExpr()->getType(); if (vt->is("Super")) vt = extractClassGeneric(vt); auto cls = getClass(vt); bool isVirtual = cls && cls->rtti && static_cast( in(cls->virtuals, getUnmangledName(bestMethod->ast->getName()))) && !isDispatch(bestMethod) && !bestMethod->ast->hasAttribute(Attr::StaticMethod) && !bestMethod->ast->hasAttribute(Attr::Property); if (parentCall && !bestMethod->ast->hasAttribute(Attr::StaticMethod) && !bestMethod->ast->hasAttribute(Attr::Property) && !isVirtual) { // Instance access: `obj.method` from the call // Modify the call to push `self` to the front of the argument list. // Avoids creating partial functions. parentCall->items.insert(parentCall->items.begin(), expr->getExpr()); unify(expr->getType(), e->getType()); resultExpr = transform(e); } else { if (isVirtual) { Expr *id = nullptr, *slf = nullptr; if (expr->getExpr()->getType()->is("Super")) { id = N(N(expr->getExpr(), "__T__"), "__id__"); slf = N(expr->getExpr(), "__obj__"); } else { id = N(0); slf = expr->getExpr(); } resultExpr = transform( N(N(getMangledMethod("std.internal.core", "RTTIType", "_thunk_dispatch")), std::vector{ CallArg{"slf", slf}, CallArg{"cls_id", id}, CallArg{"F", N(bestMethod->ast->name)}, CallArg{"", N(EllipsisExpr::PARTIAL)}})); return; } // Instance access: `obj.method` // Transform y.method to a partial call `type(y).method(y, ...)` std::vector methodArgs; // Do not add self if a method is marked with @staticmethod if (!bestMethod->ast->hasAttribute(Attr::StaticMethod)) { methodArgs.emplace_back(expr->getExpr()); } // If a method is marked with @property, just call it directly if (!bestMethod->ast->hasAttribute(Attr::Property)) { methodArgs.emplace_back(N(EllipsisExpr::PARTIAL)); } resultExpr = transform(N(e, methodArgs)); } } break; } } /// Access identifiers outside of the current function/class scope. /// Either use them as-is (globals), capture them if allowed (nonlocals), /// or raise an error. void TypecheckVisitor::checkCapture(const TypeContext::Item &val) const { if (!ctx->isOuter(val)) return; if ((val->isType() && !val->isGeneric()) || val->isFunc()) return; // Ensure that outer variables can be captured (i.e., do not cross no-capture // boundary). Example: // def foo(): // x = 1 // class T: # <- boundary (classes cannot capture locals) // t: int = x # x cannot be accessed // def bar(): # <- another boundary // # (class methods cannot capture locals except class generics) // print(x) # x cannot be accessed bool crossCaptureBoundary = false; bool localGeneric = val->isGeneric() && val->getBaseName() == ctx->getBaseName(); bool parentClassGeneric = val->isGeneric() && !ctx->getBase()->isType() && (ctx->bases.size() > 1 && ctx->bases[ctx->bases.size() - 2].isType() && ctx->bases[ctx->bases.size() - 2].name == val->getBaseName()); auto i = ctx->bases.size(); while (i-- > 0) { if (ctx->bases[i].name == val->getBaseName()) break; if (!localGeneric && !parentClassGeneric) crossCaptureBoundary = true; } // Mark methods (class functions that access class generics) if (parentClassGeneric) ctx->getBase()->func->setAttribute(Attr::Method); // Ignore generics if (parentClassGeneric || localGeneric) return; // Case: a global variable that has not been marked with `global` statement if (val->isVar() && val->getBaseName().empty() && val->scope.size() == 1) { registerGlobal(val->getName()); return; } // Not in module; probably some capture if (val->getModule() != ctx->getModule()) return; // Check if a real variable (not a static) is defined outside the current scope if (crossCaptureBoundary) E(Error::ID_CANNOT_CAPTURE, getSrcInfo(), getUserFacingName(val->getName())); // Case: a nonlocal variable that has not been marked with `nonlocal` statement // and capturing is *not* enabled E(Error::ID_NONLOCAL, getSrcInfo(), getUserFacingName(val->getName())); } /// Check if a chain (a.b.c.d...) contains an import or a class prefix. std::pair TypecheckVisitor::getImport(const std::vector &chain) { size_t importEnd = 0; std::string importName; // 1. Find the longest prefix that corresponds to the existing import // (e.g., `a.b.c.d` -> `a.b.c` if there is `import a.b.c`) TypeContext::Item val = nullptr, importVal = nullptr; for (auto i = chain.size(); i-- > 0;) { auto name = join(chain, "/", 0, i + 1); val = ctx->find(name, getTime()); if (val && val->type->is("Import") && startswith(val->getName(), "%_import_")) { importName = getStrLiteral(val->type.get()); importEnd = i + 1; importVal = val; break; } } // Case: the whole chain is import itself if (importEnd == chain.size()) return {importEnd, val}; // Find the longest prefix that corresponds to the existing class // (e.g., `a.b.c` -> `a.b` if there is `class a: class b:`) std::string itemName; size_t itemEnd = 0; auto ictx = importName.empty() ? ctx : getImport(importName)->ctx; for (auto i = chain.size(); i-- > importEnd;) { if (ictx->getModule() == "std.python" && importEnd < chain.size()) { // Special case: importing from Python. // Fake TypecheckItem that indicates std.python access val = std::make_shared( "", "", ictx->getModule(), TypecheckVisitor(ictx).instantiateUnbound()); return {importEnd, val}; } else { auto key = join(chain, ".", importEnd, i + 1); // check only globals for imports! val = ictx->find(key, 0, importName.empty() ? nullptr : ""); if (val && i + 1 != chain.size() && val->getType()->is("Import") && startswith(val->getName(), "%_import_")) { importName = getStrLiteral(val->getType()); importEnd = i + 1; importVal = val; ictx = getImport(importName)->ctx; i = chain.size(); continue; } bool isOverload = val && val->isFunc() && hasOverloads(val->canonicalName); if (isOverload && importEnd == i) { // top-level overload itemName = val->canonicalName, itemEnd = i + 1; break; } // Class member if (val && !isOverload && (importName.empty() || val->isType() || !val->isConditional())) { itemName = val->canonicalName, itemEnd = i + 1; break; } // Resolve the identifier from the import if (auto imp = ctx->find("Import")) { auto t = extractClassType(imp->getType()); if (findMember(t, key)) return {importEnd, importVal}; if (!findMethod(t, key).empty()) return {importEnd, importVal}; } } } if (itemName.empty() && importName.empty()) { if (ctx->getBase()->pyCaptures) return {1, nullptr}; E(Error::IMPORT_NO_MODULE, getSrcInfo(), chain[importEnd]); } else if (itemName.empty()) { auto import = getImport(importName); if (!ctx->isStdlibLoading && endswith(importName, "__init__.codon")) { // Special case: subimport is not yet loaded // (e.g., import a; a.b.x where a.b is a module as well) if (auto file = getImportFile(ctx->cache, chain[importEnd], importName)) { // Auto-load support Stmt *s = N(N(N(chain[importEnd]), nullptr)); if (auto err = ScopingVisitor::apply(ctx->cache, s)) throw exc::ParserException(std::move(err)); s = TypecheckVisitor(import->ctx, preamble).transform(s); prependStmts->push_back(s); return getImport(chain); } } E(Error::IMPORT_NO_NAME, getSrcInfo(), chain[importEnd], import->name); } importEnd = itemEnd; return {importEnd, val}; } /// Find or generate an overload dispatch function for a given overload. /// Dispatch functions ensure that a function call is being routed to the correct /// overload /// even when dealing with partial functions and decorators. /// @example /// This is how dispatch looks like: /// ```def foo:dispatch(*args, **kwargs): /// return foo(*args, **kwargs)``` types::FuncType *TypecheckVisitor::getDispatch(const std::string &fn) { auto &overloads = ctx->cache->overloads[fn]; // Single overload: just return it if (overloads.size() == 1) return ctx->forceFind(overloads.front())->type->getFunc(); // Check if dispatch exists for (auto &m : overloads) if (isDispatch(getFunction(m)->ast)) return getFunction(m)->getType(); // Dispatch does not exist. Generate it auto name = fmt::format("{}{}", fn, FN_DISPATCH_SUFFIX); Expr *root; // Root function name used for calling auto ofn = getFunction(overloads[0]); auto aa = ofn->ast->getAttribute(Attr::ParentClass); if (aa) root = N(N(aa->value), getUnmangledName(fn)); else root = N(fn); root = N(root, N(N("args")), N(N("kwargs"))); auto nar = ctx->generateCanonicalName("args"); auto nkw = ctx->generateCanonicalName("kwargs"); auto ast = N(name, nullptr, std::vector{Param("*" + nar), Param("**" + nkw)}, N(N(root))); ast->setAttribute(Attr::AutoGenerated); ast->setAttribute(Attr::Module, ctx->moduleName.path); if (aa) ast->setAttribute(Attr::ParentClass, aa->value); ctx->cache->reverseIdentifierLookup[name] = getUnmangledName(fn); auto baseType = getFuncTypeBase(2); auto typ = std::make_shared(baseType.get(), ast); /// Make sure that parent is set so that the parent type can be passed to the inner /// call /// (e.g., A[B].foo -> A.foo.dispatch() { A[B].foo() }) typ->funcParent = ofn->type->funcParent; typ = std::static_pointer_cast(typ->generalize(ctx->typecheckLevel - 1)); ctx->addFunc(name, name, typ); overloads.insert(overloads.begin(), name); ctx->cache->functions[name] = Cache::Function{"", fn, ast, typ}; ast->setDone(); return typ.get(); // stored in Cache::Function, hence not destroyed } /// Find a class member. /// @example /// `obj.GENERIC` -> `GENERIC` (IdExpr with generic/static value) /// `optional.member` -> `unwrap(optional).member` /// `pyobj.member` -> `pyobj._getattr("member")` Expr *TypecheckVisitor::getClassMember(DotExpr *expr) { auto typ = extractClassType(expr->getExpr()); seqassert(typ, "not a class"); // Case: object member access (`obj.member`) if (!isTypeExpr(expr->getExpr())) { if (auto member = findMember(typ, expr->getMember())) { unify(expr->getType(), instantiateType(member->getType(), typ)); if (!expr->getType()->canRealize() && member->typeExpr) { unify(expr->getType(), extractType(withClassGenerics(typ, [&]() { return transform(clean_clone(member->typeExpr)); }))); } if (expr->getExpr()->isDone() && realize(expr->getType())) expr->setDone(); return nullptr; } } bool hasSide = hasSideEffect(expr->getExpr()); auto wrapSide = [&](Expr *e) -> Expr * { if (hasSide) return transform(N(N(expr->getExpr()), e)); return transform(e); }; // Case: class variable (`Cls.var`) if (auto cls = getClass(typ)) if (auto var = in(cls->classVars, expr->getMember())) { return wrapSide(N(*var)); } // Case: special members std::unordered_map specialMembers{ {"__elemsize__", "int"}, {"__atomic__", "bool"}, {"__contents_atomic__", "bool"}}; if (auto mtyp = in(specialMembers, expr->getMember())) { unify(expr->getType(), getStdLibType(*mtyp)); if (expr->getExpr()->isDone() && realize(expr->getType())) expr->setDone(); return nullptr; } if (expr->getMember() == "__name__" && isTypeExpr(expr->getExpr())) { unify(expr->getType(), getStdLibType("str")); if (expr->getExpr()->isDone() && realize(expr->getType())) expr->setDone(); return nullptr; } // Case: object generic access (`obj.T`) ClassType::Generic *generic = nullptr; for (auto &g : typ->generics) if (expr->getMember() == getUnmangledName(g.name)) { generic = &g; break; } if (generic) { if (generic->staticKind) { unify(expr->getType(), generic->getType()); if (realize(expr->getType())) { return wrapSide(generic->type->getStatic()->getStaticExpr()); } } else { unify(expr->getType(), instantiateTypeVar(generic->getType())); if (realize(expr->getType())) return wrapSide(N(generic->getType()->realizedName())); } return nullptr; } // Case: transform `optional.member` to `unwrap(optional).member` if (typ->is(TYPE_OPTIONAL)) { expr->expr = transform(N(N(FN_OPTIONAL_UNWRAP), expr->getExpr())); return nullptr; } // Case: transform `pyobj.member` to `pyobj._getattr("member")` if (typ->is("pyobj")) { return transform(N(N(expr->getExpr(), "_getattr"), N(expr->getMember()))); } // Case: transform `union.m` to `Union._member(union, "m", ...)` if (typ->getUnion()) { if (!typ->canRealize()) return nullptr; // delay! return transform(N( N(N("Union"), "_member"), std::vector{CallArg{"union", expr->getExpr()}, CallArg{"member", N(expr->getMember())}})); } // Case: __getattr__ support. Ensure that only Literal[str] arguments are accepted. auto u = instantiateUnbound(); u->staticKind = LiteralKind::String; if (auto m = findBestMethod(typ, "__getattr__", {typ, u.get()})) { if (m->funcGenerics.size() == 1 && extractFuncGeneric(m)->getStaticKind() == LiteralKind::String) { return transform(N(N(expr->getExpr(), "__getattr__"), N(expr->getMember()))); } } // For debugging purposes: findMethod(typ, expr->getMember()); E(Error::DOT_NO_ATTR, expr, typ->prettyString(), expr->getMember()); return nullptr; } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/assign.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include "codon/cir/attribute.h" #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" #include "codon/parser/match.h" #include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; using namespace codon::matcher; namespace codon::ast { using namespace types; /// Transform walrus (assignment) expression. /// @example /// `(expr := var)` -> `var = expr; var` void TypecheckVisitor::visit(AssignExpr *expr) { auto a = N(clone(expr->getVar()), expr->getExpr()); a->cloneAttributesFrom(expr); resultExpr = transform(N(a, expr->getVar())); } /// Transform assignments. Handle dominated assignments, forward declarations, static /// assignments and type/function aliases. /// See @c transformAssignment and @c unpackAssignments for more details. /// See @c wrapExpr for more examples. void TypecheckVisitor::visit(AssignStmt *stmt) { if (cast(stmt->lhs) || cast(stmt->lhs)) { resultStmt = transform(unpackAssignment(stmt->lhs, stmt->rhs)); return; } bool mustUpdate = stmt->isUpdate() || stmt->isAtomicUpdate(); mustUpdate |= stmt->getLhs()->hasAttribute(Attr::ExprDominated); mustUpdate |= stmt->getLhs()->hasAttribute(Attr::ExprDominatedUsed); if (cast(stmt->getRhs()) && cast(stmt->getRhs())->isInPlace()) { // Update case: a += b seqassert(!stmt->getTypeExpr(), "invalid AssignStmt {}", stmt->toString(0)); mustUpdate = true; } resultStmt = transformAssignment(stmt, mustUpdate); if (stmt->getLhs()->hasAttribute(Attr::ExprDominatedUsed)) { // If this is dominated, set __used__ if needed stmt->getLhs()->eraseAttribute(Attr::ExprDominatedUsed); auto e = cast(stmt->getLhs()); seqassert(e, "dominated bad assignment"); resultStmt = transform(N( resultStmt, N(N(fmt::format("{}{}", getUnmangledName(e->getValue()), VAR_USED_SUFFIX)), N(true), nullptr, AssignStmt::UpdateMode::Update))); } } /// Transform deletions. /// @example /// `del a` -> `a = type(a)()` and remove `a` from the context /// `del a[x]` -> `a.__delitem__(x)` void TypecheckVisitor::visit(DelStmt *stmt) { if (auto idx = cast(stmt->getExpr())) { resultStmt = N(transform( N(N(idx->getExpr(), "__delitem__"), idx->getIndex()))); } else if (auto ei = cast(stmt->getExpr())) { // Assign `a` to `type(a)()` to mark it for deletion resultStmt = transform(N( stmt->getExpr(), N(N(N(TYPE_TYPE), clone(stmt->getExpr()))), nullptr, AssignStmt::Update)); // Allow deletion *only* if the binding is dominated auto val = ctx->find(ei->getValue()); if (!val) E(Error::ID_NOT_FOUND, ei, ei->getValue()); if (ctx->getScope() != val->scope) E(Error::DEL_NOT_ALLOWED, ei, ei->getValue()); ctx->remove(ei->getValue()); ctx->remove(getUnmangledName(ei->getValue())); } else { E(Error::DEL_INVALID, stmt); } } /// Unpack an assignment expression `lhs = rhs` into a list of simple assignment /// expressions (e.g., `a = b`, `a.x = b`, or `a[x] = b`). /// Handle Python unpacking rules. /// @example /// `(a, b) = c` -> `a = c[0]; b = c[1]` /// `a, b = c` -> `a = c[0]; b = c[1]` /// `[a, *x, b] = c` -> `a = c[0]; x = c[1:-1]; b = c[-1]`. /// Non-trivial right-hand expressions are first stored in a temporary variable. /// @example /// `a, b = c, d + foo()` -> `assign = (c, d + foo); a = assign[0]; b = assign[1]`. /// Each assignment is unpacked recursively to allow cases like `a, (b, c) = d`. Stmt *TypecheckVisitor::unpackAssignment(Expr *lhs, Expr *rhs) { std::vector leftSide; if (auto et = cast(lhs)) { // Case: (a, b) = ... for (auto *i : *et) leftSide.push_back(i); } else if (auto el = cast(lhs)) { // Case: [a, b] = ... for (auto *i : *el) leftSide.push_back(i); } else { return N(lhs, rhs); } // Prepare the right-side expression auto oldSrcInfo = getSrcInfo(); setSrcInfo(rhs->getSrcInfo()); auto *block = N(); if (!cast(rhs)) { // Store any non-trivial right-side expression into a variable auto var = getTemporaryVar("assign"); auto newRhs = N(var); block->addStmt(N(newRhs, ast::clone(rhs))); rhs = newRhs; } // Process assignments until the fist StarExpr (if any) size_t st = 0; for (; st < leftSide.size(); st++) { if (cast(leftSide[st])) break; // Transformation: `leftSide_st = rhs[st]` where `st` is static integer auto rightSide = N(ast::clone(rhs), N(st)); // Recursively process the assignment because of cases like `(a, (b, c)) = d)` auto ns = unpackAssignment(leftSide[st], rightSide); block->addStmt(ns); } // Process StarExpr (if any) and the assignments that follow it if (st < leftSide.size() && cast(leftSide[st])) { // StarExpr becomes SliceExpr (e.g., `b` in `(a, *b, c) = d` becomes // `list(d[1:-2])`) Expr *rightSide = N( N(getMangledMethod("std.internal.types.array", "List", "as_list")), N( ast::clone(rhs), N(N(st), // this slice is either [st:] or [st:-lhs_len + st + 1] leftSide.size() == st + 1 ? nullptr : N(-leftSide.size() + st + 1), nullptr))); auto ns = unpackAssignment(cast(leftSide[st])->getExpr(), rightSide); block->addStmt(ns); st += 1; // Process remaining assignments. They will use negative indices (-1, -2 etc.) // because we do not know how big is StarExpr for (; st < leftSide.size(); st++) { if (cast(leftSide[st])) E(Error::ASSIGN_MULTI_STAR, leftSide[st]->getSrcInfo()); rightSide = N(ast::clone(rhs), N(-static_cast(leftSide.size() - st))); auto next = unpackAssignment(leftSide[st], rightSide); block->addStmt(next); } } setSrcInfo(oldSrcInfo); return block; } /// Transform simple assignments. /// @example /// `a[x] = b` -> `a.__setitem__(x, b)` /// `a.x = b` -> @c AssignMemberStmt /// `a: type` = b -> @c AssignStmt /// `a = b` -> @c AssignStmt or @c UpdateStmt (see below) Stmt *TypecheckVisitor::transformAssignment(AssignStmt *stmt, bool mustExist) { if (auto idx = cast(stmt->getLhs())) { // Case: a[x] = b seqassert(!stmt->type, "unexpected type annotation"); if (auto b = cast(stmt->getRhs())) { // Case: a[x] += b (inplace operator) if (mustExist && b->isInPlace() && !cast(b->getRhs())) { auto var = getTemporaryVar("assign"); return transform(N( N(N(var), idx->getIndex()), N(N( N(idx->getExpr(), "__setitem__"), N(var), N(N(clone(idx->getExpr()), N(var)), b->getOp(), b->getRhs(), true))))); } } return transform(N(N(N(idx->getExpr(), "__setitem__"), idx->getIndex(), stmt->getRhs()))); } if (auto dot = cast(stmt->getLhs())) { // Case: a.x = b dot->expr = transform(dot->getExpr(), true); return transform(N( dot->getExpr(), dot->member, transform(stmt->getRhs()), stmt->getTypeExpr())); } // Case: a (: t) = b auto e = cast(stmt->getLhs()); if (!e) { E(Error::ASSIGN_INVALID, stmt->getLhs()); return nullptr; } if (ctx->inFunction() && stmt->getRhs() && !mustExist) { if (auto b = ctx->getBase()->func->getAttribute(Attr::Bindings)) { const auto &bd = b->bindings[e->getValue()]; if (bd.isNonlocal) { if (stmt->getTypeExpr()) stmt->type = N(N("Capsule"), stmt->getTypeExpr()); else stmt->type = N("Capsule"); } } } bool isThreadLocal = false; auto typeExpr = transformType(stmt->getTypeExpr()); if (typeExpr && extractType(typeExpr)->is(getMangledClass("std.threading", "ThreadLocal"))) { isThreadLocal = true; if (auto ti = cast(stmt->getTypeExpr())) { typeExpr = transformType(ti->getIndex()); } else { typeExpr = nullptr; } } // Make sure that existing values that cannot be shadowed are only updated // mustExist |= val && !ctx->isOuter(val); if (mustExist) { auto val = ctx->find(e->getValue(), getTime()); if (!val) E(Error::ASSIGN_LOCAL_REFERENCE, e, e->getValue(), e->getSrcInfo()); auto s = N(stmt->getLhs(), stmt->getRhs(), typeExpr); if (!ctx->getBase()->isType() && ctx->getBase()->func->hasAttribute(Attr::Atomic)) s->setAtomicUpdate(); else s->setUpdate(); if (auto u = transformUpdate(s)) return u; else return s; // delay } stmt->rhs = transform(stmt->getRhs(), true); stmt->type = typeExpr; // Generate new canonical variable name for this assignment and add it to the context auto canonical = ctx->generateCanonicalName(e->getValue()); auto assign = N(N(canonical), stmt->getRhs(), stmt->getTypeExpr()); assign->getLhs()->cloneAttributesFrom(stmt->getLhs()); assign->getLhs()->setType(stmt->getLhs()->getType() ? stmt->getLhs()->getType()->shared_from_this() : instantiateUnbound(assign->getLhs()->getSrcInfo())); if (isThreadLocal) assign->setThreadLocal(); if (!stmt->getRhs() && !stmt->getTypeExpr() && ctx->find("NoneType")) { // All declarations that are not handled are to be marked with NoneType later on // (useful for dangling declarations that are not initialized afterwards due to // static check) assign->getLhs()->getType()->getLink()->defaultType = getStdLibType("NoneType")->shared_from_this(); ctx->getBase()->pendingDefaults[1].insert( assign->getLhs()->getType()->shared_from_this()); } if (stmt->getTypeExpr()) { auto t = extractType(stmt->getTypeExpr()); unify(assign->getLhs()->getType(), instantiateType(stmt->getTypeExpr()->getSrcInfo(), t)); } auto val = std::make_shared( canonical, ctx->getBaseName(), ctx->getModule(), assign->getLhs()->getType()->shared_from_this(), ctx->getScope()); val->time = getTime(); val->setSrcInfo(getSrcInfo()); ctx->add(e->getValue(), val); ctx->addAlwaysVisible(val); if (assign->getRhs()) { // not a declaration! // Check if we can wrap the expression (e.g., `a: float = 3` -> `a = float(3)`) if (wrapExpr(&assign->rhs, assign->getLhs()->getType())) { unify(assign->getLhs()->getType(), assign->getRhs()->getType()); } // Generalize non-variable types. That way we can support cases like: // `a = foo(x, ...); a(1); a('s')` if (!val->isVar()) { val->type = val->type->generalize(ctx->typecheckLevel - 1); // See capture_function_partial_proper_realize test assign->getLhs()->setType(val->type); assign->getRhs()->setType(val->type); } } // Mark declarations or generalized type/functions as done if ((!assign->getRhs() || assign->getRhs()->isDone()) && assign->getLhs()->getType()->canRealize()) { if (auto r = realize(assign->getLhs()->getType())) { // overwrite types to remove dangling unbounds with some partials... assign->getLhs()->setType(r->shared_from_this()); if (assign->getRhs()) assign->getRhs()->setType(r->shared_from_this()); assign->setDone(); } } else if (assign->getRhs() && !val->isVar() && !val->type->hasUnbounds(false)) { assign->setDone(); } // Register all toplevel variables as global in JIT mode // OR if they are in imported module (not toplevel) bool isGlobal = (ctx->cache->isJit && val->isGlobal() && !val->isGeneric()) || (canonical == VAR_ARGV) || (val->isGlobal() && val->getModule() != ""); if (isGlobal && val->isVar()) { registerGlobal(canonical); if (ctx->cache->isJit) { getImport(STDLIB_IMPORT)->ctx->addToplevel(getUnmangledName(val->getName()), val); } } return assign; } /// Transform binding updates. Special handling is done for atomic or in-place /// statements (e.g., `a += b`). /// See @c transformInplaceUpdate and @c wrapExpr for details. Stmt *TypecheckVisitor::transformUpdate(AssignStmt *stmt) { stmt->lhs = transform(stmt->getLhs()); // Check inplace updates auto [inPlace, inPlaceStmt] = transformInplaceUpdate(stmt); if (inPlace) { return inPlaceStmt; } stmt->rhs = transform(stmt->getRhs()); stmt->type = transformType(stmt->getTypeExpr()); if (stmt->getTypeExpr()) { unify(stmt->getLhs()->getType(), instantiateType(stmt->getTypeExpr()->getSrcInfo(), extractType(stmt->getTypeExpr()))); } // Case: wrap expressions if needed (e.g. floats or optionals) if (wrapExpr(&stmt->rhs, stmt->getLhs()->getType())) unify(stmt->getRhs()->getType(), stmt->getLhs()->getType()); if (stmt->getRhs()->isDone() && realize(stmt->getLhs()->getType())) stmt->setDone(); return nullptr; } /// Typecheck instance member assignments (e.g., `a.b = c`) and handle optional /// instances. Disallow tuple updates. /// @example /// `opt.foo = bar` -> `unwrap(opt).foo = wrap(bar)` /// See @c wrapExpr for more examples. void TypecheckVisitor::visit(AssignMemberStmt *stmt) { stmt->lhs = transform(stmt->getLhs()); if (auto lhsClass = extractClassType(stmt->getLhs())) { auto member = findMember(lhsClass, stmt->getMember()); if (!member) { // Case: property setters auto setters = findMethod( lhsClass, fmt::format("{}{}", FN_SETTER_SUFFIX, stmt->getMember())); if (!setters.empty()) { resultStmt = transform(N(N(N(setters.front()->getFuncName()), stmt->getLhs(), stmt->getRhs()))); return; } // Case: class variables if (auto cls = getClass(lhsClass)) if (auto var = in(cls->classVars, stmt->getMember())) { auto a = N(N(*var), transform(stmt->getRhs())); a->setUpdate(); resultStmt = transform(a); return; } } if (!member && lhsClass->is(TYPE_OPTIONAL)) { // Unwrap optional and look up there resultStmt = transform(N( N(N(FN_OPTIONAL_UNWRAP), stmt->getLhs()), stmt->getMember(), stmt->getRhs())); return; } // Case: __setattr__ support. Ensure that only Literal[str] arguments are accepted. if (!member) { auto u = instantiateUnbound(), v = instantiateUnbound(); u->staticKind = LiteralKind::String; if (auto m = findBestMethod(lhsClass, "__setattr__", {lhsClass, u.get(), v.get()})) { if (m->funcGenerics.size() >= 1 && extractFuncGeneric(m)->getStaticKind() == LiteralKind::String) { resultStmt = transform(N( N(N(stmt->getLhs(), "__setattr__"), N(stmt->getMember()), stmt->getRhs()))); return; } } } if (!member) { E(Error::DOT_NO_ATTR, stmt->getLhs(), lhsClass->prettyString(), stmt->getMember()); return; } if (lhsClass->isRecord()) // prevent tuple member assignment E(Error::ASSIGN_UNEXPECTED_FROZEN, stmt->getLhs()); stmt->rhs = transform(stmt->getRhs()); stmt->type = transformType(stmt->getTypeExpr()); if (stmt->getTypeExpr()) { unify(stmt->getRhs()->getType(), instantiateType(stmt->getTypeExpr()->getSrcInfo(), extractType(stmt->getTypeExpr()))); } auto ftyp = instantiateType(stmt->getLhs()->getSrcInfo(), member->getType(), lhsClass); if (!ftyp->canRealize() && member->typeExpr) { unify(ftyp.get(), extractType(withClassGenerics(lhsClass, [&]() { return transform(clean_clone(member->typeExpr)); }))); } if (!wrapExpr(&stmt->rhs, ftyp.get())) return; unify(stmt->getRhs()->getType(), ftyp.get()); if (stmt->getRhs()->isDone()) stmt->setDone(); } } /// Transform in-place and atomic updates. /// @example /// `a += b` -> `a.__iadd__(a, b)` if `__iadd__` exists /// Capsule operations: /// `a = b` -> a.val[0] = b /// `a += b` -> a.val[0] += b /// Atomic operations (when the needed magics are available): /// `a = b` -> `type(a).__atomic_xchg__(__ptr__(a), b)` /// `a += b` -> `type(a).__atomic_add__(__ptr__(a), b)` /// `a = min(a, b)` -> `type(a).__atomic_min__(__ptr__(a), b)` (same for `max`) /// @return a tuple indicating whether (1) the update statement can be replaced with an /// expression, and (2) the replacement expression. std::pair TypecheckVisitor::transformInplaceUpdate(AssignStmt *stmt) { // Case: capsule operations if (stmt->getLhs()->getType()->is("Capsule")) { return {true, transform(N( N(N(N(getMangledMethod("std.internal.core", "Capsule", "_ptr")), stmt->getLhs()), N(0)), stmt->getRhs()))}; } // Case: in-place updates (e.g., `a += b`). // They are stored as `Update(a, Binary(a + b, inPlace=true))` auto bin = cast(stmt->getRhs()); if (bin && bin->isInPlace()) { bin->lexpr = transform(bin->getLhs()); bin->rexpr = transform(bin->getRhs()); if (!stmt->getRhs()->getType()) stmt->getRhs()->setType(instantiateUnbound()); if (bin->getLhs()->getClassType() && bin->getRhs()->getClassType()) { if (auto transformed = transformBinaryInplaceMagic(bin, stmt->isAtomicUpdate())) { unify(stmt->getRhs()->getType(), transformed->getType()); return {true, transform(N(transformed))}; } else if (!stmt->isAtomicUpdate()) { // If atomic, call normal magic and then use __atomic_xchg__ below return {false, nullptr}; } } else { // Delay unify(stmt->getLhs()->getType(), unify(stmt->getRhs()->getType(), instantiateUnbound())); return {true, nullptr}; } } // Case: atomic min/max operations. // Note: check only `a = min(a, b)`; does NOT check `a = min(b, a)` auto lhsClass = extractClassType(stmt->getLhs()); auto call = cast(stmt->getRhs()); auto lei = cast(stmt->getLhs()); auto cei = call ? cast(call->getExpr()) : nullptr; if (stmt->isAtomicUpdate() && call && lei && cei && (cei->getValue() == "min" || cei->getValue() == "max") && call->size() == 2) { call->front().value = transform(call->front()); if (cast(call->front()) && cast(call->front())->getValue() == lei->getValue()) { // `type(a).__atomic_min__(__ptr__(a), b)` auto ptrTyp = instantiateType(stmt->getLhs()->getSrcInfo(), getStdLibType("Ptr"), std::vector{lhsClass}); (*call)[1].value = transform((*call)[1]); auto rhsTyp = extractClassType((*call)[1].value); if (auto method = findBestMethod(lhsClass, fmt::format("__atomic_{}__", cei->getValue()), {ptrTyp.get(), rhsTyp})) { return {true, transform(N(N( N(method->getFuncName()), N(N("__ptr__"), stmt->getLhs()), (*call)[1])))}; } } } // Case: atomic assignments if (stmt->isAtomicUpdate() && lhsClass) { // `type(a).__atomic_xchg__(__ptr__(a), b)` stmt->rhs = transform(stmt->getRhs()); if (auto rhsClass = stmt->getRhs()->getClassType()) { auto ptrType = instantiateType(stmt->getLhs()->getSrcInfo(), getStdLibType("Ptr"), std::vector{lhsClass}); if (auto m = findBestMethod(lhsClass, "__atomic_xchg__", {ptrType.get(), rhsClass})) { return {true, transform(N( N(N(m->getFuncName()), N(N("__ptr__"), stmt->getLhs()), stmt->getRhs())))}; } } } return {false, nullptr}; } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/basic.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; namespace codon::ast { using namespace types; /// Set type to `Optional[?]` void TypecheckVisitor::visit(NoneExpr *expr) { unify(expr->getType(), instantiateType(getStdLibType(TYPE_OPTIONAL))); if (realize(expr->getType())) { // Realize the appropriate `Optional.__new__` for the translation stage auto f = ctx->forceFind(getMangledMethod("std.internal.core", TYPE_OPTIONAL, "__new__")) ->getType(); auto t = realize(instantiateType(f, extractClassType(expr))); expr->setDone(); } } /// Set type to `bool` void TypecheckVisitor::visit(BoolExpr *expr) { unify(expr->getType(), instantiateStatic(expr->getValue())); expr->setDone(); } /// Set type to `int` void TypecheckVisitor::visit(IntExpr *expr) { resultExpr = transformInt(expr); } /// Set type to `float` void TypecheckVisitor::visit(FloatExpr *expr) { resultExpr = transformFloat(expr); } /// Set type to `str`. Concatinate strings in list and apply appropriate transformations /// (e.g., `str` wrap). void TypecheckVisitor::visit(StringExpr *expr) { if (expr->isSimple()) { unify(expr->getType(), instantiateStatic(expr->getValue())); expr->setDone(); } else { std::vector items; for (auto &p : *expr) { if (p.expr) { if (!p.format.conversion.empty()) { switch (p.format.conversion[0]) { case 'r': p.expr = N(N("repr"), p.expr); break; case 's': p.expr = N(N("str"), p.expr); break; case 'a': p.expr = N(N("ascii"), p.expr); break; default: // TODO: error? break; } } if (!p.format.spec.empty()) { p.expr = N(N(p.expr, "__format__"), N(p.format.spec)); } p.expr = N(N("str"), p.expr); if (!p.format.text.empty()) { p.expr = N(N(N("str"), "cat"), N(p.format.text), p.expr); } items.emplace_back(p.expr); } else if (!p.prefix.empty()) { /// Custom prefix strings: /// call `str.__prefsix_[prefix]__(str, [static length of str])` items.emplace_back(N( N(N("str"), fmt::format("__prefix_{}__", p.prefix)), N(p.value), N(p.value.size()))); } else { items.emplace_back(N(p.value)); } } if (items.size() == 1) resultExpr = transform(items.front()); else resultExpr = transform(N(N(N("str"), "cat"), items)); } } /// Parse various integer representations depending on the integer suffix. /// @example /// `123u` -> `UInt[64](123)` /// `123i56` -> `Int[56](123)` /// `123pf` -> `int.__suffix_pf__(123)` Expr *TypecheckVisitor::transformInt(IntExpr *expr) { auto [value, suffix] = expr->getRawData(); Expr *holder = nullptr; if (!expr->hasStoredValue()) { holder = N(value); if (suffix.empty()) suffix = "i64"; } else { holder = N(expr->getValue()); } /// Handle fixed-width integers: suffixValue is a pointer to NN if the suffix /// is `uNNN` or `iNNN`. std::unique_ptr suffixValue = nullptr; if (suffix.size() > 1 && (suffix[0] == 'u' || suffix[0] == 'i') && isdigit(suffix.substr(1))) { try { suffixValue = std::make_unique(std::stoi(suffix.substr(1))); } catch (...) { } if (suffixValue && *suffixValue > MAX_INT_WIDTH) suffixValue = nullptr; } if (suffix.empty()) { // A normal integer (int64_t) unify(expr->getType(), instantiateStatic(expr->getValue())); expr->setDone(); return nullptr; } else if (suffix == "u") { // Unsigned integer: call `UInt[64](value)` return transform( N(N(N("UInt"), N(64)), holder)); } else if (suffixValue) { // Fixed-width numbers (with `uNNN` and `iNNN` suffixes): // call `UInt[NNN](value)` or `Int[NNN](value)` return transform( N(N(N(suffix[0] == 'u' ? "UInt" : "Int"), N(*suffixValue)), holder)); } else { // Custom suffix: call `int.__suffix_[suffix]__(value)` return transform(N( N(N("int"), fmt::format("__suffix_{}__", suffix)), holder)); } } /// Parse various float representations depending on the suffix. /// @example /// `123.4pf` -> `float.__suffix_pf__(123.4)` Expr *TypecheckVisitor::transformFloat(FloatExpr *expr) { auto [value, suffix] = expr->getRawData(); Expr *holder = nullptr; if (!expr->hasStoredValue()) { holder = N(value); } else { holder = N(expr->getValue()); } if (suffix.empty() && expr->hasStoredValue()) { // A normal float (double) unify(expr->getType(), getStdLibType("float")); expr->setDone(); return nullptr; } else if (suffix.empty()) { return transform(N(N(N("float"), "__new__"), holder)); } else { // Custom suffix: call `float.__suffix_[suffix]__(value)` return transform(N( N(N("float"), fmt::format("__suffix_{}__", suffix)), holder)); } } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/call.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" #include "codon/parser/match.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; namespace codon::ast { using namespace types; using namespace matcher; /// Transform print statement. /// @example /// `print a, b` -> `print(a, b)` /// `print a, b,` -> `print(a, b, end=' ')` void TypecheckVisitor::visit(PrintStmt *stmt) { std::vector args; args.reserve(stmt->size()); for (auto &i : *stmt) args.emplace_back(i); if (!stmt->hasNewline()) args.emplace_back("end", N(" ")); resultStmt = transform(N(N(N("print"), args))); } /// Just ensure that this expression is not independent of CallExpr where it is handled. void TypecheckVisitor::visit(StarExpr *expr) { E(Error::UNEXPECTED_TYPE, expr, "star"); } /// Just ensure that this expression is not independent of CallExpr where it is handled. void TypecheckVisitor::visit(KeywordStarExpr *expr) { E(Error::UNEXPECTED_TYPE, expr, "kwstar"); } /// Typechecks an ellipsis. Ellipses are typically replaced during the typechecking; the /// only remaining ellipses are those that belong to PipeExprs. void TypecheckVisitor::visit(EllipsisExpr *expr) { if (expr->isPipe() && realize(expr->getType())) { expr->setDone(); } else if (expr->isStandalone()) { resultExpr = transform(N(N("ellipsis"))); unify(expr->getType(), resultExpr->getType()); } } /// Typecheck a call expression. This is the most complex expression to typecheck. /// @example /// `fn(1, 2, x=3, y=4)` -> `func(a=1, x=3, args=(2,), kwargs=KwArgs(y=4), T=int)` /// `fn(arg1, ...)` -> `(_v = Partial.N10(arg1); _v)` /// See @c transformCallArgs , @c getCalleeFn , @c callReorderArguments , /// @c typecheckCallArgs , @c transformSpecialCall and @c wrapExpr for more details. void TypecheckVisitor::visit(CallExpr *expr) { if (ctx->simpleTypes) E(Error::CALL_NO_TYPE, expr); if (match(expr->getExpr(), M("tuple")) && expr->size() == 1) { expr->setAttribute(Attr::TupleCall); } validateCall(expr); // Check if this call is partial call PartialCallData part; if (!expr->empty()) { if (auto el = cast(expr->back().getExpr()); el && el->isPartial()) { part.isPartial = true; } } // Do not allow realization here (function will be realized later); // used to prevent early realization of compile_error expr->setAttribute(Attr::ParentCallExpr); if (part.isPartial) expr->getExpr()->setAttribute(Attr::ExprDoNotRealize); expr->expr = transform(expr->getExpr()); expr->eraseAttribute(Attr::ParentCallExpr); if (isUnbound(expr->getExpr())) return; // delay auto [calleeFn, newExpr] = getCalleeFn(expr, part); // Transform `tuple(i for i in tup)` into a GeneratorExpr // that will be handled during the type checking. if (!calleeFn && expr->hasAttribute(Attr::TupleCall)) { if (cast(expr->begin()->getExpr())) { auto g = cast(expr->begin()->getExpr()); if (!g || g->kind != GeneratorExpr::Generator || g->loopCount() != 1) E(Error::CALL_TUPLE_COMPREHENSION, expr->begin()->getExpr()); g->kind = GeneratorExpr::TupleGenerator; resultExpr = transform(g); return; } else { resultExpr = transformTupleFn(expr); return; } } else if ((resultExpr = newExpr)) { return; } else if (!calleeFn) { return; } if (!withClassGenerics( calleeFn.get(), [&]() { return transformCallArgs(expr); }, true, true)) return; // Early dispatch modifier if (isDispatch(calleeFn.get())) { if (startswith(calleeFn->getFuncName(), "Tuple.__new__")) { generateTuple(expr->size()); } std::unique_ptr> m = nullptr; auto id = cast(getHeadExpr(expr->getExpr())); if (id && part.var.empty()) { // Case: function overloads (IdExpr) // Make sure to ignore partial constructs (they are also StmtExpr(IdExpr, ...)) std::vector methods; auto key = id->getValue(); if (isDispatch(key)) key = key.substr(0, key.size() - std::string(FN_DISPATCH_SUFFIX).size()); for (auto &ovs : getOverloads(key)) { if (!isDispatch(ovs)) methods.push_back(getFunction(ovs)->getType()); } std::ranges::reverse(methods); m = std::make_unique>(findMatchingMethods( calleeFn->funcParent ? calleeFn->funcParent->getClass() : nullptr, methods, expr->items, expr->getExpr()->getType()->getPartial())); } // partials have dangling ellipsis that messes up with the unbound check below bool doDispatch = !m || m->empty() || part.isPartial; if (!doDispatch && m && m->size() > 1) { for (auto &a : *expr) { if (isUnbound(a.getExpr())) return; // typecheck this later once we know the argument } } if (!doDispatch) { calleeFn = instantiateType(m->front(), calleeFn->funcParent ? calleeFn->funcParent->getClass() : nullptr); auto e = N(calleeFn->getFuncName()); e->setType(calleeFn); if (cast(expr->getExpr())) { expr->expr = e; } else if (cast(expr->getExpr())) { // Side effect... for (auto *se = cast(expr->getExpr());;) { if (auto ne = cast(se->getExpr())) { se = ne; } else { se->expr = e; break; } } } else { expr->expr = N(N(expr->getExpr()), e); } expr->getExpr()->setType(calleeFn); } else if (m && m->empty()) { std::vector a; for (auto &t : *expr) a.emplace_back(fmt::format("{}", t.getExpr()->getType()->getStatic() ? t.getExpr()->getClassType()->name : t.getExpr()->getType()->prettyString())); auto argsNice = fmt::format("({})", join(a, ", ")); auto name = getUnmangledName(calleeFn->getFuncName()); if (auto a = calleeFn->ast->getAttribute(Attr::ParentClass)) name = fmt::format("{}.{}", getUserFacingName(a->value), name); E(Error::FN_NO_ATTR_ARGS, expr, name, argsNice); } } // Handle named and default arguments if ((resultExpr = callReorderArguments(calleeFn.get(), expr, part))) return; // Handle special calls if (!part.isPartial) { auto [isSpecial, specialExpr] = transformSpecialCall(expr); if (isSpecial) { resultExpr = specialExpr; return; } } // Typecheck arguments with the function signature bool done = typecheckCallArgs(calleeFn.get(), expr->items, part); if (!part.isPartial && calleeFn->canRealize()) { // Previous unifications can qualify existing identifiers. // Transform again to get the full identifier expr->expr = transform(expr->expr); } done &= expr->expr->isDone(); // Emit the final call if (part.isPartial) { // Case: partial call. `calleeFn(args...)` -> `Partial(args..., fn, mask)` std::vector newArgs; for (auto &r : *expr) if (!cast(r.getExpr())) { newArgs.push_back(r.getExpr()); newArgs.back()->setAttribute(Attr::ExprSequenceItem); } newArgs.push_back(part.args); auto partialCall = generatePartialCall(part.known, calleeFn->getFunc(), N(newArgs), part.kwArgs); std::string var = getTemporaryVar("part"); Expr *call = nullptr; if (!part.var.empty()) { // Callee is already a partial call auto stmts = cast(expr->expr)->items; stmts.push_back(N(N(var), partialCall)); call = N(stmts, N(var)); } else { // New partial call: `(part = Partial(stored_args...); part)` call = N(N(N(var), partialCall), N(var)); } call->setAttribute(Attr::ExprPartial); resultExpr = transform(call); } else { // Case: normal function call unify(expr->getType(), calleeFn->getRetType()); if (done) expr->setDone(); } } void TypecheckVisitor::validateCall(CallExpr *expr) { if (expr->hasAttribute(Attr::Validated)) return; bool namesStarted = false, foundEllipsis = false; for (auto &a : *expr) { if (a.name.empty() && namesStarted && !(cast(a.value) || cast(a.value))) E(Error::CALL_NAME_ORDER, a.value); if (!a.name.empty() && (cast(a.value) || cast(a.value))) E(Error::CALL_NAME_STAR, a.value); if (cast(a.value) && foundEllipsis) E(Error::CALL_ELLIPSIS, a.value); foundEllipsis |= static_cast(cast(a.value)); namesStarted |= !a.name.empty(); } expr->setAttribute(Attr::Validated); } /// Transform call arguments. Expand *args and **kwargs to the list of @c CallArg /// objects. /// @return false if expansion could not be completed; true otherwise bool TypecheckVisitor::transformCallArgs(CallExpr *expr) { for (auto ai = 0; ai < expr->size();) { if (auto star = cast((*expr)[ai].getExpr())) { // Case: *args expansion star->expr = transform(star->getExpr()); auto typ = star->getExpr()->getClassType(); while (typ && typ->is(TYPE_OPTIONAL)) { star->expr = transform(N(N(FN_OPTIONAL_UNWRAP), star->getExpr())); typ = star->getExpr()->getClassType(); } if (!typ) // Process later return false; if (!typ->isRecord()) E(Error::CALL_BAD_UNPACK, (*expr)[ai], typ->prettyString()); auto fields = getClassFields(typ); Expr *head = star->getExpr(), *lead = nullptr; if (hasSideEffect(head)) { auto var = getTemporaryVar("star"); lead = N(N(var), head); head = N(var); } for (size_t i = 0; i < fields.size(); i++, ai++) { expr->items.insert( expr->items.begin() + ai, CallArg{"", transform(N(clone(lead && i == 0 ? lead : head), fields[i].name))}); } expr->items.erase(expr->items.begin() + ai); } else if (const auto kwstar = cast((*expr)[ai].getExpr())) { // Case: **kwargs expansion kwstar->expr = transform(kwstar->getExpr()); auto typ = kwstar->getExpr()->getClassType(); while (typ && typ->is(TYPE_OPTIONAL)) { kwstar->expr = transform(N(N(FN_OPTIONAL_UNWRAP), kwstar->getExpr())); typ = kwstar->getExpr()->getClassType(); } if (!typ) return false; Expr *head = kwstar->getExpr(), *lead = nullptr; if (hasSideEffect(head)) { auto var = getTemporaryVar("star"); lead = N(N(var), head); head = N(var); } if (typ->is("NamedTuple")) { auto id = getIntLiteral(typ); seqassert(id >= 0 && id < ctx->cache->generatedTupleNames.size(), "bad id: {}", id); auto names = ctx->cache->generatedTupleNames[id]; for (size_t i = 0; i < names.size(); i++, ai++) { expr->items.insert( expr->items.begin() + ai, CallArg{names[i], transform(N( N(clone(lead && i == 0 ? lead : head), "args"), fmt::format("item{}", i + 1)))}); } expr->items.erase(expr->items.begin() + ai); } else if (typ->isRecord()) { auto fields = getClassFields(typ); for (size_t i = 0; i < fields.size(); i++, ai++) { expr->items.insert( expr->items.begin() + ai, CallArg{fields[i].name, transform(N(clone(lead && i == 0 ? lead : head), fields[i].name))}); } expr->items.erase(expr->items.begin() + ai); } else { E(Error::CALL_BAD_KWUNPACK, (*expr)[ai], typ->prettyString()); } } else { // Case: normal argument (no expansion) (*expr)[ai].value = transform((*expr)[ai].getExpr()); ai++; } } // Check if some argument names are reused after the expansion std::set seen; for (auto &a : *expr) if (!a.name.empty()) { if (in(seen, a.name)) E(Error::CALL_REPEATED_NAME, a, a.name); seen.insert(a.name); } return true; } /// Extract the @c FuncType that represents the function to be called by the callee. /// Also handle special callees: constructors and partial functions. /// @return a pair with the callee's @c FuncType and the replacement expression /// (when needed; otherwise nullptr). std::pair, Expr *> TypecheckVisitor::getCalleeFn(CallExpr *expr, PartialCallData &part) { auto callee = expr->getExpr()->getClassType(); if (!callee) { // Case: unknown callee, wait until it becomes known return {nullptr, nullptr}; } if (expr->hasAttribute(Attr::TupleCall) && (extractType(expr->getExpr())->is(TYPE_TUPLE) || (callee->getFunc() && startswith(callee->getFunc()->ast->name, "std.internal.static.tuple.")))) return {nullptr, nullptr}; if (isTypeExpr(expr->getExpr())) { auto typ = expr->getExpr()->getClassType(); if (!isId(expr->getExpr(), TYPE_TYPE)) typ = extractClassGeneric(typ)->getClass(); if (!typ) return {nullptr, nullptr}; auto clsName = typ->name; if (typ->isRecord()) { if (expr->hasAttribute(Attr::TupleCall)) { expr->eraseAttribute(Attr::TupleCall); } // Case: tuple constructor. Transform to: `T.__new__(args)` auto e = transform(N(N(expr->getExpr(), "__new__"), expr->items)); return {nullptr, e}; } // Case: reference type constructor. Transform to // `ctr = T.__new__(); v.__init__(args)` Expr *var = N(getTemporaryVar("ctr")); auto newInit = N(clone(var), N(N(expr->getExpr(), "__new__"))); auto e = N(N(newInit), clone(var)); auto init = N(N(N(clone(var), "__init__"), expr->items)); e->items.emplace_back(init); return {nullptr, transform(e)}; } if (auto partType = callee->getPartial()) { auto mask = partType->getPartialMask(); auto genFn = partType->getPartialFunc()->generalize(0); auto calleeFn = std::static_pointer_cast(instantiateType(genFn.get())); if (!partType->isPartialEmpty() || std::ranges::any_of(mask, [](char c) { return c != ClassType::PartialFlag::Missing; })) { // Case: calling partial object `p`. Transform roughly to // `part = callee; partial_fn(*part.args, args...)` Expr *var = N(part.var = getTemporaryVar("partcall")); expr->expr = transform(N(N(clone(var), expr->getExpr()), N(calleeFn->getFuncName()))); part.known = mask; } else { expr->expr = transform(N(calleeFn->getFuncName())); } seqassert(expr->getExpr()->getType()->getFunc(), "not a function: {}", *(expr->getExpr()->getType())); unify(expr->getExpr()->getType(), calleeFn); // Unify partial generics with types known thus far auto knownArgTypes = extractClassGeneric(partType, 1)->getClass(); for (size_t i = 0, j = 0, k = 0; i < mask.size(); i++) if ((*calleeFn->ast)[i].isGeneric()) { j++; } else if (mask[i] == ClassType::PartialFlag::Included) { unify(extractFuncArgType(calleeFn.get(), i - j), extractClassGeneric(knownArgTypes, k)); k++; } else if (mask[i] == ClassType::PartialFlag::Default) { k++; } return {calleeFn, nullptr}; } else if (!callee->getFunc()) { // Case: callee is not a function. Try __call__ method instead return {nullptr, transform(N(N(expr->getExpr(), "__call__"), expr->items))}; } else { return {std::static_pointer_cast( callee->getFunc()->shared_from_this()), nullptr}; } } /// Reorder the call arguments to match the signature order. Ensure that every @c /// CallArg has a set name. Form *args/**kwargs tuples if needed, and use partial /// and default values where needed. /// @example /// `foo(1, 2, baz=3, baf=4)` -> `foo(a=1, baz=2, args=(3, ), kwargs=KwArgs(baf=4))` Expr *TypecheckVisitor::callReorderArguments(FuncType *calleeFn, CallExpr *expr, PartialCallData &part) { if (calleeFn->ast->hasAttribute(Attr::NoArgReorder)) return nullptr; bool inOrder = true; std::vector> ordered; std::vector args; // stores ordered and processed arguments std::vector typeArgs; // stores type and static arguments (e.g., `T: type`) int64_t starIdx = -1; // for *args std::vector starArgs; int64_t kwStarIdx = -1; // for **kwargs std::vector kwStarNames; std::vector kwStarArgs; auto newMask = std::string(calleeFn->ast->size(), ClassType::PartialFlag::Included); // Extract pi-th partial argument from a partial object auto getPartialArg = [&](size_t pi) { auto id = transform(N(N(part.var), "args")); // Manually call @c transformStaticTupleIndex to avoid spurious InstantiateExpr auto ex = transformStaticTupleIndex(id->getClassType(), id, N(pi)); seqassert(ex.first && ex.second, "partial indexing failed: {}", *(id->getType())); return ex.second; }; auto addReordered = [&](size_t i) -> bool { Expr **e = &((*expr)[i].value); if (hasSideEffect(*e)) { if (!ordered.empty() && i < ordered.back().first) inOrder = false; ordered.emplace_back(i, e); return true; } return false; }; // Handle reordered arguments (see @c reorderNamedArgs for details) bool partial = false; auto reorderFn = [&](int starArgIndex, int kwstarArgIndex, const std::vector> &slots, bool _partial) { partial = _partial; return withClassGenerics( calleeFn, [&]() { for (size_t si = 0, pi = 0, gi = 0; si < slots.size(); si++) { // Get the argument name to be used later auto [_, rn] = (*calleeFn->ast)[si].getNameWithStars(); auto realName = getUnmangledName(rn); if ((*calleeFn->ast)[si].isGeneric()) { // Case: generic arguments. Populate typeArgs if (startswith(realName, "$")) { if (slots[si].empty()) { if (!part.known.empty() && part.known[si] == ClassType::PartialFlag::Included) { auto t = N(realName); t->setType( calleeFn->funcGenerics[gi].getType()->shared_from_this()); typeArgs.emplace_back(t); } else { typeArgs.emplace_back(transform(N(realName.substr(1)))); } } else { typeArgs.emplace_back((*expr)[slots[si][0]].getExpr()); if (addReordered(slots[si][0])) inOrder = false; // type arguments always need preprocessing } newMask[si] = ClassType::PartialFlag::Included; } else if (slots[si].empty()) { typeArgs.push_back(nullptr); newMask[si] = ClassType::PartialFlag::Missing; } else { typeArgs.push_back((*expr)[slots[si][0]].getExpr()); newMask[si] = ClassType::PartialFlag::Included; if (addReordered(slots[si][0])) inOrder = false; // type arguments always need preprocessing } gi++; } else if (si == starArgIndex && !(slots[si].size() == 1 && (*expr)[slots[si][0]].getExpr()->hasAttribute( Attr::ExprStarArgument))) { // Case: *args. Build the tuple that holds them all if (!part.known.empty()) { starArgs.push_back(N(getPartialArg(-1))); } for (auto &e : slots[si]) { starArgs.push_back((*expr)[e].getExpr()); addReordered(e); } starIdx = static_cast(args.size()); args.emplace_back(realName, nullptr); if (partial) newMask[si] = ClassType::PartialFlag::Missing; } else if (si == kwstarArgIndex && !(slots[si].size() == 1 && (*expr)[slots[si][0]].getExpr()->hasAttribute( Attr::ExprKwStarArgument))) { // Case: **kwargs. Build the named tuple that holds them all std::unordered_set newNames; for (auto &e : slots[si]) // kwargs names can be overriden later newNames.insert((*expr)[e].getName()); if (!part.known.empty()) { auto e = transform(N(N(part.var), "kwargs")); for (auto &[n, ne] : extractNamedTuple(e)) { if (!in(newNames, n)) { newNames.insert(n); kwStarNames.emplace_back(n); kwStarArgs.emplace_back(transform(ne)); } } } for (auto &e : slots[si]) { kwStarNames.emplace_back((*expr)[e].getName()); kwStarArgs.emplace_back((*expr)[e].getExpr()); addReordered(e); } kwStarIdx = static_cast(args.size()); args.emplace_back(realName, nullptr); if (partial) newMask[si] = ClassType::PartialFlag::Missing; } else if (slots[si].empty()) { // Case: no arguments provided. if (!part.known.empty() && part.known[si] == ClassType::PartialFlag::Included) { // Case 1: Argument captured by partial args.emplace_back(realName, getPartialArg(pi)); pi++; } else if (startswith(realName, "$")) { // Case 3: Local name capture bool added = false; if (partial) { if (auto val = ctx->find(realName.substr(1))) { if (val->isFunc() && val->getType()->getFunc()->ast->getName() == calleeFn->ast->getName()) { // Special case: fn(fn=fn) // Delay this one. args.emplace_back( realName, transform(N(EllipsisExpr::PARTIAL))); newMask[si] = ClassType::PartialFlag::Missing; added = true; } } } if (!added) args.emplace_back(realName, transform(N(realName.substr(1)))); } else if ((*calleeFn->ast)[si].getDefault()) { // Case 4: default is present if (auto ai = cast((*calleeFn->ast)[si].getDefault())) { // Case 4a: non-values (Ids / .default names) if (!part.known.empty() && part.known[si] == ClassType::PartialFlag::Default) { // Case 4a/1: Default already captured by partial. args.emplace_back(realName, getPartialArg(pi)); pi++; } else { // TODO: check if the value is toplevel to avoid capturing it auto e = transform(N(ai->getValue())); seqassert(e->getType()->getLink(), "not a link type"); args.emplace_back(realName, e); } if (partial) newMask[si] = ClassType::PartialFlag::Default; } else if (!partial) { // Case 4b: values / non-Id defaults (None, etc.) if (cast((*calleeFn->ast)[si].getDefault()) && !(*calleeFn->ast)[si].type) { args.emplace_back( realName, transform(N(N( N("Optional"), N("NoneType"))))); } else { args.emplace_back( realName, transform(clean_clone((*calleeFn->ast)[si].getDefault()))); } } else { args.emplace_back(realName, transform(N(EllipsisExpr::PARTIAL))); newMask[si] = ClassType::PartialFlag::Missing; } } else if (partial) { // Case 5: this is partial call. Just add ... for missing arguments args.emplace_back(realName, transform(N(EllipsisExpr::PARTIAL))); newMask[si] = ClassType::PartialFlag::Missing; } else { seqassert(expr, "cannot happen"); } } else { // Case: argument provided seqassert(slots[si].size() == 1, "call transformation failed"); args.emplace_back(realName, (*expr)[slots[si][0]].getExpr()); addReordered(slots[si][0]); } } return 0; }, true); }; // Reorder arguments if needed part.args = part.kwArgs = nullptr; // Stores partial *args/**kwargs expression if (expr->hasAttribute(Attr::ExprOrderedCall)) { args = expr->items; } else { reorderNamedArgs( calleeFn, expr->items, reorderFn, [&](error::Error e, const SrcInfo &o, const std::string &errorMsg) { E(Error::CUSTOM, o, errorMsg.c_str()); return -1; }, part.known); } // Do reordering if (!inOrder) { std::vector prepends; std::ranges::sort(ordered, [](const auto &a, const auto &b) { return a.first < b.first; }); for (auto &eptr : ordered | std::views::values) { auto name = getTemporaryVar("call"); auto front = transform( N(N(name), *eptr, getParamType((*eptr)->getType()))); auto swap = transform(N(name)); *eptr = swap; prepends.emplace_back(front); } return transform(N(prepends, expr)); } // Handle *args if (starIdx != -1) { Expr *se = N(starArgs); se->setAttribute(Attr::ExprStarArgument); if (!match(expr->getExpr(), M("hasattr"))) se = transform(se); if (partial) { part.args = se; args[starIdx].value = transform(N(EllipsisExpr::PARTIAL)); } else { args[starIdx].value = se; } } // Handle **kwargs if (kwStarIdx != -1) { auto kwid = generateKwId(kwStarNames); auto kwe = transform(N(N("NamedTuple"), N(kwStarArgs), N(kwid))); kwe->setAttribute(Attr::ExprKwStarArgument); if (partial) { part.kwArgs = kwe; args[kwStarIdx] = transform(N(EllipsisExpr::PARTIAL)); } else { args[kwStarIdx] = kwe; } } // Populate partial data if (part.args != nullptr) part.args->setAttribute(Attr::ExprSequenceItem); if (part.kwArgs != nullptr) part.kwArgs->setAttribute(Attr::ExprSequenceItem); if (part.isPartial) { expr->items.pop_back(); if (!part.args) part.args = transform(N()); // use () if (!part.kwArgs) part.kwArgs = transform(N(N("NamedTuple"))); // use NamedTuple() } // Unify function type generics with the provided generics seqassert((expr->hasAttribute(Attr::ExprOrderedCall) && typeArgs.empty()) || (!expr->hasAttribute(Attr::ExprOrderedCall) && typeArgs.size() == calleeFn->funcGenerics.size()), "bad vector sizes"); if (!calleeFn->funcGenerics.empty()) { auto niGenerics = calleeFn->ast->getNonInferrableGenerics(); for (size_t si = 0; !expr->hasAttribute(Attr::ExprOrderedCall) && si < calleeFn->funcGenerics.size(); si++) { const auto &gen = calleeFn->funcGenerics[si]; if (typeArgs[si]) { auto typ = extractType(typeArgs[si]); if (gen.staticKind && !typ->getStaticKind()) { E(Error::EXPECTED_STATIC, typeArgs[si]); } unify(typ, gen.getType()); } else { if (isUnbound(gen.getType()) && !(*calleeFn->ast)[si].getDefault() && !partial && in(niGenerics, gen.name)) { E(Error::CUSTOM, getSrcInfo(), "generic '{}' not provided", getUnmangledName(gen.name)); } } } } expr->items = args; expr->setAttribute(Attr::ExprOrderedCall); part.known = newMask; return nullptr; } /// Unify the call arguments' types with the function declaration signatures. /// Also apply argument transformations to ensure the type compatibility and handle /// default generics. /// @example /// `foo(1, 2)` -> `foo(1, Optional(2), T=int)` bool TypecheckVisitor::typecheckCallArgs(FuncType *calleeFn, std::vector &args, const PartialCallData &partial) { bool wrappingDone = true; // tracks whether all arguments are wrapped std::vector replacements; // list of replacement arguments withClassGenerics( calleeFn, [&]() { for (size_t i = 0, si = 0; i < calleeFn->ast->size(); i++) { if ((*calleeFn->ast)[i].isGeneric()) continue; if (startswith((*calleeFn->ast)[i].getName(), "*") && (*calleeFn->ast)[i].getType()) { // Special case: `*args: type` and `**kwargs: type` if (auto callExpr = cast(args[si].getExpr())) { auto typ = extractType(transform(clone((*calleeFn->ast)[i].getType()))); if (startswith((*calleeFn->ast)[i].getName(), "**")) callExpr = cast(callExpr->front().getExpr()); for (auto &ca : *callExpr) { if (wrapExpr(&ca.value, typ, calleeFn)) { unify(ca.getExpr()->getType(), typ); } else { wrappingDone = false; } } auto name = callExpr->getClassType()->name; auto tup = transform(N(N(name), callExpr->items)); if (startswith((*calleeFn->ast)[i].getName(), "**")) { args[si].value = transform(N( N(N("NamedTuple"), "__new__"), tup, N(extractClassGeneric(args[si].getExpr()->getType()) ->getIntStatic() ->value))); } else { args[si].value = tup; } } replacements.push_back(args[si].getExpr()->getType()); // else this is empty and is a partial call; leave it for later } else if (partial.isPartial && !partial.known.empty() && partial.known[si] == ClassType::PartialFlag::Default) { // Defaults should not be unified (yet)! replacements.push_back(extractFuncArgType(calleeFn, si)); } else { if (wrapExpr(&args[si].value, extractFuncArgType(calleeFn, si), calleeFn)) { unify(args[si].getExpr()->getType(), extractFuncArgType(calleeFn, si)); } else { wrappingDone = false; } replacements.push_back(!extractFuncArgType(calleeFn, si)->getClass() ? args[si].getExpr()->getType() : extractFuncArgType(calleeFn, si)); } si++; } return true; }, true); // Realize arguments bool done = true; for (auto &a : args) { // Previous unifications can qualify existing identifiers. // Transform again to get the full identifier if (realize(a.getExpr()->getType())) a.value = transform(a.getExpr()); done &= a.getExpr()->isDone(); } // Handle default generics if (!partial.isPartial) for (size_t i = 0, j = 0; wrappingDone && i < calleeFn->ast->size(); i++) if ((*calleeFn->ast)[i].isGeneric()) { if ((*calleeFn->ast)[i].getDefault() && isUnbound(extractFuncGeneric(calleeFn, j))) { auto def = extractType(withClassGenerics( calleeFn, [&]() { return transform(clean_clone((*calleeFn->ast)[i].getDefault())); }, true)); unify(extractFuncGeneric(calleeFn, j), def); } j++; } // Replace the arguments for (size_t si = 0; si < replacements.size(); si++) { if (replacements[si]) { extractClassGeneric(calleeFn)->getClass()->generics[si].type = replacements[si]->shared_from_this(); } } extractClassGeneric(calleeFn)->getClass()->_rn = ""; calleeFn->getClass()->_rn = ""; /// TODO: TERRIBLE! return done; } /// Transform and typecheck the following special call expressions: /// `superf(fn)` /// `super()` /// `__ptr__(var)` /// `__array__[int](sz)` /// `isinstance(obj, type)` /// `static.len(tup)` /// `hasattr(obj, "attr")` /// `getattr(obj, "attr")` /// `type(obj)` /// `compile_err("msg")` /// See below for more details. std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) { if (expr->hasAttribute(Attr::ExprNoSpecial)) return {false, nullptr}; auto ei = cast(expr->getExpr()); if (!ei) return {false, nullptr}; auto isF = [](const IdExpr *val, const std::string &module, const std::string &cls, const std::string &name = "") { if (name.empty()) return val->getValue() == getMangledFunc(module, cls); else return val->getValue() == getMangledMethod(module, cls, name); }; if (isF(ei, "std.internal.core", "superf")) { return {true, transformSuperF(expr)}; } else if (isF(ei, "std.internal.core", "super")) { return {true, transformSuper()}; } else if (isF(ei, "std.internal.core", "__ptr__")) { return {true, transformPtr(expr)}; } else if (isF(ei, "std.internal.core", "__array__", "__new__")) { return {true, transformArray(expr)}; } else if (isF(ei, "std.internal.core", "isinstance")) { // static return {true, transformIsInstance(expr)}; } else if (isF(ei, "std.internal.static", "len")) { // static return {true, transformStaticLen(expr)}; } else if (isF(ei, "std.internal.core", "hasattr")) { // static return {true, transformHasAttr(expr)}; } else if (isF(ei, "std.internal.core", "getattr")) { return {true, transformGetAttr(expr)}; } else if (isF(ei, "std.internal.core", "setattr")) { return {true, transformSetAttr(expr)}; } else if (isF(ei, "std.internal.core", "type", "__new__")) { return {true, transformTypeFn(expr)}; } else if (isF(ei, "std.internal.core", "compile_error")) { return {true, transformCompileError(expr)}; } else if (isF(ei, "std.internal.static", "print")) { return {false, transformStaticPrintFn(expr)}; } else if (isF(ei, "std.collections", "namedtuple")) { return {true, transformNamedTuple(expr)}; } else if (isF(ei, "std.functools", "partial")) { return {true, transformFunctoolsPartial(expr)}; } else if (isF(ei, "std.internal.static", "has_rtti")) { // static return {true, transformHasRttiFn(expr)}; } else if (isF(ei, "std.internal.static", "function", "realized")) { return {true, transformRealizedFn(expr)}; } else if (isF(ei, "std.internal.static", "function", "can_call")) { // static return {true, transformStaticFnCanCall(expr)}; } else if (isF(ei, "std.internal.static", "function", "has_type")) { // static return {true, transformStaticFnArgHasType(expr)}; } else if (isF(ei, "std.internal.static", "function", "get_type")) { return {true, transformStaticFnArgGetType(expr)}; } else if (isF(ei, "std.internal.static", "function", "args")) { return {true, transformStaticFnArgs(expr)}; } else if (isF(ei, "std.internal.static", "function", "has_default")) { // static return {true, transformStaticFnHasDefault(expr)}; } else if (isF(ei, "std.internal.static", "function", "get_default")) { return {true, transformStaticFnGetDefault(expr)}; } else if (isF(ei, "std.internal.static", "function", "wrap_args")) { return {true, transformStaticFnWrapCallArgs(expr)}; } else if (isF(ei, "std.internal.static", "vars")) { return {true, transformStaticVars(expr)}; } else if (isF(ei, "std.internal.static", "tuple_type")) { return {true, transformStaticTupleType(expr)}; } else if (isF(ei, "std.internal.static", "format")) { // static return {true, transformStaticFormat(expr)}; } else if (isF(ei, "std.internal.static", "int_to_string")) { // static return {true, transformStaticIntToStr(expr)}; } else { return {false, nullptr}; } } /// Get the list that describes the inheritance hierarchy of a given type. /// The first type in the list is the most recently inherited type. std::vector TypecheckVisitor::getStaticSuperTypes(ClassType *cls) { std::vector result; if (!cls) return result; result.push_back(cls->shared_from_this()); auto c = getClass(cls); auto fields = getClassFields(cls); for (auto &name : c->staticParentClasses) { auto parentTyp = instantiateType(extractClassType(name), cls); auto parentFields = getClassFields(parentTyp->getClass()); for (auto &field : fields) { for (auto &parentField : parentFields) if (field.name == parentField.name) { auto t = instantiateType(field.getType(), cls); unify(t.get(), instantiateType(parentField.getType(), parentTyp->getClass())); break; } } for (auto &t : getStaticSuperTypes(parentTyp->getClass())) result.push_back(t); } return result; } /// Get the list that describes the inheritance hierarchy of a given type. /// The first type in the list is the most recently inherited type. std::vector TypecheckVisitor::getRTTISuperTypes(ClassType *cls) { std::vector result; if (!cls) return result; auto c = getClass(cls); const auto &mro = c->mro; for (const auto &umt : c->mro) { auto mt = instantiateType(umt.get(), cls); realize(mt.get()); // ensure that parent types are realized result.push_back(mt); } return result; } /// Return a partial type call `Partial(args, kwargs, fn, mask)` for a given function /// and a mask. /// @param mask a 0-1 vector whose size matches the number of function arguments. /// 1 indicates that the argument has been provided and is cached within /// the partial object. Expr *TypecheckVisitor::generatePartialCall(const std::string &mask, types::FuncType *fn, Expr *args, Expr *kwargs) { if (!args) args = N(std::vector{N()}); if (!kwargs) kwargs = N(N("NamedTuple")); auto efn = N(fn->getFuncName()); efn->setType(instantiateType(getStdLibType("unrealized_type"), std::vector{fn->getFunc()})); efn->setDone(); Expr *call = N( N("Partial"), std::vector{CallArg{"args", args}, CallArg{"kwargs", kwargs}, CallArg{"M", N(mask)}, CallArg{"F", efn}}); return call; } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/class.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include "codon/cir/attribute.h" #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" #include "codon/parser/match.h" #include "codon/parser/visitors/format/format.h" #include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; namespace codon::ast { using namespace types; using namespace matcher; /// Parse a class (type) declaration and add a (generic) type to the context. void TypecheckVisitor::visit(ClassStmt *stmt) { // Get root name std::string name = stmt->getName(); // Generate/find class' canonical name (unique ID) and AST std::string canonicalName; std::vector &argsToParse = stmt->items; // classItem will be added later when the scope is different auto classItem = std::make_shared("", "", ctx->getModule(), nullptr, ctx->getScope()); classItem->setSrcInfo(stmt->getSrcInfo()); std::shared_ptr timedItem = nullptr; types::ClassType *typ = nullptr; if (!stmt->hasAttribute(Attr::Extend)) { classItem->canonicalName = canonicalName = ctx->generateCanonicalName(name, !stmt->hasAttribute(Attr::Internal), /* noSuffix*/ stmt->hasAttribute(Attr::Internal)); if (canonicalName == "Union") classItem->type = std::make_shared(ctx->cache); else classItem->type = std::make_shared(ctx->cache, canonicalName); if (stmt->isRecord()) classItem->type->getClass()->isTuple = true; classItem->type->setSrcInfo(stmt->getSrcInfo()); typ = classItem->getType()->getClass(); if (canonicalName != TYPE_TYPE) classItem->type = instantiateTypeVar(classItem->getType()); timedItem = std::make_shared(*classItem); // timedItem->time = getTime(); // Reference types are added to the context here. // Tuple types are added after class contents are parsed to prevent // recursive record types (note: these are allowed for reference types) if (!stmt->hasAttribute(Attr::Tuple)) { ctx->add(name, timedItem); ctx->addAlwaysVisible(classItem); } } else { // Find the canonical name and AST of the class that is to be extended if (!ctx->isGlobal() || ctx->isConditional()) E(Error::EXPECTED_TOPLEVEL, getSrcInfo(), "class extension"); auto val = ctx->find(name, getTime()); if (!val || !val->isType()) E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), name); typ = val->getName() == TYPE_TYPE ? val->getType()->getClass() : extractClassType(val->getType()); if (getClass(typ)->ast->hasAttribute(Attr::NoExtend)) if (!ctx->isStdlibLoading && !stmt->hasAttribute(Attr::AutoGenerated)) E(Error::CLASS_NO_EXTEND, getSrcInfo(), name); canonicalName = typ->name; argsToParse = getClass(typ)->ast->items; } auto &cls = ctx->cache->classes[canonicalName]; std::vector clsStmts; // Will be filled later! std::vector varStmts; // Will be filled later! std::vector fnStmts; // Will be filled later! std::vector addLater; try { // Add the class base TypeContext::BaseGuard br(ctx.get(), canonicalName); ctx->getBase()->type = typ->shared_from_this(); // Parse and add class generics std::vector args; if (stmt->hasAttribute(Attr::Extend)) { for (auto &a : argsToParse) { if (!a.isGeneric()) continue; auto val = ctx->forceFind(a.name); val->type->getLink()->kind = LinkType::Unbound; ctx->add(getUnmangledName(val->canonicalName), val); args.emplace_back(val->canonicalName, nullptr, nullptr, a.status); } } else { if (stmt->hasAttribute(Attr::ClassDeduce)) { if (!autoDeduceMembers(stmt, argsToParse)) stmt->eraseAttribute(Attr::ClassDeduce); } // Add all generics before parent classes, fields and methods for (auto &a : argsToParse) { if (!a.isGeneric()) continue; auto varName = ctx->generateCanonicalName(a.getName()); auto generic = instantiateUnbound(); auto typId = generic->id; generic->getLink()->genericName = varName; auto defType = transformType(clone(a.getDefault())); if (defType) generic->defaultType = extractType(defType)->shared_from_this(); if (auto st = getStaticGeneric(a.getType())) { if (st == LiteralKind::Runtime) a.type = transform(a.getType()); // trigger error generic->staticKind = st; auto val = ctx->addVar(a.getName(), varName, generic); val->generic = true; } else { if (cast(a.getType())) { // Parse TraitVar a.type = transform(a.getType()); auto ti = cast(a.getType()); seqassert(ti && isId(ti->getExpr(), TRAIT_TYPE), "not a TypeTrait instantiation: {}", *(a.getType())); auto l = extractType(ti->front()); if (l->getLink() && l->getLink()->trait) generic->getLink()->trait = l->getLink()->trait; else generic->getLink()->trait = std::make_shared(l->shared_from_this()); } ctx->addType(a.getName(), varName, generic)->generic = true; } typ->generics.emplace_back(varName, generic->generalize(ctx->typecheckLevel), typId, generic->staticKind); args.emplace_back(varName, a.getType(), defType, a.status); } } // Form class type node (e.g. `Foo`, or `Foo[T, U]` for generic classes) Expr *transformedTypeAst = nullptr; if (!stmt->hasAttribute(Attr::Extend)) { transformedTypeAst = N(canonicalName); for (auto &a : args) { if (a.isGeneric()) { if (!cast(transformedTypeAst)) { transformedTypeAst = N(N(canonicalName), std::vector{}); } cast(transformedTypeAst) ->items.push_back(transform(N(a.getName()), true)); } } } // Collect classes (and their fields) that are to be statically inherited std::vector staticBaseASTs; if (!stmt->hasAttribute(Attr::Extend)) { // Handle static inheritance staticBaseASTs = parseBaseClasses(stmt->staticBaseClasses, args, stmt, canonicalName, nullptr, typ); // Handle RTTI inheritance parseBaseClasses(stmt->baseClasses, args, stmt, canonicalName, transformedTypeAst, typ); } // A ClassStmt will be separated into class variable assignments, method-free // ClassStmts (that include nested classes) and method FunctionStmts transformNestedClasses(stmt, clsStmts, varStmts, fnStmts); // Collect class fields for (auto &a : argsToParse) { if (a.isValue()) { if (ClassStmt::isClassVar(a)) { // Handle class variables. Transform them later to allow self-references auto varName = fmt::format("{}.{}", canonicalName, a.getName()); auto h = transform(N(N(varName), nullptr, nullptr)); preamble->addStmt(h); auto val = ctx->forceFind(varName); val->baseName = ""; val->scope = {0}; registerGlobal(val->canonicalName); if (a.getDefault()) { auto assign = N( N(varName), transform(a.getDefault()), a.getType() ? cast(a.getType())->getIndex() : nullptr); assign->setUpdate(); varStmts.push_back(assign); } cls.classVars[a.getName()] = varName; ctx->add(a.getName(), val); } else if (!stmt->hasAttribute(Attr::Extend)) { std::string varName = a.getName(); auto ta = transformType(clean_clone(a.getType()), true); args.emplace_back(varName, ta, transform(clone(a.getDefault()), true)); cls.fields.emplace_back(varName, nullptr, canonicalName); } } } // ASTs for member arguments to be used for populating magic methods std::vector memberArgs; for (auto &a : args) { if (a.isValue()) memberArgs.emplace_back(clone(a)); } // Handle class members if (!stmt->hasAttribute(Attr::Extend)) { ctx->typecheckLevel++; // to avoid unifying generics early if (canonicalName == TYPE_TUPLE) { // Special tuple handling! for (auto aj = 0; aj < MAX_TUPLE; aj++) { auto genName = fmt::format("T{}", aj + 1); auto genCanName = ctx->generateCanonicalName(genName); auto generic = instantiateUnbound(); generic->getLink()->genericName = genName; Expr *te = N(genCanName); cls.fields.emplace_back(fmt::format("item{}", aj + 1), generic->generalize(ctx->typecheckLevel), "", te); } } else { for (auto ai = 0, aj = 0; ai < args.size(); ai++) { if (args[ai].isValue() && !ClassStmt::isClassVar(args[ai])) { cls.fields[aj].typeExpr = clean_clone(args[ai].getType()); cls.fields[aj].type = extractType(args[ai].getType())->generalize(ctx->typecheckLevel - 1); cls.fields[aj].type->setSrcInfo(args[ai].getType()->getSrcInfo()); aj++; } } } ctx->typecheckLevel--; } // Parse class members (arguments) and methods if (!stmt->hasAttribute(Attr::Extend)) { // Now that we are done with arguments, add record type to the context if (stmt->hasAttribute(Attr::Tuple)) { ctx->add(name, timedItem); ctx->addAlwaysVisible(classItem); } // Create a cached AST. stmt->setAttribute(Attr::Module, ctx->moduleName.status == ImportFile::STDLIB ? STDLIB_IMPORT : ctx->moduleName.path); cls.ast = N(canonicalName, args, N()); cls.ast->cloneAttributesFrom(stmt); cls.ast->baseClasses = stmt->baseClasses; for (auto &b : staticBaseASTs) cls.staticParentClasses.emplace_back(b->getClass()->name); cls.module = ctx->moduleName.path; cls.jitCell = ctx->cache->jitCell; // Codegen default magic methods // __new__ must be the first if (auto aa = stmt->getAttribute(Attr::ClassMagic)) for (const auto &m : aa->values) { fnStmts.push_back(transform(codegenMagic(m, transformedTypeAst, memberArgs, stmt->hasAttribute(Attr::Tuple)))); } // Add inherited methods for (auto &base : staticBaseASTs) { for (auto &val : getClass(base->getClass())->methods | std::views::values) for (auto &mf : getOverloads(val)) { const auto &fp = getFunction(mf); auto f = fp->origAst; if (f && !f->hasAttribute(Attr::AutoGenerated)) { fnStmts.push_back( cast(withClassGenerics(base->getClass(), [&]() { // since functions can come from other modules // make sure to transform them in their respective module // however make sure to add/pop generics :/ if (!ctx->isStdlibLoading && fp->module != ctx->moduleName.path) { auto ictx = getImport(fp->module)->ctx; TypeContext::BaseGuard _(ictx.get(), canonicalName); ictx->getBase()->type = typ->shared_from_this(); auto tv = TypecheckVisitor(ictx); auto e = tv.withClassGenerics( typ, [&]() { return tv.transform(clean_clone(f)); }, false, false, /*instantiate*/ true); return e; } else { return transform(clean_clone(f)); } }))); } } } } // Add class methods for (const auto &sp : getClassMethods(stmt->getSuite())) { if (auto fp = cast(sp)) { for (auto *&dc : fp->decorators) { // Handle @setter setters if (match(dc, M(M(fp->getName()), "setter")) && fp->size() == 2) { fp->name = fmt::format("{}{}", FN_SETTER_SUFFIX, fp->getName()); dc = nullptr; break; } } // All tuple methods are marked with AutoGenerated for later convenience // (e.g. heterogenous tuple processing). if (canonicalName == TYPE_TUPLE) sp->setAttribute(Attr::AutoGenerated); fnStmts.emplace_back(transform(sp)); } } // After popping context block, record types and nested classes will disappear. // Store their references and re-add them to the context after popping addLater.reserve(clsStmts.size() + 1); for (auto &c : clsStmts) addLater.emplace_back(ctx->find(cast(c)->getName())); if (stmt->hasAttribute(Attr::Tuple)) addLater.emplace_back(ctx->forceFind(name)); // Mark functions as virtual: auto banned = std::set{"__init__", "__new__", "__raw__", "__tuplesize__", "__repr_default__"}; for (const auto &method : cls.methods | std::views::keys) { for (size_t mi = 1; mi < cls.mro.size(); mi++) { // ... in the current class auto b = cls.mro[mi]->name; if (in(getClass(b)->methods, method) && !in(banned, method)) { cls.virtuals.insert(method); } } for (auto &v : cls.virtuals) { for (size_t mi = 1; mi < cls.mro.size(); mi++) { // ... and in parent classes auto b = cls.mro[mi]->name; getClass(b)->virtuals.insert(v); } } } // Generalize generics and remove them from the context for (const auto &g : args) if (!g.isValue()) { auto generic = ctx->forceFind(g.name)->type; if (g.status == Param::Generic) { // Generalize generics. Hidden generics are linked to the class generics so // ignore them seqassert(generic && generic->getLink() && generic->getLink()->kind != types::LinkType::Link, "generic has been unified"); generic->getLink()->kind = LinkType::Generic; } ctx->remove(g.name); } // Debug information if (!startswith(canonicalName, "Tuple") && false) { LOG_REALIZE("[class] {} -> {:c} / {}", canonicalName, *typ, cls.fields.size()); for (auto &m : cls.fields) LOG_REALIZE(" - member: {}: {:c}", m.name, *(m.type)); for (auto &m : cls.methods) LOG_REALIZE(" - method: {}: {}", m.first, m.second); for (auto &m : cls.mro) LOG_REALIZE(" - mro: {:c}", *m); } } catch (const exc::ParserException &) { if (!stmt->hasAttribute(Attr::Tuple)) ctx->remove(name); ctx->cache->classes.erase(name); throw; } for (auto &i : addLater) ctx->add(getUnmangledName(i->canonicalName), i); // Extensions are not needed as the cache is already populated if (!stmt->hasAttribute(Attr::Extend)) { auto c = cls.ast; seqassert(c, "not a class AST for {}", canonicalName); c->setDone(); clsStmts.push_back(c); } clsStmts.insert(clsStmts.end(), fnStmts.begin(), fnStmts.end()); for (auto &a : varStmts) { // Transform class variables here to allow self-references clsStmts.push_back(transform(a)); } resultStmt = N(clsStmts); } /// Parse statically inherited classes. /// Returns a list of their ASTs. Also updates the class fields. /// @param args Class fields that are to be updated with base classes' fields. /// @param typeAst Transformed AST for base class type (e.g., `A[T]`). /// Only set when dealing with dynamic polymorphism. std::vector TypecheckVisitor::parseBaseClasses( std::vector &baseClasses, std::vector &args, const Stmt *attr, const std::string &canonicalName, const Expr *typeAst, types::ClassType *typ) { std::vector asts; // TODO)) fix MRO it to work with generic classes (maybe replacements? IDK...) std::vector> mro{{typ->shared_from_this()}}; for (auto &cls : baseClasses) { std::vector subs; // Get the base class and generic replacements (e.g., if there is Bar[T], // Bar in Foo(Bar[int]) will have `T = int`) cls = transformType(cls, true); if (!cls->getClassType()) E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), FormatVisitor::apply(cls)); auto clsTyp = extractClassType(cls); asts.push_back(clsTyp->shared_from_this()); auto cachedCls = getClass(clsTyp); if (!cachedCls->ast) E(Error::CLASS_NO_INHERIT, getSrcInfo(), "nested", "surrounding"); std::vector rootMro; for (auto &t : cachedCls->mro) rootMro.push_back(instantiateType(t.get(), clsTyp)); mro.push_back(rootMro); // Sanity checks if (attr->hasAttribute(Attr::Tuple) && typeAst) E(Error::CLASS_NO_INHERIT, getSrcInfo(), "tuple", "other"); if (!attr->hasAttribute(Attr::Tuple) && cachedCls->ast->hasAttribute(Attr::Tuple)) E(Error::CLASS_TUPLE_INHERIT, getSrcInfo()); if (cachedCls->ast->hasAttribute(Attr::Internal)) E(Error::CLASS_NO_INHERIT, getSrcInfo(), "internal", "other"); // Mark parent classes as polymorphic as well. if (typeAst && !cachedCls->hasRTTI()) { if (ctx->cache->isJit && cachedCls->jitCell != ctx->cache->jitCell) E(Error::CUSTOM, cls, "cannot inherit from a non-RTTI class defined in previous cell '{}' " "in JIT mode", getUnmangledName(clsTyp->name)); cachedCls->rtti = true; } // Add hidden generics addClassGenerics(clsTyp); for (auto &g : clsTyp->generics) { g.type = g.getType()->generalize(ctx->typecheckLevel); typ->hiddenGenerics.push_back(g); } for (auto &g : clsTyp->hiddenGenerics) { g.type = g.getType()->generalize(ctx->typecheckLevel); typ->hiddenGenerics.push_back(g); } // Add class variables for (auto &[varName, varCanonicalName] : cachedCls->classVars) { // Handle class variables. Transform them later to allow self-references auto newName = fmt::format("{}.{}", canonicalName, varName); auto newCanonicalName = ctx->generateCanonicalName(newName); getClass(typ)->classVars[varName] = varCanonicalName; ctx->add(newName, ctx->forceFind(varCanonicalName)); ctx->add(newCanonicalName, ctx->forceFind(varCanonicalName)); } } // Add normal fields auto cls = getClass(canonicalName); for (auto &clsTyp : asts) { withClassGenerics(clsTyp->getClass(), [&]() { int ai = 0; auto ast = getClass(clsTyp->getClass())->ast; for (auto &a : *ast) { auto acls = getClass(ast->name); if (a.isValue() && !ClassStmt::isClassVar(a)) { auto name = a.getName(); int i = 0; for (auto &aa : args) i += aa.getName() == a.getName() || startswith(aa.getName(), a.getName() + "#"); if (i) name = fmt::format("{}#{}", name, i); seqassert(acls->fields[ai].name == a.getName(), "bad class fields: {} vs {}", acls->fields[ai].name, a.getName()); args.emplace_back(name, transformType(clean_clone(a.getType()), true), transform(clean_clone(a.getDefault()))); cls->fields.emplace_back( name, extractType(args.back().getType())->shared_from_this(), acls->fields[ai].baseClass); ai++; } } return true; }); } if (typeAst) { if (!asts.empty() || typ->name == getMangledClass("std.internal.builtin", "object")) { mro.push_back(asts); cls->rtti = true; } cls->mro = Cache::mergeC3(mro); if (cls->mro.empty()) { E(Error::CLASS_BAD_MRO, getSrcInfo()); } } return asts; } /// Find the first __init__ with self parameter and use it to deduce class members. /// Each deduced member will be treated as generic. /// @example /// ```@deduce /// class Foo: /// def __init__(self): /// self.x, self.y = 1, 2``` /// will result in /// ```class Foo[T1, T2]: /// x: T1 /// y: T2``` /// @return the transformed init and the pointer to the original function. bool TypecheckVisitor::autoDeduceMembers(ClassStmt *stmt, std::vector &args) { std::set members; for (const auto &sp : getClassMethods(stmt->suite)) if (auto f = cast(sp)) { if (f->name == "__init__") if (const auto b = f->getAttribute(Attr::ClassDeduce)) { for (const auto &m : b->values) members.insert(m); } } if (!members.empty()) { // log("auto-deducing {}: {}", stmt->name, members); if (auto aa = stmt->getAttribute(Attr::ClassMagic)) std::erase(aa->values, "init"); for (auto m : members) { auto genericName = fmt::format("T_{}", m); args.emplace_back(genericName, N(TYPE_TYPE), N("NoneType"), Param::Generic); args.emplace_back(m, N(genericName)); } return true; } return false; } /// Return a list of all statements within a given class suite. /// Checks each suite recursively, and assumes that each statement is either /// a function, a class or a docstring. std::vector TypecheckVisitor::getClassMethods(Stmt *s) { std::vector v; if (!s) return v; if (auto sp = cast(s)) { for (auto *ss : *sp) for (auto *u : getClassMethods(ss)) v.push_back(u); } else if (cast(s) || cast(s)) { v.push_back(s); } else if (!match(s, M(M()))) { E(Error::CLASS_BAD_ATTR, s); } return v; } /// Extract nested classes and transform them before the main class. void TypecheckVisitor::transformNestedClasses(const ClassStmt *stmt, std::vector &clsStmts, std::vector &varStmts, std::vector &fnStmts) { for (const auto &sp : getClassMethods(stmt->suite)) if (auto cp = cast(sp)) { auto origName = cp->getName(); // If class B is nested within A, it's name is always A.B, never B itself. // Ensure that parent class name is appended auto parentName = stmt->getName(); cp->name = fmt::format("{}.{}", parentName, origName); auto tsp = transform(cp); if (auto tss = cast(tsp)) { std::string name; for (auto &s : *tss) if (auto c = cast(s)) { clsStmts.push_back(s); name = c->getName(); } else if (cast(s)) { varStmts.push_back(s); } else { fnStmts.push_back(s); } ctx->add(origName, ctx->forceFind(name)); } } } /// Generate a magic method `__op__` for each magic `op` /// described by @param typExpr and its arguments. /// Currently, generate: /// @li Constructors: __new__, __init__ /// @li Utilities: __raw__, __hash__, __repr__, __tuplesize__, __add__, __mul__, __len__ /// @li Iteration: __iter__, __getitem__, __len__, __contains__ /// @li Comparisons: __eq__, __ne__, __lt__, __le__, __gt__, __ge__ /// @li Pickling: __pickle__, __unpickle__ /// @li Python: __to_py__, __from_py__ /// @li GPU: __to_gpu__, __from_gpu__, __from_gpu_new__ Stmt *TypecheckVisitor::codegenMagic(const std::string &op, Expr *typExpr, const std::vector &allArgs, bool isRecord) { #define I(s) N(s) #define NS(x) N(N("__magic__"), (x)) seqassert(typExpr, "typExpr is null"); Expr *ret = nullptr; std::vector fargs; std::vector stmts; std::vector attrs{Attr::AutoGenerated}; std::vector args; args.reserve(allArgs.size()); for (auto &a : allArgs) args.push_back(clone(a)); if (op == "new") { ret = clone(typExpr); if (isRecord) { // Tuples: def __new__() -> T (internal) for (auto &a : args) fargs.emplace_back(a.getName(), clone(a.getType()), clone(a.getDefault())); attrs.push_back(Attr::Internal); } else { // Classes: def __new__() -> T stmts.emplace_back(N(N(NS(op), clone(typExpr)))); } } else if (op == "init") { // Classes: def __init__(self: T, a1: T1, ..., aN: TN) -> None: // self.aI = aI ... ret = I("NoneType"); fargs.emplace_back("self", clone(typExpr)); for (auto &a : args) { fargs.emplace_back(a.getName(), clean_clone(a.getType()), clone(a.getDefault())); stmts.push_back( N(N(I("self"), a.getName()), I(a.getName()))); } } else if (op == "raw" || op == "dict") { // Classes: def __raw__(self: T) fargs.emplace_back("self", clone(typExpr)); stmts.emplace_back(N(N(NS(op), I("self")))); } else if (op == "tuplesize") { // def __tuplesize__() -> int ret = I("int"); stmts.emplace_back(N(N(NS(op)))); } else if (op == "getitem") { // Tuples: def __getitem__(self: T, index: int) fargs.emplace_back("self", clone(typExpr)); fargs.emplace_back("index", I("int")); stmts.emplace_back(N(N(NS(op), I("self"), I("index")))); } else if (op == "iter") { // Tuples: def __iter__(self: T) fargs.emplace_back("self", clone(typExpr)); stmts.emplace_back(N(N(NS(op), I("self")))); } else if (op == "contains") { // Tuples: def __contains__(self: T, what) -> bool fargs.emplace_back("self", clone(typExpr)); fargs.emplace_back("what", nullptr); ret = I("bool"); stmts.emplace_back(N(N(NS(op), I("self"), I("what")))); } else if (op == "eq" || op == "ne" || op == "lt" || op == "le" || op == "gt" || op == "ge") { // def __op__(self: T, obj: T) -> bool fargs.emplace_back("self", clone(typExpr)); fargs.emplace_back("obj", clone(typExpr)); ret = I("bool"); stmts.emplace_back(N(N(NS(op), I("self"), I("obj")))); } else if (op == "hash" || op == "len") { // def __hash__(self: T) -> int fargs.emplace_back("self", clone(typExpr)); ret = I("int"); stmts.emplace_back(N(N(NS(op), I("self")))); } else if (op == "pickle") { // def __pickle__(self: T, dest: Ptr[byte]) fargs.emplace_back("self", clone(typExpr)); fargs.emplace_back("dest", N(I("Ptr"), I("byte"))); stmts.emplace_back(N(N(NS(op), I("self"), I("dest")))); } else if (op == "unpickle" || op == "from_py") { // def __unpickle__(src: Ptr[byte]) -> T fargs.emplace_back("src", N(I("Ptr"), I("byte"))); ret = clone(typExpr); stmts.emplace_back(N(N(NS(op), I("src"), clone(typExpr)))); } else if (op == "to_py") { // def __to_py__(self: T) -> Ptr[byte] fargs.emplace_back("self", clone(typExpr)); ret = N(I("Ptr"), I("byte")); stmts.emplace_back(N(N(NS(op), I("self")))); } else if (op == "to_gpu") { // def __to_gpu__(self: T, cache) -> T fargs.emplace_back("self", clone(typExpr)); fargs.emplace_back("cache"); ret = clone(typExpr); stmts.emplace_back(N(N(NS(op), I("self"), I("cache")))); } else if (op == "from_gpu") { // def __from_gpu__(self: T, other: T) fargs.emplace_back("self", clone(typExpr)); fargs.emplace_back("other", clone(typExpr)); stmts.emplace_back(N(N(NS(op), I("self"), I("other")))); } else if (op == "from_gpu_new") { // def __from_gpu_new__(other: T) -> T fargs.emplace_back("other", clone(typExpr)); ret = clone(typExpr); stmts.emplace_back(N(N(NS(op), I("other")))); } else if (op == "repr") { // def __repr__(self: T) -> str fargs.emplace_back("self", clone(typExpr)); ret = I("str"); stmts.emplace_back(N(N(NS(op), I("self")))); } else if (op == "repr_default") { // def __repr_default__(self: T) -> str fargs.emplace_back("self", clone(typExpr)); ret = I("str"); stmts.emplace_back(N(N(NS(op), I("self")))); } else if (op == "add") { // def __add__(self, obj) fargs.emplace_back("self", clone(typExpr)); fargs.emplace_back("obj", nullptr); stmts.emplace_back(N(N(NS(op), I("self"), I("obj")))); } else if (op == "mul") { // def __mul__(self, i: Literal[int]) fargs.emplace_back("self", clone(typExpr)); fargs.emplace_back("i", N(I("Literal"), I("int"))); stmts.emplace_back(N(N(NS(op), I("self"), I("i")))); } else { seqassert(false, "invalid magic {}", op); } #undef I #undef NS auto t = NC(fmt::format("__{}__", op), ret, fargs, NC(stmts)); for (auto &a : attrs) t->setAttribute(a); t->setSrcInfo(ctx->cache->generateSrcInfo()); return t; } int TypecheckVisitor::generateKwId(const std::vector &names) const { auto key = join(names, ";"); std::string suffix; if (!names.empty()) { // Each set of names generates different tuple (i.e., `KwArgs[foo, bar]` is not the // same as `KwArgs[bar, baz]`). Cache the names and use an integer for each name // combination. if (!in(ctx->cache->generatedKwTuples, key)) { ctx->cache->generatedTupleNames.push_back(names); ctx->cache->generatedKwTuples[key] = static_cast(ctx->cache->generatedKwTuples.size()) + 1; } return ctx->cache->generatedKwTuples[key]; } else { return 0; } } types::ClassType *TypecheckVisitor::generateTuple(size_t n, bool generateNew) { if (n > MAX_TUPLE) E(Error::CUSTOM, getSrcInfo(), "tuple too large ({})", n); auto key = fmt::format("{}.{}", TYPE_TUPLE, n); auto val = getImport(STDLIB_IMPORT)->ctx->find(key); if (!val) { auto t = std::make_shared(ctx->cache, TYPE_TUPLE); t->isTuple = true; auto cls = getClass(t.get()); seqassert(n <= cls->fields.size(), "tuple too large"); for (size_t i = 0; i < n; i++) { const auto &f = cls->fields[i]; auto gt = f.getType()->getLink(); t->generics.emplace_back(cast(f.typeExpr)->getValue(), f.type, gt->id, LiteralKind::Runtime); } val = getImport(STDLIB_IMPORT)->ctx->addType(key, key, t); } auto t = val->getType()->getClass(); if (generateNew && !in(ctx->cache->generatedTuples, n)) { ctx->cache->generatedTuples.insert(n); std::vector newFnArgs; std::vector typeArgs; for (size_t i = 0; i < n; i++) { newFnArgs.emplace_back(fmt::format("item{}", i + 1), N(fmt::format("T{}", i + 1))); typeArgs.emplace_back(N(fmt::format("T{}", i + 1))); } for (size_t i = 0; i < n; i++) { newFnArgs.emplace_back(fmt::format("T{}", i + 1), N(TYPE_TYPE)); } Stmt *fn = N( "__new__", N(N(TYPE_TUPLE), N(typeArgs)), newFnArgs, nullptr); fn->setAttribute(Attr::Internal); Stmt *ext = N(TYPE_TUPLE, std::vector{}, fn); ext->setAttribute(Attr::Extend); ext->setAttribute(Attr::AutoGenerated); ext = N(ext); llvm::cantFail(ScopingVisitor::apply(ctx->cache, ext)); auto rctx = getImport(STDLIB_IMPORT)->ctx; auto oldBases = rctx->bases; rctx->bases.clear(); rctx->bases.push_back(oldBases[0]); ext = TypecheckVisitor::apply(rctx, ext); rctx->bases = oldBases; preamble->addStmt(ext); } return t; } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/collections.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; namespace codon::ast { using namespace types; /// Transform tuples. /// @example /// `(a1, ..., aN)` -> `Tuple.__new__(a1, ..., aN)` void TypecheckVisitor::visit(TupleExpr *expr) { resultExpr = transform(N(N(N(TYPE_TUPLE), "__new__"), expr->items)); } /// Transform a list `[a1, ..., aN]` to the corresponding statement expression. /// See @c transformComprehension void TypecheckVisitor::visit(ListExpr *expr) { expr->setType(instantiateUnbound()); auto name = getStdLibType("List")->name; if ((resultExpr = transformComprehension(name, "append", expr->items))) { resultExpr->setAttribute(Attr::ExprList); } } /// Transform a set `{a1, ..., aN}` to the corresponding statement expression. /// See @c transformComprehension void TypecheckVisitor::visit(SetExpr *expr) { expr->setType(instantiateUnbound()); auto name = getStdLibType("Set")->name; if ((resultExpr = transformComprehension(name, "add", expr->items))) { resultExpr->setAttribute(Attr::ExprSet); } } /// Transform a dictionary `{k1: v1, ..., kN: vN}` to a corresponding statement /// expression. See @c transformComprehension void TypecheckVisitor::visit(DictExpr *expr) { expr->setType(instantiateUnbound()); auto name = getStdLibType("Dict")->name; if ((resultExpr = transformComprehension(name, "__setitem__", expr->items))) { resultExpr->setAttribute(Attr::ExprDict); } } /// Transform a tuple generator expression. /// @example /// `tuple(expr for i in tuple_generator)` -> `Tuple.N.__new__(expr...)` void TypecheckVisitor::visit(GeneratorExpr *expr) { // List comprehension optimization: // Use `iter.__len__()` when creating list if there is a single for loop // without any if conditions in the comprehension bool canOptimize = expr->kind == GeneratorExpr::ListGenerator && expr->loopCount() == 1; if (canOptimize) { auto iter = transform(clone(cast(expr->getFinalSuite())->getIter())); auto ce = cast(iter); if (IdExpr *id = nullptr; ce && ((id = cast(ce->getExpr())))) { // Turn off this optimization for static items canOptimize &= !startswith(id->getValue(), "std.internal.static"); } } Expr *var = N(getTemporaryVar("gen")); if (expr->kind == GeneratorExpr::ListGenerator) { // List comprehensions expr->setFinalExpr( N(N(clone(var), "append"), expr->getFinalExpr())); auto suite = expr->getFinalSuite(); auto noOptStmt = N(N(clone(var), N(N("List"))), suite); if (canOptimize) { auto optimizeVar = getTemporaryVar("i"); auto origIter = cast(expr->getFinalSuite())->getIter(); auto optStmt = clone(noOptStmt); cast((*cast(optStmt))[1])->iter = N(optimizeVar); optStmt = N( N(N(optimizeVar), clone(origIter)), N( clone(var), N(N("List"), N(N(N(optimizeVar), "__len__")))), (*cast(optStmt))[1]); resultExpr = N( N(N("hasattr"), clone(origIter), N("__len__")), N(optStmt, clone(var)), N(noOptStmt, var)); } else { resultExpr = N(noOptStmt, var); } resultExpr = transform(resultExpr); } else if (expr->kind == GeneratorExpr::SetGenerator) { // Set comprehensions auto head = N(clone(var), N(N("Set"))); expr->setFinalExpr( N(N(clone(var), "add"), expr->getFinalExpr())); auto suite = expr->getFinalSuite(); resultExpr = transform(N(N(head, suite), var)); } else if (expr->kind == GeneratorExpr::DictGenerator) { // Dictionary comprehensions auto head = N(clone(var), N(N("Dict"))); expr->setFinalExpr(N(N(clone(var), "__setitem__"), N(expr->getFinalExpr()))); auto suite = expr->getFinalSuite(); resultExpr = transform(N(N(head, suite), var)); } else if (expr->kind == GeneratorExpr::TupleGenerator) { seqassert(expr->loopCount() == 1, "invalid tuple generator"); auto gen = transform(cast(expr->getFinalSuite())->getIter()); if (!gen->getType()->canRealize()) return; // Wait until the iterator can be realized auto block = N(); // `tuple = tuple_generator` auto tupleVar = getTemporaryVar("tuple"); block->addStmt(N(N(tupleVar), gen)); auto forStmt = clone(cast(expr->getFinalSuite())); auto finalExpr = expr->getFinalExpr(); auto [ok, delay, preamble, staticItems] = transformStaticLoopCall( cast(expr->getFinalSuite())->getVar(), &forStmt->suite, gen, [&](Stmt *wrap) { return N(clone(wrap), clone(finalExpr)); }, true); if (!ok) E(Error::CALL_BAD_ITER, gen, gen->getType()->prettyString()); if (delay) return; std::vector tupleItems; for (auto &i : staticItems) tupleItems.push_back(cast(i)); if (preamble) block->addStmt(preamble); resultExpr = transform(N(block, N(tupleItems))); } else { expr->loops = transform(expr->getFinalSuite()); // assume: internal data will be changed if (!expr->getFinalExpr()) { // Case such as (0 for _ in static.range(2)) // TODO: make this better. E(Error::CUSTOM, expr, "generator cannot be compiled. If using static tuple generator, use tuple(...) " "instead."); } unify(expr->getType(), instantiateType(getStdLibType("Generator"), {expr->getFinalExpr()->getType()})); if (realize(expr->getType())) expr->setDone(); } } /// Transform a collection of type `type` to a statement expression: /// `[a1, ..., aN]` -> `cont = [type](); (cont.[fn](a1); ...); cont` /// Any star-expression within the collection will be expanded: /// `[a, *b]` -> `cont.[fn](a); for i in b: cont.[fn](i)`. /// @example /// `[a, *b, c]` -> ```cont = List(3) /// cont.append(a) /// for i in b: cont.append(i) /// cont.append(c)``` /// `{a, *b, c}` -> ```cont = Set() /// cont.add(a) /// for i in b: cont.add(i) /// cont.add(c)``` /// `{a: 1, **d}` -> ```cont = Dict() /// cont.__setitem__((a, 1)) /// for i in b.items(): cont.__setitem__((i[0], i[i]))``` Expr *TypecheckVisitor::transformComprehension(const std::string &type, const std::string &fn, std::vector &items) { // Deduce the super type of the collection--- in other words, the least common // ancestor of all types in the collection. For example, `type([1, 1.2]) == type([1.2, // 1]) == float` because float is an "ancestor" of int. // TODO: use wrapExpr... auto superTyp = [&](ClassType *collectionCls, ClassType *ti) -> TypePtr { if (!collectionCls) return ti->shared_from_this(); if (collectionCls->is("int") && ti->is("float")) { // Rule: int derives from float return ti->shared_from_this(); } else if (collectionCls->name != TYPE_OPTIONAL && ti->name == TYPE_OPTIONAL) { // Rule: T derives from Optional[T] return instantiateType(getStdLibType("Optional"), std::vector{collectionCls}); } else if (collectionCls->name == TYPE_OPTIONAL && ti->name != TYPE_OPTIONAL) { return instantiateType(getStdLibType("Optional"), std::vector{ti}); } else if (!collectionCls->is("pyobj") && ti->is("pyobj")) { // Rule: anything derives from pyobj return ti->shared_from_this(); } else if (collectionCls->name != ti->name) { // Rule: subclass derives from superclass const auto &mros = getClass(collectionCls)->mro; for (size_t i = 1; i < mros.size(); i++) { auto t = instantiateType(mros[i].get(), collectionCls); if (t->unify(ti, nullptr) >= 0) { return ti->shared_from_this(); } } } return nullptr; }; TypePtr collectionTyp = instantiateUnbound(); bool done = true; bool isDict = type == getStdLibType("Dict")->name; for (auto &i : items) { ClassType *typ = nullptr; if (!isDict && cast(i)) { auto star = cast(i); star->expr = transform(N(N(star->getExpr(), "__iter__"))); if (star->getExpr()->getType()->is("Generator")) typ = extractClassGeneric(star->getExpr()->getType())->getClass(); } else if (isDict && cast(i)) { auto star = cast(i); star->expr = transform(N(N(star->getExpr(), "items"))); if (star->getExpr()->getType()->is("Generator")) typ = extractClassGeneric(star->getExpr()->getType())->getClass(); } else { i = transform(i); typ = i->getClassType(); } if (!typ) { done = false; continue; } if (!collectionTyp->getClass()) { unify(collectionTyp.get(), typ); } else if (!isDict) { if (auto t = superTyp(collectionTyp->getClass(), typ)) collectionTyp = t; } else { auto tt = unify(typ, instantiateType(generateTuple(2)))->getClass(); seqassert(collectionTyp->getClass() && collectionTyp->getClass()->generics.size() == 2 && tt->generics.size() == 2, "bad dict"); std::vector nt; for (int di = 0; di < 2; di++) { nt.push_back(extractClassGeneric(collectionTyp.get(), di)->shared_from_this()); if (!nt[di]->getClass()) unify(nt[di].get(), extractClassGeneric(tt, di)); else if (auto dt = superTyp(nt[di]->getClass(), extractClassGeneric(tt, di)->getClass())) nt[di] = dt; } collectionTyp = instantiateType(generateTuple(nt.size()), ctx->cache->castVectorPtr(nt)); } } if (!done) return nullptr; std::vector stmts; Expr *var = N(getTemporaryVar("cont")); std::vector constructorArgs{}; if (type == getStdLibType("List")->name && !items.empty()) { // Optimization: pre-allocate the list with the exact number of elements constructorArgs.push_back(N(items.size())); } auto t = N(type); auto ta = instantiateType(getStdLibType(type)); if (isDict && collectionTyp->getClass()) { seqassert(collectionTyp->getClass()->isRecord(), "bad dict"); std::vector nt; for (auto &g : collectionTyp->getClass()->generics) nt.push_back(g.getType()); ta = instantiateType(getStdLibType(type), nt); } else if (!isDict) { ta = instantiateType(getStdLibType(type), {collectionTyp.get()}); } t->setType(instantiateTypeVar(ta.get())); stmts.push_back(N(clone(var), N(t, constructorArgs))); for (const auto &it : items) { if (!isDict && cast(it)) { // Unpack star-expression by iterating over it // `*star` -> `for i in star: cont.[fn](i)` auto star = cast(it); Expr *forVar = N(getTemporaryVar("i")); star->getExpr()->setAttribute(Attr::ExprStarSequenceItem); stmts.push_back(N( clone(forVar), star->getExpr(), N(N(N(clone(var), fn), clone(forVar))))); } else if (isDict && cast(it)) { // Expand kwstar-expression by iterating over it: see the example above auto star = cast(it); Expr *forVar = N(getTemporaryVar("it")); star->getExpr()->setAttribute(Attr::ExprStarSequenceItem); stmts.push_back(N( clone(forVar), star->getExpr(), N(N(N(clone(var), fn), N(clone(forVar), N(0)), N(clone(forVar), N(1)))))); } else { it->setAttribute(Attr::ExprSequenceItem); if (isDict) { Expr *head = it, *lead = nullptr; if (hasSideEffect(head)) { auto var = getTemporaryVar("star"); lead = N(N(var), head); head = N(var); } else { lead = clone(head); } lead->setAttribute(Attr::ExprSequenceItem); head->setAttribute(Attr::ExprSequenceItem); stmts.push_back(N(N(N(clone(var), fn), N(lead, N(0)), N(head, N(1))))); } else { it->setAttribute(Attr::ExprSequenceItem); stmts.push_back(N(N(N(clone(var), fn), it))); } } } return transform(N(stmts, var)); } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/cond.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; namespace codon::ast { using namespace types; /// Call `ready` and `notReady` depending whether the provided static expression can be /// evaluated or not. template auto evaluateStaticCondition(Expr *cond, TT ready, TF notReady) { seqassertn(cond->getType()->getStaticKind(), "not a static condition"); if (cond->getType()->canRealize()) { bool isTrue = false; if (auto as = cond->getType()->getStrStatic()) isTrue = !as->value.empty(); else if (auto ai = cond->getType()->getIntStatic()) isTrue = ai->value; else if (auto ab = cond->getType()->getBoolStatic()) isTrue = ab->value; return ready(isTrue); } else { return notReady(); } } /// Only allowed in @c MatchStmt void TypecheckVisitor::visit(RangeExpr *expr) { E(Error::UNEXPECTED_TYPE, expr, "range"); } /// Typecheck if expressions. Evaluate static if blocks if possible. /// Also wrap conditional expressions to match each other. See @c wrapExpr for more /// details. void TypecheckVisitor::visit(IfExpr *expr) { auto oldExpectedType = getStdLibType("bool")->shared_from_this(); std::swap(ctx->expectedType, oldExpectedType); expr->cond = transform(expr->getCond()); std::swap(ctx->expectedType, oldExpectedType); // Static if evaluation if (expr->getCond()->getType()->getStaticKind()) { resultExpr = evaluateStaticCondition( expr->getCond(), [&](bool isTrue) { LOG_TYPECHECK("[static::cond] {}: {}", getSrcInfo(), isTrue); return transform(isTrue ? expr->getIf() : expr->getElse()); }, [&]() -> Expr * { return nullptr; }); if (resultExpr) unify(expr->getType(), resultExpr->getType()); else if (expr->getType()->getUnbound()) expr->getType()->getUnbound()->staticKind = LiteralKind::Int; // determine later! return; } expr->ifexpr = transform(expr->getIf()); expr->elsexpr = transform(expr->getElse()); wrapExpr(&expr->cond, getStdLibType("bool")); // Add wrappers and unify both sides if (expr->getIf()->getType()->getStatic()) expr->getIf()->setType( expr->getIf()->getType()->getStatic()->getNonStaticType()->shared_from_this()); if (expr->getElse()->getType()->getStatic()) expr->getElse()->setType(expr->getElse() ->getType() ->getStatic() ->getNonStaticType() ->shared_from_this()); wrapExpr(&expr->elsexpr, expr->getIf()->getType(), nullptr, /*allowUnwrap*/ false); wrapExpr(&expr->ifexpr, expr->getElse()->getType(), nullptr, /*allowUnwrap*/ false); unify(expr->getType(), expr->getIf()->getType()); unify(expr->getType(), expr->getElse()->getType()); if (expr->getCond()->isDone() && expr->getIf()->isDone() && expr->getElse()->isDone()) expr->setDone(); } /// Typecheck if statements. Evaluate static if blocks if possible. /// See @c wrapExpr for more details. void TypecheckVisitor::visit(IfStmt *stmt) { auto oldExpectedType = getStdLibType("bool")->shared_from_this(); std::swap(ctx->expectedType, oldExpectedType); stmt->cond = transform(stmt->getCond()); std::swap(ctx->expectedType, oldExpectedType); // Static if evaluation if (stmt->getCond()->getType()->getStaticKind()) { resultStmt = evaluateStaticCondition( stmt->getCond(), [&](bool isTrue) { LOG_TYPECHECK("[static::cond] {}: {}", getSrcInfo(), isTrue); auto t = transform(isTrue ? stmt->getIf() : stmt->getElse()); return t ? t : transform(N()); }, [&]() -> Stmt * { return nullptr; }); return; } wrapExpr(&stmt->cond, getStdLibType("bool")); ctx->blockLevel++; stmt->ifSuite = SuiteStmt::wrap(transform(stmt->getIf())); stmt->elseSuite = SuiteStmt::wrap(transform(stmt->getElse())); ctx->blockLevel--; if (stmt->cond->isDone() && (!stmt->getIf() || stmt->getIf()->isDone()) && (!stmt->getElse() || stmt->getElse()->isDone())) stmt->setDone(); } /// Simplify match statement by transforming it into a series of conditional statements. /// @example /// ```match e: /// case pattern1: ... /// case pattern2 if guard: ... /// ...``` -> /// ```_match = e /// while True: # used to simulate goto statement with break /// [pattern1 transformation]: (...; break) /// [pattern2 transformation]: if guard: (...; break) /// ... /// break # exit the loop no matter what``` /// The first pattern that matches the given expression will be used; other patterns /// will not be used (i.e., there is no fall-through). See @c transformPattern for /// pattern transformations void TypecheckVisitor::visit(MatchStmt *stmt) { auto var = getTemporaryVar("match"); auto result = N(); result->addStmt(transform(N(N(var), clone(stmt->getExpr())))); for (auto &c : *stmt) { Stmt *suite = N(c.getSuite(), N()); if (c.getGuard()) suite = N(c.getGuard(), suite); result->addStmt(transformPattern(N(var), c.getPattern(), suite)); } // Make sure to break even if there is no case _ to prevent infinite loop result->addStmt(N()); resultStmt = transform(N(N(true), result)); } /// Transform a match pattern into a series of if statements. /// @example /// `case True` -> `if isinstance(var, "bool"): if var == True` /// `case 1` -> `if isinstance(var, "int"): if var == 1` /// `case 1...3` -> ```if isinstance(var, "int"): /// if var >= 1: if var <= 3``` /// `case (1, pat)` -> ```if isinstance(var, "Tuple"): if static.len(var) == 2: /// if match(var[0], 1): if match(var[1], pat)``` /// `case [1, ..., pat]` -> ```if isinstance(var, "List"): if len(var) >= 2: /// if match(var[0], 1): if match(var[-1], pat)``` /// `case 1 or pat` -> `if match(var, 1): if match(var, pat)` /// (note: pattern suite is cloned for each `or`) /// `case (x := pat)` -> `(x := var; if match(var, pat))` /// `case x` -> `(x := var)` /// (only when `x` is not '_') /// `case expr` -> `if hasattr(typeof(var), "__match__"): if /// var.__match__(foo())` /// (any expression that does not fit above patterns) Stmt *TypecheckVisitor::transformPattern(Expr *var, Expr *pattern, Stmt *suite) { // Convenience function to generate `isinstance(e, typ)` calls auto isinstance = [&](Expr *e, const std::string &typ) -> Expr * { return N(N("isinstance"), clone(e), N(typ)); }; // Convenience function to find the index of an ellipsis within a list pattern auto findEllipsis = [&](const std::vector &items) { size_t i = items.size(); for (auto it = 0; it < items.size(); it++) if (cast(items[it])) { if (i != items.size()) E(Error::MATCH_MULTI_ELLIPSIS, items[it], "multiple ellipses in pattern"); i = it; } return i; }; // See the above examples for transformation details if (cast(pattern) || cast(pattern)) { // Bool and int patterns return N(isinstance(var, cast(pattern) ? "bool" : "int"), N(N(var, "==", pattern), suite)); } else if (auto er = cast(pattern)) { // Range pattern return N( isinstance(var, "int"), N(N(var, ">=", er->start), N(N(clone(var), "<=", er->stop), suite))); } else if (auto et = cast(pattern)) { // Tuple pattern for (auto it = et->items.size(); it-- > 0;) { suite = transformPattern(N(clone(var), N(it)), (*et)[it], suite); } return N( isinstance(var, "Tuple"), N(N( N( N(getMangledFunc("std.internal.static", "len")), var), "==", N(et->size())), suite)); } else if (auto el = cast(pattern)) { // List pattern size_t ellipsis = findEllipsis(el->items), sz = el->size(); std::string op; if (ellipsis == el->size()) { op = "=="; } else { op = ">=", sz -= 1; } for (auto it = el->size(); it-- > ellipsis + 1;) { suite = transformPattern(N(clone(var), N(it - el->size())), (*el)[it], suite); } for (auto it = ellipsis; it-- > 0;) { suite = transformPattern(N(clone(var), N(it)), (*el)[it], suite); } return N( isinstance(var, "List"), N(N(N(N("len"), var), op, N(sz)), suite)); } else if (auto eb = cast(pattern)) { // Or pattern if (eb->op == "|" || eb->op == "||") { return N(transformPattern(clone(var), eb->lexpr, clone(suite)), transformPattern(var, eb->rexpr, suite)); } } else if (auto ei = cast(pattern)) { // Wildcard pattern if (ei->value != "_") { return N(N(pattern, var), suite); } else { return suite; } } else if (auto ea = cast(pattern)) { // Bound pattern seqassert(cast(ea->getVar()), "only simple assignment expressions are supported"); return N(N(ea->getVar(), clone(var)), transformPattern(var, ea->getExpr(), suite)); } pattern = transform(pattern); // transform to check for pattern errors if (cast(pattern)) pattern = N(N("ellipsis")); // Fallback (`__match__`) pattern auto p = N(N(N("hasattr"), clone(var), N("__match__"), clone(pattern)), N(N(N(var, "__match__"), pattern), suite)); return p; } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/ctx.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "ctx.h" #include #include #include #include #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/visitors/format/format.h" #include "codon/parser/visitors/scoping/scoping.h" using namespace codon::error; namespace codon::ast { TypecheckItem::TypecheckItem(std::string canonicalName, std::string baseName, std::string moduleName, types::TypePtr type, std::vector scope) : canonicalName(std::move(canonicalName)), baseName(std::move(baseName)), moduleName(std::move(moduleName)), type(std::move(type)), scope(std::move(scope)) {} TypeContext::TypeContext(Cache *cache, std::string filename) : Context(std::move(filename)), cache(cache) { bases.emplace_back(); scope.emplace_back(0); auto e = cache->N(); e->setSrcInfo(cache->generateSrcInfo()); pushNode(e); // Always have srcInfo() around } void TypeContext::add(const std::string &name, const TypeContext::Item &var) { seqassert(!var->scope.empty(), "bad scope for '{}'", name); Context::add(name, var); } void TypeContext::removeFromMap(const std::string &name) { Context::removeFromMap(name); } TypeContext::Item TypeContext::addVar(const std::string &name, const std::string &canonicalName, const types::TypePtr &type, int64_t time, const SrcInfo &srcInfo) { seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); // seqassert(type->getLink(), "bad var"); auto t = std::make_shared(canonicalName, getBaseName(), getModule(), type, getScope()); t->setSrcInfo(srcInfo); t->time = time; add(name, t); addAlwaysVisible(t); return t; } TypeContext::Item TypeContext::addType(const std::string &name, const std::string &canonicalName, const types::TypePtr &type, const SrcInfo &srcInfo) { seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); // seqassert(type->getClass(), "bad type"); auto t = std::make_shared(canonicalName, getBaseName(), getModule(), type, getScope()); t->setSrcInfo(srcInfo); add(name, t); addAlwaysVisible(t); return t; } TypeContext::Item TypeContext::addFunc(const std::string &name, const std::string &canonicalName, const types::TypePtr &type, const SrcInfo &srcInfo) { seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); seqassert(type->getFunc(), "bad func"); auto t = std::make_shared(canonicalName, getBaseName(), getModule(), type, getScope()); t->setSrcInfo(srcInfo); add(name, t); addAlwaysVisible(t); return t; } TypeContext::Item TypeContext::addAlwaysVisible(const TypeContext::Item &item, bool pop) { add(item->canonicalName, item); if (pop) stack.front().pop_back(); // do not remove it later! if (!cache->typeCtx->Context::find(item->canonicalName)) { cache->typeCtx->add(item->canonicalName, item); if (pop) cache->typeCtx->stack.front().pop_back(); // do not remove it later! // Realizations etc. if (!in(cache->reverseIdentifierLookup, item->canonicalName)) cache->reverseIdentifierLookup[item->canonicalName] = item->canonicalName; } return item; } TypeContext::Item TypeContext::find(const std::string &name, int64_t time, const char *inBase) const { auto it = map.find(name); bool isMangled = in(name, "."); std::string base = inBase ? inBase : getBaseName(); if (it != map.end()) { for (auto &i : it->second) { if (!isMangled && !startswith(base, i->getBaseName())) { continue; // avoid middle realizations } if (isMangled || i->getBaseName() != base || !time || i->getModule() != getModule()) { return i; } else { if (i->getTime() <= time) return i; } } } // Item is not found in the current module. Time to look in the standard library! // Note: the standard library items cannot be dominated. TypeContext::Item t = nullptr; auto stdlib = cache->imports[STDLIB_IMPORT].ctx; if (stdlib.get() != this) t = stdlib->Context::find(name); // Maybe we are looking for a canonical identifier? if (!t && cache->typeCtx.get() != this) t = cache->typeCtx->Context::find(name); return t; } TypeContext::Item TypeContext::forceFind(const std::string &name) const { auto f = find(name); seqassert(f, "cannot find '{}'", name); return f; } /// Getters and setters std::string TypeContext::getBaseName() const { return bases.back().name; } std::string TypeContext::getModule() const { std::string base = moduleName.status == ImportFile::STDLIB ? "std." : ""; base += moduleName.module; if (auto sz = startswith(base, "__main__")) base = base.substr(sz); return base; } std::string TypeContext::getModulePath() const { return moduleName.path; } void TypeContext::dump() { dump(0); } std::string TypeContext::generateCanonicalName(const std::string &name, bool includeBase, bool noSuffix) const { std::string newName = name; if (name.find('.') != std::string::npos) return name; includeBase &= !(!name.empty() && name[0] == '%'); if (includeBase) { std::string base = getBaseName(); if (base.empty()) base = getModule(); if (base == "std.internal.core") { noSuffix = true; base = ""; } newName = (base.empty() ? "" : (base + ".")) + newName; } auto num = cache->identifierCount[newName]++; if (!noSuffix) newName = fmt::format("{}.{}", newName, num); if (name != newName) cache->identifierCount[newName]++; cache->reverseIdentifierLookup[newName] = name; return newName; } bool TypeContext::isGlobal() const { return bases.size() == 1; } bool TypeContext::isConditional() const { return scope.size() > 1; } TypeContext::Base *TypeContext::getBase() { return bases.empty() ? nullptr : &(bases.back()); } bool TypeContext::inFunction() const { return !isGlobal() && !bases.back().isType(); } bool TypeContext::inClass() const { return !isGlobal() && bases.back().isType(); } bool TypeContext::isOuter(const Item &val) const { return getBaseName() != val->getBaseName() || getModule() != val->getModule(); } TypeContext::Base *TypeContext::getClassBase() { if (bases.size() >= 2 && bases[bases.size() - 2].isType()) return &(bases[bases.size() - 2]); return nullptr; } size_t TypeContext::getRealizationDepth() const { return bases.size(); } std::string TypeContext::getRealizationStackName() const { if (bases.empty()) return ""; std::vector s; for (auto &b : bases) if (b.type) s.push_back(b.type->realizedName()); return join(s, ":"); } void TypeContext::dump(int pad) { auto ordered = std::map(map.begin(), map.end()); LOG("current module: {} ({})", moduleName.module, moduleName.path); LOG("current base: {} / {}", getRealizationStackName(), getBase()->name); for (auto &i : ordered) { std::string s; auto t = i.second.front(); LOG("{}{:.<25}", std::string(static_cast(pad) * 2, ' '), i.first); LOG(" ... kind: {}", t->isType() * 100 + t->isFunc() * 10 + t->isVar()); LOG(" ... canonical: {}", t->canonicalName); LOG(" ... base: {}", t->baseName); LOG(" ... module: {}", t->moduleName); LOG(" ... type: {}", t->type ? t->type->debugString(2) : ""); LOG(" ... scope: {}", t->scope); LOG(" ... gnrc/sttc: {} / {}", t->generic, static_cast(t->getStaticKind())); } } std::string TypeContext::debugInfo() { return fmt::format("[{}:i{}@{}]", getBase()->name, getBase()->iteration, getSrcInfo()); } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/ctx.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include #include "codon/parser/cache.h" #include "codon/parser/common.h" #include "codon/parser/ctx.h" namespace codon::ast { class TypecheckVisitor; /** * Typecheck context identifier. * Can be either a function, a class (type), or a variable. */ struct TypecheckItem : public SrcObject { /// Unique identifier (canonical name) std::string canonicalName; /// Base name (e.g., foo.bar.baz) std::string baseName; /// Full module name std::string moduleName; /// Type types::TypePtr type = nullptr; /// Full base scope information std::vector scope = {0}; /// Specifies at which time the name was added to the context. /// Used to prevent using later definitions early (can happen in /// advanced type checking iterations). int64_t time = 0; /// Set if an identifier is a class or a function generic bool generic = false; TypecheckItem(std::string, std::string, std::string, types::TypePtr, std::vector = {0}); /* Convenience getters */ std::string getBaseName() const { return baseName; } std::string getModule() const { return moduleName; } bool isVar() const { return !generic && !isFunc() && !isType(); } bool isFunc() const { return type->getFunc() != nullptr; } bool isType() const { return type->is(TYPE_TYPE); } bool isGlobal() const { return scope.size() == 1 && baseName.empty(); } /// True if an identifier is within a conditional block /// (i.e., a block that might not be executed during the runtime) bool isConditional() const { return scope.size() > 1; } bool isGeneric() const { return generic; } types::LiteralKind getStaticKind() const { return type->getStaticKind(); } types::Type *getType() const { return type.get(); } std::string getName() const { return canonicalName; } int64_t getTime() const { return time; } }; /** Context class that tracks identifiers during the typechecking. **/ struct TypeContext : public Context { /// A pointer to the shared cache. Cache *cache; /// Holds the information about current scope. /// A scope is defined as a stack of conditional blocks /// (i.e., blocks that might not get executed during the runtime). /// Used mainly to support Python's variable scoping rules. struct ScopeBlock { int id; std::unordered_map> replacements; /// List of statements that are to be prepended to a block /// after its transformation. std::vector stmts; ScopeBlock(int id) : id(id) {} }; /// Current hierarchy of conditional blocks. std::vector scope; std::vector getScope() const { std::vector result; result.reserve(scope.size()); for (const auto &b : scope) result.emplace_back(b.id); return result; } /// Holds the information about current base. /// A base is defined as a function or a class block. struct Base { /// Canonical name of a function or a class that owns this base. std::string name; /// Function type types::TypePtr type; /// The return type of currently realized function types::TypePtr returnType; /// Typechecking iteration int iteration = 0; /// Only set for functions. FunctionStmt *func = nullptr; Stmt *suite = nullptr; /// Index of the parent base int parent = 0; struct { /// Set if the base is class base and if class is marked with @deduce. /// Stores the list of class fields in the order of traversal. std::shared_ptr> deducedMembers = nullptr; /// Canonical name of `self` parameter that is used to deduce class fields /// (e.g., self in self.foo). std::string selfName; } deduce; /// Map of captured identifiers (i.e., identifiers not defined in a function). /// Captured (canonical) identifiers are mapped to the new canonical names /// (representing the canonical function argument names that are appended to the /// function after processing) and their types (indicating if they are a type, a /// static or a variable). // std::unordered_set captures; /// Map of identifiers that are to be fetched from Python. std::unordered_set *pyCaptures = nullptr; // /// Scope that defines the base. // std::vector scope; /// A stack of nested loops enclosing the current statement used for transforming /// "break" statement in loop-else constructs. Each loop is defined by a "break" /// variable created while parsing a loop-else construct. If a loop has no else /// block, the corresponding loop variable is empty. struct Loop { std::string breakVar; /// False if a loop has continue/break statement. Used for flattening static /// loops. bool flat = true; Loop(const std::string &breakVar) : breakVar(breakVar) {} }; std::vector loops; std::map> pendingDefaults; public: Loop *getLoop() { return loops.empty() ? nullptr : &(loops.back()); } bool isType() const { return func == nullptr; } }; /// Current base stack (the last enclosing base is the last base in the stack). std::vector bases; struct BaseGuard { TypeContext *holder; BaseGuard(TypeContext *holder, const std::string &name) : holder(holder) { holder->bases.emplace_back(); holder->bases.back().name = name; holder->addBlock(); } ~BaseGuard() { holder->bases.pop_back(); holder->popBlock(); } }; /// Current module. The default module is named `__main__`. ImportFile moduleName = {ImportFile::PACKAGE, "", ""}; /// Set if the standard library is currently being loaded. bool isStdlibLoading = false; /// The current type-checking level (for type instantiation and generalization). int typecheckLevel = 0; int changedNodes = 0; /// Number of nested realizations. Used to prevent infinite instantiations. int realizationDepth = 0; /// Number of nested blocks (0 for toplevel) int blockLevel = 0; /// True if an early return is found (anything afterwards won't be typechecked) bool returnEarly = false; /// Stack of static loop control variables (used to emulate goto statements). std::vector staticLoops = {}; /// Current statement time. int64_t time = 0; /// @brief Type to be expected upon completed typechecking. types::TypePtr expectedType = nullptr; bool autoPython = false; /// True if no side-effects allowed bool simpleTypes = false; std::unordered_map globalShadows; std::unordered_map _itime; public: explicit TypeContext(Cache *cache, std::string filename = ""); void add(const std::string &name, const Item &var) override; /// Convenience method for adding an object to the context. Item addVar(const std::string &name, const std::string &canonicalName, const types::TypePtr &type, int64_t time = 0, const SrcInfo &srcInfo = SrcInfo()); Item addType(const std::string &name, const std::string &canonicalName, const types::TypePtr &type, const SrcInfo &srcInfo = SrcInfo()); Item addFunc(const std::string &name, const std::string &canonicalName, const types::TypePtr &type, const SrcInfo &srcInfo = SrcInfo()); /// Add the item to the standard library module, thus ensuring its visibility from all /// modules. Item addAlwaysVisible(const Item &item, bool = false); /// Get an item from the context before given srcInfo. If the item does not exist, /// nullptr is returned. Item find(const std::string &name, int64_t time = 0, const char * = nullptr) const; /// Get an item that exists in the context. If the item does not exist, assertion is /// raised. Item forceFind(const std::string &name) const; /// Return a canonical name of the current base. /// An empty string represents the toplevel base. std::string getBaseName() const; /// Return the current module. std::string getModule() const; /// Return the current module path. std::string getModulePath() const; /// Pretty-print the current context state. void dump() override; /// Generate a unique identifier (name) for a given string. std::string generateCanonicalName(const std::string &name, bool includeBase = false, bool noSuffix = false) const; /// True if we are at the toplevel. bool isGlobal() const; /// True if we are within a conditional block. bool isConditional() const; /// Get the current base. Base *getBase(); /// True if the current base is function. bool inFunction() const; /// True if the current base is class. bool inClass() const; /// True if an item is defined outside of the current base or a module. bool isOuter(const Item &val) const; /// Get the enclosing class base (or nullptr if such does not exist). Base *getClassBase(); /// Convenience method for adding an object to the context. std::shared_ptr addToplevel(const std::string &name, const std::shared_ptr &item) { map[name].push_front(item); return item; } public: /// Get the current realization depth (i.e., the number of nested realizations). size_t getRealizationDepth() const; /// Get the name of the current realization stack (e.g., `fn1:fn2:...`). std::string getRealizationStackName() const; private: /// Pretty-print the current context state. void dump(int pad); /// Pretty-print the current realization context. std::string debugInfo(); protected: void removeFromMap(const std::string &name) override; }; } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/error.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/match.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon::ast { using namespace types; using namespace matcher; using namespace error; /// Transform asserts. /// @example /// `assert foo()` -> /// `if not foo(): raise __internal__.seq_assert([file], [line], "")` /// `assert foo(), msg` -> /// `if not foo(): raise __internal__.seq_assert([file], [line], str(msg))` /// Use `seq_assert_test` instead of `seq_assert` and do not raise anything during unit /// testing (i.e., when the enclosing function is marked with `@test`). void TypecheckVisitor::visit(AssertStmt *stmt) { Expr *msg = N(""); if (stmt->getMessage()) msg = N(N("str"), stmt->getMessage()); auto test = ctx->inFunction() && ctx->getBase()->func->hasAttribute(Attr::Test); auto ex = N( N(N("__internal__"), test ? "seq_assert_test" : "seq_assert"), N(stmt->getSrcInfo().file), N(stmt->getSrcInfo().line), msg); auto cond = N("!", stmt->getExpr()); if (test) { resultStmt = transform(N(cond, N(ex))); } else { resultStmt = transform(N(cond, N(ex))); } } /// Typecheck try-except statements. Handle Python exceptions separately. /// @example /// ```try: ... /// except python.Error as e: ... /// except PyExc as f: ... /// except ValueError as g: ... /// ``` -> ``` /// try: ... /// except ValueError as g: ... # ValueError /// except PyExc as exc: /// while True: /// if isinstance(exc.pytype, python.Error): # python.Error /// e = exc.pytype; ...; break /// f = exc; ...; break # PyExc /// raise``` void TypecheckVisitor::visit(TryStmt *stmt) { ctx->blockLevel++; stmt->suite = SuiteStmt::wrap(transform(stmt->getSuite())); ctx->blockLevel--; std::vector catches; auto pyCatchStmt = N(N(true), N()); auto done = stmt->getSuite()->isDone(); for (auto &c : *stmt) { TypeContext::Item val = nullptr; if (!c->getVar().empty()) { if (!c->hasAttribute(Attr::ExprDominated) && !c->hasAttribute(Attr::ExprDominatedUsed)) { val = ctx->addVar(getUnmangledName(c->getVar()), ctx->generateCanonicalName(c->getVar()), instantiateUnbound(), getTime()); } else if (c->hasAttribute(Attr::ExprDominatedUsed)) { val = ctx->forceFind(c->getVar()); c->eraseAttribute(Attr::ExprDominatedUsed); c->setAttribute(Attr::ExprDominated); c->suite = N( N(N(fmt::format("{}{}", getUnmangledName(c->getVar()), VAR_USED_SUFFIX)), N(true), nullptr, AssignStmt::UpdateMode::Update), c->getSuite()); } else { val = ctx->forceFind(c->getVar()); } c->var = val->canonicalName; } c->exc = transform(c->getException()); if (c->getException() && extractClassType(c->getException())->is("pyobj")) { // Transform python.Error exceptions if (!stmt->hasAttribute(Attr::TryPyVar)) stmt->setAttribute(Attr::TryPyVar, getTemporaryVar("pyexc")); auto pyVar = stmt->getAttribute(Attr::TryPyVar)->value; if (!c->var.empty()) { c->suite = N( N(N(c->var), N(N(pyVar), "pytype")), c->getSuite()); } c->suite = SuiteStmt::wrap(N( N(N("isinstance"), N(N(pyVar), "pytype"), c->getException()), N(c->getSuite(), N()), nullptr)); cast(pyCatchStmt->getSuite())->addStmt(c->getSuite()); } else if (c->getException() && extractClassType(c->getException()) ->is(getMangledClass("std.internal.python", "PyError"))) { // Transform PyExc exceptions if (!stmt->hasAttribute(Attr::TryPyVar)) stmt->setAttribute(Attr::TryPyVar, getTemporaryVar("pyexc")); auto pyVar = stmt->getAttribute(Attr::TryPyVar)->value; if (!c->var.empty()) { c->suite = N(N(N(c->var), N(pyVar)), c->getSuite()); } c->suite = N(c->getSuite(), N()); cast(pyCatchStmt->getSuite())->addStmt(c->getSuite()); } else { // Handle all other exceptions c->exc = transformType(c->getException()); if (c->getException()) { auto t = extractClassType(c->getException()); bool exceptionOK = false; for (auto &p : getRTTISuperTypes(t)) if (p->is(getMangledClass("std.internal.types.error", "BaseException"))) { exceptionOK = true; break; } if (!exceptionOK) E(Error::CATCH_EXCEPTION_TYPE, c->getException(), t->prettyString()); if (val) unify(val->getType(), extractType(c->getException())); } ctx->blockLevel++; c->suite = SuiteStmt::wrap(transform(c->getSuite())); ctx->blockLevel--; done &= (!c->getException() || c->getException()->isDone()) && c->getSuite()->isDone(); catches.push_back(c); } } if (!cast(pyCatchStmt->getSuite())->empty()) { // Process PyError catches auto pyVar = stmt->getAttribute(Attr::TryPyVar)->value; auto exc = N(getMangledClass("std.internal.python", "PyError")); cast(pyCatchStmt->getSuite())->addStmt(N(nullptr)); auto c = N(pyVar, transformType(exc), pyCatchStmt); auto val = ctx->addVar( pyVar, pyVar, extractType(c->getException())->shared_from_this(), getTime()); ctx->blockLevel++; c->suite = SuiteStmt::wrap(transform(c->getSuite())); ctx->blockLevel--; done &= (!c->exc || c->exc->isDone()) && c->getSuite()->isDone(); catches.push_back(c); } stmt->items = catches; if (stmt->getElse()) { ctx->blockLevel++; stmt->elseSuite = SuiteStmt::wrap(transform(stmt->getElse())); ctx->blockLevel--; done &= stmt->getElse()->isDone(); } if (stmt->getFinally()) { ctx->blockLevel++; stmt->finally = SuiteStmt::wrap(transform(stmt->getFinally())); ctx->blockLevel--; done &= stmt->getFinally()->isDone(); } if (done) stmt->setDone(); } /// Transform `raise` statements. /// @example /// `raise exc` -> ```raise BaseException.set_header(exc, "fn", "file", line, col)``` void TypecheckVisitor::visit(ThrowStmt *stmt) { if (!stmt->expr) { stmt->setDone(); return; } stmt->expr = transform(stmt->getExpr()); if (!match(stmt->getExpr(), M(M(getMangledMethod("std.internal.types.error", "BaseException", "_set_header")), M_))) { stmt->expr = transform(N( N(getMangledMethod("std.internal.types.error", "BaseException", "_set_header")), stmt->getExpr(), N(ctx->getBase()->name), N(stmt->getSrcInfo().file), N(stmt->getSrcInfo().line), N(stmt->getSrcInfo().col), stmt->getFrom() ? N(N(N("Super"), "_super"), stmt->getFrom(), N(getMangledClass("std.internal.types.error", "BaseException"))) : N(N("NoneType")))); } if (stmt->getExpr()->isDone()) stmt->setDone(); } /// Transform with statements. /// @example /// `with foo(), bar() as a: ...` -> /// ```tmp = foo() /// tmp.__enter__() /// try: /// a = bar() /// a.__enter__() /// try: /// ... /// finally: /// a.__exit__() /// finally: /// tmp.__exit__()``` void TypecheckVisitor::visit(WithStmt *stmt) { seqassert(!stmt->empty(), "stmt->items is empty"); bool isAsync = stmt->isAsync(); std::vector content; for (auto i = stmt->items.size(); i-- > 0;) { std::string var = stmt->vars[i].empty() ? getTemporaryVar("with") : stmt->vars[i]; auto as = N(N(var), (*stmt)[i], nullptr, (*stmt)[i]->hasAttribute(Attr::ExprDominated) ? AssignStmt::UpdateMode::Update : AssignStmt::UpdateMode::Assign); Expr *enter = N(N(N(var), isAsync ? "__aenter__" : "__enter__")); Expr *exit = N(N(N(var), isAsync ? "__aexit__" : "__exit__")); if (isAsync) { enter = N(enter); exit = N(exit); } content = std::vector{ as, N(enter), N(!content.empty() ? N(content) : clone(stmt->getSuite()), std::vector{}, nullptr, N(N(exit)))}; } resultStmt = transform(N(content)); } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/function.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include "codon/cir/attribute.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/match.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; namespace codon::ast { using namespace types; using namespace matcher; /// Unify the function return type with `Generator[?]`. /// The unbound type will be deduced from return/yield statements. void TypecheckVisitor::visit(LambdaExpr *expr) { std::vector params; std::string name = getTemporaryVar("lambda"); params.reserve(expr->size()); for (auto &s : *expr) params.emplace_back(s); Stmt *f = N(name, nullptr, params, N(N(expr->getExpr()))); /// TODO: just copy BindingsAttribute from expr instead? if (auto err = ScopingVisitor::apply(ctx->cache, N(f))) throw exc::ParserException(std::move(err)); f->setAttribute(Attr::ExprTime, getTime()); // to handle captures properly f = transform(f); if (auto a = expr->getAttribute(Attr::Bindings)) f->setAttribute(Attr::Bindings, a->clone()); prependStmts->push_back(f); resultExpr = transform(N(name)); } /// Unify the function return type with `Generator[?]`. /// The unbound type will be deduced from return/yield statements. void TypecheckVisitor::visit(YieldExpr *expr) { if (!ctx->inFunction()) E(Error::FN_OUTSIDE_ERROR, expr, "yield"); unify(ctx->getBase()->returnType.get(), instantiateType(getStdLibType("Generator"), {expr->getType()})); if (realize(expr->getType())) expr->setDone(); } /// Typecheck await statements. void TypecheckVisitor::visit(AwaitExpr *expr) { if (!ctx->inFunction()) E(Error::FN_OUTSIDE_ERROR, expr, "await"); auto isAsync = ctx->getBase()->func->isAsync(); if (!isAsync) E(Error::FN_OUTSIDE_ERROR, expr, "await"); expr->expr = transform(expr->getExpr()); if (!expr->transformed) { if (auto c = expr->getExpr()->getType()->getClass()) { bool isCoroutine = c->is(getMangledClass("std.internal.core", "Coroutine")) || c->is(getMangledClass("std.asyncio", "Future")) || c->is(getMangledClass("std.asyncio", "Task")); if (!isCoroutine) { if (!findMethod(c, "__await__").empty()) { auto e = transform(N(N(expr->getExpr(), "__await__"))); isCoroutine = e->getType()->is(getMangledClass("std.internal.core", "Coroutine")) || e->getType()->is(getMangledClass("std.asyncio", "Future")) || e->getType()->is(getMangledClass("std.asyncio", "Task")) || e->getType()->is(getMangledClass("std.internal.core", "Generator")); if (!isCoroutine) { E(Error::EXPECTED_TYPE, expr, "awaitable"); } else { expr->expr = e; expr->transformed = true; } } else { E(Error::EXPECTED_TYPE, expr, "awaitable"); } } } } if (expr->getExpr()->getType()->getClass()) unify(expr->getType(), extractClassGeneric(expr->getExpr()->getType())); if (expr->getExpr()->isDone()) expr->setDone(); } /// Typecheck return statements. Empty return is transformed to `return NoneType()`. /// Also partialize functions if they are being returned. /// See @c wrapExpr for more details. void TypecheckVisitor::visit(ReturnStmt *stmt) { if (stmt->hasAttribute(Attr::Internal)) { stmt->expr = transform(N( N(getMangledMethod("std.internal.core", "NoneType", "__new__")))); stmt->setDone(); return; } if (!ctx->inFunction()) E(Error::FN_OUTSIDE_ERROR, stmt, "return"); auto isAsync = ctx->getBase()->func->isAsync(); if (!stmt->expr && ctx->getBase()->func->hasAttribute(Attr::IsGenerator)) { stmt->setDone(); } else { if (ctx->getBase()->func->hasAttribute(Attr::IsGenerator)) E(Error::CUSTOM, stmt, "returning values from generators not yet supported"); if (!stmt->expr) stmt->expr = N(N("NoneType")); stmt->expr = transform(stmt->getExpr()); // Wrap expression to match the return type if (!ctx->getBase()->returnType->getUnbound()) if (!wrapExpr(&stmt->expr, ctx->getBase()->returnType.get())) { return; } // Special case: partialize functions if we are returning them if (stmt->getExpr()->getType()->getFunc() && !(ctx->getBase()->returnType->getClass() && ctx->getBase()->returnType->is("Function"))) { stmt->expr = transform( N(N(stmt->getExpr()->getType()->getFunc()->ast->getName()), N(EllipsisExpr::PARTIAL))); } if (!ctx->getBase()->returnType->getStaticKind() && stmt->getExpr()->getType()->getStatic()) stmt->getExpr()->setType(stmt->getExpr() ->getType() ->getStatic() ->getNonStaticType() ->shared_from_this()); if (isAsync) { unify(ctx->getBase()->returnType.get(), instantiateType(getStdLibType("Coroutine"), {stmt->getExpr()->getType()})); } else { unify(ctx->getBase()->returnType.get(), stmt->getExpr()->getType()); } } // If we are not within conditional block, ignore later statements in this function. // Useful with static if statements. if (!ctx->blockLevel) ctx->returnEarly = true; if (!stmt->getExpr() || stmt->getExpr()->isDone()) stmt->setDone(); } /// Typecheck yield statements. Empty yields assume `NoneType`. void TypecheckVisitor::visit(YieldStmt *stmt) { if (!ctx->inFunction()) E(Error::FN_OUTSIDE_ERROR, stmt, "yield"); auto isAsync = ctx->getBase()->func->isAsync(); stmt->expr = transform(stmt->getExpr() ? stmt->getExpr() : N(N("NoneType"))); unify(ctx->getBase()->returnType.get(), instantiateType(getStdLibType(!isAsync ? "Generator" : "AsyncGenerator"), {stmt->getExpr()->getType()})); if (stmt->getExpr()->isDone()) stmt->setDone(); } /// Transform `yield from` statements. /// @example /// `yield from a` -> `for var in a: yield var` void TypecheckVisitor::visit(YieldFromStmt *stmt) { auto var = getTemporaryVar("yield"); resultStmt = transform( N(N(var), stmt->getExpr(), N(N(var)))); } /// Process `global` statements. Remove them upon completion. void TypecheckVisitor::visit(GlobalStmt *stmt) { resultStmt = N(); } /// Parse a function stub and create a corresponding generic function type. /// Also realize built-ins and extern C functions. void TypecheckVisitor::visit(FunctionStmt *stmt) { if (stmt->hasAttribute(Attr::Python)) { // Handle Python block resultStmt = transformPythonDefinition(stmt->getName(), stmt->items, stmt->getReturn(), stmt->getSuite()->firstInBlock()); return; } auto origStmt = clean_clone(stmt); // Parse attributes std::vector attributes; bool hasDecorators = false; for (auto i = stmt->decorators.size(); i-- > 0;) { if (!stmt->decorators[i]) continue; auto [isAttr, attrName, attrRealizedName] = getDecorator(stmt->decorators[i]); if (!attrName.empty()) { if (attrName == getMangledFunc("std.internal.attributes", "test")) stmt->setAttribute(Attr::Test); else if (attrName == getMangledFunc("std.internal.attributes", "export")) stmt->setAttribute(Attr::Export); else if (attrName == getMangledFunc("std.internal.attributes", "inline")) stmt->setAttribute(Attr::Inline); else if (attrName == getMangledFunc("std.internal.attributes", "no_arg_reorder")) stmt->setAttribute(Attr::NoArgReorder); else if (attrName == getMangledFunc("std.internal.core", "overload")) stmt->setAttribute(Attr::Overload); if (!stmt->hasAttribute(Attr::FunctionAttributes)) stmt->setAttribute(Attr::FunctionAttributes, std::make_unique()); std::string key = attrRealizedName; stmt->getAttribute(Attr::FunctionAttributes) ->attributes[attrName] = key; const auto &attrFn = getFunction(attrName); if (attrFn && attrFn->ast) { if (attrFn->ast->hasAttribute(Attr::Export)) stmt->setAttribute(Attr::Export); if (attrFn->ast->hasAttribute(Attr::Inline)) stmt->setAttribute(Attr::Inline); if (attrFn->ast->hasAttribute(Attr::NoArgReorder)) stmt->setAttribute(Attr::NoArgReorder); if (attrFn->ast->hasAttribute(Attr::ForceRealize)) stmt->setAttribute(Attr::ForceRealize); if (attrFn->ast->hasAttribute(Attr::FunctionAttributes)) { for (const auto &[k, v] : attrFn->ast ->getAttribute(Attr::FunctionAttributes) ->attributes) stmt->getAttribute(Attr::FunctionAttributes) ->attributes[k] = v; } } if (isAttr) stmt->decorators[i] = nullptr; // remove it from further consideration } if (!isAttr) { hasDecorators = true; } } bool isClassMember = ctx->inClass(); if (stmt->hasAttribute(Attr::ForceRealize) && (!ctx->isGlobal() || isClassMember)) E(Error::EXPECTED_TOPLEVEL, getSrcInfo(), "builtin function"); // All overloads share the same canonical name except for the number at the // end (e.g., `foo.1:0`, `foo.1:1` etc.) std::string rootName; if (isClassMember) { // Case 1: method overload if (auto n = in(getClass(ctx->getBase()->name)->methods, stmt->getName())) rootName = *n; // TODO: handle static inherits and auto-generated cases // if (!rootName.empty() && stmt->hasAttribute(Attr::Overload)) { // compilationWarning( // fmt::format("function '{}' should be marked with @overload", // stmt->getName()), getSrcInfo().file, getSrcInfo().line); // } if (rootName.empty() && stmt->hasAttribute(Attr::Overload)) { compilationWarning(fmt::format("function '{}' marked with unnecessary @overload", stmt->getName()), getSrcInfo().file, getSrcInfo().line); } } else if (stmt->hasAttribute(Attr::Overload)) { // Case 2: function overload if (auto c = ctx->find(stmt->getName(), getTime())) { if (c->isFunc() && c->getModule() == ctx->getModule() && c->getBaseName() == ctx->getBaseName()) { rootName = c->canonicalName; } } } if (rootName.empty()) rootName = ctx->generateCanonicalName(stmt->getName(), true, isClassMember); // Append overload number to the name auto canonicalName = rootName; if (!in(ctx->cache->overloads, rootName)) ctx->cache->overloads.insert({rootName, {}}); canonicalName += fmt::format(":{}", getOverloads(rootName).size()); ctx->cache->reverseIdentifierLookup[canonicalName] = stmt->getName(); if (isClassMember) { // Set the enclosing class name stmt->setAttribute(Attr::ParentClass, ctx->getBase()->name); // Add the method to the class' method list getClass(ctx->getBase()->name)->methods[stmt->getName()] = rootName; } // Handle captures. Add additional argument to the function for every capture. // Make sure to account for **kwargs if present if (auto b = stmt->getAttribute(Attr::Bindings)) { size_t insertSize = stmt->size(); if (!stmt->empty() && startswith(stmt->back().name, "**")) insertSize--; for (auto &[c, t] : b->captures) { std::string cc = "$" + c; if (auto v = ctx->find(c, getTime())) { if (t != BindingsAttribute::CaptureType::Global && !v->isGlobal()) { bool parentClassGeneric = ctx->bases.back().isType() && ctx->bases.back().name == v->getBaseName(); if (v->isGeneric() && parentClassGeneric) { stmt->setAttribute(Attr::Method); } if (!v->isGeneric() || (v->getStaticKind() && !parentClassGeneric)) { if (!v->isFunc()) { if (v->isType()) { stmt->items.insert(stmt->items.begin() + insertSize++, Param(cc, N(TYPE_TYPE))); } else if (auto si = v->getStaticKind()) { stmt->items.insert( stmt->items.begin() + insertSize++, Param(cc, N(N("Literal"), N(Type::stringFromLiteral(si))))); } else { stmt->items.insert(stmt->items.begin() + insertSize++, Param(cc)); } } else { // Local function is captured. Just note its canonical name and add it to // the context during realization. b->localRenames[c] = v->getName(); } } continue; } } if ((c == stmt->getName() && hasDecorators) /* decorated recursive fns */ || in(ctx->globalShadows, c)) { // log("-> {} / {}", stmt->getName(), c); stmt->items.insert(stmt->items.begin() + insertSize++, Param(cc)); } } } std::vector args; Stmt *suite = nullptr; Expr *ret = nullptr; std::vector explicits; std::shared_ptr baseType = nullptr; bool isGlobal = ctx->isGlobal(); { // Set up the base TypeContext::BaseGuard br(ctx.get(), canonicalName); ctx->getBase()->func = stmt; // Parse arguments and add them to the context for (auto &a : *stmt) { auto [stars, varName] = a.getNameWithStars(); auto name = ctx->generateCanonicalName(varName); // Mark as method if the first argument is self if (isClassMember && stmt->hasAttribute(Attr::HasSelf) && a.getName() == "self") stmt->setAttribute(Attr::Method); // Handle default values auto defaultValue = a.getDefault(); if (match(defaultValue, MOr(M(), M(), M(), M(), M(), M()))) { // Special case: all simple types and Nones are handled at call site // (as they are not mutable). if (match(defaultValue, M())) { if (match(a.getType(), M(MOr(TYPE_TYPE, TRAIT_TYPE)))) { // Special case: `arg: type = None` -> `arg: type = NoneType` defaultValue = N("NoneType"); } else { ; // Do nothing. NoneExpr will be handled later (we don't want it // to be converted to Optional call yet.) } } else { defaultValue = transform(defaultValue); } } else if (defaultValue) { if (!a.isValue()) { // Special case: generic defaults are evaluated as-is! defaultValue = transform(defaultValue); } else { auto defName = fmt::format(".default.{}.{}", canonicalName, a.getName()); auto nctx = std::make_shared(ctx->cache); *nctx = *ctx; nctx->bases.pop_back(); if (isClassMember) // class variable; go to the global context! nctx->bases.erase(nctx->bases.begin() + 1, nctx->bases.end()); auto tv = TypecheckVisitor(nctx); auto as = N(N(defName), defaultValue, a.isValue() ? nullptr : a.getType()); if (isClassMember) { preamble->addStmt( tv.transform(N(N(defName), nullptr, nullptr))); registerGlobal(defName); as->setUpdate(); } else if (isGlobal) { registerGlobal(defName); } auto das = tv.transform(as); prependStmts->push_back(das); // Default unbounds must be allowed to pass through // to support cases such as `a = []` auto f = ctx->forceFind(defName); for (auto &u : f->getType()->getUnbounds(false)) { // log("pass-through: {} / {}", stmt->getName(), u->debugString(2)); u->getUnbound()->passThrough = true; stmt->setAttribute(Attr::AllowPassThrough); } defaultValue = tv.transform(N(defName)); } } args.emplace_back(std::string(stars, '*') + name, a.getType(), defaultValue, a.status); // Add generics to the context if (!a.isValue()) { // Generic and static types auto generic = instantiateUnbound(); auto typId = generic->getLink()->id; generic->genericName = varName; auto defType = transform(clone(a.getDefault())); if (auto st = getStaticGeneric(a.getType())) { auto val = ctx->addVar(varName, name, generic); val->generic = true; generic->staticKind = st; if (defType) generic->defaultType = extractType(defType)->shared_from_this(); } else { if (match(a.getType(), M(M(TRAIT_TYPE), M_))) { // Parse TraitVar auto l = transformType(cast(a.getType())->front(), true) ->getType(); if (l->getLink() && l->getLink()->trait) generic->getLink()->trait = l->getLink()->trait; else generic->getLink()->trait = std::make_shared(l->shared_from_this()); } auto val = ctx->addType(varName, name, generic); val->generic = true; if (defType) generic->defaultType = extractType(defType)->shared_from_this(); } auto g = generic->generalize(ctx->typecheckLevel); if (startswith(varName, "$")) varName = varName.substr(1); explicits.emplace_back(name, g, typId, g->getStaticKind()); } } // Prepare list of all generic types ClassType *parentClass = nullptr; if (isClassMember && stmt->hasAttribute(Attr::Method)) { // Get class generics (e.g., T for `class Cls[T]: def foo:`) auto aa = stmt->getAttribute(Attr::ParentClass); parentClass = extractClassType(aa->value); } // Add function generics std::vector generics; generics.reserve(explicits.size()); for (const auto &i : explicits) generics.emplace_back(extractType(i.name)->shared_from_this()); // Handle function arguments // Base type: `Function[[args,...], ret]` baseType = getFuncTypeBase(stmt->size() - explicits.size()); ctx->typecheckLevel++; // Parse arguments to the context. Needs to be done after adding generics // to support cases like `foo(a: T, T: type)` for (auto &a : args) { a.type = transformType(a.getType(), true); } // Unify base type generics with argument types. Add non-generic arguments to the // context. Delayed to prevent cases like `def foo(a, b=a)` auto argType = extractClassGeneric(baseType.get())->getClass(); for (int ai = 0, aj = 0; ai < stmt->size(); ai++) { if (!(*stmt)[ai].isValue()) continue; auto [_, canName] = (*stmt)[ai].getNameWithStars(); if (!(*stmt)[ai].getType()) { if (parentClass && ai == 0 && (*stmt)[ai].getName() == "self") { // Special case: self in methods auto *st = unify(extractClassGeneric(argType, aj), parentClass); if (getClass(parentClass->name)->ast->hasAttribute(Attr::ClassDeduce) && stmt->hasAttribute(Attr::ClassDeduce) && stmt->getName() == "__init__") { for (auto &u : st->getUnbounds(true)) { stmt->setAttribute(Attr::AllowPassThrough); // log("pass-through: {}.__init__ / {}", parentClass->name, // u->debugString(2)); u->getLink()->passThrough = true; } } } else { generics.push_back(extractClassGeneric(argType, aj)->shared_from_this()); } } else if (startswith((*stmt)[ai].getName(), "*")) { // Special case: `*args: type` and `**kwargs: type`. Do not add this type to the // signature (as the real type is `Tuple[type, ...]`); it will be used during // call typechecking generics.push_back(extractClassGeneric(argType, aj)->shared_from_this()); } else { unify(extractClassGeneric(argType, aj), extractType(transformType((*stmt)[ai].getType(), true))); } aj++; } // Parse the return type ret = transformType(stmt->getReturn(), true); auto retType = extractClassGeneric(baseType.get(), 1); if (ret) { // Fix for functions returning Literal types if (auto st = getStaticGeneric(ret)) baseType->generics[1].staticKind = st; unify(retType, extractType(ret)); if (isId(ret, "Union")) extractClassGeneric(retType)->getUnbound()->kind = LinkType::Generic; } else { generics.push_back(unify(retType, instantiateUnbound())->shared_from_this()); } ctx->typecheckLevel--; // Generalize generics and remove them from the context for (const auto &g : generics) { for (auto &u : g->getUnbounds(false)) if (u->getUnbound()) { u->getUnbound()->kind = LinkType::Generic; } } // Parse function body if (!stmt->hasAttribute(Attr::Internal) && !stmt->hasAttribute(Attr::C)) { if (stmt->hasAttribute(Attr::LLVM)) { suite = transformLLVMDefinition(stmt->getSuite()->firstInBlock()); } else if (stmt->hasAttribute(Attr::C)) { // Do nothing } else { suite = clone(stmt->getSuite()); } } } stmt->setAttribute(Attr::Module, ctx->moduleName.path); // Make function AST and cache it for later realization auto f = N(canonicalName, ret, args, suite, std::vector{}, stmt->isAsync()); f->cloneAttributesFrom(stmt); auto &fn = ctx->cache->functions[canonicalName] = Cache::Function{ctx->getModulePath(), rootName, f, nullptr, origStmt, ctx->getModule().empty() && ctx->isGlobal()}; f->setDone(); auto aa = stmt->getAttribute(Attr::ParentClass); auto parentClass = aa ? extractClassType(aa->value) : nullptr; // Construct the type auto funcTyp = std::make_shared(baseType.get(), fn.ast, explicits); funcTyp->setSrcInfo(getSrcInfo()); if (isClassMember && stmt->hasAttribute(Attr::Method)) { funcTyp->funcParent = parentClass->shared_from_this(); } funcTyp = std::static_pointer_cast( funcTyp->generalize(ctx->typecheckLevel)); fn.type = funcTyp; auto &overloads = ctx->cache->overloads[rootName]; if (rootName == "Tuple.__new__") { overloads.insert(std::ranges::upper_bound( overloads, canonicalName, [&](const auto &a, const auto &b) { return getFunction(a)->getType()->funcGenerics.size() < getFunction(b)->getType()->funcGenerics.size(); }), canonicalName); } else { overloads.push_back(canonicalName); } auto val = ctx->addFunc(stmt->name, rootName, funcTyp); // val->time = getTime(); ctx->addFunc(canonicalName, canonicalName, funcTyp); if (stmt->hasAttribute(Attr::Overload) || isClassMember) { ctx->remove(stmt->name); // first overload will handle it! } // Special method handling if (isClassMember) { auto m = getClassMethod(parentClass, getUnmangledName(canonicalName)); bool found = false; for (auto &i : getOverloads(m)) if (i == canonicalName) { getFunction(i)->type = funcTyp; found = true; break; } seqassert(found, "cannot find matching class method for {}", canonicalName); } else { // Hack so that we can later use same helpers for class overloads getClass(VAR_CLASS_TOPLEVEL)->methods[stmt->getName()] = rootName; } // Ensure that functions with @C, @force_realize, and @export attributes can be // realized if (stmt->hasAttribute(Attr::ForceRealize) || stmt->hasAttribute(Attr::Export) || (stmt->hasAttribute(Attr::C) && !stmt->hasAttribute(Attr::CVarArg))) { if (!funcTyp->canRealize()) E(Error::FN_REALIZE_BUILTIN, stmt); } // Expression to be used if function binding is modified by captures or decorators Expr *finalExpr = nullptr; // Parse remaining decorators for (auto i = stmt->decorators.size(); i-- > 0;) { if (stmt->decorators[i]) { // Replace each decorator with `decorator(finalExpr)` in the reverse order finalExpr = N(stmt->decorators[i], finalExpr ? finalExpr : N(canonicalName)); } } if (finalExpr) { auto a = N(N(stmt->getName()), finalExpr); if (isClassMember) { // class method decorator auto nctx = std::make_shared(ctx->cache); *nctx = *ctx; nctx->bases.pop_back(); nctx->bases.erase(nctx->bases.begin() + 1, nctx->bases.end()); // global context auto tv = TypecheckVisitor(nctx); auto defName = ctx->generateCanonicalName(stmt->getName()); preamble->addStmt( tv.transform(N(N(defName), nullptr, nullptr))); registerGlobal(defName); a->setUpdate(); cast(a->getLhs())->value = defName; std::vector args; for (auto arg : *stmt) { if (startswith(arg.name, "**")) args.push_back(N(N(arg.name))); else if (startswith(arg.name, "*")) args.push_back(N(N(arg.name))); else args.push_back(N(arg.name)); } Stmt *newFunc = N( stmt->getName(), clone(stmt->getReturn()), clone(stmt->items), N(N(N(N(defName), args))), std::vector{}, stmt->isAsync()); newFunc = transform(newFunc); resultStmt = N(f, N(transform(a), newFunc)); } else { resultStmt = N(f, transform(a)); } } else { resultStmt = f; } } /// Transform Python code blocks. /// @example /// ```@python /// def foo(x: int, y) -> int: /// [code] /// ``` -> ``` /// pyobj._exec("def foo(x, y): [code]") /// from python import __main__.foo(int, _) -> int /// ``` Stmt *TypecheckVisitor::transformPythonDefinition(const std::string &name, const std::vector &args, Expr *ret, Stmt *codeStmt) { seqassert(codeStmt && cast(codeStmt) && cast(cast(codeStmt)->getExpr()), "invalid Python definition"); auto code = cast(cast(codeStmt)->getExpr())->getValue(); std::vector pyargs; pyargs.reserve(args.size()); for (const auto &a : args) pyargs.emplace_back(a.getName()); code = fmt::format("def {}({}):\n{}\n", name, join(pyargs, ", "), code); return transform(N( N( N(N(N("pyobj"), "_exec"), N(code))), N(N("python"), N(N("__main__"), name), clone(args), ret ? clone(ret) : N("pyobj")))); } /// Transform LLVM functions. /// @example /// ```@llvm /// def foo(x: int) -> float: /// [code] /// ``` -> ``` /// def foo(x: int) -> float: /// StringExpr("[code]") /// SuiteStmt(referenced_types) /// ``` /// As LLVM code can reference types and static expressions in `{=expr}` blocks, /// all block expression will be stored in the `referenced_types` suite. /// "[code]" is transformed accordingly: each `{=expr}` block will /// be replaced with `{}` so that @c fmt::format can fill the gaps. /// Note that any brace (`{` or `}`) that is not part of a block is /// escaped (e.g. `{` -> `{{` and `}` -> `}}`) so that @c fmt::format can process them. Stmt *TypecheckVisitor::transformLLVMDefinition(Stmt *codeStmt) { StringExpr *codeExpr; auto m = match(codeStmt, M(MVar(codeExpr))); seqassert(m, "invalid LLVM definition"); auto code = codeExpr->getValue(); /// Remove docstring (if any) size_t start = 0; while (start < code.size() && std::isspace(code[start])) start++; if (startswith(code.substr(start), "\"\"\"")) { start += 3; bool found = false; while (start < code.size() - 2) { if (code[start] == '"' && code[start + 1] == '"' && code[start + 2] == '"') { found = true; start += 3; break; } start++; } if (found) { code = code.substr(start); } } std::vector items; std::string finalCode; items.push_back(nullptr); // Parse LLVM code and look for expression blocks that start with `{=` int braceCount = 0, braceStart = 0; for (int i = 0; i < code.size(); i++) { if (i < code.size() - 1 && code[i] == '\\' && code[i + 1] == '\n') { code[i] = code[i + 1] = ' '; } if (i < code.size() - 1 && code[i] == '{' && code[i + 1] == '=') { if (braceStart <= i) finalCode += escapeFStringBraces(code, braceStart, i - braceStart) + '{'; if (!braceCount) { braceStart = i + 2; braceCount++; } else { E(Error::FN_BAD_LLVM, getSrcInfo()); } } else if (braceCount && code[i] == '}') { braceCount--; std::string exprCode = code.substr(braceStart, i - braceStart); auto offset = getSrcInfo(); offset.col += i; auto exprOrErr = parseExpr(ctx->cache, exprCode, offset); if (!exprOrErr) throw exc::ParserException(exprOrErr.takeError()); auto expr = exprOrErr->first; items.push_back(N(expr)); braceStart = i + 1; finalCode += '}'; } } if (braceCount) E(Error::FN_BAD_LLVM, getSrcInfo()); if (braceStart != code.size()) finalCode += escapeFStringBraces(code, braceStart, static_cast(code.size()) - braceStart); items[0] = N(N(finalCode)); return N(items); } /// Fetch a decorator canonical name. The first pair member indicates if a decorator is /// actually an attribute (a function with `@__attribute__`). std::tuple TypecheckVisitor::getDecorator(Expr *e) { auto dt = transform(clone(e)); dt = getHeadExpr(dt); if (auto id = cast(cast(dt) ? cast(dt)->getExpr() : dt)) { auto ci = ctx->find(id->getValue(), getTime()); if (ci && ci->isFunc()) { auto fn = ci->getType()->getFunc()->ast->getName(); auto f = getFunction(fn); if (!f) { if (auto o = in(ctx->cache->overloads, fn)) { if (o->size() == 1) f = getFunction(o->front()); } } // Special case: Id to Call if (f->ast->hasAttribute(Attr::Attribute) && cast(dt)) dt = transform(N(dt)); if (f) return {f->ast->hasAttribute(Attr::Attribute), fn, dt->isDone() ? id->getValue() : ""}; } } return {false, "", ""}; } /// Generate and return `Function[Tuple[args...], ret]` type std::shared_ptr TypecheckVisitor::getFuncTypeBase(size_t nargs) { auto baseType = instantiateType(getStdLibType("Function")); unify(extractClassGeneric(baseType->getClass()), instantiateType(generateTuple(nargs, false))); return std::static_pointer_cast(baseType); } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/import.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/match.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/format/format.h" #include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; using namespace codon::matcher; namespace codon::ast { /// Import and parse a new module into its own context. /// Also handle special imports ( see @c transformSpecialImport ). /// To simulate Python's dynamic import logic and import stuff only once, /// each import statement is guarded as follows: /// if not _import_N_done: /// _import_N() /// _import_N_done = True /// See @c transformNewImport and below for more details. void TypecheckVisitor::visit(ImportStmt *stmt) { seqassert(!ctx->inClass(), "imports within a class"); if ((resultStmt = transformSpecialImport(stmt))) return; // Fetch the import auto components = getImportPath(stmt->getFrom(), stmt->getDots()); auto path = combine2(components, "/"); auto file = getImportFile(ctx->cache, path, ctx->getFilename()); if (!file) { if (stmt->getDots() == 0 && ctx->autoPython) { auto newStr = FormatVisitor::apply(stmt->getFrom()); if (stmt->getWhat()) newStr += "." + FormatVisitor::apply(stmt->getWhat()); compilationWarning(fmt::format("importing '{}' from Python", newStr), stmt->getSrcInfo().file, stmt->getSrcInfo().line, stmt->getSrcInfo().col); auto exprOrErr = parseExpr(ctx->cache, newStr, stmt->getFrom()->getSrcInfo()); if (!exprOrErr) throw exc::ParserException(exprOrErr.takeError()); resultStmt = transform(N(N("python"), exprOrErr->first, stmt->getArgs(), stmt->getReturnType(), stmt->getAs())); return; } std::string s(stmt->getDots(), '.'); for (auto &c : components) { if (c == "..") { continue; } else if (!s.empty() && s.back() != '.') { s += "." + c; } else { s += c; } } bool allDot = true; for (auto cp : s) if (cp != '.') { allDot = false; break; } if (allDot && match(stmt->getWhat(), M())) s = cast(stmt->getWhat())->getValue(); E(Error::IMPORT_NO_MODULE, stmt, s); } // If the file has not been seen before, load it into cache bool handled = true; if (!in(ctx->cache->imports, file->path)) { resultStmt = transformNewImport(*file); if (!resultStmt) handled = false; // we need an import } const auto &import = getImport(file->path); std::string importVar = import->importVar; if (!import->loadedAtToplevel) handled = false; // Construct `if _import_done.__invert__(): (_import(); _import_done = True)`. // Do not do this during the standard library loading (we assume that standard library // imports are "clean" and do not need guards). Note that the importVar is empty if // the import has been loaded during the standard library loading. if (!handled) { resultStmt = N( N(N(getMangledFunc("", fmt::format("{}_call", importVar))))); LOG_TYPECHECK("[import] loading {}", importVar); } // Import requested identifiers from the import's scope to the current scope if (!stmt->getWhat()) { // Case: import foo auto name = stmt->as.empty() ? path : stmt->getAs(); auto e = ctx->forceFind(importVar); ctx->add(name, e); } else if (cast(stmt->getWhat()) && cast(stmt->getWhat())->getValue() == "*") { // Case: from foo import * seqassert(stmt->getAs().empty(), "renamed star-import"); // Just copy all symbols from import's context here. for (auto &[i, ival] : *(import->ctx)) { if ((!startswith(i, "_") || (ctx->isStdlibLoading && startswith(i, "__")))) { // Ignore all identifiers that start with `_` but not those that start with // `__` while the standard library is being loaded auto c = ival.front(); if (c->isConditional() && i.find('.') == std::string::npos) c = import->ctx->find(i); // Imports should ignore noShadow property ctx->add(i, c); } } } else { // Case 3: from foo import bar auto i = cast(stmt->getWhat()); seqassert(i, "not a valid import what expression"); auto c = import->ctx->find(i->getValue()); // Make sure that we are importing an existing global symbol if (!c) E(Error::IMPORT_NO_NAME, i, i->getValue(), file->module); if (c->isConditional()) c = import->ctx->find(i->getValue()); // Imports should ignore noShadow property ctx->add(stmt->getAs().empty() ? i->getValue() : stmt->getAs(), c); } resultStmt = transform(!resultStmt ? N() : resultStmt); // erase it } /// Transform special `from C` and `from python` imports. /// See @c transformCImport, @c transformCDLLImport and @c transformPythonImport Stmt *TypecheckVisitor::transformSpecialImport(const ImportStmt *stmt) { if (auto fi = cast(stmt->getFrom())) { if (fi->getValue() == "C") { auto wi = cast(stmt->getWhat()); if (wi && !stmt->isCVar()) { // C function imports return transformCImport(wi->getValue(), stmt->getArgs(), stmt->getReturnType(), stmt->getAs()); } else if (wi) { // C variable imports return transformCVarImport(wi->getValue(), stmt->getReturnType(), stmt->getAs()); } else if (auto de = cast(stmt->getWhat())) { // dylib C imports return transformCDLLImport(de->getExpr(), de->getMember(), stmt->getArgs(), stmt->getReturnType(), stmt->getAs(), !stmt->isCVar()); } } else if (fi->getValue() == "python" && stmt->getWhat()) { // Python imports return transformPythonImport(stmt->getWhat(), stmt->getArgs(), stmt->getReturnType(), stmt->getAs()); } } return nullptr; } /// Transform Dot(Dot(a, b), c...) into "{a, b, c, ...}". /// Useful for getting import paths. std::vector TypecheckVisitor::getImportPath(Expr *from, size_t dots) const { std::vector components; // Path components if (from) { for (; cast(from); from = cast(from)->getExpr()) components.push_back(cast(from)->getMember()); seqassert(cast(from), "invalid import statement"); components.push_back(cast(from)->getValue()); } // Handle dots (i.e., `..` in `from ..m import x`) for (size_t i = 1; i < dots; i++) components.emplace_back(".."); std::ranges::reverse(components); return components; } /// Transform a C function import. /// @example /// `from C import foo(int) -> float as f` -> /// ```@.c /// def foo(a1: int) -> float: /// pass /// f = foo # if altName is provided``` /// No return type implies void return type. *args is treated as C VAR_ARGS. Stmt *TypecheckVisitor::transformCImport(const std::string &name, const std::vector &args, Expr *ret, const std::string &altName) { std::vector fnArgs; bool hasVarArgs = false; for (size_t ai = 0; ai < args.size(); ai++) { seqassert(args[ai].getName().empty(), "unexpected argument name"); seqassert(!args[ai].getDefault(), "unexpected default argument"); seqassert(args[ai].getType(), "missing type"); if (cast(args[ai].getType()) && ai + 1 == args.size()) { // C VAR_ARGS support hasVarArgs = true; fnArgs.emplace_back("*args", nullptr, nullptr); } else { fnArgs.emplace_back(args[ai].getName().empty() ? fmt::format("a{}", ai) : args[ai].getName(), clone(args[ai].getType()), nullptr); } } auto _ = ctx->generateCanonicalName(name); // avoid canonicalName == name Stmt *f = N(name, ret ? clone(ret) : N("NoneType"), fnArgs, nullptr); f->setAttribute(Attr::C); if (hasVarArgs) f->setAttribute(Attr::CVarArg); f = transform(f); // Already in the preamble if (!altName.empty()) { auto v = ctx->find(altName); auto val = ctx->forceFind(name); ctx->add(altName, val); ctx->remove(name); } return f; } /// Transform a C variable import. /// @example /// `from C import foo: int as f` -> /// ```f: int = "foo"``` Stmt *TypecheckVisitor::transformCVarImport(const std::string &name, Expr *type, const std::string &altName) { auto canonical = ctx->generateCanonicalName(name); auto typ = transformType(clone(type)); auto val = ctx->addVar( altName.empty() ? name : altName, canonical, std::make_shared(extractClassType(typ)->shared_from_this()), getTime()); auto s = N(N(canonical), nullptr, typ); s->lhs->setAttribute(Attr::ExprExternVar); s->lhs->setType(val->type); s->lhs->setDone(); s->setDone(); return s; } /// Transform a dynamic C import. /// @example /// `from C import lib.foo(int) -> float as f` -> /// `f = _dlsym(lib, "foo", Fn=Function[[int], float]); f` /// No return type implies void return type. Stmt *TypecheckVisitor::transformCDLLImport(Expr *dylib, const std::string &name, const std::vector &args, Expr *ret, const std::string &altName, bool isFunction) { Expr *type = nullptr; if (isFunction) { std::vector fnArgs{N(), ret ? clone(ret) : N("NoneType")}; for (const auto &a : args) { seqassert(a.getName().empty(), "unexpected argument name"); seqassert(!a.getDefault(), "unexpected default argument"); seqassert(a.getType(), "missing type"); cast(fnArgs[0])->items.emplace_back(clone(a.getType())); } type = N(N("Function"), N(fnArgs)); } else { type = clone(ret); } Expr *c = clone(dylib); return transform(N( N(altName.empty() ? name : altName), N(N("_dlsym"), std::vector{CallArg(c), CallArg(N(name)), CallArg{"Fn", type}}))); } /// Transform a Python module and function imports. /// @example /// `from python import module as f` -> `f = pyobj._import("module")` /// `from python import lib.foo(int) -> float as f` -> /// ```def f(a0: int) -> float: /// f = pyobj._import("lib")._getattr("foo") /// return float.__from_py__(f(a0))``` /// If a return type is nullptr, the function just returns f (raw pyobj). Stmt *TypecheckVisitor::transformPythonImport(Expr *what, const std::vector &args, Expr *ret, const std::string &altName) { // Get a module name (e.g., os.path) auto components = getImportPath(what); if (!ret && args.empty()) { // Simple import: `from python import foo.bar` -> `bar = pyobj._import("foo.bar")` return transform( N(N(altName.empty() ? components.back() : altName), N(N(N("pyobj"), "_import"), N(combine2(components, "."))))); } // Python function import: // `from python import foo.bar(int) -> float` -> // ```def bar(a1: int) -> float: // f = pyobj._import("foo")._getattr("bar") // return float.__from_py__(f(a1))``` // f = pyobj._import("foo")._getattr("bar") auto call = N( N("f"), N(N(N(N(N("pyobj"), "_import"), N(combine2( components, ".", 0, static_cast(components.size()) - 1))), "_getattr"), N(components.back()))); // f(a1, ...) std::vector params; std::vector callArgs; for (int i = 0; i < args.size(); i++) { params.emplace_back(fmt::format("a{}", i), clone(args[i].getType()), nullptr); callArgs.emplace_back(N(fmt::format("a{}", i))); } // `return ret.__from_py__(f(a1, ...))` auto retType = (ret && !cast(ret)) ? clone(ret) : N("NoneType"); auto retExpr = N(N(clone(retType), "__from_py__"), N(N(N("f"), callArgs), "p")); auto retStmt = N(retExpr); // Create a function return transform(N(altName.empty() ? components.back() : altName, retType, params, N(call, retStmt))); } /// Import a new file into its own context and wrap its top-level statements into a /// function to support Python-like runtime import loading. /// @example /// ```_import_[I]_done = False /// def _import_[I](): /// global [imported global variables]... /// __name__ = [I] /// [imported top-level statements]``` Stmt *TypecheckVisitor::transformNewImport(const ImportFile &file) { // Use a clean context to parse a new file auto moduleID = file.module; std::ranges::replace(moduleID, '.', '_'); auto ictx = std::make_shared(ctx->cache, file.path); ictx->isStdlibLoading = ctx->isStdlibLoading; ictx->moduleName = file; auto &import = ctx->cache->imports[file.path]; import.update(file.module, file.path, ictx); import.loadedAtToplevel = getImport(ctx->moduleName.path)->loadedAtToplevel && (ctx->isStdlibLoading || (ctx->isGlobal() && ctx->scope.size() == 1)); auto importVar = import.importVar = getTemporaryVar(fmt::format("import_{}", moduleID)); LOG_REALIZE("[import] initializing {} (location: {}, toplevel: {})", importVar, file.path, import.loadedAtToplevel); // __name__ = [import name] Stmt *n = nullptr; if (file.module != "internal.core") { // str is not defined when loading internal.core; __name__ is not needed anyway n = N( N(N("__name__"), N(ictx->moduleName.module)), N(N("__file__"), N(ictx->moduleName.path))); ctx->addBlock(); preamble->addStmt(transform( N(N(importVar), N(N("Import.__new__"), N(false), N(file.path), N(file.module)), N("Import")))); auto val = ctx->forceFind(importVar); ctx->popBlock(); val->scope = {0}; val->baseName = ""; val->moduleName = MODULE_MAIN; val->time = 0; getImport(STDLIB_IMPORT)->ctx->addToplevel(importVar, val); registerGlobal(val->getName()); } auto nodeOrErr = parseFile(ctx->cache, file.path); if (!nodeOrErr) throw exc::ParserException(nodeOrErr.takeError()); n = N(n, *nodeOrErr); auto tv = TypecheckVisitor(ictx, preamble); if (auto err = ScopingVisitor::apply(ctx->cache, n, &ictx->globalShadows)) throw exc::ParserException(std::move(err)); if (!ctx->cache->errors.empty()) throw exc::ParserException(ctx->cache->errors); // Add comment to the top of import for easier dump inspection auto comment = N(fmt::format("import: {} at {}", file.module, file.path)); auto suite = N(comment, n); if (ctx->isStdlibLoading) { // When loading the standard library, imports are not wrapped. // We assume that the standard library has no recursive imports and that all // statements are executed before the user-provided code. return tv.transform(suite); } else { // Generate import identifier auto stmts = N(); auto ret = N(); ret->setAttribute(Attr::Internal); // do not trigger toplevel ReturnStmt error stmts->addStmt(N(N(N(importVar), "loaded"), ret)); stmts->addStmt(N( N(N("Import._set_loaded"), N(N("__ptr__"), N(importVar))))); stmts->addStmt(suite); // Wrap all imported top-level statements into a function. auto fnName = fmt::format("{}_call", importVar); Stmt *fn = N(fnName, N("NoneType"), std::vector{}, stmts); fn = tv.transform(fn); tv.realize(ictx->forceFind(fnName)->getType()); preamble->addStmt(fn); // LOG_USER("[import] done importing {}", file.module); } return nullptr; } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/infer.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include #include #include #include "codon/cir/attribute.h" #include "codon/cir/types/types.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/translate/translate.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; constexpr int MAX_TYPECHECK_ITER = 1000; namespace codon::ast { using namespace types; /// Unify types a (passed by reference) and b. /// Destructive operation as it modifies both a and b. If types cannot be unified, raise /// an error. /// @param a Type (by reference) /// @param b Type /// @return a Type *TypecheckVisitor::unify(Type *a, Type *b) const { seqassert(a, "lhs is nullptr"); if (!((*a) << b)) { types::Type::Unification undo; a->unify(b, &undo); // log("[unify] {} {}", a->debugString(2), b->debugString(2)); // log("[unify] {} {}", a->debugString(1), b->debugString(1)); E(Error::TYPE_UNIFY, getSrcInfo(), a->prettyString(), b->prettyString()); return nullptr; } return a; } /// Infer all types within a Stmt *. Implements the LTS-DI typechecking. /// @param isToplevel set if typechecking the program toplevel. Stmt *TypecheckVisitor::inferTypes(Stmt *result, bool isToplevel) { if (!result) return nullptr; for (ctx->getBase()->iteration = 1;; ctx->getBase()->iteration++) { LOG_TYPECHECK("[iter] {} :: {}", ctx->getBase()->name, ctx->getBase()->iteration); if (ctx->getBase()->iteration >= MAX_TYPECHECK_ITER) { // log("-> {}", result->toString(2)); ParserErrors errors; errors.addError( {ErrorMessage{fmt::format("cannot typecheck '{}' in reasonable time", ctx->getBase()->name.empty() ? "toplevel" : getUnmangledName(ctx->getBase()->name)), result->getSrcInfo()}}); for (auto &error : findTypecheckErrors(result)) errors.addError(error); throw exc::ParserException(errors); } // Keep iterating until: // (1) success: the statement is marked as done; or // (2) failure: no expression or statements were marked as done during an // iteration (i.e., changedNodes is zero) ctx->typecheckLevel++; auto changedNodes = ctx->changedNodes; ctx->changedNodes = 0; auto returnEarly = ctx->returnEarly; ctx->returnEarly = false; auto tv = TypecheckVisitor(ctx, preamble); // if (preamble) { // auto pt = tv.transform(preamble); // preamble = cast(pt); // } result = tv.transform(result); std::swap(ctx->changedNodes, changedNodes); std::swap(ctx->returnEarly, returnEarly); ctx->typecheckLevel--; if (ctx->getBase()->iteration == 1 && isToplevel) { // Realize all @force_realize functions for (auto &f : ctx->cache->functions) { auto ast = f.second.ast; if (f.second.type && f.second.realizations.empty() && (ast->hasAttribute(Attr::ForceRealize) || ast->hasAttribute(Attr::Export) || (ast->hasAttribute(Attr::C) && !ast->hasAttribute(Attr::CVarArg)))) { seqassert(f.second.type->canRealize(), "cannot realize {}", f.first); LOG_REALIZE("[force_realize] {}", f.second.getType()->debugString(2)); realize(instantiateType(f.second.getType())); seqassert(!f.second.realizations.empty(), "cannot realize {}", f.first); } } } if (result->isDone()) { // Special union case: if union cannot be inferred return type is Union[NoneType] if (auto tr = ctx->getBase()->returnType) { if (auto tu = tr->getUnion()) { if (!tu->isSealed()) { if (tu->pendingTypes[0]->getLink() && tu->pendingTypes[0]->getLink()->kind == LinkType::Unbound) { auto r = tu->addType(getStdLibType("NoneType")); seqassert(r, "cannot add type to union {}", tu->debugString(2)); tu->seal(); } } } } break; } else if (changedNodes) { continue; } else { // Special case: nothing was changed, however there are unbound types that have // default values (e.g., generics with default values). Unify those types with // their default values and then run another round to see if anything changed. bool anotherRound = false; // Special case: return type might have default as well (e.g., Union) if (auto t = ctx->getBase()->returnType) { ctx->getBase()->pendingDefaults[0].insert(t); } // First unify "explicit" generics (whose default type is explicit), // then "implicit" ones (whose default type is compiler generated, // e.g. compiler-generated variable placeholders with default NoneType) for (auto &unbounds : ctx->getBase()->pendingDefaults | std::views::values) { if (!unbounds.empty()) { for (const auto &unbound : unbounds) { if (auto tu = unbound->getUnion()) { // Seal all dynamic unions after the iteration is over if (!tu->isSealed()) { tu->seal(); anotherRound = true; } } else if (auto u = unbound->getLink()) { types::Type::Unification undo; if (u->defaultType) { if (u->defaultType->getClass()) { // type[...] if (u->unify(extractClassType(u->defaultType.get()), &undo) >= 0) { anotherRound = true; } } else { // generic if (u->unify(u->defaultType.get(), &undo) >= 0) { anotherRound = true; } } } } } unbounds.clear(); if (anotherRound) break; } } if (anotherRound) continue; // Nothing helps. Return nullptr. return nullptr; } } return result; } /// Realize a type and create IR type stub. If type is a function type, also realize the /// underlying function and generate IR function stub. /// @return realized type or nullptr if the type cannot be realized types::Type *TypecheckVisitor::realize(types::Type *typ) { if (!typ || !typ->canRealize()) { return nullptr; } try { if (auto f = typ->getFunc()) { // Cache::CTimer t(ctx->cache, f->realizedName()); if (auto ret = realizeFunc(f)) { // Realize Function[..] type as well auto t = std::make_shared(ret->getClass()); realizeType(t.get()); // Needed for return type unification unify(f->getRetType(), extractClassGeneric(ret, 1)); return ret; } } else if (auto c = typ->getClass()) { auto t = realizeType(c); return t; } } catch (exc::ParserException &exc) { seqassert(!exc.getErrors().empty(), "empty error trace"); auto &bt = exc.getErrors().back(); if (bt.front().getErrorCode() == Error::MAX_REALIZATION) throw; if (auto f = typ->getFunc()) { if (f->ast->hasAttribute(Attr::HiddenFromUser)) { bt.back().setSrcInfo(getSrcInfo()); } else { std::vector args; for (size_t i = 0, ai = 0, gi = 0; i < f->ast->size(); i++) { auto [ns, n] = (*f->ast)[i].getNameWithStars(); args.push_back(fmt::format( "{}{}: {}", std::string(ns, '*'), getUserFacingName(n), (*f->ast)[i].isGeneric() ? extractFuncGeneric(f, gi++)->prettyString() : extractFuncArgType(f, ai++)->prettyString())); } auto name = f->ast->name; std::string name_args; if (startswith(name, "%_import_")) { for (auto &i : ctx->cache->imports | std::views::values) if (getMangledFunc("", i.importVar + "_call") == name) { name = i.name; break; } name = fmt::format("", name); } else { name = getUserFacingName(f->ast->getName()); name_args = fmt::format("({})", join(args, ", ")); } bt.addMessage(fmt::format("during the realization of {}{}", name, name_args), getSrcInfo()); } } else { bt.addMessage(fmt::format("during the realization of {}", typ->prettyString()), getSrcInfo()); } throw; } return nullptr; } /// Realize a type and create IR type stub. /// @return realized type or nullptr if the type cannot be realized types::Type *TypecheckVisitor::realizeType(types::ClassType *type) { if (!type || !type->canRealize()) return nullptr; // Check if the type fields are all initialized // (sometimes that's not the case: e.g., `class X: x: List[X]`) // generalize generics to ensure that they do not get unified later! if (type->is("unrealized_type")) type->generics[0].type = extractClassGeneric(type)->generalize(0); if (type->is("__NTuple__")) { auto n = std::max(static_cast(0), getIntLiteral(type)); auto tt = extractClassGeneric(type, 1)->getClass(); std::vector generics; auto t = instantiateType(generateTuple(n * tt->generics.size())); for (size_t i = 0, j = 0; i < n; i++) for (const auto &ttg : tt->generics) { unify(t->generics[j].getType(), ttg.getType()); generics.push_back(t->generics[j]); j++; } type->name = TYPE_TUPLE; type->generics = generics; type->_rn = ""; } // Check if the type was already realized auto rn = type->ClassType::realizedName(); auto cls = getClass(type); if (auto r = in(cls->realizations, rn)) { return (*r)->type->getClass(); } auto realized = type->getClass(); auto fields = getClassFields(realized); if (!cls->ast) return nullptr; // not yet done! auto fTypes = getClassFieldTypes(realized); for (auto &field : fTypes) { if (!field) return nullptr; } if (auto s = type->getStatic()) realized = s->getNonStaticType()->getClass(); // do not cache static but its root type! // Realize generics if (!type->is("unrealized_type")) for (auto &e : realized->generics) { if (!realize(e.getType())) return nullptr; if (e.type->getFunc() && !e.type->getFunc()->getRetType()->canRealize()) return nullptr; } // Realizations should always be visible, so add them to the toplevel rn = type->ClassType::realizedName(); auto rt = std::static_pointer_cast(realized->generalize(0)); auto val = std::make_shared(rn, "", ctx->getModule(), rt); if (!val->type->is(TYPE_TYPE)) val->type = instantiateTypeVar(realized); ctx->addAlwaysVisible(val, true); auto realization = getClass(realized)->realizations[rn] = std::make_shared(); realization->type = rt; realization->id = ++ctx->cache->classRealizationCnt; const auto &mros = getClass(realized)->mro; for (size_t i = 1; i < mros.size(); i++) { auto mt = instantiateType(mros[i].get(), realized); seqassert(mt->canRealize(), "cannot realize {}", mt->debugString(2)); realization->bases.push_back(mt); } // Create LLVM stub auto lt = makeIRType(realized); // Realize fields std::vector typeArgs; // needed for IR std::vector names; // needed for IR std::map memberInfo; // needed for IR for (size_t i = 0; i < fTypes.size(); i++) { if (!realize(fTypes[i].get())) { // realize(fTypes[i].get()); E(Error::TYPE_CANNOT_REALIZE_ATTR, getSrcInfo(), fields[i].name, realized->prettyString()); } // LOG_REALIZE("- member: {} -> {}: {}", field.name, field.type, fTypes[i]); realization->fields.emplace_back(fields[i].name, fTypes[i]); names.emplace_back(fields[i].name); typeArgs.emplace_back(makeIRType(fTypes[i]->getClass())); memberInfo[fields[i].name] = fTypes[i]->getSrcInfo(); } // Set IR attributes if (!names.empty()) { if (auto *ir = cast(lt)) { ir->getContents()->realize(typeArgs, names); ir->setAttribute(std::make_unique(memberInfo)); ir->getContents()->setAttribute( std::make_unique(memberInfo)); } } return rt.get(); } types::Type *TypecheckVisitor::realizeFunc(types::FuncType *type, bool force) { auto module = type->ast->getAttribute(Attr::Module)->value; auto &realizations = getFunction(type->getFuncName())->realizations; auto imp = getImport(module); if (auto r = in(realizations, type->realizedName())) { if (!force) { return (*r)->getType(); } } // auto *_t = new Cache::CTimer(ctx->cache, ctx->getRealizationStackName() + ":" + // type->realizedName()); auto oldCtx = this->ctx; this->ctx = imp->ctx; if (ctx->getRealizationDepth() > MAX_REALIZATION_DEPTH) { E(Error::MAX_REALIZATION, getSrcInfo(), getUserFacingName(type->getFuncName())); } bool isImport = isImportFn(type->getFuncName()); if (!isImport) { getLogger().level++; ctx->addBlock(); ctx->typecheckLevel++; ctx->bases.push_back({type->getFuncName(), type->getFunc()->shared_from_this(), type->getRetType()->shared_from_this()}); for (size_t t = ctx->bases.size() - 1; t-- > 0;) { if (startswith(ctx->getBaseName(), ctx->bases[t].name)) { ctx->getBase()->parent = static_cast(t); break; } } // LOG("[realize] F {} -> {} : base {} ; depth = {} ; ctx-base: {}; ret = {}; " // "parent = {}", // type->getFuncName(), type->realizedName(), ctx->getRealizationStackName(), // ctx->getRealizationDepth(), ctx->getBaseName(), // ctx->getBase()->returnType->debugString(2), // ctx->bases[ctx->getBase()->parent].name); } // Types might change after realization, fix it for (auto &t : *type) realizeType(t.getType()->getClass()); // Clone the generic AST that is to be realized auto ast = clean_clone(type->ast); if (auto s = generateSpecialAst(type)) ast->suite = s; addClassGenerics(type, true); ctx->getBase()->func = ast; // Internal functions have no AST that can be realized bool hasAst = ast->getSuite() && !ast->hasAttribute(Attr::Internal); // Add function arguments if (auto b = ast->getAttribute(Attr::Bindings)) { for (auto &[c, t] : b->captures) { if (t == BindingsAttribute::CaptureType::Global) { auto cp = ctx->find(c); if (!cp) E(Error::ID_NOT_FOUND, getSrcInfo(), c); if (!cp->isGlobal()) E(Error::FN_GLOBAL_NOT_FOUND, getSrcInfo(), "global", c); } } for (const auto [name, canonical] : b->localRenames) { auto val = ctx->forceFind(canonical); ctx->add(name, val); } } for (size_t i = 0, j = 0, gi = 0; hasAst && i < ast->size(); i++) { auto [_, varName] = (*ast)[i].getNameWithStars(); auto un = getUnmangledName(varName); if ((*ast)[i].isValue()) { TypePtr at = extractFuncArgType(type, j++)->shared_from_this(); bool isStatic = ast && getStaticGeneric((*ast)[i].getType()); if (!isStatic && at && at->getStatic()) at = at->getStatic()->getNonStaticType()->shared_from_this(); if (startswith(un, "$")) un = un.substr(1); if (at->is("TypeWrap")) { ctx->addType(un, varName, instantiateTypeVar(extractClassGeneric(at.get()))); } else { ctx->addVar(un, varName, std::make_shared(at)); } } else { if (startswith(un, "$")) { un = un.substr(1); auto g = type->funcGenerics[gi]; auto t = g.type; if (!g.staticKind && !t->is(TYPE_TYPE)) t = instantiateTypeVar(t.get()); auto v = ctx->addType(un, varName, t); v->generic = true; } gi++; } } // Populate realization table in advance to support recursive realizations auto key = type->realizedName(); // note: the key might change later ir::Func *oldIR = nullptr; // Get it if it was already made (force mode) if (auto i = in(realizations, key)) oldIR = (*i)->ir; auto r = realizations[key] = std::make_shared(); r->type = std::static_pointer_cast(type->shared_from_this()); r->ir = oldIR; if (auto b = ast->getAttribute(Attr::Bindings)) for (const auto &c : b->captures | std::views::keys) { auto h = ctx->find(c); r->captures.push_back(h ? h->canonicalName : ""); } // Realizations should always be visible, so add them to the toplevel auto val = std::make_shared(key, "", ctx->getModule(), type->shared_from_this()); ctx->addAlwaysVisible(val, true); ctx->getBase()->suite = ast->getSuite(); if (hasAst) { auto oldBlockLevel = ctx->blockLevel; ctx->blockLevel = 0; auto ret = inferTypes(ctx->getBase()->suite); ctx->blockLevel = oldBlockLevel; if (!ret) { realizations.erase(key); ParserErrors errors; if (!startswith(ast->name, "%_lambda")) { // Lambda typecheck failures are "ignored" as they are treated as statements, // not functions. // TODO: generalize this further. errors = findTypecheckErrors(ctx->getBase()->suite); } if (!isImport) { ctx->bases.pop_back(); ctx->popBlock(); ctx->typecheckLevel--; getLogger().level--; } if (!errors.empty()) { throw exc::ParserException(errors); } this->ctx = oldCtx; return nullptr; // inference must be delayed } else { ctx->getBase()->suite = ret; } // Use NoneType as the return type when the return type is not specified and // function has no return statement if (!ast->getReturn() && isUnbound(type->getRetType())) { auto rt = getStdLibType("NoneType")->shared_from_this(); if (ast->isAsync()) rt = instantiateType(getStdLibType("Coroutine"), {rt.get()}); unify(type->getRetType(), rt.get()); } } // Realize the return type auto ret = realize(type->getRetType()); if (type->hasUnbounds(/*includeGenerics*/ false)) { // log("cannot realize {}; undoing...", type->debugString(2)); realizations.erase(key); ctx->bases.pop_back(); ctx->popBlock(); ctx->typecheckLevel--; getLogger().level--; return nullptr; } seqassert(ret, "cannot realize return type '{}'", *(type->getRetType())); std::vector args; for (auto &i : *ast) { auto [_, varName] = i.getNameWithStars(); args.emplace_back(varName, nullptr, nullptr, i.status); } r->ast = N(r->type->realizedName(), nullptr, args, ctx->getBase()->suite, std::vector{}, ast->isAsync()); r->ast->setSrcInfo(ast->getSrcInfo()); r->ast->cloneAttributesFrom(ast); auto newType = std::static_pointer_cast(type->generalize(0)); auto newKey = newType->realizedName(); if (!in(ctx->cache->pendingRealizations, make_pair(newType->getFuncName(), newKey))) { realizations[newKey] = r; } else { realizations[key] = realizations[newKey]; } if (force) realizations[newKey]->ast = r->ast; r->type = newType; if (!r->ir) r->ir = makeIRFunction(r); val = std::make_shared(newKey, "", ctx->getModule(), r->type); ctx->addAlwaysVisible(val, true); if (!isImport) { ctx->bases.pop_back(); ctx->popBlock(); ctx->typecheckLevel--; getLogger().level--; } this->ctx = oldCtx; LOG_REALIZE("[func] {}", r->getType()->debugString(2)); return r->getType(); } /// Make IR node for a realized type. ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) { // Realize if not, and return cached value if it exists auto realizedName = t->ClassType::realizedName(); auto cls = ctx->cache->getClass(t); if (!in(cls->realizations, realizedName)) { t = realize(t->getClass())->getClass(); realizedName = t->ClassType::realizedName(); cls = ctx->cache->getClass(t); } if (auto l = cls->realizations[realizedName]->ir) { if (cls->rtti) cast(l)->setPolymorphic(); return l; } auto forceFindIRType = [&](Type *tt) { auto ttc = tt->getClass(); auto rn = ttc->ClassType::realizedName(); auto ttcls = ctx->cache->getClass(ttc); seqassert(ttc && in(ttcls->realizations, rn), "{} not realized", *tt); auto l = ttcls->realizations[rn]->ir; seqassert(l, "no LLVM type for {}", *tt); return l; }; // Prepare generics and statics std::vector types; std::vector statics; if (t->is("unrealized_type")) types.push_back(nullptr); else for (auto &m : t->generics) { if (auto s = m.type->getStatic()) statics.push_back(s); else types.push_back(forceFindIRType(m.getType())); } // Get the IR type auto *module = ctx->cache->module; ir::types::Type *handle = nullptr; if (t->name == "bool") { handle = module->getBoolType(); } else if (t->name == "byte") { handle = module->getByteType(); } else if (t->name == "int") { handle = module->getIntType(); } else if (t->name == "float") { handle = module->getFloatType(); } else if (t->name == "float32") { handle = module->getFloat32Type(); } else if (t->name == "float16") { handle = module->getFloat16Type(); } else if (t->name == "bfloat16") { handle = module->getBFloat16Type(); } else if (t->name == "float128") { handle = module->getFloat128Type(); } else if (t->name == "str") { handle = module->getStringType(); } else if (t->name == "Int" || t->name == "UInt") { handle = module->Nr(getIntLiteral(statics[0]), t->name == "Int"); } else if (t->name == "Ptr") { seqassert(types.size() == 1, "bad generics/statics"); handle = module->unsafeGetPointerType(types[0]); } else if (t->name == "Generator" || t->name == "AsyncGenerator") { seqassert(types.size() == 1, "bad generics/statics"); handle = module->unsafeGetGeneratorType(types[0]); } else if (t->name == "Coroutine") { seqassert(types.size() == 1, "bad generics/statics"); handle = module->unsafeGetGeneratorType(types[0]); } else if (t->name == TYPE_OPTIONAL) { seqassert(types.size() == 1, "bad generics/statics"); handle = module->unsafeGetOptionalType(types[0]); } else if (t->name == "NoneType") { seqassert(types.empty() && statics.empty(), "bad generics/statics"); auto record = cast(module->unsafeGetMemberedType(realizedName)); record->realize({}, {}); handle = record; } else if (t->name == "Union") { seqassert(!types.empty(), "bad union"); auto unionTypes = t->getUnion()->getRealizationTypes(); std::vector unionVec; unionVec.reserve(unionTypes.size()); for (auto &u : unionTypes) unionVec.emplace_back(forceFindIRType(u)); handle = module->unsafeGetUnionType(unionVec); } else if (t->name == "Function") { types.clear(); for (auto &m : extractClassGeneric(t)->getClass()->generics) types.push_back(forceFindIRType(m.getType())); auto ret = forceFindIRType(extractClassGeneric(t, 1)); handle = module->unsafeGetFuncType(realizedName, ret, types); } else if (t->name == getMangledClass("std.simd", "Vec")) { seqassert(types.size() == 1 && !statics.empty(), "bad generics/statics"); handle = module->unsafeGetVectorType(getIntLiteral(statics[0]), types[0]); } else { // Type arguments will be populated afterwards to avoid infinite loop with recursive // reference types (e.g., `class X: x: Optional[X]`) if (t->isRecord()) { std::vector typeArgs; // needed for IR std::vector names; // needed for IR std::map memberInfo; // needed for IR seqassert(!t->is("__NTuple__"), "ntuple not inlined"); auto ft = getClassFieldTypes(t->getClass()); const auto &fields = cls->fields; for (size_t i = 0; i < ft.size(); i++) { if (!realize(ft[i].get())) { E(Error::TYPE_CANNOT_REALIZE_ATTR, getSrcInfo(), fields[i].name, t->prettyString()); } names.emplace_back(fields[i].name); typeArgs.emplace_back(makeIRType(ft[i]->getClass())); memberInfo[fields[i].name] = ft[i]->getSrcInfo(); } auto record = cast(module->unsafeGetMemberedType(realizedName)); record->realize(typeArgs, names); handle = record; handle->setAttribute( std::make_unique(std::move(memberInfo))); } else { handle = module->unsafeGetMemberedType(realizedName, !t->isRecord()); if (cls->rtti) cast(handle)->setPolymorphic(); } } handle->setSrcInfo(t->getSrcInfo()); handle->setAstType( std::const_pointer_cast(t->shared_from_this())); return cls->realizations[realizedName]->ir = handle; } /// Make IR node for a realized function. ir::Func *TypecheckVisitor::makeIRFunction( const std::shared_ptr &r) { ir::Func *fn = nullptr; auto irm = ctx->cache->module; // Create and store a function IR node and a realized AST for IR passes if (r->ast->hasAttribute(Attr::Internal)) { // e.g., __new__, Ptr.__new__, etc. fn = irm->Nr(r->type->ast->name); } else if (r->ast->hasAttribute(Attr::LLVM)) { fn = irm->Nr(r->type->realizedName()); } else if (r->ast->hasAttribute(Attr::C)) { fn = irm->Nr(r->type->realizedName()); } else { fn = irm->Nr(r->type->realizedName()); } fn->setUnmangledName(ctx->cache->reverseIdentifierLookup[r->type->ast->name]); auto parent = r->type->funcParent; if (auto aa = r->ast->getAttribute(Attr::ParentClass)) { if (!aa->value.empty() && !r->ast->hasAttribute(Attr::Method)) { // Hack for non-generic methods parent = ctx->find(aa->value)->type; } } if (parent && parent->isInstantiated() && parent->canRealize()) { parent = extractClassType(parent.get())->shared_from_this(); realize(parent.get()); fn->setParentType(makeIRType(parent->getClass())); } fn->setGlobal(); // Mark this realization as pending (i.e., realized but not translated) ctx->cache->pendingRealizations.insert({r->type->ast->name, r->type->realizedName()}); seqassert(!r->type || r->ast->size() == r->type->size() + r->type->funcGenerics.size(), "type/AST argument mismatch"); // Populate the IR node std::vector names; std::vector types; for (size_t i = 0, j = 0; i < r->ast->size(); i++) { if ((*r->ast)[i].isValue()) { if (!extractFuncArgType(r->getType(), j)->getFunc()) { types.push_back(makeIRType(extractFuncArgType(r->getType(), j)->getClass())); names.push_back(ctx->cache->reverseIdentifierLookup[(*r->ast)[i].getName()]); } j++; } } if (r->ast->hasAttribute(Attr::CVarArg)) { types.pop_back(); names.pop_back(); } auto irType = irm->unsafeGetFuncType(r->type->realizedName(), makeIRType(r->type->getRetType()->getClass()), types, r->ast->hasAttribute(Attr::CVarArg)); irType->setAstType(r->type->shared_from_this()); fn->realize(irType, names); return fn; } ir::Func *TypecheckVisitor::realizeIRFunc(types::FuncType *fn, const std::vector &generics) { // TODO: used by cytonization. Probably needs refactoring. auto fnType = instantiateType(fn); types::Type::Unification u; for (size_t i = 0; i < generics.size(); i++) fnType->getFunc()->funcGenerics[i].type->unify(generics[i].get(), &u); if (!realize(fnType.get())) return nullptr; auto pr = ctx->cache->pendingRealizations; // copy it as it might be modified for (const auto &key : pr | std::views::keys) TranslateVisitor(ctx->cache->codegenCtx) .translateStmts(clone(getFunction(key)->ast)); return getFunction(fn->ast->getName())->realizations[fnType->realizedName()]->ir; } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/loops.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/match.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; namespace codon::ast { using namespace types; using namespace matcher; /// Ensure that `break` is in a loop. /// Transform if a loop break variable is available /// (e.g., a break within loop-else block). /// @example /// `break` -> `no_break = False; break` void TypecheckVisitor::visit(BreakStmt *stmt) { if (!ctx->getBase()->getLoop()) E(Error::EXPECTED_LOOP, stmt, "break"); ctx->getBase()->getLoop()->flat = false; if (!ctx->getBase()->getLoop()->breakVar.empty()) { resultStmt = N(transform(N( N(ctx->getBase()->getLoop()->breakVar), N(false), nullptr, AssignStmt::UpdateMode::Update)), N()); } else { stmt->setDone(); if (!ctx->staticLoops.back().empty()) { auto a = N(N(ctx->staticLoops.back()), N(false)); a->setUpdate(); resultStmt = transform(N(a, stmt)); } } } /// Ensure that `continue` is in a loop void TypecheckVisitor::visit(ContinueStmt *stmt) { if (!ctx->getBase()->getLoop()) E(Error::EXPECTED_LOOP, stmt, "continue"); ctx->getBase()->getLoop()->flat = false; stmt->setDone(); if (!ctx->staticLoops.back().empty()) { resultStmt = N(); resultStmt->setDone(); } } /// Transform a while loop. /// @example /// `while cond: ...` -> `while cond: ...` /// `while cond: ... else: ...` -> ```no_break = True /// while cond: /// ... /// if no_break: ...``` void TypecheckVisitor::visit(WhileStmt *stmt) { // Check for while-else clause std::string breakVar; if (stmt->getElse() && stmt->getElse()->firstInBlock()) { // no_break = True breakVar = getTemporaryVar("no_break"); prependStmts->push_back( transform(N(N(breakVar), N(true)))); } ctx->staticLoops.push_back(stmt->gotoVar.empty() ? "" : stmt->gotoVar); ctx->getBase()->loops.emplace_back(breakVar); auto oldExpectedType = getStdLibType("bool")->shared_from_this(); std::swap(ctx->expectedType, oldExpectedType); stmt->cond = transform(stmt->getCond()); std::swap(ctx->expectedType, oldExpectedType); wrapExpr(&stmt->cond, getStdLibType("bool")); ctx->blockLevel++; stmt->suite = SuiteStmt::wrap(transform(stmt->getSuite())); ctx->blockLevel--; ctx->staticLoops.pop_back(); // Complete while-else clause if (stmt->getElse() && stmt->getElse()->firstInBlock()) { auto es = stmt->getElse(); stmt->elseSuite = nullptr; resultStmt = transform(N(stmt, N(N(breakVar), es))); } ctx->getBase()->loops.pop_back(); if (stmt->getCond()->isDone() && stmt->getSuite()->isDone()) stmt->setDone(); } /// Typecheck for statements. Wrap the iterator expression with `__iter__` if needed. /// See @c transformHeterogenousTupleFor for iterating heterogenous tuples. void TypecheckVisitor::visit(ForStmt *stmt) { stmt->decorator = transformForDecorator(stmt->getDecorator()); if (auto fc = cast(stmt->getDecorator())) { if (auto fi = cast(fc->getExpr()); fi && fi->getType()->getFunc() && fi->getType()->getFunc()->getFuncName() == getMangledFunc("std.openmp", "for_par")) { if (auto n = extractFuncGeneric(fi->getType(), 3)->getBoolStatic(); n && n->value) { prependStmts->push_back( transform(N(N("gpu"), nullptr, std::vector{}, nullptr, getTemporaryVar("_")))); } } } std::string breakVar; // Needs in-advance transformation to prevent name clashes with the iterator variable stmt->getIter()->setAttribute( Attr::ExprNoSpecial); // do not expand special calls here, // might be needed for statis loops! stmt->iter = transform(stmt->getIter()); // Check for for-else clause Stmt *assign = nullptr; if (stmt->getElse() && stmt->getElse()->firstInBlock()) { breakVar = getTemporaryVar("no_break"); assign = transform(N(N(breakVar), N(true))); } // Extract the iterator type of the for auto iterType = extractClassType(stmt->getIter()); if (!iterType) return; // wait until the iterator is known auto [delay, staticLoop] = transformStaticForLoop(stmt); if (delay) return; if (staticLoop) { resultStmt = staticLoop; return; } // Replace for (i, j) in ... { ... } with for tmp in ...: { i, j = tmp ; ... } if (!cast(stmt->getVar())) { auto var = N(ctx->cache->getTemporaryVar("for")); auto ns = unpackAssignment(stmt->getVar(), var); stmt->suite = N(ns, stmt->getSuite()); stmt->var = var; } // Case: iterating a non-generator. Wrap with `__iter__` bool isGenerator = iterType->name == (stmt->isAsync() ? "AsyncGenerator" : "Generator"); if (!isGenerator && !stmt->isWrapped()) { stmt->iter = transform(N(N(stmt->getIter(), "__iter__"))); iterType = extractClassType(stmt->getIter()); stmt->wrapped = true; if (!iterType) return; isGenerator = iterType->name == (stmt->isAsync() ? "AsyncGenerator" : "Generator"); } ctx->getBase()->loops.emplace_back(breakVar); auto var = cast(stmt->getVar()); seqassert(var, "corrupt for variable: {}", *(stmt->getVar())); if (!var->hasAttribute(Attr::ExprDominated) && !var->hasAttribute(Attr::ExprDominatedUsed)) { ctx->addVar( getUnmangledName(var->getValue()), ctx->generateCanonicalName(var->getValue()), stmt->getVar()->getType() ? stmt->getVar()->getType()->shared_from_this() : instantiateUnbound(), getTime()); } else if (var->hasAttribute(Attr::ExprDominatedUsed)) { var->eraseAttribute(Attr::ExprDominatedUsed); var->setAttribute(Attr::ExprDominated); stmt->suite = N( N(N(fmt::format("{}{}", var->getValue(), VAR_USED_SUFFIX)), N(true), nullptr, AssignStmt::UpdateMode::Update), stmt->getSuite()); } stmt->var = transform(stmt->getVar()); // Unify iterator variable and the iterator type if (iterType && !isGenerator) E(Error::EXPECTED_GENERATOR, stmt->getIter()); if (iterType) unify(stmt->getVar()->getType(), extractClassGeneric(iterType)); ctx->staticLoops.emplace_back(); ctx->blockLevel++; stmt->suite = SuiteStmt::wrap(transform(stmt->getSuite())); ctx->blockLevel--; ctx->staticLoops.pop_back(); if (ctx->getBase()->getLoop()->flat) stmt->flat = true; // Complete for-else clause if (stmt->getElse() && stmt->getElse()->firstInBlock()) { auto es = stmt->getElse(); stmt->elseSuite = nullptr; resultStmt = transform(N(assign, stmt, N(N(breakVar), es))); stmt->elseSuite = nullptr; } ctx->getBase()->loops.pop_back(); if (stmt->getIter()->isDone() && stmt->getSuite()->isDone()) stmt->setDone(); } /// Transform and check for OpenMP decorator. /// @example /// `@par(num_threads=2, openmp="schedule(static)")` -> /// `for_par(num_threads=2, schedule="static")` Expr *TypecheckVisitor::transformForDecorator(Expr *decorator) { if (!decorator) return nullptr; Expr *callee = decorator; if (auto c = cast(callee)) callee = c->getExpr(); auto ci = cast(transform(callee)); if (!ci || !startswith(ci->getValue(), getMangledFunc("std.openmp", "for_par"))) { E(Error::LOOP_DECORATOR, decorator); } std::vector args; std::string openmp; std::vector omp; if (auto c = cast(decorator)) for (auto &a : *c) { if (a.getName() == "openmp" || (a.getName().empty() && openmp.empty() && cast(a.getExpr()))) { auto ompOrErr = parseOpenMP(ctx->cache, cast(a.getExpr())->getValue(), a.value->getSrcInfo()); if (!ompOrErr) throw exc::ParserException(ompOrErr.takeError()); omp = *ompOrErr; } else { args.emplace_back(a.getName(), transform(a.getExpr())); } } for (auto &a : omp) args.emplace_back(a.getName(), transform(a.getExpr())); return transform(N(transform(N("for_par")), args)); } /// Handle static for constructs. /// @example /// `for i in statictuple(1, x): ` -> /// ```loop = True /// while loop: /// while loop: /// i: Literal[int] = 1; ; break /// while loop: /// i = x; ; break /// loop = False # also set to False on break /// If a loop is flat, while wrappers are removed. /// A separate suite is generated for each static iteration. std::pair TypecheckVisitor::transformStaticForLoop(const ForStmt *stmt) { auto loopVar = getTemporaryVar("loop"); auto suite = clean_clone(stmt->getSuite()); auto [ok, delay, preamble, items] = transformStaticLoopCall( stmt->getVar(), &suite, stmt->getIter(), [&](Stmt *assigns) { Stmt *ret = nullptr; if (!stmt->flat) { auto brk = N(); brk->setDone(); // Avoid transforming this one to skip extra checks // var [: Static] := expr; suite... auto loop = N(N(loopVar), N(assigns, clone(suite), brk)); loop->gotoVar = loopVar; ret = loop; } else { ret = N(assigns, clone(stmt->getSuite())); } return ret; }); if (!ok) return {false, nullptr}; if (delay) return {true, nullptr}; // Close the loop auto block = N(); block->addStmt(preamble); for (auto &i : items) block->addStmt(cast(i)); Stmt *loop = nullptr; if (!stmt->flat) { ctx->blockLevel++; auto a = N(N(loopVar), N(false)); a->setUpdate(); block->addStmt(a); loop = transform(N(N(N(loopVar), N(true)), N(N(loopVar), block))); ctx->blockLevel--; } else { loop = transform(block); } return {false, loop}; } std::tuple> TypecheckVisitor::transformStaticLoopCall(Expr *varExpr, SuiteStmt **varSuite, Expr *iter, const std::function &wrap, bool allowNonHeterogenous) { if (!iter->getClassType()) return {true, true, nullptr, {}}; std::vector vars{}; if (auto ei = cast(varExpr)) { vars.push_back(ei->getValue()); } else { Items *list = nullptr; if (auto el = cast(varExpr)) list = el; else if (auto et = cast(varExpr)) list = et; if (list) { for (const auto &it : *list) if (auto eli = cast(it)) { vars.push_back(eli->getValue()); } else { return {false, false, nullptr, {}}; } } else { return {false, false, nullptr, {}}; } } Stmt *preamble = nullptr; iter = getHeadExpr(iter); auto fn = cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; std::vector block; if (fn && startswith(fn->getValue(), getMangledFunc("std.internal.static", "tuple"))) { block = populateStaticTupleLoop(iter, vars); } else if (fn && startswith(fn->getValue(), getMangledFunc("std.internal.static", "range", 1))) { block = populateSimpleStaticRangeLoop(iter, vars); } else if (fn && startswith(fn->getValue(), getMangledFunc("std.internal.static", "range"))) { block = populateStaticRangeLoop(iter, vars); } else if (fn && startswith(fn->getValue(), getMangledMethod("std.internal.static", "function", "overloads"))) { block = populateStaticFnOverloadsLoop(iter, vars); } else if (fn && startswith(fn->getValue(), getMangledFunc("std.internal.static", "enumerate"))) { block = populateStaticEnumerateLoop(iter, vars); } else if (fn && startswith(fn->getValue(), getMangledFunc("std.internal.static", "vars"))) { block = populateStaticVarsLoop(iter, vars); } else if (fn && startswith(fn->getValue(), getMangledFunc("std.internal.static", "vars_types"))) { block = populateStaticVarTypesLoop(iter, vars); } else { if (iter->getType()->is(TYPE_TUPLE)) { // Maybe heterogenous? if (!iter->getType()->canRealize()) return {true, true, nullptr, {}}; // wait until the tuple is fully realizable if (!isHeterogenous(iter->getClassType()) && !allowNonHeterogenous) return {false, false, nullptr, {}}; block = populateStaticHeterogenousTupleLoop(iter, vars); preamble = block.back(); block.pop_back(); } else { return {false, false, nullptr, {}}; } } std::vector wrapBlock; wrapBlock.reserve(block.size()); for (auto b : block) { wrapBlock.push_back(wrap(b)); } return {true, false, preamble, wrapBlock}; } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/op.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" #include "codon/parser/match.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; namespace codon::ast { using namespace types; using namespace matcher; /// Replace unary operators with the appropriate magic calls. /// Also evaluate static expressions. See @c evaluateStaticUnary for details. void TypecheckVisitor::visit(UnaryExpr *expr) { expr->expr = transform(expr->getExpr()); if (cast(expr->getExpr()) && expr->getOp() == "-") { // Special case: make - INT(val) same as INT(-val) to simplify IR and everything resultExpr = transform(N(-cast(expr->getExpr())->getValue())); return; } StaticType *staticType = nullptr; static std::unordered_map> staticOps = { {LiteralKind::Int, {"-", "+", "!", "~"}}, {LiteralKind::String, {"!"}}, {LiteralKind::Bool, {"!"}}}; // Handle static expressions if (auto s = expr->getExpr()->getType()->getStaticKind()) { if (in(staticOps[s], expr->getOp())) { if ((resultExpr = evaluateStaticUnary(expr))) { staticType = resultExpr->getType()->getStatic(); } else { return; } } } else if (isUnbound(expr->getExpr())) { return; } if (expr->getOp() == "!") { // `not expr` -> `expr.__bool__().__invert__()` resultExpr = transform(N(N( N(N(expr->getExpr(), "__bool__")), "__invert__"))); } else { std::string magic; if (expr->getOp() == "~") magic = "invert"; else if (expr->getOp() == "+") magic = "pos"; else if (expr->getOp() == "-") magic = "neg"; else seqassert(false, "invalid unary operator '{}'", expr->getOp()); resultExpr = transform( N(N(expr->getExpr(), fmt::format("__{}__", magic)))); } if (staticType) resultExpr->setType(staticType->shared_from_this()); } /// Replace binary operators with the appropriate magic calls. /// See @c transformBinarySimple , @c transformBinaryIs , @c transformBinaryMagic and /// @c transformBinaryInplaceMagic for details. /// Also evaluate static expressions. See @c evaluateStaticBinary for details. void TypecheckVisitor::visit(BinaryExpr *expr) { expr->lexpr = transform(expr->getLhs(), true); // Static short-circuit if (expr->getLhs()->getType()->getStaticKind() && expr->op == "&&") { if (auto tb = expr->getLhs()->getType()->getBoolStatic()) { if (!tb->value) { if (ctx->expectedType && ctx->expectedType->is("bool")) resultExpr = transform(N(false)); else resultExpr = expr->getLhs(); } else { resultExpr = transform(N(N(expr->getLhs()), expr->getRhs())); } } else if (auto ts = expr->getLhs()->getType()->getStrStatic()) { if (ts->value.empty()) { if (ctx->expectedType && ctx->expectedType->is("bool")) resultExpr = transform(N(false)); else resultExpr = expr->getLhs(); } else { resultExpr = transform(N(N(expr->getLhs()), expr->getRhs())); } } else if (auto ti = expr->getLhs()->getType()->getIntStatic()) { if (!ti->value) { if (ctx->expectedType && ctx->expectedType->is("bool")) resultExpr = transform(N(false)); else resultExpr = expr->getLhs(); } else { resultExpr = transform(N(N(expr->getLhs()), expr->getRhs())); } } else { expr->getType()->getUnbound()->staticKind = LiteralKind::Bool; } return; } else if (expr->getLhs()->getType()->getStaticKind() && expr->op == "||") { if (auto tb = expr->getLhs()->getType()->getBoolStatic()) { if (tb->value) { if (ctx->expectedType && ctx->expectedType->is("bool")) resultExpr = transform(N(true)); else resultExpr = expr->getLhs(); } else { resultExpr = transform(N(N(expr->getLhs()), expr->getRhs())); } } else if (auto ts = expr->getLhs()->getType()->getStrStatic()) { if (!ts->value.empty()) { if (ctx->expectedType && ctx->expectedType->is("bool")) resultExpr = transform(N(true)); else resultExpr = expr->getLhs(); } else { resultExpr = transform(N(N(expr->getLhs()), expr->getRhs())); } } else if (auto ti = expr->getLhs()->getType()->getIntStatic()) { if (ti->value) { if (ctx->expectedType && ctx->expectedType->is("bool")) resultExpr = transform(N(true)); else resultExpr = expr->getLhs(); } else { resultExpr = transform(N(N(expr->getLhs()), expr->getRhs())); } } else { expr->getType()->getUnbound()->staticKind = LiteralKind::Bool; } return; } expr->rexpr = transform(expr->getRhs(), true); StaticType *staticType = nullptr; static std::unordered_map> staticOps = { {LiteralKind::Int, {"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//", "%", "&", "|", "^", ">>", "<<"}}, {LiteralKind::String, {"==", "!=", "+"}}, {LiteralKind::Bool, {"<", "<=", ">", ">=", "==", "!=", "&&", "||"}}}; if (auto l = expr->getLhs()->getType()->getStaticKind(), r = expr->getRhs()->getType()->getStaticKind(); l && r) { bool isStatic = l == r && in(staticOps[l], expr->getOp()); if (!isStatic && ((l == LiteralKind::Int && r == LiteralKind::Bool) || (r == LiteralKind::Int && l == LiteralKind::Bool)) && in(staticOps[LiteralKind::Int], expr->getOp())) isStatic = true; if (isStatic) { if ((resultExpr = evaluateStaticBinary(expr))) staticType = resultExpr->getType()->getStatic(); else return; } } if (isTypeExpr(expr->getLhs()) && isTypeExpr(expr->getRhs()) && expr->getOp() == "|") { // Case: unions resultExpr = transform(N( N("Union"), std::vector{expr->getLhs(), expr->getRhs()})); } else if (auto e = transformBinarySimple(expr)) { // Case: simple binary expressions resultExpr = e; } else if (expr->getLhs()->getType()->getUnbound() || (expr->getOp() != "is" && expr->getRhs()->getType()->getUnbound())) { // Case: types are unknown, so continue later return; } else if (expr->getOp() == "is") { // Case: is operator resultExpr = transformBinaryIs(expr); } else { if (auto ei = transformBinaryInplaceMagic(expr, false)) { // Case: in-place magic methods resultExpr = ei; } else if (auto em = transformBinaryMagic(expr)) { // Case: normal magic methods resultExpr = em; } else if (expr->getLhs()->getType()->is(TYPE_OPTIONAL)) { // Special case: handle optionals if everything else fails. // Assumes that optionals have no relevant magics (except for __eq__) resultExpr = transform( N(N(N(FN_OPTIONAL_UNWRAP), expr->getLhs()), expr->getOp(), expr->getRhs(), expr->isInPlace())); } else { // Nothing found: report an error E(Error::OP_NO_MAGIC, expr, expr->getOp(), expr->getLhs()->getType()->prettyString(), expr->getRhs()->getType()->prettyString()); } } if (staticType) resultExpr->setType(staticType->shared_from_this()); } /// Transform chain binary expression. /// @example /// `a <= b <= c` -> `(a <= (chain := b)) and (chain <= c)` /// The assignment above ensures that all expressions are executed only once. void TypecheckVisitor::visit(ChainBinaryExpr *expr) { seqassert(expr->exprs.size() >= 2, "not enough expressions in ChainBinaryExpr"); std::vector items; std::string prev; for (int i = 1; i < expr->exprs.size(); i++) { auto l = prev.empty() ? clone(expr->exprs[i - 1].second) : N(prev); prev = ctx->generateCanonicalName("chain"); auto r = (i + 1 == expr->exprs.size()) ? clone(expr->exprs[i].second) : N(N(N(prev), clone(expr->exprs[i].second)), N(prev)); items.emplace_back(N(l, expr->exprs[i].first, r)); } Expr *final = items.back(); for (auto i = items.size() - 1; i-- > 0;) final = N(items[i], "&&", final); auto oldExpectedType = getStdLibType("bool")->shared_from_this(); std::swap(ctx->expectedType, oldExpectedType); resultExpr = transform(final); std::swap(ctx->expectedType, oldExpectedType); } /// Helper function that locates the pipe ellipsis within a collection of (possibly /// nested) CallExprs. /// @return List of CallExprs and their locations within the parent CallExpr /// needed to access the ellipsis. /// @example /// `foo(bar(1, baz(...)))` returns `[{0, baz}, {1, bar}, {0, foo}]` std::vector> TypecheckVisitor::findEllipsis(Expr *expr) { auto call = cast(expr); if (!call) return {}; size_t ai = 0; for (auto &a : *call) { if (auto el = cast(a)) { if (el->isPipe()) return {{ai, expr}}; } else if (cast(a)) { auto v = findEllipsis(a); if (!v.empty()) { v.emplace_back(ai, expr); return v; } } ai++; } return {}; } /// Typecheck pipe expressions. /// Each stage call `foo(x)` without an ellipsis will be transformed to `foo(..., x)`. /// Stages that are not in the form of CallExpr will be transformed to it (e.g., `foo` /// -> `foo(...)`). /// Special care is taken of stages that can expand to multiple stages (e.g., `a |> foo` /// might become `a |> unwrap |> foo` to satisfy type constraints). void TypecheckVisitor::visit(PipeExpr *expr) { bool hasGenerator = false; // Return T if t is of type `Generator[T]`; otherwise just `type(t)` auto getIterableType = [&](Type *t) { if (t->is("Generator")) { hasGenerator = true; return extractClassGeneric(t); } return t; }; // List of output types // (e.g., for `a|>b|>c` it is `[type(a), type(a|>b), type(a|>b|>c)]`). // Note: the generator types are completely preserved (i.e., not extracted) expr->inTypes.clear(); // Process the pipeline head expr->front().expr = transform(expr->front().expr); auto inType = expr->front().expr->getType(); // input type to the next stage expr->inTypes.push_back(inType->shared_from_this()); inType = getIterableType(inType); auto done = expr->front().expr->isDone(); for (size_t pi = 1; pi < expr->size(); pi++) { int inTypePos = -1; // ellipsis position Expr **ec = &((*expr)[pi].expr); // a pointer so that we can replace it while (auto se = cast(*ec)) // handle StmtExpr (e.g., in partial calls) ec = &(se->expr); if (auto call = cast(*ec)) { // Case: a call. Find the position of the pipe ellipsis within it for (size_t ia = 0; inTypePos == -1 && ia < call->size(); ia++) if (cast((*call)[ia].value)) inTypePos = static_cast(ia); // No ellipses found? Prepend it as the first argument if (inTypePos == -1) { call->items.insert(call->items.begin(), CallArg{"", N(EllipsisExpr::PARTIAL)}); inTypePos = 0; } } else { // Case: not a call. Convert it to a call with a single ellipsis (*expr)[pi].expr = N((*expr)[pi].expr, N(EllipsisExpr::PARTIAL)); ec = &(*expr)[pi].expr; inTypePos = 0; } // Set the ellipsis type auto el = cast((*cast(*ec))[inTypePos].value); el->mode = EllipsisExpr::PIPE; // Don't unify unbound inType yet (it might become a generator that needs to be // extracted) if (!el->getType()) el->setType(instantiateUnbound()); if (inType && !inType->getUnbound()) unify(el->getType(), inType); // Transform the call. Because a transformation might wrap the ellipsis in layers, // make sure to extract these layers and move them to the pipeline. // Example: `foo(...)` that is transformed to `foo(unwrap(...))` will become // `unwrap(...) |> foo(...)` *ec = transform(*ec); auto layers = findEllipsis(*ec); seqassert(!layers.empty(), "can't find the ellipsis"); if (layers.size() > 1) { // Prepend layers for (auto &[pos, prepend] : layers) { (*cast(prepend))[pos].value = N(EllipsisExpr::PIPE); expr->items.insert(expr->items.begin() + pi++, {"|>", prepend}); } // Rewind the loop (yes, the current expression will get transformed again) /// TODO: avoid reevaluation expr->items.erase(expr->items.begin() + pi); pi = pi - layers.size() - 1; continue; } if ((*ec)->getType()) unify((*expr)[pi].expr->getType(), (*ec)->getType()); (*expr)[pi].expr = *ec; inType = (*expr)[pi].expr->getType(); if (!realize(inType)) done = false; expr->inTypes.push_back(inType->shared_from_this()); // Do not extract the generator in the last stage of a pipeline if (pi + 1 < expr->items.size()) inType = getIterableType(inType); } unify(expr->getType(), (hasGenerator ? getStdLibType("NoneType") : inType)); if (done) expr->setDone(); } /// Transform index expressions. /// @example /// `foo[T]` -> Instantiate(foo, [T]) if `foo` is a type /// `tup[1]` -> `tup.item1` if `tup` is tuple /// `foo[idx]` -> `foo.__getitem__(idx)` /// expr.itemN or a sub-tuple if index is static (see transformStaticTupleIndex()), void TypecheckVisitor::visit(IndexExpr *expr) { if (match(expr, M(M(MOr("Literal", "Static")), M(MOr("int", "str", "bool"))))) { // Special case: static types. auto typ = instantiateUnbound(); typ->staticKind = getStaticGeneric(expr); unify(expr->getType(), typ); expr->setDone(); return; } else if (match(expr->expr, M(MOr("Literal", "Static")))) { E(Error::BAD_STATIC_TYPE, expr->getIndex()); } if (match(expr->expr, M("tuple"))) cast(expr->expr)->setValue(TYPE_TUPLE); expr->expr = transform(expr->expr, true); // IndexExpr[i1, ..., iN] is internally represented as // IndexExpr[TupleExpr[i1, ..., iN]] for N > 1 std::vector items; bool isTuple = false; if (auto t = cast(expr->getIndex())) { items = t->items; isTuple = true; } else { items.push_back(expr->getIndex()); } auto origIndex = clone(expr->getIndex()); for (auto &i : items) { if (cast(i) && isTypeExpr(expr->getExpr())) { // Special case: `A[[A, B], C]` -> `A[Tuple[A, B], C]` (e.g., in // `Function[...]`) i = N(N(TYPE_TUPLE), cast(i)->items); } i = transform(i, true); } if (isTypeExpr(expr->getExpr())) { resultExpr = transform(N(expr->getExpr(), items)); return; } expr->index = (!isTuple && items.size() == 1) ? items[0] : N(items); auto cls = expr->getExpr()->getClassType(); if (!cls) { // Wait until the type becomes known return; } // Case: static tuple access // Note: needs untransformed origIndex to parse statics nicely auto [isStaticTuple, tupleExpr] = transformStaticTupleIndex(cls, expr->getExpr(), origIndex); if (isStaticTuple) { if (tupleExpr) resultExpr = tupleExpr; } else { // Case: normal __getitem__ resultExpr = transform( N(N(expr->getExpr(), "__getitem__"), expr->getIndex())); } } /// Transform an instantiation to canonical realized name. /// @example /// Instantiate(foo, [bar]) -> Id("foo[bar]") void TypecheckVisitor::visit(InstantiateExpr *expr) { expr->expr = transformType(expr->getExpr()); TypePtr typ = nullptr; size_t typeParamsSize = expr->size(); if (extractType(expr->expr)->is(TYPE_TUPLE)) { if (!expr->empty()) { expr->items.front() = transform(expr->front()); if (expr->front()->getType()->getStaticKind() == LiteralKind::Int) { auto et = N( N("Tuple"), std::vector(expr->items.begin() + 1, expr->items.end())); resultExpr = transform(N(N("__NTuple__"), std::vector{(*expr)[0], et})); return; } } auto t = generateTuple(typeParamsSize); typ = instantiateType(t); } else { typ = instantiateType(expr->getExpr()->getSrcInfo(), extractType(expr->getExpr())); } seqassert(typ->getClass(), "unknown type: {}", *(expr->getExpr())); auto &generics = typ->getClass()->generics; bool isUnion = typ->getUnion() != nullptr; if (!isUnion && typeParamsSize != generics.size()) E(Error::GENERICS_MISMATCH, expr, getUserFacingName(typ->getClass()->name), generics.size(), typeParamsSize); if (isId(expr->getExpr(), TRAIT_CALLABLE)) { // Case: CallableTrait[...] trait instantiation // CallableTrait error checking. std::vector types; for (auto &typeParam : *expr) { typeParam = transformType(typeParam); if (typeParam->getType()->getStaticKind()) E(Error::INST_CALLABLE_STATIC, typeParam); types.push_back(extractType(typeParam)->shared_from_this()); } auto ub = instantiateUnbound(); // Set up the CallableTrait ub->getLink()->trait = std::make_shared(ctx->cache, types); unify(expr->getType(), instantiateTypeVar(ub.get())); } else if (isId(expr->getExpr(), TRAIT_TYPE)) { // Case: TypeTrait[...] trait instantiation (*expr)[0] = transformType((*expr)[0]); auto ub = instantiateUnbound(); ub->getLink()->trait = std::make_shared(extractType(expr->front())->shared_from_this()); unify(expr->getType(), ub); } else { for (size_t i = 0; i < expr->size(); i++) { (*expr)[i] = transformType((*expr)[i]); auto t = instantiateType((*expr)[i]->getSrcInfo(), extractType((*expr)[i])); if (isUnion || (*expr)[i]->getType()->getStaticKind() != generics[i].getType()->getStaticKind()) { if (cast((*expr)[i])) // `None` -> `NoneType` (*expr)[i] = transformType((*expr)[i]); if (!isTypeExpr((*expr)[i])) E(Error::EXPECTED_TYPE, (*expr)[i], "type"); } if (isUnion) { if (!typ->getUnion()->addType(t.get())) E(error::Error::UNION_TOO_BIG, (*expr)[i], typ->getUnion()->pendingTypes.size()); } else { unify(t.get(), generics[i].getType()); } } if (isUnion) { typ->getUnion()->seal(); } unify(expr->getType(), instantiateTypeVar(typ.get())); // If the type is realizable, use the realized name instead of instantiation // (e.g. use Id("Ptr[byte]") instead of Instantiate(Ptr, {byte})) if (auto rt = realize(expr->getType())) { auto t = extractType(rt); resultExpr = N(t->realizedName()); resultExpr->setType(rt->shared_from_this()); resultExpr->setDone(); } } // Handle side effects if (!ctx->simpleTypes) { std::vector prepends; for (auto &t : *expr) { if (hasSideEffect(t)) { auto name = getTemporaryVar("call"); auto front = transform(N(N(name), t, getParamType(t->getType()))); auto swap = transformType(N(name)); t = swap; prepends.emplace_back(front); } } if (!prepends.empty()) { resultExpr = transform(N(prepends, resultExpr ? resultExpr : expr)); } } } /// Transform a slice expression. /// @example /// `start::step` -> `Slice(start, Optional.__new__(), step)` void TypecheckVisitor::visit(SliceExpr *expr) { Expr *none = N(N(N(TYPE_OPTIONAL), "__new__")); resultExpr = transform(N(N(getStdLibType("Slice")->name), expr->getStart() ? expr->getStart() : clone(none), expr->getStop() ? expr->getStop() : clone(none), expr->getStep() ? expr->getStep() : clone(none))); } /// Evaluate a static unary expression and return the resulting static expression. /// If the expression cannot be evaluated yet, return nullptr. /// Supported operators: (strings) not (ints) not, -, + Expr *TypecheckVisitor::evaluateStaticUnary(const UnaryExpr *expr) { // Case: static strings if (expr->getExpr()->getType()->getStaticKind() == LiteralKind::String) { if (expr->getOp() == "!") { if (expr->getExpr()->getType()->canRealize()) { bool value = getStrLiteral(expr->getExpr()->getType()).empty(); LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value); return transform(N(value)); } else { // Cannot be evaluated yet: just set the type expr->getType()->getUnbound()->staticKind = LiteralKind::Int; } } return nullptr; } // Case: static bools if (expr->getExpr()->getType()->getStaticKind() == LiteralKind::Bool) { if (expr->getOp() == "!") { if (expr->getExpr()->getType()->canRealize()) { bool value = getBoolLiteral(expr->getExpr()->getType()); LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value); return transform(N(!value)); } else { // Cannot be evaluated yet: just set the type expr->getType()->getUnbound()->staticKind = LiteralKind::Bool; } } return nullptr; } // Case: static integers if (expr->getOp() == "-" || expr->getOp() == "+" || expr->getOp() == "!" || expr->getOp() == "~") { if (expr->getExpr()->getType()->canRealize()) { int64_t value = getIntLiteral(expr->getExpr()->getType()); if (expr->getOp() == "+") ; else if (expr->getOp() == "-") value = -value; else if (expr->getOp() == "~") value = ~value; else value = !static_cast(value); LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value); if (expr->getOp() == "!") return transform(N(value)); else return transform(N(value)); } else { // Cannot be evaluated yet: just set the type expr->getType()->getUnbound()->staticKind = expr->getOp() == "!" ? LiteralKind::Bool : LiteralKind::Int; } } return nullptr; } /// Division and modulus implementations. std::pair divMod(const std::shared_ptr &ctx, int64_t a, int64_t b) { if (!b) { E(Error::STATIC_DIV_ZERO, ctx->getSrcInfo()); return {0, 0}; } else if (ctx->cache->pythonCompat) { // Use Python implementation. int64_t d = a / b; int64_t m = a - d * b; if (m && ((b ^ m) < 0)) { m += b; d -= 1; } return {d, m}; } else { // Use C implementation. return {a / b, a % b}; } } /// Evaluate a static binary expression and return the resulting static expression. /// If the expression cannot be evaluated yet, return nullptr. /// Supported operators: (strings) +, ==, != /// (ints) <, <=, >, >=, ==, !=, and, or, +, -, *, //, %, ^, |, & Expr *TypecheckVisitor::evaluateStaticBinary(const BinaryExpr *expr) { // Case: static strings if (expr->getRhs()->getType()->getStaticKind() == LiteralKind::String) { if (expr->getOp() == "+") { // `"a" + "b"` -> `"ab"` if (expr->getLhs()->getType()->getStrStatic() && expr->getRhs()->getType()->getStrStatic()) { auto value = getStrLiteral(expr->getLhs()->getType()) + getStrLiteral(expr->getRhs()->getType()); LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value); return transform(N(value)); } else { // Cannot be evaluated yet: just set the type expr->getType()->getUnbound()->staticKind = LiteralKind::String; } } else { // `"a" == "b"` -> `False` (also handles `!=`) if (expr->getLhs()->getType()->getStrStatic() && expr->getRhs()->getType()->getStrStatic()) { bool eq = getStrLiteral(expr->getLhs()->getType()) == getStrLiteral(expr->getRhs()->getType()); bool value = expr->getOp() == "==" ? eq : !eq; LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value); return transform(N(value)); } else { // Cannot be evaluated yet: just set the type expr->getType()->getUnbound()->staticKind = LiteralKind::Bool; } } return nullptr; } // Case: static integers if (expr->getLhs()->getType()->getStatic() && expr->getRhs()->getType()->getStatic()) { int64_t lvalue = expr->getLhs()->getType()->getIntStatic() ? getIntLiteral(expr->getLhs()->getType()) : getBoolLiteral(expr->getLhs()->getType()); int64_t rvalue = expr->getRhs()->getType()->getIntStatic() ? getIntLiteral(expr->getRhs()->getType()) : getBoolLiteral(expr->getRhs()->getType()); if (expr->getOp() == "<") lvalue = lvalue < rvalue; else if (expr->getOp() == "<=") lvalue = lvalue <= rvalue; else if (expr->getOp() == ">") lvalue = lvalue > rvalue; else if (expr->getOp() == ">=") lvalue = lvalue >= rvalue; else if (expr->getOp() == "==") lvalue = lvalue == rvalue; else if (expr->getOp() == "!=") lvalue = lvalue != rvalue; else if (expr->getOp() == "&&") lvalue = lvalue ? rvalue : lvalue; else if (expr->getOp() == "||") lvalue = lvalue ? lvalue : rvalue; else if (expr->getOp() == "+") lvalue = lvalue + rvalue; else if (expr->getOp() == "-") lvalue = lvalue - rvalue; else if (expr->getOp() == "*") lvalue = lvalue * rvalue; else if (expr->getOp() == "^") lvalue = lvalue ^ rvalue; else if (expr->getOp() == "&") lvalue = lvalue & rvalue; else if (expr->getOp() == "|") lvalue = lvalue | rvalue; else if (expr->getOp() == ">>") lvalue = lvalue >> rvalue; else if (expr->getOp() == "<<") lvalue = lvalue << rvalue; else if (expr->getOp() == "//") lvalue = divMod(ctx, lvalue, rvalue).first; else if (expr->getOp() == "%") lvalue = divMod(ctx, lvalue, rvalue).second; else seqassert(false, "unknown static operator {}", expr->getOp()); LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), lvalue); if (in(std::set{"==", "!=", "<", "<=", ">", ">="}, expr->getOp())) { return transform(N(lvalue)); } else if ((expr->getOp() == "&&" || expr->getOp() == "||") && expr->getLhs()->getType()->getBoolStatic() && expr->getLhs()->getType()->getBoolStatic()) { return transform(N(lvalue)); } else { return transform(N(lvalue)); } } else { // Cannot be evaluated yet: just set the type if (in(std::set{"==", "!=", "<", "<=", ">", ">="}, expr->getOp())) { expr->getType()->getUnbound()->staticKind = LiteralKind::Bool; } else if ((expr->getOp() == "&&" || expr->getOp() == "||") && expr->getLhs()->getType()->getBoolStatic() && expr->getLhs()->getType()->getBoolStatic()) { expr->getType()->getUnbound()->staticKind = LiteralKind::Bool; } else { expr->getType()->getUnbound()->staticKind = LiteralKind::Int; } } return nullptr; } /// Transform a simple binary expression. /// @example /// `a and b` -> `b if a else False` /// `a or b` -> `True if a else b` /// `a in b` -> `a.__contains__(b)` /// `a not in b` -> `not (a in b)` /// `a is not b` -> `not (a is b)` Expr *TypecheckVisitor::transformBinarySimple(const BinaryExpr *expr) { // Case: simple transformations if (expr->getOp() == "&&") { if (ctx->expectedType && ctx->expectedType->is("bool")) { return transform(N(expr->getLhs(), N(N(expr->getRhs(), "__bool__")), N(false))); } else { auto lt = realize(expr->getLhs()->getType()); auto rt = realize(expr->getRhs()->getType()); if (!lt || !rt) { return const_cast(expr); // delay } else { auto vn = getTemporaryVar("cond"); auto ve = N(N(vn), expr->getLhs()); if (lt->realizedName() == rt->realizedName()) { return N(ve, expr->getRhs(), N(vn)); } else { auto T = N( N("Union"), std::vector{N(lt->realizedName()), N(rt->realizedName())}); return N(ve, N(T, expr->getRhs()), N(clone(T), N(vn))); } } } } else if (expr->getOp() == "||") { if (ctx->expectedType && ctx->expectedType->is("bool")) { return transform(N(expr->getLhs(), N(true), N(N(expr->getRhs(), "__bool__")))); } else { auto lt = realize(expr->getLhs()->getType()); auto rt = realize(expr->getRhs()->getType()); if (!lt || !rt) { return const_cast(expr); // delay } else { auto vn = getTemporaryVar("cond"); auto ve = N(N(vn), expr->getLhs()); if (lt->realizedName() == rt->realizedName()) { return N(ve, N(vn), expr->getRhs()); } else { auto T = N( N(getMangledClass("std.internal.core", "Union")), std::vector{N(lt->realizedName()), N(rt->realizedName())}); return N(ve, N(T, N(vn)), N(clone(T), expr->getRhs())); } } } } else if (expr->getOp() == "not in") { return transform(N(N( N(N(expr->getRhs(), "__contains__"), expr->getLhs()), "__invert__"))); } else if (expr->getOp() == "in") { return transform( N(N(expr->getRhs(), "__contains__"), expr->getLhs())); } else if (expr->getOp() == "is") { if (cast(expr->getLhs()) && cast(expr->getRhs())) return transform(N(true)); else if (cast(expr->getLhs())) return transform(N(expr->getRhs(), "is", expr->getLhs())); } else if (expr->getOp() == "is not") { return transform( N("!", N(expr->getLhs(), "is", expr->getRhs()))); } return nullptr; } /// Transform a binary `is` expression by checking for type equality. Handle special `is /// None` cаses as well. See inside for details. Expr *TypecheckVisitor::transformBinaryIs(const BinaryExpr *expr) { seqassert(expr->op == "is", "not an is binary expression"); // Case: `is None` expressions if (cast(expr->getRhs())) { if (extractClassType(expr->getLhs())->is("NoneType")) return transform(N(true)); if (!extractClassType(expr->getLhs())->is(TYPE_OPTIONAL)) { // lhs is not optional: `return False` return transform(N(false)); } else { // Special case: Optional[Optional[... Optional[NoneType]]...] == NoneType auto g = extractClassType(expr->getLhs()); for (; extractClassGeneric(g)->is("Optional"); g = extractClassGeneric(g)->getClass()) ; if (!extractClassGeneric(g)->getClass()) { auto typ = instantiateUnbound(); typ->staticKind = LiteralKind::Bool; unify(expr->getType(), typ); return nullptr; } if (extractClassGeneric(g)->is("NoneType")) return transform(N(true)); // lhs is optional: `return lhs.__has__().__invert__()` if (expr->getType()->getUnbound() && expr->getType()->getStaticKind()) expr->getType()->getUnbound()->staticKind = LiteralKind::Runtime; return transform(N(N( N(N(expr->getLhs(), "__has__")), "__invert__"))); } } // Check the type equality (operand types and __raw__ pointers must match). auto lc = realize(expr->getLhs()->getType()); auto rc = realize(expr->getRhs()->getType()); if (!lc || !rc) { // Types not known: return early unify(expr->getType(), getStdLibType("bool")); return nullptr; } if (isTypeExpr(expr->getLhs()) && isTypeExpr(expr->getRhs())) return transform(N(lc->realizedName() == rc->realizedName())); if (!lc->getClass()->isRecord() && !rc->getClass()->isRecord()) { // Both reference types: `return lhs.__raw__() == rhs.__raw__()` return transform( N(N(N(expr->getLhs(), "__raw__")), "==", N(N(expr->getRhs(), "__raw__")))); } if (lc->is(TYPE_OPTIONAL)) { // lhs is optional: `return lhs.__is_optional__(rhs)` return transform( N(N(expr->getLhs(), "__is_optional__"), expr->getRhs())); } if (rc->is(TYPE_OPTIONAL)) { // rhs is optional: `return rhs.__is_optional__(lhs)` return transform( N(N(expr->getRhs(), "__is_optional__"), expr->getLhs())); } if (lc->realizedName() != rc->realizedName()) { // tuple names do not match: `return False` return transform(N(false)); } // Same tuple types: `return lhs == rhs` return transform(N(expr->getLhs(), "==", expr->getRhs())); } /// Return a binary magic opcode for the provided operator. std::pair TypecheckVisitor::getMagic(const std::string &op) const { // Table of supported binary operations and the corresponding magic methods. static auto magics = std::unordered_map{ {"+", "add"}, {"-", "sub"}, {"*", "mul"}, {"**", "pow"}, {"/", "truediv"}, {"//", "floordiv"}, {"@", "matmul"}, {"%", "mod"}, {"<", "lt"}, {"<=", "le"}, {">", "gt"}, {">=", "ge"}, {"==", "eq"}, {"!=", "ne"}, {"<<", "lshift"}, {">>", "rshift"}, {"&", "and"}, {"|", "or"}, {"^", "xor"}, }; auto mi = magics.find(op); if (mi == magics.end()) seqassert(false, "invalid binary operator '{}'", op); static auto rightMagics = std::unordered_map{ {"<", "gt"}, {"<=", "ge"}, {">", "lt"}, {">=", "le"}, {"==", "eq"}, {"!=", "ne"}, }; auto rm = in(rightMagics, op); return {mi->second, rm ? *rm : "r" + mi->second}; } /// Transform an in-place binary expression. /// @example /// `a op= b` -> `a.__iopmagic__(b)` /// @param isAtomic if set, use atomic magics if available. Expr *TypecheckVisitor::transformBinaryInplaceMagic(BinaryExpr *expr, bool isAtomic) { auto [magic, _] = getMagic(expr->getOp()); auto lt = expr->getLhs()->getClassType(); seqassert(lt, "lhs type not known"); FuncType *method = nullptr; // Atomic operations: check if `lhs.__atomic_op__(Ptr[lhs], rhs)` exists if (isAtomic) { auto ptr = instantiateType(getStdLibType("Ptr"), std::vector{lt}); if ((method = findBestMethod(lt, fmt::format("__atomic_{}__", magic), {ptr.get(), expr->getRhs()->getType()}))) { expr->lexpr = N(N("__ptr__"), expr->getLhs()); } } // In-place operations: check if `lhs.__iop__(lhs, rhs)` exists if (!method && expr->isInPlace()) { method = findBestMethod(lt, fmt::format("__i{}__", magic), std::vector{expr->getLhs(), expr->getRhs()}); } if (method) return transform( N(N(method->getFuncName()), expr->getLhs(), expr->getRhs())); return nullptr; } /// Transform a magic binary expression. /// @example /// `a op b` -> `a.__opmagic__(b)` Expr *TypecheckVisitor::transformBinaryMagic(const BinaryExpr *expr) { auto [magic, rightMagic] = getMagic(expr->getOp()); auto lt = expr->getLhs()->getType(); auto rt = expr->getRhs()->getType(); if (!lt->is("pyobj") && rt->is("pyobj")) { // Special case: `obj op pyobj` -> `rhs.__rmagic__(lhs)` on lhs // Assumes that pyobj implements all left and right magics auto l = getTemporaryVar("l"); auto r = getTemporaryVar("r"); return transform(N( N(N(l), expr->getLhs()), N(N(r), expr->getRhs()), N(N(N(r), fmt::format("__{}__", rightMagic)), N(l)))); } if (lt->getUnion()) { // Special case: `union op obj` -> `union.__magic__(rhs)` return transform(N( N(expr->getLhs(), fmt::format("__{}__", magic)), expr->getRhs())); } // Normal operations: check if `lhs.__magic__(lhs, rhs)` exists if (auto method = findBestMethod(lt->getClass(), fmt::format("__{}__", magic), std::vector{expr->getLhs(), expr->getRhs()})) { // Normal case: `__magic__(lhs, rhs)` return transform( N(N(method->getFuncName()), expr->getLhs(), expr->getRhs())); } // Right-side magics: check if `rhs.__rmagic__(rhs, lhs)` exists if (auto method = findBestMethod(rt->getClass(), fmt::format("__{}__", rightMagic), std::vector{expr->getRhs(), expr->getLhs()})) { auto l = getTemporaryVar("l"); auto r = getTemporaryVar("r"); return transform(N( N(N(l), expr->getLhs()), N(N(r), expr->getRhs()), N(N(method->getFuncName()), N(r), N(l)))); } return nullptr; } /// Given a tuple type and the expression `expr[index]`, check if an `index` is static /// (integer or slice). If so, statically extract the specified tuple item or a /// sub-tuple (if the index is a slice). /// Works only on normal tuples and partial functions. std::pair TypecheckVisitor::transformStaticTupleIndex(ClassType *tuple, Expr *expr, Expr *index) { bool isStaticString = expr->getType()->getStaticKind() == LiteralKind::String; if (isStaticString && !expr->getType()->canRealize()) { return {true, nullptr}; } else if (!isStaticString) { if (!tuple->isRecord()) return {false, nullptr}; if (!tuple->is(TYPE_TUPLE)) { if (tuple->is(TYPE_OPTIONAL)) { if (auto newTuple = extractClassGeneric(tuple)->getClass()) { return transformStaticTupleIndex( newTuple, transform(N(N(FN_OPTIONAL_UNWRAP), expr)), index); } else { return {true, nullptr}; } } return {false, nullptr}; } } // Extract the static integer value from expression auto getInt = [&](int64_t *o, Expr *e) { if (!e) return true; auto ore = transform(clone(e)); if (auto s = ore->getType()->getIntStatic()) { *o = s->value; return true; } return false; }; std::string str = isStaticString ? getStrLiteral(expr->getType()) : ""; auto sz = static_cast(isStaticString ? str.size() : getClassFields(tuple).size()); int64_t start = 0, stop = sz, step = 1, multiple = 0; if (getInt(&start, index)) { // Case: `tuple[int]` auto i = translateIndex(start, stop); if (i < 0 || i >= stop) E(Error::TUPLE_RANGE_BOUNDS, index, stop - 1, i); start = i; } else if (auto slice = cast(index)) { // Case: `tuple[int:int:int]` if (!getInt(&start, slice->getStart()) || !getInt(&stop, slice->getStop()) || !getInt(&step, slice->getStep())) return {false, nullptr}; // Adjust slice indices (Python slicing rules) if (slice->getStep() && !slice->getStart()) start = step > 0 ? 0 : (sz - 1); if (slice->getStep() && !slice->getStop()) stop = step > 0 ? sz : -(sz + 1); sliceAdjustIndices(sz, &start, &stop, step); multiple = 1; } else { return {false, nullptr}; } if (isStaticString) { if (!multiple) { return {true, transform(N(str.substr(start, 1)))}; } else { std::string newStr; for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) newStr += str[i]; return {true, transform(N(newStr))}; } } else { auto classFields = getClassFields(tuple); if (!multiple) { return {true, transform(N(expr, classFields[start].name))}; } else { // Generate a sub-tuple auto var = N(getTemporaryVar("tup")); auto ass = N(var, expr); std::vector te; for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) { if (i < 0 || i >= sz) E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i); te.push_back(N(clone(var), classFields[i].name)); } generateTuple(te.size()); Expr *e = transform(N(std::vector{ass}, N(N(TYPE_TUPLE), te))); return {true, e}; } } } /// Follow Python indexing rules for static tuple indices. /// Taken from https://github.com/python/cpython/blob/main/Objects/sliceobject.c. int64_t TypecheckVisitor::translateIndex(int64_t idx, int64_t len, bool clamp) const { if (idx < 0) idx += len; if (clamp) { if (idx < 0) idx = 0; if (idx > len) idx = len; } else if (idx < 0 || idx >= len) { E(Error::TUPLE_RANGE_BOUNDS, getSrcInfo(), len - 1, idx); } return idx; } /// Follow Python slice indexing rules for static tuple indices. /// Taken from https://github.com/python/cpython/blob/main/Objects/sliceobject.c. /// Quote (sliceobject.c:269): "this is harder to get right than you might think" int64_t TypecheckVisitor::sliceAdjustIndices(int64_t length, int64_t *start, int64_t *stop, int64_t step) const { if (step == 0) E(Error::SLICE_STEP_ZERO, getSrcInfo()); if (*start < 0) { *start += length; if (*start < 0) { *start = (step < 0) ? -1 : 0; } } else if (*start >= length) { *start = (step < 0) ? length - 1 : length; } if (*stop < 0) { *stop += length; if (*stop < 0) { *stop = (step < 0) ? -1 : 0; } } else if (*stop >= length) { *stop = (step < 0) ? length - 1 : length; } if (step < 0) { if (*stop < *start) { return (*start - *stop - 1) / (-step) + 1; } } else { if (*start < *stop) { return (*stop - *start - 1) / step + 1; } } return 0; } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/special.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include #include #include #include #include "codon/cir/attribute.h" #include "codon/cir/types/types.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/typecheck/typecheck.h" using namespace codon::error; namespace codon::ast { using namespace types; /// Generate ASTs for all internal functions that deal with vtable generation. /// Intended to be called once the typechecking is done. /// TODO: add JIT compatibility. void TypecheckVisitor::prepareVTables() { // def RTTIType._get_thunk_id(F, T): // return VID auto fn = getFunction(getMangledMethod("std.internal.core", "RTTIType", "_get_thunk_id")); auto oldAst = fn->ast; // Keep iterating as thunks can generate more thunks. std::unordered_set cache; for (bool added = true; added;) { added = false; for (const auto &[rn, real] : fn->realizations) { if (in(cache, rn)) continue; cache.insert(rn); added = true; fn->ast->suite = generateGetThunkIDAst(real->getType()); real->type->ast = fn->ast; LOG_REALIZE("[poly] {} : {}", real->type->debugString(2), fn->ast->toString(2)); realizeFunc(real->type.get(), true); fn->ast = oldAst; } } fn = getFunction( getMangledMethod("std.internal.core", "RTTIType", "_populate_vtables")); fn->ast->suite = generateClassPopulateVTablesAST(); auto typ = fn->realizations.begin()->second->getType(); typ->ast = fn->ast; LOG_REALIZE("[poly] {} : {}", typ->debugString(2), fn->ast->toString(2)); realizeFunc(typ, true); // def RTTIType._dist(B, D): // return Tuple[].__elemsize__ fn = getFunction(getMangledMethod("std.internal.core", "RTTIType", "_dist")); oldAst = fn->ast; for (const auto &real : fn->realizations | std::views::values) { fn->ast->suite = generateBaseDerivedDistAST(real->getType()); real->type->ast = fn->ast; LOG_REALIZE("[poly] {} : {}", real->type->debugString(2), fn->ast->toString(2)); realizeFunc(real->type.get(), true); } fn->ast = oldAst; } SuiteStmt *TypecheckVisitor::generateClassPopulateVTablesAST() { auto suite = N(); for (const auto &cls : ctx->cache->classes | std::views::values) { for (const auto &[r, real] : cls.realizations) { if (real->vtable.empty()) continue; // RTTIType._init_vtable(size, real.type) suite->addStmt(N( N(N(N("RTTIType"), "_init_vtable"), N(ctx->cache->thunkIds.size() + 2), N(r)))); LOG_REALIZE("[poly] {} -> {}", r, real->id); for (const auto &[key, fn] : real->vtable) { auto id = in(ctx->cache->thunkIds, key); seqassert(id, "key {} not found in thunkIds", key); std::vector ids; for (const auto &t : *fn) ids.push_back(N(t.getType()->realizedName())); // p[real.ID].__setitem__(f.ID, Function[](f).__raw__()) LOG_REALIZE("[poly] vtable[{}!!{}][{}] = {}", real->getType()->realizedName(), real->id, *id, fn->realizedName()); Expr *fnCall = N( N( N("Function"), std::vector{N(N(TYPE_TUPLE), ids), N(fn->getRetType()->realizedName())}), N(fn->realizedName())); suite->addStmt(N( N(N(N("RTTIType"), "_set_vtable_fn"), N(real->id), N(int64_t(*id)), N(N(fnCall, "__raw__")), N(r)))); } } } return suite; } SuiteStmt *TypecheckVisitor::generateBaseDerivedDistAST(FuncType *f) { auto baseTyp = extractFuncGeneric(f, 0)->getClass(); size_t baseTypFields = 0; for (auto &fld : getClassFields(baseTyp)) { if (fld.baseClass == baseTyp->name) { baseTypFields++; } } std::unordered_set alreadyDerived; for (auto &m : getClass(baseTyp)->mro) alreadyDerived.insert(m->name); auto derivedTyp = extractFuncGeneric(f, 1)->getClass(); auto fields = getClassFields(derivedTyp); auto types = std::vector{}; auto found = false; for (auto &fld : fields) { if (in(alreadyDerived, fld.baseClass)) { found = true; break; } else { auto ft = realize(instantiateType(fld.getType(), derivedTyp)); types.push_back(N(ft->realizedName())); } } seqassert(found || !baseTypFields, "cannot find distance between {} and {}", derivedTyp->name, baseTyp->name); Stmt *suite = N( N(N(N(TYPE_TUPLE), types), "__elemsize__")); return SuiteStmt::wrap(suite); } FunctionStmt *TypecheckVisitor::generateThunkAST(const FuncType *fp, ClassType *base, const ClassType *derived) { auto ct = instantiateType(extractClassType(derived->name), base->getClass()); std::vector args; for (const auto &a : *fp) args.push_back(a.getType()); args[0] = ct.get(); auto m = findBestMethod(ct->getClass(), getUnmangledName(fp->getFuncName()), args); if (!m) { // Print a nice error message std::vector a; for (auto &t : args) a.emplace_back(fmt::format("{}", t->prettyString())); std::string argsNice = fmt::format("({})", join(a, ", ")); E(Error::DOT_NO_ATTR_ARGS, getSrcInfo(), ct->prettyString(), getUnmangledName(fp->getFuncName()), argsNice); } std::vector ns; for (auto &a : args) ns.push_back(a->realizedName()); auto thunkName = fmt::format("_thunk.{}.{}.{}", base->name, fp->getFuncName(), join(ns, ".")); if (getFunction(getMangledFunc("", thunkName))) return nullptr; // Thunk contents: // def _thunk...(self, ): // return ( // RTTIType._to_derived(self, , ), // ) std::vector fnArgs; fnArgs.emplace_back("self", N(base->realizedName()), nullptr); for (size_t i = 1; i < args.size(); i++) fnArgs.emplace_back(getUnmangledName((*fp->ast)[i].getName()), N(args[i]->realizedName()), nullptr); std::vector callArgs; callArgs.emplace_back(N(N(N("RTTIType"), "_to_derived"), N("self"), N(base->realizedName()), N(derived->realizedName()))); for (size_t i = 1; i < args.size(); i++) callArgs.emplace_back(N(getUnmangledName((*fp->ast)[i].getName()))); std::vector debugCallArgs{N(base->name), N(fp->getFuncName()), N(join(ns, "."))}; debugCallArgs.insert(debugCallArgs.end(), callArgs.begin(), callArgs.end()); auto thunkAst = N( thunkName, nullptr, fnArgs, N( // For debugging N(N(N(getMangledMethod( "std.internal.core", "RTTIType", "_thunk_debug")), debugCallArgs)), N(N(N(m->ast->getName()), callArgs)))); thunkAst->setAttribute(Attr::Inline); return cast(transform(thunkAst)); } /// Generate thunks in all derived classes for a given virtual function (must be fully /// realizable) and the corresponding base class. /// @return unique thunk ID. SuiteStmt *TypecheckVisitor::generateGetThunkIDAst(types::FuncType *f) { auto fp = extractType(extractFuncGeneric(f))->getFunc(); auto cp = extractType(extractFuncGeneric(f, 1))->getClass(); seqassert(cp && cp->canRealize() && fp && fp->canRealize() && fp->getRetType()->canRealize(), "bad {}", f->debugString(2)); // TODO: ugly, ugly; surely needs refactoring // Function signature for storing thunks auto sig = [&](const types::FuncType *ft) -> std::string { std::vector gs; for (const auto &a : *ft) gs.emplace_back(a.getType()->realizedName()); gs.emplace_back("|"); for (auto &a : ft->funcGenerics) if (!a.name.empty()) gs.push_back(a.type->realizedName()); return fmt::format("{}:{}", getUnmangledName(ft->getFuncName()), join(gs, ",")); }; // Set up the base class information auto baseCls = cp->name; auto fnSig = sig(fp); auto key = std::make_pair(baseCls, fnSig); // Add or extract thunk ID auto baseRealization = getClassRealization(cp); seqassert(!in(baseRealization->vtable, key), "thunk {}.{} already added", baseCls, fnSig); if (!in(ctx->cache->thunkIds, key)) ctx->cache->thunkIds[key] = 1 + ctx->cache->thunkIds.size(); auto vid = ctx->cache->thunkIds[key]; baseRealization->vtable[key] = std::static_pointer_cast(fp->shared_from_this()); // Iterate through all derived classes and instantiate the corresponding thunk for (const auto &[clsName, cls] : ctx->cache->classes) { bool inMro = false; for (auto &m : cls.mro) if (m && m->is(baseCls)) { inMro = true; break; } if (inMro && clsName != baseCls) { for (const auto &real : cls.realizations | std::views::values) { if (auto thunkAst = generateThunkAST(fp, cp, real->getType())) { auto thunkFn = getFunction(thunkAst->name); auto ti = std::static_pointer_cast(instantiateType(thunkFn->getType())); auto tm = realizeFunc(ti.get(), true); seqassert(tm, "bad thunk {}", thunkFn->type->debugString(2)); seqassert(!in(real->vtable, key), "thunk {}.{} already added to {}", baseCls, fnSig, real->getType()->realizedName()); real->vtable[key] = std::static_pointer_cast(tm->shared_from_this()); LOG_REALIZE("[thunk]: {}->{}@{} == {}", baseCls, real->getType()->realizedName(), key, vid); } } } } return N(N(N(vid))); } SuiteStmt *TypecheckVisitor::generateFunctionCallInternalAST(FuncType *type) { // Special case: Function.__call_internal__ /// TODO: move to IR one day std::vector items; items.push_back(nullptr); std::vector ll; std::vector lla; seqassert(extractFuncArgType(type, 1)->is(TYPE_TUPLE), "bad function base: {}", extractFuncArgType(type, 1)->debugString(2)); auto as = extractFuncArgType(type, 1)->getClass()->generics.size(); auto [_, ag] = (*type->ast)[1].getNameWithStars(); for (int i = 0; i < as; i++) { ll.push_back(fmt::format("%{} = extractvalue {{}} %args, {}", i, i)); items.push_back(N(N(ag))); } items.push_back(N(N("TR"))); for (int i = 0; i < as; i++) { items.push_back(N(N(N(ag), N(i)))); lla.push_back(fmt::format("{{}} %{}", i)); } items.push_back(N(N("TR"))); ll.push_back(fmt::format("%{} = call {{}} %self({})", as, combine2(lla))); ll.push_back(fmt::format("ret {{}} %{}", as)); items[0] = N(N(combine2(ll, "\n"))); return N(items); } SuiteStmt *TypecheckVisitor::generateUnionNewAST(const FuncType *type) { auto unionType = type->funcParent->getUnion(); seqassert(unionType, "expected union, got {}", *(type->funcParent)); Stmt *suite = N(N(N(N("Union"), "_new"), N(type->ast->begin()->name), N(unionType->realizedName()))); return SuiteStmt::wrap(suite); } SuiteStmt *TypecheckVisitor::generateUnionTagAST(FuncType *type) { // return Union._get_data(union, T0) auto tag = getIntLiteral(extractFuncGeneric(type)); auto unionType = extractFuncArgType(type)->getUnion(); auto unionTypes = unionType->getRealizationTypes(); if (tag < 0 || tag >= unionTypes.size()) E(Error::CUSTOM, getSrcInfo(), "bad union tag"); auto selfVar = type->ast->begin()->name; auto suite = N(N(N( N(getMangledMethod("std.internal.core", "Union", "_get_data")), N(selfVar), N(unionTypes[tag]->realizedName())))); return suite; } SuiteStmt *TypecheckVisitor::generateNamedKeysAST(FuncType *type) { auto n = getIntLiteral(extractFuncGeneric(type)); if (n < 0 || n >= ctx->cache->generatedTupleNames.size()) E(Error::CUSTOM, getSrcInfo(), "bad namedkeys index"); std::vector s; for (auto &k : ctx->cache->generatedTupleNames[n]) s.push_back(N(k)); auto suite = N(N(N(s))); return suite; } SuiteStmt *TypecheckVisitor::generateTupleMulAST(FuncType *type) { auto n = std::max(static_cast(0), getIntLiteral(extractFuncGeneric(type))); auto t = extractFuncArgType(type)->getClass(); if (!t || !t->is(TYPE_TUPLE)) return nullptr; std::vector exprs; for (size_t i = 0; i < n; i++) for (size_t j = 0; j < t->generics.size(); j++) exprs.push_back( N(N(type->ast->front().getName()), N(j))); auto suite = N(N(N(exprs))); return suite; } /// Generate ASTs for dynamically generated functions. SuiteStmt *TypecheckVisitor::generateSpecialAst(types::FuncType *type) { // Clone the generic AST that is to be realized auto ast = type->ast; if (ast->hasAttribute(Attr::AutoGenerated) && endswith(ast->name, ".__iter__:0") && isHeterogenous(extractFuncArgType(type, 0))) { // Special case: do not realize auto-generated heterogenous __iter__ E(Error::EXPECTED_TYPE, getSrcInfo(), "iterable"); } else if (ast->hasAttribute(Attr::AutoGenerated) && endswith(ast->name, ".__getitem__:0") && isHeterogenous(extractFuncArgType(type, 0))) { // Special case: do not realize auto-generated heterogenous __getitem__ E(Error::EXPECTED_TYPE, getSrcInfo(), "iterable"); } else if (startswith(ast->name, "Function.__call_internal__")) { return generateFunctionCallInternalAST(type); } else if (startswith(ast->name, "Union.__new__")) { return generateUnionNewAST(type); } else if (startswith(ast->name, getMangledMethod("std.internal.core", "Union", "_tag"))) { return generateUnionTagAST(type); } else if (startswith(ast->name, getMangledMethod("std.internal.core", "NamedTuple", "_namedkeys"))) { return generateNamedKeysAST(type); } else if (startswith(ast->name, getMangledMethod("std.internal.core", "__magic__", "mul"))) { return generateTupleMulAST(type); } return nullptr; } /// Transform named tuples. /// @example /// `namedtuple("NT", ["a", ("b", int)])` -> ```@tuple /// class NT[T1]: /// a: T1 /// b: int``` Expr *TypecheckVisitor::transformNamedTuple(CallExpr *expr) { // Ensure that namedtuple call is valid auto name = getStrLiteral(extractFuncGeneric(expr->getExpr()->getType())); if (expr->size() != 1) E(Error::CALL_NAMEDTUPLE, expr); // Construct the class statement std::vector generics, params; auto orig = cast(expr->front().getExpr()->getOrigExpr()); size_t ti = 1; for (auto *i : *orig) { if (auto s = cast(i)) { generics.emplace_back(fmt::format("T{}", ti), N(TYPE_TYPE), nullptr, true); params.emplace_back(s->getValue(), N(fmt::format("T{}", ti++)), nullptr); continue; } auto t = cast(i); if (t && t->size() == 2 && cast((*t)[0])) { params.emplace_back(cast((*t)[0])->getValue(), transformType((*t)[1]), nullptr); continue; } E(Error::CALL_NAMEDTUPLE, i); } for (auto &g : generics) params.push_back(g); auto cls = N( N(name, params, nullptr, std::vector{N("tuple")})); if (auto err = ast::ScopingVisitor::apply(ctx->cache, cls)) throw exc::ParserException(std::move(err)); prependStmts->push_back(transform(cls)); return transformType(N(name)); } /// Transform partial calls (Python syntax). /// @example /// `partial(foo, 1, a=2)` -> `foo(1, a=2, ...)` Expr *TypecheckVisitor::transformFunctoolsPartial(CallExpr *expr) { if (expr->empty()) E(Error::CALL_PARTIAL, getSrcInfo()); std::vector args(expr->items.begin() + 1, expr->items.end()); args.emplace_back("", N(EllipsisExpr::PARTIAL)); return transform(N(expr->begin()->value, args)); } /// Typecheck superf method. This method provides the access to the previous matching /// overload. /// @example /// ```class cls: /// def foo(): print('foo 1') /// def foo(): /// superf() # access the previous foo /// print('foo 2') /// cls.foo()``` /// prints "foo 1" followed by "foo 2" Expr *TypecheckVisitor::transformSuperF(CallExpr *expr) { auto func = ctx->getBase()->type->getFunc(); // Find list of matching superf methods std::vector supers; if (!isDispatch(func)) { if (auto a = func->ast->getAttribute(Attr::ParentClass)) { auto c = getClass(a->value); if (auto m = in(c->methods, getUnmangledName(func->getFuncName()))) { for (auto &overload : getOverloads(*m)) { if (isDispatch(overload)) continue; if (overload == func->getFuncName()) break; supers.emplace_back(getFunction(overload)->getType()); } } std::ranges::reverse(supers); } } if (supers.empty()) E(Error::CALL_SUPERF, expr); seqassert(expr->size() == 1 && cast(expr->begin()->getExpr()), "bad superf call"); std::vector newArgs; for (const auto &a : *cast(expr->begin()->getExpr())) newArgs.emplace_back(a.getExpr()); auto m = findMatchingMethods( func->funcParent ? func->funcParent->getClass() : nullptr, supers, newArgs); if (m.empty()) E(Error::CALL_SUPERF, expr); auto c = transform(N(N(m[0]->getFuncName()), newArgs)); return c; } /// Typecheck and transform super method. Replace it with the current self object cast /// to the first inherited type. /// TODO: only an empty super() is currently supported. Expr *TypecheckVisitor::transformSuper() { if (!ctx->getBase()->type) E(Error::CALL_SUPER_PARENT, getSrcInfo()); auto funcTyp = ctx->getBase()->type->getFunc(); if (!funcTyp || !funcTyp->ast->hasAttribute(Attr::Method)) E(Error::CALL_SUPER_PARENT, getSrcInfo()); if (funcTyp->empty()) E(Error::CALL_SUPER_PARENT, getSrcInfo()); ClassType *typ = extractFuncArgType(funcTyp)->getClass(); auto cls = getClass(typ); auto cands = cls->staticParentClasses; if (cands.empty()) { // Dynamic inheritance: use MRO // TODO: maybe super() should be split into two separate functions... const auto &vCands = cls->mro; if (vCands.size() < 2) E(Error::CALL_SUPER_PARENT, getSrcInfo()); auto superTyp = instantiateType(vCands[1].get(), typ); auto self = N(funcTyp->ast->begin()->name); self->setType(typ->shared_from_this()); auto typExpr = N(superTyp->getClass()->name); typExpr->setType(instantiateTypeVar(superTyp->getClass())); return transform(N(N(N("Super"), "_super"), self, typExpr, N(1))); } const auto &name = cands.front(); // the first inherited type auto superTyp = instantiateType(extractClassType(name), typ); if (typ->isRecord()) { // Case: tuple types. Return `tuple(obj.args...)` std::vector members; for (auto &field : getClassFields(superTyp->getClass())) members.push_back( N(N(funcTyp->ast->begin()->getName()), field.name)); Expr *e = transform(N(members)); auto ft = getClassFieldTypes(superTyp->getClass()); for (size_t i = 0; i < ft.size(); i++) unify(ft[i].get(), extractClassGeneric(e->getType(), i)); // see super_tuple test e->setType(superTyp->shared_from_this()); return e; } else { // Case: reference types. Return `Super._super(self, T)` auto self = N(funcTyp->ast->begin()->name); self->setType(typ->shared_from_this()); return castToSuperClass(self, superTyp->getClass()); } } /// Typecheck __ptr__ method. This method creates a pointer to an object. Ensure that /// the argument is a variable binding. Expr *TypecheckVisitor::transformPtr(CallExpr *expr) { expr->begin()->value = transform(expr->begin()->getExpr()); auto head = getHeadExpr(expr->begin()->getExpr()); std::vector members; for (bool last = true;; last = false) { auto t = extractClassType(head); if (!t) return nullptr; if (!last && !t->isRecord()) E(Error::CALL_PTR_VAR, expr->begin()->getExpr()); if (auto id = cast(head)) { auto val = id ? ctx->find(id->getValue(), getTime()) : nullptr; if (!val || !val->isVar()) E(Error::CALL_PTR_VAR, expr->begin()->getExpr()); break; } else if (auto dot = cast(head)) { head = dot->getExpr(); } else { E(Error::CALL_PTR_VAR, expr->begin()->getExpr()); break; } } unify(expr->getType(), instantiateType(getStdLibType("Ptr"), {expr->begin()->getExpr()->getType()})); if (expr->begin()->getExpr()->isDone()) expr->setDone(); return nullptr; } /// Typecheck __array__ method. This method creates a stack-allocated array via alloca. Expr *TypecheckVisitor::transformArray(CallExpr *expr) { auto arrTyp = expr->expr->getType()->getFunc(); unify(expr->getType(), instantiateType(getStdLibType("Array"), {extractClassGeneric(arrTyp->getParentType())})); if (realize(expr->getType())) expr->setDone(); return nullptr; } /// Transform isinstance method to a static boolean expression. /// Special cases: /// `isinstance(obj, ByVal)` is True if `type(obj)` is a tuple type /// `isinstance(obj, ByRef)` is True if `type(obj)` is a reference type Expr *TypecheckVisitor::transformIsInstance(CallExpr *expr) { if (auto u = expr->getType()->getUnbound()) u->staticKind = LiteralKind::Bool; expr->begin()->value = transform(expr->begin()->getExpr()); auto typ = expr->begin()->getExpr()->getClassType(); if (!typ || !typ->canRealize()) return nullptr; expr->begin()->value = transform(expr->begin()->getExpr()); // again to realize it typ = extractClassType(typ); auto &typExpr = (*expr)[1].value; if (cast(typExpr)) { // Handle `isinstance(obj, (type1, type2, ...))` if (typExpr->getOrigExpr() && cast(typExpr->getOrigExpr())) { Expr *result = transform(N(false)); for (auto *i : *cast(typExpr->getOrigExpr())) { result = transform(N( result, "||", N(N("isinstance"), expr->begin()->getExpr(), i))); } return result; } } auto tei = cast(typExpr); if (tei && tei->getValue() == "type") { return transform(N(isTypeExpr(expr->begin()->value))); } else if (tei && tei->getValue() == "type[Tuple]") { return transform(N(typ->is(TYPE_TUPLE))); } else if (tei && tei->getValue() == "type[ByVal]") { return transform(N(typ->isRecord())); } else if (tei && tei->getValue() == "type[ByRef]") { return transform(N(!typ->isRecord())); } else if (tei && tei->getValue() == "type[Union]") { return transform(N(typ->getUnion() != nullptr)); } else if (!extractType(typExpr)->getUnion() && typ->getUnion()) { auto unionTypes = typ->getUnion()->getRealizationTypes(); int tag = -1; for (size_t ui = 0; ui < unionTypes.size(); ui++) { if (extractType(typExpr)->unify(unionTypes[ui], nullptr) >= 0) { tag = static_cast(ui); break; } } if (tag == -1) return transform(N(false)); return transform( N(N(N(N("Union"), "_get_tag"), expr->begin()->getExpr()), "==", N(tag))); } else if (typExpr->getType()->is("pyobj")) { if (typ->is("pyobj")) { return transform( N(N(getMangledFunc("std.internal.python", "_isinstance")), expr->begin()->getExpr(), (*expr)[1].getExpr())); } else { return transform(N(false)); } } typExpr = transformType(typExpr); auto targetType = extractType(typExpr); // Check static super types (i.e., statically inherited) as well for (auto &tx : getStaticSuperTypes(typ->getClass())) { types::Type::Unification us; auto s = tx->unify(targetType, &us); us.undo(); if (s >= 0) return transform(N(true)); } // Check RTTI super types for (auto &tx : getRTTISuperTypes(typ->getClass())) { types::Type::Unification us; auto s = tx->unify(targetType, &us); us.undo(); if (s >= 0) return transform(N(true)); } // Check runtime RTTI info if needed for (auto &tx : getRTTISuperTypes(targetType->getClass())) { types::Type::Unification us; auto s = tx->unify(typ, &us); us.undo(); if (s >= 0) { // check RTTI match return transform(N( N(getMangledMethod("std.internal.core", "RTTIType", "_isinstance")), expr->begin()->getExpr(), (*expr)[1].getExpr())); } } return transform(N(false)); } /// Transform staticlen method to a static integer expression. This method supports only /// static strings and tuple types. Expr *TypecheckVisitor::transformStaticLen(CallExpr *expr) { if (auto u = expr->getType()->getUnbound()) u->staticKind = LiteralKind::Int; expr->begin()->value = transform(expr->begin()->getExpr()); auto typ = extractType(expr->begin()->getExpr()); if (auto ss = typ->getStrStatic()) { // Case: staticlen on static strings return transform(N(ss->value.size())); } if (!typ->getClass()) return nullptr; if (typ->getUnion()) { if (realize(typ)) return transform(N(typ->getUnion()->getRealizationTypes().size())); return nullptr; } if (!typ->getClass()->isRecord()) E(Error::EXPECTED_TUPLE, expr->begin()->getExpr()); return transform(N(getClassFields(typ->getClass()).size())); } /// Transform hasattr method to a static boolean expression. /// This method also supports additional argument types that are used to check /// for a matching overload (not available in Python). Expr *TypecheckVisitor::transformHasAttr(CallExpr *expr) { if (auto u = expr->getType()->getUnbound()) u->staticKind = LiteralKind::Bool; auto typ = extractClassType((*expr)[0].getExpr()); if (!typ) return nullptr; auto member = getStrLiteral(extractFuncGeneric(expr->getExpr()->getType())); std::vector> args{{"", typ}}; if (auto tup = cast((*expr)[1].getExpr())) { for (auto &a : *tup) { a.value = transform(a.getExpr()); if (!a.getExpr()->getClassType()) return nullptr; auto t = extractType(a); args.emplace_back("", t->is("TypeWrap") ? extractClassGeneric(t) : t); } } for (auto &[n, ne] : extractNamedTuple((*expr)[2].getExpr())) { ne = transform(ne); auto t = extractType(ne); args.emplace_back(n, t->is("TypeWrap") ? extractClassGeneric(t) : t); } if (typ->getUnion()) { Expr *cond = nullptr; auto unionTypes = typ->getUnion()->getRealizationTypes(); for (auto &unionType : unionTypes) { auto tu = realize(unionType); if (!tu) return nullptr; auto te = N(tu->getClass()->realizedName()); auto e = N( N(N("isinstance"), (*expr)[0].getExpr(), te), "&&", N(N("hasattr"), te, N(member))); cond = !cond ? e : N(cond, "||", e); } if (!cond) return transform(N(false)); return transform(cond); } else if (typ->is("NamedTuple")) { if (!typ->canRealize()) return nullptr; auto id = getIntLiteral(typ); seqassert(id >= 0 && id < ctx->cache->generatedTupleNames.size(), "bad id: {}", id); const auto &names = ctx->cache->generatedTupleNames[id]; return transform(N(in(names, member))); } bool exists = !findMethod(typ->getClass(), member).empty() || findMember(typ->getClass(), member); if (exists && args.size() > 1) { exists &= findBestMethod(typ, member, args) != nullptr; } return transform(N(exists)); } /// Transform getattr method to a DotExpr. Expr *TypecheckVisitor::transformGetAttr(CallExpr *expr) { auto name = getStrLiteral(extractFuncGeneric(expr->expr->getType())); // special handling for NamedTuple if (expr->begin()->getExpr()->getType() && expr->begin()->getExpr()->getType()->is("NamedTuple")) { auto val = expr->begin()->getExpr()->getClassType(); auto id = getIntLiteral(val); seqassert(id >= 0 && id < ctx->cache->generatedTupleNames.size(), "bad id: {}", id); auto names = ctx->cache->generatedTupleNames[id]; for (size_t i = 0; i < names.size(); i++) if (names[i] == name) { return transform( N(N(expr->begin()->getExpr(), "args"), N(i))); } E(Error::DOT_NO_ATTR, expr, val->prettyString(), name); } return transform(N(expr->begin()->getExpr(), name)); } /// Transform setattr method to a AssignMemberStmt. Expr *TypecheckVisitor::transformSetAttr(CallExpr *expr) { auto attr = getStrLiteral(extractFuncGeneric(expr->expr->getType())); return transform( N(N((*expr)[0].getExpr(), attr, (*expr)[1].getExpr()), N(N("NoneType")))); } /// Raise a compiler error. Expr *TypecheckVisitor::transformCompileError(CallExpr *expr) const { auto msg = getStrLiteral(extractFuncGeneric(expr->expr->getType())); E(Error::CUSTOM, expr, msg.c_str()); return nullptr; } /// Convert a class to a tuple. Expr *TypecheckVisitor::transformTupleFn(CallExpr *expr) { for (auto &a : *expr) a.value = transform(a.getExpr()); auto cls = extractClassType(expr->begin()->getExpr()->getType()); if (!cls) return nullptr; // tuple(ClassType) is a tuple type that corresponds to a class if (isTypeExpr(expr->begin()->getExpr())) { if (!realize(cls)) return expr; std::vector items; auto ft = getClassFieldTypes(cls); for (size_t i = 0; i < ft.size(); i++) { auto rt = realize(ft[i].get()); seqassert(rt, "cannot realize '{}' in {}", getClass(cls)->fields[i].name, cls->debugString(2)); items.push_back(N(rt->realizedName())); } auto e = transform(N(N(TYPE_TUPLE), items)); return e; } std::vector args; std::string var = getTemporaryVar("tup"); for (auto &field : getClassFields(cls)) args.emplace_back(N(N(var), field.name)); return transform(N(N(N(var), expr->begin()->getExpr()), N(args))); } /// Transform type function to a type IdExpr identifier. Expr *TypecheckVisitor::transformTypeFn(CallExpr *expr) { expr->begin()->value = transform(expr->begin()->getExpr()); unify(expr->getType(), instantiateTypeVar(expr->begin()->getExpr()->getType())); if (!realize(expr->getType())) return nullptr; auto e = N(expr->getType()->realizedName()); e->setType(expr->getType()->shared_from_this()); e->setDone(); return e; } /// Transform static.realized function to a fully realized type identifier. Expr *TypecheckVisitor::transformRealizedFn(CallExpr *expr) { auto fn = extractType((*expr)[0].getExpr()->getType())->shared_from_this(); auto pt = (*expr)[0].getExpr()->getType()->getPartial(); if (!fn->getFunc() && pt && pt->isPartialEmpty()) { auto pft = pt->getPartialFunc()->generalize(0); fn = instantiateType(pft.get()); } if (!fn->getFunc()) E(Error::CALL_REALIZED_FN, (*expr)[0].getExpr()); auto argt = (*expr)[1].getExpr()->getType()->getClass(); if (!argt) return nullptr; seqassert(argt->name == TYPE_TUPLE, "not a tuple"); for (size_t i = 0; i < std::min(argt->size(), fn->getFunc()->size()); i++) { auto at = (*argt)[i]->is("TypeWrap") ? extractClassGeneric((*argt)[i]) : (*argt)[i]; unify((*fn->getFunc())[i], at); } if (auto f = realize(fn.get())) { auto e = N(f->getFunc()->realizedName()); e->setType(f->shared_from_this()); e->setDone(); return e; } return nullptr; } /// Transform __static_print__ function to a fully realized type identifier. Expr *TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) const { for (auto &a : *cast(expr->begin()->getExpr())) { fmt::print(stderr, "[print] {}: {} ({}){}\n", getSrcInfo(), a.getExpr()->getType() ? a.getExpr()->getType()->debugString(2) : "-", a.getExpr()->getType() ? a.getExpr()->getType()->realizedName() : "-", a.getExpr()->getType()->getStatic() ? " [static]" : ""); } return nullptr; } /// Transform static.has_rtti to a static boolean that indicates RTTI status of a type. Expr *TypecheckVisitor::transformHasRttiFn(const CallExpr *expr) { if (auto u = expr->getType()->getUnbound()) u->staticKind = LiteralKind::Bool; auto t = extractFuncGeneric(expr->getExpr()->getType())->getClass(); if (!t) return nullptr; return transform(N(getClass(t)->hasRTTI())); } // Transform internal.static calls Expr *TypecheckVisitor::transformStaticFnCanCall(CallExpr *expr) { if (auto u = expr->getType()->getUnbound()) u->staticKind = LiteralKind::Bool; auto typ = extractClassType((*expr)[0].getExpr()); if (!typ) return nullptr; auto inargs = unpackTupleTypes((*expr)[1].getExpr()); auto kwargs = unpackTupleTypes((*expr)[2].getExpr()); seqassert(inargs && kwargs, "bad call to fn_can_call"); std::vector callArgs; for (auto &[v, t] : *inargs) { callArgs.emplace_back(v, N()); // dummy expression callArgs.back().getExpr()->setType(t->shared_from_this()); } for (auto &[v, t] : *kwargs) { callArgs.emplace_back(v, N()); // dummy expression callArgs.back().getExpr()->setType(t->shared_from_this()); } if (auto fn = typ->getFunc()) { return transform(N(canCall(fn, callArgs) >= 0)); } else if (auto pt = typ->getPartial()) { return transform(N(canCall(pt->getPartialFunc(), callArgs, pt) >= 0)); } else { compilationWarning("cannot use fn_can_call on non-functions", getSrcInfo().file, getSrcInfo().line, getSrcInfo().col); return transform(N(false)); } } Expr *TypecheckVisitor::transformStaticFnArgHasType(CallExpr *expr) { if (auto u = expr->getType()->getUnbound()) u->staticKind = LiteralKind::Bool; auto fn = extractFunction(expr->begin()->getExpr()->getType()); if (!fn) E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", expr->begin()->getExpr()->getType()->prettyString()); auto idx = extractFuncGeneric(expr->getExpr()->getType())->getIntStatic(); seqassert(idx, "expected a static integer"); return transform(N(idx->value >= 0 && idx->value < fn->size() && (*fn)[idx->value]->canRealize())); } Expr *TypecheckVisitor::transformStaticFnArgGetType(CallExpr *expr) { auto fn = extractFunction(expr->begin()->getExpr()->getType()); if (!fn) E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", expr->begin()->getExpr()->getType()->prettyString()); auto idx = extractFuncGeneric(expr->getExpr()->getType())->getIntStatic(); seqassert(idx, "expected a static integer"); if (idx->value < 0 || idx->value >= fn->size() || !(*fn)[idx->value]->canRealize()) E(Error::CUSTOM, getSrcInfo(), "argument does not have type"); return transform(N((*fn)[idx->value]->realizedName())); } Expr *TypecheckVisitor::transformStaticFnArgs(CallExpr *expr) { auto fn = extractFunction(expr->begin()->value->getType()); if (!fn) E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", expr->begin()->getExpr()->getType()->prettyString()); std::vector v; v.reserve(fn->ast->size()); for (const auto &a : *fn->ast) { auto [_, n] = a.getNameWithStars(); n = getUnmangledName(n); v.push_back(N(n)); } return transform(N(v)); } Expr *TypecheckVisitor::transformStaticFnHasDefault(CallExpr *expr) { if (auto u = expr->getType()->getUnbound()) u->staticKind = LiteralKind::Bool; auto fn = extractFunction(expr->begin()->getExpr()->getType()); if (!fn) E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", expr->begin()->getExpr()->getType()->prettyString()); auto idx = extractFuncGeneric(expr->getExpr()->getType())->getIntStatic(); seqassert(idx, "expected a static integer"); if (idx->value < 0 || idx->value >= fn->ast->size()) E(Error::CUSTOM, getSrcInfo(), "argument out of bounds"); return transform(N((*fn->ast)[idx->value].getDefault() != nullptr)); } Expr *TypecheckVisitor::transformStaticFnGetDefault(CallExpr *expr) { auto fn = extractFunction(expr->begin()->getExpr()->getType()); if (!fn) E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", expr->begin()->getExpr()->getType()->prettyString()); auto idx = extractFuncGeneric(expr->getExpr()->getType())->getIntStatic(); seqassert(idx, "expected a static integer"); if (idx->value < 0 || idx->value >= fn->ast->size()) E(Error::CUSTOM, getSrcInfo(), "argument out of bounds"); return transform((*fn->ast)[idx->value].getDefault()); } Expr *TypecheckVisitor::transformStaticFnWrapCallArgs(CallExpr *expr) { auto typ = expr->begin()->getExpr()->getClassType(); if (!typ) return nullptr; auto fn = extractFunction(expr->begin()->getExpr()->getType()); if (!fn) E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", expr->begin()->getExpr()->getType()->prettyString()); std::vector callArgs; if (auto tup = cast((*expr)[1].getExpr()->getOrigExpr())) { for (auto *a : *tup) { callArgs.emplace_back("", a); } } if (auto kw = cast((*expr)[1].getExpr()->getOrigExpr())) { auto kwCls = getClass(expr->getClassType()); seqassert(kwCls, "cannot find {}", expr->getClassType()->name); for (size_t i = 0; i < kw->size(); i++) { callArgs.emplace_back(kwCls->fields[i].name, (*kw)[i].getExpr()); } } auto tempCall = transform(N(N(fn->getFuncName()), callArgs)); if (!tempCall->isDone()) return nullptr; std::vector tupArgs; for (auto &a : *cast(tempCall)) tupArgs.push_back(a.getExpr()); return transform(N(tupArgs)); } Expr *TypecheckVisitor::transformStaticVars(CallExpr *expr) { auto t = extractFuncGeneric(expr->getExpr()->getType()); if (!t || !t->getClass()) return nullptr; auto withIdx = getBoolLiteral(t); types::ClassType *typ = nullptr; std::vector tupleItems; auto e = transform(expr->begin()->getExpr()); if (!((typ = e->getClassType()))) return nullptr; size_t idx = 0; for (auto &f : getClassFields(typ)) { auto k = N(f.name); auto v = N(expr->begin()->value, f.name); if (withIdx) { auto i = N(idx); tupleItems.push_back(N(std::vector{i, k, v})); } else { tupleItems.push_back(N(std::vector{k, v})); } idx++; } return transform(N(tupleItems)); } Expr *TypecheckVisitor::transformStaticTupleType(const CallExpr *expr) { auto funcTyp = expr->getExpr()->getType()->getFunc(); auto t = extractFuncGeneric(funcTyp)->getClass(); if (!t || !realize(t)) return nullptr; auto n = getIntLiteral(extractFuncGeneric(funcTyp, 1)); types::TypePtr typ = nullptr; auto f = getClassFields(t); if (n < 0 || n >= f.size()) E(Error::CUSTOM, getSrcInfo(), "invalid index"); auto rt = realize(instantiateType(f[n].getType(), t)); return transform(N(rt->realizedName())); } /// Transform staticlen method to a static integer expression. This method supports only /// static strings and tuple types. Expr *TypecheckVisitor::transformStaticFormat(CallExpr *expr) { if (auto u = expr->getType()->getUnbound()) u->staticKind = LiteralKind::String; auto funcTyp = expr->getExpr()->getType()->getFunc(); auto fmt = getStrLiteral(extractFuncGeneric(funcTyp, 0)); auto arg = getStrLiteral(extractFuncGeneric(funcTyp, 1)); size_t start = 0; fmt::dynamic_format_arg_store store; while ((start = fmt.find("%%", start)) != std::string::npos) { fmt.replace(start, 2, "{}"); store.push_back(arg); start += 2; } return transform(N(fmt::vformat(fmt, store))); } /// Transform staticlen method to a static integer expression. This method supports only /// static strings and tuple types. Expr *TypecheckVisitor::transformStaticIntToStr(CallExpr *expr) { if (auto u = expr->getType()->getUnbound()) u->staticKind = LiteralKind::String; auto funcTyp = expr->getExpr()->getType()->getFunc(); auto val = getIntLiteral(extractFuncGeneric(funcTyp, 0)); return transform(N(std::to_string(val))); } std::vector TypecheckVisitor::populateStaticTupleLoop(Expr *iter, const std::vector &vars) { std::vector block; auto stmt = N(N(vars[0]), nullptr, nullptr); auto call = cast(cast(iter)->front()); if (vars.size() != 1) E(Error::CUSTOM, getSrcInfo(), "expected one item"); for (auto &a : *call) { stmt->rhs = transform(clean_clone(a.value)); if (auto st = stmt->rhs->getType()->getStatic()) { stmt->type = N(N("Literal"), N(st->name)); } else { stmt->type = nullptr; } block.push_back(clone(stmt)); } return block; } std::vector TypecheckVisitor::populateSimpleStaticRangeLoop(Expr *iter, const std::vector &vars) { if (vars.size() != 1) E(Error::CUSTOM, getSrcInfo(), "expected one item"); auto fn = cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; auto stmt = N(N(vars[0]), nullptr, nullptr); std::vector block; auto ed = getIntLiteral(extractFuncGeneric(fn->getType())); if (ed > MAX_STATIC_ITER) E(Error::STATIC_RANGE_BOUNDS, fn, MAX_STATIC_ITER, ed); for (int64_t i = 0; i < ed; i++) { stmt->rhs = N(i); stmt->type = N(N("Literal"), N("int")); block.push_back(clone(stmt)); } return block; } std::vector TypecheckVisitor::populateStaticRangeLoop(Expr *iter, const std::vector &vars) { if (vars.size() != 1) E(Error::CUSTOM, getSrcInfo(), "expected one item"); auto fn = cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; auto stmt = N(N(vars[0]), nullptr, nullptr); std::vector block; auto st = getIntLiteral(extractFuncGeneric(fn->getType(), 0)); auto ed = getIntLiteral(extractFuncGeneric(fn->getType(), 1)); auto step = getIntLiteral(extractFuncGeneric(fn->getType(), 2)); if (std::abs(st - ed) / std::abs(step) > MAX_STATIC_ITER) E(Error::STATIC_RANGE_BOUNDS, fn, MAX_STATIC_ITER, std::abs(st - ed) / std::abs(step)); for (int64_t i = st; step > 0 ? i < ed : i > ed; i += step) { stmt->rhs = N(i); stmt->type = N(N("Literal"), N("int")); block.push_back(clone(stmt)); } return block; } std::vector TypecheckVisitor::populateStaticFnOverloadsLoop(Expr *iter, const std::vector &vars) { if (vars.size() != 1) E(Error::CUSTOM, getSrcInfo(), "expected one item"); auto fn = cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; auto stmt = N(N(vars[0]), nullptr, nullptr); std::vector block; auto typ = extractFuncGeneric(fn->getType(), 0)->getClass(); seqassert(extractFuncGeneric(fn->getType(), 1)->getStrStatic(), "bad static string"); auto name = getStrLiteral(extractFuncGeneric(fn->getType(), 1)); std::vector overloads; if (typ->is("NoneType")) { if (auto func = ctx->cache->typeCtx->find(name)) { auto root = getRootName(func->getType()->getFunc()); overloads = getOverloads(root); } } else { if (auto n = in(getClass(typ)->methods, name)) overloads = getOverloads(*n); } if (!overloads.empty()) { for (int mti = static_cast(overloads.size()) - 1; mti >= 0; mti--) { auto &method = overloads[mti]; auto cfn = getFunction(method); if (isDispatch(method) || !cfn->type) continue; if (isHeterogenous(typ)) { if (cfn->ast->hasAttribute(Attr::AutoGenerated) && (endswith(cfn->ast->name, ".__iter__:0") || endswith(cfn->ast->name, ".__getitem__:0"))) { // ignore __getitem__ and other heterogenuous methods continue; } } stmt->rhs = N(method); block.push_back(clone(stmt)); } } return block; } std::vector TypecheckVisitor::populateStaticEnumerateLoop(Expr *iter, const std::vector &vars) { if (vars.size() != 2) E(Error::CUSTOM, getSrcInfo(), "expected two items"); auto fn = cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; std::vector block; auto typ = extractFuncArgType(fn->getType())->getClass(); if (typ && typ->isRecord()) { for (size_t i = 0; i < getClassFields(typ).size(); i++) { auto b = N(std::vector{ N(N(vars[0]), N(i), N(N("Literal"), N("int"))), N( N(vars[1]), N(clone((*cast(iter))[0].value), N(i)))}); block.push_back(b); } } else { E(Error::CUSTOM, getSrcInfo(), "static.enumerate needs a tuple"); } return block; } std::vector TypecheckVisitor::populateStaticVarsLoop(Expr *iter, const std::vector &vars) { auto fn = cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; bool withIdx = getBoolLiteral(extractFuncGeneric(fn->getType())); if (!withIdx && vars.size() != 2) E(Error::CUSTOM, getSrcInfo(), "expected two items"); else if (withIdx && vars.size() != 3) E(Error::CUSTOM, getSrcInfo(), "expected three items"); std::vector block; auto typ = extractFuncArgType(fn->getType())->getClass(); size_t idx = 0; if (typ->is("TypeWrap")) { // type passed! for (auto &f : getClass(extractClassGeneric(typ))->classVars) { std::vector stmts; if (withIdx) { stmts.push_back( N(N(vars[0]), N(idx), N(N("Literal"), N("int")))); } stmts.push_back( N(N(vars[withIdx]), N(f.first), N(N("Literal"), N("str")))); stmts.push_back(N(N(vars[withIdx + 1]), N(f.second))); auto b = N(stmts); block.push_back(b); idx++; } } else { for (auto &f : getClassFields(typ)) { std::vector stmts; if (withIdx) { stmts.push_back( N(N(vars[0]), N(idx), N(N("Literal"), N("int")))); } stmts.push_back( N(N(vars[withIdx]), N(f.name), N(N("Literal"), N("str")))); stmts.push_back( N(N(vars[withIdx + 1]), N(clone((*cast(iter))[0].value), f.name))); auto b = N(stmts); block.push_back(b); idx++; } } return block; } std::vector TypecheckVisitor::populateStaticVarTypesLoop(Expr *iter, const std::vector &vars) { auto fn = cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; auto typ = realize(extractFuncGeneric(fn->getType(), 0)->getClass()); bool withIdx = getBoolLiteral(extractFuncGeneric(fn->getType(), 1)); if (!withIdx && vars.size() != 1) E(Error::CUSTOM, getSrcInfo(), "expected one item"); else if (withIdx && vars.size() != 2) E(Error::CUSTOM, getSrcInfo(), "expected two items"); seqassert(typ, "vars_types expects a realizable type, got '{}' instead", *(extractFuncGeneric(fn->getType(), 0))); std::vector block; if (auto utyp = typ->getUnion()) { for (size_t i = 0; i < utyp->getRealizationTypes().size(); i++) { std::vector stmts; if (withIdx) { stmts.push_back( N(N(vars[0]), N(i), N(N("Literal"), N("int")))); } stmts.push_back( N(N(vars[1]), N(utyp->getRealizationTypes()[i]->realizedName()))); auto b = N(stmts); block.push_back(b); } } else { size_t idx = 0; for (auto &f : getClassFields(typ->getClass())) { auto ta = realize(instantiateType(f.type.get(), typ->getClass())); seqassert(ta, "cannot realize '{}'", f.type->debugString(2)); std::vector stmts; if (withIdx) { stmts.push_back( N(N(vars[0]), N(idx), N(N("Literal"), N("int")))); } stmts.push_back( N(N(vars[withIdx]), N(ta->realizedName()))); auto b = N(stmts); block.push_back(b); idx++; } } return block; } std::vector TypecheckVisitor::populateStaticHeterogenousTupleLoop( Expr *iter, const std::vector &vars) { std::vector block; std::string tupleVar; Stmt *preamble = nullptr; if (!cast(iter)) { tupleVar = getTemporaryVar("tuple"); preamble = N(N(tupleVar), iter); } else { tupleVar = cast(iter)->getValue(); } for (size_t i = 0; i < iter->getClassType()->generics.size(); i++) { auto s = N(); if (vars.size() > 1) { for (size_t j = 0; j < vars.size(); j++) { s->addStmt( N(N(vars[j]), N(N(N(tupleVar), N(i)), N(j)))); } } else { s->addStmt(N(N(vars[0]), N(N(tupleVar), N(i)))); } block.push_back(s); } block.push_back(preamble); return block; } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/typecheck.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "typecheck.h" #include #include #include #include #include "codon/cir/pyextension.h" #include "codon/cir/util/irtools.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/match.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/typecheck/ctx.h" using namespace codon::error; namespace codon::ast { using namespace types; using namespace matcher; /// Simplify an AST node. Load standard library if needed. /// @param cache Pointer to the shared cache ( @c Cache ) /// @param file Filename to be used for error reporting /// @param barebones Use the bare-bones standard library for faster testing /// @param defines User-defined static values (typically passed as `codon run -DX=Y`). /// Each value is passed as a string. Stmt *TypecheckVisitor::apply( Cache *cache, Stmt *node, const std::string &file, const std::unordered_map &defines, const std::unordered_map &earlyDefines, bool barebones) { auto preamble = cache->N(); seqassertn(cache->module, "cache's module is not set"); // Load standard library if it has not been loaded if (!in(cache->imports, STDLIB_IMPORT)) loadStdLibrary(cache, preamble, earlyDefines, barebones); // Set up the context and the cache auto ctx = std::make_shared(cache, file); cache->imports[file].update(MAIN_IMPORT, file, ctx); cache->imports[MAIN_IMPORT] = cache->imports[file]; ctx->setFilename(file); ctx->moduleName = {ImportFile::PACKAGE, file, MODULE_MAIN}; // Prepare the code auto tv = TypecheckVisitor(ctx, preamble); auto *suite = tv.N(); auto &stmts = suite->items; // Load compile-time defines (e.g., codon run -DFOO=1 ...) for (auto &d : defines) { if (startswith(d.second, "str:")) { stmts.push_back(tv.N( tv.N(d.first), tv.N(d.second.substr(4)), tv.N(tv.N("Literal"), tv.N("str")))); } else if (startswith(d.second, "bool:")) { stmts.push_back(tv.N( tv.N(d.first), tv.N(d.second == "bool:True" ? true : false), tv.N(tv.N("Literal"), tv.N("bool")))); } else { stmts.push_back(tv.N( tv.N(d.first), tv.N(startswith(d.second, "int:") ? d.second.substr(4) : d.second), tv.N(tv.N("Literal"), tv.N("int")))); } } // Set up __name__ stmts.push_back( tv.N(tv.N("__name__"), tv.N(MODULE_MAIN))); stmts.push_back(tv.N(tv.N("__file__"), tv.N(file))); stmts.push_back(node); if (auto err = ScopingVisitor::apply(cache, suite, &ctx->globalShadows)) throw exc::ParserException(std::move(err)); auto n = tv.inferTypes(suite, true); if (!n) { auto errors = tv.findTypecheckErrors(suite); throw exc::ParserException(errors); } suite = tv.N(); suite->items.push_back(preamble); // Add dominated assignment declarations suite->items.insert(suite->items.end(), ctx->scope.back().stmts.begin(), ctx->scope.back().stmts.end()); suite->items.push_back(n); if (cast(n)) tv.prepareVTables(); if (!ctx->cache->errors.empty()) throw exc::ParserException(ctx->cache->errors); return suite; } void TypecheckVisitor::loadStdLibrary( Cache *cache, SuiteStmt *preamble, const std::unordered_map &earlyDefines, bool barebones) { // Load the internal.__init__ auto stdlib = std::make_shared(cache, STDLIB_IMPORT); auto stdlibPath = getImportFile(cache, STDLIB_INTERNAL_MODULE, "", true); const std::string initFile = "__init__.codon"; if (!stdlibPath || !endswith(stdlibPath->path, initFile)) E(Error::COMPILER_NO_STDLIB); /// Use __init_test__ for faster testing (e.g., #%% name,barebones) /// TODO: get rid of it one day... if (barebones) { stdlibPath->path = stdlibPath->path.substr(0, stdlibPath->path.size() - initFile.size()) + "__init_test__.codon"; } stdlib->setFilename(stdlibPath->path); cache->imports[stdlibPath->path].update(STDLIB_IMPORT, stdlibPath->path, stdlib); cache->imports[STDLIB_IMPORT] = cache->imports[stdlibPath->path]; // Load the standard library stdlib->isStdlibLoading = true; stdlib->moduleName = {ImportFile::STDLIB, stdlibPath->path, "__init__"}; stdlib->setFilename(stdlibPath->path); // 1. Core definitions cache->classes[VAR_CLASS_TOPLEVEL] = Cache::Class(); auto coreOrErr = parseCode(stdlib->cache, stdlibPath->path, "from internal.core import *"); if (!coreOrErr) throw exc::ParserException(coreOrErr.takeError()); auto *core = *coreOrErr; if (auto err = ScopingVisitor::apply(stdlib->cache, core)) throw exc::ParserException(std::move(err)); auto tv = TypecheckVisitor(stdlib, preamble); core = tv.inferTypes(core, true); preamble->addStmt(core); // 2. Load early compile-time defines (for standard library) for (auto &d : earlyDefines) { AssignStmt *s = nullptr; if (startswith(d.second, "str:")) { s = tv.N( tv.N(d.first), tv.N(d.second.substr(4)), tv.N(tv.N("Literal"), tv.N("str"))); } else if (startswith(d.second, "bool:")) { s = tv.N( tv.N(d.first), tv.N(d.second == "bool:True" ? true : false), tv.N(tv.N("Literal"), tv.N("bool"))); } else { s = tv.N( tv.N(d.first), tv.N(startswith(d.second, "int:") ? d.second.substr(4) : d.second), tv.N(tv.N("Literal"), tv.N("int"))); } auto def = tv.transform(s); preamble->addStmt(def); } // 3. Load stdlib auto stdOrErr = parseFile(stdlib->cache, stdlibPath->path); if (!stdOrErr) throw exc::ParserException(stdOrErr.takeError()); auto std = *stdOrErr; if (auto err = ScopingVisitor::apply(stdlib->cache, std, &stdlib->globalShadows)) throw exc::ParserException(std::move(err)); tv = TypecheckVisitor(stdlib, preamble); std = tv.inferTypes(std, true); preamble->addStmt(std); stdlib->isStdlibLoading = false; } /// Simplify an AST node. Assumes that the standard library is loaded. Stmt *TypecheckVisitor::apply(const std::shared_ptr &ctx, Stmt *node, const std::string &file) { auto oldFilename = ctx->getFilename(); ctx->setFilename(file); auto preamble = ctx->cache->N(); auto tv = TypecheckVisitor(ctx, preamble); auto n = tv.inferTypes(node, true); ctx->setFilename(oldFilename); if (!n) { auto errors = tv.findTypecheckErrors(node); throw exc::ParserException(errors); } if (!ctx->cache->errors.empty()) throw exc::ParserException(ctx->cache->errors); auto suite = ctx->cache->N(preamble); suite->addStmt(n); return suite; } /**************************************************************************************/ TypecheckVisitor::TypecheckVisitor(std::shared_ptr ctx, SuiteStmt *pre, const std::shared_ptr> &stmts) : ctx(std::move(ctx)), resultExpr(nullptr), resultStmt(nullptr) { preamble = pre ? pre : this->ctx->cache->N(); prependStmts = stmts ? stmts : std::make_shared>(); } /**************************************************************************************/ Expr *TypecheckVisitor::transform(Expr *expr) { return transform(expr, true); } /// Transform an expression node. Expr *TypecheckVisitor::transform(Expr *expr, bool allowTypes) { if (!expr) return nullptr; if (!expr->getType()) expr->setType(instantiateUnbound()); if (!expr->isDone()) { TypecheckVisitor v(ctx, preamble, prependStmts); v.setSrcInfo(expr->getSrcInfo()); ctx->pushNode(expr); expr->accept(v); ctx->popNode(); if (v.resultExpr) { for (auto it = expr->attributes_begin(); it != expr->attributes_end(); ++it) { const auto *attr = expr->getAttribute(*it); if (!v.resultExpr->hasAttribute(*it)) v.resultExpr->setAttribute(*it, attr->clone()); } v.resultExpr->setOrigExpr(expr->getOrigExpr() ? expr->getOrigExpr() : expr); expr = v.resultExpr; if (!expr->getType()) expr->setType(instantiateUnbound()); } if (!allowTypes && expr && isTypeExpr(expr)) E(Error::UNEXPECTED_TYPE, expr, "type"); if (expr->isDone()) ctx->changedNodes++; } if (expr) { if (!expr->hasAttribute(Attr::ExprDoNotRealize)) { if (auto p = realize(expr->getType())) { unify(expr->getType(), p); } } LOG_TYPECHECK("[expr] {}: {}{}", getSrcInfo(), *(expr), expr->isDone() ? "[done]" : ""); } return expr; } /// Transform a type expression node. /// Special case: replace `None` with `NoneType` /// @throw @c ParserException if a node is not a type (use @c transform instead). Expr *TypecheckVisitor::transformType(Expr *expr, bool simple) { if (cast(expr)) { auto ne = N("NoneType"); ne->setSrcInfo(expr->getSrcInfo()); expr = ne; } if (simple != ctx->simpleTypes) std::swap(ctx->simpleTypes, simple); expr = transform(expr); if (simple != ctx->simpleTypes) std::swap(ctx->simpleTypes, simple); if (expr) { if (expr->getType()->getStaticKind()) { ; } else if (isTypeExpr(expr)) { expr->setType(instantiateType(expr->getType())); } else if (expr->getType()->getUnbound() && !expr->getType()->getUnbound()->genericName.empty()) { // generic! expr->setType(instantiateType(expr->getType())); } else if (expr->getType()->getUnbound() && expr->getType()->getUnbound()->trait) { // generic (is type)! expr->setType(instantiateType(expr->getType())); } else { E(Error::EXPECTED_TYPE, expr, "type"); } } return expr; } void TypecheckVisitor::defaultVisit(Expr *e) { seqassert(false, "unexpected AST node {}", e->toString()); } /// Transform a statement node. Stmt *TypecheckVisitor::transform(Stmt *stmt) { if (!stmt || stmt->isDone()) return stmt; TypecheckVisitor v(ctx, preamble); v.setSrcInfo(stmt->getSrcInfo()); ctx->pushNode(stmt); int64_t time = 0; if (auto a = stmt->getAttribute(Attr::ExprTime)) time = a->value; auto oldTime = ctx->time; ctx->time = time; stmt->accept(v); ctx->time = oldTime; ctx->popNode(); if (v.resultStmt) stmt = v.resultStmt; if (!v.prependStmts->empty()) { if (stmt) v.prependStmts->push_back(stmt); bool done = true; for (auto &s : *(v.prependStmts)) done &= s->isDone(); stmt = N(*v.prependStmts); if (done) stmt->setDone(); } if (stmt->isDone()) ctx->changedNodes++; return stmt; } void TypecheckVisitor::defaultVisit(Stmt *s) { seqassert(false, "unexpected AST node {}", s->toString()); } /**************************************************************************************/ /// Typecheck statement expressions. void TypecheckVisitor::visit(StmtExpr *expr) { auto done = true; for (auto &s : *expr) { s = transform(s); done &= s->isDone(); } expr->expr = transform(expr->getExpr()); unify(expr->getType(), expr->getExpr()->getType()); if (done && expr->getExpr()->isDone()) expr->setDone(); } /// Typecheck a list of statements. void TypecheckVisitor::visit(SuiteStmt *stmt) { std::vector stmts; // for filtering out nullptr statements auto done = true; std::vector prepend; if (auto b = stmt->getAttribute(Attr::Bindings)) { for (auto &[n, bd] : b->bindings) { prepend.push_back(N(N(n), nullptr)); if (bd.count > 0) prepend.push_back(N( N(fmt::format("{}{}", n, VAR_USED_SUFFIX)), N(false))); } stmt->eraseAttribute(Attr::Bindings); } if (!prepend.empty()) stmt->items.insert(stmt->items.begin(), prepend.begin(), prepend.end()); for (auto *s : *stmt) { if (ctx->returnEarly) { // If returnEarly is set (e.g., in the function) ignore the rest break; } if ((s = transform(s))) { if (!cast(s)) { done &= s->isDone(); stmts.push_back(s); } else { for (auto *ss : *cast(s)) { if (ss) { done &= ss->isDone(); stmts.push_back(ss); } } } } } stmt->items = stmts; if (done) stmt->setDone(); } /// Typecheck expression statements. void TypecheckVisitor::visit(ExprStmt *stmt) { stmt->expr = transform(stmt->getExpr()); if (stmt->getExpr()->isDone()) stmt->setDone(); } void TypecheckVisitor::visit(CustomStmt *stmt) { if (stmt->getSuite()) { auto fn = in(ctx->cache->customBlockStmts, stmt->getKeyword()); seqassert(fn, "unknown keyword {}", stmt->getKeyword()); resultStmt = fn->second(this, stmt); } else { auto fn = in(ctx->cache->customExprStmts, stmt->getKeyword()); seqassert(fn, "unknown keyword {}", stmt->getKeyword()); resultStmt = (*fn)(this, stmt); } } void TypecheckVisitor::visit(CommentStmt *stmt) { stmt->setDone(); } void TypecheckVisitor::visit(DirectiveStmt *stmt) { if (stmt->getKey() == "auto_python") { ctx->autoPython = stmt->getValue() == "1"; compilationWarning( fmt::format("directive '{}' = {}", stmt->getKey(), ctx->autoPython), stmt->getSrcInfo().file, stmt->getSrcInfo().line, stmt->getSrcInfo().col); } else { compilationWarning(fmt::format("unknown directive '{}'", stmt->getKey()), stmt->getSrcInfo().file, stmt->getSrcInfo().line, stmt->getSrcInfo().col); } stmt->setDone(); } /**************************************************************************************/ /// Select the best method indicated of an object that matches the given argument /// types. See @c findMatchingMethods for details. types::FuncType * TypecheckVisitor::findBestMethod(ClassType *typ, const std::string &member, const std::vector &args) { std::vector callArgs; for (auto &a : args) { callArgs.emplace_back("", N()); // dummy expression callArgs.back().value->setType(a->shared_from_this()); } auto methods = findMethod(typ, member, false); auto m = findMatchingMethods(typ, methods, callArgs); return m.empty() ? nullptr : m[0]; } /// Select the best method indicated of an object that matches the given argument /// types. See @c findMatchingMethods for details. types::FuncType *TypecheckVisitor::findBestMethod(ClassType *typ, const std::string &member, const std::vector &args) { std::vector callArgs; for (auto &a : args) callArgs.emplace_back("", a); auto methods = findMethod(typ, member, false); auto m = findMatchingMethods(typ, methods, callArgs); return m.empty() ? nullptr : m[0]; } /// Select the best method indicated of an object that matches the given argument /// types. See @c findMatchingMethods for details. types::FuncType *TypecheckVisitor::findBestMethod( ClassType *typ, const std::string &member, const std::vector> &args) { std::vector callArgs; for (auto &[n, a] : args) { callArgs.emplace_back(n, N()); // dummy expression callArgs.back().value->setType(a->shared_from_this()); } auto methods = findMethod(typ, member, false); auto m = findMatchingMethods(typ, methods, callArgs); return m.empty() ? nullptr : m[0]; } /// Check if a function can be called with the given arguments. /// See @c reorderNamedArgs for details. int TypecheckVisitor::canCall(types::FuncType *fn, const std::vector &args, types::ClassType *part) { std::vector partialArgs; if (part && part->getPartial()) { auto known = part->getPartialMask(); auto knownArgTypes = extractClassGeneric(part, 1)->getClass(); for (size_t i = 0, k = 0; i < known.size(); i++) if (known[i] == ClassType::PartialFlag::Included) { partialArgs.push_back(extractClassGeneric(knownArgTypes, static_cast(k))); k++; } else if (known[i] == ClassType::PartialFlag::Default) { k++; } } std::vector> reordered; auto niGenerics = fn->ast->getNonInferrableGenerics(); auto score = reorderNamedArgs( fn, args, [&](int s, int k, const std::vector> &slots, bool _) { for (int si = 0, gi = 0, pi = 0; si < slots.size(); si++) { if ((*fn->ast)[si].isGeneric()) { if (slots[si].empty()) { // is this "real" type? if (in(niGenerics, (*fn->ast)[si].getName()) && !(*fn->ast)[si].getDefault()) return -1; reordered.emplace_back(nullptr, 0); } else { seqassert(gi < fn->funcGenerics.size(), "bad fn"); if (!extractFuncGeneric(fn, gi)->getStaticKind() && !isTypeExpr(args[slots[si][0]])) return -1; reordered.emplace_back(args[slots[si][0]].getExpr()->getType(), slots[si][0]); } gi++; } else if (si == s || si == k || slots[si].size() != 1) { // Partials if (slots[si].empty() && part && part->getPartial() && part->getPartialMask()[si] == ClassType::PartialFlag::Included) { reordered.emplace_back(partialArgs[pi++], 0); } else { // Ignore *args, *kwargs and default arguments reordered.emplace_back(nullptr, 0); } } else { reordered.emplace_back(args[slots[si][0]].getExpr()->getType(), slots[si][0]); } } return 0; }, [](error::Error, const SrcInfo &, const std::string &) { return -1; }, part && part->getPartial() ? part->getPartialMask() : ""); int ai = 0, mai = 0, gi = 0, real_gi = 0; for (; score != -1 && ai < reordered.size(); ai++) { auto expectTyp = (*fn->ast)[ai].isValue() ? extractFuncArgType(fn, mai++) : extractFuncGeneric(fn, gi++); auto [argType, argTypeIdx] = reordered[ai]; if (!argType) continue; real_gi += !(*fn->ast)[ai].isValue(); if (!(*fn->ast)[ai].isValue()) { // Check if this is a good generic! if (expectTyp && expectTyp->getStaticKind()) { if (!args[argTypeIdx].getExpr()->getType()->getStaticKind()) { score = -1; break; } else { argType = args[argTypeIdx].getExpr()->getType(); } } else { /// TODO: check if these are real types or if traits are satisfied continue; } } auto [_, newArgTyp, _ignore] = canWrapExpr(argType, expectTyp, fn); if (!newArgTyp) newArgTyp = argType->shared_from_this(); if (newArgTyp->unify(expectTyp, nullptr) < 0) score = -1; } if (score >= 0) score += (real_gi == fn->funcGenerics.size()); return score; } /// Select the best method among the provided methods given the list of arguments. /// See @c reorderNamedArgs for details. std::vector TypecheckVisitor::findMatchingMethods( types::ClassType *typ, const std::vector &methods, const std::vector &args, types::ClassType *part) { // Pick the last method that accepts the given arguments. std::vector results; for (const auto &mi : methods) { if (!mi) continue; // avoid overloads that have not been seen yet auto method = instantiateType(mi, typ); int score = canCall(method->getFunc(), args, part); if (score != -1) { results.push_back(mi); } } return results; } /// Wrap an expression to coerce it to the expected type if the type of the expression /// does not match it. Also unify types. /// @example /// expected `Generator` -> `expr.__iter__()` /// expected `float`, got `int` -> `float(expr)` /// expected `Optional[T]`, got `T` -> `Optional(expr)` /// expected `T`, got `Optional[T]` -> `unwrap(expr)` /// expected `Function`, got a function -> partialize function /// expected `T`, got `Union[T...]` -> `Union._get(expr, T)` /// expected `Union[T...]`, got `T` -> `Union._new(expr, Union[T...])` /// expected base class, got derived -> downcast to base class /// @param allowUnwrap allow optional unwrapping. bool TypecheckVisitor::wrapExpr(Expr **expr, Type *expectedType, FuncType *callee, bool allowUnwrap) { auto [canWrap, newArgTyp, fn] = canWrapExpr((*expr)->getType(), expectedType, callee, allowUnwrap, cast(*expr)); // TODO: get rid of this line one day! if ((*expr)->getType()->getStaticKind() && (!expectedType || !expectedType->getStaticKind())) (*expr)->setType(getUnderlyingStaticType((*expr)->getType())->shared_from_this()); if (canWrap && fn) { *expr = transform(fn(*expr)); } return canWrap; } std::tuple> TypecheckVisitor::canWrapExpr(Type *exprType, Type *expectedType, FuncType *callee, bool allowUnwrap, bool isEllipsis) { auto expectedClass = expectedType->getClass(); auto exprClass = exprType->getClass(); auto doArgWrap = !callee || !callee->ast->hasFunctionAttribute(getMangledFunc( "std.internal.attributes", "no_argument_wrap")); if (!doArgWrap) return {true, expectedType ? expectedType->shared_from_this() : nullptr, nullptr}; TypePtr type = nullptr; std::function fn = nullptr; if (callee && exprType->is(TYPE_TYPE)) { auto c = extractClassType(exprType); if (!c) return {false, nullptr, nullptr}; if (!(expectedType && (expectedType->is(TYPE_TYPE)))) { type = instantiateType(getStdLibType("TypeWrap"), std::vector{c}); fn = [&](Expr *expr) -> Expr * { return N(N("TypeWrap"), expr); }; } return {true, type, fn}; } std::unordered_set hints = {"Generator", "float", TYPE_OPTIONAL, "pyobj"}; if (!expectedType || !expectedType->getStaticKind()) { if (exprType->getStaticKind()) { exprType = getUnderlyingStaticType(exprType); exprClass = exprType->getClass(); type = exprType->shared_from_this(); } } if (!exprClass && expectedClass && in(hints, expectedClass->name)) { return {false, nullptr, nullptr}; // argument type not yet known. } else if (expectedClass && !expectedClass->is("Capsule") && exprClass && exprClass->is("Capsule")) { type = extractClassGeneric(exprClass)->shared_from_this(); fn = [&](Expr *expr) -> Expr * { return N( N(getMangledMethod("std.internal.core", "Capsule", "_get")), expr); }; } else if (expectedClass && expectedClass->is("Capsule") && exprClass && !exprClass->is("Capsule")) { type = instantiateType(getStdLibType("Capsule"), std::vector{exprClass}); fn = [&](Expr *expr) -> Expr * { return N( N(getMangledMethod("std.internal.core", "Capsule", "make")), expr); }; } else if (expectedClass && !expectedClass->is("Any") && exprClass && exprClass->is("Any")) { type = expectedClass->shared_from_this(); fn = [this, type](Expr *expr) -> Expr * { auto r = realize(type.get()); seqassert(r, "not realizable"); return N(N("Any.unwrap"), expr, N(r->realizedName())); }; } else if (expectedClass && expectedClass->is("Any") && exprClass && !exprClass->is("Any")) { type = expectedClass->shared_from_this(); fn = [&](Expr *expr) -> Expr * { return N(N("Any"), expr); }; } else if (expectedClass && expectedClass->is("Generator") && !exprClass->is(expectedClass->name) && !isEllipsis) { if (findMethod(exprClass, "__iter__").empty()) return {false, nullptr, nullptr}; // Note: do not do this in pipelines (TODO: why?) type = instantiateType(expectedClass); fn = [&](Expr *expr) -> Expr * { return N(N(expr, "__iter__")); }; } else if (expectedClass && expectedClass->is("float") && exprClass->is("int")) { type = instantiateType(expectedClass); fn = [&](Expr *expr) -> Expr * { return N(N("float"), expr); }; } else if (!callee && expectedClass && expectedClass->is("bool") && exprClass && !exprClass->is("bool")) { // Do not do this in function calls---only use for if-else wrapping type = instantiateType(expectedClass); fn = [&](Expr *expr) -> Expr * { return N(N(expr, "__bool__")); }; } else if (expectedClass && expectedClass->is(TYPE_OPTIONAL) && exprClass && !exprClass->is(expectedClass->name)) { type = instantiateType(getStdLibType(TYPE_OPTIONAL), std::vector{exprClass}); fn = [&](Expr *expr) -> Expr * { return N(N(TYPE_OPTIONAL), expr); }; } else if (allowUnwrap && expectedClass && exprClass && exprClass->is(TYPE_OPTIONAL) && !exprClass->is(expectedClass->name)) { // unwrap optional type = instantiateType(extractClassGeneric(exprClass)); fn = [&](Expr *expr) -> Expr * { return N(N(FN_OPTIONAL_UNWRAP), expr); }; } else if (expectedClass && expectedClass->is("pyobj") && !exprClass->is(expectedClass->name)) { // wrap to pyobj if (findMethod(exprClass, "__to_py__").empty()) return {false, nullptr, nullptr}; type = instantiateType(expectedClass); fn = [&](Expr *expr) -> Expr * { return N(N("pyobj"), N(N(expr, "__to_py__"))); }; } else if (allowUnwrap && expectedClass && exprClass && exprClass->is("pyobj") && !exprClass->is(expectedClass->name)) { // unwrap pyobj if (findMethod(expectedClass, "__from_py__").empty()) return {false, nullptr, nullptr}; type = instantiateType(expectedClass); auto texpr = N(expectedClass->name); texpr->setType(expectedType->shared_from_this()); fn = [this, texpr](Expr *expr) -> Expr * { return N(N(texpr, "__from_py__"), N(expr, "p")); }; } else if (expectedClass && expectedClass->is(TYPE_CALLABLE) && exprClass && (exprClass->getPartial() || exprClass->getFunc() || exprClass->is(TYPE_FUNCTION))) { // Get list of arguments std::vector argTypes; Type *retType; std::shared_ptr fnType = nullptr; if (!exprClass->getPartial()) { auto targs = extractClassGeneric(exprClass)->getClass(); for (size_t i = 0; i < targs->size(); i++) argTypes.push_back((*targs)[i]); retType = extractClassGeneric(exprClass, 1); } else { fnType = instantiateType(exprClass->getPartial()->getPartialFunc()); std::vector argumentTypes; auto known = exprClass->getPartial()->getPartialMask(); for (size_t i = 0; i < known.size(); i++) { if (known[i] != ClassType::PartialFlag::Included) argTypes.push_back((*fnType)[i]); } retType = fnType->getRetType(); } auto expectedArgs = extractClassGeneric(expectedClass)->getClass(); if (argTypes.size() != expectedArgs->size()) return {false, nullptr, nullptr}; for (size_t i = 0; i < argTypes.size(); i++) { if (argTypes[i]->unify((*expectedArgs)[i], nullptr) < 0) return {false, nullptr, nullptr}; } if (retType->unify(extractClassGeneric(expectedClass, 1), nullptr) < 0) return {false, nullptr, nullptr}; type = expectedType->shared_from_this(); fn = [this, type](Expr *expr) -> Expr * { auto exprClass = expr->getType()->getClass(); auto expectedClass = type->getClass(); std::vector argTypes; Type *retType; std::shared_ptr fnType = nullptr; if (!exprClass->getPartial()) { auto targs = extractClassGeneric(exprClass)->getClass(); for (size_t i = 0; i < targs->size(); i++) argTypes.push_back((*targs)[i]); retType = extractClassGeneric(exprClass, 1); } else { fnType = instantiateType(exprClass->getPartial()->getPartialFunc()); std::vector argumentTypes; auto known = exprClass->getPartial()->getPartialMask(); for (size_t i = 0; i < known.size(); i++) { if (known[i] != ClassType::PartialFlag::Included) argTypes.push_back((*fnType)[i]); } retType = fnType->getRetType(); } auto expectedArgs = extractClassGeneric(expectedClass)->getClass(); for (size_t i = 0; i < argTypes.size(); i++) unify(argTypes[i], (*expectedArgs)[i]); unify(retType, extractClassGeneric(expectedClass, 1)); std::string fname; Expr *retFn = nullptr, *dataArg = nullptr, *dataType = nullptr; if (exprClass->getPartial()) { auto rf = realize(exprClass); fname = rf->realizedName(); seqassert(rf, "not realizable"); retFn = N( N(N(N("Ptr"), N(rf->realizedName())), N("data")), N(0)); dataType = N("cobj"); } else if (exprClass->getFunc()) { auto rf = realize(exprClass); seqassert(rf, "not realizable"); fname = rf->realizedName(); retFn = N(rf->getFunc()->realizedName()); dataArg = N(N("cobj")); dataType = N("cobj"); } else { seqassert(exprClass->is("Function"), "bad type: {}", exprClass->debugString(2)); auto rf = realize(exprClass); seqassert(rf, "not realizable"); fname = rf->realizedName(); retFn = N(N(rf->realizedName()), N("data")); dataType = N("cobj"); } fname = fmt::format(".proxy.{}", fname); if (!ctx->find(fname)) { // Create wrapper if needed auto f = N( fname, nullptr, std::vector{ Param{"data", dataType}, Param{"args", N(expectedArgs->realizedName())}}, // Tuple[...] N( N(N(retFn, N(N("args")))))); f = cast(transform(f)); } auto e = N(N(TYPE_CALLABLE), N(fname), dataArg ? dataArg : expr); return e; }; } else if (callee && exprClass && exprType->getFunc() && !(expectedClass && expectedClass->is("Function"))) { // Wrap raw Seq functions into Partial(...) call for easy realization. // Special case: Seq functions are embedded (via lambda!) if (expectedClass) type = instantiateType(expectedClass); auto fnName = exprType->getFunc()->ast->getName(); fn = [&, fnName](Expr *expr) -> Expr * { auto p = N(N(fnName), N(EllipsisExpr::PARTIAL)); if (auto se = cast(expr)) return N(se->items, p); return p; }; } else if (expectedClass && expectedClass->is("Function") && exprClass && exprClass->getPartial() && exprClass->getPartial()->isPartialEmpty()) { type = instantiateType(expectedClass); auto fnName = exprClass->getPartial()->getPartialFunc()->ast->name; auto t = instantiateType(ctx->forceFind(fnName)->getType()); if (type->unify(t.get(), nullptr) >= 0) fn = [&](Expr *expr) -> Expr * { return N(fnName); }; else type = nullptr; } else if (allowUnwrap && exprClass && exprType->getUnion() && expectedClass && !expectedClass->getUnion()) { // Extract union types via Union._get if (auto t = realize(expectedClass)) { auto e = realize(exprType); if (!e) return {false, nullptr, nullptr}; bool ok = false; for (auto &ut : e->getUnion()->getRealizationTypes()) { if (ut->unify(t, nullptr) >= 0) { ok = true; break; } } if (ok) { type = t->shared_from_this(); fn = [this, type](Expr *expr) -> Expr * { return N( N(getMangledMethod("std.internal.core", "Union", "_get")), expr, N(type->realizedName())); }; } } else { return {false, nullptr, nullptr}; } } else if (exprClass && expectedClass && expectedClass->getUnion()) { // Make union types via Union._new if (!expectedClass->getUnion()->isSealed()) { if (!expectedClass->getUnion()->addType(exprClass)) E(error::Error::UNION_TOO_BIG, expectedClass->getSrcInfo(), expectedClass->getUnion()->pendingTypes.size()); } if (auto t = realize(expectedClass)) { if (expectedClass->unify(exprClass, nullptr) == -1) { type = t->shared_from_this(); fn = [this, type](Expr *expr) -> Expr * { return N(N(N("Union"), "_new"), expr, N(type->realizedName())); }; } } else { return {false, nullptr, nullptr}; } } else if (exprClass && exprClass->is(TYPE_TYPE) && expectedClass && (expectedClass->is("TypeWrap"))) { type = instantiateType(getStdLibType("TypeWrap"), std::vector{exprClass}); fn = [this](Expr *expr) -> Expr * { return N(N("TypeWrap"), expr); }; } else if (exprClass && exprClass->is("Super") && expectedClass && !expectedClass->is("Super")) { // Super[T] to T type = extractClassGeneric(exprClass)->shared_from_this(); fn = [this](Expr *expr) -> Expr * { return N( N(getMangledMethod("std.internal.core", "Super", "_unwrap")), expr); }; } else if (exprClass && expectedClass && !exprClass->is(expectedClass->name)) { // Cast derived classes to base classes const auto &mros = ctx->cache->getClass(exprClass)->mro; for (size_t i = 1; i < mros.size(); i++) { auto t = instantiateType(mros[i].get(), exprClass); if (t->unify(expectedClass, nullptr) >= 0) { type = expectedClass->shared_from_this(); fn = [this, type](Expr *expr) -> Expr * { return castToSuperClass(expr, type->getClass(), true); }; break; } } } return {true, type, fn}; } /// Cast derived class to a base class. Expr *TypecheckVisitor::castToSuperClass(Expr *expr, ClassType *superTyp, bool isVirtual) { ClassType *typ = expr->getClassType(); for (auto &field : getClassFields(typ)) { for (auto &parentField : getClassFields(superTyp)) if (field.name == parentField.name) { auto t = instantiateType(field.getType(), typ); unify(t.get(), instantiateType(parentField.getType(), superTyp)); } } realize(superTyp); auto typExpr = N(superTyp->realizedName()); return transform( N(N(N("Super"), "_super"), expr, typExpr)); } /// Unpack a Tuple or KwTuple expression into (name, type) vector. /// Name is empty when handling Tuple; otherwise it matches names of KwTuple. std::shared_ptr>> TypecheckVisitor::unpackTupleTypes(const Expr *expr) { auto ret = std::make_shared>>(); if (auto tup = cast(expr->getOrigExpr())) { for (auto &a : *tup) { a = transform(a); if (!a->getClassType()) return nullptr; ret->emplace_back("", a->getType()); } } else if (cast(expr->getOrigExpr())) { auto val = extractClassType(expr->getType()); if (!val || !val->is("NamedTuple") || !extractClassGeneric(val, 1)->getClass() || !extractClassGeneric(val)->canRealize()) return nullptr; auto id = getIntLiteral(val); seqassert(id >= 0 && id < ctx->cache->generatedTupleNames.size(), "bad id: {}", id); auto names = ctx->cache->generatedTupleNames[id]; auto types = extractClassGeneric(val, 1)->getClass(); seqassert(startswith(types->name, "Tuple"), "bad NamedTuple argument"); for (size_t i = 0; i < types->generics.size(); i++) { if (!extractClassGeneric(types, i)) return nullptr; ret->emplace_back(names[i], extractClassGeneric(types, i)); } } else { return nullptr; } return ret; } std::vector> TypecheckVisitor::extractNamedTuple(Expr *expr) { std::vector> ret; seqassert(expr->getType()->is("NamedTuple") && extractClassGeneric(expr->getClassType())->canRealize(), "bad named tuple: {}", *expr); auto id = getIntLiteral(expr->getClassType()); seqassert(id >= 0 && id < ctx->cache->generatedTupleNames.size(), "bad id: {}", id); auto names = ctx->cache->generatedTupleNames[id]; for (size_t i = 0; i < names.size(); i++) { ret.emplace_back(names[i], N(N(expr, "args"), N(i))); } return ret; } std::vector TypecheckVisitor::getClassFields(types::ClassType *t) const { auto f = getClass(t->name)->fields; if (t->is(TYPE_TUPLE)) f = std::vector(f.begin(), f.begin() + t->generics.size()); return f; } std::vector TypecheckVisitor::getClassFieldTypes(types::ClassType *cls) { return withClassGenerics(cls, [&]() { std::vector result; for (auto &field : getClassFields(cls)) { auto ftyp = instantiateType(field.getType(), cls); if (!ftyp->canRealize() && field.typeExpr) { auto t = extractType(transform(clean_clone(field.typeExpr))); unify(ftyp.get(), t); } result.push_back(ftyp); } return result; }); } types::Type *TypecheckVisitor::extractType(types::Type *t) const { while (t && t->is(TYPE_TYPE)) t = extractClassGeneric(t); return t; } types::Type *TypecheckVisitor::extractType(Expr *e) const { if (cast(e) && cast(e)->getValue() == TYPE_TYPE) return e->getType(); if (auto i = cast(e)) if (cast(i->getExpr()) && cast(i->getExpr())->getValue() == TYPE_TYPE) return e->getType(); return extractType(e->getType()); } types::Type *TypecheckVisitor::extractType(const std::string &s) const { auto c = ctx->forceFind(s); return s == TYPE_TYPE ? c->getType() : extractType(c->getType()); } types::ClassType *TypecheckVisitor::extractClassType(Expr *e) const { auto t = extractType(e); return t->getClass(); } types::ClassType *TypecheckVisitor::extractClassType(types::Type *t) const { return extractType(t)->getClass(); } types::ClassType *TypecheckVisitor::extractClassType(const std::string &s) const { return extractType(s)->getClass(); } bool TypecheckVisitor::isUnbound(types::Type *t) { return t->getUnbound() != nullptr; } bool TypecheckVisitor::isUnbound(const Expr *e) { return isUnbound(e->getType()); } bool TypecheckVisitor::hasOverloads(const std::string &root) const { auto i = in(ctx->cache->overloads, root); return i && i->size() > 1; } std::vector TypecheckVisitor::getOverloads(const std::string &root) const { auto i = in(ctx->cache->overloads, root); seqassert(i, "bad root"); return *i; } std::string TypecheckVisitor::getUnmangledName(const std::string &s) const { if (in(ctx->cache->reverseIdentifierLookup, s)) return ctx->cache->rev(s); return s; } std::string TypecheckVisitor::getUserFacingName(const std::string &s) const { auto n = getUnmangledName(s); if (startswith(n, "$")) n = n.substr(1); return n; } Cache::Class *TypecheckVisitor::getClass(const std::string &t) const { auto i = in(ctx->cache->classes, t); return i; } Cache::Class *TypecheckVisitor::getClass(types::Type *t) const { if (t) { if (auto c = t->getClass()) return getClass(c->name); } seqassert(false, "bad class"); return nullptr; } Cache::Function *TypecheckVisitor::getFunction(const std::string &n) const { auto i = in(ctx->cache->functions, n); return i; } Cache::Function *TypecheckVisitor::getFunction(types::Type *t) const { seqassert(t->getFunc(), "bad function"); return getFunction(t->getFunc()->getFuncName()); } Cache::Class::ClassRealization * TypecheckVisitor::getClassRealization(types::Type *t) const { seqassert(t->canRealize(), "bad class"); auto i = in(getClass(t)->realizations, t->getClass()->realizedName()); seqassert(i, "bad class realization"); return i->get(); } std::string TypecheckVisitor::getRootName(const types::FuncType *t) const { auto i = in(ctx->cache->functions, t->getFuncName()); seqassert(i && !i->rootName.empty(), "bad function"); return i->rootName; } bool TypecheckVisitor::isTypeExpr(const Expr *e) { return e && e->getType() && e->getType()->is(TYPE_TYPE); } Cache::Module *TypecheckVisitor::getImport(const std::string &s) const { auto i = in(ctx->cache->imports, s); seqassert(i, "bad import"); return i; } bool TypecheckVisitor::isDispatch(const std::string &s) { return endswith(s, FN_DISPATCH_SUFFIX); } bool TypecheckVisitor::isDispatch(const FunctionStmt *ast) { return ast && isDispatch(ast->name); } bool TypecheckVisitor::isDispatch(types::Type *f) { return f->getFunc() && isDispatch(f->getFunc()->ast); } bool TypecheckVisitor::isHeterogenous(types::Type *type) { if (!type->getClass() || !type->getClass()->isRecord()) return false; std::vector fs; if (type->is(TYPE_TUPLE)) { for (auto &g : type->getClass()->generics) fs.push_back(g.getType()->shared_from_this()); } else { fs = getClassFieldTypes(type->getClass()); } if (fs.size() > 1) { std::string first = fs[0]->realizedName(); for (int i = 1; i < fs.size(); i++) if (fs[i]->realizedName() != first) return true; } return false; } void TypecheckVisitor::addClassGenerics(types::ClassType *typ, bool func, bool onlyMangled, bool instantiate) { auto addGen = [&](const types::ClassType::Generic &g) { auto t = g.type; if (instantiate) { if (auto l = t->getLink()) if (l->kind == types::LinkType::Generic) { auto lx = std::make_shared(*l); lx->kind = types::LinkType::Unbound; t = lx; } } seqassert(!g.staticKind || t->getStaticKind(), "{} not a static: {}", g.name, *(g.type)); if (!g.staticKind && !t->is(TYPE_TYPE)) t = instantiateTypeVar(t.get()); auto n = onlyMangled ? g.name : getUnmangledName(g.name); auto v = ctx->addType(n, g.name, t); v->generic = true; }; if (func && typ->getFunc()) { auto tf = typ->getFunc(); for (auto parent = tf->funcParent; parent;) { if (auto f = parent->getFunc()) { // Add parent function generics for (auto &g : f->funcGenerics) addGen(g); parent = f->funcParent; } else { // Add parent class generics seqassert(parent->getClass(), "not a class: {}", *parent); for (auto &g : parent->getClass()->hiddenGenerics) addGen(g); for (auto &g : parent->getClass()->generics) addGen(g); break; } } for (auto &g : tf->funcGenerics) addGen(g); } else { for (auto &g : typ->hiddenGenerics) addGen(g); for (auto &g : typ->generics) addGen(g); } } types::TypePtr TypecheckVisitor::instantiateTypeVar(types::Type *t) { return instantiateType(ctx->forceFind(TYPE_TYPE)->getType(), {t}); } void TypecheckVisitor::registerGlobal(const std::string &name) const { if (!in(ctx->cache->globals, name)) { ctx->cache->globals[name] = nullptr; } } types::ClassType *TypecheckVisitor::getStdLibType(const std::string &type) const { auto t = getImport(STDLIB_IMPORT)->ctx->forceFind(type)->getType(); if (type == TYPE_TYPE) return t->getClass(); return extractClassType(t); } types::Type *TypecheckVisitor::extractClassGeneric(types::Type *t, size_t idx) const { seqassert(t->getClass() && idx < t->getClass()->generics.size(), "bad class"); return t->getClass()->generics[idx].type.get(); } types::Type *TypecheckVisitor::extractFuncGeneric(types::Type *t, size_t idx) const { seqassert(t->getFunc() && idx < t->getFunc()->funcGenerics.size(), "bad function"); return t->getFunc()->funcGenerics[idx].type.get(); } types::Type *TypecheckVisitor::extractFuncArgType(types::Type *t, size_t idx) const { seqassert(t->getFunc(), "bad function"); return extractClassGeneric(extractClassGeneric(t), idx); } std::string TypecheckVisitor::getClassMethod(types::Type *typ, const std::string &member) const { if (auto cls = getClass(typ)) { if (auto t = in(cls->methods, member)) return *t; } seqassertn(false, "cannot find '{}' in '{}'", member, *typ); return ""; } std::string TypecheckVisitor::getTemporaryVar(const std::string &s) const { return ctx->cache->getTemporaryVar(s); } std::string TypecheckVisitor::getStrLiteral(types::Type *t, size_t pos) const { seqassert(t && t->getClass(), "not a class"); if (t->getStrStatic()) return t->getStrStatic()->value; auto ct = extractClassGeneric(t, pos); seqassert(ct->canRealize() && ct->getStrStatic(), "not a string literal"); return ct->getStrStatic()->value; } int64_t TypecheckVisitor::getIntLiteral(types::Type *t, size_t pos) const { seqassert(t && t->getClass(), "not a class"); if (t->getIntStatic()) return t->getIntStatic()->value; auto ct = extractClassGeneric(t, pos); seqassert(ct->canRealize() && ct->getIntStatic(), "not a int literal"); return ct->getIntStatic()->value; } bool TypecheckVisitor::getBoolLiteral(types::Type *t, size_t pos) const { seqassert(t && t->getClass(), "not a class"); if (t->getBoolStatic()) return t->getBoolStatic()->value; auto ct = extractClassGeneric(t, pos); seqassert(ct->canRealize() && ct->getBoolStatic(), "not a bool literal"); return ct->getBoolStatic()->value; } Expr *TypecheckVisitor::getParamType(Type *t) { if (!t) return nullptr; if (t->is(TYPE_TYPE)) { return N(TYPE_TYPE); } else if (auto st = t->getStaticKind()) { return N(N("Literal"), N(Type::stringFromLiteral(st))); } else { return nullptr; } } bool TypecheckVisitor::hasSideEffect(Expr *e) { return !match(e, MOr(M(), M(M()), M(), M(), M(), M(), M(), M(), M(), M())); } Expr *TypecheckVisitor::getHeadExpr(Expr *e) { while (auto se = cast(e)) e = se->getExpr(); return e; } bool TypecheckVisitor::isImportFn(const std::string &s) { return startswith(s, "%_import_"); } int64_t TypecheckVisitor::getTime() const { return ctx->time; } types::Type *TypecheckVisitor::getUnderlyingStaticType(types::Type *t) const { if (t->getStatic()) { return t->getStatic()->getNonStaticType(); } else if (auto c = t->getStaticKind()) { return getStdLibType(Type::stringFromLiteral(c)); } return t; } std::shared_ptr TypecheckVisitor::instantiateUnbound(const SrcInfo &srcInfo, int level) const { auto typ = std::make_shared( ctx->cache, types::LinkType::Unbound, ctx->cache->unboundCount++, level, nullptr); typ->setSrcInfo(srcInfo); return typ; } std::shared_ptr TypecheckVisitor::instantiateUnbound(const SrcInfo &srcInfo) const { return instantiateUnbound(srcInfo, ctx->typecheckLevel); } std::shared_ptr TypecheckVisitor::instantiateUnbound() const { return instantiateUnbound(getSrcInfo(), ctx->typecheckLevel); } types::TypePtr TypecheckVisitor::instantiateType(const SrcInfo &srcInfo, types::Type *type, types::ClassType *generics) const { seqassert(type, "type is null"); std::unordered_map genericCache; if (generics) { for (auto &g : generics->hiddenGenerics) if (g.type && !(g.type->getLink() && g.type->getLink()->kind == types::LinkType::Generic)) { genericCache[g.id] = g.type; } for (auto &g : generics->generics) if (g.type && !(g.type->getLink() && g.type->getLink()->kind == types::LinkType::Generic)) { genericCache[g.id] = g.type; } // special case: __SELF__ if (type->getFunc() && !type->getFunc()->funcGenerics.empty() && getUnmangledName(type->getFunc()->funcGenerics[0].name) == "__SELF__") { genericCache[type->getFunc()->funcGenerics[0].id] = generics->shared_from_this(); } } auto t = type->instantiate(ctx->typecheckLevel, &(ctx->cache->unboundCount), &genericCache); for (auto &val : genericCache | std::views::values) { if (auto l = val->getLink()) { val->setSrcInfo(srcInfo); if (l->defaultType) { ctx->getBase()->pendingDefaults[0].insert(val); } } } if (t->getUnion() && !t->getUnion()->isSealed()) { t->setSrcInfo(srcInfo); ctx->getBase()->pendingDefaults[0].insert(t); } return t; } types::TypePtr TypecheckVisitor::instantiateType(const SrcInfo &srcInfo, types::Type *root, const std::vector &generics) const { auto c = root->getClass(); seqassert(c, "root class is null"); // dummy generic type auto g = std::make_shared(ctx->cache, ""); if (generics.size() != c->generics.size()) { E(Error::GENERICS_MISMATCH, srcInfo, getUserFacingName(c->name), c->generics.size(), generics.size()); } for (int i = 0; i < c->generics.size(); i++) { auto t = generics[i]; seqassert(c->generics[i].type, "generic is null"); if (!c->generics[i].staticKind && t->getStatic()) t = t->getStatic()->getNonStaticType(); g->generics.emplace_back("", t->shared_from_this(), c->generics[i].id, c->generics[i].staticKind); } return instantiateType(srcInfo, root, g.get()); } std::vector TypecheckVisitor::findMethod(types::ClassType *type, const std::string &method, bool hideShadowed) { std::vector vv; std::unordered_set signatureLoci; auto populate = [&](const auto &cls) { auto t = in(cls.methods, method); if (!t) return; auto mt = getOverloads(*t); for (int mti = static_cast(mt.size()) - 1; mti >= 0; --mti) { const auto &exactMethod = mt[mti]; auto f = getFunction(exactMethod); if (isDispatch(exactMethod) || !f->getType()) continue; if (hideShadowed) { auto sig = f->ast->getSignature(); if (!in(signatureLoci, sig)) { signatureLoci.insert(sig); vv.emplace_back(f->getType()); } } else { vv.emplace_back(f->getType()); } } }; if (type->is("Capsule") || type->is("Super")) { type = extractClassGeneric(type)->getClass(); } if (type && type->is(TYPE_TUPLE) && method == "__new__" && !type->generics.empty()) { generateTuple(type->generics.size()); auto mc = getClass(TYPE_TUPLE); populate(*mc); for (auto f : vv) if (f->size() == type->generics.size()) return {f}; return {}; } if (auto cls = getClass(type)) { for (const auto &pc : cls->mro) { auto mc = getClass(pc->name == "__NTuple__" ? TYPE_TUPLE : pc->name); populate(*mc); } } return vv; } Cache::Class::ClassField * TypecheckVisitor::findMember(types::ClassType *type, const std::string &member) const { if (type->is("Capsule")) { type = extractClassGeneric(type)->getClass(); } if (auto cls = getClass(type)) { for (const auto &pc : cls->mro) { auto mc = getClass(pc.get()); for (auto &mm : mc->fields) { if (pc->is(TYPE_TUPLE) && (&mm - &(mc->fields[0])) >= type->generics.size()) break; if (mm.name == member) return &mm; } } } return nullptr; } int TypecheckVisitor::reorderNamedArgs(const types::FuncType *func, const std::vector &args, const ReorderDoneFn &onDone, const ReorderErrorFn &onError, const std::string &known) const { // See https://docs.python.org/3.6/reference/expressions.html#calls for details. // Final score: // - +1 for each matched argument // - 0 for *args/**kwargs/default arguments // - -1 for failed match int score = 0; // 0. Find *args and **kwargs // True if there is a trailing ellipsis (full partial: fn(all_args, ...)) bool partial = !args.empty() && args.back().name.empty() && cast(args.back().value) && cast(args.back().value)->isPartial(); int starArgIndex = -1, kwstarArgIndex = -1; for (int i = 0; i < func->ast->size(); i++) { if (startswith((*func->ast)[i].name, "**")) kwstarArgIndex = i, score -= 2; else if (startswith((*func->ast)[i].name, "*")) starArgIndex = i, score -= 2; } // 1. Assign positional arguments to slots // Each slot contains a list of arg's indices std::vector> slots(func->ast->size()); seqassert(known.empty() || func->ast->size() == known.size(), "bad 'known' string"); std::vector extra; std::map namedArgs, extraNamedArgs; // keep the map---we need it sorted! for (int ai = 0, si = 0; ai < args.size() - partial; ai++) { if (args[ai].name.empty()) { while (!known.empty() && si < slots.size() && known[si] == ClassType::PartialFlag::Included) si++; if (si < slots.size() && (starArgIndex == -1 || si < starArgIndex)) slots[si++] = {ai}; else extra.emplace_back(ai); } else { namedArgs[args[ai].name] = ai; } } score += 2 * static_cast(slots.size() - func->funcGenerics.size()); for (auto ai : std::vector{std::max(starArgIndex, kwstarArgIndex), std::min(starArgIndex, kwstarArgIndex)}) if (ai != -1 && !slots[ai].empty()) { extra.insert(extra.begin(), ai); slots[ai].clear(); } // 2. Assign named arguments to slots if (!namedArgs.empty()) { std::map slotNames; for (int i = 0; i < func->ast->size(); i++) if (known.empty() || known[i] != ClassType::PartialFlag::Included) { auto [_, n] = (*func->ast)[i].getNameWithStars(); slotNames[getUnmangledName(n)] = i; } for (auto &n : namedArgs) { if (!in(slotNames, n.first)) extraNamedArgs[n.first] = n.second; else if (slots[slotNames[n.first]].empty()) slots[slotNames[n.first]].push_back(n.second); else return onError(Error::CALL_REPEATED_NAME, args[n.second].value->getSrcInfo(), Emsg(Error::CALL_REPEATED_NAME, n.first)); } } // 3. Fill in *args, if present if (!extra.empty() && starArgIndex == -1) return onError(Error::CALL_ARGS_MANY, getSrcInfo(), Emsg(Error::CALL_ARGS_MANY, getUserFacingName(func->ast->getName()), func->ast->size(), args.size() - partial)); if (starArgIndex != -1) slots[starArgIndex] = extra; // 4. Fill in **kwargs, if present if (!extraNamedArgs.empty() && kwstarArgIndex == -1) return onError(Error::CALL_ARGS_INVALID, args[extraNamedArgs.begin()->second].value->getSrcInfo(), Emsg(Error::CALL_ARGS_INVALID, extraNamedArgs.begin()->first, getUnmangledName(func->ast->getName()))); if (kwstarArgIndex != -1) for (auto &val : extraNamedArgs | std::views::values) slots[kwstarArgIndex].push_back(val); // 5. Fill in the default arguments for (auto i = 0; i < func->ast->size(); i++) if (slots[i].empty() && i != starArgIndex && i != kwstarArgIndex) { if (((*func->ast)[i].isValue() && ((*func->ast)[i].getDefault() || (!known.empty() && known[i] == ClassType::PartialFlag::Included))) || startswith((*func->ast)[i].getName(), "$")) { score -= 2; } else if (!partial && (*func->ast)[i].isValue()) { auto [_, n] = (*func->ast)[i].getNameWithStars(); return onError(Error::CALL_ARGS_MISSING, getSrcInfo(), Emsg(Error::CALL_ARGS_MISSING, getUnmangledName(func->ast->getName()), getUnmangledName(n))); } } auto s = onDone(starArgIndex, kwstarArgIndex, slots, partial); return s != -1 ? score + s : -1; } bool TypecheckVisitor::isCanonicalName(const std::string &name) { return name.rfind('.') != std::string::npos; } types::FuncType *TypecheckVisitor::extractFunction(types::Type *t) { if (auto f = t->getFunc()) return f; if (auto p = t->getPartial()) return p->getPartialFunc(); return nullptr; } class SearchVisitor : public CallbackASTVisitor { std::function exprPredicate; std::function stmtPredicate; public: std::vector result; public: SearchVisitor(const std::function &exprPredicate, const std::function &stmtPredicate) : exprPredicate(exprPredicate), stmtPredicate(stmtPredicate) {} void transform(Expr *expr) override { if (expr && exprPredicate(expr)) { result.push_back(expr); } else { SearchVisitor v(exprPredicate, stmtPredicate); if (expr) expr->accept(v); result.insert(result.end(), v.result.begin(), v.result.end()); } } void transform(Stmt *stmt) override { if (stmt && stmtPredicate(stmt)) { SearchVisitor v(exprPredicate, stmtPredicate); stmt->accept(v); result.insert(result.end(), v.result.begin(), v.result.end()); } } }; ParserErrors TypecheckVisitor::findTypecheckErrors(Stmt *n) const { SearchVisitor v([](const Expr *e) { return !e->isDone(); }, [](const Stmt *s) { return !s->isDone(); }); v.transform(n); std::vector errors; for (auto e : v.result) { auto code = ctx->cache->getContent(e->getSrcInfo()); if (!code.empty()) errors.emplace_back(fmt::format("cannot typecheck '{}'", code), e->getSrcInfo()); else errors.emplace_back(fmt::format("cannot typecheck the expression"), e->getSrcInfo()); } return ParserErrors(errors); } /***** Cython-like code generation *****/ const std::string CYTHON_MODULE = "std.internal.python"; const std::string CYTHON_WRAP = "_PyWrap"; const std::string CYTHON_ITER = "_PyWrap.IterWrap"; ir::PyType TypecheckVisitor::cythonizeClass(const std::string &name) { auto c = getClass(name); auto ci = getImport(c->module); if (!ci->name.empty()) return {"", ""}; if (!in(c->methods, "__to_py__") || !in(c->methods, "__from_py__")) return {"", ""}; LOG_USER("[py] Cythonizing {} ({})", getUserFacingName(name), name); ir::PyType py{getUnmangledName(name), c->ast->getDocstr()}; auto tc = extractType(ctx->forceFind(name)->getType()); if (!tc->canRealize()) E(Error::CUSTOM, c->ast, "cannot realize '{}' for Python export", getUnmangledName(name)); tc = realize(tc); seqassertn(tc, "cannot realize '{}'", name); // 1. Replace to_py / from_py with _PyWrap.wrap_to_py/from_py if (auto ofnn = in(c->methods, "__to_py__")) { auto fnn = getOverloads(*ofnn).front(); // default first overload! auto fna = getFunction(fnn)->ast; fna->suite = SuiteStmt::wrap(N(N( N(getMangledMethod(CYTHON_MODULE, CYTHON_WRAP, "wrap_to_py")), N(fna->begin()->name)))); } if (auto ofnn = in(c->methods, "__from_py__")) { auto fnn = getOverloads(*ofnn).front(); // default first overload! auto fna = getFunction(fnn)->ast; fna->suite = SuiteStmt::wrap(N(N( N(getMangledMethod(CYTHON_MODULE, CYTHON_WRAP, "wrap_from_py")), N(fna->begin()->name), N(name)))); } for (auto &n : std::vector{"__from_py__", "__to_py__"}) { auto fnn = getOverloads(*in(c->methods, n)).front(); auto fn = getFunction(fnn); ir::Func *oldIR = nullptr; if (!fn->realizations.empty()) oldIR = fn->realizations.begin()->second->ir; fn->realizations.clear(); auto tf = realize(fn->type); seqassertn(tf, "cannot re-realize '{}'", fnn); if (oldIR) { std::vector args; for (auto it = oldIR->arg_begin(); it != oldIR->arg_end(); ++it) { args.push_back(ctx->cache->module->Nr(*it)); } cast(oldIR)->setBody( ir::util::series(ir::util::call(fn->realizations.begin()->second->ir, args))); } } for (auto &r : getFunction(getMangledMethod(CYTHON_MODULE, CYTHON_WRAP, "py_type")) ->realizations | std::views::values) { if (r->type->funcGenerics[0].type->unify(tc, nullptr) >= 0) { py.typePtrHook = r->ir; break; } } // 2. Handle methods auto methods = c->methods; for (const auto &[n, ofnn] : methods) { auto canonicalName = getOverloads(ofnn).back(); auto fn = getFunction(canonicalName); if (getOverloads(ofnn).size() == 1 && fn->ast->hasAttribute(Attr::AutoGenerated)) continue; auto fna = fn->ast; bool isMethod = fna->hasAttribute(Attr::Method); bool isProperty = fna->hasAttribute(Attr::Property); std::string call = getMangledMethod(CYTHON_MODULE, CYTHON_WRAP, "wrap_multiple"); bool isMagic = false; if (startswith(n, "__") && endswith(n, "__")) { auto m = n.substr(2, n.size() - 4); if (m == "new" && c->ast->hasAttribute(Attr::Tuple)) m = "init"; auto cls = getClass(getMangledClass(CYTHON_MODULE, CYTHON_WRAP)); if (auto i = in(cls->methods, "wrap_magic_" + m)) { call = getMangledMethod(CYTHON_MODULE, CYTHON_WRAP, "wrap_magic_" + m); isMagic = true; } } if (isProperty) call = getMangledMethod(CYTHON_MODULE, CYTHON_WRAP, "wrap_get"); auto generics = std::vector{tc->shared_from_this()}; if (isProperty) { generics.push_back(instantiateStatic(getUnmangledName(canonicalName))); } else if (!isMagic) { generics.push_back(instantiateStatic(n)); generics.push_back(instantiateStatic(static_cast(isMethod))); } auto f = realizeIRFunc(getFunction(call)->getType(), generics); if (!f) continue; LOG_USER("[py] {} -> {} (method={}; property={})", n, call, isMethod, isProperty); if (isProperty) { py.getset.push_back({getUnmangledName(canonicalName), "", f, nullptr}); } else if (n == "__repr__") { py.repr = f; } else if (n == "__add__") { py.add = f; } else if (n == "__iadd__") { py.iadd = f; } else if (n == "__sub__") { py.sub = f; } else if (n == "__isub__") { py.isub = f; } else if (n == "__mul__") { py.mul = f; } else if (n == "__imul__") { py.imul = f; } else if (n == "__mod__") { py.mod = f; } else if (n == "__imod__") { py.imod = f; } else if (n == "__divmod__") { py.divmod = f; } else if (n == "__pow__") { py.pow = f; } else if (n == "__ipow__") { py.ipow = f; } else if (n == "__neg__") { py.neg = f; } else if (n == "__pos__") { py.pos = f; } else if (n == "__abs__") { py.abs = f; } else if (n == "__bool__") { py.bool_ = f; } else if (n == "__invert__") { py.invert = f; } else if (n == "__lshift__") { py.lshift = f; } else if (n == "__ilshift__") { py.ilshift = f; } else if (n == "__rshift__") { py.rshift = f; } else if (n == "__irshift__") { py.irshift = f; } else if (n == "__and__") { py.and_ = f; } else if (n == "__iand__") { py.iand = f; } else if (n == "__xor__") { py.xor_ = f; } else if (n == "__ixor__") { py.ixor = f; } else if (n == "__or__") { py.or_ = f; } else if (n == "__ior__") { py.ior = f; } else if (n == "__int__") { py.int_ = f; } else if (n == "__float__") { py.float_ = f; } else if (n == "__floordiv__") { py.floordiv = f; } else if (n == "__ifloordiv__") { py.ifloordiv = f; } else if (n == "__truediv__") { py.truediv = f; } else if (n == "__itruediv__") { py.itruediv = f; } else if (n == "__index__") { py.index = f; } else if (n == "__matmul__") { py.matmul = f; } else if (n == "__imatmul__") { py.imatmul = f; } else if (n == "__len__") { py.len = f; } else if (n == "__getitem__") { py.getitem = f; } else if (n == "__setitem__") { py.setitem = f; } else if (n == "__contains__") { py.contains = f; } else if (n == "__hash__") { py.hash = f; } else if (n == "__call__") { py.call = f; } else if (n == "__str__") { py.str = f; } else if (n == "__iter__") { py.iter = f; } else if (n == "__del__") { py.del = f; } else if (n == "__init__" || (c->ast->hasAttribute(Attr::Tuple) && n == "__new__")) { py.init = f; } else { py.methods.push_back(ir::PyFunction{ n, fna->getDocstr(), f, fna->hasAttribute(Attr::Method) ? ir::PyFunction::Type::METHOD : ir::PyFunction::Type::CLASS, // always use FASTCALL for now; works even for 0- or 1- arg methods 2}); py.methods.back().keywords = true; } } for (auto &m : py.methods) { if (in(std::set{"__lt__", "__le__", "__eq__", "__ne__", "__gt__", "__ge__"}, m.name)) { py.cmp = realizeIRFunc( ctx->forceFind(getMangledMethod(CYTHON_MODULE, CYTHON_WRAP, "wrap_cmp")) ->type->getFunc(), {tc->shared_from_this()}); break; } } if (c->realizations.size() != 1) E(Error::CUSTOM, c->ast, "cannot pythonize generic class '{}'", name); auto r = c->realizations.begin()->second; py.type = r->ir; seqassertn(!r->type->is(TYPE_TUPLE), "tuples not yet done"); for (auto &mn : r->fields | std::views::keys) { /// TODO: handle PyMember for tuples // Generate getters & setters auto generics = std::vector{tc->shared_from_this(), instantiateStatic(mn)}; auto gf = realizeIRFunc( getFunction(getMangledMethod(CYTHON_MODULE, CYTHON_WRAP, "wrap_get")) ->getType(), generics); ir::Func *sf = nullptr; if (!c->ast->hasAttribute(Attr::Tuple)) sf = realizeIRFunc( getFunction(getMangledMethod(CYTHON_MODULE, CYTHON_WRAP, "wrap_set")) ->getType(), generics); py.getset.push_back({mn, "", gf, sf}); LOG_USER("[py] member -> {} . {}", name, mn); } return py; } ir::PyType TypecheckVisitor::cythonizeIterator(const std::string &name) { LOG_USER("[py] iterfn: {}", name); ir::PyType py{name, ""}; auto cr = ctx->cache->classes[CYTHON_ITER].realizations[name]; auto tc = cr->getType(); for (auto &r : getFunction(getMangledMethod(CYTHON_MODULE, CYTHON_WRAP, "py_type")) ->realizations | std::views::values) { if (extractFuncGeneric(r->getType())->unify(tc, nullptr) >= 0) { py.typePtrHook = r->ir; break; } } for (auto &n : std::vector{"_iter", "_iternext"}) { auto fnn = getOverloads(getClass(CYTHON_ITER)->methods[n]).front(); TypePtr t; if (n == "_iter") { t = instantiateType(getFunction(fnn)->getType(), tc->getClass()); } else { auto ut = extractClassGeneric(tc)->getClass(); if (!ut) continue; auto fn = findBestMethod(ut, "__iter__", std::vector{ut}); if (!fn) continue; auto fnt = realize(instantiateType(fn, ut)); if (!fnt) continue; t = instantiateType(getFunction(fnn)->getType(), tc->getClass()); unify(extractFuncGeneric(t->getFunc()), fnt->getFunc()->getRetType()); } if (auto rtv = realize(t.get())) { auto f = getFunction(rtv->getFunc()->getFuncName())->realizations[rtv->realizedName()]; if (n == "_iter") py.iter = f->ir; else py.iternext = f->ir; } } py.type = cr->ir; return py; } ir::PyFunction TypecheckVisitor::cythonizeFunction(const std::string &name) { if (auto f = getFunction(name); f->isToplevel) { auto fnName = getMangledMethod(CYTHON_MODULE, CYTHON_WRAP, "wrap_multiple"); auto generics = std::vector{getStdLibType("NoneType")->shared_from_this(), instantiateStatic(f->ast->getName()), instantiateStatic(static_cast(0))}; if (auto ir = realizeIRFunc(getFunction(fnName)->getType(), generics)) { LOG_USER("[py] toplevel -> {} ({}): ({})", getUserFacingName(name), name, f->getType()->debugString(2)); ir::PyFunction fn{getUnmangledName(name), f->ast->getDocstr(), ir, ir::PyFunction::Type::TOPLEVEL, static_cast(f->ast->size())}; fn.keywords = true; return fn; } } return {"", ""}; } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/typecheck/typecheck.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/visitors/typecheck/ctx.h" #include "codon/parser/visitors/visitor.h" namespace codon::ast { /** * Visitor that infers expression types and performs type-guided transformations. * * -> Note: this stage *modifies* the provided AST. Clone it before simplification * if you need it intact. */ class TypecheckVisitor : public ReplacingCallbackASTVisitor { /// Shared simplification context. std::shared_ptr ctx; /// Statements to prepend before the current statement. std::shared_ptr> prependStmts = nullptr; SuiteStmt *preamble = nullptr; /// Each new expression is stored here (as @c visit does not return anything) and /// later returned by a @c transform call. Expr *resultExpr = nullptr; /// Each new statement is stored here (as @c visit does not return anything) and /// later returned by a @c transform call. Stmt *resultStmt = nullptr; public: // static Stmt * apply(Cache *cache, const Stmt * &stmts); static Stmt * apply(Cache *cache, Stmt *node, const std::string &file, const std::unordered_map &defines = {}, const std::unordered_map &earlyDefines = {}, bool barebones = false); static Stmt *apply(const std::shared_ptr &cache, Stmt *node, const std::string &file = ""); private: static void loadStdLibrary(Cache *, SuiteStmt *, const std::unordered_map &, bool); public: explicit TypecheckVisitor( std::shared_ptr ctx, SuiteStmt *preamble = nullptr, const std::shared_ptr> &stmts = nullptr); public: // Convenience transformators Expr *transform(Expr *e) override; Expr *transform(Expr *expr, bool allowTypes); Stmt *transform(Stmt *s) override; Expr *transformType(Expr *expr, bool simple = false); private: void defaultVisit(Expr *e) override; void defaultVisit(Stmt *s) override; private: // Node typechecking rules /* Basic type expressions (basic.cpp) */ void visit(NoneExpr *) override; void visit(BoolExpr *) override; void visit(IntExpr *) override; Expr *transformInt(IntExpr *); void visit(FloatExpr *) override; Expr *transformFloat(FloatExpr *); void visit(StringExpr *) override; /* Identifier access expressions (access.cpp) */ void visit(IdExpr *) override; void checkCapture(const TypeContext::Item &) const; void visit(DotExpr *) override; std::pair getImport(const std::vector &); Expr *getClassMember(DotExpr *); types::FuncType *getDispatch(const std::string &); /* Collection and comprehension expressions (collections.cpp) */ void visit(TupleExpr *) override; void visit(ListExpr *) override; void visit(SetExpr *) override; void visit(DictExpr *) override; Expr *transformComprehension(const std::string &, const std::string &, std::vector &); void visit(GeneratorExpr *) override; /* Conditional expression and statements (cond.cpp) */ void visit(RangeExpr *) override; void visit(IfExpr *) override; void visit(IfStmt *) override; void visit(MatchStmt *) override; Stmt *transformPattern(Expr *, Expr *, Stmt *); /* Operators (op.cpp) */ void visit(UnaryExpr *) override; Expr *evaluateStaticUnary(const UnaryExpr *); void visit(BinaryExpr *) override; Expr *evaluateStaticBinary(const BinaryExpr *); Expr *transformBinarySimple(const BinaryExpr *); Expr *transformBinaryIs(const BinaryExpr *); std::pair getMagic(const std::string &) const; Expr *transformBinaryInplaceMagic(BinaryExpr *, bool); Expr *transformBinaryMagic(const BinaryExpr *); void visit(ChainBinaryExpr *) override; void visit(PipeExpr *) override; void visit(IndexExpr *) override; std::pair transformStaticTupleIndex(types::ClassType *, Expr *, Expr *); int64_t translateIndex(int64_t, int64_t, bool = false) const; int64_t sliceAdjustIndices(int64_t, int64_t *, int64_t *, int64_t) const; void visit(InstantiateExpr *) override; void visit(SliceExpr *) override; /* Calls (call.cpp) */ void visit(PrintStmt *) override; /// Holds partial call information for a CallExpr. struct PartialCallData { bool isPartial = false; // true if the call is partial std::string var; // set if calling a partial type itself std::string known; // mask of known arguments Expr *args = nullptr, *kwArgs = nullptr; // partial *args/**kwargs expressions }; void visit(StarExpr *) override; void visit(KeywordStarExpr *) override; void visit(EllipsisExpr *) override; void visit(CallExpr *) override; static void validateCall(CallExpr *expr); bool transformCallArgs(CallExpr *); std::pair, Expr *> getCalleeFn(CallExpr *, PartialCallData &); Expr *callReorderArguments(types::FuncType *, CallExpr *, PartialCallData &); bool typecheckCallArgs(types::FuncType *, std::vector &, const PartialCallData &); std::pair transformSpecialCall(CallExpr *); std::vector getStaticSuperTypes(types::ClassType *); std::vector getRTTISuperTypes(types::ClassType *); /* Assignments (assign.cpp) */ void visit(AssignExpr *) override; void visit(AssignStmt *) override; Stmt *unpackAssignment(Expr *lhs, Expr *rhs); Stmt *transformUpdate(AssignStmt *); Stmt *transformAssignment(AssignStmt *, bool = false); void visit(DelStmt *) override; void visit(AssignMemberStmt *) override; std::pair transformInplaceUpdate(AssignStmt *); /* Imports (import.cpp) */ void visit(ImportStmt *) override; Stmt *transformSpecialImport(const ImportStmt *); std::vector getImportPath(Expr *, size_t = 0) const; Stmt *transformCImport(const std::string &, const std::vector &, Expr *, const std::string &); Stmt *transformCVarImport(const std::string &, Expr *, const std::string &); Stmt *transformCDLLImport(Expr *, const std::string &, const std::vector &, Expr *, const std::string &, bool); Stmt *transformPythonImport(Expr *, const std::vector &, Expr *, const std::string &); Stmt *transformNewImport(const ImportFile &); /* Loops (loops.cpp) */ void visit(BreakStmt *) override; void visit(ContinueStmt *) override; void visit(WhileStmt *) override; void visit(ForStmt *) override; Expr *transformForDecorator(Expr *); std::pair transformStaticForLoop(const ForStmt *); /* Errors and exceptions (error.cpp) */ void visit(AssertStmt *) override; void visit(TryStmt *) override; void visit(ThrowStmt *) override; void visit(WithStmt *) override; /* Functions (function.cpp) */ void visit(YieldExpr *) override; void visit(AwaitExpr *) override; void visit(ReturnStmt *) override; void visit(YieldStmt *) override; void visit(YieldFromStmt *) override; void visit(LambdaExpr *) override; void visit(GlobalStmt *) override; void visit(FunctionStmt *) override; Stmt *transformPythonDefinition(const std::string &, const std::vector &, Expr *, Stmt *); Stmt *transformLLVMDefinition(Stmt *); std::tuple getDecorator(Expr *); std::shared_ptr getFuncTypeBase(size_t); private: /* Classes (class.cpp) */ void visit(ClassStmt *) override; std::vector parseBaseClasses(std::vector &, std::vector &, const Stmt *, const std::string &, const Expr *, types::ClassType *); bool autoDeduceMembers(ClassStmt *, std::vector &); static std::vector getClassMethods(Stmt *s); void transformNestedClasses(const ClassStmt *, std::vector &, std::vector &, std::vector &); Stmt *codegenMagic(const std::string &, Expr *, const std::vector &, bool); int generateKwId(const std::vector & = {}) const; public: types::ClassType *generateTuple(size_t n, bool = true); private: /* The rest (typecheck.cpp) */ void visit(SuiteStmt *) override; void visit(ExprStmt *) override; void visit(StmtExpr *) override; void visit(CommentStmt *stmt) override; void visit(CustomStmt *) override; void visit(DirectiveStmt *) override; public: /* Type inference (infer.cpp) */ types::Type *unify(types::Type *a, types::Type *b) const; types::Type *unify(types::Type *a, types::TypePtr &&b) { return unify(a, b.get()); } types::Type *realize(types::Type *); types::TypePtr &&realize(types::TypePtr &&t) { realize(t.get()); return std::move(t); } private: Stmt *inferTypes(Stmt *, bool isToplevel = false); types::Type *realizeFunc(types::FuncType *, bool = false); types::Type *realizeType(types::ClassType *); SuiteStmt *generateSpecialAst(types::FuncType *); codon::ir::types::Type *makeIRType(types::ClassType *); codon::ir::Func * makeIRFunction(const std::shared_ptr &); private: types::FuncType *findBestMethod(types::ClassType *typ, const std::string &member, const std::vector &args); types::FuncType *findBestMethod(types::ClassType *typ, const std::string &member, const std::vector &args); types::FuncType * findBestMethod(types::ClassType *typ, const std::string &member, const std::vector> &args); int canCall(types::FuncType *, const std::vector &, types::ClassType * = nullptr); std::vector findMatchingMethods( types::ClassType *typ, const std::vector &methods, const std::vector &args, types::ClassType *part = nullptr); Expr *castToSuperClass(Expr *expr, types::ClassType *superTyp, bool = false); void prepareVTables(); std::vector> extractNamedTuple(Expr *); std::vector getClassFieldTypes(types::ClassType *); static std::vector> findEllipsis(Expr *); public: bool wrapExpr(Expr **expr, types::Type *expectedType, types::FuncType *callee = nullptr, bool allowUnwrap = true); std::tuple> canWrapExpr(types::Type *exprType, types::Type *expectedType, types::FuncType *callee = nullptr, bool allowUnwrap = true, bool isEllipsis = false); std::vector getClassFields(types::ClassType *) const; std::shared_ptr getCtx() const { return ctx; } Expr *generatePartialCall(const std::string &, types::FuncType *, Expr * = nullptr, Expr * = nullptr); friend struct Cache; friend struct TypeContext; friend class types::CallableTrait; friend class types::UnionType; private: // Helpers std::shared_ptr>> unpackTupleTypes(const Expr *); std::tuple> transformStaticLoopCall(Expr *, SuiteStmt **, Expr *, const std::function &, bool = false); public: template Tn *N(Ts &&...args) { Tn *t = ctx->cache->N(std::forward(args)...); t->setSrcInfo(getSrcInfo()); if (cast(t) && getTime()) t->setAttribute(Attr::ExprTime, getTime()); return t; } template Tn *NC(Ts &&...args) { Tn *t = ctx->cache->N(std::forward(args)...); return t; } private: template void log(const std::string &prefix, Ts &&...args) { fmt::print(codon::getLogger().log, fmt::runtime("[{}] [{}${}]: " + prefix + "\n"), ctx->getSrcInfo(), ctx->getBaseName(), ctx->getBase()->iteration, std::forward(args)...); } template void logfile(const std::string &file, const std::string &prefix, Ts &&...args) { if (in(ctx->getSrcInfo().file, file)) fmt::print(codon::getLogger().log, fmt::runtime("[{}] [{}${}]: " + prefix + "\n"), ctx->getSrcInfo(), ctx->getBaseName(), ctx->getBase()->iteration, std::forward(args)...); } public: types::Type *extractType(types::Type *t) const; types::Type *extractType(Expr *e) const; types::Type *extractType(const std::string &) const; types::ClassType *extractClassType(Expr *e) const; types::ClassType *extractClassType(types::Type *t) const; types::ClassType *extractClassType(const std::string &s) const; static bool isUnbound(types::Type *t); static bool isUnbound(const Expr *e); bool hasOverloads(const std::string &root) const; std::vector getOverloads(const std::string &root) const; std::string getUnmangledName(const std::string &s) const; std::string getUserFacingName(const std::string &s) const; Cache::Class *getClass(const std::string &t) const; Cache::Class *getClass(types::Type *t) const; Cache::Function *getFunction(const std::string &n) const; Cache::Function *getFunction(types::Type *t) const; Cache::Class::ClassRealization *getClassRealization(types::Type *t) const; std::string getRootName(const types::FuncType *t) const; static bool isTypeExpr(const Expr *e); Cache::Module *getImport(const std::string &s) const; static bool isDispatch(const std::string &s); static bool isDispatch(const FunctionStmt *ast); static bool isDispatch(types::Type *f); bool isHeterogenous(types::Type *); void addClassGenerics(types::ClassType *typ, bool func = false, bool onlyMangled = false, bool instantiate = false); template auto withClassGenerics(types::ClassType *typ, F fn, bool func = false, bool onlyMangled = false, bool instantiate = false) { ctx->addBlock(); addClassGenerics(typ, func, onlyMangled, instantiate); auto t = fn(); ctx->popBlock(); return t; } types::TypePtr instantiateTypeVar(types::Type *t); void registerGlobal(const std::string &s) const; types::ClassType *getStdLibType(const std::string &type) const; types::Type *extractClassGeneric(types::Type *t, size_t idx = 0) const; types::Type *extractFuncGeneric(types::Type *t, size_t idx = 0) const; types::Type *extractFuncArgType(types::Type *t, size_t idx = 0) const; std::string getClassMethod(types::Type *typ, const std::string &member) const; std::string getTemporaryVar(const std::string &s) const; static bool isImportFn(const std::string &s); int64_t getTime() const; types::Type *getUnderlyingStaticType(types::Type *t) const; int64_t getIntLiteral(types::Type *t, size_t pos = 0) const; bool getBoolLiteral(types::Type *t, size_t pos = 0) const; std::string getStrLiteral(types::Type *t, size_t pos = 0) const; Expr *getParamType(types::Type *t); static bool hasSideEffect(Expr *); static Expr *getHeadExpr(Expr *e); Expr *transformNamedTuple(CallExpr *); Expr *transformFunctoolsPartial(CallExpr *); Expr *transformSuperF(CallExpr *); Expr *transformSuper(); Expr *transformPtr(CallExpr *); Expr *transformArray(CallExpr *); Expr *transformIsInstance(CallExpr *); Expr *transformStaticLen(CallExpr *); Expr *transformHasAttr(CallExpr *); Expr *transformGetAttr(CallExpr *); Expr *transformSetAttr(CallExpr *); Expr *transformCompileError(CallExpr *) const; Expr *transformTupleFn(CallExpr *); Expr *transformTypeFn(CallExpr *); Expr *transformRealizedFn(CallExpr *); Expr *transformStaticPrintFn(CallExpr *) const; Expr *transformHasRttiFn(const CallExpr *); Expr *transformStaticFnCanCall(CallExpr *); Expr *transformStaticFnArgHasType(CallExpr *); Expr *transformStaticFnArgGetType(CallExpr *); Expr *transformStaticFnArgs(CallExpr *); Expr *transformStaticFnHasDefault(CallExpr *); Expr *transformStaticFnGetDefault(CallExpr *); Expr *transformStaticFnWrapCallArgs(CallExpr *); Expr *transformStaticVars(CallExpr *); Expr *transformStaticTupleType(const CallExpr *); Expr *transformStaticFormat(CallExpr *); Expr *transformStaticIntToStr(CallExpr *); SuiteStmt *generateClassPopulateVTablesAST(); SuiteStmt *generateBaseDerivedDistAST(types::FuncType *); FunctionStmt *generateThunkAST(const types::FuncType *fp, types::ClassType *base, const types::ClassType *derived); SuiteStmt *generateGetThunkIDAst(types::FuncType *); SuiteStmt *generateFunctionCallInternalAST(types::FuncType *); SuiteStmt *generateUnionNewAST(const types::FuncType *); SuiteStmt *generateUnionTagAST(types::FuncType *); SuiteStmt *generateNamedKeysAST(types::FuncType *); SuiteStmt *generateTupleMulAST(types::FuncType *); std::vector populateStaticTupleLoop(Expr *, const std::vector &); std::vector populateSimpleStaticRangeLoop(Expr *, const std::vector &); std::vector populateStaticRangeLoop(Expr *, const std::vector &); std::vector populateStaticFnOverloadsLoop(Expr *, const std::vector &); std::vector populateStaticEnumerateLoop(Expr *, const std::vector &); std::vector populateStaticVarsLoop(Expr *, const std::vector &); std::vector populateStaticVarTypesLoop(Expr *, const std::vector &); std::vector populateStaticHeterogenousTupleLoop(Expr *, const std::vector &); ParserErrors findTypecheckErrors(Stmt *n) const; public: /// Create an unbound type with the provided typechecking level. std::shared_ptr instantiateUnbound(const SrcInfo &info, int level) const; std::shared_ptr instantiateUnbound(const SrcInfo &info) const; std::shared_ptr instantiateUnbound() const; /// Call `type->instantiate`. /// Prepare the generic instantiation table with the given a generic parameter. /// Example: when instantiating List[T].foo, generics=List[int].foo will ensure that /// T=int. types::TypePtr instantiateType(const SrcInfo &info, types::Type *type, types::ClassType *generics = nullptr) const; types::TypePtr instantiateType(const SrcInfo &info, types::Type *root, const std::vector &generics) const; template std::shared_ptr instantiateType(T *type, types::ClassType *generics = nullptr) { return std::static_pointer_cast( instantiateType(getSrcInfo(), std::move(type), generics)); } template std::shared_ptr instantiateType(T *root, const std::vector &generics) { return std::static_pointer_cast( instantiateType(getSrcInfo(), std::move(root), generics)); } std::shared_ptr instantiateStatic(int64_t i) const { return std::make_shared(ctx->cache, i); } std::shared_ptr instantiateStatic(const std::string &s) const { return std::make_shared(ctx->cache, s); } std::shared_ptr instantiateStatic(bool i) const { return std::make_shared(ctx->cache, i); } /// Returns the list of generic methods that correspond to typeName.method. std::vector findMethod(types::ClassType *type, const std::string &method, bool hideShadowed = true); /// Returns the generic type of typeName.member, if it exists (nullptr otherwise). /// Special cases: __elemsize__ and __atomic__. Cache::Class::ClassField *findMember(types::ClassType *, const std::string &) const; using ReorderDoneFn = std::function> &, bool)>; using ReorderErrorFn = std::function; /// Reorders a given vector or named arguments (consisting of names and the /// corresponding types) according to the signature of a given function. /// Returns the reordered vector and an associated reordering score (missing /// default arguments' score is half of the present arguments). /// Score is -1 if the given arguments cannot be reordered. /// @param known Bitmask that indicated if an argument is already provided /// (partial function) or not. int reorderNamedArgs(const types::FuncType *func, const std::vector &args, const ReorderDoneFn &onDone, const ReorderErrorFn &onError, const std::string &known = "") const; static bool isCanonicalName(const std::string &name); static types::FuncType *extractFunction(types::Type *t); ir::PyType cythonizeClass(const std::string &name); ir::PyType cythonizeIterator(const std::string &name); ir::PyFunction cythonizeFunction(const std::string &name); ir::Func *realizeIRFunc(types::FuncType *fn, const std::vector &generics = {}); // types::Type *getType(const std::string &); }; } // namespace codon::ast ================================================ FILE: codon/parser/visitors/visitor.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "visitor.h" #include "codon/parser/ast.h" namespace codon::ast { void ASTVisitor::defaultVisit(Expr *expr) {} void ASTVisitor::defaultVisit(Stmt *stmt) {} void ASTVisitor::visit(NoneExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(BoolExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(IntExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(FloatExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(StringExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(IdExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(StarExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(KeywordStarExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(TupleExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(ListExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(SetExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(DictExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(GeneratorExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(IfExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(UnaryExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(BinaryExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(ChainBinaryExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(PipeExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(IndexExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(CallExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(DotExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(SliceExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(EllipsisExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(LambdaExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(YieldExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(AwaitExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(AssignExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(RangeExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(InstantiateExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(StmtExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(SuiteStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(BreakStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(ContinueStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(ExprStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(AssignStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(AssignMemberStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(DelStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(PrintStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(ReturnStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(YieldStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(AssertStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(WhileStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(ForStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(IfStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(MatchStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(ImportStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(TryStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(ExceptStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(GlobalStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(ThrowStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(FunctionStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(ClassStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(YieldFromStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(WithStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(CustomStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(DirectiveStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(CommentStmt *stmt) { defaultVisit(stmt); } } // namespace codon::ast ================================================ FILE: codon/parser/visitors/visitor.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include "codon/parser/ast.h" #include namespace codon::ast { /** * Base Seq AST visitor. * Each visit() by default calls an appropriate defaultVisit(). */ struct ASTVisitor { virtual ~ASTVisitor() {} protected: /// Default expression node visitor if a particular visitor is not overloaded. virtual void defaultVisit(Expr *expr); /// Default statement node visitor if a particular visitor is not overloaded. virtual void defaultVisit(Stmt *stmt); public: virtual void visit(NoneExpr *); virtual void visit(BoolExpr *); virtual void visit(IntExpr *); virtual void visit(FloatExpr *); virtual void visit(StringExpr *); virtual void visit(IdExpr *); virtual void visit(StarExpr *); virtual void visit(KeywordStarExpr *); virtual void visit(TupleExpr *); virtual void visit(ListExpr *); virtual void visit(SetExpr *); virtual void visit(DictExpr *); virtual void visit(GeneratorExpr *); virtual void visit(IfExpr *); virtual void visit(UnaryExpr *); virtual void visit(BinaryExpr *); virtual void visit(ChainBinaryExpr *); virtual void visit(PipeExpr *); virtual void visit(IndexExpr *); virtual void visit(CallExpr *); virtual void visit(DotExpr *); virtual void visit(SliceExpr *); virtual void visit(EllipsisExpr *); virtual void visit(LambdaExpr *); virtual void visit(YieldExpr *); virtual void visit(AwaitExpr *); virtual void visit(AssignExpr *); virtual void visit(RangeExpr *); virtual void visit(InstantiateExpr *); virtual void visit(StmtExpr *); virtual void visit(AssignMemberStmt *); virtual void visit(SuiteStmt *); virtual void visit(BreakStmt *); virtual void visit(ContinueStmt *); virtual void visit(ExprStmt *); virtual void visit(AssignStmt *); virtual void visit(DelStmt *); virtual void visit(PrintStmt *); virtual void visit(ReturnStmt *); virtual void visit(YieldStmt *); virtual void visit(AssertStmt *); virtual void visit(WhileStmt *); virtual void visit(ForStmt *); virtual void visit(IfStmt *); virtual void visit(MatchStmt *); virtual void visit(ImportStmt *); virtual void visit(TryStmt *); virtual void visit(ExceptStmt *); virtual void visit(GlobalStmt *); virtual void visit(ThrowStmt *); virtual void visit(FunctionStmt *); virtual void visit(ClassStmt *); virtual void visit(YieldFromStmt *); virtual void visit(WithStmt *); virtual void visit(CustomStmt *); virtual void visit(DirectiveStmt *); virtual void visit(CommentStmt *); }; /** * Callback AST visitor. * This visitor extends base ASTVisitor and stores node's source location (SrcObject). * Function transform() will visit a node and return the appropriate transformation. As * each node type (expression or statement) might return a different type, * this visitor is generic for each different return type. */ template struct CallbackASTVisitor : public ASTVisitor, public SrcObject { virtual TE transform(Expr *expr) = 0; virtual TS transform(Stmt *stmt) = 0; /// Convenience method that transforms a vector of nodes. template auto transform(const std::vector &ts) { std::vector r; for (auto &e : ts) r.push_back(transform(e)); return r; } public: void visit(NoneExpr *expr) override {} void visit(BoolExpr *expr) override {} void visit(IntExpr *expr) override {} void visit(FloatExpr *expr) override {} void visit(StringExpr *expr) override {} void visit(IdExpr *expr) override {} void visit(StarExpr *expr) override { transform(expr->expr); } void visit(KeywordStarExpr *expr) override { transform(expr->expr); } void visit(TupleExpr *expr) override { for (auto &i : expr->items) transform(i); } void visit(ListExpr *expr) override { for (auto &i : expr->items) transform(i); } void visit(SetExpr *expr) override { for (auto &i : expr->items) transform(i); } void visit(DictExpr *expr) override { for (auto &i : expr->items) transform(i); } void visit(GeneratorExpr *expr) override { transform(expr->loops); } void visit(IfExpr *expr) override { transform(expr->cond); transform(expr->ifexpr); transform(expr->elsexpr); } void visit(UnaryExpr *expr) override { transform(expr->expr); } void visit(BinaryExpr *expr) override { transform(expr->lexpr); transform(expr->rexpr); } void visit(ChainBinaryExpr *expr) override { for (auto &val : expr->exprs | std::views::values) transform(val); } void visit(PipeExpr *expr) override { for (auto &e : expr->items) transform(e.expr); } void visit(IndexExpr *expr) override { transform(expr->expr); transform(expr->index); } void visit(CallExpr *expr) override { transform(expr->expr); for (auto &a : expr->items) transform(a.value); } void visit(DotExpr *expr) override { transform(expr->expr); } void visit(SliceExpr *expr) override { transform(expr->start); transform(expr->stop); transform(expr->step); } void visit(EllipsisExpr *expr) override {} void visit(LambdaExpr *expr) override { for (auto &a : expr->items) { transform(a.type); transform(a.defaultValue); } transform(expr->expr); } void visit(YieldExpr *expr) override {} void visit(AwaitExpr *expr) override { transform(expr->expr); } void visit(AssignExpr *expr) override { transform(expr->var); transform(expr->expr); } void visit(RangeExpr *expr) override { transform(expr->start); transform(expr->stop); } void visit(InstantiateExpr *expr) override { transform(expr->expr); for (auto &e : expr->items) transform(e); } void visit(StmtExpr *expr) override { for (auto &s : expr->items) transform(s); transform(expr->expr); } void visit(SuiteStmt *stmt) override { for (auto &s : stmt->items) transform(s); } void visit(BreakStmt *stmt) override {} void visit(ContinueStmt *stmt) override {} void visit(ExprStmt *stmt) override { transform(stmt->expr); } void visit(AssignStmt *stmt) override { transform(stmt->lhs); transform(stmt->rhs); transform(stmt->type); } void visit(AssignMemberStmt *stmt) override { transform(stmt->lhs); transform(stmt->rhs); transform(stmt->type); } void visit(DelStmt *stmt) override { transform(stmt->expr); } void visit(PrintStmt *stmt) override { for (auto &e : stmt->items) transform(e); } void visit(ReturnStmt *stmt) override { transform(stmt->expr); } void visit(YieldStmt *stmt) override { transform(stmt->expr); } void visit(AssertStmt *stmt) override { transform(stmt->expr); transform(stmt->message); } void visit(WhileStmt *stmt) override { transform(stmt->cond); transform(stmt->suite); transform(stmt->elseSuite); } void visit(ForStmt *stmt) override { transform(stmt->var); transform(stmt->iter); transform(stmt->suite); transform(stmt->elseSuite); transform(stmt->decorator); for (auto &a : stmt->ompArgs) transform(a.value); } void visit(IfStmt *stmt) override { transform(stmt->cond); transform(stmt->ifSuite); transform(stmt->elseSuite); } void visit(MatchStmt *stmt) override { transform(stmt->expr); for (auto &m : stmt->items) { transform(m.pattern); transform(m.guard); transform(m.suite); } } void visit(ImportStmt *stmt) override { transform(stmt->from); transform(stmt->what); for (auto &a : stmt->args) { transform(a.type); transform(a.defaultValue); } transform(stmt->ret); } void visit(TryStmt *stmt) override { transform(stmt->suite); for (auto &a : stmt->items) transform(a); transform(stmt->elseSuite); transform(stmt->finally); } void visit(ExceptStmt *stmt) override { transform(stmt->exc); transform(stmt->suite); } void visit(GlobalStmt *stmt) override {} void visit(ThrowStmt *stmt) override { transform(stmt->expr); transform(stmt->from); } void visit(FunctionStmt *stmt) override { transform(stmt->ret); for (auto &a : stmt->items) { transform(a.type); transform(a.defaultValue); } transform(stmt->suite); for (auto &d : stmt->decorators) transform(d); } void visit(ClassStmt *stmt) override { for (auto &a : stmt->items) { transform(a.type); transform(a.defaultValue); } transform(stmt->suite); for (auto &d : stmt->decorators) transform(d); for (auto &d : stmt->baseClasses) transform(d); for (auto &d : stmt->staticBaseClasses) transform(d); } void visit(YieldFromStmt *stmt) override { transform(stmt->expr); } void visit(WithStmt *stmt) override { for (auto &a : stmt->items) transform(a); transform(stmt->suite); } void visit(CustomStmt *stmt) override { transform(stmt->expr); transform(stmt->suite); } }; /** * Callback AST visitor. * This visitor extends base ASTVisitor and stores node's source location (SrcObject). * Function transform() will visit a node and return the appropriate transformation. As * each node type (expression or statement) might return a different type, * this visitor is generic for each different return type. */ struct ReplacingCallbackASTVisitor : public CallbackASTVisitor { public: void visit(StarExpr *expr) override { expr->expr = transform(expr->expr); } void visit(KeywordStarExpr *expr) override { expr->expr = transform(expr->expr); } void visit(TupleExpr *expr) override { for (auto &i : expr->items) i = transform(i); } void visit(ListExpr *expr) override { for (auto &i : expr->items) i = transform(i); } void visit(SetExpr *expr) override { for (auto &i : expr->items) i = transform(i); } void visit(DictExpr *expr) override { for (auto &i : expr->items) i = transform(i); } void visit(GeneratorExpr *expr) override { expr->loops = transform(expr->loops); } void visit(IfExpr *expr) override { expr->cond = transform(expr->cond); expr->ifexpr = transform(expr->ifexpr); expr->elsexpr = transform(expr->elsexpr); } void visit(UnaryExpr *expr) override { expr->expr = transform(expr->expr); } void visit(BinaryExpr *expr) override { expr->lexpr = transform(expr->lexpr); expr->rexpr = transform(expr->rexpr); } void visit(ChainBinaryExpr *expr) override { for (auto &val : expr->exprs | std::views::values) val = transform(val); } void visit(PipeExpr *expr) override { for (auto &e : expr->items) e.expr = transform(e.expr); } void visit(IndexExpr *expr) override { expr->expr = transform(expr->expr); expr->index = transform(expr->index); } void visit(CallExpr *expr) override { expr->expr = transform(expr->expr); for (auto &a : expr->items) a.value = transform(a.value); } void visit(DotExpr *expr) override { expr->expr = transform(expr->expr); } void visit(SliceExpr *expr) override { expr->start = transform(expr->start); expr->stop = transform(expr->stop); expr->step = transform(expr->step); } void visit(EllipsisExpr *expr) override {} void visit(LambdaExpr *expr) override { for (auto &a : expr->items) { a.type = transform(a.type); a.defaultValue = transform(a.defaultValue); } expr->expr = transform(expr->expr); } void visit(YieldExpr *expr) override {} void visit(AwaitExpr *expr) override { expr->expr = transform(expr->expr); } void visit(AssignExpr *expr) override { expr->var = transform(expr->var); expr->expr = transform(expr->expr); } void visit(RangeExpr *expr) override { expr->start = transform(expr->start); expr->stop = transform(expr->stop); } void visit(InstantiateExpr *expr) override { expr->expr = transform(expr->expr); for (auto &e : expr->items) e = transform(e); } void visit(StmtExpr *expr) override { for (auto &s : expr->items) s = transform(s); expr->expr = transform(expr->expr); } void visit(SuiteStmt *stmt) override { for (auto &s : stmt->items) s = transform(s); } void visit(ExprStmt *stmt) override { stmt->expr = transform(stmt->expr); } void visit(AssignStmt *stmt) override { stmt->lhs = transform(stmt->lhs); stmt->rhs = transform(stmt->rhs); stmt->type = transform(stmt->type); } void visit(AssignMemberStmt *stmt) override { stmt->lhs = transform(stmt->lhs); stmt->rhs = transform(stmt->rhs); stmt->type = transform(stmt->type); } void visit(DelStmt *stmt) override { stmt->expr = transform(stmt->expr); } void visit(PrintStmt *stmt) override { for (auto &e : stmt->items) e = transform(e); } void visit(ReturnStmt *stmt) override { stmt->expr = transform(stmt->expr); } void visit(YieldStmt *stmt) override { stmt->expr = transform(stmt->expr); } void visit(AssertStmt *stmt) override { stmt->expr = transform(stmt->expr); stmt->message = transform(stmt->message); } void visit(WhileStmt *stmt) override { stmt->cond = transform(stmt->cond); stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); stmt->elseSuite = SuiteStmt::wrap(transform(stmt->elseSuite)); } void visit(ForStmt *stmt) override { stmt->var = transform(stmt->var); stmt->iter = transform(stmt->iter); stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); stmt->elseSuite = SuiteStmt::wrap(transform(stmt->elseSuite)); stmt->decorator = transform(stmt->decorator); for (auto &a : stmt->ompArgs) a.value = transform(a.value); } void visit(IfStmt *stmt) override { stmt->cond = transform(stmt->cond); stmt->ifSuite = SuiteStmt::wrap(transform(stmt->ifSuite)); stmt->elseSuite = SuiteStmt::wrap(transform(stmt->elseSuite)); } void visit(MatchStmt *stmt) override { stmt->expr = transform(stmt->expr); for (auto &m : stmt->items) { m.pattern = transform(m.pattern); m.guard = transform(m.guard); m.suite = SuiteStmt::wrap(transform(m.suite)); } } void visit(ImportStmt *stmt) override { stmt->from = transform(stmt->from); stmt->what = transform(stmt->what); for (auto &a : stmt->args) { a.type = transform(a.type); a.defaultValue = transform(a.defaultValue); } stmt->ret = transform(stmt->ret); } void visit(TryStmt *stmt) override { stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); for (auto &a : stmt->items) a = static_cast(transform(a)); stmt->elseSuite = SuiteStmt::wrap(transform(stmt->elseSuite)); stmt->finally = SuiteStmt::wrap(transform(stmt->finally)); } void visit(ExceptStmt *stmt) override { stmt->exc = transform(stmt->exc); stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); } void visit(GlobalStmt *stmt) override {} void visit(ThrowStmt *stmt) override { stmt->expr = transform(stmt->expr); stmt->from = transform(stmt->from); } void visit(FunctionStmt *stmt) override { stmt->ret = transform(stmt->ret); for (auto &a : stmt->items) { a.type = transform(a.type); a.defaultValue = transform(a.defaultValue); } stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); for (auto &d : stmt->decorators) d = transform(d); } void visit(ClassStmt *stmt) override { for (auto &a : stmt->items) { a.type = transform(a.type); a.defaultValue = transform(a.defaultValue); } stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); for (auto &d : stmt->decorators) d = transform(d); for (auto &d : stmt->baseClasses) d = transform(d); for (auto &d : stmt->staticBaseClasses) d = transform(d); } void visit(YieldFromStmt *stmt) override { stmt->expr = transform(stmt->expr); } void visit(WithStmt *stmt) override { for (auto &a : stmt->items) a = transform(a); stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); } void visit(CustomStmt *stmt) override { stmt->expr = transform(stmt->expr); stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); } }; } // namespace codon::ast ================================================ FILE: codon/runtime/exc.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "codon/runtime/lib.h" #include "llvm/BinaryFormat/Dwarf.h" #include #include #include #include #include #include #include #include #include #include #include #if defined(__APPLE__) && (__arm64__ || __aarch64__) #define APPLE_SILICON #endif #ifdef APPLE_SILICON #include "llvm/BinaryFormat/MachO.h" #include // https://github.com/llvm/llvm-project/issues/49036 // Define a minimal mach header for JIT'd code. static llvm::MachO::mach_header_64 fake_mach_header = { .magic = llvm::MachO::MH_MAGIC_64, .cputype = llvm::MachO::CPU_TYPE_ARM64, .cpusubtype = llvm::MachO::CPU_SUBTYPE_ARM64_ALL, .filetype = llvm::MachO::MH_DYLIB, .ncmds = 0, .sizeofcmds = 0, .flags = 0, .reserved = 0}; // Declare libunwind SPI types and functions. struct unw_dynamic_unwind_sections { uintptr_t dso_base; uintptr_t dwarf_section; size_t dwarf_section_length; uintptr_t compact_unwind_section; size_t compact_unwind_section_length; }; int find_dynamic_unwind_sections(uintptr_t addr, unw_dynamic_unwind_sections *info) { info->dso_base = (uintptr_t)&fake_mach_header; info->dwarf_section = 0; info->dwarf_section_length = 0; info->compact_unwind_section = 0; info->compact_unwind_section_length = 0; return 1; } // Typedef for callback above. typedef int (*unw_find_dynamic_unwind_sections)( uintptr_t addr, struct unw_dynamic_unwind_sections *info); #endif struct BacktraceFrame { char *function; char *filename; uintptr_t pc; int32_t lineno; }; struct Backtrace { static const size_t LIMIT = 20; struct BacktraceFrame *frames; size_t count; void push_back(const char *function, const char *filename, uintptr_t pc, int32_t lineno) { if (count >= LIMIT || !function || !filename) { return; } else if (count == 0) { frames = (BacktraceFrame *)seq_alloc(LIMIT * sizeof(*frames)); } size_t functionLen = strlen(function) + 1; auto *functionDup = (char *)seq_alloc_atomic(functionLen); memcpy(functionDup, function, functionLen); size_t filenameLen = strlen(filename) + 1; auto *filenameDup = (char *)seq_alloc_atomic(filenameLen); memcpy(filenameDup, filename, filenameLen); frames[count++] = {functionDup, filenameDup, pc, lineno}; } void push_back(uintptr_t pc) { push_back("", "", pc, 0); } void free() { for (auto i = 0; i < count; i++) { auto *frame = &frames[i]; seq_free(frame->function); seq_free(frame->filename); } seq_free(frames); frames = nullptr; count = 0; } }; void seq_backtrace_error_callback(void *data, const char *msg, int errnum) { // printf("seq_backtrace_error_callback: %s (errnum = %d)\n", msg, errnum); } int seq_backtrace_full_callback(void *data, uintptr_t pc, const char *filename, int lineno, const char *function) { auto *bt = ((Backtrace *)data); bt->push_back(function, filename, pc, lineno); return (bt->count < Backtrace::LIMIT) ? 0 : 1; } int seq_backtrace_simple_callback(void *data, uintptr_t pc) { auto *bt = ((Backtrace *)data); bt->push_back(pc); return (bt->count < Backtrace::LIMIT) ? 0 : 1; } /* * This is largely based on * llvm/examples/ExceptionDemo/ExceptionDemo.cpp */ namespace { template static uintptr_t ReadType(const uint8_t *&p) { Type_ value; memcpy(&value, p, sizeof(Type_)); p += sizeof(Type_); return static_cast(value); } } // namespace // Note: this should match Codon definition struct TypeInfo { seq_int_t id; seq_int_t *parent_ids; seq_int_t n_parent_ids; seq_str_t raw_name; // other fields do not need to be included }; struct RTTIObject { void *data; TypeInfo *type; }; struct CodonBaseExceptionType { int type; }; struct CodonBaseException { void *obj; Backtrace bt; _Unwind_Exception unwindException; }; struct CodonExceptionHeader { seq_str_t msg; seq_str_t func; seq_str_t file; seq_int_t line; seq_int_t col; void *python_type; void *cause; }; void seq_exc_init(int flags) { #ifdef APPLE_SILICON if (!(flags & SEQ_FLAG_STANDALONE)) { if (auto *unw_add_find_dynamic_unwind_sections = (int (*)(unw_find_dynamic_unwind_sections find_dynamic_unwind_sections)) dlsym(RTLD_DEFAULT, "__unw_add_find_dynamic_unwind_sections")) { unw_add_find_dynamic_unwind_sections(find_dynamic_unwind_sections); } } #endif } static void seq_delete_exc(_Unwind_Exception *expToDelete) { if (!expToDelete || expToDelete->exception_class != SEQ_EXCEPTION_CLASS) return; auto *exc = (CodonBaseException *)((char *)expToDelete + seq_exc_offset()); if (seq_flags & SEQ_FLAG_DEBUG) { exc->bt.free(); } seq_free(exc); } static void seq_delete_unwind_exc(_Unwind_Reason_Code reason, _Unwind_Exception *expToDelete) { seq_delete_exc(expToDelete); } static struct backtrace_state *state = nullptr; static std::mutex stateLock; SEQ_FUNC void *seq_alloc_exc(void *obj) { const size_t size = sizeof(CodonBaseException); auto *e = (CodonBaseException *)memset(seq_alloc(size), 0, size); assert(e); e->obj = obj; e->unwindException.exception_class = SEQ_EXCEPTION_CLASS; e->unwindException.exception_cleanup = seq_delete_unwind_exc; if (seq_flags & SEQ_FLAG_DEBUG) { e->bt.frames = nullptr; e->bt.count = 0; if (seq_flags & SEQ_FLAG_STANDALONE) { if (!state) { stateLock.lock(); if (!state) state = backtrace_create_state(/*filename=*/nullptr, /*threaded=*/1, seq_backtrace_error_callback, /*data=*/nullptr); stateLock.unlock(); } backtrace_full(state, /*skip=*/1, seq_backtrace_full_callback, seq_backtrace_error_callback, &e->bt); } else { backtrace_simple(/*state=*/nullptr, /*skip=*/1, seq_backtrace_simple_callback, seq_backtrace_error_callback, &e->bt); } } return &(e->unwindException); } static void print_from_last_dot(seq_str_t s, std::ostringstream &buf) { char *p = s.str; int64_t n = s.len; for (int64_t i = n - 1; i >= 0; i--) { if (p[i] == '.') { p += (i + 1); n -= (i + 1); break; } } buf.write(p, (size_t)n); } static std::function jitErrorCallback; SEQ_FUNC void seq_terminate(void *exc) { auto *base = (CodonBaseException *)((char *)exc + seq_exc_offset()); void *obj = base->obj; auto *hdr = *(CodonExceptionHeader **)obj; auto tname = ((RTTIObject *)obj)->type->raw_name; if (std::string(tname.str, tname.len) == "SystemExit") { seq_int_t status = *(seq_int_t *)(hdr + 1); exit((int)status); } std::ostringstream buf; if (seq_flags & SEQ_FLAG_CAPTURE_OUTPUT) buf << codon::runtime::getCapturedOutput(); buf << "\033[1m"; print_from_last_dot(tname, buf); if (hdr->msg.len > 0) { buf << ": \033[0m"; buf.write(hdr->msg.str, hdr->msg.len); } else { buf << "\033[0m"; } buf << "\n\n\033[1mRaised from:\033[0m \033[32m"; buf.write(hdr->func.str, hdr->func.len); buf << "\033[0m\n"; buf.write(hdr->file.str, hdr->file.len); if (hdr->line > 0) { buf << ":" << hdr->line; if (hdr->col > 0) buf << ":" << hdr->col; } buf << "\n"; if ((seq_flags & SEQ_FLAG_DEBUG) && (seq_flags & SEQ_FLAG_STANDALONE)) { auto *bt = &base->bt; if (bt->count > 0) { buf << "\n\033[1mBacktrace:\033[0m\n"; for (unsigned i = 0; i < bt->count; i++) { auto *frame = &bt->frames[i]; buf << " " << codon::runtime::makeBacktraceFrameString( frame->pc, std::string(frame->function), std::string(frame->filename), frame->lineno) << "\n"; } } } auto output = buf.str(); if (seq_flags & SEQ_FLAG_STANDALONE) { fwrite(output.data(), 1, output.size(), stderr); abort(); } else { auto *bt = &base->bt; std::string msg(hdr->msg.str, hdr->msg.len); std::string file(hdr->file.str, hdr->file.len); std::string type(tname.str, tname.len); std::vector backtrace; if (seq_flags & SEQ_FLAG_DEBUG) { for (unsigned i = 0; i < bt->count; i++) { backtrace.push_back(bt->frames[i].pc); } } codon::runtime::JITError e(output, msg, type, file, (int)hdr->line, (int)hdr->col, backtrace); if (jitErrorCallback) jitErrorCallback(e); else throw e; } } SEQ_FUNC void seq_throw(void *exc) { _Unwind_Reason_Code code = _Unwind_RaiseException((_Unwind_Exception *)exc); (void)code; seq_terminate(exc); } static uintptr_t readULEB128(const uint8_t **data) { uintptr_t result = 0; uintptr_t shift = 0; unsigned char byte; const uint8_t *p = *data; do { byte = *p++; result |= (byte & 0x7f) << shift; shift += 7; } while (byte & 0x80); *data = p; return result; } static uintptr_t readSLEB128(const uint8_t **data) { uintptr_t result = 0; uintptr_t shift = 0; unsigned char byte; const uint8_t *p = *data; do { byte = *p++; result |= (byte & 0x7f) << shift; shift += 7; } while (byte & 0x80); *data = p; if ((byte & 0x40) && (shift < (sizeof(result) << 3))) { result |= (~0 << shift); } return result; } static unsigned getEncodingSize(uint8_t encoding) { if (encoding == llvm::dwarf::DW_EH_PE_omit) return 0; switch (encoding & 0x0F) { case llvm::dwarf::DW_EH_PE_absptr: return sizeof(uintptr_t); case llvm::dwarf::DW_EH_PE_udata2: return sizeof(uint16_t); case llvm::dwarf::DW_EH_PE_udata4: return sizeof(uint32_t); case llvm::dwarf::DW_EH_PE_udata8: return sizeof(uint64_t); case llvm::dwarf::DW_EH_PE_sdata2: return sizeof(int16_t); case llvm::dwarf::DW_EH_PE_sdata4: return sizeof(int32_t); case llvm::dwarf::DW_EH_PE_sdata8: return sizeof(int64_t); default: // not supported abort(); } } static uintptr_t readEncodedPointer(const uint8_t **data, uint8_t encoding) { uintptr_t result = 0; const uint8_t *p = *data; if (encoding == llvm::dwarf::DW_EH_PE_omit) return result; // first get value switch (encoding & 0x0F) { case llvm::dwarf::DW_EH_PE_absptr: result = ReadType(p); break; case llvm::dwarf::DW_EH_PE_uleb128: result = readULEB128(&p); break; // Note: This case has not been tested case llvm::dwarf::DW_EH_PE_sleb128: result = readSLEB128(&p); break; case llvm::dwarf::DW_EH_PE_udata2: result = ReadType(p); break; case llvm::dwarf::DW_EH_PE_udata4: result = ReadType(p); break; case llvm::dwarf::DW_EH_PE_udata8: result = ReadType(p); break; case llvm::dwarf::DW_EH_PE_sdata2: result = ReadType(p); break; case llvm::dwarf::DW_EH_PE_sdata4: result = ReadType(p); break; case llvm::dwarf::DW_EH_PE_sdata8: result = ReadType(p); break; default: // not supported abort(); } // then add relative offset switch (encoding & 0x70) { case llvm::dwarf::DW_EH_PE_absptr: // do nothing break; case llvm::dwarf::DW_EH_PE_pcrel: result += (uintptr_t)(*data); break; case llvm::dwarf::DW_EH_PE_textrel: case llvm::dwarf::DW_EH_PE_datarel: case llvm::dwarf::DW_EH_PE_funcrel: case llvm::dwarf::DW_EH_PE_aligned: default: // not supported abort(); } // then apply indirection if (encoding & llvm::dwarf::DW_EH_PE_indirect) { result = *((uintptr_t *)result); } *data = p; return result; } static bool isinstance(void *obj, seq_int_t type) { auto *info = ((RTTIObject *)obj)->type; if (info->id == type) return true; if (info->parent_ids) { auto *p = info->parent_ids; while (*p) { if (*p++ == type) { return true; } } } return false; } static bool handleActionValue(int64_t *resultAction, uint8_t TTypeEncoding, const uint8_t *ClassInfo, uintptr_t actionEntry, uint64_t exceptionClass, _Unwind_Exception *exceptionObject) { bool ret = false; if (!resultAction || !exceptionObject || (exceptionClass != SEQ_EXCEPTION_CLASS)) return ret; auto *excp = (struct CodonBaseException *)(((char *)exceptionObject) + seq_exc_offset()); const uint8_t *actionPos = (uint8_t *)actionEntry, *tempActionPos; int64_t typeOffset = 0, actionOffset; for (int i = 0;; i++) { // Each emitted dwarf action corresponds to a 2 tuple of // type info address offset, and action offset to the next // emitted action. typeOffset = (int64_t)readSLEB128(&actionPos); tempActionPos = actionPos; actionOffset = (int64_t)readSLEB128(&tempActionPos); assert(typeOffset >= 0); // Note: A typeOffset == 0 implies that a cleanup llvm.eh.selector // argument has been matched. if (typeOffset > 0) { unsigned EncSize = getEncodingSize(TTypeEncoding); const uint8_t *EntryP = ClassInfo - typeOffset * EncSize; uintptr_t P = readEncodedPointer(&EntryP, TTypeEncoding); auto *ThisClassInfo = reinterpret_cast(P); auto ThisClassType = ThisClassInfo->type; // type=0 means catch-all if (ThisClassType == 0 || isinstance(excp->obj, ThisClassType)) { *resultAction = ThisClassType; ret = true; break; } } if (!actionOffset) break; actionPos += actionOffset; } return ret; } static _Unwind_Reason_Code handleLsda(int version, const uint8_t *lsda, _Unwind_Action actions, uint64_t exceptionClass, _Unwind_Exception *exceptionObject, _Unwind_Context *context) { _Unwind_Reason_Code ret = _URC_CONTINUE_UNWIND; if (!lsda) return ret; // Get the current instruction pointer and offset it before next // instruction in the current frame which threw the exception. uintptr_t pc = _Unwind_GetIP(context) - 1; // Get beginning current frame's code (as defined by the // emitted dwarf code) uintptr_t funcStart = _Unwind_GetRegionStart(context); uintptr_t pcOffset = pc - funcStart; const uint8_t *ClassInfo = nullptr; // Note: See JITDwarfEmitter::EmitExceptionTable(...) for corresponding // dwarf emission // Parse LSDA header. uint8_t lpStartEncoding = *lsda++; if (lpStartEncoding != llvm::dwarf::DW_EH_PE_omit) { readEncodedPointer(&lsda, lpStartEncoding); } uint8_t ttypeEncoding = *lsda++; uintptr_t classInfoOffset; if (ttypeEncoding != llvm::dwarf::DW_EH_PE_omit) { // Calculate type info locations in emitted dwarf code which // were flagged by type info arguments to llvm.eh.selector // intrinsic classInfoOffset = readULEB128(&lsda); ClassInfo = lsda + classInfoOffset; } // Walk call-site table looking for range that // includes current PC. uint8_t callSiteEncoding = *lsda++; auto callSiteTableLength = (uint32_t)readULEB128(&lsda); const uint8_t *callSiteTableStart = lsda; const uint8_t *callSiteTableEnd = callSiteTableStart + callSiteTableLength; const uint8_t *actionTableStart = callSiteTableEnd; const uint8_t *callSitePtr = callSiteTableStart; while (callSitePtr < callSiteTableEnd) { uintptr_t start = readEncodedPointer(&callSitePtr, callSiteEncoding); uintptr_t length = readEncodedPointer(&callSitePtr, callSiteEncoding); uintptr_t landingPad = readEncodedPointer(&callSitePtr, callSiteEncoding); // Note: Action value uintptr_t actionEntry = readULEB128(&callSitePtr); if (exceptionClass != SEQ_EXCEPTION_CLASS) { // We have been notified of a foreign exception being thrown, // and we therefore need to execute cleanup landing pads actionEntry = 0; } if (landingPad == 0) { continue; // no landing pad for this entry } if (actionEntry) { actionEntry += (uintptr_t)actionTableStart - 1; } bool exceptionMatched = false; if ((start <= pcOffset) && (pcOffset < (start + length))) { int64_t actionValue = 0; if (actionEntry) { exceptionMatched = handleActionValue(&actionValue, ttypeEncoding, ClassInfo, actionEntry, exceptionClass, exceptionObject); } if (!(actions & _UA_SEARCH_PHASE)) { // Found landing pad for the PC. // Set Instruction Pointer to so we re-enter function // at landing pad. The landing pad is created by the // compiler to take two parameters in registers. _Unwind_SetGR(context, __builtin_eh_return_data_regno(0), (uintptr_t)exceptionObject); // Note: this virtual register directly corresponds // to the return of the llvm.eh.selector intrinsic if (!actionEntry || !exceptionMatched) { // We indicate cleanup only _Unwind_SetGR(context, __builtin_eh_return_data_regno(1), 0); } else { // Matched type info index of llvm.eh.selector intrinsic // passed here. _Unwind_SetGR(context, __builtin_eh_return_data_regno(1), (uintptr_t)actionValue); } // To execute landing pad set here _Unwind_SetIP(context, funcStart + landingPad); ret = _URC_INSTALL_CONTEXT; } else if (exceptionMatched) { ret = _URC_HANDLER_FOUND; } break; } } return ret; } SEQ_FUNC _Unwind_Reason_Code seq_personality(int version, _Unwind_Action actions, uint64_t exceptionClass, _Unwind_Exception *exceptionObject, _Unwind_Context *context) { const auto *lsda = (uint8_t *)_Unwind_GetLanguageSpecificData(context); // The real work of the personality function is captured here return handleLsda(version, lsda, actions, exceptionClass, exceptionObject, context); } SEQ_FUNC int64_t seq_exc_offset() { static CodonBaseException dummy = {}; return (int64_t)((uintptr_t)&dummy - (uintptr_t)&(dummy.unwindException)); } std::string codon::runtime::makeBacktraceFrameString(uintptr_t pc, const std::string &func, const std::string &file, int line, int col) { std::ostringstream buf; buf << "[\033[33m0x" << std::hex << pc << std::dec << "\033[0m]"; if (!func.empty()) { buf << " \033[32m" << func << "\033[0m"; if (!file.empty()) { buf << " at \033[36m" << file << "\033[0m"; if (line != 0) { buf << ":\033[33m" << line << "\033[0m"; if (col != 0) { buf << ":\033[33m" << col << "\033[0m"; } } } } return buf.str(); } void codon::runtime::setJITErrorCallback( std::function callback) { jitErrorCallback = callback; } ================================================ FILE: codon/runtime/floatlib/extenddftf2.c ================================================ //===-- lib/extenddftf2.c - double -> quad conversion -------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #define QUAD_PRECISION #include "fp_lib.h" #if defined(CRT_HAS_TF_MODE) #define SRC_DOUBLE #define DST_QUAD #include "fp_extend_impl.inc" COMPILER_RT_ABI fp_t __extenddftf2(double a) { return __extendXfYf2__(a); } #endif ================================================ FILE: codon/runtime/floatlib/extendhfsf2.c ================================================ //===-- lib/extendhfsf2.c - half -> single conversion -------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #define SRC_HALF #define DST_SINGLE #include "fp_extend_impl.inc" // Use a forwarding definition and noinline to implement a poor man's alias, // as there isn't a good cross-platform way of defining one. COMPILER_RT_ABI NOINLINE float __extendhfsf2(src_t a) { return __extendXfYf2__(a); } COMPILER_RT_ABI float __gnu_h2f_ieee(src_t a) { return __extendhfsf2(a); } #if defined(__ARM_EABI__) #if defined(COMPILER_RT_ARMHF_TARGET) AEABI_RTABI float __aeabi_h2f(src_t a) { return __extendhfsf2(a); } #else COMPILER_RT_ALIAS(__extendhfsf2, __aeabi_h2f) #endif #endif ================================================ FILE: codon/runtime/floatlib/extendhftf2.c ================================================ //===-- lib/extendhftf2.c - half -> quad conversion ---------------*- C -*-===// // // The LLVM Compiler Infrastructure // // This file is dual licensed under the MIT and the University of Illinois Open // Source Licenses. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// #define QUAD_PRECISION #include "fp_lib.h" #if defined(CRT_HAS_TF_MODE) && defined(COMPILER_RT_HAS_FLOAT16) #define SRC_HALF #define DST_QUAD #include "fp_extend_impl.inc" COMPILER_RT_ABI long double __extendhftf2(_Float16 a) { return __extendXfYf2__(a); } #endif ================================================ FILE: codon/runtime/floatlib/extendsfdf2.c ================================================ //===-- lib/extendsfdf2.c - single -> double conversion -----------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #define SRC_SINGLE #define DST_DOUBLE #include "fp_extend_impl.inc" COMPILER_RT_ABI double __extendsfdf2(float a) { return __extendXfYf2__(a); } #if defined(__ARM_EABI__) #if defined(COMPILER_RT_ARMHF_TARGET) AEABI_RTABI double __aeabi_f2d(float a) { return __extendsfdf2(a); } #else COMPILER_RT_ALIAS(__extendsfdf2, __aeabi_f2d) #endif #endif ================================================ FILE: codon/runtime/floatlib/extendsftf2.c ================================================ //===-- lib/extendsftf2.c - single -> quad conversion -------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #define QUAD_PRECISION #include "fp_lib.h" #if defined(CRT_HAS_TF_MODE) #define SRC_SINGLE #define DST_QUAD #include "fp_extend_impl.inc" COMPILER_RT_ABI fp_t __extendsftf2(float a) { return __extendXfYf2__(a); } #endif ================================================ FILE: codon/runtime/floatlib/fp_extend.h ================================================ //===-lib/fp_extend.h - low precision -> high precision conversion -*- C //-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Set source and destination setting // //===----------------------------------------------------------------------===// #ifndef FP_EXTEND_HEADER #define FP_EXTEND_HEADER #include "int_lib.h" #if defined SRC_SINGLE typedef float src_t; typedef uint32_t src_rep_t; #define SRC_REP_C UINT32_C static const int srcSigBits = 23; #define src_rep_t_clz clzsi #elif defined SRC_DOUBLE typedef double src_t; typedef uint64_t src_rep_t; #define SRC_REP_C UINT64_C static const int srcSigBits = 52; static __inline int src_rep_t_clz(src_rep_t a) { #if defined __LP64__ return __builtin_clzl(a); #else if (a & REP_C(0xffffffff00000000)) return clzsi(a >> 32); else return 32 + clzsi(a & REP_C(0xffffffff)); #endif } #elif defined SRC_HALF #ifdef COMPILER_RT_HAS_FLOAT16 typedef _Float16 src_t; #else typedef uint16_t src_t; #endif typedef uint16_t src_rep_t; #define SRC_REP_C UINT16_C static const int srcSigBits = 10; #define src_rep_t_clz __builtin_clz #else #error Source should be half, single, or double precision! #endif // end source precision #if defined DST_SINGLE typedef float dst_t; typedef uint32_t dst_rep_t; #define DST_REP_C UINT32_C static const int dstSigBits = 23; #elif defined DST_DOUBLE typedef double dst_t; typedef uint64_t dst_rep_t; #define DST_REP_C UINT64_C static const int dstSigBits = 52; #elif defined DST_QUAD typedef long double dst_t; typedef __uint128_t dst_rep_t; #define DST_REP_C (__uint128_t) static const int dstSigBits = 112; #else #error Destination should be single, double, or quad precision! #endif // end destination precision // End of specialization parameters. Two helper routines for conversion to and // from the representation of floating-point data as integer values follow. static __inline src_rep_t srcToRep(src_t x) { const union { src_t f; src_rep_t i; } rep = {.f = x}; return rep.i; } static __inline dst_t dstFromRep(dst_rep_t x) { const union { dst_t f; dst_rep_t i; } rep = {.i = x}; return rep.f; } // End helper routines. Conversion implementation follows. #endif // FP_EXTEND_HEADER ================================================ FILE: codon/runtime/floatlib/fp_extend_impl.inc ================================================ //=-lib/fp_extend_impl.inc - low precision -> high precision conversion -*-- -// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a fairly generic conversion from a narrower to a wider // IEEE-754 floating-point type. The constants and types defined following the // includes below parameterize the conversion. // // It does not support types that don't use the usual IEEE-754 interchange // formats; specifically, some work would be needed to adapt it to // (for example) the Intel 80-bit format or PowerPC double-double format. // // Note please, however, that this implementation is only intended to support // *widening* operations; if you need to convert to a *narrower* floating-point // type (e.g. double -> float), then this routine will not do what you want it // to. // // It also requires that integer types at least as large as both formats // are available on the target platform; this may pose a problem when trying // to add support for quad on some 32-bit systems, for example. You also may // run into trouble finding an appropriate CLZ function for wide source types; // you will likely need to roll your own on some platforms. // // Finally, the following assumptions are made: // // 1. Floating-point types and integer types have the same endianness on the // target platform. // // 2. Quiet NaNs, if supported, are indicated by the leading bit of the // significand field being set. // //===----------------------------------------------------------------------===// #include "fp_extend.h" static __inline dst_t __extendXfYf2__(src_t a) { // Various constants whose values follow from the type parameters. // Any reasonable optimizer will fold and propagate all of these. const int srcBits = sizeof(src_t) * CHAR_BIT; const int srcExpBits = srcBits - srcSigBits - 1; const int srcInfExp = (1 << srcExpBits) - 1; const int srcExpBias = srcInfExp >> 1; const src_rep_t srcMinNormal = SRC_REP_C(1) << srcSigBits; const src_rep_t srcInfinity = (src_rep_t)srcInfExp << srcSigBits; const src_rep_t srcSignMask = SRC_REP_C(1) << (srcSigBits + srcExpBits); const src_rep_t srcAbsMask = srcSignMask - 1; const src_rep_t srcQNaN = SRC_REP_C(1) << (srcSigBits - 1); const src_rep_t srcNaNCode = srcQNaN - 1; const int dstBits = sizeof(dst_t) * CHAR_BIT; const int dstExpBits = dstBits - dstSigBits - 1; const int dstInfExp = (1 << dstExpBits) - 1; const int dstExpBias = dstInfExp >> 1; const dst_rep_t dstMinNormal = DST_REP_C(1) << dstSigBits; // Break a into a sign and representation of the absolute value. const src_rep_t aRep = srcToRep(a); const src_rep_t aAbs = aRep & srcAbsMask; const src_rep_t sign = aRep & srcSignMask; dst_rep_t absResult; // If sizeof(src_rep_t) < sizeof(int), the subtraction result is promoted // to (signed) int. To avoid that, explicitly cast to src_rep_t. if ((src_rep_t)(aAbs - srcMinNormal) < srcInfinity - srcMinNormal) { // a is a normal number. // Extend to the destination type by shifting the significand and // exponent into the proper position and rebiasing the exponent. absResult = (dst_rep_t)aAbs << (dstSigBits - srcSigBits); absResult += (dst_rep_t)(dstExpBias - srcExpBias) << dstSigBits; } else if (aAbs >= srcInfinity) { // a is NaN or infinity. // Conjure the result by beginning with infinity, then setting the qNaN // bit (if needed) and right-aligning the rest of the trailing NaN // payload field. absResult = (dst_rep_t)dstInfExp << dstSigBits; absResult |= (dst_rep_t)(aAbs & srcQNaN) << (dstSigBits - srcSigBits); absResult |= (dst_rep_t)(aAbs & srcNaNCode) << (dstSigBits - srcSigBits); } else if (aAbs) { // a is denormal. // renormalize the significand and clear the leading bit, then insert // the correct adjusted exponent in the destination type. const int scale = src_rep_t_clz(aAbs) - src_rep_t_clz(srcMinNormal); absResult = (dst_rep_t)aAbs << (dstSigBits - srcSigBits + scale); absResult ^= dstMinNormal; const int resultExponent = dstExpBias - srcExpBias - scale + 1; absResult |= (dst_rep_t)resultExponent << dstSigBits; } else { // a is zero. absResult = 0; } // Apply the signbit to the absolute value. const dst_rep_t result = absResult | (dst_rep_t)sign << (dstBits - srcBits); return dstFromRep(result); } ================================================ FILE: codon/runtime/floatlib/fp_lib.h ================================================ //===-- lib/fp_lib.h - Floating-point utilities -------------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file is a configuration header for soft-float routines in compiler-rt. // This file does not provide any part of the compiler-rt interface, but defines // many useful constants and utility routines that are used in the // implementation of the soft-float routines in compiler-rt. // // Assumes that float, double and long double correspond to the IEEE-754 // binary32, binary64 and binary 128 types, respectively, and that integer // endianness matches floating point endianness on the target platform. // //===----------------------------------------------------------------------===// #ifndef FP_LIB_HEADER #define FP_LIB_HEADER #include "int_lib.h" #include "int_math.h" #include "int_types.h" #include #include #include #if defined SINGLE_PRECISION typedef uint16_t half_rep_t; typedef uint32_t rep_t; typedef uint64_t twice_rep_t; typedef int32_t srep_t; typedef float fp_t; #define HALF_REP_C UINT16_C #define REP_C UINT32_C #define significandBits 23 static __inline int rep_clz(rep_t a) { return clzsi(a); } // 32x32 --> 64 bit multiply static __inline void wideMultiply(rep_t a, rep_t b, rep_t *hi, rep_t *lo) { const uint64_t product = (uint64_t)a * b; *hi = product >> 32; *lo = product; } COMPILER_RT_ABI fp_t __addsf3(fp_t a, fp_t b); #elif defined DOUBLE_PRECISION typedef uint32_t half_rep_t; typedef uint64_t rep_t; typedef int64_t srep_t; typedef double fp_t; #define HALF_REP_C UINT32_C #define REP_C UINT64_C #define significandBits 52 static __inline int rep_clz(rep_t a) { #if defined __LP64__ return __builtin_clzl(a); #else if (a & REP_C(0xffffffff00000000)) return clzsi(a >> 32); else return 32 + clzsi(a & REP_C(0xffffffff)); #endif } #define loWord(a) (a & 0xffffffffU) #define hiWord(a) (a >> 32) // 64x64 -> 128 wide multiply for platforms that don't have such an operation; // many 64-bit platforms have this operation, but they tend to have hardware // floating-point, so we don't bother with a special case for them here. static __inline void wideMultiply(rep_t a, rep_t b, rep_t *hi, rep_t *lo) { // Each of the component 32x32 -> 64 products const uint64_t plolo = loWord(a) * loWord(b); const uint64_t plohi = loWord(a) * hiWord(b); const uint64_t philo = hiWord(a) * loWord(b); const uint64_t phihi = hiWord(a) * hiWord(b); // Sum terms that contribute to lo in a way that allows us to get the carry const uint64_t r0 = loWord(plolo); const uint64_t r1 = hiWord(plolo) + loWord(plohi) + loWord(philo); *lo = r0 + (r1 << 32); // Sum terms contributing to hi with the carry from lo *hi = hiWord(plohi) + hiWord(philo) + hiWord(r1) + phihi; } #undef loWord #undef hiWord COMPILER_RT_ABI fp_t __adddf3(fp_t a, fp_t b); #elif defined QUAD_PRECISION #if defined(CRT_HAS_F128) && defined(CRT_HAS_128BIT) typedef uint64_t half_rep_t; typedef __uint128_t rep_t; typedef __int128_t srep_t; typedef tf_float fp_t; #define HALF_REP_C UINT64_C #define REP_C (__uint128_t) #if defined(CRT_HAS_IEEE_TF) // Note: Since there is no explicit way to tell compiler the constant is a // 128-bit integer, we let the constant be casted to 128-bit integer #define significandBits 112 #define TF_MANT_DIG (significandBits + 1) static __inline int rep_clz(rep_t a) { const union { __uint128_t ll; #if _YUGA_BIG_ENDIAN struct { uint64_t high, low; } s; #else struct { uint64_t low, high; } s; #endif } uu = {.ll = a}; uint64_t word; uint64_t add; if (uu.s.high) { word = uu.s.high; add = 0; } else { word = uu.s.low; add = 64; } return __builtin_clzll(word) + add; } #define Word_LoMask UINT64_C(0x00000000ffffffff) #define Word_HiMask UINT64_C(0xffffffff00000000) #define Word_FullMask UINT64_C(0xffffffffffffffff) #define Word_1(a) (uint64_t)((a >> 96) & Word_LoMask) #define Word_2(a) (uint64_t)((a >> 64) & Word_LoMask) #define Word_3(a) (uint64_t)((a >> 32) & Word_LoMask) #define Word_4(a) (uint64_t)(a & Word_LoMask) // 128x128 -> 256 wide multiply for platforms that don't have such an operation; // many 64-bit platforms have this operation, but they tend to have hardware // floating-point, so we don't bother with a special case for them here. static __inline void wideMultiply(rep_t a, rep_t b, rep_t *hi, rep_t *lo) { const uint64_t product11 = Word_1(a) * Word_1(b); const uint64_t product12 = Word_1(a) * Word_2(b); const uint64_t product13 = Word_1(a) * Word_3(b); const uint64_t product14 = Word_1(a) * Word_4(b); const uint64_t product21 = Word_2(a) * Word_1(b); const uint64_t product22 = Word_2(a) * Word_2(b); const uint64_t product23 = Word_2(a) * Word_3(b); const uint64_t product24 = Word_2(a) * Word_4(b); const uint64_t product31 = Word_3(a) * Word_1(b); const uint64_t product32 = Word_3(a) * Word_2(b); const uint64_t product33 = Word_3(a) * Word_3(b); const uint64_t product34 = Word_3(a) * Word_4(b); const uint64_t product41 = Word_4(a) * Word_1(b); const uint64_t product42 = Word_4(a) * Word_2(b); const uint64_t product43 = Word_4(a) * Word_3(b); const uint64_t product44 = Word_4(a) * Word_4(b); const __uint128_t sum0 = (__uint128_t)product44; const __uint128_t sum1 = (__uint128_t)product34 + (__uint128_t)product43; const __uint128_t sum2 = (__uint128_t)product24 + (__uint128_t)product33 + (__uint128_t)product42; const __uint128_t sum3 = (__uint128_t)product14 + (__uint128_t)product23 + (__uint128_t)product32 + (__uint128_t)product41; const __uint128_t sum4 = (__uint128_t)product13 + (__uint128_t)product22 + (__uint128_t)product31; const __uint128_t sum5 = (__uint128_t)product12 + (__uint128_t)product21; const __uint128_t sum6 = (__uint128_t)product11; const __uint128_t r0 = (sum0 & Word_FullMask) + ((sum1 & Word_LoMask) << 32); const __uint128_t r1 = (sum0 >> 64) + ((sum1 >> 32) & Word_FullMask) + (sum2 & Word_FullMask) + ((sum3 << 32) & Word_HiMask); *lo = r0 + (r1 << 64); *hi = (r1 >> 64) + (sum1 >> 96) + (sum2 >> 64) + (sum3 >> 32) + sum4 + (sum5 << 32) + (sum6 << 64); } #undef Word_1 #undef Word_2 #undef Word_3 #undef Word_4 #undef Word_HiMask #undef Word_LoMask #undef Word_FullMask #endif // defined(CRT_HAS_IEEE_TF) #else typedef long double fp_t; #endif // defined(CRT_HAS_F128) && defined(CRT_HAS_128BIT) #else #error SINGLE_PRECISION, DOUBLE_PRECISION or QUAD_PRECISION must be defined. #endif #if defined(SINGLE_PRECISION) || defined(DOUBLE_PRECISION) || \ (defined(QUAD_PRECISION) && defined(CRT_HAS_TF_MODE)) #define typeWidth (sizeof(rep_t) * CHAR_BIT) static __inline rep_t toRep(fp_t x) { const union { fp_t f; rep_t i; } rep = {.f = x}; return rep.i; } static __inline fp_t fromRep(rep_t x) { const union { fp_t f; rep_t i; } rep = {.i = x}; return rep.f; } #if !defined(QUAD_PRECISION) || defined(CRT_HAS_IEEE_TF) #define exponentBits (typeWidth - significandBits - 1) #define maxExponent ((1 << exponentBits) - 1) #define exponentBias (maxExponent >> 1) #define implicitBit (REP_C(1) << significandBits) #define significandMask (implicitBit - 1U) #define signBit (REP_C(1) << (significandBits + exponentBits)) #define absMask (signBit - 1U) #define exponentMask (absMask ^ significandMask) #define oneRep ((rep_t)exponentBias << significandBits) #define infRep exponentMask #define quietBit (implicitBit >> 1) #define qnanRep (exponentMask | quietBit) static __inline int normalize(rep_t *significand) { const int shift = rep_clz(*significand) - rep_clz(implicitBit); *significand <<= shift; return 1 - shift; } static __inline void wideLeftShift(rep_t *hi, rep_t *lo, int count) { *hi = *hi << count | *lo >> (typeWidth - count); *lo = *lo << count; } static __inline void wideRightShiftWithSticky(rep_t *hi, rep_t *lo, unsigned int count) { if (count < typeWidth) { const bool sticky = (*lo << (typeWidth - count)) != 0; *lo = *hi << (typeWidth - count) | *lo >> count | sticky; *hi = *hi >> count; } else if (count < 2 * typeWidth) { const bool sticky = *hi << (2 * typeWidth - count) | *lo; *lo = *hi >> (count - typeWidth) | sticky; *hi = 0; } else { const bool sticky = *hi | *lo; *lo = sticky; *hi = 0; } } // Implements logb methods (logb, logbf, logbl) for IEEE-754. This avoids // pulling in a libm dependency from compiler-rt, but is not meant to replace // it (i.e. code calling logb() should get the one from libm, not this), hence // the __compiler_rt prefix. static __inline fp_t __compiler_rt_logbX(fp_t x) { rep_t rep = toRep(x); int exp = (rep & exponentMask) >> significandBits; // Abnormal cases: // 1) +/- inf returns +inf; NaN returns NaN // 2) 0.0 returns -inf if (exp == maxExponent) { if (((rep & signBit) == 0) || (x != x)) { return x; // NaN or +inf: return x } else { return -x; // -inf: return -x } } else if (x == 0.0) { // 0.0: return -inf return fromRep(infRep | signBit); } if (exp != 0) { // Normal number return exp - exponentBias; // Unbias exponent } else { // Subnormal number; normalize and repeat rep &= absMask; const int shift = 1 - normalize(&rep); exp = (rep & exponentMask) >> significandBits; return exp - exponentBias - shift; // Unbias exponent } } // Avoid using scalbn from libm. Unlike libc/libm scalbn, this function never // sets errno on underflow/overflow. static __inline fp_t __compiler_rt_scalbnX(fp_t x, int y) { const rep_t rep = toRep(x); int exp = (rep & exponentMask) >> significandBits; if (x == 0.0 || exp == maxExponent) return x; // +/- 0.0, NaN, or inf: return x // Normalize subnormal input. rep_t sig = rep & significandMask; if (exp == 0) { exp += normalize(&sig); sig &= ~implicitBit; // clear the implicit bit again } if (__builtin_sadd_overflow(exp, y, &exp)) { // Saturate the exponent, which will guarantee an underflow/overflow below. exp = (y >= 0) ? INT_MAX : INT_MIN; } // Return this value: [+/-] 1.sig * 2 ** (exp - exponentBias). const rep_t sign = rep & signBit; if (exp >= maxExponent) { // Overflow, which could produce infinity or the largest-magnitude value, // depending on the rounding mode. return fromRep(sign | ((rep_t)(maxExponent - 1) << significandBits)) * 2.0f; } else if (exp <= 0) { // Subnormal or underflow. Use floating-point multiply to handle truncation // correctly. fp_t tmp = fromRep(sign | (REP_C(1) << significandBits) | sig); exp += exponentBias - 1; if (exp < 1) exp = 1; tmp *= fromRep((rep_t)exp << significandBits); return tmp; } else return fromRep(sign | ((rep_t)exp << significandBits) | sig); } #endif // !defined(QUAD_PRECISION) || defined(CRT_HAS_IEEE_TF) // Avoid using fmax from libm. static __inline fp_t __compiler_rt_fmaxX(fp_t x, fp_t y) { // If either argument is NaN, return the other argument. If both are NaN, // arbitrarily return the second one. Otherwise, if both arguments are +/-0, // arbitrarily return the first one. return (crt_isnan(x) || x < y) ? y : x; } #endif #if defined(SINGLE_PRECISION) static __inline fp_t __compiler_rt_logbf(fp_t x) { return __compiler_rt_logbX(x); } static __inline fp_t __compiler_rt_scalbnf(fp_t x, int y) { return __compiler_rt_scalbnX(x, y); } static __inline fp_t __compiler_rt_fmaxf(fp_t x, fp_t y) { #if defined(__aarch64__) // Use __builtin_fmaxf which turns into an fmaxnm instruction on AArch64. return __builtin_fmaxf(x, y); #else // __builtin_fmaxf frequently turns into a libm call, so inline the function. return __compiler_rt_fmaxX(x, y); #endif } #elif defined(DOUBLE_PRECISION) static __inline fp_t __compiler_rt_logb(fp_t x) { return __compiler_rt_logbX(x); } static __inline fp_t __compiler_rt_scalbn(fp_t x, int y) { return __compiler_rt_scalbnX(x, y); } static __inline fp_t __compiler_rt_fmax(fp_t x, fp_t y) { #if defined(__aarch64__) // Use __builtin_fmax which turns into an fmaxnm instruction on AArch64. return __builtin_fmax(x, y); #else // __builtin_fmax frequently turns into a libm call, so inline the function. return __compiler_rt_fmaxX(x, y); #endif } #elif defined(QUAD_PRECISION) && defined(CRT_HAS_TF_MODE) // The generic implementation only works for ieee754 floating point. For other // floating point types, continue to rely on the libm implementation for now. #if defined(CRT_HAS_IEEE_TF) static __inline tf_float __compiler_rt_logbtf(tf_float x) { return __compiler_rt_logbX(x); } static __inline tf_float __compiler_rt_scalbntf(tf_float x, int y) { return __compiler_rt_scalbnX(x, y); } static __inline tf_float __compiler_rt_fmaxtf(tf_float x, tf_float y) { return __compiler_rt_fmaxX(x, y); } #define __compiler_rt_logbl __compiler_rt_logbtf #define __compiler_rt_scalbnl __compiler_rt_scalbntf #define __compiler_rt_fmaxl __compiler_rt_fmaxtf #define crt_fabstf crt_fabsf128 #define crt_copysigntf crt_copysignf128 #elif defined(CRT_LDBL_128BIT) static __inline tf_float __compiler_rt_logbtf(tf_float x) { return crt_logbl(x); } static __inline tf_float __compiler_rt_scalbntf(tf_float x, int y) { return crt_scalbnl(x, y); } static __inline tf_float __compiler_rt_fmaxtf(tf_float x, tf_float y) { return crt_fmaxl(x, y); } #define __compiler_rt_logbl crt_logbl #define __compiler_rt_scalbnl crt_scalbnl #define __compiler_rt_fmaxl crt_fmaxl #define crt_fabstf crt_fabsl #define crt_copysigntf crt_copysignl #else #error Unsupported TF mode type #endif #endif // *_PRECISION #endif // FP_LIB_HEADER ================================================ FILE: codon/runtime/floatlib/fp_trunc.h ================================================ //=== lib/fp_trunc.h - high precision -> low precision conversion *- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Set source and destination precision setting // //===----------------------------------------------------------------------===// #ifndef FP_TRUNC_HEADER #define FP_TRUNC_HEADER #include "int_lib.h" #if defined SRC_SINGLE typedef float src_t; typedef uint32_t src_rep_t; #define SRC_REP_C UINT32_C static const int srcSigBits = 23; #elif defined SRC_DOUBLE typedef double src_t; typedef uint64_t src_rep_t; #define SRC_REP_C UINT64_C static const int srcSigBits = 52; #elif defined SRC_QUAD typedef long double src_t; typedef __uint128_t src_rep_t; #define SRC_REP_C (__uint128_t) static const int srcSigBits = 112; #else #error Source should be double precision or quad precision! #endif // end source precision #if defined DST_DOUBLE typedef double dst_t; typedef uint64_t dst_rep_t; #define DST_REP_C UINT64_C static const int dstSigBits = 52; #elif defined DST_SINGLE typedef float dst_t; typedef uint32_t dst_rep_t; #define DST_REP_C UINT32_C static const int dstSigBits = 23; #elif defined DST_HALF #ifdef COMPILER_RT_HAS_FLOAT16 typedef _Float16 dst_t; #else typedef uint16_t dst_t; #endif typedef uint16_t dst_rep_t; #define DST_REP_C UINT16_C static const int dstSigBits = 10; #elif defined DST_BFLOAT typedef __bf16 dst_t; typedef uint16_t dst_rep_t; #define DST_REP_C UINT16_C static const int dstSigBits = 7; #else #error Destination should be single precision or double precision! #endif // end destination precision // End of specialization parameters. Two helper routines for conversion to and // from the representation of floating-point data as integer values follow. static __inline src_rep_t srcToRep(src_t x) { const union { src_t f; src_rep_t i; } rep = {.f = x}; return rep.i; } static __inline dst_t dstFromRep(dst_rep_t x) { const union { dst_t f; dst_rep_t i; } rep = {.i = x}; return rep.f; } #endif // FP_TRUNC_HEADER ================================================ FILE: codon/runtime/floatlib/fp_trunc_impl.inc ================================================ //= lib/fp_trunc_impl.inc - high precision -> low precision conversion *-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a fairly generic conversion from a wider to a narrower // IEEE-754 floating-point type in the default (round to nearest, ties to even) // rounding mode. The constants and types defined following the includes below // parameterize the conversion. // // This routine can be trivially adapted to support conversions to // half-precision or from quad-precision. It does not support types that don't // use the usual IEEE-754 interchange formats; specifically, some work would be // needed to adapt it to (for example) the Intel 80-bit format or PowerPC // double-double format. // // Note please, however, that this implementation is only intended to support // *narrowing* operations; if you need to convert to a *wider* floating-point // type (e.g. float -> double), then this routine will not do what you want it // to. // // It also requires that integer types at least as large as both formats // are available on the target platform; this may pose a problem when trying // to add support for quad on some 32-bit systems, for example. // // Finally, the following assumptions are made: // // 1. Floating-point types and integer types have the same endianness on the // target platform. // // 2. Quiet NaNs, if supported, are indicated by the leading bit of the // significand field being set. // //===----------------------------------------------------------------------===// #include "fp_trunc.h" static __inline dst_t __truncXfYf2__(src_t a) { // Various constants whose values follow from the type parameters. // Any reasonable optimizer will fold and propagate all of these. const int srcBits = sizeof(src_t) * CHAR_BIT; const int srcExpBits = srcBits - srcSigBits - 1; const int srcInfExp = (1 << srcExpBits) - 1; const int srcExpBias = srcInfExp >> 1; const src_rep_t srcMinNormal = SRC_REP_C(1) << srcSigBits; const src_rep_t srcSignificandMask = srcMinNormal - 1; const src_rep_t srcInfinity = (src_rep_t)srcInfExp << srcSigBits; const src_rep_t srcSignMask = SRC_REP_C(1) << (srcSigBits + srcExpBits); const src_rep_t srcAbsMask = srcSignMask - 1; const src_rep_t roundMask = (SRC_REP_C(1) << (srcSigBits - dstSigBits)) - 1; const src_rep_t halfway = SRC_REP_C(1) << (srcSigBits - dstSigBits - 1); const src_rep_t srcQNaN = SRC_REP_C(1) << (srcSigBits - 1); const src_rep_t srcNaNCode = srcQNaN - 1; const int dstBits = sizeof(dst_t) * CHAR_BIT; const int dstExpBits = dstBits - dstSigBits - 1; const int dstInfExp = (1 << dstExpBits) - 1; const int dstExpBias = dstInfExp >> 1; const int underflowExponent = srcExpBias + 1 - dstExpBias; const int overflowExponent = srcExpBias + dstInfExp - dstExpBias; const src_rep_t underflow = (src_rep_t)underflowExponent << srcSigBits; const src_rep_t overflow = (src_rep_t)overflowExponent << srcSigBits; const dst_rep_t dstQNaN = DST_REP_C(1) << (dstSigBits - 1); const dst_rep_t dstNaNCode = dstQNaN - 1; // Break a into a sign and representation of the absolute value. const src_rep_t aRep = srcToRep(a); const src_rep_t aAbs = aRep & srcAbsMask; const src_rep_t sign = aRep & srcSignMask; dst_rep_t absResult; if (aAbs - underflow < aAbs - overflow) { // The exponent of a is within the range of normal numbers in the // destination format. We can convert by simply right-shifting with // rounding and adjusting the exponent. absResult = aAbs >> (srcSigBits - dstSigBits); absResult -= (dst_rep_t)(srcExpBias - dstExpBias) << dstSigBits; const src_rep_t roundBits = aAbs & roundMask; // Round to nearest. if (roundBits > halfway) absResult++; // Tie to even. else if (roundBits == halfway) absResult += absResult & 1; } else if (aAbs > srcInfinity) { // a is NaN. // Conjure the result by beginning with infinity, setting the qNaN // bit and inserting the (truncated) trailing NaN field. absResult = (dst_rep_t)dstInfExp << dstSigBits; absResult |= dstQNaN; absResult |= ((aAbs & srcNaNCode) >> (srcSigBits - dstSigBits)) & dstNaNCode; } else if (aAbs >= overflow) { // a overflows to infinity. absResult = (dst_rep_t)dstInfExp << dstSigBits; } else { // a underflows on conversion to the destination type or is an exact // zero. The result may be a denormal or zero. Extract the exponent // to get the shift amount for the denormalization. const int aExp = aAbs >> srcSigBits; const int shift = srcExpBias - dstExpBias - aExp + 1; const src_rep_t significand = (aRep & srcSignificandMask) | srcMinNormal; // Right shift by the denormalization amount with sticky. if (shift > srcSigBits) { absResult = 0; } else { const bool sticky = (significand << (srcBits - shift)) != 0; src_rep_t denormalizedSignificand = significand >> shift | sticky; absResult = denormalizedSignificand >> (srcSigBits - dstSigBits); const src_rep_t roundBits = denormalizedSignificand & roundMask; // Round to nearest if (roundBits > halfway) absResult++; // Ties to even else if (roundBits == halfway) absResult += absResult & 1; } } // Apply the signbit to the absolute value. const dst_rep_t result = absResult | sign >> (srcBits - dstBits); return dstFromRep(result); } ================================================ FILE: codon/runtime/floatlib/int_endianness.h ================================================ //===-- int_endianness.h - configuration header for compiler-rt -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file is a configuration header for compiler-rt. // This file is not part of the interface of this library. // //===----------------------------------------------------------------------===// #ifndef INT_ENDIANNESS_H #define INT_ENDIANNESS_H #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \ defined(__ORDER_LITTLE_ENDIAN__) // Clang and GCC provide built-in endianness definitions. #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #define _YUGA_LITTLE_ENDIAN 0 #define _YUGA_BIG_ENDIAN 1 #elif __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ #define _YUGA_LITTLE_ENDIAN 1 #define _YUGA_BIG_ENDIAN 0 #endif // __BYTE_ORDER__ #else // Compilers other than Clang or GCC. #if defined(__SVR4) && defined(__sun) #include #if defined(_BIG_ENDIAN) #define _YUGA_LITTLE_ENDIAN 0 #define _YUGA_BIG_ENDIAN 1 #elif defined(_LITTLE_ENDIAN) #define _YUGA_LITTLE_ENDIAN 1 #define _YUGA_BIG_ENDIAN 0 #else // !_LITTLE_ENDIAN #error "unknown endianness" #endif // !_LITTLE_ENDIAN #endif // Solaris // .. #if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__DragonFly__) || \ defined(__minix) #include #if _BYTE_ORDER == _BIG_ENDIAN #define _YUGA_LITTLE_ENDIAN 0 #define _YUGA_BIG_ENDIAN 1 #elif _BYTE_ORDER == _LITTLE_ENDIAN #define _YUGA_LITTLE_ENDIAN 1 #define _YUGA_BIG_ENDIAN 0 #endif // _BYTE_ORDER #endif // *BSD #if defined(__OpenBSD__) #include #if _BYTE_ORDER == _BIG_ENDIAN #define _YUGA_LITTLE_ENDIAN 0 #define _YUGA_BIG_ENDIAN 1 #elif _BYTE_ORDER == _LITTLE_ENDIAN #define _YUGA_LITTLE_ENDIAN 1 #define _YUGA_BIG_ENDIAN 0 #endif // _BYTE_ORDER #endif // OpenBSD // .. // Mac OSX has __BIG_ENDIAN__ or __LITTLE_ENDIAN__ automatically set by the // compiler (at least with GCC) #if defined(__APPLE__) || defined(__ellcc__) #ifdef __BIG_ENDIAN__ #if __BIG_ENDIAN__ #define _YUGA_LITTLE_ENDIAN 0 #define _YUGA_BIG_ENDIAN 1 #endif #endif // __BIG_ENDIAN__ #ifdef __LITTLE_ENDIAN__ #if __LITTLE_ENDIAN__ #define _YUGA_LITTLE_ENDIAN 1 #define _YUGA_BIG_ENDIAN 0 #endif #endif // __LITTLE_ENDIAN__ #endif // Mac OSX // .. #if defined(_WIN32) #define _YUGA_LITTLE_ENDIAN 1 #define _YUGA_BIG_ENDIAN 0 #endif // Windows #endif // Clang or GCC. // . #if !defined(_YUGA_LITTLE_ENDIAN) || !defined(_YUGA_BIG_ENDIAN) #error Unable to determine endian #endif // Check we found an endianness correctly. #endif // INT_ENDIANNESS_H ================================================ FILE: codon/runtime/floatlib/int_lib.h ================================================ //===-- int_lib.h - configuration header for compiler-rt -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file is a configuration header for compiler-rt. // This file is not part of the interface of this library. // //===----------------------------------------------------------------------===// #ifndef INT_LIB_H #define INT_LIB_H // Assumption: Signed integral is 2's complement. // Assumption: Right shift of signed negative is arithmetic shift. // Assumption: Endianness is little or big (not mixed). // ABI macro definitions #if __ARM_EABI__ #ifdef COMPILER_RT_ARMHF_TARGET #define COMPILER_RT_ABI #else #define COMPILER_RT_ABI __attribute__((__pcs__("aapcs"))) #endif #else #define COMPILER_RT_ABI #endif #define AEABI_RTABI __attribute__((__pcs__("aapcs"))) #if defined(_MSC_VER) && !defined(__clang__) #define ALWAYS_INLINE __forceinline #define NOINLINE __declspec(noinline) #define NORETURN __declspec(noreturn) #define UNUSED #else #define ALWAYS_INLINE __attribute__((always_inline)) #define NOINLINE __attribute__((noinline)) #define NORETURN __attribute__((noreturn)) #define UNUSED __attribute__((unused)) #endif #define STR(a) #a #define XSTR(a) STR(a) #define SYMBOL_NAME(name) XSTR(__USER_LABEL_PREFIX__) #name #if defined(__ELF__) || defined(__MINGW32__) || defined(__wasm__) || defined(_AIX) #define COMPILER_RT_ALIAS(name, aliasname) \ COMPILER_RT_ABI __typeof(name) aliasname __attribute__((__alias__(#name))); #elif defined(__APPLE__) #if defined(VISIBILITY_HIDDEN) #define COMPILER_RT_ALIAS_VISIBILITY(name) \ __asm__(".private_extern " SYMBOL_NAME(name)); #else #define COMPILER_RT_ALIAS_VISIBILITY(name) #endif #define COMPILER_RT_ALIAS(name, aliasname) \ __asm__(".globl " SYMBOL_NAME(aliasname)); \ COMPILER_RT_ALIAS_VISIBILITY(aliasname) \ __asm__(SYMBOL_NAME(aliasname) " = " SYMBOL_NAME(name)); \ COMPILER_RT_ABI __typeof(name) aliasname; #elif defined(_WIN32) #define COMPILER_RT_ALIAS(name, aliasname) #else #error Unsupported target #endif #if (defined(__FreeBSD__) || defined(__NetBSD__)) && \ (defined(_KERNEL) || defined(_STANDALONE)) // // Kernel and boot environment can't use normal headers, // so use the equivalent system headers. // NB: FreeBSD (and OpenBSD) deprecate machine/limits.h in // favour of sys/limits.h, so prefer the former, but fall // back on the latter if not available since NetBSD only has // the latter. // #if defined(__has_include) && __has_include() #include #else #include #endif #include #include #else // Include the standard compiler builtin headers we use functionality from. #include #include #include #include #endif // Include the commonly used internal type definitions. #include "int_types.h" // Include internal utility function declarations. #include "int_util.h" COMPILER_RT_ABI int __paritysi2(si_int a); COMPILER_RT_ABI int __paritydi2(di_int a); COMPILER_RT_ABI di_int __divdi3(di_int a, di_int b); COMPILER_RT_ABI si_int __divsi3(si_int a, si_int b); COMPILER_RT_ABI su_int __udivsi3(su_int n, su_int d); COMPILER_RT_ABI su_int __udivmodsi4(su_int a, su_int b, su_int *rem); COMPILER_RT_ABI du_int __udivmoddi4(du_int a, du_int b, du_int *rem); #ifdef CRT_HAS_128BIT COMPILER_RT_ABI int __clzti2(ti_int a); COMPILER_RT_ABI tu_int __udivmodti4(tu_int a, tu_int b, tu_int *rem); #endif // Definitions for builtins unavailable on MSVC #if defined(_MSC_VER) && !defined(__clang__) #include int __inline __builtin_ctz(uint32_t value) { unsigned long trailing_zero = 0; if (_BitScanForward(&trailing_zero, value)) return trailing_zero; return 32; } int __inline __builtin_clz(uint32_t value) { unsigned long leading_zero = 0; if (_BitScanReverse(&leading_zero, value)) return 31 - leading_zero; return 32; } #if defined(_M_ARM) || defined(_M_X64) int __inline __builtin_clzll(uint64_t value) { unsigned long leading_zero = 0; if (_BitScanReverse64(&leading_zero, value)) return 63 - leading_zero; return 64; } #else int __inline __builtin_clzll(uint64_t value) { if (value == 0) return 64; uint32_t msh = (uint32_t)(value >> 32); uint32_t lsh = (uint32_t)(value & 0xFFFFFFFF); if (msh != 0) return __builtin_clz(msh); return 32 + __builtin_clz(lsh); } #endif #define __builtin_clzl __builtin_clzll bool __inline __builtin_sadd_overflow(int x, int y, int *result) { if ((x < 0) != (y < 0)) { *result = x + y; return false; } int tmp = (unsigned int)x + (unsigned int)y; if ((tmp < 0) != (x < 0)) return true; *result = tmp; return false; } #endif // defined(_MSC_VER) && !defined(__clang__) #endif // INT_LIB_H ================================================ FILE: codon/runtime/floatlib/int_math.h ================================================ //===-- int_math.h - internal math inlines --------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file is not part of the interface of this library. // // This file defines substitutes for the libm functions used in some of the // compiler-rt implementations, defined in such a way that there is not a direct // dependency on libm or math.h. Instead, we use the compiler builtin versions // where available. This reduces our dependencies on the system SDK by foisting // the responsibility onto the compiler. // //===----------------------------------------------------------------------===// #ifndef INT_MATH_H #define INT_MATH_H #ifndef __has_builtin #define __has_builtin(x) 0 #endif #if defined(_MSC_VER) && !defined(__clang__) #include #include #endif #if defined(_MSC_VER) && !defined(__clang__) #define CRT_INFINITY INFINITY #else #define CRT_INFINITY __builtin_huge_valf() #endif #if defined(_MSC_VER) && !defined(__clang__) #define crt_isfinite(x) _finite((x)) #define crt_isinf(x) !_finite((x)) #define crt_isnan(x) _isnan((x)) #else // Define crt_isfinite in terms of the builtin if available, otherwise provide // an alternate version in terms of our other functions. This supports some // versions of GCC which didn't have __builtin_isfinite. #if __has_builtin(__builtin_isfinite) #define crt_isfinite(x) __builtin_isfinite((x)) #elif defined(__GNUC__) #define crt_isfinite(x) \ __extension__(({ \ __typeof((x)) x_ = (x); \ !crt_isinf(x_) && !crt_isnan(x_); \ })) #else #error "Do not know how to check for infinity" #endif // __has_builtin(__builtin_isfinite) #define crt_isinf(x) __builtin_isinf((x)) #define crt_isnan(x) __builtin_isnan((x)) #endif // _MSC_VER #if defined(_MSC_VER) && !defined(__clang__) #define crt_copysign(x, y) copysign((x), (y)) #define crt_copysignf(x, y) copysignf((x), (y)) #define crt_copysignl(x, y) copysignl((x), (y)) #else #define crt_copysign(x, y) __builtin_copysign((x), (y)) #define crt_copysignf(x, y) __builtin_copysignf((x), (y)) #define crt_copysignl(x, y) __builtin_copysignl((x), (y)) #endif #if defined(_MSC_VER) && !defined(__clang__) #define crt_fabs(x) fabs((x)) #define crt_fabsf(x) fabsf((x)) #define crt_fabsl(x) fabs((x)) #else #define crt_fabs(x) __builtin_fabs((x)) #define crt_fabsf(x) __builtin_fabsf((x)) #define crt_fabsl(x) __builtin_fabsl((x)) #endif #if defined(_MSC_VER) && !defined(__clang__) #define crt_fmaxl(x, y) __max((x), (y)) #else #define crt_fmaxl(x, y) __builtin_fmaxl((x), (y)) #endif #if defined(_MSC_VER) && !defined(__clang__) #define crt_logbl(x) logbl((x)) #else #define crt_logbl(x) __builtin_logbl((x)) #endif #if defined(_MSC_VER) && !defined(__clang__) #define crt_scalbnl(x, y) scalbnl((x), (y)) #else #define crt_scalbnl(x, y) __builtin_scalbnl((x), (y)) #endif #endif // INT_MATH_H ================================================ FILE: codon/runtime/floatlib/int_types.h ================================================ //===-- int_lib.h - configuration header for compiler-rt -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file is not part of the interface of this library. // // This file defines various standard types, most importantly a number of unions // used to access parts of larger types. // //===----------------------------------------------------------------------===// #ifndef INT_TYPES_H #define INT_TYPES_H #include "int_endianness.h" // si_int is defined in Linux sysroot's asm-generic/siginfo.h #ifdef si_int #undef si_int #endif typedef int32_t si_int; typedef uint32_t su_int; #if UINT_MAX == 0xFFFFFFFF #define clzsi __builtin_clz #define ctzsi __builtin_ctz #elif ULONG_MAX == 0xFFFFFFFF #define clzsi __builtin_clzl #define ctzsi __builtin_ctzl #else #error could not determine appropriate clzsi macro for this system #endif typedef int64_t di_int; typedef uint64_t du_int; typedef union { di_int all; struct { #if _YUGA_LITTLE_ENDIAN su_int low; si_int high; #else si_int high; su_int low; #endif // _YUGA_LITTLE_ENDIAN } s; } dwords; typedef union { du_int all; struct { #if _YUGA_LITTLE_ENDIAN su_int low; su_int high; #else su_int high; su_int low; #endif // _YUGA_LITTLE_ENDIAN } s; } udwords; #if defined(__LP64__) || defined(__wasm__) || defined(__mips64) || \ defined(__SIZEOF_INT128__) || defined(_WIN64) #define CRT_HAS_128BIT #endif // MSVC doesn't have a working 128bit integer type. Users should really compile // compiler-rt with clang, but if they happen to be doing a standalone build for // asan or something else, disable the 128 bit parts so things sort of work. #if defined(_MSC_VER) && !defined(__clang__) #undef CRT_HAS_128BIT #endif #ifdef CRT_HAS_128BIT typedef int ti_int __attribute__((mode(TI))); typedef unsigned tu_int __attribute__((mode(TI))); typedef union { ti_int all; struct { #if _YUGA_LITTLE_ENDIAN du_int low; di_int high; #else di_int high; du_int low; #endif // _YUGA_LITTLE_ENDIAN } s; } twords; typedef union { tu_int all; struct { #if _YUGA_LITTLE_ENDIAN du_int low; du_int high; #else du_int high; du_int low; #endif // _YUGA_LITTLE_ENDIAN } s; } utwords; static __inline ti_int make_ti(di_int h, di_int l) { twords r; r.s.high = h; r.s.low = l; return r.all; } static __inline tu_int make_tu(du_int h, du_int l) { utwords r; r.s.high = h; r.s.low = l; return r.all; } #endif // CRT_HAS_128BIT // FreeBSD's boot environment does not support using floating-point and poisons // the float and double keywords. #if defined(__FreeBSD__) && defined(_STANDALONE) #define CRT_HAS_FLOATING_POINT 0 #else #define CRT_HAS_FLOATING_POINT 1 #endif #if CRT_HAS_FLOATING_POINT typedef union { su_int u; float f; } float_bits; typedef union { udwords u; double f; } double_bits; #endif typedef struct { #if _YUGA_LITTLE_ENDIAN udwords low; udwords high; #else udwords high; udwords low; #endif // _YUGA_LITTLE_ENDIAN } uqwords; // Check if the target supports 80 bit extended precision long doubles. // Notably, on x86 Windows, MSVC only provides a 64-bit long double, but GCC // still makes it 80 bits. Clang will match whatever compiler it is trying to // be compatible with. On 32-bit x86 Android, long double is 64 bits, while on // x86_64 Android, long double is 128 bits. #if (defined(__i386__) || defined(__x86_64__)) && \ !(defined(_MSC_VER) || defined(__ANDROID__)) #define HAS_80_BIT_LONG_DOUBLE 1 #elif defined(__m68k__) || defined(__ia64__) #define HAS_80_BIT_LONG_DOUBLE 1 #else #define HAS_80_BIT_LONG_DOUBLE 0 #endif #if CRT_HAS_FLOATING_POINT typedef union { uqwords u; long double f; } long_double_bits; #if __STDC_VERSION__ >= 199901L typedef float _Complex Fcomplex; typedef double _Complex Dcomplex; typedef long double _Complex Lcomplex; #define COMPLEX_REAL(x) __real__(x) #define COMPLEX_IMAGINARY(x) __imag__(x) #else typedef struct { float real, imaginary; } Fcomplex; typedef struct { double real, imaginary; } Dcomplex; typedef struct { long double real, imaginary; } Lcomplex; #define COMPLEX_REAL(x) (x).real #define COMPLEX_IMAGINARY(x) (x).imaginary #endif #endif #endif // INT_TYPES_H ================================================ FILE: codon/runtime/floatlib/int_util.h ================================================ //===-- int_util.h - internal utility functions ---------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file is not part of the interface of this library. // // This file defines non-inline utilities which are available for use in the // library. The function definitions themselves are all contained in int_util.c // which will always be compiled into any compiler-rt library. // //===----------------------------------------------------------------------===// #ifndef INT_UTIL_H #define INT_UTIL_H /// \brief Trigger a program abort (or panic for kernel code). #define compilerrt_abort() __compilerrt_abort_impl(__FILE__, __LINE__, __func__) NORETURN void __compilerrt_abort_impl(const char *file, int line, const char *function); #define COMPILE_TIME_ASSERT(expr) COMPILE_TIME_ASSERT1(expr, __COUNTER__) #define COMPILE_TIME_ASSERT1(expr, cnt) COMPILE_TIME_ASSERT2(expr, cnt) #define COMPILE_TIME_ASSERT2(expr, cnt) \ typedef char ct_assert_##cnt[(expr) ? 1 : -1] UNUSED // Force unrolling the code specified to be repeated N times. #define REPEAT_0_TIMES(code_to_repeat) /* do nothing */ #define REPEAT_1_TIMES(code_to_repeat) code_to_repeat #define REPEAT_2_TIMES(code_to_repeat) \ REPEAT_1_TIMES(code_to_repeat) \ code_to_repeat #define REPEAT_3_TIMES(code_to_repeat) \ REPEAT_2_TIMES(code_to_repeat) \ code_to_repeat #define REPEAT_4_TIMES(code_to_repeat) \ REPEAT_3_TIMES(code_to_repeat) \ code_to_repeat #define REPEAT_N_TIMES_(N, code_to_repeat) REPEAT_##N##_TIMES(code_to_repeat) #define REPEAT_N_TIMES(N, code_to_repeat) REPEAT_N_TIMES_(N, code_to_repeat) #endif // INT_UTIL_H ================================================ FILE: codon/runtime/floatlib/truncdfbf2.c ================================================ //===-- lib/truncdfbf2.c - double -> bfloat conversion ------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #define SRC_DOUBLE #define DST_BFLOAT #include "fp_trunc_impl.inc" COMPILER_RT_ABI dst_t __truncdfbf2(double a) { return __truncXfYf2__(a); } ================================================ FILE: codon/runtime/floatlib/truncdfhf2.c ================================================ //===-- lib/truncdfhf2.c - double -> half conversion --------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #define SRC_DOUBLE #define DST_HALF #include "fp_trunc_impl.inc" COMPILER_RT_ABI dst_t __truncdfhf2(double a) { return __truncXfYf2__(a); } #if defined(__ARM_EABI__) #if defined(COMPILER_RT_ARMHF_TARGET) AEABI_RTABI dst_t __aeabi_d2h(double a) { return __truncdfhf2(a); } #else COMPILER_RT_ALIAS(__truncdfhf2, __aeabi_d2h) #endif #endif ================================================ FILE: codon/runtime/floatlib/truncdfsf2.c ================================================ //===-- lib/truncdfsf2.c - double -> single conversion ------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #define SRC_DOUBLE #define DST_SINGLE #include "fp_trunc_impl.inc" COMPILER_RT_ABI float __truncdfsf2(double a) { return __truncXfYf2__(a); } #if defined(__ARM_EABI__) #if defined(COMPILER_RT_ARMHF_TARGET) AEABI_RTABI float __aeabi_d2f(double a) { return __truncdfsf2(a); } #else COMPILER_RT_ALIAS(__truncdfsf2, __aeabi_d2f) #endif #endif ================================================ FILE: codon/runtime/floatlib/truncsfbf2.c ================================================ //===-- lib/truncsfbf2.c - single -> bfloat conversion ------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #define SRC_SINGLE #define DST_BFLOAT #include "fp_trunc_impl.inc" COMPILER_RT_ABI dst_t __truncsfbf2(float a) { return __truncXfYf2__(a); } ================================================ FILE: codon/runtime/floatlib/truncsfhf2.c ================================================ //===-- lib/truncsfhf2.c - single -> half conversion --------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #define SRC_SINGLE #define DST_HALF #include "fp_trunc_impl.inc" // Use a forwarding definition and noinline to implement a poor man's alias, // as there isn't a good cross-platform way of defining one. COMPILER_RT_ABI NOINLINE dst_t __truncsfhf2(float a) { return __truncXfYf2__(a); } COMPILER_RT_ABI dst_t __gnu_f2h_ieee(float a) { return __truncsfhf2(a); } #if defined(__ARM_EABI__) #if defined(COMPILER_RT_ARMHF_TARGET) AEABI_RTABI dst_t __aeabi_f2h(float a) { return __truncsfhf2(a); } #else COMPILER_RT_ALIAS(__truncsfhf2, __aeabi_f2h) #endif #endif ================================================ FILE: codon/runtime/floatlib/trunctfdf2.c ================================================ //===-- lib/truncdfsf2.c - quad -> double conversion --------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #define QUAD_PRECISION #include "fp_lib.h" #if defined(CRT_HAS_TF_MODE) #define SRC_QUAD #define DST_DOUBLE #include "fp_trunc_impl.inc" COMPILER_RT_ABI double __trunctfdf2(long double a) { return __truncXfYf2__(a); } #endif ================================================ FILE: codon/runtime/floatlib/trunctfhf2.c ================================================ //===-- lib/trunctfhf2.c - quad -> half conversion ----------------*- C -*-===// // // The LLVM Compiler Infrastructure // // This file is dual licensed under the MIT and the University of Illinois Open // Source Licenses. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// #define QUAD_PRECISION #include "fp_lib.h" #if defined(CRT_HAS_TF_MODE) && defined(COMPILER_RT_HAS_FLOAT16) #define SRC_QUAD #define DST_HALF #include "fp_trunc_impl.inc" COMPILER_RT_ABI _Float16 __trunctfhf2(long double a) { return __truncXfYf2__(a); } #endif ================================================ FILE: codon/runtime/floatlib/trunctfsf2.c ================================================ //===-- lib/trunctfsf2.c - quad -> single conversion --------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #define QUAD_PRECISION #include "fp_lib.h" #if defined(CRT_HAS_TF_MODE) #define SRC_QUAD #define DST_SINGLE #include "fp_trunc_impl.inc" COMPILER_RT_ABI float __trunctfsf2(long double a) { return __truncXfYf2__(a); } #endif ================================================ FILE: codon/runtime/lib.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define GC_THREADS #include "codon/runtime/lib.h" #include #include #define FASTFLOAT_ALLOWS_LEADING_PLUS #define FASTFLOAT_SKIP_WHITE_SPACE #include "fast_float/fast_float.h" /* * General */ #define USE_STANDARD_MALLOC 0 // OpenMP patch with GC callbacks typedef int (*gc_setup_callback)(GC_stack_base *); typedef void (*gc_roots_callback)(void *, void *); extern "C" void __kmpc_set_gc_callbacks(gc_setup_callback get_stack_base, gc_setup_callback register_thread, gc_roots_callback add_roots, gc_roots_callback del_roots); void seq_exc_init(int flags); int seq_flags; SEQ_FUNC void seq_init(int flags) { #if !USE_STANDARD_MALLOC GC_INIT(); GC_set_warn_proc(GC_ignore_warn_proc); GC_allow_register_threads(); __kmpc_set_gc_callbacks(GC_get_stack_base, (gc_setup_callback)GC_register_my_thread, GC_add_roots, GC_remove_roots); #endif seq_exc_init(flags); seq_flags = flags; } SEQ_FUNC seq_int_t seq_pid() { return (seq_int_t)getpid(); } SEQ_FUNC seq_int_t seq_time() { auto duration = std::chrono::system_clock::now().time_since_epoch(); seq_int_t nanos = std::chrono::duration_cast(duration).count(); return nanos; } SEQ_FUNC seq_int_t seq_time_monotonic() { auto duration = std::chrono::steady_clock::now().time_since_epoch(); seq_int_t nanos = std::chrono::duration_cast(duration).count(); return nanos; } SEQ_FUNC seq_int_t seq_time_highres() { auto duration = std::chrono::high_resolution_clock::now().time_since_epoch(); seq_int_t nanos = std::chrono::duration_cast(duration).count(); return nanos; } static void copy_time_c_to_seq(struct tm *x, seq_time_t *output) { output->year = x->tm_year; output->yday = x->tm_yday; output->sec = x->tm_sec; output->min = x->tm_min; output->hour = x->tm_hour; output->mday = x->tm_mday; output->mon = x->tm_mon; output->wday = x->tm_wday; output->isdst = x->tm_isdst; } static void copy_time_seq_to_c(seq_time_t *x, struct tm *output) { output->tm_year = x->year; output->tm_yday = x->yday; output->tm_sec = x->sec; output->tm_min = x->min; output->tm_hour = x->hour; output->tm_mday = x->mday; output->tm_mon = x->mon; output->tm_wday = x->wday; output->tm_isdst = x->isdst; } SEQ_FUNC bool seq_localtime(seq_int_t secs, seq_time_t *output) { struct tm result; time_t now = (secs >= 0 ? secs : time(nullptr)); if (now == (time_t)-1 || !localtime_r(&now, &result)) return false; copy_time_c_to_seq(&result, output); return true; } SEQ_FUNC bool seq_gmtime(seq_int_t secs, seq_time_t *output) { struct tm result; time_t now = (secs >= 0 ? secs : time(nullptr)); if (now == (time_t)-1 || !gmtime_r(&now, &result)) return false; copy_time_c_to_seq(&result, output); return true; } SEQ_FUNC seq_int_t seq_mktime(seq_time_t *time) { struct tm result; copy_time_seq_to_c(time, &result); return mktime(&result); } SEQ_FUNC void seq_sleep(double secs) { std::this_thread::sleep_for(std::chrono::duration>(secs)); } extern char **environ; SEQ_FUNC char **seq_env() { return environ; } /* * GC */ SEQ_FUNC void *seq_alloc(size_t n) { #if USE_STANDARD_MALLOC return malloc(n); #else return GC_MALLOC(n); #endif } SEQ_FUNC void *seq_alloc_atomic(size_t n) { #if USE_STANDARD_MALLOC return malloc(n); #else return GC_MALLOC_ATOMIC(n); #endif } SEQ_FUNC void *seq_alloc_uncollectable(size_t n) { #if USE_STANDARD_MALLOC return malloc(n); #else return GC_MALLOC_UNCOLLECTABLE(n); #endif } SEQ_FUNC void *seq_alloc_atomic_uncollectable(size_t n) { #if USE_STANDARD_MALLOC return malloc(n); #else return GC_MALLOC_ATOMIC_UNCOLLECTABLE(n); #endif } SEQ_FUNC void *seq_realloc(void *p, size_t newsize, size_t oldsize) { #if USE_STANDARD_MALLOC return realloc(p, newsize); #else return GC_REALLOC(p, newsize); #endif } SEQ_FUNC void seq_free(void *p) { #if USE_STANDARD_MALLOC free(p); #else GC_FREE(p); #endif } SEQ_FUNC void seq_register_finalizer(void *p, void (*f)(void *obj, void *data)) { #if !USE_STANDARD_MALLOC GC_REGISTER_FINALIZER(p, f, nullptr, nullptr, nullptr); #endif } SEQ_FUNC void seq_gc_add_roots(void *start, void *end) { #if !USE_STANDARD_MALLOC GC_add_roots(start, end); #endif } SEQ_FUNC void seq_gc_remove_roots(void *start, void *end) { #if !USE_STANDARD_MALLOC GC_remove_roots(start, end); #endif } SEQ_FUNC void seq_gc_clear_roots() { #if !USE_STANDARD_MALLOC GC_clear_roots(); #endif } SEQ_FUNC void seq_gc_exclude_static_roots(void *start, void *end) { #if !USE_STANDARD_MALLOC GC_exclude_static_roots(start, end); #endif } /* * String conversion */ static seq_str_t string_conv(const std::string &s) { auto n = s.size(); auto *p = (char *)seq_alloc_atomic(n); memcpy(p, s.data(), n); return {(seq_int_t)n, p}; } template std::string default_format(T n) { return fmt::format(FMT_STRING("{}"), n); } template <> std::string default_format(double n) { return fmt::format(FMT_STRING("{:g}"), n); } template seq_str_t fmt_conv(T n, seq_str_t format, bool *error) { *error = false; try { if (format.len == 0) { return string_conv(default_format(n)); } else { auto locale = std::locale("en_US.UTF-8"); std::string fstr(format.str, format.len); return string_conv(fmt::format( locale, fmt::runtime(fmt::format(FMT_STRING("{{:{}}}"), fstr)), n)); } } catch (const std::runtime_error &f) { *error = true; return string_conv(f.what()); } } SEQ_FUNC seq_str_t seq_str_int(seq_int_t n, seq_str_t format, bool *error) { return fmt_conv(n, format, error); } SEQ_FUNC seq_str_t seq_str_uint(seq_int_t n, seq_str_t format, bool *error) { return fmt_conv(n, format, error); } SEQ_FUNC seq_str_t seq_str_float(double f, seq_str_t format, bool *error) { return fmt_conv(f, format, error); } SEQ_FUNC seq_str_t seq_str_ptr(void *p, seq_str_t format, bool *error) { return fmt_conv(fmt::ptr(p), format, error); } SEQ_FUNC seq_str_t seq_str_str(seq_str_t s, seq_str_t format, bool *error) { std::string t(s.str, s.len); return fmt_conv(t, format, error); } SEQ_FUNC seq_int_t seq_int_from_str(seq_str_t s, const char **e, int base) { seq_int_t result; auto r = fast_float::from_chars(s.str, s.str + s.len, result, base); *e = (r.ec == std::errc()) ? r.ptr : s.str; return result; } SEQ_FUNC double seq_float_from_str(seq_str_t s, const char **e) { double result; auto r = fast_float::from_chars(s.str, s.str + s.len, result); *e = (r.ec == std::errc() || r.ec == std::errc::result_out_of_range) ? r.ptr : s.str; return result; } /* * General I/O */ SEQ_FUNC seq_str_t seq_check_errno() { if (errno) { std::string msg = strerror(errno); auto *buf = (char *)seq_alloc_atomic(msg.size()); memcpy(buf, msg.data(), msg.size()); return {(seq_int_t)msg.size(), buf}; } return {0, nullptr}; } SEQ_FUNC void seq_print(seq_str_t str) { seq_print_full(str, stdout); } static std::ostringstream capture; static std::mutex captureLock; SEQ_FUNC void seq_print_full(seq_str_t str, FILE *fo) { if ((seq_flags & SEQ_FLAG_CAPTURE_OUTPUT) && (fo == stdout || fo == stderr)) { captureLock.lock(); capture.write(str.str, str.len); captureLock.unlock(); } else { fwrite(str.str, 1, (size_t)str.len, fo); } } std::string codon::runtime::getCapturedOutput() { std::string result = capture.str(); capture.str(""); return result; } SEQ_FUNC void *seq_stdin() { return stdin; } SEQ_FUNC void *seq_stdout() { return stdout; } SEQ_FUNC void *seq_stderr() { return stderr; } /* * Threading */ SEQ_FUNC void *seq_lock_new() { return (void *)new (seq_alloc_atomic(sizeof(std::timed_mutex))) std::timed_mutex(); } SEQ_FUNC bool seq_lock_acquire(void *lock, bool block, double timeout) { auto *m = (std::timed_mutex *)lock; if (timeout < 0.0) { if (block) { m->lock(); return true; } else { return m->try_lock(); } } else { return m->try_lock_for(std::chrono::duration(timeout)); } } SEQ_FUNC void seq_lock_release(void *lock) { auto *m = (std::timed_mutex *)lock; m->unlock(); } SEQ_FUNC void *seq_rlock_new() { return (void *)new (seq_alloc_atomic(sizeof(std::recursive_timed_mutex))) std::recursive_timed_mutex(); } SEQ_FUNC bool seq_rlock_acquire(void *lock, bool block, double timeout) { auto *m = (std::recursive_timed_mutex *)lock; if (timeout < 0.0) { if (block) { m->lock(); return true; } else { return m->try_lock(); } } else { return m->try_lock_for(std::chrono::duration(timeout)); } } SEQ_FUNC void seq_rlock_release(void *lock) { auto *m = (std::recursive_timed_mutex *)lock; m->unlock(); } ================================================ FILE: codon/runtime/lib.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include #include #include #include #define SEQ_FLAG_DEBUG (1 << 0) // compiled/running in debug mode #define SEQ_FLAG_CAPTURE_OUTPUT (1 << 1) // capture writes to stdout/stderr #define SEQ_FLAG_STANDALONE (1 << 2) // compiled as a standalone object/binary #define SEQ_EXCEPTION_CLASS 0x6f626a0073657100 #define SEQ_FUNC extern "C" typedef int64_t seq_int_t; struct seq_str_t { seq_int_t len; char *str; }; struct seq_time_t { int16_t year; int16_t yday; int8_t sec; int8_t min; int8_t hour; int8_t mday; int8_t mon; int8_t wday; int8_t isdst; }; SEQ_FUNC int seq_flags; SEQ_FUNC void seq_init(int flags); SEQ_FUNC seq_int_t seq_pid(); SEQ_FUNC seq_int_t seq_time(); SEQ_FUNC seq_int_t seq_time_monotonic(); SEQ_FUNC seq_int_t seq_time_highres(); SEQ_FUNC bool seq_localtime(seq_int_t secs, seq_time_t *output); SEQ_FUNC bool seq_gmtime(seq_int_t secs, seq_time_t *output); SEQ_FUNC seq_int_t seq_mktime(seq_time_t *time); SEQ_FUNC void seq_sleep(double secs); SEQ_FUNC char **seq_env(); SEQ_FUNC void seq_assert_failed(seq_str_t file, seq_int_t line); SEQ_FUNC void *seq_alloc(size_t n); SEQ_FUNC void *seq_alloc_atomic(size_t n); SEQ_FUNC void *seq_alloc_uncollectable(size_t n); SEQ_FUNC void *seq_alloc_atomic_uncollectable(size_t n); SEQ_FUNC void *seq_realloc(void *p, size_t newsize, size_t oldsize); SEQ_FUNC void seq_free(void *p); SEQ_FUNC void seq_register_finalizer(void *p, void (*f)(void *obj, void *data)); SEQ_FUNC void seq_gc_add_roots(void *start, void *end); SEQ_FUNC void seq_gc_remove_roots(void *start, void *end); SEQ_FUNC void seq_gc_clear_roots(); SEQ_FUNC void seq_gc_exclude_static_roots(void *start, void *end); SEQ_FUNC void *seq_alloc_exc(void *obj); SEQ_FUNC void seq_throw(void *exc); SEQ_FUNC _Unwind_Reason_Code seq_personality(int version, _Unwind_Action actions, uint64_t exceptionClass, _Unwind_Exception *exceptionObject, _Unwind_Context *context); SEQ_FUNC int64_t seq_exc_offset(); SEQ_FUNC seq_str_t seq_str_int(seq_int_t n, seq_str_t format, bool *error); SEQ_FUNC seq_str_t seq_str_uint(seq_int_t n, seq_str_t format, bool *error); SEQ_FUNC seq_str_t seq_str_float(double f, seq_str_t format, bool *error); SEQ_FUNC seq_str_t seq_str_ptr(void *p, seq_str_t format, bool *error); SEQ_FUNC seq_str_t seq_str_str(seq_str_t s, seq_str_t format, bool *error); SEQ_FUNC void *seq_stdin(); SEQ_FUNC void *seq_stdout(); SEQ_FUNC void *seq_stderr(); SEQ_FUNC void seq_print(seq_str_t str); SEQ_FUNC void seq_print_full(seq_str_t str, FILE *fo); SEQ_FUNC void *seq_lock_new(); SEQ_FUNC bool seq_lock_acquire(void *lock, bool block, double timeout); SEQ_FUNC void seq_lock_release(void *lock); SEQ_FUNC void *seq_rlock_new(); SEQ_FUNC bool seq_rlock_acquire(void *lock, bool block, double timeout); SEQ_FUNC void seq_rlock_release(void *lock); namespace codon { namespace runtime { class JITError : public std::runtime_error { private: std::string output; std::string type; std::string file; int line; int col; std::vector backtrace; public: JITError(const std::string &output, const std::string &what, const std::string &type, const std::string &file, int line, int col, std::vector backtrace = {}) : std::runtime_error(what), output(output), type(type), file(file), line(line), col(col), backtrace(std::move(backtrace)) {} std::string getOutput() const { return output; } std::string getType() const { return type; } std::string getFile() const { return file; } int getLine() const { return line; } int getCol() const { return col; } std::vector getBacktrace() const { return backtrace; } }; std::string makeBacktraceFrameString(uintptr_t pc, const std::string &func = "", const std::string &file = "", int line = 0, int col = 0); std::string getCapturedOutput(); void setJITErrorCallback(std::function callback); } // namespace runtime } // namespace codon ================================================ FILE: codon/runtime/numpy/loops.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "codon/runtime/lib.h" #if (defined(__aarch64__) || defined(__arm64__)) && !defined(__ARM_FEATURE_SVE) #define HWY_DISABLED_TARGETS HWY_ALL_SVE #endif // clang-format off #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "codon/runtime/numpy/loops.cpp" #include "hwy/foreach_target.h" #include "hwy/highway.h" #include "hwy/contrib/math/math-inl.h" // clang-format on #include #include #include HWY_BEFORE_NAMESPACE(); namespace { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; struct AcosFunctor { template static inline auto vector(const hn::ScalableTag d, const V &v) { return Acos(d, v); } static inline double scalar(const double x) { return acos(x); } static inline float scalar(const float x) { return acosf(x); } }; struct AcoshFunctor { template static inline auto vector(const hn::ScalableTag d, const V &v) { const auto nan = Set(d, std::numeric_limits::quiet_NaN()); const auto pinf = Set(d, std::numeric_limits::infinity()); const auto pone = Set(d, static_cast(1.0)); return IfThenElse(v == pinf, pinf, IfThenElse(v < pone, nan, Acosh(d, v))); } static inline double scalar(const double x) { return acosh(x); } static inline float scalar(const float x) { return acoshf(x); } }; struct AsinFunctor { template static inline auto vector(const hn::ScalableTag d, const V &v) { return Asin(d, v); } static inline double scalar(const double x) { return asin(x); } static inline float scalar(const float x) { return asinf(x); } }; struct AsinhFunctor { template static inline auto vector(const hn::ScalableTag d, const V &v) { const auto pinf = Set(d, std::numeric_limits::infinity()); const auto ninf = Set(d, -std::numeric_limits::infinity()); const auto zero = Set(d, static_cast(0.0)); return IfThenElse(IsNaN(v), v, IfThenElse(IsInf(v), v, Asinh(d, v))); } static inline double scalar(const double x) { return asinh(x); } static inline float scalar(const float x) { return asinhf(x); } }; struct AtanFunctor { template static inline auto vector(const hn::ScalableTag d, const V &v) { const auto ppi2 = Set(d, static_cast(+3.14159265358979323846264 / 2)); const auto npi2 = Set(d, static_cast(-3.14159265358979323846264 / 2)); const auto pinf = Set(d, std::numeric_limits::infinity()); const auto ninf = Set(d, -std::numeric_limits::infinity()); return IfThenElse(v == pinf, ppi2, IfThenElse(v == ninf, npi2, Atan(d, v))); } static inline double scalar(const double x) { return atan(x); } static inline float scalar(const float x) { return atanf(x); } }; struct AtanhFunctor { template static inline auto vector(const hn::ScalableTag d, const V &v) { const auto nan = Set(d, std::numeric_limits::quiet_NaN()); const auto pinf = Set(d, std::numeric_limits::infinity()); const auto ninf = Set(d, -std::numeric_limits::infinity()); const auto pone = Set(d, static_cast(1.0)); const auto none = Set(d, static_cast(-1.0)); const auto nzero = Set(d, static_cast(0.0)); return IfThenElse( v == pone, pinf, IfThenElse(v == none, ninf, IfThenElse(Abs(v) > pone, nan, Atanh(d, v)))); } static inline double scalar(const double x) { return atanh(x); } static inline float scalar(const float x) { return atanhf(x); } }; struct Atan2Functor { template static inline auto vector(const hn::ScalableTag d, const V &v1, const V &v2) { constexpr bool kIsF32 = (sizeof(T) == 4); using TI = hwy::MakeSigned; const hn::Rebind> di; auto pzero = Set(di, kIsF32 ? static_cast(0x00000000L) : static_cast(0x0000000000000000LL)); auto nzero = Set(di, kIsF32 ? static_cast(0x80000000L) : static_cast(0x8000000000000000LL)); auto negneg = And(BitCast(di, v1) == nzero, BitCast(di, v2) == nzero); auto posneg = And(BitCast(di, v1) == pzero, BitCast(di, v2) == nzero); const auto ppi = Set(d, static_cast(+3.14159265358979323846264)); const auto npi = Set(d, static_cast(-3.14159265358979323846264)); return BitCast(d, IfThenElse(negneg, BitCast(di, npi), BitCast(di, IfThenElse(posneg, BitCast(di, ppi), BitCast(di, Atan2(d, v1, v2)))))); } static inline auto scalar(const double x, const double y) { return atan2(x, y); } static inline auto scalar(const float x, const float y) { return atan2f(x, y); } }; struct CosFunctor { template static inline T limit() { if constexpr (std::is_same_v) { return 3.37e9; } else if constexpr (std::is_same_v) { return 2.63e7f; } else { return T{}; } } template static inline auto vector(const hn::ScalableTag d, const V &v) { // Values outside of [-LIMIT, LIMIT] are not valid for SIMD version. const T LIMIT = limit(); HWY_LANES_CONSTEXPR size_t L = hn::Lanes(d); T tmp[L]; Store(v, d, tmp); for (auto i = 0; i < L; ++i) { const auto x = tmp[i]; if (x < -LIMIT || x > LIMIT) { // Just use scalar version in this case. for (auto j = 0; j < L; ++j) tmp[j] = scalar(tmp[j]); return Load(d, tmp); } } return Cos(d, v); } static inline double scalar(const double x) { return cos(x); } static inline float scalar(const float x) { return cosf(x); } }; struct ExpFunctor { template static inline T limit() { if constexpr (std::is_same_v) { return 1000.0; } else if constexpr (std::is_same_v) { return 128.0f; } else { return T{}; } } template static inline auto vector(const hn::ScalableTag d, const V &v) { const auto lim = Set(d, limit()); const auto pinf = Set(d, std::numeric_limits::infinity()); return IfThenElse(IsNaN(v), v, IfThenElse(v >= lim, pinf, Exp(d, v))); } static inline double scalar(const double x) { return exp(x); } static inline float scalar(const float x) { return expf(x); } }; struct Exp2Functor { template static inline T limit() { if constexpr (std::is_same_v) { return 2048.0; } else if constexpr (std::is_same_v) { return 128.0f; } else { return T{}; } } template static inline auto vector(const hn::ScalableTag d, const V &v) { const auto lim = Set(d, limit()); const auto pinf = Set(d, std::numeric_limits::infinity()); return IfThenElse(IsNaN(v), v, IfThenElse(v >= lim, pinf, Exp2(d, v))); } static inline double scalar(const double x) { return exp2(x); } static inline float scalar(const float x) { return exp2f(x); } }; struct Expm1Functor { template static inline T limit() { if constexpr (std::is_same_v) { return 1000.0; } else if constexpr (std::is_same_v) { return 128.0f; } else { return T{}; } } template static inline auto vector(const hn::ScalableTag d, const V &v) { const auto lim = Set(d, limit()); const auto pinf = Set(d, std::numeric_limits::infinity()); return IfThenElse(IsNaN(v), v, IfThenElse(v >= lim, pinf, Expm1(d, v))); } static inline double scalar(const double x) { return expm1(x); } static inline float scalar(const float x) { return expm1f(x); } }; struct LogFunctor { template static inline auto vector(const hn::ScalableTag d, const V &v) { const auto nan = Set(d, std::numeric_limits::quiet_NaN()); const auto pinf = Set(d, std::numeric_limits::infinity()); const auto ninf = Set(d, -std::numeric_limits::infinity()); const auto zero = Set(d, static_cast(0.0)); return IfThenElse( v == zero, ninf, IfThenElse(v < zero, nan, IfThenElse(v == pinf, pinf, IfThenElse(IsNaN(v), v, Log(d, v))))); } static inline double scalar(const double x) { return log(x); } static inline float scalar(const float x) { return logf(x); } }; struct Log10Functor { template static inline auto vector(const hn::ScalableTag d, const V &v) { const auto nan = Set(d, std::numeric_limits::quiet_NaN()); const auto pinf = Set(d, std::numeric_limits::infinity()); const auto ninf = Set(d, -std::numeric_limits::infinity()); const auto zero = Set(d, static_cast(0.0)); return IfThenElse( v == zero, ninf, IfThenElse(v < zero, nan, IfThenElse(v == pinf, pinf, IfThenElse(IsNaN(v), v, Log10(d, v))))); } static inline double scalar(const double x) { return log10(x); } static inline float scalar(const float x) { return log10f(x); } }; struct Log1pFunctor { template static inline auto vector(const hn::ScalableTag d, const V &v) { const auto nan = Set(d, std::numeric_limits::quiet_NaN()); const auto pinf = Set(d, std::numeric_limits::infinity()); const auto ninf = Set(d, -std::numeric_limits::infinity()); const auto none = Set(d, static_cast(-1.0)); return IfThenElse( v == none, ninf, IfThenElse(v < none, nan, IfThenElse(v == pinf, pinf, IfThenElse(IsNaN(v), v, Log1p(d, v))))); } static inline double scalar(const double x) { return log1p(x); } static inline float scalar(const float x) { return log1pf(x); } }; struct Log2Functor { template static inline auto vector(const hn::ScalableTag d, const V &v) { const auto nan = Set(d, std::numeric_limits::quiet_NaN()); const auto pinf = Set(d, std::numeric_limits::infinity()); const auto ninf = Set(d, -std::numeric_limits::infinity()); const auto zero = Set(d, static_cast(0.0)); return IfThenElse( v == zero, ninf, IfThenElse(v < zero, nan, IfThenElse(v == pinf, pinf, IfThenElse(IsNaN(v), v, Log2(d, v))))); } static inline double scalar(const double x) { return log2(x); } static inline float scalar(const float x) { return log2f(x); } }; struct SinFunctor { template static inline T limit() { if constexpr (std::is_same_v) { return 6.74e9; } else if constexpr (std::is_same_v) { return 5.30e8f; } else { return T{}; } } template static inline auto vector(const hn::ScalableTag d, const V &v) { // Values outside of [-LIMIT, LIMIT] are not valid for SIMD version. const T LIMIT = limit(); HWY_LANES_CONSTEXPR size_t L = hn::Lanes(d); T tmp[L]; Store(v, d, tmp); for (auto i = 0; i < L; ++i) { const auto x = tmp[i]; if (x < -LIMIT || x > LIMIT) { // Just use scalar version in this case. for (auto j = 0; j < L; ++j) tmp[j] = scalar(tmp[j]); return Load(d, tmp); } } return Sin(d, v); } static inline double scalar(const double x) { return sin(x); } static inline float scalar(const float x) { return sinf(x); } }; struct SinhFunctor { template static inline T limit() { if constexpr (std::is_same_v) { return 709.0; } else if constexpr (std::is_same_v) { return 88.7228f; } else { return T{}; } } template static inline auto vector(const hn::ScalableTag d, const V &v) { // Values outside of [-LIMIT, LIMIT] are not valid for SIMD version. const T LIMIT = limit(); HWY_LANES_CONSTEXPR size_t L = hn::Lanes(d); T tmp[L]; Store(v, d, tmp); for (auto i = 0; i < L; ++i) { const auto x = tmp[i]; if (x < -LIMIT || x > LIMIT) { // Just use scalar version in this case. for (auto j = 0; j < L; ++j) tmp[j] = scalar(tmp[j]); return Load(d, tmp); } } return Sinh(d, v); } static inline double scalar(const double x) { return sinh(x); } static inline float scalar(const float x) { return sinhf(x); } }; struct TanhFunctor { template static inline auto vector(const hn::ScalableTag d, const V &v) { return Tanh(d, v); } static inline double scalar(const double x) { return tanh(x); } static inline float scalar(const float x) { return tanhf(x); } }; struct HypotFunctor { template static inline auto vector(const hn::ScalableTag d, const V &v1, const V &v2) { return Hypot(d, v1, v2); } static inline auto scalar(const double x, const double y) { return hypot(x, y); } static inline auto scalar(const float x, const float y) { return hypotf(x, y); } }; template void UnaryLoop(const T *in, size_t is, T *out, size_t os, size_t n) { const hn::ScalableTag d; HWY_LANES_CONSTEXPR size_t L = hn::Lanes(d); T tmp[L]; size_t i; if (is == sizeof(T) && os == sizeof(T)) { for (i = 0; i + L <= n; i += L) { memcpy(tmp, in + i, L * sizeof(T)); auto vec = hn::Load(d, tmp); Store(F::template vector(d, vec), d, tmp); memcpy(out + i, tmp, L * sizeof(T)); } for (; i < n; ++i) out[i] = F::scalar(in[i]); } else { for (i = 0; i + L <= n; i += L) { for (size_t j = 0; j < L; ++j) tmp[j] = *(T *)((char *)in + (i + j) * is); auto vec = hn::Load(d, tmp); Store(F::template vector(d, vec), d, tmp); for (size_t j = 0; j < L; ++j) *(T *)((char *)out + (i + j) * os) = tmp[j]; } for (; i < n; ++i) *(T *)((char *)out + i * os) = F::scalar(*(T *)((char *)in + i * is)); } } template void BinaryLoop(const T *in1, size_t is1, const T *in2, size_t is2, T *out, size_t os, size_t n) { const hn::ScalableTag d; HWY_LANES_CONSTEXPR size_t L = hn::Lanes(d); T tmp1[L]; T tmp2[L]; size_t i; if (is1 == sizeof(T) && is2 == sizeof(T) && os == sizeof(T)) { for (i = 0; i + L <= n; i += L) { memcpy(tmp1, in1 + i, L * sizeof(T)); memcpy(tmp2, in2 + i, L * sizeof(T)); auto vec1 = hn::Load(d, tmp1); auto vec2 = hn::Load(d, tmp2); Store(F::template vector(d, vec1, vec2), d, tmp1); memcpy(out + i, tmp1, L * sizeof(T)); } for (; i < n; ++i) out[i] = F::scalar(in1[i], in2[i]); } else if (is1 == 0 && is2 == sizeof(T) && os == sizeof(T)) { for (size_t j = 0; j < L; ++j) tmp1[j] = in1[0]; for (i = 0; i + L <= n; i += L) { memcpy(tmp2, in2 + i, L * sizeof(T)); auto vec1 = hn::Load(d, tmp1); auto vec2 = hn::Load(d, tmp2); Store(F::template vector(d, vec1, vec2), d, tmp1); memcpy(out + i, tmp1, L * sizeof(T)); } for (; i < n; ++i) out[i] = F::scalar(in1[0], in2[i]); } else if (is1 == sizeof(T) && is2 == 0 && os == sizeof(T)) { for (size_t j = 0; j < L; ++j) tmp2[j] = in2[0]; for (i = 0; i + L <= n; i += L) { memcpy(tmp1, in1 + i, L * sizeof(T)); auto vec1 = hn::Load(d, tmp1); auto vec2 = hn::Load(d, tmp2); Store(F::template vector(d, vec1, vec2), d, tmp1); memcpy(out + i, tmp1, L * sizeof(T)); } for (; i < n; ++i) out[i] = F::scalar(in1[i], in2[0]); } else { for (i = 0; i + L <= n; i += L) { for (size_t j = 0; j < L; ++j) { tmp1[j] = *(T *)((char *)in1 + (i + j) * is1); tmp2[j] = *(T *)((char *)in2 + (i + j) * is2); } auto vec1 = hn::Load(d, tmp1); auto vec2 = hn::Load(d, tmp2); Store(F::template vector(d, vec1, vec2), d, tmp1); for (size_t j = 0; j < L; ++j) *(T *)((char *)out + (i + j) * os) = tmp1[j]; } for (; i < n; ++i) *(T *)((char *)out + i * os) = F::scalar(*(T *)((char *)in1 + i * is1), *(T *)((char *)in2 + i * is2)); } } void LoopAcos32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopAcos64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopAcosh32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopAcosh64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopAsin32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopAsin64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopAsinh32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopAsinh64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopAtan32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopAtan64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopAtanh32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopAtanh64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopAtan232(const float *in1, size_t is1, const float *in2, size_t is2, float *out, size_t os, size_t n) { BinaryLoop(in1, is1, in2, is2, out, os, n); } void LoopAtan264(const double *in1, size_t is1, const double *in2, size_t is2, double *out, size_t os, size_t n) { BinaryLoop(in1, is1, in2, is2, out, os, n); } void LoopCos32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopCos64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopExp32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopExp64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopExp232(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopExp264(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopExpm132(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopExpm164(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopLog32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopLog64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopLog1032(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopLog1064(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopLog1p32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopLog1p64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopLog232(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopLog264(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopSin32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopSin64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopSinh32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopSinh64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopTanh32(const float *in, size_t is, float *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopTanh64(const double *in, size_t is, double *out, size_t os, size_t n) { UnaryLoop(in, is, out, os, n); } void LoopHypot32(const float *in1, size_t is1, const float *in2, size_t is2, float *out, size_t os, size_t n) { BinaryLoop(in1, is1, in2, is2, out, os, n); } void LoopHypot64(const double *in1, size_t is1, const double *in2, size_t is2, double *out, size_t os, size_t n) { BinaryLoop(in1, is1, in2, is2, out, os, n); } } // namespace HWY_NAMESPACE } // namespace HWY_AFTER_NAMESPACE(); #if HWY_ONCE HWY_EXPORT(LoopAcos32); HWY_EXPORT(LoopAcos64); HWY_EXPORT(LoopAcosh32); HWY_EXPORT(LoopAcosh64); HWY_EXPORT(LoopAsin32); HWY_EXPORT(LoopAsin64); HWY_EXPORT(LoopAsinh32); HWY_EXPORT(LoopAsinh64); HWY_EXPORT(LoopAtan32); HWY_EXPORT(LoopAtan64); HWY_EXPORT(LoopAtanh32); HWY_EXPORT(LoopAtanh64); HWY_EXPORT(LoopAtan232); HWY_EXPORT(LoopAtan264); HWY_EXPORT(LoopCos32); HWY_EXPORT(LoopCos64); HWY_EXPORT(LoopExp32); HWY_EXPORT(LoopExp64); HWY_EXPORT(LoopExp232); HWY_EXPORT(LoopExp264); HWY_EXPORT(LoopExpm132); HWY_EXPORT(LoopExpm164); HWY_EXPORT(LoopLog32); HWY_EXPORT(LoopLog64); HWY_EXPORT(LoopLog1032); HWY_EXPORT(LoopLog1064); HWY_EXPORT(LoopLog1p32); HWY_EXPORT(LoopLog1p64); HWY_EXPORT(LoopLog232); HWY_EXPORT(LoopLog264); HWY_EXPORT(LoopSin32); HWY_EXPORT(LoopSin64); HWY_EXPORT(LoopSinh32); HWY_EXPORT(LoopSinh64); HWY_EXPORT(LoopTanh32); HWY_EXPORT(LoopTanh64); HWY_EXPORT(LoopHypot32); HWY_EXPORT(LoopHypot64); SEQ_FUNC void cnp_acos_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAcos32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_acos_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAcos64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_acosh_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAcosh32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_acosh_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAcosh64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_asin_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAsin32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_asin_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAsin64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_asinh_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAsinh32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_asinh_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAsinh64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_atan_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAtan32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_atan_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAtan64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_atanh_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAtanh32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_atanh_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAtanh64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_atan2_float32(const float *in1, size_t is1, const float *in2, size_t is2, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAtan232); return ptr(in1, is1, in2, is2, out, os, n); } SEQ_FUNC void cnp_atan2_float64(const double *in1, size_t is1, const double *in2, size_t is2, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopAtan264); return ptr(in1, is1, in2, is2, out, os, n); } SEQ_FUNC void cnp_cos_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopCos32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_cos_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopCos64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_exp_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopExp32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_exp_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopExp64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_exp2_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopExp232); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_exp2_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopExp264); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_expm1_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopExpm132); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_expm1_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopExpm164); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_log_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopLog32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_log_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopLog64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_log10_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopLog1032); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_log10_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopLog1064); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_log1p_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopLog1p32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_log1p_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopLog1p64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_log2_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopLog232); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_log2_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopLog264); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_sin_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopSin32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_sin_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopSin64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_sinh_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopSinh32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_sinh_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopSinh64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_tanh_float32(const float *in, size_t is, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopTanh32); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_tanh_float64(const double *in, size_t is, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopTanh64); return ptr(in, is, out, os, n); } SEQ_FUNC void cnp_hypot_float32(const float *in1, size_t is1, const float *in2, size_t is2, float *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopHypot32); return ptr(in1, is1, in2, is2, out, os, n); } SEQ_FUNC void cnp_hypot_float64(const double *in1, size_t is1, const double *in2, size_t is2, double *out, size_t os, size_t n) { const auto ptr = HWY_DYNAMIC_POINTER(LoopHypot64); return ptr(in1, is1, in2, is2, out, os, n); } #endif // HWY_ONCE ================================================ FILE: codon/runtime/numpy/sort.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "codon/runtime/lib.h" #include "hwy/contrib/sort/vqsort-inl.h" SEQ_FUNC void cnp_sort_int16(int16_t *data, int64_t n) { hwy::VQSort(data, n, hwy::SortAscending()); } SEQ_FUNC void cnp_sort_uint16(uint16_t *data, int64_t n) { hwy::VQSort(data, n, hwy::SortAscending()); } SEQ_FUNC void cnp_sort_int32(int32_t *data, int64_t n) { hwy::VQSort(data, n, hwy::SortAscending()); } SEQ_FUNC void cnp_sort_uint32(uint32_t *data, int64_t n) { hwy::VQSort(data, n, hwy::SortAscending()); } SEQ_FUNC void cnp_sort_int64(int64_t *data, int64_t n) { hwy::VQSort(data, n, hwy::SortAscending()); } SEQ_FUNC void cnp_sort_uint64(uint64_t *data, int64_t n) { hwy::VQSort(data, n, hwy::SortAscending()); } SEQ_FUNC void cnp_sort_uint128(hwy::uint128_t *data, int64_t n) { hwy::VQSort(data, n, hwy::SortAscending()); } SEQ_FUNC void cnp_sort_float32(float *data, int64_t n) { hwy::VQSort(data, n, hwy::SortAscending()); } SEQ_FUNC void cnp_sort_float64(double *data, int64_t n) { hwy::VQSort(data, n, hwy::SortAscending()); } ================================================ FILE: codon/runtime/numpy/zmath.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. // This file resolves ABI issues with single-precision complex // math functions. #include "codon/runtime/lib.h" #include SEQ_FUNC void cnp_cexpf(float r, float i, float *z) { std::complex x(r, i); auto y = std::exp(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_clogf(float r, float i, float *z) { std::complex x(r, i); auto y = std::log(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_csqrtf(float r, float i, float *z) { std::complex x(r, i); auto y = std::sqrt(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_ccoshf(float r, float i, float *z) { std::complex x(r, i); auto y = std::cosh(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_csinhf(float r, float i, float *z) { std::complex x(r, i); auto y = std::sinh(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_ctanhf(float r, float i, float *z) { std::complex x(r, i); auto y = std::tanh(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_cacoshf(float r, float i, float *z) { std::complex x(r, i); auto y = std::acosh(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_casinhf(float r, float i, float *z) { std::complex x(r, i); auto y = std::asinh(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_catanhf(float r, float i, float *z) { std::complex x(r, i); auto y = std::atanh(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_ccosf(float r, float i, float *z) { std::complex x(r, i); auto y = std::cos(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_csinf(float r, float i, float *z) { std::complex x(r, i); auto y = std::sin(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_ctanf(float r, float i, float *z) { std::complex x(r, i); auto y = std::tan(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_cacosf(float r, float i, float *z) { std::complex x(r, i); auto y = std::acos(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_casinf(float r, float i, float *z) { std::complex x(r, i); auto y = std::asin(x); z[0] = y.real(); z[1] = y.imag(); } SEQ_FUNC void cnp_catanf(float r, float i, float *z) { std::complex x(r, i); auto y = std::atan(x); z[0] = y.real(); z[1] = y.imag(); } ================================================ FILE: codon/runtime/re.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "codon/runtime/lib.h" #include #include #include #include #include #include #include using Regex = re2::RE2; using re2::StringPiece; /* * Flags -- (!) must match Codon's */ #define ASCII (1 << 0) #define DEBUG (1 << 1) #define IGNORECASE (1 << 2) #define LOCALE (1 << 3) #define MULTILINE (1 << 4) #define DOTALL (1 << 5) #define VERBOSE (1 << 6) static inline Regex::Options flags2opt(seq_int_t flags) { Regex::Options opt; opt.set_log_errors(false); opt.set_encoding(Regex::Options::Encoding::EncodingLatin1); if (flags & ASCII) { // nothing } if (flags & DEBUG) { // nothing } if (flags & IGNORECASE) { opt.set_case_sensitive(false); } if (flags & LOCALE) { // nothing } if (flags & MULTILINE) { opt.set_one_line(false); } if (flags & DOTALL) { opt.set_dot_nl(true); } if (flags & VERBOSE) { // nothing } return opt; } /* * Internal helpers & utilities */ struct Span { seq_int_t start; seq_int_t end; }; template struct GCMapAllocator : public std::allocator { GCMapAllocator() = default; GCMapAllocator(GCMapAllocator const &) = default; template GCMapAllocator(const GCMapAllocator &) noexcept {} KV *allocate(std::size_t n) { return (KV *)seq_alloc_uncollectable(n * sizeof(KV)); } void deallocate(KV *p, std::size_t n) { seq_free(p); } template struct rebind { using other = GCMapAllocator; }; }; static inline seq_str_t convert(const std::string &p) { seq_int_t n = p.size(); auto *s = (char *)seq_alloc_atomic(n); std::memcpy(s, p.data(), n); return {n, s}; } static inline StringPiece str2sp(const seq_str_t &s) { return StringPiece(s.str, s.len); } using Key = std::pair; struct KeyEqual { bool operator()(const Key &a, const Key &b) const { return a.second == b.second && str2sp(a.first) == str2sp(b.first); } }; struct KeyHash { std::size_t operator()(const Key &k) const { using sv = std::string_view; return std::hash()(sv(k.first.str, k.first.len)) ^ k.second; } }; static thread_local std::unordered_map>> cache; static inline Regex *get(const seq_str_t &p, seq_int_t flags) { auto key = std::make_pair(p, flags); auto it = cache.find(key); if (it == cache.end()) { auto result = cache.emplace(std::piecewise_construct, std::forward_as_tuple(key), std::forward_as_tuple(str2sp(p), flags2opt(flags))); return &result.first->second; } else { return &it->second; } } /* * Matching */ SEQ_FUNC Span *seq_re_match(Regex *re, seq_int_t anchor, seq_str_t s, seq_int_t pos, seq_int_t endpos) { const int num_groups = re->NumberOfCapturingGroups() + 1; // need $0 std::vector groups; groups.resize(num_groups); if (!re->Match(str2sp(s), pos, endpos, static_cast(anchor), groups.data(), groups.size())) { // Ensure that groups are null before converting to spans! for (auto &it : groups) { it = StringPiece(); } } auto *spans = (Span *)seq_alloc_atomic(num_groups * sizeof(Span)); unsigned i = 0; for (const auto &it : groups) { if (it.data() == nullptr) { spans[i++] = {-1, -1}; } else { spans[i++] = {static_cast(it.data() - s.str), static_cast(it.data() - s.str + it.size())}; } } return spans; } SEQ_FUNC Span seq_re_match_one(Regex *re, seq_int_t anchor, seq_str_t s, seq_int_t pos, seq_int_t endpos) { StringPiece m; if (!re->Match(str2sp(s), pos, endpos, static_cast(anchor), &m, 1)) return {-1, -1}; else return {static_cast(m.data() - s.str), static_cast(m.data() - s.str + m.size())}; } /* * General functions */ SEQ_FUNC seq_str_t seq_re_escape(seq_str_t p) { return convert(Regex::QuoteMeta(str2sp(p))); } SEQ_FUNC Regex *seq_re_compile(seq_str_t p, seq_int_t flags) { return get(p, flags); } SEQ_FUNC void seq_re_purge() { cache.clear(); } /* * Pattern methods */ SEQ_FUNC seq_int_t seq_re_pattern_groups(Regex *pattern) { return pattern->NumberOfCapturingGroups(); } SEQ_FUNC seq_int_t seq_re_group_name_to_index(Regex *pattern, seq_str_t name) { const auto &mapping = pattern->NamedCapturingGroups(); auto it = mapping.find(std::string(name.str, name.len)); return (it != mapping.end()) ? it->second : -1; } SEQ_FUNC seq_str_t seq_re_group_index_to_name(Regex *pattern, seq_int_t index) { const auto &mapping = pattern->CapturingGroupNames(); auto it = mapping.find(index); seq_str_t empty = {0, nullptr}; return (it != mapping.end()) ? convert(it->second) : empty; } SEQ_FUNC bool seq_re_check_rewrite_string(Regex *pattern, seq_str_t rewrite, seq_str_t *error) { std::string e; bool ans = pattern->CheckRewriteString(str2sp(rewrite), &e); if (!ans) *error = convert(e); return ans; } SEQ_FUNC seq_str_t seq_re_pattern_error(Regex *pattern) { if (pattern->ok()) return {0, nullptr}; return convert(pattern->error()); } ================================================ FILE: codon/util/common.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "common.h" #include "llvm/Support/Path.h" #include #include #include #include namespace codon { namespace { void compilationMessage(const std::string &header, const std::string &msg, const std::string &file, int line, int col, int len, int errorCode, MessageGroupPos pos) { auto &out = getLogger().err; seqassertn(!(file.empty() && (line > 0 || col > 0)), "empty filename with non-zero line/col: file={}, line={}, col={}", file, line, col); seqassertn(!(col > 0 && line <= 0), "col but no line: file={}, line={}, col={}", file, line, col); switch (pos) { case MessageGroupPos::NONE: break; case MessageGroupPos::HEAD: break; case MessageGroupPos::MID: fmt::print(out, "├─ "); break; case MessageGroupPos::LAST: fmt::print(out, "╰─ "); break; } fmt::print(out, "\033[1m"); if (!file.empty()) { auto f = file.substr(file.rfind('/') + 1); fmt::print(out, "{}", f == "-" ? "" : f); } if (line > 0) fmt::print(out, ":{}", line); if (col > 0) { fmt::print(out, " ({}", col); if (len > 0) fmt::print(out, "-{})", col + len); else fmt::print(out, ")"); } if (!file.empty()) fmt::print(out, ": "); fmt::print(out, "{}\033[1m {}\033[0m{}\n", header, msg, errorCode != -1 ? fmt::format(" (see https://exaloop.io/error/{:04d})", errorCode) : ""); } std::vector loggers; } // namespace std::ostream &operator<<(std::ostream &out, const codon::SrcInfo &src) { out << llvm::sys::path::filename(src.file).str() << ":" << src.line << ":" << src.col; return out; } void compilationError(const std::string &msg, const std::string &file, int line, int col, int len, int errorCode, bool terminate, MessageGroupPos pos) { compilationMessage("\033[1;31merror:\033[0m", msg, file, line, col, len, errorCode, pos); if (terminate) exit(EXIT_FAILURE); } void compilationWarning(const std::string &msg, const std::string &file, int line, int col, int len, int errorCode, bool terminate, MessageGroupPos pos) { compilationMessage("\033[1;33mwarning:\033[0m", msg, file, line, col, len, errorCode, pos); if (terminate) exit(EXIT_FAILURE); } void Logger::parse(const std::string &s) { flags |= s.find('t') != std::string::npos ? FLAG_TIME : 0; flags |= s.find('r') != std::string::npos ? FLAG_REALIZE : 0; flags |= s.find('T') != std::string::npos ? FLAG_TYPECHECK : 0; flags |= s.find('i') != std::string::npos ? FLAG_IR : 0; flags |= s.find('l') != std::string::npos ? FLAG_USER : 0; } } // namespace codon codon::Logger &codon::getLogger() { if (loggers.empty()) loggers.emplace_back(); return loggers.back(); } void codon::pushLogger() { loggers.emplace_back(); } bool codon::popLogger() { if (loggers.empty()) return false; loggers.pop_back(); return true; } void codon::assertionFailure(const char *expr_str, const char *file, int line, const std::string &msg) { auto &out = getLogger().err; out << "Assert failed:\t" << msg << "\n" << "Expression:\t" << expr_str << "\n" << "Source:\t\t" << file << ":" << line << "\n"; abort(); } ================================================ FILE: codon/util/common.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include #include #include #include #include "codon/compiler/error.h" #include "codon/config/config.h" #include "codon/parser/ast/error.h" #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" #define DBGI(c, ...) \ fmt::print(codon::getLogger().log, "{}" c "\n", \ std::string(2 * codon::getLogger().level, ' '), ##__VA_ARGS__) #define DBG(c, ...) fmt::print(codon::getLogger().log, c "\n", ##__VA_ARGS__) #define LOG(c, ...) DBG(c, ##__VA_ARGS__) #define LOG_TIME(c, ...) \ { \ if (codon::getLogger().flags & codon::Logger::FLAG_TIME) \ DBG(c, ##__VA_ARGS__); \ } #define LOG_REALIZE(c, ...) \ { \ if (codon::getLogger().flags & codon::Logger::FLAG_REALIZE) \ DBG(c, ##__VA_ARGS__); \ } #define LOG_TYPECHECK(c, ...) \ { \ if (codon::getLogger().flags & codon::Logger::FLAG_TYPECHECK) \ DBG(c, ##__VA_ARGS__); \ } #define LOG_IR(c, ...) \ { \ if (codon::getLogger().flags & codon::Logger::FLAG_IR) \ DBG(c, ##__VA_ARGS__); \ } #define LOG_USER(c, ...) \ { \ if (codon::getLogger().flags & codon::Logger::FLAG_USER) \ DBG(c, ##__VA_ARGS__); \ } #define TIME(name) codon::Timer __timer(name) #ifndef NDEBUG #define seqassertn(expr, msg, ...) \ ((expr) ? (void)(0) \ : codon::assertionFailure(#expr, __FILE__, __LINE__, \ fmt::format(msg, ##__VA_ARGS__))) #define seqassert(expr, msg, ...) \ ((expr) ? (void)(0) \ : codon::assertionFailure( \ #expr, __FILE__, __LINE__, \ fmt::format(msg " [{}]", ##__VA_ARGS__, getSrcInfo()))) #else #define seqassertn(expr, msg, ...) ; #define seqassert(expr, msg, ...) ; #endif #pragma clang diagnostic pop namespace codon { void assertionFailure(const char *expr_str, const char *file, int line, const std::string &msg); struct Logger { static constexpr int FLAG_TIME = (1 << 0); static constexpr int FLAG_REALIZE = (1 << 1); static constexpr int FLAG_TYPECHECK = (1 << 2); static constexpr int FLAG_IR = (1 << 3); static constexpr int FLAG_USER = (1 << 4); int flags; int level; std::ostream &out; std::ostream &err; std::ostream &log; Logger() : flags(0), level(0), out(std::cout), err(std::cerr), log(std::clog) {} void parse(const std::string &logs); }; Logger &getLogger(); void pushLogger(); bool popLogger(); class Timer { private: using clock_type = std::chrono::high_resolution_clock; std::string name; std::chrono::time_point start, end; public: bool logged; public: void log() { if (!logged) { LOG_TIME("[T] {} = {:.3f}", name, elapsed()); logged = true; } } double elapsed(std::chrono::time_point end = clock_type::now()) const { return std::chrono::duration_cast(end - start).count() / 1000.0; } Timer(std::string name) : name(std::move(name)), start(), end(), logged(false) { start = clock_type::now(); } ~Timer() { log(); } }; std::ostream &operator<<(std::ostream &out, const codon::SrcInfo &src); struct SrcObject { private: SrcInfo info; public: SrcObject() : info() {} SrcObject(const SrcObject &s) { setSrcInfo(s.getSrcInfo()); } virtual ~SrcObject() = default; SrcInfo getSrcInfo() const { return info; } SrcObject *setSrcInfo(SrcInfo info) { this->info = std::move(info); return this; } }; template void E(error::Error e, codon::SrcObject *o, const TA &...args) { E(e, o->getSrcInfo(), args...); } template void E(error::Error e, const codon::SrcObject &o, const TA &...args) { E(e, o.getSrcInfo(), args...); } template void E(error::Error e, const std::shared_ptr &o, const TA &...args) { E(e, o->getSrcInfo(), args...); } enum MessageGroupPos { NONE = 0, HEAD, MID, LAST, }; void compilationError(const std::string &msg, const std::string &file = "", int line = 0, int col = 0, int len = 0, int errorCode = -1, bool terminate = true, MessageGroupPos pos = NONE); void compilationWarning(const std::string &msg, const std::string &file = "", int line = 0, int col = 0, int len = 0, int errorCode = -1, bool terminate = false, MessageGroupPos pos = NONE); } // namespace codon template <> struct fmt::formatter : fmt::ostream_formatter {}; ================================================ FILE: codon/util/jupyter.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include "codon/util/jupyter.h" #include namespace codon { int startJupyterKernel(const std::string &argv0, const std::vector &plugins, const std::string &configPath) { fprintf(stderr, "Jupyter support not included. Please install Codon Jupyter plugin.\n"); return EXIT_FAILURE; } } // namespace codon ================================================ FILE: codon/util/jupyter.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include namespace codon { int startJupyterKernel(const std::string &argv0, const std::vector &plugins, const std::string &configPath); } // namespace codon ================================================ FILE: codon/util/peg2cpp.cpp ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define FMT_HEADER_ONLY #include using namespace std; string escape(const string &str) { string r; for (unsigned char c : str) { switch (c) { case '\n': r += "\\\\n"; break; case '\r': r += "\\\\r"; break; case '\t': r += "\\\\t"; break; case '\\': r += "\\\\"; break; case '"': r += "\\\""; break; default: if (c < 32 || c >= 127) r += fmt::format("\\\\x{:x}", c); else r += c; } } return r; } template string join(const T &items, const string &delim = " ", int start = 0, int end = -1) { string s; if (end == -1) end = items.size(); for (int i = start; i < end; i++) s += (i > start ? delim : "") + items[i]; return s; } // const string PREDICATE = ".predicate"; // bool is_predicate(const std::string &name) { // return (name.size() > PREDICATE.size() && name.substr(name.size() - // PREDICATE.size()) == PREDICATE); // } class PrintVisitor : public peg::Ope::Visitor { vector v; public: static string parse(const shared_ptr &op) { PrintVisitor v; op->accept(v); if (v.v.size()) { if (v.v[0].empty()) return fmt::format("P[\"{}\"]", v.v[1]); else return fmt::format("{}({})", v.v[0], join(v.v, ", ", 1)); } return "-"; }; private: void visit(peg::Sequence &s) override { v = {"seq"}; for (auto &o : s.opes_) v.push_back(parse(o)); } void visit(peg::PrioritizedChoice &s) override { v = {"cho"}; for (auto &o : s.opes_) v.push_back(parse(o)); } void visit(peg::Repetition &s) override { if (s.is_zom()) v = {"zom", parse(s.ope_)}; else if (s.min_ == 1 && s.max_ == std::numeric_limits::max()) v = {"oom", parse(s.ope_)}; else if (s.min_ == 0 && s.max_ == 1) v = {"opt", parse(s.ope_)}; else v = {"rep", parse(s.ope_), to_string(s.min_), to_string(s.max_)}; } void visit(peg::AndPredicate &s) override { v = {"apd", parse(s.ope_)}; } void visit(peg::NotPredicate &s) override { v = {"npd", parse(s.ope_)}; } void visit(peg::LiteralString &s) override { v = {s.ignore_case_ ? "liti" : "lit", fmt::format("\"{}\"", escape(s.lit_))}; } void visit(peg::CharacterClass &s) override { vector sv; for (auto &c : s.ranges_) sv.push_back(fmt::format("{{0x{:x}, 0x{:x}}}", (int)c.first, (int)c.second)); v = {s.negated_ ? "ncls" : "cls", "vc{" + join(sv, ",") + "}"}; } void visit(peg::Character &s) override { v = {"chr", fmt::format("'{}'", s.ch_)}; } void visit(peg::AnyCharacter &s) override { v = {"dot"}; } void visit(peg::Cut &s) override { v = {"cut"}; } void visit(peg::Reference &s) override { if (s.is_macro_) { vector vs; for (auto &o : s.args_) vs.push_back(parse(o)); v = {"ref", "P", fmt::format("\"{}\"", s.name_), "\"\"", "true", "{" + join(vs, ", ") + "}"}; } else { v = {"ref", "P", fmt::format("\"{}\"", s.name_)}; } } void visit(peg::TokenBoundary &s) override { v = {"tok", parse(s.ope_)}; } void visit(peg::Ignore &s) override { v = {"ign", parse(s.ope_)}; } void visit(peg::Recovery &s) override { v = {"rec", parse(s.ope_)}; } // infix TODO }; int main(int argc, char **argv) { peg::parser parser; fmt::print("Generating grammar from {}\n", argv[1]); ifstream ifs(argv[1]); string g((istreambuf_iterator(ifs)), istreambuf_iterator()); ifs.close(); string start; peg::Rules dummy = {}; if (string(argv[3]) == "codon") dummy["NLP"] = peg::usr([](const char *, size_t, peg::SemanticValues &, any &) -> size_t { return -1; }); bool enablePackratParsing; string preamble; peg::Log log = [](size_t line, size_t col, const string &msg, const string &rule) { cerr << line << ":" << col << ": " << msg << " (" << rule << ")\n"; }; auto grammar = peg::ParserGenerator::get_instance().perform_core( g.c_str(), g.size(), dummy, start, enablePackratParsing, preamble, log); assert(grammar); string rules, actions, actionFns; string action_preamble = " auto &CTX = any_cast(DT);\n"; string const_action_preamble = " const auto &CTX = any_cast(DT);\n"; string loc_preamble = " const auto &LI = VS.line_info();\n" " auto LOC = codon::SrcInfo(\n" " VS.path, LI.first + CTX.line_offset,\n" " LI.second + CTX.col_offset,\n" " VS.sv().size());\n"; for (auto &[name, def] : *grammar) { auto op = def.get_core_operator(); if (dummy.find(name) != dummy.end()) continue; rules += fmt::format(" {}P[\"{}\"] <= {};\n", def.ignoreSemanticValue ? "~" : "", name, PrintVisitor::parse(op)); rules += fmt::format(" P[\"{}\"].name = \"{}\";\n", name, escape(name)); if (def.is_macro) rules += fmt::format(" P[\"{}\"].is_macro = true;\n", name); if (!def.enable_memoize) rules += fmt::format(" P[\"{}\"].enable_memoize = false;\n", name); if (!def.params.empty()) { vector params; for (auto &p : def.params) params.push_back(fmt::format("\"{}\"", escape(p))); rules += fmt::format(" P[\"{}\"].params = {{{}}};\n", name, join(params, ", ")); } string code = op->code; if (code.empty()) { bool all_empty = true; if (auto ope = dynamic_cast(op.get())) { for (int i = 0; i < ope->opes_.size(); i++) if (!ope->opes_[i]->code.empty()) { code += fmt::format(" if (VS.choice() == {}) {}\n", i, ope->opes_[i]->code); all_empty = false; } else { code += fmt::format(" if (VS.choice() == {}) return V0;\n", i); } } if (all_empty) code = ""; if (!code.empty()) code = "{\n" + code + "}"; } if (!code.empty()) { code = code.substr(1, code.size() - 2); if (code.find("LOC") != std::string::npos) code = loc_preamble + code; if (code.find("CTX") != std::string::npos) code = action_preamble + code; actions += fmt::format("P[\"{}\"] = fn_{};\n", name, name); actionFns += fmt::format( "auto fn_{}(peg::SemanticValues &VS, any &DT) {{\n{}\n}};\n", name, code); } if (!(code = def.predicate_code).empty()) { code = code.substr(1, code.size() - 2); if (code.find("LOC") != std::string::npos) code = loc_preamble + code; if (code.find("CTX") != std::string::npos) code = const_action_preamble + code; actions += fmt::format("P[\"{}\"].predicate = pred_{};\n", name, name); actionFns += fmt::format("auto pred_{}(const peg::SemanticValues &VS, const any " "&DT, std::string &MSG) {{\n{}\n}};\n", name, code); } }; FILE *fout = fopen(argv[2], "w"); fmt::print(fout, "// clang-format off\n"); fmt::print(fout, "#pragma clang diagnostic push\n"); fmt::print(fout, "#pragma clang diagnostic ignored \"-Wreturn-type\"\n"); if (!preamble.empty()) fmt::print(fout, "{}\n", preamble.substr(1, preamble.size() - 2)); string rules_preamble = " using namespace peg;\n" " using peg::seq;\n" " using vc = vector>;\n"; fmt::print(fout, "void init_{}_rules(peg::Grammar &P) {{\n{}\n{}\n}}\n", argv[3], rules_preamble, rules); fmt::print(fout, "{}\n", actionFns); fmt::print(fout, "void init_{}_actions(peg::Grammar &P) {{\n {}\n}}\n", argv[3], actions); fmt::print(fout, "// clang-format on\n"); fmt::print(fout, "#pragma clang diagnostic pop\n"); fclose(fout); return 0; } ================================================ FILE: codon/util/serialize.h ================================================ // Copyright (C) 2022-2026 Exaloop Inc. #pragma once #include #include #include #include "codon/util/tser.h" namespace codon { template struct PolymorphicSerializer { struct Serializer { std::function save; std::function load; }; template static Serializer serializerFor() { return {[](Base *b, Archive &a) { a.save(*(static_cast(b))); }, [](Base *&b, Archive &a) { b = new Derived(); a.load(static_cast(*b)); }}; } static inline std::unordered_map _serializers; static inline std::unordered_map _factory; template static void register_types() { (_serializers.emplace((void *)(Derived::nodeId()), Derived::_typeName), ...); (_factory.emplace(std::string(Derived::_typeName), serializerFor()), ...); } static void save(const std::string &s, Base *b, Archive &a) { auto i = _factory.find(s); assert(i != _factory.end() && "bad op"); i->second.save(b, a); } static void load(const std::string &s, Base *&b, Archive &a) { auto i = _factory.find(s); assert(i != _factory.end() && "bad op"); i->second.load(b, a); } }; } // namespace codon #define SERIALIZE(Type, ...) \ inline decltype(auto) members() const { return std::tie(__VA_ARGS__); } \ inline decltype(auto) members() { return std::tie(__VA_ARGS__); } \ static constexpr std::array \ _memberNameData = []() { \ std::array chars{'\0'}; \ size_t _idx = 0; \ constexpr auto *ini(#__VA_ARGS__); \ for (char const *_c = ini; *_c; ++_c, ++_idx) \ if (*_c != ',' && *_c != ' ') \ chars[_idx] = *_c; \ return chars; \ }(); \ static constexpr const char *_typeName = #Type; \ static constexpr std::array \ _memberNames = []() { \ std::array out{}; \ for (size_t _i = 0, nArgs = 0; nArgs < tser::detail::n_args(#__VA_ARGS__); \ ++_i) { \ while (Type::_memberNameData[_i] == '\0') \ _i++; \ out[nArgs++] = &Type::_memberNameData[_i]; \ while (Type::_memberNameData[++_i] != '\0') \ ; \ } \ return out; \ }() #define BASE(T) tser::base(this) ================================================ FILE: codon/util/tser.h ================================================ // Licensed under the Boost License . // SPDX-License-Identifier: BSL-1.0 #pragma once #include #include #include #include #include #include #include // #include "tser/varint_encoding.hpp"// Licensed under the Boost License // . SPDX-License-Identifier: BSL-1.0 #include namespace tser { template size_t encode_varint(T value, char *output) { size_t i = 0; if constexpr (std::is_signed_v) value = static_cast(value << 1 ^ (value >> (sizeof(T) * 8 - 1))); for (; value > 127; ++i, value >>= 7) output[i] = static_cast(static_cast(value & 127) | 128); output[i++] = static_cast(value) & 127; return i; } template size_t decode_varint(T &value, const char *const input) { size_t i = 0; for (value = 0; i == 0 || (input[i - 1] & 128); i++) value |= static_cast(input[i] & 127) << (7 * i); if constexpr (std::is_signed_v) value = (value & 1) ? -static_cast((value + 1) >> 1) : (value + 1) >> 1; return i; } } // namespace tser // #include "tser/base64_encoding.hpp"// Licensed under the Boost License // . SPDX-License-Identifier: BSL-1.0 #include #include #include namespace tser { // tables for the base64 conversions static constexpr auto g_encodingTable = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; static constexpr auto g_decodingTable = []() { std::array decTable{}; for (unsigned char i = 0; i < 64u; ++i) decTable[static_cast(g_encodingTable[i])] = i; return decTable; }(); static std::string encode_base64(std::string_view in) { std::string out; unsigned val = 0; int valb = -6; for (char c : in) { val = (val << 8) + static_cast(c); valb += 8; while (valb >= 0) { out.push_back(g_encodingTable[(val >> valb) & 63u]); valb -= 6; } } if (valb > -6) out.push_back(g_encodingTable[((val << 8) >> (valb + 8)) & 0x3F]); return out; } static std::string decode_base64(std::string_view in) { std::string out; unsigned val = 0; int valb = -8; for (char c : in) { val = (val << 6) + g_decodingTable[static_cast(c)]; valb += 6; if (valb >= 0) { out.push_back(char((val >> valb) & 0xFF)); valb -= 8; } } return out; } } // namespace tser namespace tser { // implementation details for C++20 is_detected namespace detail { struct ns { ~ns() = delete; ns(ns const &) = delete; }; template class Op, class... Args> struct detector { using value_t = std::false_type; using type = Default; }; template class Op, class... Args> struct detector>, Op, Args...> { using value_t = std::true_type; using type = Op; }; template struct is_array : std::is_array {}; template