---
Raven is an ecosystem of OCaml libraries for numerical computing, machine learning, and data science. Everything you know from Python — NumPy, JAX, PyTorch, Matplotlib, Jupyter — rebuilt with type safety.
> Raven is **alpha**. APIs will change. [Feedback welcome.](https://github.com/raven-ml/raven/issues)
```ocaml
(* nx — n-dimensional arrays *)
let x = Nx.linspace float32 0. 10. 100
let y = Nx.sin x
(* rune — automatic differentiation *)
let grad_f = Rune.grad (fun x -> Rune.sum (Rune.mul x x)) x
(* brot — tokenization *)
let tokenizer = Brot.from_file "tokenizer.json" |> Result.get_ok
let ids = Brot.encode_ids tokenizer "The meaning of life is"
(* kaun — neural networks *)
let model = Kaun.Layer.sequential [
Kaun.Layer.linear ~in_features:768 ~out_features:128 ();
Kaun.Layer.relu ();
Kaun.Layer.linear ~in_features:128 ~out_features:10 ();
]
(* talon — dataframes *)
let df = Talon.create [
"name", Talon.Col.string_list [ "Alice"; "Bob"; "Charlie" ];
"score", Talon.Col.float64_list [ 85.5; 92.0; 78.5 ];
]
(* hugin — plotting *)
let () = Hugin.(figure () |> subplot |> Plotting.plot ~x ~y |> ignore; show ())
```
## Packages
| | Package | Like | What it does |
| --- | ------------------------------ | ----------------- | -------------------------------------------------------- |
| | [**nx**](packages/nx/) | NumPy | N-dimensional arrays with linear algebra operations |
| ᛏ | [**tolk**](packages/tolk/) | tinygrad | Minimal ML compiler for GPU tensor computation |
| ᚱ | [**rune**](packages/rune/) | JAX | Automatic differentiation and functional transformations |
| ᚲ | [**kaun**](packages/kaun/) | Flax | Neural networks and training |
| ᚹ | [**vega**](packages/vega/) | Optax | Composable gradient-based optimizers |
| ᚾ | [**norn**](packages/norn/) | BlackJAX | MCMC sampling with automatic gradients |
| ᚨ | [**brot**](packages/brot/) | HF Tokenizers | Fast, HuggingFace-compatible tokenization |
| ᛃ | [**talon**](packages/talon/) | Polars | Fast and elegant dataframes with type-safe operations |
| ᛞ | [**hugin**](packages/hugin/) | Matplotlib | Publication-quality plotting |
| ᛈ | [**quill**](packages/quill/) | Jupyter + IPython | Interactive REPL and markdown notebooks |
| ᚠ | [**fehu**](packages/fehu/) | Gymnasium | Reinforcement learning environments |
| ᛋ | [**sowilo**](packages/sowilo/) | OpenCV | Differentiable computer vision |
| ᛗ | [**munin**](packages/munin/) | W&B / MLFlow | Local experiment tracking with live TUI dashboard |
## Getting started
```bash
opam install raven
```
This installs the full ecosystem. You can also install only what you need — e.g. `opam install kaun` for neural networks, or `opam install nx` for just arrays.
Add to your `dune` file:
```dune
(executable
(name main)
(libraries raven))
```
See the [installation guide](https://raven-ml.dev/docs/installation/) for system dependencies and editor setup.
## Support
Building a scientific computing ecosystem takes sustained effort. Sponsorships help us ship JIT compilation, distributed training, better developer tooling, and production deployment through MirageOS.
**[Support Raven →](https://raven-ml.dev/docs/support-raven/)**
Thanks to our sponsors [Ahrefs](https://ahrefs.com) and [Tarides](https://tarides.com).
## License
[ISC](LICENSE)
================================================
FILE: TODO.md
================================================
# todo
## beta (jit)
goalpost: jit-compiled gpt2 matching pytorch performance
perf:
- close rune grad performance gap (within <2x of pytorch)
- close nx performance gaps (within <2x of numpy)
tolk:
- integrate tolk as rune jit transformation
- kernel fusion and optimization
- cpu, cuda, metal backends
## v1 (production)
goalpost: end-to-end train -> deploy as unikernel or static binary
training:
- gradient accumulation
- mixed precision (fp16/bf16 forward, fp32 master weights, loss scaling)
- gradient checkpointing (rune.checkpoint, recompute activations in backward)
- flash attention (tolk kernel and/or kaun.fn primitive)
- parallel data loading (ocaml 5 domains, background prefetch)
- layer completions: transposed conv, group norm, full conv2d stride/dilation/padding
- onnx import (onnx -> tolk ir adapter, cover resnet/bert/gpt2/llama/vit/whisper ops)
deployment:
- aot compilation: cpu (c via clang, musl static linking) and gpu (cuda/metal/opencl)
- mimir: kv cache, continuous batching, pagedattention
- mimir: http server (rest api, /health, /metrics, sigterm, structured logging)
- post-training quantization (int8/int4, tolk quantized kernels)
- mirageos unikernel deployment (raven-mirage package)
- no blas dep (tolk aot generates all compute)
- weight loading via network (mirage-http)
- verify ocaml 5 effects on mirageos runtime
- http server on mirageos network stack
docs/website:
- landing page rewrite with benchmarks
- deployment guide (aot, static binary, docker, mirageos, gpu)
- end-to-end examples (serving, onnx+deploy workflow)
================================================
FILE: dev/README.md
================================================
# dev
Development sandbox for experiments and prototypes that support the Raven ecosystem.
## Projects
| Name | Description |
| ---- | ----------- |
| [mimir](mimir/) | Experimental inference engine |
| [tolk](tolk/) | ML compiler inspired by tinygrad |
================================================
FILE: dev/mimir/README.md
================================================
# mimir
Experimental inference engine for raven.
The gap between "I can run a forward pass" and "I can serve a model in production" is large. mimir is where we figure out what the OCaml answer to that gap looks like.
## Current state
The sampling layer: composable logits processors (temperature, top-k, top-p, repetition penalty, n-gram blocking), stopping criteria, and the autoregressive generation loop operating on nx tensors.
This is the outermost piece of the inference puzzle — the part that turns model logits into actual token sequences. Everything below is open.
## What we want to explore
**Memory management for KV cache.** The attention mechanism produces intermediate state (keys and values) that grows linearly with sequence length. Naive allocation wastes memory; the interesting question is whether we can apply OS-style virtual memory ideas — fixed-size blocks, deferred allocation, reference-counted sharing — to make long sequences and shared prefixes cheap. This is the core idea behind PagedAttention.
**Request scheduling.** A single request is simple. Thousands of concurrent requests with different prompt lengths, generation limits, and priority levels is a scheduling problem. Batching amortizes GPU overhead but introduces latency trade-offs. Continuous batching (letting new requests join mid-batch as others finish) changes the calculus further. OCaml's algebraic types and pattern matching may give us a cleaner expression of scheduling policies than the typical mutable-state approach.
**Prefill/decode asymmetry.** The two phases of autoregressive generation have opposite performance characteristics — one is compute-bound, the other memory-bound. An engine that treats them identically leaves performance on the table.
**JIT compilation of decode steps.** The decode phase repeats the same computation graph with different inputs. If rune's JIT can capture and replay these graphs, we avoid per-step compilation overhead — similar in spirit to CUDA graph capture.
**Structured generation.** Constraining the sampling step so that output conforms to a grammar, regex, or JSON schema. This means masking logits at each step based on what the constraint automaton allows, which interacts with the sampling pipeline we already have.
**Tensor parallelism.** Splitting a model across multiple devices. This is a rune-level concern more than a mimir concern, but the inference engine needs to coordinate it.
## References
- [Nano-vLLM](https://github.com/GeeeekExplworker/nano-vllm) — minimal (~1,200 lines) inference engine by a DeepSeek contributor, good for understanding the essential moving parts
- [vLLM: PagedAttention paper](https://arxiv.org/abs/2309.06180)
- [SGLang](https://github.com/sgl-project/sglang) — alternative engine with RadixAttention for prefix sharing
================================================
FILE: dev/mimir/dune-project
================================================
(lang dune 3.21)
(name mimir)
(package
(name mimir)
(synopsis "Experimental inference engine for Raven")
(description
"Mimir is an inference engine for the Raven ecosystem. It provides sampling, KV cache management, request scheduling, and structured generation for serving ML models.")
(depends
(ocaml
(>= 5.2))))
================================================
FILE: dev/mimir/lib/dune
================================================
(library
(name mimir)
(public_name mimir)
(libraries nx unix))
================================================
FILE: dev/mimir/lib/mimir.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
include Sampler
================================================
FILE: dev/mimir/lib/mimir.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Mimir - Text generation with composable logits processors.
Experimental inference/generation library for the Raven ML ecosystem.
Provides the autoregressive decode loop, composable logits processors,
stopping criteria, and generation configuration. *)
include module type of Sampler
================================================
FILE: dev/mimir/lib/sampler.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* ───── Core Types ───── *)
type logits = (float, Bigarray.float32_elt) Nx.t
type token_ids = int array
(* ───── Logits Processors ───── *)
type logits_processor = {
name : string;
process : prompt_length:int -> token_ids -> logits -> logits;
}
type logits_processor_list = logits_processor list
(* ───── Stopping Criteria ───── *)
type stopping_criterion = {
name : string;
should_stop : prompt_length:int -> start_time:float -> token_ids -> bool;
}
type stopping_criteria_list = stopping_criterion list
(* ───── Generation Configuration ───── *)
type generation_config = {
max_length : int;
max_new_tokens : int option;
min_length : int;
min_new_tokens : int;
do_sample : bool;
temperature : float;
top_k : int;
top_p : float;
repetition_penalty : float;
no_repeat_ngram_size : int;
bad_words_ids : int list list;
force_words_ids : int list list;
pad_token_id : int option;
bos_token_id : int option;
eos_token_id : int option;
eos_token_ids : int list;
}
let default =
{
max_length = 100;
max_new_tokens = None;
min_length = 0;
min_new_tokens = 0;
do_sample = false;
temperature = 1.0;
top_k = 0;
top_p = 1.0;
repetition_penalty = 1.0;
no_repeat_ngram_size = 0;
bad_words_ids = [];
force_words_ids = [];
pad_token_id = None;
bos_token_id = None;
eos_token_id = None;
eos_token_ids = [];
}
(* ───── Builder Pattern ───── *)
let with_temperature temperature config = { config with temperature }
let with_top_k top_k config = { config with top_k }
let with_top_p top_p config = { config with top_p }
let with_repetition_penalty repetition_penalty config =
{ config with repetition_penalty }
let with_max_length max_length config = { config with max_length }
let with_max_new_tokens max_new_tokens config =
{ config with max_new_tokens = Some max_new_tokens }
let with_min_length min_length config = { config with min_length }
let with_min_new_tokens min_new_tokens config = { config with min_new_tokens }
let with_no_repeat_ngram no_repeat_ngram_size config =
{ config with no_repeat_ngram_size }
let with_do_sample do_sample config = { config with do_sample }
(* ───── Preset Configurations ───── *)
let creative_writing =
default |> with_do_sample true |> with_temperature 0.8 |> with_top_p 0.9
|> with_repetition_penalty 1.2
|> with_no_repeat_ngram 3 |> with_max_new_tokens 512
let chat =
default |> with_do_sample true |> with_temperature 0.7 |> with_top_p 0.95
|> with_repetition_penalty 1.1
|> with_max_new_tokens 512
let code_generation =
default |> with_do_sample true |> with_temperature 0.2 |> with_top_k 5
|> with_repetition_penalty 1.0
|> with_max_new_tokens 1024
let factual =
default |> with_do_sample true |> with_temperature 0.3 |> with_top_k 10
|> with_repetition_penalty 1.1
|> with_max_new_tokens 256
let from_preset = function
| "creative_writing" -> creative_writing
| "chat" -> chat
| "code_generation" -> code_generation
| "factual" -> factual
| _ -> default
(* ───── Logits Processors ───── *)
let neg_infinity = Float.neg_infinity
let temperature_warper ~temperature =
{
name = Printf.sprintf "temperature(%.2f)" temperature;
process =
(fun ~prompt_length:_ _tokens logits ->
if temperature = 1.0 then logits else Nx.div_s logits temperature);
}
let top_k_warper ~k =
{
name = Printf.sprintf "top_k(%d)" k;
process =
(fun ~prompt_length:_ _tokens logits ->
if k <= 0 then logits
else
let sorted_values, _sorted_indices =
Nx.sort ~descending:true logits
in
let vocab_size = Nx.numel logits in
let cutoff_k = min k vocab_size in
let threshold = Nx.item [ cutoff_k - 1 ] sorted_values in
let mask = Nx.less_s logits threshold in
Nx.where mask (Nx.full_like logits neg_infinity) logits);
}
let top_p_warper ~p =
{
name = Printf.sprintf "top_p(%.2f)" p;
process =
(fun ~prompt_length:_ _tokens logits ->
if p >= 1.0 then logits
else
let probs = Nx.softmax logits in
let sorted_probs, sorted_indices = Nx.sort ~descending:true probs in
let cumulative = Nx.cumsum sorted_probs in
(* Find where cumulative exceeds p, keeping at least 1 token *)
let cutoff_mask = Nx.greater_s cumulative p in
(* Shift mask right by 1 so the token that crosses p is kept *)
let n = Nx.numel logits in
let shifted_arr = Nx.to_array cutoff_mask in
let new_mask_arr = Array.make n false in
for i = 1 to n - 1 do
new_mask_arr.(i) <- shifted_arr.(i - 1)
done;
let shifted_mask = Nx.create Nx.bool [| n |] new_mask_arr in
(* Map mask back to original token order *)
let result = Nx.copy logits in
let sorted_idx_arr = Nx.to_array sorted_indices in
let shifted_mask_arr = Nx.to_array shifted_mask in
for i = 0 to n - 1 do
if shifted_mask_arr.(i) then
Nx.set_item
[ Int32.to_int sorted_idx_arr.(i) ]
neg_infinity result
done;
result);
}
let repetition_penalty ~penalty =
{
name = Printf.sprintf "repetition_penalty(%.2f)" penalty;
process =
(fun ~prompt_length:_ previous_tokens logits ->
if penalty = 1.0 then logits
else
let result = Nx.copy logits in
let vocab_size = Nx.numel result in
Array.iter
(fun token_id ->
if token_id < vocab_size then begin
let score = Nx.item [ token_id ] result in
let penalized =
if score < 0.0 then score *. penalty else score /. penalty
in
Nx.set_item [ token_id ] penalized result
end)
previous_tokens;
result);
}
let no_repeat_ngram ~ngram_size =
{
name = Printf.sprintf "no_repeat_ngram(%d)" ngram_size;
process =
(fun ~prompt_length:_ previous_tokens logits ->
let len = Array.length previous_tokens in
if ngram_size <= 0 || len < ngram_size - 1 then logits
else
let result = Nx.copy logits in
(* Get the last (ngram_size - 1) tokens as the current prefix *)
let prefix_start = len - (ngram_size - 1) in
let prefix =
Array.sub previous_tokens prefix_start (ngram_size - 1)
in
(* Scan history for matching prefixes *)
for i = 0 to len - ngram_size do
let matches = ref true in
for j = 0 to ngram_size - 2 do
if previous_tokens.(i + j) <> prefix.(j) then matches := false
done;
if !matches then begin
let blocked_token = previous_tokens.(i + ngram_size - 1) in
if blocked_token < Nx.numel result then
Nx.set_item [ blocked_token ] neg_infinity result
end
done;
result);
}
let min_length ~min_length ~eos_token_ids =
{
name = Printf.sprintf "min_length(%d)" min_length;
process =
(fun ~prompt_length:_ tokens logits ->
if Array.length tokens >= min_length then logits
else
let result = Nx.copy logits in
let vocab_size = Nx.numel result in
List.iter
(fun eos_id ->
if eos_id < vocab_size then
Nx.set_item [ eos_id ] neg_infinity result)
eos_token_ids;
result);
}
let min_new_tokens ~min_new_tokens ~eos_token_ids =
{
name = Printf.sprintf "min_new_tokens(%d)" min_new_tokens;
process =
(fun ~prompt_length tokens logits ->
let new_tokens = Array.length tokens - prompt_length in
if new_tokens >= min_new_tokens then logits
else
let result = Nx.copy logits in
let vocab_size = Nx.numel result in
List.iter
(fun eos_id ->
if eos_id < vocab_size then
Nx.set_item [ eos_id ] neg_infinity result)
eos_token_ids;
result);
}
let bad_words ~bad_words_ids =
{
name = "bad_words";
process =
(fun ~prompt_length:_ tokens logits ->
let result = Nx.copy logits in
let len = Array.length tokens in
let vocab_size = Nx.numel result in
List.iter
(fun bad_sequence ->
let seq_len = List.length bad_sequence in
if seq_len > 0 && len >= seq_len - 1 then (
let prefix_len = seq_len - 1 in
let matches = ref true in
let prefix = List.rev (List.tl (List.rev bad_sequence)) in
List.iteri
(fun i expected ->
if tokens.(len - prefix_len + i) <> expected then
matches := false)
prefix;
if !matches then begin
let bad_token = List.nth bad_sequence (seq_len - 1) in
if bad_token < vocab_size then
Nx.set_item [ bad_token ] neg_infinity result
end))
bad_words_ids;
result);
}
let force_words ~force_words_ids ~iteration =
{
name = "force_words";
process =
(fun ~prompt_length:_ _tokens logits ->
if iteration >= List.length force_words_ids then logits
else
let forced_tokens = List.nth force_words_ids iteration in
let result = Nx.full_like logits neg_infinity in
List.iter
(fun token_id ->
if token_id < Nx.numel result then
Nx.set_item [ token_id ] (Nx.item [ token_id ] logits) result)
forced_tokens;
result);
}
let custom ~name ~process = { name; process }
(* ───── Stopping Criteria ───── *)
let max_length_criteria ~max_length =
{
name = Printf.sprintf "max_length(%d)" max_length;
should_stop =
(fun ~prompt_length:_ ~start_time:_ tokens ->
Array.length tokens >= max_length);
}
let max_new_tokens_criteria ~max_new_tokens =
{
name = Printf.sprintf "max_new_tokens(%d)" max_new_tokens;
should_stop =
(fun ~prompt_length ~start_time:_ tokens ->
Array.length tokens - prompt_length >= max_new_tokens);
}
let eos_token_criteria ~eos_token_ids =
{
name = "eos_token";
should_stop =
(fun ~prompt_length:_ ~start_time:_ tokens ->
let len = Array.length tokens in
if len = 0 then false else List.mem tokens.(len - 1) eos_token_ids);
}
let max_time_criteria ~max_time =
{
name = Printf.sprintf "max_time(%.1fs)" max_time;
should_stop =
(fun ~prompt_length:_ ~start_time _tokens ->
Unix.gettimeofday () -. start_time > max_time);
}
let stop_strings_criteria ~stop_strings ~decoder =
{
name = "stop_strings";
should_stop =
(fun ~prompt_length:_ ~start_time:_ tokens ->
let text = decoder tokens in
List.exists
(fun stop_str -> String_util.contains_substring text stop_str)
stop_strings);
}
let custom_criteria ~name ~should_stop = { name; should_stop }
(* ───── Utilities ───── *)
let apply_processors ~processors ~prompt_length ~tokens ~logits =
List.fold_left
(fun acc processor -> processor.process ~prompt_length tokens acc)
logits processors
let check_stopping ~criteria ~prompt_length ~start_time ~tokens =
List.exists
(fun criterion -> criterion.should_stop ~prompt_length ~start_time tokens)
criteria
(* ───── Main Generation Functions ───── *)
type generation_output = {
sequences : int array list;
scores : float list list option;
}
let sample_from_logits logits =
let probs = Nx.softmax logits in
let probs_arr = Nx.to_array probs in
let r = Random.float 1.0 in
let cumsum = ref 0.0 in
let result = ref 0 in
(try
for i = 0 to Array.length probs_arr - 1 do
cumsum := !cumsum +. probs_arr.(i);
if !cumsum > r then begin
result := i;
raise_notrace Exit
end
done
with Exit -> ());
!result
let argmax logits = Int32.to_int (Nx.item [ 0 ] (Nx.argmax logits))
let generate ~model ?(input_ids = [||]) ?(generation_config = default)
?(logits_processor = []) ?(stopping_criteria = []) () =
let start_time = Unix.gettimeofday () in
let prompt_length = Array.length input_ids in
let processors =
let ps = [] in
let ps =
if generation_config.temperature <> 1.0 then
temperature_warper ~temperature:generation_config.temperature :: ps
else ps
in
let ps =
if generation_config.top_k > 0 then
top_k_warper ~k:generation_config.top_k :: ps
else ps
in
let ps =
if generation_config.top_p < 1.0 then
top_p_warper ~p:generation_config.top_p :: ps
else ps
in
let ps =
if generation_config.repetition_penalty <> 1.0 then
repetition_penalty ~penalty:generation_config.repetition_penalty :: ps
else ps
in
let ps =
if generation_config.no_repeat_ngram_size > 0 then
no_repeat_ngram ~ngram_size:generation_config.no_repeat_ngram_size :: ps
else ps
in
let eos_ids =
match generation_config.eos_token_id with
| Some id -> id :: generation_config.eos_token_ids
| None -> generation_config.eos_token_ids
in
let ps =
if generation_config.min_length > 0 then
min_length ~min_length:generation_config.min_length
~eos_token_ids:eos_ids
:: ps
else ps
in
let ps =
if generation_config.min_new_tokens > 0 then
min_new_tokens ~min_new_tokens:generation_config.min_new_tokens
~eos_token_ids:eos_ids
:: ps
else ps
in
ps @ logits_processor
in
let criteria =
let cs = [] in
let cs =
max_length_criteria ~max_length:generation_config.max_length :: cs
in
let cs =
match generation_config.max_new_tokens with
| Some max_new -> max_new_tokens_criteria ~max_new_tokens:max_new :: cs
| None -> cs
in
let eos_ids =
match generation_config.eos_token_id with
| Some id -> id :: generation_config.eos_token_ids
| None -> generation_config.eos_token_ids
in
let cs =
if eos_ids <> [] then eos_token_criteria ~eos_token_ids:eos_ids :: cs
else cs
in
cs @ stopping_criteria
in
let tokens_ref = ref (Array.copy input_ids) in
let rec generate_loop () =
let current_tokens = !tokens_ref in
if
Array.length current_tokens > prompt_length
&& check_stopping ~criteria ~prompt_length ~start_time
~tokens:current_tokens
then current_tokens
else begin
let raw_logits = model current_tokens in
let processed =
apply_processors ~processors ~prompt_length ~tokens:current_tokens
~logits:raw_logits
in
let next_token =
if generation_config.do_sample then sample_from_logits processed
else argmax processed
in
tokens_ref := Array.append current_tokens [| next_token |];
generate_loop ()
end
in
let sequences = generate_loop () in
{ sequences = [ sequences ]; scores = None }
let generate_text ~model ~tokenizer ~decoder ?(prompt = "")
?(generation_config = default) ?(logits_processor = [])
?(stopping_criteria = []) () =
let input_ids = tokenizer prompt in
let output =
generate ~model ~input_ids ~generation_config ~logits_processor
~stopping_criteria ()
in
match output.sequences with seq :: _ -> decoder seq | [] -> ""
================================================
FILE: dev/mimir/lib/sampler.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Text generation with composable logits processors.
Provides the autoregressive decode loop, composable logits processors,
stopping criteria, and generation configuration for language model
inference. Operates on nx tensors for logits. *)
(** {1 Core Types} *)
type logits = (float, Bigarray.float32_elt) Nx.t
(** 1D float32 tensor of unnormalized token probabilities. Length equals
vocabulary size. *)
type token_ids = int array
(** Sequence of token IDs representing encoded text. *)
type logits_processor = {
name : string;
process : prompt_length:int -> token_ids -> logits -> logits;
}
(** Transforms logits before sampling. *)
type logits_processor_list = logits_processor list
type stopping_criterion = {
name : string;
should_stop : prompt_length:int -> start_time:float -> token_ids -> bool;
}
(** Determines when to end generation. *)
type stopping_criteria_list = stopping_criterion list
(** {1 Generation Configuration} *)
type generation_config = {
max_length : int;
max_new_tokens : int option;
min_length : int;
min_new_tokens : int;
do_sample : bool;
temperature : float;
top_k : int;
top_p : float;
repetition_penalty : float;
no_repeat_ngram_size : int;
bad_words_ids : int list list;
force_words_ids : int list list;
pad_token_id : int option;
bos_token_id : int option;
eos_token_id : int option;
eos_token_ids : int list;
}
val default : generation_config
(** {2 Builder Pattern} *)
val with_temperature : float -> generation_config -> generation_config
val with_top_k : int -> generation_config -> generation_config
val with_top_p : float -> generation_config -> generation_config
val with_repetition_penalty : float -> generation_config -> generation_config
val with_max_length : int -> generation_config -> generation_config
val with_max_new_tokens : int -> generation_config -> generation_config
val with_min_length : int -> generation_config -> generation_config
val with_min_new_tokens : int -> generation_config -> generation_config
val with_no_repeat_ngram : int -> generation_config -> generation_config
val with_do_sample : bool -> generation_config -> generation_config
(** {2 Presets} *)
val creative_writing : generation_config
val chat : generation_config
val code_generation : generation_config
val factual : generation_config
val from_preset : string -> generation_config
(** {1 Logits Processors} *)
val temperature_warper : temperature:float -> logits_processor
val top_k_warper : k:int -> logits_processor
val top_p_warper : p:float -> logits_processor
val repetition_penalty : penalty:float -> logits_processor
val no_repeat_ngram : ngram_size:int -> logits_processor
val min_length : min_length:int -> eos_token_ids:int list -> logits_processor
val min_new_tokens :
min_new_tokens:int -> eos_token_ids:int list -> logits_processor
val bad_words : bad_words_ids:int list list -> logits_processor
val force_words :
force_words_ids:int list list -> iteration:int -> logits_processor
val custom :
name:string ->
process:(prompt_length:int -> token_ids -> logits -> logits) ->
logits_processor
(** {1 Stopping Criteria} *)
val max_length_criteria : max_length:int -> stopping_criterion
val max_new_tokens_criteria : max_new_tokens:int -> stopping_criterion
val eos_token_criteria : eos_token_ids:int list -> stopping_criterion
val max_time_criteria : max_time:float -> stopping_criterion
val stop_strings_criteria :
stop_strings:string list ->
decoder:(token_ids -> string) ->
stopping_criterion
val custom_criteria :
name:string ->
should_stop:(prompt_length:int -> start_time:float -> token_ids -> bool) ->
stopping_criterion
(** {1 Generation} *)
type generation_output = {
sequences : int array list;
scores : float list list option;
}
val generate :
model:(token_ids -> logits) ->
?input_ids:token_ids ->
?generation_config:generation_config ->
?logits_processor:logits_processor_list ->
?stopping_criteria:stopping_criteria_list ->
unit ->
generation_output
val generate_text :
model:(token_ids -> logits) ->
tokenizer:(string -> token_ids) ->
decoder:(token_ids -> string) ->
?prompt:string ->
?generation_config:generation_config ->
?logits_processor:logits_processor_list ->
?stopping_criteria:stopping_criteria_list ->
unit ->
string
(** {1 Utilities} *)
val apply_processors :
processors:logits_processor_list ->
prompt_length:int ->
tokens:token_ids ->
logits:logits ->
logits
val check_stopping :
criteria:stopping_criteria_list ->
prompt_length:int ->
start_time:float ->
tokens:token_ids ->
bool
================================================
FILE: dev/mimir/lib/string_util.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let contains_substring s sub =
let len_s = String.length s in
let len_sub = String.length sub in
if len_sub = 0 then true
else if len_sub > len_s then false
else
let rec check i =
if i > len_s - len_sub then false
else if String.sub s i len_sub = sub then true
else check (i + 1)
in
check 0
================================================
FILE: dev/umbra/README.md
================================================
# Umbra
Computational astronomy for OCaml, powered by [Nx](../../packages/nx/) and [Rune](../../packages/rune/)
Umbra provides dimensionally-typed physical quantities, cosmological distances,
spectral energy distributions, dust extinction, synthetic photometry, coordinate
transforms, time scales, catalog cross-matching, and weak lensing survey science.
All computations operate on Nx tensors and are differentiable through Rune --
fit cosmological parameters, propagate uncertainties via Jacobians, or sample
posteriors with HMC, all from the same forward model.
## Quick Start
Compute the luminosity distance to a galaxy at redshift 0.5:
```ocaml
open Umbra
let () =
let f64 = Nx.float64 in
let z = Nx.scalar f64 0.5 in
let dl = Cosmo.luminosity_distance ~p:Cosmo.planck18 z in
Printf.printf "d_L(z=0.5) = %.1f Mpc\n"
(Nx.item [] (Unit.Length.in_mpc dl))
```
Fit stellar temperature from photometry with automatic derivatives:
```ocaml
let model params =
let temp = Unit.Temperature.of_kelvin (Nx.exp (Nx.slice [ I 0 ] params)) in
let av = Nx.reshape [||] (Nx.slice [ I 1 ] params) in
let rv = Nx.scalar Nx.float64 3.1 in
List.map (fun bp ->
let wave = Photometry.wavelength bp in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:wave
|> Extinction.apply (Extinction.ccm89 ~rv) ~av
|> Spectrum.as_flux_density
in
Photometry.ab_mag bp sed) bands
|> Nx.stack ~axis:0
(* Rune differentiates through the entire pipeline *)
let loss, grad = Rune.value_and_grad chi2 params
```
## Features
- **Dimensional types**: `Unit.Length`, `Unit.Mass`, `Unit.Time`, `Unit.Angle`, etc. with compile-time safety
- **Physical constants**: CODATA 2022 and IAU 2015 via `Const`
- **Cosmology**: LCDM, wCDM, w0waCDM distances, growth factors, and matter power spectra via `Cosmo`
- **Spectra**: blackbody, power-law, and line profiles (Gaussian, Lorentzian, Voigt) via `Spectrum`
- **Extinction**: CCM89, Fitzpatrick99, O'Donnell94, Calzetti00 dust laws via `Extinction`
- **Photometry**: AB, ST, and Vega magnitudes through standard filter bandpasses via `Photometry`
- **Filters**: SDSS, Johnson-Cousins, 2MASS, Gaia DR3, Rubin/LSST, Euclid via `Filters`
- **Coordinates**: ICRS, Galactic, Ecliptic, Supergalactic frame transforms and kd-tree cross-matching via `Coord`
- **Time**: UTC, TAI, TT, TDB time scales with phantom-typed safety via `Time`
- **Observer geometry**: altitude-azimuth coordinates and airmass via `Altaz`
- **Survey science**: angular power spectra and Fisher forecasting via `Survey`
- **FITS I/O**: image and table read/write via `Umbra_fits`
- **Fully differentiable**: all forward models work with Rune's autodiff, Jacobians, and MCMC
## Examples
| Example | Concept |
|---------|---------|
| [`01-constants-and-units`](examples/01-constants-and-units/) | Type-safe physical quantities and conversions |
| [`02-cosmological-distances`](examples/02-cosmological-distances/) | LCDM distances and SN Ia fitting |
| [`03-blackbody-fitting`](examples/03-blackbody-fitting/) | Fit stellar temperature from photometry |
| [`04-extinction-and-magnitudes`](examples/04-extinction-and-magnitudes/) | Dust extinction, magnitude systems, K-corrections |
| [`05-sed-fitting`](examples/05-sed-fitting/) | Full SED pipeline: blackbody, extinction, photometry |
| [`06-coordinates-and-time`](examples/06-coordinates-and-time/) | Frame transforms, time scales, observer geometry |
| [`07-batch-photometry`](examples/07-batch-photometry/) | Batched operations over parameter grids |
| [`08-photometric-redshifts`](examples/08-photometric-redshifts/) | Two-stage photo-z: grid search + gradient refinement |
| [`09-gravitational-lensing`](examples/09-gravitational-lensing/) | Point-mass lens model parameter fitting |
| [`10-uncertainty-propagation`](examples/10-uncertainty-propagation/) | AD Jacobians for error propagation vs Monte Carlo |
| [`11-bayesian-sed`](examples/11-bayesian-sed/) | Fisher matrix + HMC posterior sampling |
| [`12-survey-optimization`](examples/12-survey-optimization/) | Differentiable Fisher forecasting for survey design |
## Papers
- [**Perlmutter et al. 1999**](papers/perlmutter1999/) -- Reproducing the Nobel Prize-winning discovery of cosmic acceleration using the Pantheon+ dataset
## Contributing
See the [Raven monorepo README](../../README.md) for guidelines.
## License
ISC License. See [LICENSE](../../LICENSE) for details.
================================================
FILE: dev/umbra/dune-project
================================================
(lang dune 3.21)
(name umbra)
(package
(name umbra)
(synopsis "Astronomy library for OCaml")
(description
"Physical units, celestial coordinates, FITS I/O, cosmological distances, and catalog cross-matching. Built on Nx and Talon.")
(depends
(ocaml
(>= 5.2.0))
dune
(nx
(>= 1.0.0~alpha3))
(talon
(>= 1.0.0~alpha3))
(windtrap :with-test)))
================================================
FILE: dev/umbra/examples/01-constants-and-units/README.md
================================================
# `01-constants-and-units`
Introduction to Umbra's type-safe unit system and physical constants. Creates
quantities in different units, converts between them, and demonstrates how
phantom types prevent mixing incompatible dimensions at compile time.
```bash
dune exec dev/umbra/examples/01-constants-and-units/main.exe
```
## What You'll Learn
- Creating quantities with scalar constructors (`Length.pc`, `Angle.deg`, `Mass.solar_mass`)
- Converting between units (`Length.in_ly`, `Angle.in_arcsec`)
- Adding quantities of the same dimension (`Unit.(+)`)
- Using physical constants (`Const.c`, `Const.h_si`, `Const.k_b_si`)
- Cross-dimension conversions (`parallax_to_distance`, `wavelength_to_frequency`)
- Batch operations on tensor-valued quantities
## Key Functions
| Function | Purpose |
| --------------------------- | -------------------------------------------- |
| `Length.pc`, `Length.au` | Create length quantities in parsecs, AU |
| `Length.in_m`, `Length.in_ly`| Extract values in metres, light-years |
| `Angle.deg`, `Angle.arcsec` | Create angles in degrees, arcseconds |
| `Temperature.kelvin` | Create temperature quantities |
| `Mass.solar_mass` | Create mass in solar masses |
| `Const.c`, `Const.h_si` | Speed of light, Planck constant |
| `parallax_to_distance` | Convert stellar parallax to distance |
| `wavelength_to_frequency` | Convert wavelength to frequency via c/lambda |
## Try It
1. Compute the Schwarzschild radius of the Sun using `Const.g_si`, `Const.solar_mass`, and `Const.c`.
2. Add `Length.ly 4.246` (Proxima Centauri) and check it matches the parallax-derived distance.
3. Use `Unit.doppler_optical` to compute the observed wavelength of H-alpha at a radial velocity of 100 km/s.
## Next Steps
Continue to [02-cosmological-distances](../02-cosmological-distances/) to compute
distances and times in an expanding universe.
================================================
FILE: dev/umbra/examples/01-constants-and-units/dune
================================================
(executable
(name main)
(libraries nx umbra))
================================================
FILE: dev/umbra/examples/01-constants-and-units/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Type-safe units and physical constants.
Introduces Umbra's dimensional type system: quantities carry phantom types
that prevent mixing incompatible dimensions at compile time. Shows how to
create, convert, and combine quantities in different units, and how to use
physical and astronomical constants. *)
open Nx
open Umbra
let f64 = Nx.float64
let () =
Printf.printf "Type-safe units and physical constants\n";
Printf.printf "======================================\n\n";
(* --- Length: metres, parsecs, AU, light-years --- *)
Printf.printf "Length conversions\n";
Printf.printf "------------------\n";
let d_pc = Unit.Length.pc 1.0 in
Printf.printf " 1 parsec = %.4e m\n" (item [] (Unit.Length.in_m d_pc));
Printf.printf " 1 parsec = %.6f ly\n" (item [] (Unit.Length.in_ly d_pc));
Printf.printf " 1 parsec = %.0f AU\n" (item [] (Unit.Length.in_au d_pc));
let d_au = Unit.Length.au 1.0 in
Printf.printf " 1 AU = %.4e m\n" (item [] (Unit.Length.in_m d_au));
Printf.printf " 1 AU = %.4e pc\n\n" (item [] (Unit.Length.in_pc d_au));
(* Adding lengths of different units — the type system ensures consistency *)
let d_total = Unit.( + ) (Unit.Length.kpc 10.0) (Unit.Length.pc 500.0) in
Printf.printf " 10 kpc + 500 pc = %.3f kpc\n\n"
(item [] (Unit.Length.in_kpc d_total));
(* --- Angle: degrees, radians, arcseconds --- *)
Printf.printf "Angle conversions\n";
Printf.printf "-----------------\n";
let a_deg = Unit.Angle.deg 1.0 in
Printf.printf " 1 degree = %.6f rad\n" (item [] (Unit.Angle.in_rad a_deg));
Printf.printf " 1 degree = %.1f arcmin\n"
(item [] (Unit.Angle.in_arcmin a_deg));
Printf.printf " 1 degree = %.1f arcsec\n"
(item [] (Unit.Angle.in_arcsec a_deg));
let a_mas = Unit.Angle.mas 1.0 in
Printf.printf " 1 mas = %.4e arcsec\n\n"
(item [] (Unit.Angle.in_arcsec a_mas));
(* --- Temperature --- *)
Printf.printf "Temperature\n";
Printf.printf "-----------\n";
let sun_t = Unit.Temperature.kelvin 5778.0 in
Printf.printf " Sun surface: %.0f K\n"
(item [] (Unit.Temperature.in_kelvin sun_t));
let sirius_t = Unit.Temperature.kelvin 9940.0 in
Printf.printf " Sirius: %.0f K\n\n"
(item [] (Unit.Temperature.in_kelvin sirius_t));
(* --- Time durations --- *)
Printf.printf "Time durations\n";
Printf.printf "--------------\n";
let t_yr = Unit.Time.yr 1.0 in
Printf.printf " 1 Julian year = %.0f days\n"
(item [] (Unit.Time.in_day t_yr));
Printf.printf " 1 Julian year = %.2f s\n" (item [] (Unit.Time.in_s t_yr));
let t_gyr = Unit.Time.gyr 13.8 in
Printf.printf " Age of universe ~ %.2e yr\n\n"
(item [] (Unit.Time.in_yr t_gyr));
(* --- Mass: kg, solar masses, Earth masses --- *)
Printf.printf "Mass conversions\n";
Printf.printf "----------------\n";
let m_sun = Unit.Mass.solar_mass 1.0 in
Printf.printf " 1 solar mass = %.4e kg\n" (item [] (Unit.Mass.in_kg m_sun));
Printf.printf " 1 solar mass = %.0f Earth masses\n"
(item [] (Unit.Mass.in_earth_mass m_sun));
Printf.printf " 1 solar mass = %.1f Jupiter masses\n\n"
(item [] (Unit.Mass.in_jupiter_mass m_sun));
(* --- Physical constants --- *)
Printf.printf "Physical constants\n";
Printf.printf "------------------\n";
Printf.printf " c = %.0f m/s\n" (Unit.to_float Const.c);
Printf.printf " h = %.4e J s\n" Const.h_si;
Printf.printf " k_B = %.4e J/K\n" Const.k_b_si;
Printf.printf " G = %.4e m^3 kg^-1 s^-2\n" Const.g_si;
Printf.printf " sigma = %.4e W m^-2 K^-4\n\n" Const.sigma_sb_si;
(* --- Astronomical constants --- *)
Printf.printf "Astronomical constants\n";
Printf.printf "----------------------\n";
Printf.printf " L_sun = %.4e W\n"
(item [] (Unit.Power.in_w Const.solar_luminosity));
Printf.printf " R_sun = %.4e m\n"
(item [] (Unit.Length.in_m Const.solar_radius));
Printf.printf " M_sun = %.4e kg\n"
(item [] (Unit.Mass.in_kg Const.solar_mass));
Printf.printf " 1 AU = %.4e m\n" (item [] (Unit.Length.in_m Const.au));
Printf.printf " 1 pc = %.4e m\n\n" (item [] (Unit.Length.in_m Const.pc));
(* --- Cross-dimension: parallax to distance --- *)
Printf.printf "Parallax to distance\n";
Printf.printf "--------------------\n";
let parallax = Unit.Angle.arcsec 1.0 in
let dist = Unit.parallax_to_distance parallax in
Printf.printf " 1 arcsec parallax -> %.3f pc\n"
(item [] (Unit.Length.in_pc dist));
let proxima_parallax = Unit.Angle.mas 768.5 in
let proxima_dist = Unit.parallax_to_distance proxima_parallax in
Printf.printf " Proxima Cen (768.5 mas) -> %.3f pc\n"
(item [] (Unit.Length.in_pc proxima_dist));
(* --- Tensor operations: batch unit conversions --- *)
Printf.printf "\nBatch operations\n";
Printf.printf "----------------\n";
let wavelengths_nm =
create f64 [| 5 |] [| 380.0; 450.0; 550.0; 650.0; 750.0 |]
in
let wavelengths = Unit.Length.of_nm wavelengths_nm in
let wavelengths_angstrom = Unit.Length.in_angstrom wavelengths in
Printf.printf " Wavelengths (nm): %s\n"
(Nx.data_to_string wavelengths_nm);
Printf.printf " Wavelengths (angstrom): %s\n"
(Nx.data_to_string wavelengths_angstrom);
(* Convert wavelength to frequency *)
let freqs = Unit.wavelength_to_frequency wavelengths in
Printf.printf " Frequencies (Hz): %s\n"
(Nx.data_to_string (Unit.Frequency.in_hz freqs));
(* Wien's law: peak wavelength of a blackbody *)
Printf.printf "\nWien's displacement law\n";
Printf.printf "----------------------\n";
let b_wien = Const.b_wien_si in
let sun_peak_m = b_wien /. item [] (Unit.Temperature.in_kelvin sun_t) in
Printf.printf " Sun (T=%.0f K): peak at %.0f nm\n"
(item [] (Unit.Temperature.in_kelvin sun_t))
(sun_peak_m *. 1e9);
let sirius_peak_m = b_wien /. item [] (Unit.Temperature.in_kelvin sirius_t) in
Printf.printf " Sirius (T=%.0f K): peak at %.0f nm\n"
(item [] (Unit.Temperature.in_kelvin sirius_t))
(sirius_peak_m *. 1e9)
================================================
FILE: dev/umbra/examples/02-cosmological-distances/README.md
================================================
# `02-cosmological-distances`
Cosmological distance calculations and parameter fitting. First prints a
distance table for the Planck 2018 cosmology, then fits H0 and Omega_m from
synthetic Type Ia supernova distance moduli using gradient descent.
```bash
dune exec dev/umbra/examples/02-cosmological-distances/main.exe
```
## What You'll Learn
- Using preset cosmologies (`Cosmo.planck18`)
- Computing distances (`comoving_distance`, `luminosity_distance`, `angular_diameter_distance`)
- Computing distance modulus and lookback time
- Building differentiable cosmological models with `create_flat_lcdm`
- Fitting cosmological parameters with Rune autodiff and Vega optimizers
## Key Functions
| Function | Purpose |
| ----------------------------- | --------------------------------------------- |
| `Cosmo.planck18` | Planck 2018 flat LCDM preset |
| `Cosmo.comoving_distance` | Line-of-sight comoving distance |
| `Cosmo.luminosity_distance` | Luminosity distance at redshift z |
| `Cosmo.distance_modulus` | Distance modulus mu = 5 log10(d_L/Mpc) + 25 |
| `Cosmo.lookback_time` | Time since light was emitted |
| `Cosmo.age` | Age of the universe at redshift z |
| `Cosmo.create_flat_lcdm` | Tensor-parameterized cosmology for autodiff |
| `Rune.value_and_grads` | Forward pass + gradient computation |
## How It Works
The distance modulus forward model uses `Cosmo.distance_modulus`, which
internally integrates E(z) via 16-point Gauss-Legendre quadrature. Since all
operations are Nx tensor ops, gradients flow through the entire pipeline
automatically via Rune.
The optimizer starts from H0=65, Omega_m=0.25 and converges toward the true
values (H0~73, Omega_m~0.3) that generated the synthetic data.
## Try It
1. Change the preset to `Cosmo.wmap9` and compare the distance table.
2. Add `Omega_L` as a free parameter using `create_lcdm` for a non-flat model.
3. Use `Cosmo.z_at_value` to find the redshift where the lookback time is 10 Gyr.
## Next Steps
Continue to [03-blackbody-fitting](../03-blackbody-fitting/) to fit stellar
temperatures from photometry.
================================================
FILE: dev/umbra/examples/02-cosmological-distances/dune
================================================
(executable
(name main)
(libraries nx rune vega umbra))
================================================
FILE: dev/umbra/examples/02-cosmological-distances/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Differentiable cosmological parameter fitting from Type Ia supernova distance
moduli.
Fits H0 (Hubble constant) and Omega_m (matter density fraction) by gradient
descent on the distance modulus residuals. The forward model uses
Umbra.Cosmo.distance_modulus directly -- its Gauss-Legendre quadrature,
luminosity distance, and distance modulus are all Nx tensor operations,
making them natively differentiable through Rune's autodiff.
Also demonstrates basic cosmological distance queries: comoving distance,
luminosity distance, angular diameter distance, lookback time, and the age of
the universe at various redshifts. *)
open Nx
open Umbra
let f64 = Nx.float64
(* --- Part 1: Distance table for the Planck 2018 cosmology --- *)
let print_distance_table () =
Printf.printf "Cosmological distances (Planck 2018)\n";
Printf.printf "====================================\n\n";
let p = Cosmo.planck18 in
Printf.printf " H0 = %.2f km/s/Mpc\n" (item [] (Cosmo.h0 p));
Printf.printf " Omega_m = %.4f\n" (item [] (Cosmo.omega_m p));
Printf.printf " Omega_L = %.4f\n\n" (item [] (Cosmo.omega_l p));
Printf.printf "%6s %10s %10s %10s %8s %8s\n" "z" "d_C (Mpc)" "d_L (Mpc)"
"d_A (Mpc)" "mu" "t_lb (Gyr)";
Printf.printf "%6s %10s %10s %10s %8s %8s\n" "------" "----------"
"----------" "----------" "--------" "----------";
let redshifts = [| 0.01; 0.1; 0.3; 0.5; 1.0; 2.0; 3.0; 5.0 |] in
Array.iter
(fun z ->
let zv = scalar f64 z in
let d_c = item [] (Unit.Length.in_mpc (Cosmo.comoving_distance ~p zv)) in
let d_l =
item [] (Unit.Length.in_mpc (Cosmo.luminosity_distance ~p zv))
in
let d_a =
item [] (Unit.Length.in_mpc (Cosmo.angular_diameter_distance ~p zv))
in
let mu = item [] (Cosmo.distance_modulus ~p zv) in
let t_lb = item [] (Unit.Time.in_gyr (Cosmo.lookback_time ~p zv)) in
Printf.printf "%6.2f %10.1f %10.1f %10.1f %8.2f %8.2f\n" z d_c d_l
d_a mu t_lb)
redshifts;
Printf.printf "\n";
(* Age of the universe *)
let age_now = item [] (Unit.Time.in_gyr (Cosmo.age ~p (scalar f64 0.0))) in
Printf.printf " Age of the universe (z=0): %.2f Gyr\n\n" age_now
(* --- Part 2: Fit H0 and Omega_m from SN Ia data --- *)
(* Representative SN Ia data points (z, observed distance modulus). Based on
Pantheon+ compilation values for flat LCDM with H0 ~ 73, Omega_m ~ 0.3. *)
let z_arr = [| 0.01; 0.03; 0.08; 0.15; 0.25; 0.40; 0.55; 0.70; 0.85; 1.00 |]
let n_sn = Array.length z_arr
let mu_obs =
[| 33.07; 35.47; 37.62; 39.07; 40.24; 41.42; 42.23; 42.85; 43.34; 43.74 |]
(* Forward model: compute distance modulus for all SNe. The differentiable
parameters are H0 and Omega_m, which flow through Cosmo.distance_modulus via
Nx tensor operations. *)
let loss params =
match params with
| [ h0; omega_m ] ->
let p = Cosmo.create_flat_lcdm ~h0 ~omega_m in
let total = ref (scalar f64 0.0) in
for i = 0 to n_sn - 1 do
let z_i = scalar f64 z_arr.(i) in
let mu_pred = Cosmo.distance_modulus ~p z_i in
let mu_obs_i = scalar f64 mu_obs.(i) in
let residual = sub mu_pred mu_obs_i in
total := add !total (square residual)
done;
div_s !total (Float.of_int n_sn)
| _ -> failwith "expected [h0; omega_m]"
let fit_cosmology () =
Printf.printf "Fitting H0 and Omega_m from Type Ia supernovae\n";
Printf.printf "===============================================\n";
Printf.printf " Data: %d distance moduli (Pantheon+-like)\n" n_sn;
Printf.printf " Method: Adam optimizer, 300 steps\n";
Printf.printf " Model: flat LCDM via Cosmo.distance_modulus\n\n";
let algo = Vega.adam (Vega.Schedule.constant 0.5) in
let h0 = ref (scalar f64 65.0) in
let omega_m = ref (scalar f64 0.25) in
let states = [| Vega.init algo !h0; Vega.init algo !omega_m |] in
let steps = 300 in
Printf.printf "%5s %10s %8s %8s\n" "step" "loss" "H0" "Omega_m";
Printf.printf "%5s %10s %8s %8s\n" "-----" "----------" "--------"
"--------";
let refs = [| h0; omega_m |] in
for i = 0 to steps - 1 do
let loss_val, grads = Rune.value_and_grads loss [ !h0; !omega_m ] in
List.iteri
(fun j g ->
let p, s = Vega.step states.(j) ~grad:g ~param:!(refs.(j)) in
refs.(j) := p;
states.(j) <- s)
grads;
if i mod 50 = 0 || i = steps - 1 then
Printf.printf "%5d %10.6f %8.2f %8.4f\n" i (item [] loss_val)
(item [] !h0) (item [] !omega_m)
done;
Printf.printf "\nFitted parameters:\n";
Printf.printf " H0 = %.2f km/s/Mpc\n" (item [] !h0);
Printf.printf " Omega_m = %.4f\n" (item [] !omega_m)
let () =
print_distance_table ();
fit_cosmology ()
================================================
FILE: dev/umbra/examples/03-blackbody-fitting/README.md
================================================
# `03-blackbody-fitting`
Fits the effective temperature and luminosity normalization of a star from
synthetic UGRIZ broadband photometry using gradient descent on a blackbody
model.
```bash
dune exec dev/umbra/examples/03-blackbody-fitting/main.exe
```
## What You'll Learn
- Using physical constants (`Const.h_si`, `Const.k_b_si`, `Const.c`)
- Building a differentiable Planck function from Nx tensor operations
- Parameterizing in log-space for numerical stability
- Fitting chi-squared with Rune autodiff and Vega's Adam optimizer
## Key Functions
| Function | Purpose |
| --------------------- | -------------------------------------------------- |
| `Const.h_si` | Planck constant (J s) |
| `Const.k_b_si` | Boltzmann constant (J/K) |
| `Const.c` | Speed of light (typed velocity) |
| `Unit.to_float` | Extract scalar SI value from a typed constant |
| `Rune.value_and_grads`| Compute loss and gradients in one pass |
| `Vega.adam` | Adam optimizer |
| `Vega.step` | Apply one optimization step |
## How It Works
The Planck spectral radiance B(lambda, T) = 2hc^2 / lambda^5 / (exp(hc /
lambda k T) - 1) is implemented entirely with Nx tensor operations. Since Rune
can differentiate any Nx computation, gradients of chi-squared with respect to
log(T) and log(A) are computed automatically.
The optimizer starts from T=5000 K and converges toward the true temperature
of 5800 K (Sun-like star). Log-space parameterization ensures positivity and
improves gradient conditioning.
## Try It
1. Change the true temperature to 10000 K (A-type star) and observe how the
SED shape changes.
2. Add a third parameter for a dust extinction term.
3. Replace the central-wavelength approximation with proper filter integration
using `Photometry.ab_mag` (see example 05).
## Next Steps
Continue to [04-extinction-and-magnitudes](../04-extinction-and-magnitudes/) to
learn about dust extinction, K-corrections, and magnitude systems.
================================================
FILE: dev/umbra/examples/03-blackbody-fitting/dune
================================================
(executable
(name main)
(libraries nx rune vega umbra))
================================================
FILE: dev/umbra/examples/03-blackbody-fitting/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Differentiable blackbody SED fitting.
Given broadband photometric measurements in UGRIZ bands, fit the stellar
effective temperature and luminosity normalization by gradient descent on the
chi-squared statistic. The Planck function is evaluated as Nx tensor
operations, making it fully differentiable through Rune.
Uses Umbra.Const for physical constants. *)
open Nx
open Umbra
let f64 = Nx.float64
(* Central wavelengths of SDSS UGRIZ bands in meters *)
let lambda =
create f64 [| 5 |] [| 3.551e-7; 4.686e-7; 6.166e-7; 7.480e-7; 8.932e-7 |]
(* Physical constants from Umbra *)
let h_planck = Const.h_si
let c_light = Unit.to_float Const.c
let k_boltz = Const.k_b_si
(* Pre-computed constant tensors *)
let two_hc2 = scalar f64 (2.0 *. h_planck *. c_light *. c_light)
let hc_over_k = scalar f64 (h_planck *. c_light /. k_boltz)
let lam5 = pow_s lambda 5.0
(* Generate synthetic observations from a Sun-like star *)
let true_temp = 5800.0
let true_log_norm = -50.0
let planck_scalar lam_m temp =
let x = h_planck *. c_light /. (lam_m *. k_boltz *. temp) in
2.0 *. h_planck *. c_light *. c_light
/. (lam_m *. lam_m *. lam_m *. lam_m *. lam_m)
/. (Float.exp x -. 1.0)
let flux_obs =
let norm = Float.exp true_log_norm in
let fluxes =
Array.init 5 (fun i ->
let lam_m =
[| 3.551e-7; 4.686e-7; 6.166e-7; 7.480e-7; 8.932e-7 |].(i)
in
norm
*. planck_scalar lam_m true_temp
*. (1.0 +. (0.02 *. (Float.of_int i -. 2.0))))
in
create f64 [| 5 |] fluxes
(* Fractional errors: 5% photometry *)
let flux_err = mul_s flux_obs 0.05
let band_names = [| "U"; "G"; "R"; "I"; "Z" |]
(* Differentiable forward model: Planck function at 5 wavelengths. Parameterized
in log-space for positivity and gradient conditioning.
B(lambda, T) = 2hc^2 / lambda^5 / (exp(hc / (lambda * k * T)) - 1) *)
let loss params =
match params with
| [ log_temp; log_norm ] ->
let temp = exp log_temp in
let norm = exp log_norm in
let exponent = div hc_over_k (mul lambda temp) in
let planck =
div (div two_hc2 lam5) (sub (exp exponent) (scalar f64 1.0))
in
let flux_pred = mul norm planck in
let residual = div (sub flux_pred flux_obs) flux_err in
sum (square residual)
| _ -> failwith "expected [log_temp; log_norm]"
let () =
Printf.printf "Differentiable blackbody SED fitting\n";
Printf.printf "====================================\n";
Printf.printf "Fitting temperature and normalization to UGRIZ photometry\n\n";
Printf.printf "True parameters:\n";
Printf.printf " T = %.0f K\n" true_temp;
Printf.printf " logA = %.1f\n\n" true_log_norm;
Printf.printf "Synthetic observations (5%% errors):\n";
for i = 0 to 4 do
Printf.printf " %s: %.4e +/- %.4e\n" band_names.(i) (item [ i ] flux_obs)
(item [ i ] flux_err)
done;
Printf.printf "\n";
(* Start from a guess *)
let algo = Vega.adam (Vega.Schedule.constant 1e-2) in
let log_temp = ref (scalar f64 (Float.log 5000.0)) in
let log_norm = ref (scalar f64 (-52.0)) in
let states = [| Vega.init algo !log_temp; Vega.init algo !log_norm |] in
let steps = 500 in
Printf.printf "%5s %12s %8s %10s\n" "step" "chi2" "T (K)" "log_norm";
Printf.printf "%5s %12s %8s %10s\n" "-----" "------------" "--------"
"----------";
let refs = [| log_temp; log_norm |] in
for i = 0 to steps - 1 do
let loss_val, grads = Rune.value_and_grads loss [ !log_temp; !log_norm ] in
List.iteri
(fun j g ->
let p, s = Vega.step states.(j) ~grad:g ~param:!(refs.(j)) in
refs.(j) := p;
states.(j) <- s)
grads;
if i mod 100 = 0 || i = steps - 1 then
Printf.printf "%5d %12.4f %8.1f %10.3f\n" i (item [] loss_val)
(Float.exp (item [] !log_temp))
(item [] !log_norm)
done;
Printf.printf "\nFitted parameters:\n";
Printf.printf " T = %.1f K (true: %.0f K)\n"
(Float.exp (item [] !log_temp))
true_temp;
Printf.printf " logA = %.3f (true: %.1f)\n" (item [] !log_norm)
true_log_norm
================================================
FILE: dev/umbra/examples/04-extinction-and-magnitudes/README.md
================================================
# `04-extinction-and-magnitudes`
Explores three key photometric concepts: magnitude systems (AB, ST, Vega),
K-corrections from redshift, and interstellar dust extinction. Shows how to
compose `Spectrum`, `Extinction`, `Photometry`, and `Filters` modules.
```bash
dune exec dev/umbra/examples/04-extinction-and-magnitudes/main.exe
```
## What You'll Learn
- Computing AB, ST, and Vega magnitudes through real SDSS filters
- Understanding K-corrections from redshift-shifted SEDs
- Applying extinction laws (CCM89, Fitzpatrick99, O'Donnell94)
- Measuring colors and color excess from dust reddening
## Key Functions
| Function | Purpose |
| --------------------------- | ---------------------------------------------- |
| `Photometry.ab_mag` | AB magnitude through a bandpass |
| `Photometry.st_mag` | ST magnitude through a bandpass |
| `Photometry.vega_mag` | Vega magnitude through a bandpass |
| `Photometry.color` | Color index (mag difference between two bands) |
| `Spectrum.blackbody` | Planck spectral radiance |
| `Spectrum.redshift` | Apply cosmological redshift to an SED |
| `Spectrum.as_flux_density` | Cast spectrum to flux density kind |
| `Extinction.ccm89` | Cardelli, Clayton & Mathis (1989) dust law |
| `Extinction.fitzpatrick99` | Fitzpatrick (1999) dust law |
| `Extinction.apply` | Redden a spectrum by A_V magnitudes |
| `Filters.sdss_r` | Pre-built SDSS r-band bandpass |
## How It Works
**Magnitude systems** differ in their reference flux:
- AB: constant f_nu = 3631 Jy
- ST: constant f_lambda = 3.63e-9 erg/s/cm^2/A
- Vega: the spectrum of alpha Lyrae
**K-corrections** arise because redshift moves the SED across the bandpass,
changing the measured flux even without distance dimming. K(z) = m_obs - m_rest.
**Extinction** attenuates and reddens starlight. The extinction curve
A_lambda/A_V depends on wavelength and the dust grain properties (encoded in
R_V). Higher A_V means more dimming; bluer bands are affected more, producing
reddening.
## Try It
1. Compare Galactic extinction (CCM89, R_V=3.1) with starburst attenuation
(`Extinction.calzetti00`).
2. Apply both redshift and extinction to see their combined effect on colors.
3. Use `Extinction.unredden` to recover the intrinsic SED from a reddened
observation.
## Next Steps
Continue to [05-sed-fitting](../05-sed-fitting/) to fit temperature, extinction,
and normalization simultaneously.
================================================
FILE: dev/umbra/examples/04-extinction-and-magnitudes/dune
================================================
(executable
(name main)
(libraries nx rune umbra))
================================================
FILE: dev/umbra/examples/04-extinction-and-magnitudes/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* K-corrections, extinction, and magnitude systems.
Demonstrates three key photometric concepts:
1. Magnitude systems: AB, ST, and Vega magnitudes through real SDSS filters.
2. K-corrections: the difference between observed and rest-frame magnitudes
due to redshift shifting the SED across the bandpass. 3. Extinction: how
interstellar dust reddens and dims stellar light, comparing CCM89 and
Fitzpatrick99 extinction laws. *)
open Nx
open Umbra
let f64 = Nx.float64
let () =
Printf.printf "Extinction, K-corrections, and magnitude systems\n";
Printf.printf "=================================================\n\n";
(* --- Part 1: Magnitude systems --- *)
Printf.printf "Part 1: AB, ST, and Vega magnitudes\n";
Printf.printf "------------------------------------\n\n";
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 6000.0) in
let norm = Nx.scalar f64 (Float.exp (-50.0)) in
let bands =
[|
("SDSS u", Filters.sdss_u);
("SDSS g", Filters.sdss_g);
("SDSS r", Filters.sdss_r);
("SDSS i", Filters.sdss_i);
("SDSS z", Filters.sdss_z);
|]
in
Printf.printf " Source: T=6000 K blackbody\n\n";
Printf.printf "%8s %8s %8s %8s\n" "Band" "AB" "ST" "Vega";
Printf.printf "%8s %8s %8s %8s\n" "--------" "--------" "--------"
"--------";
Array.iter
(fun (name, bp) ->
let bp_wave = Photometry.wavelength bp in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm |> Spectrum.as_flux_density
in
let m_ab = item [] (Photometry.ab_mag bp sed) in
let m_st = item [] (Photometry.st_mag bp sed) in
let m_vega = item [] (Photometry.vega_mag bp sed) in
Printf.printf "%8s %+8.3f %+8.3f %+8.3f\n" name m_ab m_st m_vega)
bands;
Printf.printf "\n Note: AB and ST systems are defined by reference flux\n";
Printf.printf " densities; Vega magnitudes use the alpha Lyr spectrum.\n\n";
(* --- Part 2: K-corrections --- *)
Printf.printf "Part 2: K-corrections\n";
Printf.printf "---------------------\n\n";
let bp = Filters.sdss_r in
let bp_wave = Photometry.wavelength bp in
let rest_sed =
Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm |> Spectrum.as_flux_density
in
let m_ab_rest = item [] (Photometry.ab_mag bp rest_sed) in
let m_st_rest = item [] (Photometry.st_mag bp rest_sed) in
let m_vega_rest = item [] (Photometry.vega_mag bp rest_sed) in
Printf.printf " Rest-frame SDSS r-band:\n";
Printf.printf " AB = %.3f\n" m_ab_rest;
Printf.printf " ST = %.3f\n" m_st_rest;
Printf.printf " Vega = %.3f\n\n" m_vega_rest;
Printf.printf " K-correction = m_obs(z) - m_rest\n\n";
Printf.printf "%6s %8s %8s %8s\n" "z" "K_AB" "K_ST" "K_Vega";
Printf.printf "%6s %8s %8s %8s\n" "------" "--------" "--------" "--------";
let redshifts = [| 0.1; 0.2; 0.3; 0.5; 0.7; 1.0 |] in
Array.iter
(fun z ->
let zv = Nx.scalar f64 z in
let obs_sed =
Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm |> Spectrum.as_flux_density
|> Spectrum.redshift ~z:zv
in
let k_ab = item [] (Photometry.ab_mag bp obs_sed) -. m_ab_rest in
let k_st = item [] (Photometry.st_mag bp obs_sed) -. m_st_rest in
let k_vega = item [] (Photometry.vega_mag bp obs_sed) -. m_vega_rest in
Printf.printf "%6.2f %+8.3f %+8.3f %+8.3f\n" z k_ab k_st k_vega)
redshifts;
Printf.printf "\n";
(* --- Part 3: Color evolution with redshift --- *)
Printf.printf "Part 3: Color evolution (u-r) with redshift\n";
Printf.printf "-------------------------------------------\n\n";
Printf.printf "%6s %8s\n" "z" "u-r (AB)";
Printf.printf "%6s %8s\n" "------" "--------";
Array.iter
(fun z ->
let zv = Nx.scalar f64 z in
let color =
item []
(Photometry.color Filters.sdss_u Filters.sdss_r
(Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm |> Spectrum.as_flux_density
|> Spectrum.redshift ~z:zv))
in
Printf.printf "%6.2f %+8.3f\n" z color)
redshifts;
Printf.printf "\n";
(* --- Part 4: Extinction --- *)
Printf.printf "Part 4: Dust extinction\n";
Printf.printf "-----------------------\n\n";
let rv = Nx.scalar f64 3.1 in
let av_values = [| 0.0; 0.5; 1.0; 2.0; 3.0 |] in
Printf.printf " CCM89 extinction law (R_V = 3.1)\n";
Printf.printf " Reddening a T=6000 K blackbody through SDSS r-band\n\n";
Printf.printf "%6s %8s %8s %8s\n" "A_V" "m_AB" "delta_m" "E(u-r)";
Printf.printf "%6s %8s %8s %8s\n" "------" "--------" "--------" "--------";
let unreddened_sed =
Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm |> Spectrum.as_flux_density
in
let m0 = item [] (Photometry.ab_mag bp unreddened_sed) in
let color0 =
item [] (Photometry.color Filters.sdss_u Filters.sdss_r unreddened_sed)
in
Array.iter
(fun av_f ->
let av = Nx.scalar f64 av_f in
let reddened =
Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm
|> Extinction.apply (Extinction.ccm89 ~rv) ~av
|> Spectrum.as_flux_density
in
let m = item [] (Photometry.ab_mag bp reddened) in
let color =
item [] (Photometry.color Filters.sdss_u Filters.sdss_r reddened)
in
Printf.printf "%6.1f %8.3f %+8.3f %+8.3f\n" av_f m (m -. m0)
(color -. color0))
av_values;
Printf.printf "\n";
(* Compare extinction laws *)
Printf.printf " Comparing extinction laws at A_V = 1.0:\n\n";
Printf.printf "%16s %8s %8s\n" "Law" "r-band" "E(u-r)";
Printf.printf "%16s %8s %8s\n" "----------------" "--------" "--------";
let av_one = Nx.scalar f64 1.0 in
let laws =
[|
("CCM89", Extinction.ccm89 ~rv);
("Fitzpatrick99", Extinction.fitzpatrick99 ~rv);
("O'Donnell94", Extinction.odonnell94 ~rv);
|]
in
Array.iter
(fun (name, law) ->
let reddened =
Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm
|> Extinction.apply law ~av:av_one
|> Spectrum.as_flux_density
in
let m = item [] (Photometry.ab_mag bp reddened) in
let color =
item [] (Photometry.color Filters.sdss_u Filters.sdss_r reddened)
in
Printf.printf "%16s %+8.3f %+8.3f\n" name (m -. m0) (color -. color0))
laws
================================================
FILE: dev/umbra/examples/05-sed-fitting/README.md
================================================
# `05-sed-fitting`
Full SED fitting pipeline: fits stellar temperature, dust extinction (A_V), and
flux normalization simultaneously from UGRIZ photometry. Demonstrates the
composable differentiable pipeline through Spectrum, Extinction, and Photometry.
```bash
dune exec dev/umbra/examples/05-sed-fitting/main.exe
```
## What You'll Learn
- Building a full astrophysical forward model from composable modules
- How the blackbody -> extinction -> photometry pipeline is end-to-end differentiable
- Creating custom bandpasses with `Photometry.tophat`
- Fitting multiple correlated parameters (T, A_V, normalization) simultaneously
## Key Functions
| Function | Purpose |
| ---------------------------- | --------------------------------------------- |
| `Spectrum.blackbody` | Planck spectral radiance at given wavelengths |
| `Spectrum.scale` | Scale spectrum values by a factor |
| `Spectrum.as_flux_density` | Cast to flux density kind for photometry |
| `Extinction.ccm89` | Create CCM89 extinction law with R_V |
| `Extinction.apply` | Apply dust reddening to a spectrum |
| `Photometry.tophat` | Create a rectangular bandpass |
| `Photometry.ab_mag` | Compute AB magnitude through a bandpass |
| `Rune.value_and_grads` | Autodiff through the entire pipeline |
## How It Works
The forward model constructs a synthetic SED at each optimization step:
1. **Spectrum.blackbody** generates the Planck function at temperature T
2. **Spectrum.scale** applies the flux normalization
3. **Extinction.apply** reddens the spectrum using CCM89 with extinction A_V
4. **Photometry.ab_mag** integrates through each bandpass to produce magnitudes
Since every step is built from Nx tensor operations, Rune computes gradients
of chi-squared with respect to all three parameters (log T, A_V, log norm) in
a single backward pass.
The temperature and normalization are parameterized in log-space for positivity
and better gradient conditioning. A_V is left in linear space since it can
meaningfully be zero or negative (de-reddening).
## Try It
1. Replace tophat filters with real SDSS filters from `Filters.sdss_u`, etc.
2. Add a redshift parameter to fit photometric redshifts.
3. Try `Extinction.fitzpatrick99` instead of `ccm89` and compare results.
4. Increase the photometric noise and observe how parameter uncertainties grow.
## Next Steps
Continue to [06-coordinates-and-time](../06-coordinates-and-time/) to work with
celestial coordinates, time scales, and observing conditions.
================================================
FILE: dev/umbra/examples/05-sed-fitting/dune
================================================
(executable
(name main)
(libraries nx rune vega umbra))
================================================
FILE: dev/umbra/examples/05-sed-fitting/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Differentiable SED fitting: temperature + extinction from photometry.
Demonstrates the composable differentiable pipeline: Spectrum.blackbody ->
Extinction.apply -> Photometry.ab_mag
All operations flow through Nx tensor ops, making the entire pipeline
differentiable via Rune's autodiff. We fit stellar temperature, dust
extinction, and flux normalization simultaneously by gradient descent on
photometric residuals. *)
open Nx
open Umbra
let f64 = Nx.float64
(* Define 5 broadband filters (UGRIZ-like tophats) *)
let n_bp = 100
let band_u =
Photometry.tophat ~lo:(Unit.Length.m 3.0e-7) ~hi:(Unit.Length.m 4.0e-7)
~n:n_bp
let band_g =
Photometry.tophat ~lo:(Unit.Length.m 4.0e-7) ~hi:(Unit.Length.m 5.5e-7)
~n:n_bp
let band_r =
Photometry.tophat ~lo:(Unit.Length.m 5.5e-7) ~hi:(Unit.Length.m 7.0e-7)
~n:n_bp
let band_i =
Photometry.tophat ~lo:(Unit.Length.m 7.0e-7) ~hi:(Unit.Length.m 8.5e-7)
~n:n_bp
let band_z =
Photometry.tophat ~lo:(Unit.Length.m 8.5e-7) ~hi:(Unit.Length.m 1.0e-6)
~n:n_bp
let bands = [ band_u; band_g; band_r; band_i; band_z ]
let band_names = [| "U"; "G"; "R"; "I"; "Z" |]
(* True parameters *)
let true_temp = 6500.0 (* K -- F-type star *)
let true_av = 0.5 (* moderate extinction *)
let true_log_norm = -50.0
(* Generate synthetic observations *)
let rv = Nx.scalar f64 3.1
let obs_mags =
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 true_temp) in
let av = Nx.scalar f64 true_av in
let norm = Nx.scalar f64 (Float.exp true_log_norm) in
let mags =
List.map
(fun bp ->
let bp_wave = Photometry.wavelength bp in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm
|> Extinction.apply (Extinction.ccm89 ~rv) ~av
|> Spectrum.as_flux_density
in
Photometry.ab_mag bp sed)
bands
in
(* Add 3% photometric noise *)
let noise = [| 0.03; -0.02; 0.01; -0.01; 0.02 |] in
List.mapi (fun i m -> Nx.add_s m noise.(i)) mags
let obs_errs = List.init 5 (fun _ -> Nx.scalar f64 0.05)
(* Forward model: generate magnitudes from parameters *)
let forward_model log_temp av log_norm =
let temp = Unit.Temperature.of_kelvin (exp log_temp) in
let norm = exp log_norm in
List.map
(fun bp ->
let bp_wave = Photometry.wavelength bp in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm
|> Extinction.apply (Extinction.ccm89 ~rv) ~av
|> Spectrum.as_flux_density
in
Photometry.ab_mag bp sed)
bands
(* Loss function: chi-squared *)
let loss params =
match params with
| [ log_temp; av; log_norm ] ->
let pred = forward_model log_temp av log_norm in
List.fold_left2
(fun acc p (o, e) ->
let residual = div (sub p o) e in
add acc (square residual))
(scalar f64 0.0) pred
(List.combine obs_mags obs_errs)
| _ -> failwith "expected [log_temp; av; log_norm]"
let () =
Printf.printf "Differentiable SED Fitting\n";
Printf.printf "=========================\n";
Printf.printf
"Pipeline: Spectrum.blackbody -> Extinction.ccm89 -> Photometry.ab_mag\n\n";
Printf.printf "True parameters:\n";
Printf.printf " T = %.0f K\n" true_temp;
Printf.printf " A_V = %.2f mag\n" true_av;
Printf.printf " logN = %.1f\n\n" true_log_norm;
Printf.printf "Observed magnitudes (with noise):\n";
List.iteri
(fun i m ->
Printf.printf " %s = %.3f +/- %.3f\n" band_names.(i) (item [] m)
(item [] (List.nth obs_errs i)))
obs_mags;
Printf.printf "\n";
(* Initial guesses *)
let algo = Vega.adam (Vega.Schedule.constant 1e-3) in
let log_temp = ref (scalar f64 (Float.log 7000.0)) in
let av = ref (scalar f64 0.3) in
let log_norm = ref (scalar f64 (-50.5)) in
let states =
[| Vega.init algo !log_temp; Vega.init algo !av; Vega.init algo !log_norm |]
in
let steps = 1000 in
Printf.printf "%5s %10s %8s %8s %8s\n" "step" "chi2" "T (K)" "A_V"
"log_norm";
Printf.printf "%5s %10s %8s %8s %8s\n" "-----" "----------" "--------"
"--------" "--------";
let refs = [| log_temp; av; log_norm |] in
for i = 0 to steps - 1 do
let loss_val, grads =
Rune.value_and_grads loss [ !log_temp; !av; !log_norm ]
in
if i mod 200 = 0 || i = steps - 1 then
Printf.printf "%5d %10.4f %8.1f %8.3f %8.3f\n" i (item [] loss_val)
(Float.exp (item [] !log_temp))
(item [] !av) (item [] !log_norm);
List.iteri
(fun j g ->
let p, s = Vega.step states.(j) ~grad:g ~param:!(refs.(j)) in
refs.(j) := p;
states.(j) <- s)
grads
done;
Printf.printf "\nFitted parameters:\n";
Printf.printf " T = %.1f K (true: %.0f K)\n"
(Float.exp (item [] !log_temp))
true_temp;
Printf.printf " A_V = %.3f (true: %.2f)\n" (item [] !av) true_av;
Printf.printf " logN = %.3f (true: %.1f)\n" (item [] !log_norm)
true_log_norm;
(* Show fitted vs observed magnitudes *)
Printf.printf "\nFitted vs observed magnitudes:\n";
let fitted_mags = forward_model !log_temp !av !log_norm in
Printf.printf "%5s %8s %8s %8s\n" "Band" "Observed" "Fitted" "Residual";
Printf.printf "%5s %8s %8s %8s\n" "-----" "--------" "--------" "--------";
List.iteri
(fun i (obs, fit) ->
let o = item [] obs in
let f = item [] fit in
Printf.printf "%5s %8.3f %8.3f %+8.3f\n" band_names.(i) o f (f -. o))
(List.combine obs_mags fitted_mags)
================================================
FILE: dev/umbra/examples/06-coordinates-and-time/README.md
================================================
# `06-coordinates-and-time`
Celestial coordinates, astronomical time scales, and survey selection.
Demonstrates frame transforms (ICRS, Galactic), angular separation, time scale
conversions (UTC, TAI, TT, TDB), altitude-azimuth coordinates, airmass, and
a practical survey selection function.
```bash
dune exec dev/umbra/examples/06-coordinates-and-time/main.exe
```
## What You'll Learn
- Creating celestial coordinates in ICRS and converting to Galactic frame
- Computing angular separations between objects
- Parsing ISO 8601 dates and converting between time scales
- Computing horizontal coordinates for a ground-based observer
- Building a survey selection function from airmass, altitude, and magnitude cuts
## Key Functions
| Function | Purpose |
| ---------------------------- | --------------------------------------------- |
| `Coord.of_radec` | Create ICRS coordinates from RA/Dec |
| `Coord.galactic` | Convert to Galactic coordinates |
| `Coord.separation` | Angular separation between positions |
| `Time.of_iso` | Parse ISO 8601 date-time as UTC |
| `Time.utc_to_tai` | Convert UTC to TAI |
| `Time.tai_to_tt` | Convert TAI to Terrestrial Time |
| `Time.tt_to_tdb` | Convert TT to Barycentric Dynamical Time |
| `Time.to_jd`, `Time.to_mjd` | Extract Julian Date / Modified Julian Date |
| `Altaz.make_observer` | Create a ground-based observer location |
| `Altaz.of_coord` | Convert celestial to horizontal coordinates |
| `Altaz.alt`, `Altaz.az` | Extract altitude and azimuth |
| `Altaz.airmass` | Compute airmass at given altitude |
| `Filters.rubin_r` | Pre-built Rubin/LSST r-band filter |
## How It Works
**Coordinates**: Positions are stored as (longitude, latitude) pairs in typed
angle quantities. Frame transforms use 3x3 rotation matrices to convert
between ICRS, Galactic, Ecliptic, and Supergalactic systems. Angular separation
uses the Vincenty formula for numerical stability.
**Time**: Julian Dates carry phantom type tags (UTC, TAI, TT, TDB) that
enforce correct scale conversions at compile time. UTC-TAI uses the IERS
leap-second table; TT = TAI + 32.184s exactly; TDB-TT uses the Fairhead &
Bretagnon series.
**Altaz**: Converts ICRS to horizontal coordinates using IAU 2006 precession
and the Earth Rotation Angle. Airmass uses the Pickering (2002) formula.
**Selection**: Combines altitude (above horizon), airmass (atmospheric
extinction), and magnitude limit into a boolean selection function -- a
building block for survey simulations.
## Try It
1. Add atmospheric refraction with `Altaz.of_coord ~refraction:true`.
2. Compute the position angle from Vega to Deneb with `Coord.position_angle`.
3. Use `Coord.of_galactic` to create coordinates in the Galactic plane and
convert to ICRS.
4. Change the observer location and time to see how visibility changes.
## Next Steps
Explore the other Umbra examples for more advanced topics: catalog
cross-matching with `Coord.nearest`, cosmological power spectra with
`Cosmo.linear_power`, and Fisher matrix forecasts.
================================================
FILE: dev/umbra/examples/06-coordinates-and-time/dune
================================================
(executable
(name main)
(libraries nx rune umbra))
================================================
FILE: dev/umbra/examples/06-coordinates-and-time/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Coordinates, time scales, and survey selection.
Demonstrates Umbra's coordinate, time, and observing modules: - Coord:
celestial coordinates with frame transforms (ICRS, Galactic, Ecliptic,
Supergalactic) and angular separation. - Time: astronomical time with
type-safe scale conversions (UTC, TAI, TT, TDB) and ISO 8601 parsing. -
Altaz: horizontal coordinates, airmass, and atmospheric refraction.
Combines these into a survey selection function that determines which targets
are observable given an observer, time, and observing constraints. *)
open Nx
open Umbra
let f64 = Nx.float64
let () =
Printf.printf "Coordinates, time scales, and survey selection\n";
Printf.printf "===============================================\n\n";
(* --- Part 1: Coordinate frames --- *)
Printf.printf "Part 1: Coordinate frame transforms\n";
Printf.printf "------------------------------------\n\n";
let targets =
[|
("Galactic center", 266.417, -28.936);
("Vega", 279.235, 38.784);
("North Galactic Pole", 192.860, 27.128);
("LMC", 80.894, -69.756);
("M31 (Andromeda)", 10.685, 41.269);
|]
in
Printf.printf "%20s %8s %8s %8s %8s\n" "Object" "RA" "Dec" "l" "b";
Printf.printf "%20s %8s %8s %8s %8s\n" "--------------------" "--------"
"--------" "--------" "--------";
Array.iter
(fun (name, ra_deg, dec_deg) ->
let coord =
Coord.of_radec
~ra:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| ra_deg |]))
~dec:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| dec_deg |]))
in
let gal = Coord.galactic coord in
let l = item [ 0 ] (Unit.Angle.in_deg (Coord.lon gal)) in
let b = item [ 0 ] (Unit.Angle.in_deg (Coord.lat gal)) in
Printf.printf "%20s %8.2f %+8.2f %8.2f %+8.2f\n" name ra_deg dec_deg l
b)
targets;
Printf.printf "\n";
(* Angular separation *)
Printf.printf "Angular separations:\n";
let vega =
Coord.of_radec
~ra:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 279.235 |]))
~dec:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 38.784 |]))
in
let altair =
Coord.of_radec
~ra:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 297.696 |]))
~dec:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 8.868 |]))
in
let deneb =
Coord.of_radec
~ra:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 310.358 |]))
~dec:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 45.280 |]))
in
let sep_va = item [ 0 ] (Unit.Angle.in_deg (Coord.separation vega altair)) in
let sep_vd = item [ 0 ] (Unit.Angle.in_deg (Coord.separation vega deneb)) in
let sep_ad = item [ 0 ] (Unit.Angle.in_deg (Coord.separation altair deneb)) in
Printf.printf " Vega - Altair: %.2f deg\n" sep_va;
Printf.printf " Vega - Deneb: %.2f deg\n" sep_vd;
Printf.printf " Altair - Deneb: %.2f deg\n" sep_ad;
Printf.printf " (Summer Triangle)\n\n";
(* --- Part 2: Time scales --- *)
Printf.printf "Part 2: Astronomical time scales\n";
Printf.printf "--------------------------------\n\n";
let t_utc = Time.of_iso "2024-06-21T04:00:00" in
let t_tai = Time.utc_to_tai t_utc in
let t_tt = Time.tai_to_tt t_tai in
let t_tdb = Time.tt_to_tdb t_tt in
Printf.printf " UTC: %s\n" (Time.to_iso t_utc);
Printf.printf " JD (UTC): %.6f\n" (Time.to_jd t_utc);
Printf.printf " MJD (UTC): %.6f\n" (Time.to_mjd t_utc);
Printf.printf " JD (TAI): %.6f\n" (Time.to_jd t_tai);
Printf.printf " JD (TT): %.6f\n" (Time.to_jd t_tt);
Printf.printf " JD (TDB): %.6f\n" (Time.to_jd t_tdb);
let dt_tai_utc =
Unit.to_float (Time.diff t_tai (Time.unsafe_of_jd (Time.to_jd t_utc)))
in
Printf.printf "\n TAI - UTC = %.1f s (leap seconds)\n" (dt_tai_utc *. 86400.0);
let t_j2000 = Time.of_iso "2000-01-01T12:00:00" in
let dt_j2000 = Unit.to_float (Time.diff t_utc t_j2000) in
Printf.printf " Days since J2000.0: %.2f\n\n" (dt_j2000 *. 86400.0 /. 86400.0);
(* --- Part 3: Horizontal coordinates and airmass --- *)
Printf.printf "Part 3: Altitude-azimuth and airmass\n";
Printf.printf "------------------------------------\n\n";
(* Observer at Cerro Pachon (Rubin site) *)
let obs =
Altaz.make_observer
~lat:(Unit.Angle.deg (-30.2444))
~lon:(Unit.Angle.deg (-70.7494))
~height:(Unit.Length.m 2663.0) ()
in
let obstime = Time.of_iso "2024-06-21T04:00:00" in
Printf.printf " Observer: Cerro Pachon (Rubin Observatory)\n";
Printf.printf " Lat: %.4f deg\n" (-30.2444);
Printf.printf " Lon: %.4f deg\n" (-70.7494);
Printf.printf " Elevation: %.0f m\n" 2663.0;
Printf.printf " Time: 2024-06-21 04:00 UTC\n\n";
let stars =
[|
("Vega", 279.235, 38.784);
("Sirius", 101.287, -16.716);
("Canopus", 95.988, -52.696);
("Alpha Cen", 219.902, -60.834);
("Fomalhaut", 344.413, -29.622);
|]
in
Printf.printf "%12s %7s %7s %8s\n" "Star" "Alt" "Az" "Airmass";
Printf.printf "%12s %7s %7s %8s\n" "------------" "-------" "-------"
"--------";
Array.iter
(fun (name, ra_deg, dec_deg) ->
let coord =
Coord.of_radec
~ra:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| ra_deg |]))
~dec:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| dec_deg |]))
in
let hz = Altaz.of_coord ~obstime ~observer:obs coord in
let alt_deg =
item [ 0 ] (Unit.Angle.to_tensor (Altaz.alt hz)) *. 180.0 /. Float.pi
in
let az_deg =
item [ 0 ] (Unit.Angle.to_tensor (Altaz.az hz)) *. 180.0 /. Float.pi
in
let am = item [ 0 ] (Altaz.airmass hz) in
Printf.printf "%12s %+7.1f %7.1f %8.2f\n" name alt_deg az_deg am)
stars;
Printf.printf "\n";
(* --- Part 4: Survey selection --- *)
Printf.printf "Part 4: Survey selection function\n";
Printf.printf "---------------------------------\n\n";
let mag_limit = 20.0 in
let airmass_cut = 2.0 in
Printf.printf " Selection criteria:\n";
Printf.printf " Magnitude limit: r < %.1f (AB)\n" mag_limit;
Printf.printf " Airmass cut: X < %.1f\n" airmass_cut;
Printf.printf " Above horizon: alt > 0 deg\n\n";
let bp = Filters.rubin_r in
let norm = Nx.scalar f64 (Float.exp (-49.0)) in
let star_data =
[|
("Vega", 279.235, 38.784, 5800.0);
("Sirius", 101.287, -16.716, 9940.0);
("Canopus", 95.988, -52.696, 7350.0);
("Alpha Cen", 219.902, -60.834, 5790.0);
("Fomalhaut", 344.413, -29.622, 8590.0);
|]
in
Printf.printf "%12s %7s %8s %6s %s\n" "Star" "Alt" "Airmass" "r_mag"
"Select?";
Printf.printf "%12s %7s %8s %6s %s\n" "------------" "-------" "--------"
"------" "-------";
Array.iter
(fun (name, ra_deg, dec_deg, temp_k) ->
let coord =
Coord.of_radec
~ra:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| ra_deg |]))
~dec:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| dec_deg |]))
in
let hz = Altaz.of_coord ~obstime ~observer:obs coord in
let alt_deg =
item [ 0 ] (Unit.Angle.to_tensor (Altaz.alt hz)) *. 180.0 /. Float.pi
in
let am = item [ 0 ] (Altaz.airmass hz) in
(* Synthetic magnitude through Rubin r-band *)
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 temp_k) in
let bp_wave = Photometry.wavelength bp in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm |> Spectrum.as_flux_density
in
let r_mag = item [] (Photometry.ab_mag bp sed) in
let selected = alt_deg > 0.0 && am < airmass_cut && r_mag < mag_limit in
Printf.printf "%12s %+7.1f %8.2f %6.2f %s\n" name alt_deg am r_mag
(if selected then "YES" else "no"))
star_data;
Printf.printf "\n Height stored: %.0f m\n"
(item [] (Unit.Length.to_tensor (Altaz.observer_height obs)))
================================================
FILE: dev/umbra/examples/07-batch-photometry/README.md
================================================
# `07-batch-photometry`
Computes SDSS g-r colors for a grid of blackbody templates at different
temperatures and dust extinctions in a single pass using batch operations.
Instead of looping over individual spectra, the values tensor has a leading
batch dimension and all photometry operations broadcast over it.
```bash
cd dev/umbra
dune exec --root . examples/07-batch-photometry/main.exe
```
## What You'll Learn
- Constructing batched spectra by stacking blackbodies into a leading dimension
- Broadcasting extinction across a batch of SEDs with per-spectrum A_V
- Computing synthetic SDSS photometry with AB magnitudes
- Exploring color-temperature and color-extinction relations
## Key Functions
| Function | Purpose |
| -------------------------- | ------------------------------------------------ |
| `Spectrum.blackbody` | Generate Planck spectrum at a given temperature |
| `Spectrum.create` | Build a spectrum from wavelength and value arrays |
| `Spectrum.as_flux_density` | Cast to flux density kind for photometry |
| `Nx.stack` | Stack individual spectra into a batch dimension |
| `Extinction.ccm89` | Create CCM89 dust extinction law |
| `Extinction.apply` | Apply reddening with per-spectrum A_V broadcast |
| `Photometry.ab_mag` | Compute AB magnitude through a bandpass |
| `Filters.sdss_g` | SDSS g-band filter response |
## How It Works
The example first builds a grid of 20 blackbody spectra from 3000 K to 30000 K
by stacking individual `Spectrum.blackbody` outputs into a `[n_temp; 500]`
values tensor. When this batch spectrum is passed to `Photometry.ab_mag`, the
integration broadcasts over the leading dimension, producing one magnitude per
temperature in a single call.
The second half demonstrates per-spectrum extinction. A T=6000 K blackbody is
replicated into 10 copies, and `Extinction.apply` is called with an A_V tensor
of shape `[n_av; 1]` that broadcasts against the `[n_av; 500]` flux values.
This yields reddened g-r colors across a range of dust columns without any
explicit loop.
## Try It
1. Increase the temperature grid to 100 points and plot the g-r color curve to
see where the blue turnover occurs.
2. Add a third band (sdss_i) and compute the g-r vs r-i color-color diagram.
3. Replace the blackbody with a power-law spectrum and observe how the color
trends differ.
## Next Steps
Continue to [08-photometric-redshifts](../08-photometric-redshifts/) to learn
how to estimate galaxy redshifts by combining grid search with gradient-based
refinement through the differentiable photometry pipeline.
================================================
FILE: dev/umbra/examples/07-batch-photometry/dune
================================================
(executable
(name main)
(libraries nx rune umbra))
================================================
FILE: dev/umbra/examples/07-batch-photometry/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Batch template photometry.
Computes SDSS g-r colors for a grid of blackbody templates at different
temperatures and dust extinctions in a single pass, demonstrating batched
spectra. Instead of looping over individual spectra, the values tensor has a
leading batch dimension and all photometry operations broadcast over it. *)
open Nx
open Umbra
let f64 = Nx.float64
let () =
Printf.printf "Batch Template Photometry\n";
Printf.printf "=========================\n\n";
(* Temperature grid: 20 blackbodies from 3000K to 30000K *)
let n_temp = 20 in
let temps =
Array.init n_temp (fun i ->
3000.0
+. (Float.of_int i *. (30000.0 -. 3000.0) /. Float.of_int (n_temp - 1)))
in
(* Shared wavelength grid covering SDSS g and r *)
let wavelength = Unit.Length.of_m (Nx.linspace f64 3e-7 1.1e-6 500) in
(* Build batch spectrum: stack individual blackbodies into [n_temp; 500] *)
let values =
Nx.stack
(List.init n_temp (fun i ->
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 temps.(i)) in
Spectrum.values (Spectrum.blackbody ~temperature:temp ~wavelength)))
in
let batch = Spectrum.create ~wavelength ~values |> Spectrum.as_flux_density in
(* AB magnitudes in g and r — returns shape [n_temp] each *)
let g_mag = Photometry.ab_mag Filters.sdss_g batch in
let r_mag = Photometry.ab_mag Filters.sdss_r batch in
let g_r = Nx.sub g_mag r_mag in
Printf.printf "Unreddened blackbody colors (SDSS g-r):\n";
Printf.printf "%8s %8s %8s %8s\n" "T (K)" "g" "r" "g-r";
Printf.printf "%8s %8s %8s %8s\n" "--------" "--------" "--------"
"--------";
Array.iteri
(fun i t ->
if i mod 4 = 0 || i = n_temp - 1 then
Printf.printf "%8.0f %+8.3f %+8.3f %+8.3f\n" t (item [ i ] g_mag)
(item [ i ] r_mag) (item [ i ] g_r))
temps;
(* Now apply per-spectrum extinction: A_V from 0.0 to 2.0 *)
Printf.printf "\nReddening a T=6000K blackbody (SDSS g-r vs A_V):\n";
let n_av = 10 in
let av_values = Nx.linspace f64 0.0 2.0 n_av in
(* Single-temperature spectrum, batched over A_V *)
let temp_6k = Unit.Temperature.of_kelvin (Nx.scalar f64 6000.0) in
let sed_1d =
Spectrum.blackbody ~temperature:temp_6k ~wavelength
|> Spectrum.as_flux_density
in
(* Replicate into [n_av; 500] *)
let sed_values =
Nx.stack (List.init n_av (fun _ -> Spectrum.values sed_1d))
in
let sed_batch =
Spectrum.create ~wavelength ~values:sed_values |> Spectrum.as_flux_density
in
(* Per-spectrum A_V: reshape to [n_av; 1] to broadcast with [n_av; 500] *)
let rv = Nx.scalar f64 3.1 in
let av_col = Nx.reshape [| n_av; 1 |] av_values in
let reddened = Extinction.apply (Extinction.ccm89 ~rv) ~av:av_col sed_batch in
let g_red = Photometry.ab_mag Filters.sdss_g reddened in
let r_red = Photometry.ab_mag Filters.sdss_r reddened in
let g_r_red = Nx.sub g_red r_red in
Printf.printf "%8s %8s\n" "A_V" "g-r";
Printf.printf "%8s %8s\n" "--------" "--------";
for i = 0 to n_av - 1 do
Printf.printf "%8.2f %+8.3f\n" (item [ i ] av_values) (item [ i ] g_r_red)
done
================================================
FILE: dev/umbra/examples/08-photometric-redshifts/README.md
================================================
# `08-photometric-redshifts`
Two-stage photometric redshift estimation: coarse grid search followed by
gradient-based refinement using Adam. The full pipeline (blackbody -> redshift
-> extinction -> photometry) is differentiable through Rune, enabling gradient
descent on redshift and normalization parameters against synthetic SDSS ugriz
observations.
```bash
cd dev/umbra
dune exec --root . examples/08-photometric-redshifts/main.exe
```
## What You'll Learn
- Building an end-to-end differentiable photometric pipeline through SDSS ugriz filters
- Composing spectrum redshifting, dust extinction, and synthetic photometry
- Combining grid search initialization with autodiff gradient refinement
- Using multi-parameter gradients to jointly fit redshift and normalization
## Key Functions
| Function | Purpose |
| -------------------------- | ---------------------------------------------------- |
| `Spectrum.blackbody` | Generate a template SED at given temperature |
| `Spectrum.redshift` | Apply cosmological redshift to a spectrum |
| `Spectrum.scale` | Scale spectrum by a normalization factor |
| `Extinction.apply` | Apply dust reddening with CCM89 law |
| `Photometry.ab_mag` | Compute AB magnitude through a bandpass |
| `Photometry.wavelength` | Extract the wavelength grid of a bandpass filter |
| `Rune.value_and_grads` | Compute loss and parameter gradients in one pass |
| `Vega.adam` | Adam optimizer for gradient refinement |
## How It Works
The example generates synthetic observed magnitudes for a galaxy at z=0.3 with
T=5500 K, A_V=0.2, by pushing a blackbody through the full pipeline:
`blackbody -> scale -> extinction -> redshift -> ab_mag` in each of the five
SDSS bands. These serve as the "data" to fit against.
Stage 1 performs a coarse grid search over 30 redshift values from 0.01 to
0.90, computing chi-squared at each point with a fixed template. This
identifies a rough minimum without requiring gradients.
Stage 2 takes the best grid redshift and refines it with 500 Adam optimizer
steps. The loss function (sum of squared magnitude residuals) flows through
`Spectrum.redshift` and `Photometry.ab_mag`, so Rune provides exact gradients
with respect to log(1+z) and log(normalization). The parameterization in
log-space ensures positivity and improves conditioning.
## Try It
1. Change the true redshift to z=0.7 and observe how the grid search coarseness
affects the initial estimate.
2. Add temperature as a third free parameter in the refinement stage.
3. Replace the single blackbody template with a composite SED that includes an
emission line.
## Next Steps
Continue to [09-gravitational-lensing](../09-gravitational-lensing/) to see how
Rune's autodiff can fit physical parameters of a gravitational lens model from
observed image positions.
================================================
FILE: dev/umbra/examples/08-photometric-redshifts/dune
================================================
(executable
(name main)
(libraries nx rune vega umbra))
================================================
FILE: dev/umbra/examples/08-photometric-redshifts/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Photometric redshift estimation via template fitting.
Demonstrates composing Spectrum.redshift -> Extinction.apply ->
Photometry.ab_mag through real SDSS filters, with gradient refinement via
Rune's autodiff. Auto-resampling makes the pipeline seamless.
Stage 1: Grid search over redshift to find a coarse estimate. Stage 2: Adam
optimizer refines z and normalization using AD gradients. *)
open Nx
open Umbra
let f64 = Nx.float64
let bands =
[
Filters.sdss_u;
Filters.sdss_g;
Filters.sdss_r;
Filters.sdss_i;
Filters.sdss_z;
]
let band_names = [| "u"; "g"; "r"; "i"; "z" |]
(* True parameters for synthetic galaxy *)
let true_z = 0.3
let true_temp = 5500.0
let true_av = 0.2
let true_log_norm = -50.0
let rv = Nx.scalar f64 3.1
(* Synthetic observed magnitudes *)
let obs_mags =
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 true_temp) in
let z = Nx.scalar f64 true_z in
let av = Nx.scalar f64 true_av in
let norm = Nx.scalar f64 (Float.exp true_log_norm) in
List.map
(fun bp ->
let bp_wave = Photometry.wavelength bp in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm
|> Extinction.apply (Extinction.ccm89 ~rv) ~av
|> Spectrum.as_flux_density |> Spectrum.redshift ~z
in
Photometry.ab_mag bp sed)
bands
(* Grid search: coarse scan over z *)
let grid_search () =
let best_z = ref 0.0 in
let best_chi2 = ref Float.infinity in
let n_z = 30 in
for iz = 0 to n_z - 1 do
let z = Nx.scalar f64 (0.01 +. (Float.of_int iz *. 0.03)) in
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 5000.0) in
let norm = Nx.scalar f64 (Float.exp (-50.0)) in
let pred =
List.map
(fun bp ->
let bp_wave = Photometry.wavelength bp in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm |> Spectrum.as_flux_density
|> Spectrum.redshift ~z
in
Photometry.ab_mag bp sed)
bands
in
(* Color-based chi-squared: compare color differences *)
let chi2 =
List.fold_left2
(fun acc p o -> add acc (square (sub p o)))
(scalar f64 0.0) pred obs_mags
in
let chi2_v = item [] chi2 in
if chi2_v < !best_chi2 then begin
best_chi2 := chi2_v;
best_z := item [] z
end
done;
!best_z
(* Gradient refinement around grid minimum *)
let refine z0 =
let loss params =
match params with
| [ log_z1; log_norm ] ->
let z = sub (exp log_z1) (scalar f64 1.0) in
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 5500.0) in
let norm = exp log_norm in
let pred =
List.map
(fun bp ->
let bp_wave = Photometry.wavelength bp in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:bp_wave
|> Spectrum.scale norm |> Spectrum.as_flux_density
|> Spectrum.redshift ~z
in
Photometry.ab_mag bp sed)
bands
in
List.fold_left2
(fun acc p o -> add acc (square (sub p o)))
(scalar f64 0.0) pred obs_mags
| _ -> failwith "expected [log_z1; log_norm]"
in
let algo = Vega.adam (Vega.Schedule.constant 5e-4) in
let log_z1 = ref (scalar f64 (Float.log (1.0 +. z0))) in
let log_norm = ref (scalar f64 (-50.0)) in
let states = [| Vega.init algo !log_z1; Vega.init algo !log_norm |] in
let refs = [| log_z1; log_norm |] in
for _ = 0 to 499 do
let _loss_val, grads = Rune.value_and_grads loss [ !log_z1; !log_norm ] in
List.iteri
(fun j g ->
let p, s = Vega.step states.(j) ~grad:g ~param:!(refs.(j)) in
refs.(j) := p;
states.(j) <- s)
grads
done;
Float.exp (item [] !log_z1) -. 1.0
let () =
Printf.printf "Photometric Redshift Estimation\n";
Printf.printf "===============================\n";
Printf.printf
"Pipeline: blackbody -> redshift -> extinction -> ab_mag (SDSS)\n\n";
Printf.printf "True: z=%.3f T=%.0fK A_V=%.2f\n\n" true_z true_temp true_av;
Printf.printf "Observed magnitudes:\n";
List.iteri
(fun i m -> Printf.printf " %s = %.3f\n" band_names.(i) (item [] m))
obs_mags;
Printf.printf "\nStep 1: Grid search (z = 0.01 to 0.90)...\n";
let z_grid = grid_search () in
Printf.printf " Best grid z = %.3f\n" z_grid;
Printf.printf "\nStep 2: Gradient refinement (500 Adam steps)...\n";
let z_fit = refine z_grid in
Printf.printf " Refined z = %.4f (true: %.3f)\n" z_fit true_z;
Printf.printf " Error = %.4f\n" (Float.abs (z_fit -. true_z))
================================================
FILE: dev/umbra/examples/09-gravitational-lensing/README.md
================================================
# `09-gravitational-lensing`
Fits gravitational lens parameters (lens center and Einstein radius) from
observed image positions of a quadruply-imaged quasar. The point-mass lens
equation is expressed as Nx tensor operations, making the model fully
differentiable through Rune for gradient-based fitting with Adam.
```bash
cd dev/umbra
dune exec --root . examples/09-gravitational-lensing/main.exe
```
## What You'll Learn
- Expressing the gravitational lens equation as differentiable tensor operations
- Minimizing source-plane variance to fit lens parameters
- Fitting physical parameters (lens position, Einstein radius) via Adam optimizer
- Using autodiff gradients with a physics-based loss function
## Key Functions
| Function | Purpose |
| ----------------------- | ------------------------------------------------------ |
| `Nx.square` | Squared distances for radial computation |
| `Nx.sqrt` | Radial distance from lens center |
| `Nx.mean` | Mean source position across images |
| `Rune.value_and_grads` | Compute loss and gradients for all lens parameters |
| `Vega.adam` | Adam optimizer for parameter fitting |
| `Vega.step` | Apply one optimization update |
## How It Works
A point-mass gravitational lens deflects light according to the lens equation:
beta = theta - theta_E^2 * theta_hat / |theta|, where beta is the true source
position, theta is the observed image position, and theta_E is the Einstein
radius. If the lens model is correct, all observed images should map back to the
same source position in the source plane.
The example generates synthetic image positions for a quadruply-imaged quasar
with known lens parameters (x_L=0.1, y_L=-0.05, theta_E=1.0) plus small noise.
The loss function maps each image back to the source plane using the current
lens parameters and computes the variance of the inferred source positions. A
correct lens model yields zero variance.
Starting from an initial guess of (x_L=0, y_L=0, theta_E=0.5), the Adam
optimizer runs for 500 steps. Rune differentiates through the entire lens
equation -- including the division by |theta|, the square root, and the
mean/variance -- to provide exact gradients that drive convergence to the true
parameters.
## Try It
1. Increase the noise level from 0.005 to 0.05 and observe how parameter
uncertainties grow.
2. Add a shear term (gamma_1, gamma_2) to the lens model for external
tidal perturbation.
3. Replace the point-mass with a singular isothermal sphere (SIS) profile
where the deflection is constant: alpha = theta_E * theta_hat.
## Next Steps
Continue to [10-uncertainty-propagation](../10-uncertainty-propagation/) to
learn how to automatically propagate parameter uncertainties through
cosmological distance calculations using exact AD Jacobians.
================================================
FILE: dev/umbra/examples/09-gravitational-lensing/dune
================================================
(executable
(name main)
(libraries nx rune vega umbra))
================================================
FILE: dev/umbra/examples/09-gravitational-lensing/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Differentiable gravitational lens parameter fitting.
A point-mass gravitational lens produces multiple images of a background
source. Given the observed image positions, we fit the lens center and
Einstein radius by requiring that all images map back to the same source
position. The lens equation and source-plane mapping are expressed as Nx
tensor operations, making the entire model differentiable through Rune. *)
open Nx
let f64 = Nx.float64
(* True lens parameters (for generating synthetic data) *)
let true_x_l = 0.1
let true_y_l = -0.05
let true_theta_e = 1.0
(* Generate synthetic image positions for a quadruply-imaged quasar. Source at
(0.15, 0.08), lens at (true_x_l, true_y_l). *)
let source_x = 0.15
let source_y = 0.08
let () = Printf.printf "Differentiable gravitational lens modeling\n"
let () = Printf.printf "==========================================\n\n"
(* Solve lens equation: beta = theta - theta_E^2 * theta / |theta|^2 for point
mass (where theta is relative to lens center). We generate 4 image positions
by solving analytically + adding noise. *)
let img_x, img_y =
(* For a point mass, images lie along the source-lens axis. Place 4 images at
realistic positions around the lens. *)
let dx = source_x -. true_x_l in
let dy = source_y -. true_y_l in
let beta = Float.sqrt ((dx *. dx) +. (dy *. dy)) in
let cos_a = dx /. beta and sin_a = dy /. beta in
(* Two images along the axis *)
let theta_p =
(beta
+. Float.sqrt ((beta *. beta) +. (4.0 *. true_theta_e *. true_theta_e)))
/. 2.0
in
let theta_m =
(beta
-. Float.sqrt ((beta *. beta) +. (4.0 *. true_theta_e *. true_theta_e)))
/. 2.0
in
(* Image positions in 2D (along and perpendicular to axis, with noise) *)
let noise = 0.005 in
let x1 = true_x_l +. (theta_p *. cos_a) +. (noise *. 0.3) in
let y1 = true_y_l +. (theta_p *. sin_a) -. (noise *. 0.2) in
let x2 = true_x_l +. (theta_m *. cos_a) -. (noise *. 0.5) in
let y2 = true_y_l +. (theta_m *. sin_a) +. (noise *. 0.4) in
(* Add two more images from slight perturbation (simulating extended
source) *)
let x3 = true_x_l +. (theta_p *. 0.7 *. cos_a) +. (theta_p *. 0.3 *. sin_a) in
let y3 = true_y_l +. (theta_p *. 0.7 *. sin_a) -. (theta_p *. 0.3 *. cos_a) in
let x4 = true_x_l -. (theta_p *. 0.5 *. cos_a) +. (theta_p *. 0.4 *. sin_a) in
let y4 = true_y_l -. (theta_p *. 0.5 *. sin_a) -. (theta_p *. 0.4 *. cos_a) in
( create f64 [| 4 |] [| x1; x2; x3; x4 |],
create f64 [| 4 |] [| y1; y2; y3; y4 |] )
(* Loss: given lens params, map each image back to the source plane. All images
should map to the same source -> minimize variance of inferred source
positions. *)
let loss params =
match params with
| [ x_l; y_l; theta_e ] ->
(* Displacement from lens center *)
let dx = sub img_x x_l in
let dy = sub img_y y_l in
(* Distance from lens center *)
let r_sq = add (square dx) (square dy) in
let r = sqrt r_sq in
(* Point-mass deflection: alpha = theta_E^2 / r *)
let alpha = div (square theta_e) r in
(* Source position for each image: beta = theta - alpha * hat(theta) *)
let beta_x = sub img_x (mul alpha (div dx r)) in
let beta_y = sub img_y (mul alpha (div dy r)) in
(* Variance of source positions (should be ~0 if lens model is correct) *)
let mean_bx = mean beta_x in
let mean_by = mean beta_y in
let var_x = mean (square (sub beta_x mean_bx)) in
let var_y = mean (square (sub beta_y mean_by)) in
add var_x var_y
| _ -> failwith "expected [x_l; y_l; theta_e]"
let () =
Printf.printf "True parameters:\n";
Printf.printf " x_L = %.3f arcsec\n" true_x_l;
Printf.printf " y_L = %.3f arcsec\n" true_y_l;
Printf.printf " theta_E = %.3f arcsec\n\n" true_theta_e;
let algo = Vega.adam (Vega.Schedule.constant 1e-2) in
let x_l = ref (scalar f64 0.0) in
let y_l = ref (scalar f64 0.0) in
let theta_e = ref (scalar f64 0.5) in
let states =
[| Vega.init algo !x_l; Vega.init algo !y_l; Vega.init algo !theta_e |]
in
let steps = 500 in
Printf.printf "%5s %12s %8s %8s %8s\n" "step" "loss" "x_L" "y_L" "theta_E";
Printf.printf "%5s %12s %8s %8s %8s\n" "-----" "------------" "--------"
"--------" "--------";
let refs = [| x_l; y_l; theta_e |] in
for i = 0 to steps - 1 do
let loss_val, grads = Rune.value_and_grads loss [ !x_l; !y_l; !theta_e ] in
List.iteri
(fun j g ->
let p, s = Vega.step states.(j) ~grad:g ~param:!(refs.(j)) in
refs.(j) := p;
states.(j) <- s)
grads;
if i mod 100 = 0 || i = steps - 1 then
Printf.printf "%5d %12.8f %8.4f %8.4f %8.4f\n" i (item [] loss_val)
(item [] !x_l) (item [] !y_l) (item [] !theta_e)
done;
Printf.printf "\nFitted parameters:\n";
Printf.printf " x_L = %.4f (true: %.4f)\n" (item [] !x_l) true_x_l;
Printf.printf " y_L = %.4f (true: %.4f)\n" (item [] !y_l) true_y_l;
Printf.printf " theta_E = %.4f (true: %.4f)\n" (item [] !theta_e)
true_theta_e
================================================
FILE: dev/umbra/examples/10-uncertainty-propagation/README.md
================================================
# `10-uncertainty-propagation`
Automatic uncertainty propagation through cosmological distance calculations.
Propagates H0 and Omega_m uncertainties through distance modulus using exact
AD Jacobians via forward-mode differentiation. The linear error propagation
formula (Sigma_out = J Sigma_in J^T) is validated against Monte Carlo sampling
with 50,000 draws.
```bash
cd dev/umbra
dune exec --root . examples/10-uncertainty-propagation/main.exe
```
## What You'll Learn
- Computing exact Jacobians automatically with forward-mode AD (`Rune.jacfwd`)
- Applying linear error propagation via the Jacobian covariance formula
- Validating analytical uncertainty estimates with Monte Carlo sampling
- Propagating scalar uncertainties through cosmological models with JVP
## Key Functions
| Function | Purpose |
| ---------------------------- | ------------------------------------------------ |
| `Cosmo.create_flat_lcdm` | Create a flat Lambda-CDM cosmology |
| `Cosmo.distance_modulus` | Compute distance modulus at a given redshift |
| `Rune.jacfwd` | Forward-mode Jacobian of a function |
| `Rune.jvp` | Jacobian-vector product for scalar propagation |
| `Nx.cholesky` | Cholesky decomposition for MC sampling |
| `Nx.matmul` | Matrix multiply for J Sigma J^T |
| `Nx.diag` | Build diagonal covariance from variances |
## How It Works
Given input parameters with uncertainties (H0 = 70 +/- 1 km/s/Mpc, Omega_m =
0.30 +/- 0.01), the example propagates these through `Cosmo.distance_modulus`
at five redshifts (z = 0.1 to 1.0). The propagation uses the standard linear
formula: Sigma_out = J Sigma_in J^T, where J is the Jacobian of the distance
modulus with respect to [H0, Omega_m]. Rather than deriving J analytically,
`Rune.jacfwd` computes it automatically with just two JVP evaluations (one per
input parameter).
For validation, the example draws 50,000 Monte Carlo samples from the input
covariance via Cholesky decomposition, evaluates the model at each sample, and
computes empirical output statistics. Agreement below 1% between AD and MC
confirms that linear propagation is accurate for these parameter ranges.
A scalar API demo shows the simpler case: propagating redshift uncertainty
(z = 0.5 +/- 0.01) through a single `jvp` call, which returns both the output
value and its sensitivity to the input perturbation.
## Try It
1. Add correlation between H0 and Omega_m by putting off-diagonal terms in the
input covariance matrix.
2. Increase the uncertainties to see where linear propagation breaks down and
MC diverges from AD.
3. Propagate uncertainties through `Cosmo.luminosity_distance` instead of
distance modulus and compare the relative errors.
## Next Steps
Continue to [11-bayesian-sed](../11-bayesian-sed/) to see how Fisher information
and Hamiltonian Monte Carlo provide both theoretical bounds and full Bayesian
posteriors for SED parameter estimation.
================================================
FILE: dev/umbra/examples/10-uncertainty-propagation/dune
================================================
(executable
(name main)
(libraries nx rune umbra))
================================================
FILE: dev/umbra/examples/10-uncertainty-propagation/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Automatic uncertainty propagation through cosmological distances.
Demonstrates propagating H0 and Omega_m uncertainties through
Umbra.Cosmo.distance_modulus using exact AD Jacobians. The linear error
propagation formula (Sigma_out = J Sigma_in J^T) is computed automatically
via Rune.jacfwd. Results are validated against Monte Carlo sampling.
Fisher, propagation, and Monte Carlo are all trivial given Rune's jacfwd --
no dedicated library needed. *)
open Nx
open Umbra
let f64 = Nx.float64
(* Redshifts to evaluate *)
let redshifts = [| 0.1; 0.3; 0.5; 0.7; 1.0 |]
(* Forward model: given [H0; Omega_m], compute distance modulus at z *)
let distance_modulus_at z p =
let h0 = Nx.reshape [||] (Nx.slice [ I 0 ] p) in
let om = Nx.reshape [||] (Nx.slice [ I 1 ] p) in
let cosmo = Cosmo.create_flat_lcdm ~h0 ~omega_m:om in
Cosmo.distance_modulus ~p:cosmo (Nx.scalar f64 z)
(* Linear error propagation: Sigma_out = J Sigma_in J^T *)
let propagate f ~mean ~cov =
let j = Rune.jacfwd f mean in
let mean_out = f mean in
let cov_out = Nx.matmul (Nx.matmul j cov) (Nx.matrix_transpose j) in
let cov_out = Nx.div_s (Nx.add cov_out (Nx.matrix_transpose cov_out)) 2.0 in
(mean_out, cov_out)
(* Monte Carlo validation *)
let monte_carlo ?(n_samples = 50_000) f ~mean ~cov =
let n = Nx.numel mean in
let l = Nx.cholesky cov in
let z = Nx.randn f64 [| n_samples; n |] in
let samples = Nx.add (Nx.matmul z (Nx.matrix_transpose l)) mean in
let y0 = f (Nx.slice [ I 0 ] samples) in
let m = Nx.numel y0 in
let outputs = Nx.zeros f64 [| n_samples; m |] in
Nx.set_slice [ I 0 ] outputs y0;
for i = 1 to n_samples - 1 do
Nx.set_slice [ I i ] outputs (f (Nx.slice [ I i ] samples))
done;
let mean_out = Nx.mean ~axes:[ 0 ] outputs in
let centered = Nx.sub outputs mean_out in
let cov_out =
Nx.div_s
(Nx.matmul (Nx.matrix_transpose centered) centered)
(Float.of_int (n_samples - 1))
in
(mean_out, cov_out)
let () =
Printf.printf "Automatic Uncertainty Propagation through Cosmology\n";
Printf.printf "====================================================\n\n";
(* Parameters with uncertainties *)
let h0_mean = 70.0 and h0_std = 1.0 in
let om_mean = 0.30 and om_std = 0.01 in
Printf.printf "Input parameters:\n";
Printf.printf " H0 = %.1f +/- %.1f km/s/Mpc\n" h0_mean h0_std;
Printf.printf " Omega_m = %.2f +/- %.2f\n\n" om_mean om_std;
let mean = Nx.create f64 [| 2 |] [| h0_mean; om_mean |] in
let std = Nx.create f64 [| 2 |] [| h0_std; om_std |] in
let cov = Nx.diag (Nx.square std) in
Printf.printf "%5s %10s %10s %10s %10s\n" "z" "mu (AD)" "sigma (AD)"
"sigma (MC)" "agreement";
Printf.printf "%5s %10s %10s %10s %10s\n" "-----" "----------"
"----------" "----------" "----------";
Array.iter
(fun z ->
(* AD-based propagation *)
let f p = Nx.reshape [| 1 |] (distance_modulus_at z p) in
let mean_ad, cov_ad = propagate f ~mean ~cov in
let mu_ad = item [ 0 ] mean_ad in
let std_ad = Float.sqrt (item [ 0; 0 ] cov_ad) in
(* Monte Carlo validation *)
let _, cov_mc = monte_carlo f ~mean ~cov in
let std_mc = Float.sqrt (item [ 0; 0 ] cov_mc) in
let agreement = Float.abs (std_ad -. std_mc) /. std_mc *. 100.0 in
Printf.printf "%5.1f %10.4f %10.4f %10.4f %9.1f%%\n" z mu_ad std_ad
std_mc agreement)
redshifts;
Printf.printf "\n";
Printf.printf "AD uses exact Jacobians (2 JVP calls for 2 parameters).\n";
Printf.printf "MC uses 50,000 samples for validation.\n";
Printf.printf "Agreement < 1%% confirms linear propagation is accurate.\n";
(* Also demonstrate the simple scalar API *)
Printf.printf "\n--- Scalar API demo ---\n\n";
Printf.printf "Propagating z = 0.5 +/- 0.01 through distance_modulus:\n";
let x = Nx.scalar f64 0.5 in
let y, dy =
Rune.jvp (fun z -> Cosmo.distance_modulus z) x (Nx.scalar f64 1.0)
in
let mu_mean = Nx.item [] y in
let mu_std = Float.abs (Nx.item [] dy) *. 0.01 in
Printf.printf " mu = %.4f +/- %.4f\n" mu_mean mu_std
================================================
FILE: dev/umbra/examples/11-bayesian-sed/README.md
================================================
# `11-bayesian-sed`
Fisher information matrix analysis and Hamiltonian Monte Carlo sampling for
Bayesian SED parameter estimation. Computes Cramer-Rao bounds (theoretical
minimum uncertainties) from the Fisher matrix, then samples the full posterior
via HMC through the differentiable spectrum -> extinction -> photometry pipeline.
```bash
cd dev/umbra
dune exec --root . examples/11-bayesian-sed/main.exe
```
## What You'll Learn
- Computing the Fisher information matrix via reverse-mode Jacobians
- Deriving Cramer-Rao bounds on SED parameters (temperature, extinction)
- Sampling full Bayesian posteriors with Hamiltonian Monte Carlo
- Comparing Fisher-predicted vs HMC-sampled uncertainties
- Building differentiable forward models through tophat bandpasses
## Key Functions
| Function | Purpose |
| -------------------------- | --------------------------------------------------- |
| `Rune.jacrev` | Reverse-mode Jacobian for Fisher matrix computation |
| `Nx.inv` | Matrix inverse for Fisher -> covariance |
| `Nx.diagonal` | Extract diagonal (marginal variances) |
| `Spectrum.blackbody` | Generate Planck SED at given temperature |
| `Extinction.apply` | Apply CCM89 dust reddening |
| `Photometry.tophat` | Create rectangular bandpass filters |
| `Photometry.ab_mag` | Compute AB magnitude through a bandpass |
| `Norn.hmc` | Hamiltonian Monte Carlo posterior sampling |
## How It Works
The forward model maps two parameters -- log(T) and A_V -- to five broadband
magnitudes through the pipeline: `blackbody -> extinction -> ab_mag`. Synthetic
observations are generated at T=6500 K, A_V=0.5 with realistic photometric
errors (0.03-0.05 mag).
The Fisher information matrix F = J^T C^-1 J is computed from the Jacobian of
the model (via `Rune.jacrev`) and the observational covariance C. Inverting F
gives the Cramer-Rao lower bound -- the best achievable 1-sigma uncertainty on
each parameter for a given dataset, regardless of estimation method.
The example then samples the actual Bayesian posterior using `Norn.hmc`. The
log-posterior is a Gaussian likelihood with flat priors, and HMC uses Rune's
gradients to efficiently explore the parameter space with 500 post-warmup
samples. Comparing the HMC posterior width to the Fisher prediction validates
that the model is well-behaved: when they agree, the posterior is approximately
Gaussian and the Fisher bound is tight.
## Try It
1. Reduce the photometric errors to 0.01 mag and observe how both Fisher bounds
and HMC posteriors tighten.
2. Add a third parameter (redshift) and examine the resulting parameter
degeneracies in the Fisher matrix.
3. Replace the flat prior with an informative Gaussian prior on A_V and see
how the posterior shifts.
## Next Steps
Continue to [12-survey-optimization](../12-survey-optimization/) to see how
differentiable Fisher forecasting enables gradient-based optimization of survey
design parameters for weak gravitational lensing.
================================================
FILE: dev/umbra/examples/11-bayesian-sed/dune
================================================
(executable
(name main)
(libraries nx rune vega norn umbra))
================================================
FILE: dev/umbra/examples/11-bayesian-sed/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Fisher information and HMC sampling for SED parameter estimation.
Demonstrates two capabilities:
1. Fisher matrix: compute the Cramer-Rao bounds on temperature and extinction
-- "how well CAN I constrain these parameters from UGRIZ photometry?" --
before taking any data. Computed inline from Rune.jacrev + linear algebra.
2. HMC sampling: full Bayesian posterior through the differentiable Spectrum
-> Extinction -> Photometry pipeline, via Norn.hmc. *)
open Nx
open Umbra
let f64 = Nx.float64
(* Bandpasses *)
let n_bp = 20
let bands =
[
Photometry.tophat ~lo:(Unit.Length.m 3.0e-7) ~hi:(Unit.Length.m 4.0e-7)
~n:n_bp;
Photometry.tophat ~lo:(Unit.Length.m 4.0e-7) ~hi:(Unit.Length.m 5.5e-7)
~n:n_bp;
Photometry.tophat ~lo:(Unit.Length.m 5.5e-7) ~hi:(Unit.Length.m 7.0e-7)
~n:n_bp;
Photometry.tophat ~lo:(Unit.Length.m 7.0e-7) ~hi:(Unit.Length.m 8.5e-7)
~n:n_bp;
Photometry.tophat ~lo:(Unit.Length.m 8.5e-7) ~hi:(Unit.Length.m 1.0e-6)
~n:n_bp;
]
let band_names = [| "U"; "G"; "R"; "I"; "Z" |]
let rv = Nx.scalar f64 3.1
(* Forward model: [log_T, A_V] -> 5 magnitudes *)
let model params =
let log_temp = Nx.reshape [||] (Nx.slice [ I 0 ] params) in
let av = Nx.reshape [||] (Nx.slice [ I 1 ] params) in
let temp = Unit.Temperature.of_kelvin (Nx.exp log_temp) in
let mags =
List.map
(fun bp ->
let wave = Photometry.wavelength bp in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:wave
|> Extinction.apply (Extinction.ccm89 ~rv) ~av
|> Spectrum.as_flux_density
in
Photometry.ab_mag bp sed)
bands
in
Nx.stack ~axis:0 mags
(* True parameters *)
let true_log_temp = Float.log 6500.0
let true_av = 0.5
let true_params = Nx.create f64 [| 2 |] [| true_log_temp; true_av |]
(* Synthetic observations *)
let obs_errs = Nx.create f64 [| 5 |] [| 0.05; 0.03; 0.03; 0.04; 0.05 |]
let obs_mags =
let true_mags = model true_params in
let noise = Nx.create f64 [| 5 |] [| 0.03; -0.02; 0.01; -0.01; 0.02 |] in
Nx.add true_mags noise
(* Fisher information: F = J^T C^-1 J *)
let fisher f ~params ~obs_cov =
let j = Rune.jacrev f params in
let jt = Nx.matrix_transpose j in
Nx.matmul (Nx.matmul jt (Nx.inv obs_cov)) j
(* Cramer-Rao bounds: sigma = sqrt(diag(F^-1)) *)
let marginal_sigma f = Nx.sqrt (Nx.diagonal (Nx.inv f))
let () =
Printf.printf "Fisher Information & HMC for SED Fitting\n";
Printf.printf "=========================================\n\n";
Printf.printf "True parameters:\n";
Printf.printf " T = %.0f K (log_T = %.4f)\n" (Float.exp true_log_temp)
true_log_temp;
Printf.printf " A_V = %.2f\n\n" true_av;
Printf.printf "Observed magnitudes:\n";
Array.iteri
(fun i name ->
Printf.printf " %s = %.3f +/- %.3f\n" name (item [ i ] obs_mags)
(item [ i ] obs_errs))
band_names;
Printf.printf "\n";
(* --- Fisher Information --- *)
Printf.printf "=== Fisher Information ===\n\n";
let obs_cov = Nx.diag (Nx.square obs_errs) in
let f = fisher model ~params:true_params ~obs_cov in
let sigma = marginal_sigma f in
Printf.printf "Fisher matrix:\n";
Printf.printf " F = [[ %10.2f %10.2f ]\n"
(item [ 0; 0 ] f)
(item [ 0; 1 ] f);
Printf.printf " [ %10.2f %10.2f ]]\n\n"
(item [ 1; 0 ] f)
(item [ 1; 1 ] f);
Printf.printf "Cramer-Rao bounds (best achievable 1-sigma):\n";
let sigma_log_t = item [ 0 ] sigma in
let sigma_av = item [ 1 ] sigma in
Printf.printf " sigma(log_T) = %.4f -> sigma(T) ~ %.0f K\n" sigma_log_t
(sigma_log_t *. Float.exp true_log_temp);
Printf.printf " sigma(A_V) = %.4f\n\n" sigma_av;
(* --- HMC Sampling --- *)
Printf.printf "=== HMC Posterior Sampling ===\n\n";
(* Log-posterior: Gaussian likelihood, flat prior *)
let log_posterior params =
let pred = model params in
let residuals = Nx.div (Nx.sub pred obs_mags) obs_errs in
Nx.mul_s (Nx.sum (Nx.square residuals)) (-0.5)
in
let init = Nx.create f64 [| 2 |] [| Float.log 7000.0; 0.3 |] in
let result =
Norn.hmc ~step_size:0.001 ~num_leapfrog:10 ~num_warmup:200 ~n:500
log_posterior init
in
Printf.printf "HMC diagnostics:\n";
Printf.printf " Accept rate: %.1f%%\n\n" (result.stats.accept_rate *. 100.);
(* Sample statistics *)
let sample_mean = Nx.mean ~axes:[ 0 ] result.samples in
let centered = Nx.sub result.samples sample_mean in
let sample_cov =
Nx.div_s
(Nx.matmul (Nx.matrix_transpose centered) centered)
(Float.of_int 499)
in
let sample_std = Nx.sqrt (Nx.diagonal sample_cov) in
let hmc_log_t = item [ 0 ] sample_mean in
let hmc_av = item [ 1 ] sample_mean in
let hmc_sigma_log_t = item [ 0 ] sample_std in
let hmc_sigma_av = item [ 1 ] sample_std in
Printf.printf "Posterior (HMC):\n";
Printf.printf " log_T = %.4f +/- %.4f -> T ~ %.0f K\n" hmc_log_t
hmc_sigma_log_t (Float.exp hmc_log_t);
Printf.printf " A_V = %.4f +/- %.4f\n\n" hmc_av hmc_sigma_av;
(* --- Comparison --- *)
Printf.printf "=== Fisher vs HMC Comparison ===\n\n";
Printf.printf " %12s %10s %10s\n" "" "Fisher s" "HMC s";
Printf.printf " %12s %10s %10s\n" "------------" "----------" "----------";
Printf.printf " %12s %10.4f %10.4f\n" "s(log_T)" sigma_log_t
hmc_sigma_log_t;
Printf.printf " %12s %10.4f %10.4f\n\n" "s(A_V)" sigma_av hmc_sigma_av;
Printf.printf "Fisher gives the theoretical minimum uncertainty.\n";
Printf.printf "HMC gives the actual posterior width.\n";
Printf.printf "Agreement confirms the model is well-behaved (near-linear).\n"
================================================
FILE: dev/umbra/examples/12-survey-optimization/README.md
================================================
# `12-survey-optimization`
Differentiable survey optimization for a Stage IV weak lensing survey. Uses
exact autodiff gradients to optimize survey parameters that minimize the
uncertainty on S8 = sigma8 * sqrt(Omega_m / 0.3), replacing traditional grid
search with gradient-based Fisher forecasting. Demonstrates both a single-bin
area/depth tradeoff and joint optimization of sky fraction with tomographic bin
edges.
```bash
cd dev/umbra
dune exec --root . examples/12-survey-optimization/main.exe
```
## What You'll Learn
- Computing differentiable Fisher information matrices for survey forecasting
- Optimizing the area/depth tradeoff for sky coverage vs galaxy density
- Jointly optimizing sky fraction and tomographic bin edges with gradient descent
- Using sigmoid-windowed bins for smooth gradient flow through discrete boundaries
- Comparing gradient-based optimization against brute-force grid search
## Key Functions
| Function | Purpose |
| --------------------------- | ---------------------------------------------------- |
| `Survey.angular_cl` | Compute angular power spectra for tracer pairs |
| `Survey.weak_lensing` | Create a weak lensing tracer from n(z) |
| `Survey.smail` | Smail redshift distribution for source galaxies |
| `Cosmo.planck18` | Planck 2018 fiducial cosmology |
| `Cosmo.linear_power` | Linear matter power spectrum P(k, z) |
| `Cosmo.comoving_distance` | Comoving distance for lensing kernel computation |
| `Rune.value_and_grad` | Loss and gradient for survey parameter optimization |
| `Vega.adam` | Adam optimizer for continuous parameter search |
## How It Works
Part 1 tackles the area/depth tradeoff for a single tomographic bin. A fixed
galaxy budget (n_gal * f_sky = constant) means wider surveys are shallower. The
Fisher matrix for [Omega_m, sigma8] is computed from Limber-integrated angular
power spectra, with shape noise that depends on galaxy density. The objective
function -- sigma(S8) derived from the 2x2 Fisher inverse -- is fully
differentiable through f_sky via sigmoid parameterization. Adam finds the
optimal sky fraction in 300 steps with exact gradients, verified by a
finite-difference check.
Part 2 extends to joint optimization of sky fraction and two tomographic bin
edges that divide galaxies into three redshift bins. The bin boundaries use
sigmoid window functions (with width delta=0.03) so that gradients flow smoothly
through the discrete bin assignment. Narrower bins concentrate signal but
increase shot noise; the optimizer balances this tradeoff automatically. The
Limber integral uses precomputed cosmological grids (comoving distances, Hubble
rates, power spectra) evaluated at five cosmology perturbations for numerical
derivatives of C_l with respect to Omega_m and sigma8, while gradients with
respect to survey parameters (f_sky, z1, z2) flow through Rune's autodiff.
A brute-force grid search over 12 x 15 x 15 = 2700 parameter combinations
validates the gradient result, demonstrating that 500 Adam steps achieve equal
or better precision with orders of magnitude fewer function evaluations.
## Try It
1. Increase the galaxy budget from 10 to 50 gal/arcmin2 and observe how the
optimal sky fraction shifts toward wider coverage.
2. Add a fourth tomographic bin and compare the improvement in sigma(S8).
3. Replace the Smail n(z) with a sharper distribution and see how the optimal
bin edges respond.
## Next Steps
This is the final example in the Umbra series. For earlier topics, revisit
[01-constants-and-units](../01-constants-and-units/) for physical constants and
unit handling, or [05-sed-fitting](../05-sed-fitting/) for the foundations of
differentiable spectral energy distribution fitting that this example builds on.
================================================
FILE: dev/umbra/examples/12-survey-optimization/dune
================================================
(executable
(name main)
(libraries nx rune vega umbra))
================================================
FILE: dev/umbra/examples/12-survey-optimization/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Differentiable survey optimization via autodiff gradients through the Fisher
information matrix.
Traditional survey optimization uses grid search over discrete Fisher
forecasts. Umbra's fully differentiable cosmology pipeline enables
gradient-based continuous optimization: compute Fisher(survey_params) and
minimize sigma(S8) with respect to survey parameters using exact autodiff
gradients from Rune.
Part 1: Area/depth tradeoff -- optimize f_sky with fixed n(z) shape. Part 2:
Joint area + bin edge optimization -- optimize f_sky and tomographic bin
edges simultaneously, with gradients flowing through the lensing kernel
computation via differentiable n(z) windowing. *)
open Nx
open Umbra
let f64 = Nx.float64
let sigma_e = 0.26
let steradian_to_arcmin2 = 11818102.86004228
let c_km_s = 299792.458
let h0_ref = 100.0
(* Fiducial cosmology *)
let p_fid = Cosmo.planck18
let omega_m_fid = Nx.item [] (Cosmo.omega_m p_fid)
let sigma8_fid = Nx.item [] (Cosmo.sigma8 p_fid)
(* S8 = sigma8 * sqrt(omega_m / 0.3) -- derivatives at fiducial *)
let ds8_dom = sigma8_fid /. (2.0 *. Float.sqrt (0.3 *. omega_m_fid))
let ds8_ds8 = Float.sqrt (omega_m_fid /. 0.3)
(* ell weights: (2*ell+1) * dell / 2 *)
let ell_weights ell =
let n_ell = (Nx.shape ell).(0) in
let dell =
Array.init n_ell (fun l ->
if l = 0 then Nx.item [ 1 ] ell -. Nx.item [ 0 ] ell
else if l = n_ell - 1 then Nx.item [ l ] ell -. Nx.item [ l - 1 ] ell
else 0.5 *. (Nx.item [ l + 1 ] ell -. Nx.item [ l - 1 ] ell))
in
Nx.create f64 [| n_ell |]
(Array.init n_ell (fun l ->
((2.0 *. Nx.item [ l ] ell) +. 1.0) *. dell.(l) /. 2.0))
(* Compute dCl/d(theta) via central finite differences *)
let finite_diff_cl ~ell ~tracers ~param_name ~set_param ~fid_val ~eps =
let p_plus = set_param (scalar f64 (fid_val +. eps)) p_fid in
let p_minus = set_param (scalar f64 (fid_val -. eps)) p_fid in
let cl_p =
Survey.Cls.to_tensor
(Survey.angular_cl ~p:p_plus ~power:Survey.linear ~ell tracers)
in
let cl_m =
Survey.Cls.to_tensor
(Survey.angular_cl ~p:p_minus ~power:Survey.linear ~ell tracers)
in
let dcl = Nx.div_s (Nx.sub cl_p cl_m) (2.0 *. eps) in
Printf.printf " dCl/d(%-8s): max=%.3e\n" param_name (Nx.item [] (Nx.max dcl));
dcl
(* 2x2 analytical Fisher inverse -> sigma(S8) -- all differentiable *)
let sigma_s8_from_fisher f11 f12 f22 =
let det = Nx.sub (Nx.mul f11 f22) (Nx.mul f12 f12) in
let a = scalar f64 ds8_dom and b = scalar f64 ds8_ds8 in
let sigma_sq =
Nx.div
(Nx.add
(Nx.sub
(Nx.mul f22 (Nx.mul a a))
(Nx.mul_s (Nx.mul f12 (Nx.mul a b)) 2.0))
(Nx.mul f11 (Nx.mul b b)))
det
in
Nx.sqrt sigma_sq
(* ===================================================================== *)
(* Part 1: Area/depth tradeoff (single bin) *)
(* ===================================================================== *)
let part1 () =
Printf.printf "--- Part 1: Area/Depth Tradeoff (1 bin) ---\n\n";
let budget = 10.0 in
let ell = Nx.logspace f64 1.0 3.0 30 in
let w_ell = ell_weights ell in
let nz = Survey.smail ~a:2.0 ~b:1.5 ~z0:0.3 () in
let wl = Survey.weak_lensing nz in
Printf.printf "Precomputing signal derivatives...\n";
let cl_fid =
Survey.Cls.to_tensor
(Survey.angular_cl ~p:p_fid ~power:Survey.linear ~ell [ wl ])
in
let cl_fid_flat = Nx.flatten cl_fid in
let eps = 1e-4 in
let dcl_dom =
Nx.flatten
(finite_diff_cl ~ell ~tracers:[ wl ] ~param_name:"omega_m"
~set_param:(fun v p -> Cosmo.set_t ~omega_m:v p)
~fid_val:omega_m_fid ~eps)
in
let dcl_ds8 =
Nx.flatten
(finite_diff_cl ~ell ~tracers:[ wl ] ~param_name:"sigma8"
~set_param:(fun v p -> Cosmo.set_t ~sigma8:v p)
~fid_val:sigma8_fid ~eps)
in
Printf.printf "\n";
let objective log_f_sky =
let f_sky = Nx.sigmoid log_f_sky in
let n_gal = Nx.div (scalar f64 budget) f_sky in
let noise =
Nx.div
(scalar f64 (sigma_e *. sigma_e))
(Nx.mul_s n_gal steradian_to_arcmin2)
in
let cl_obs = Nx.add cl_fid_flat noise in
let cl_obs_sq = Nx.mul cl_obs cl_obs in
let weighted_dom =
Nx.div (Nx.mul w_ell (Nx.mul dcl_dom dcl_dom)) cl_obs_sq
in
let weighted_ds8 =
Nx.div (Nx.mul w_ell (Nx.mul dcl_ds8 dcl_ds8)) cl_obs_sq
in
let weighted_x = Nx.div (Nx.mul w_ell (Nx.mul dcl_dom dcl_ds8)) cl_obs_sq in
let f11 = Nx.mul f_sky (Nx.sum weighted_dom) in
let f12 = Nx.mul f_sky (Nx.sum weighted_x) in
let f22 = Nx.mul f_sky (Nx.sum weighted_ds8) in
sigma_s8_from_fisher f11 f12 f22
in
(* Gradient check *)
let log_f_sky_init = scalar f64 0.0 in
let v0, g0 = Rune.value_and_grad objective log_f_sky_init in
let fd_eps = 1e-5 in
let vp = item [] (objective (scalar f64 fd_eps)) in
let vm = item [] (objective (scalar f64 (-.fd_eps))) in
let fd = (vp -. vm) /. (2.0 *. fd_eps) in
Printf.printf "Gradient check: AD=%.6e FD=%.6e rel=%.2e\n\n" (item [] g0) fd
(Float.abs (item [] g0 -. fd) /. Float.abs fd);
let f_sky_0 = 1.0 /. (1.0 +. Float.exp (-0.0)) in
Printf.printf "Initial: f_sky=%.3f n_gal=%.1f sigma(S8)=%.6f\n" f_sky_0
(budget /. f_sky_0) (item [] v0);
let algo = Vega.adam (Vega.Schedule.constant 0.01) in
let log_f_sky = ref log_f_sky_init in
let state = ref (Vega.init algo !log_f_sky) in
let best_sigma = ref (item [] v0) in
let best_f_sky = ref f_sky_0 in
Printf.printf "\n%5s %8s %8s %10s\n" "step" "f_sky" "n_gal" "sigma(S8)";
Printf.printf "%5s %8s %8s %10s\n" "-----" "--------" "--------"
"----------";
let steps = 300 in
for i = 0 to steps - 1 do
let sigma_val, grad = Rune.value_and_grad objective !log_f_sky in
let p, s = Vega.step !state ~grad ~param:!log_f_sky in
log_f_sky := p;
state := s;
let f_sky_cur = 1.0 /. (1.0 +. Float.exp (-.item [] !log_f_sky)) in
let sigma_cur = item [] sigma_val in
if sigma_cur < !best_sigma then begin
best_sigma := sigma_cur;
best_f_sky := f_sky_cur
end;
if i mod 50 = 0 || i = steps - 1 then
Printf.printf "%5d %8.4f %8.1f %10.6f\n" i f_sky_cur
(budget /. f_sky_cur) sigma_cur
done;
Printf.printf "\nOptimal: f_sky=%.4f n_gal=%.1f gal/arcmin2\n" !best_f_sky
(budget /. !best_f_sky);
Printf.printf "Improvement: sigma(S8) reduced by %.1f%% vs initial\n\n"
((1.0 -. (!best_sigma /. item [] v0)) *. 100.0)
(* ===================================================================== *)
(* Part 2: Joint area + bin edge optimization (3 bins) *)
(* ===================================================================== *)
(* Precomputed cosmological grids -- expensive, done once per cosmology. *)
type cosmo_grid = {
n_z : int;
dz : float;
z_arr : float array;
z_vec : Nx.float64_t;
chi_safe : Nx.float64_t;
omega_m_t : Nx.float64_t;
integ_weight : Nx.float64_t;
w_pk : Nx.float64_t;
ell_factor_sq : Nx.float64_t;
}
let precompute_grid ~p ~ell =
let zmax = 3.0 in
let n_z = 50 in
let dz = zmax /. Float.of_int (n_z - 1) in
let z_arr = Array.init n_z (fun i -> Float.of_int i *. dz) in
z_arr.(0) <- 1e-6;
let z_vec = Nx.create f64 [| n_z |] z_arr in
let sw =
Array.init n_z (fun i ->
if i = 0 || i = n_z - 1 then 1.0 else if i mod 2 = 1 then 4.0 else 2.0)
in
let simpson_w = Nx.mul_s (Nx.create f64 [| n_z |] sw) (dz /. 3.0) in
let h_t = Nx.item [] (Nx.div (Cosmo.h0 p) (Nx.scalar f64 h0_ref)) in
let chi_vec =
Nx.create f64 [| n_z |]
(Array.init n_z (fun j ->
let z_t = Nx.scalar f64 z_arr.(j) in
let chi =
Nx.item [] (Unit.Length.in_mpc (Cosmo.comoving_distance ~p z_t))
in
chi *. h_t))
in
let chi_safe = Nx.clamp ~min:1e-10 chi_vec in
let h_vec_f =
Array.init n_z (fun j ->
Nx.item [] (Cosmo.hubble ~p (Nx.scalar f64 z_arr.(j))))
in
let dchi_dz_vec =
Nx.create f64 [| n_z |]
(Array.init n_z (fun j -> h_t *. c_km_s /. h_vec_f.(j)))
in
let omega_m_t = Nx.scalar f64 (Nx.item [] (Cosmo.omega_m p)) in
let integ_weight =
Nx.create f64 [| n_z |]
(Array.init n_z (fun j ->
let sw_j = Nx.item [ j ] simpson_w in
let dchi_j = Nx.item [ j ] dchi_dz_vec in
let chi_j = Nx.item [ j ] chi_safe in
sw_j *. dchi_j /. (chi_j *. chi_j) /. (c_km_s *. c_km_s)))
in
let pk_grid =
Nx.stack
(List.init n_z (fun j ->
let z_t = Nx.scalar f64 z_arr.(j) in
let chi_j = Nx.item [ j ] chi_safe in
let k_vec = Nx.div_s (Nx.add_s ell 0.5) chi_j in
Cosmo.linear_power ~p k_vec z_t))
in
let w_pk =
Nx.mul
(Nx.reshape [| n_z; 1 |]
(Nx.create f64 [| n_z |]
(Array.init n_z (fun j -> Nx.item [ j ] integ_weight))))
pk_grid
in
let l = ell in
let num =
Nx.mul
(Nx.mul (Nx.sub_s l 1.0) l)
(Nx.mul (Nx.add_s l 1.0) (Nx.add_s l 2.0))
in
let den = Nx.mul (Nx.add_s l 0.5) (Nx.add_s l 0.5) in
let ell_factor = Nx.div (Nx.sqrt (Nx.abs num)) den in
let ell_factor_sq = Nx.mul ell_factor ell_factor in
{
n_z;
dz;
z_arr;
z_vec;
chi_safe;
omega_m_t;
integ_weight;
w_pk;
ell_factor_sq;
}
(* Reverse cumulative trapezoidal sum *)
let rev_cumtrapz f_vec n dz =
let left = Nx.slice [ R (0, n - 1) ] f_vec in
let right = Nx.slice [ R (1, n) ] f_vec in
let mid = Nx.mul_s (Nx.add left right) (0.5 *. dz) in
let partial = Nx.flip (Nx.cumsum ~axis:0 (Nx.flip mid)) in
Nx.concatenate [ partial; Nx.zeros f64 [| 1 |] ]
(* Fast WL-only angular Cl from precomputed cosmo grid + pre-evaluated n(z)
tensors. nz_tensors are [n_z] tensors, one per bin, evaluated on the z grid.
Differentiable through the n(z) values. *)
let fast_wl_cl grid nz_tensors =
let n_z = grid.n_z and dz = grid.dz in
let n_bins = Array.length nz_tensors in
(* Build WL kernels *)
let prefactor =
Nx.mul_s grid.omega_m_t (3.0 *. h0_ref *. h0_ref /. (2.0 *. c_km_s))
in
let one_plus_z = Nx.add_s grid.z_vec 1.0 in
let kernels =
Array.init n_bins (fun b ->
let nz_t = nz_tensors.(b) in
let a_vec = rev_cumtrapz nz_t n_z dz in
let nz_over_chi = Nx.div nz_t grid.chi_safe in
let b_vec = rev_cumtrapz nz_over_chi n_z dz in
let g_vec = Nx.sub a_vec (Nx.mul grid.chi_safe b_vec) in
Nx.mul prefactor (Nx.mul one_plus_z (Nx.mul grid.chi_safe g_vec)))
in
(* Limber integration for all pairs *)
let pairs = ref [] in
for i = 0 to n_bins - 1 do
for j = i to n_bins - 1 do
pairs := (i, j) :: !pairs
done
done;
let pairs = List.rev !pairs in
Nx.stack
(List.map
(fun (i, j) ->
let ki = Nx.reshape [| n_z; 1 |] kernels.(i) in
let kj = Nx.reshape [| n_z; 1 |] kernels.(j) in
let integrand = Nx.mul (Nx.mul ki kj) grid.w_pk in
Nx.mul grid.ell_factor_sq (Nx.sum ~axes:[ 0 ] integrand))
pairs)
(* Parent n(z): Smail distribution, evaluated as float *)
let parent_nz =
let a = 2.0 and b = 1.5 and z0 = 0.3 in
let raw z_f = (z_f ** a) *. Float.exp (-.((z_f /. z0) ** b)) in
let norm =
let n = 256 in
let h = 3.0 /. Float.of_int n in
let s = ref (raw 1e-6 +. raw 3.0) in
for i = 1 to n - 1 do
let x = Float.of_int i *. h in
let w = if i mod 2 = 1 then 4.0 else 2.0 in
s := !s +. (w *. raw x)
done;
!s *. h /. 3.0
in
fun z_f -> raw z_f /. norm
(* Build a differentiable bin n(z) with smooth sigmoid edges *)
let make_bin_eval z_lo z_hi delta z =
let parent_val = parent_nz (Nx.item [] z) in
if parent_val < 1e-30 then scalar f64 0.0
else
let lo_gate = Nx.sigmoid (Nx.div_s (Nx.sub z z_lo) delta) in
let hi_gate = Nx.sigmoid (Nx.div_s (Nx.sub z_hi z) delta) in
Nx.mul_s (Nx.mul lo_gate hi_gate) parent_val
let part2 () =
Printf.printf "--- Part 2: Joint Area + Bin Edges (3 bins) ---\n\n";
let budget = 10.0 in
let ell = Nx.logspace f64 1.0 3.0 20 in
let w_ell = ell_weights ell in
let eps = 1e-4 in
let delta = 0.03 in
Printf.printf "Precomputing cosmo grids (fiducial + 4 perturbations)...\n";
let grid_fid = precompute_grid ~p:p_fid ~ell in
let grid_p_om =
precompute_grid
~p:(Cosmo.set_t ~omega_m:(scalar f64 (omega_m_fid +. eps)) p_fid)
~ell
in
let grid_m_om =
precompute_grid
~p:(Cosmo.set_t ~omega_m:(scalar f64 (omega_m_fid -. eps)) p_fid)
~ell
in
let grid_p_s8 =
precompute_grid
~p:(Cosmo.set_t ~sigma8:(scalar f64 (sigma8_fid +. eps)) p_fid)
~ell
in
let grid_m_s8 =
precompute_grid
~p:(Cosmo.set_t ~sigma8:(scalar f64 (sigma8_fid -. eps)) p_fid)
~ell
in
Printf.printf "Done.\n\n";
let n_z = grid_fid.n_z in
let z_arr = grid_fid.z_arr in
let dz = grid_fid.dz in
let objective params =
let log_f_sky = Nx.get [ 0 ] params in
let z1 = Nx.get [ 1 ] params in
let z2 = Nx.get [ 2 ] params in
let f_sky = Nx.sigmoid log_f_sky in
let n_gal = Nx.div (scalar f64 budget) f_sky in
(* Differentiable n(z) bin functions *)
let nz_funs =
[|
make_bin_eval (scalar f64 0.0) z1 delta;
make_bin_eval z1 z2 delta;
make_bin_eval z2 (scalar f64 3.0) delta;
|]
in
(* Evaluate n(z) on z grid -- differentiable through bin edges *)
let nz_tensors =
Array.init 3 (fun b ->
Nx.stack
(List.init n_z (fun j -> nz_funs.(b) (Nx.scalar f64 z_arr.(j)))))
in
(* Galaxy fraction per bin: integral of window_i(z) n(z) dz. Parent n(z) is
normalized so this gives the fraction of total galaxies in each bin.
Differentiable through bin edges -- narrow bins get fewer galaxies. *)
let gal_fracs =
Array.init 3 (fun b ->
let nz_t = nz_tensors.(b) in
let left = Nx.slice [ R (0, n_z - 2) ] nz_t in
let right = Nx.slice [ R (1, n_z - 1) ] nz_t in
Nx.mul_s (Nx.sum (Nx.add left right)) (0.5 *. dz))
in
(* Per-bin noise: sigma_e^2 / (n_gal_bin * ster) where n_gal_bin = n_gal *
f_i. Bins with fewer galaxies have higher shot noise. *)
let noise_per_bin =
Array.init 3 (fun b ->
Nx.div
(scalar f64 (sigma_e *. sigma_e))
(Nx.mul_s (Nx.mul n_gal gal_fracs.(b)) steradian_to_arcmin2))
in
(* Fast Cl from precomputed grids -- only n(z) -> kernel is traced *)
let cl_fid = fast_wl_cl grid_fid nz_tensors in
let cl_p_om = fast_wl_cl grid_p_om nz_tensors in
let cl_m_om = fast_wl_cl grid_m_om nz_tensors in
let cl_p_s8 = fast_wl_cl grid_p_s8 nz_tensors in
let cl_m_s8 = fast_wl_cl grid_m_s8 nz_tensors in
let dcl_dom = Nx.div_s (Nx.sub cl_p_om cl_m_om) (2.0 *. eps) in
let dcl_ds8 = Nx.div_s (Nx.sub cl_p_s8 cl_m_s8) (2.0 *. eps) in
(* Full Fisher via Tr[C^-1 dC/dtheta_i C^-1 dC/dtheta_j] with analytical 3x3
inverse. Vectorized over ell: each matrix element is a [n_ell] tensor. *)
let n_bins = 3 in
(* Pair index: (i,j) -> spectrum row in cl arrays. Ordering: (0,0)=0,
(0,1)=1, (0,2)=2, (1,1)=3, (1,2)=4, (2,2)=5 *)
let pidx i j =
let a, b = if i <= j then (i, j) else (j, i) in
(a * ((2 * n_bins) - a - 1) / 2) + b
in
(* Build 3x3 C(ell) = Cl + N, stored as flat [9] of [n_ell] tensors *)
let c =
Array.init 9 (fun idx ->
let i = idx / 3 and j = idx mod 3 in
let cl_ij = Nx.slice [ I (pidx i j) ] cl_fid in
if i = j then Nx.add cl_ij noise_per_bin.(i) else cl_ij)
in
(* 3x3 inverse via cofactors / determinant *)
let det =
Nx.add
(Nx.sub
(Nx.mul c.(0) (Nx.sub (Nx.mul c.(4) c.(8)) (Nx.mul c.(5) c.(7))))
(Nx.mul c.(1) (Nx.sub (Nx.mul c.(3) c.(8)) (Nx.mul c.(5) c.(6)))))
(Nx.mul c.(2) (Nx.sub (Nx.mul c.(3) c.(7)) (Nx.mul c.(4) c.(6))))
in
let ci = Array.make 9 (scalar f64 0.0) in
ci.(0) <- Nx.div (Nx.sub (Nx.mul c.(4) c.(8)) (Nx.mul c.(5) c.(7))) det;
ci.(1) <- Nx.div (Nx.sub (Nx.mul c.(2) c.(7)) (Nx.mul c.(1) c.(8))) det;
ci.(2) <- Nx.div (Nx.sub (Nx.mul c.(1) c.(5)) (Nx.mul c.(2) c.(4))) det;
ci.(3) <- Nx.div (Nx.sub (Nx.mul c.(5) c.(6)) (Nx.mul c.(3) c.(8))) det;
ci.(4) <- Nx.div (Nx.sub (Nx.mul c.(0) c.(8)) (Nx.mul c.(2) c.(6))) det;
ci.(5) <- Nx.div (Nx.sub (Nx.mul c.(2) c.(3)) (Nx.mul c.(0) c.(5))) det;
ci.(6) <- Nx.div (Nx.sub (Nx.mul c.(3) c.(7)) (Nx.mul c.(4) c.(6))) det;
ci.(7) <- Nx.div (Nx.sub (Nx.mul c.(1) c.(6)) (Nx.mul c.(0) c.(7))) det;
ci.(8) <- Nx.div (Nx.sub (Nx.mul c.(0) c.(4)) (Nx.mul c.(1) c.(3))) det;
(* Build dC/dtheta matrices: symmetric, no noise term *)
let dc_om =
Array.init 9 (fun idx ->
Nx.slice [ I (pidx (idx / 3) (idx mod 3)) ] dcl_dom)
in
let dc_s8 =
Array.init 9 (fun idx ->
Nx.slice [ I (pidx (idx / 3) (idx mod 3)) ] dcl_ds8)
in
(* 3x3 matmul: (AB)_ij = sum_k A_ik B_kj, vectorized over ell *)
let mm a b =
Array.init 9 (fun idx ->
let i = idx / 3 and j = idx mod 3 in
Nx.add
(Nx.add (Nx.mul a.(i * 3) b.(j)) (Nx.mul a.((i * 3) + 1) b.(3 + j)))
(Nx.mul a.((i * 3) + 2) b.(6 + j)))
in
(* Tr[AB] = sum_ij A_ij B_ji, returns [n_ell] tensor *)
let tr a b =
let t = ref (Nx.mul a.(0) b.(0)) in
for i = 0 to 2 do
for j = 0 to 2 do
if i > 0 || j > 0 then
t := Nx.add !t (Nx.mul a.((i * 3) + j) b.((j * 3) + i))
done
done;
!t
in
(* D1 = C^-1 dC/d(Omega_m), D2 = C^-1 dC/d(sigma8) *)
let d1 = mm ci dc_om in
let d2 = mm ci dc_s8 in
(* F_ij = f_sky * sum_ell w_ell * Tr[D_i D_j] *)
let f11 = Nx.mul f_sky (Nx.sum (Nx.mul w_ell (tr d1 d1))) in
let f12 = Nx.mul f_sky (Nx.sum (Nx.mul w_ell (tr d1 d2))) in
let f22 = Nx.mul f_sky (Nx.sum (Nx.mul w_ell (tr d2 d2))) in
sigma_s8_from_fisher f11 f12 f22
in
let params = Nx.create f64 [| 3 |] [| -1.1; 0.5; 1.0 |] in
Printf.printf "Computing initial sigma(S8)...\n";
let v0 = item [] (objective params) in
let f_sky_0 = 1.0 /. (1.0 +. Float.exp 1.1) in
Printf.printf
"Initial: f_sky=%.3f bins=[0, 0.50, 1.00, 3.0] sigma(S8)=%.6f\n\n" f_sky_0
v0;
let algo = Vega.adam (Vega.Schedule.constant 0.03) in
let params = ref params in
let state = ref (Vega.init algo !params) in
let best_sigma = ref v0 in
let best_params = ref !params in
Printf.printf "%5s %8s %8s %8s %10s\n" "step" "f_sky" "z1" "z2"
"sigma(S8)";
Printf.printf "%5s %8s %8s %8s %10s\n" "-----" "--------" "--------"
"--------" "----------";
let steps = 500 in
for i = 0 to steps - 1 do
let sigma_val, grad = Rune.value_and_grad objective !params in
let p, s = Vega.step !state ~grad ~param:!params in
let z1 = Float.max 0.1 (Float.min 2.8 (item [ 1 ] p)) in
let z2 = Float.max (z1 +. 0.1) (Float.min 2.9 (item [ 2 ] p)) in
params := Nx.create f64 [| 3 |] [| item [ 0 ] p; z1; z2 |];
state := s;
let sigma_cur = item [] sigma_val in
if sigma_cur < !best_sigma then begin
best_sigma := sigma_cur;
best_params := !params
end;
if i mod 50 = 0 || i = steps - 1 then begin
let f_sky = 1.0 /. (1.0 +. Float.exp (-.item [ 0 ] !params)) in
Printf.printf "%5d %8.4f %8.3f %8.3f %10.6f\n" i f_sky
(item [ 1 ] !params) (item [ 2 ] !params) sigma_cur
end
done;
let f_sky_opt = 1.0 /. (1.0 +. Float.exp (-.item [ 0 ] !best_params)) in
Printf.printf
"\nGrad optimal: f_sky=%.4f bins=[0, %.2f, %.2f, 3.0] sigma(S8)=%.6f\n"
f_sky_opt (item [ 1 ] !best_params) (item [ 2 ] !best_params) !best_sigma;
(* Grid search validation *)
let grid_best_sigma = ref infinity in
let grid_best_fs = ref 0.0 in
let grid_best_z1 = ref 0.0 in
let grid_best_z2 = ref 0.0 in
let n_fs = 12 and n_z1 = 15 and n_z2 = 15 in
let n_grid_evals = ref 0 in
Printf.printf "\nGrid search (%d*%d*%d)...\n%!" n_fs n_z1 n_z2;
for fi = 0 to n_fs - 1 do
let fs = 0.1 +. (Float.of_int fi *. 0.88 /. Float.of_int (n_fs - 1)) in
let log_fs = Float.log (fs /. (1.0 -. fs)) in
for z1i = 0 to n_z1 - 1 do
let z1_v = 0.2 +. (Float.of_int z1i *. 2.4 /. Float.of_int (n_z1 - 1)) in
for z2i = 0 to n_z2 - 1 do
let z2_v =
z1_v +. 0.15
+. (Float.of_int z2i *. (2.7 -. z1_v) /. Float.of_int (n_z2 - 1))
in
if z2_v > z1_v +. 0.1 && z2_v < 2.9 then begin
incr n_grid_evals;
let p = Nx.create f64 [| 3 |] [| log_fs; z1_v; z2_v |] in
let s = item [] (objective p) in
if s < !grid_best_sigma then begin
grid_best_sigma := s;
grid_best_fs := fs;
grid_best_z1 := z1_v;
grid_best_z2 := z2_v
end
end
done
done
done;
Printf.printf
"Grid optimal: f_sky=%.4f bins=[0, %.2f, %.2f, 3.0] sigma(S8)=%.6f (%d \
evals)\n"
!grid_best_fs !grid_best_z1 !grid_best_z2 !grid_best_sigma !n_grid_evals;
Printf.printf "\nComparison:\n";
Printf.printf " Gradient: sigma(S8)=%.6f (%d evals)\n" !best_sigma steps;
Printf.printf " Grid: sigma(S8)=%.6f (%d evals)\n" !grid_best_sigma
!n_grid_evals;
let rel = (1.0 -. (!best_sigma /. !grid_best_sigma)) *. 100.0 in
if rel >= 0.0 then
Printf.printf " Gradient %.1f%% better with %.0f* fewer evaluations\n" rel
(Float.of_int !n_grid_evals /. Float.of_int steps)
else
Printf.printf
" Gradient within %.1f%% of grid with %.0f* fewer evaluations\n"
(Float.abs rel)
(Float.of_int !n_grid_evals /. Float.of_int steps)
let () =
Printf.printf "=== Differentiable Survey Optimization ===\n";
Printf.printf "Stage IV Weak Lensing Survey\n\n";
part1 ();
part2 ()
================================================
FILE: dev/umbra/examples/README.md
================================================
# Umbra Examples
Learn Umbra through progressively complex examples. Start with
`01-constants-and-units` and work through the numbered examples in order.
## Examples
| Example | Concept | Key Functions |
|---------|---------|---------------|
| [`01-constants-and-units`](./01-constants-and-units/) | Type-safe physical quantities, conversions, constants | `Unit.Length.of_m`, `Const.c`, `Unit.Angle.deg` |
| [`02-cosmological-distances`](./02-cosmological-distances/) | LCDM distances, SN Ia fitting | `Cosmo.luminosity_distance`, `Cosmo.distance_modulus` |
| [`03-blackbody-fitting`](./03-blackbody-fitting/) | Fit stellar temperature from photometry | `Spectrum.blackbody`, `Photometry.ab_mag` |
| [`04-extinction-and-magnitudes`](./04-extinction-and-magnitudes/) | Dust extinction, magnitude systems, K-corrections | `Extinction.ccm89`, `Photometry.vega_mag`, `Photometry.color` |
| [`05-sed-fitting`](./05-sed-fitting/) | Full SED pipeline: blackbody, extinction, photometry | `Spectrum.blackbody`, `Extinction.apply`, `Photometry.ab_mag` |
| [`06-coordinates-and-time`](./06-coordinates-and-time/) | Frame transforms, time scales, observer geometry | `Coord.galactic_of_icrs`, `Time.of_iso`, `Altaz.airmass` |
| [`07-batch-photometry`](./07-batch-photometry/) | Batched operations over temperature and extinction grids | `Spectrum.blackbody`, `Extinction.apply`, `Photometry.ab_mag` |
| [`08-photometric-redshifts`](./08-photometric-redshifts/) | Two-stage photo-z: grid search + gradient refinement | `Spectrum.redshift`, `Photometry.ab_mag`, `Rune.value_and_grad` |
| [`09-gravitational-lensing`](./09-gravitational-lensing/) | Point-mass lens model parameter fitting | `Rune.value_and_grad`, `Vega.adam` |
| [`10-uncertainty-propagation`](./10-uncertainty-propagation/) | AD Jacobians for error propagation vs Monte Carlo | `Rune.jacfwd`, `Cosmo.distance_modulus` |
| [`11-bayesian-sed`](./11-bayesian-sed/) | Fisher matrix + HMC posterior sampling | `Rune.jacrev`, `Norn.hmc` |
| [`12-survey-optimization`](./12-survey-optimization/) | Differentiable Fisher forecasting for survey design | `Survey.angular_cl`, `Cosmo.linear_power` |
## Running Examples
All examples can be run with:
```bash
cd dev/umbra
dune exec --root . examples//main.exe
```
For example:
```bash
cd dev/umbra
dune exec --root . examples/01-constants-and-units/main.exe
```
## Quick Reference
### Cosmological Distances
```ocaml
open Umbra
let cosmo = Cosmo.planck18 in
let z = Nx.scalar Nx.float64 0.5 in
let dl = Cosmo.luminosity_distance cosmo z
```
### Synthetic Photometry
```ocaml
let sed =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar f64 5800.0))
~wavelength:wave
|> Extinction.apply (Extinction.ccm89 ~rv) ~av
|> Spectrum.as_flux_density
in
let mag = Photometry.ab_mag (Filters.sdss_r ()) sed
```
### Coordinate Transforms
```ocaml
let ra = Unit.Angle.deg 83.633 in
let dec = Unit.Angle.deg (-5.550) in
let l, b = Coord.galactic_of_icrs ra dec
```
================================================
FILE: dev/umbra/lib/altaz.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let pi = Float.pi
type observer = { lat : float; lon : float; height : float }
let make_observer ~lat ~lon ?(height = Unit.Length.m 0.0) () =
let lat = Nx.item [] (Unit.Angle.to_tensor lat) in
let lon = Nx.item [] (Unit.Angle.to_tensor lon) in
let height = Nx.item [] (Unit.Length.to_tensor height) in
{ lat; lon; height }
let observer_height obs =
Unit.Length.of_tensor (Nx.scalar Nx.float64 obs.height)
type t = { az : Nx.float64_t; alt : Nx.float64_t }
let alt t = Unit.Angle.of_tensor t.alt
let az t = Unit.Angle.of_tensor t.az
(* Earth Rotation Angle from UT1 Julian Date. ERA = 2π(0.7790572732640 +
1.00273781191135448 * Du) where Du = JD_UT1 - 2451545.0 *)
let era jd_ut1 =
let du = jd_ut1 -. 2_451_545.0 in
let theta =
2.0 *. pi *. (0.779_057_273_264_0 +. (1.002_737_811_911_354_48 *. du))
in
Float.rem theta (2.0 *. pi)
(* IAU 2006 precession angles (Capitaine et al. 2003). T = Julian centuries from
J2000.0 TT. Returns (zeta_A, z_A, theta_A) in radians. *)
let precession_angles t_cy =
let arcsec_to_rad x = x *. pi /. 648_000.0 in
let t2 = t_cy *. t_cy in
let t3 = t2 *. t_cy in
(* zeta_A = 2.5976176'' + 2306.0809506''T + 1.0109032''T² + 0.0182337''T³ *)
let zeta_a =
arcsec_to_rad
(2.597_617_6 +. (2306.080_950_6 *. t_cy) +. (1.010_903_2 *. t2)
+. (0.018_233_7 *. t3))
in
(* z_A = -2.5976176'' + 2306.0803226''T + 1.0947790''T² + 0.0182273''T³ *)
let z_a =
arcsec_to_rad
(~-.2.597_617_6 +. (2306.080_322_6 *. t_cy) +. (1.094_779_0 *. t2)
+. (0.018_227_3 *. t3))
in
(* theta_A = 2004.1917476''T - 0.4269353''T² - 0.0418251''T³ *)
let theta_a =
arcsec_to_rad
((2004.191_747_6 *. t_cy) -. (0.426_935_3 *. t2) -. (0.041_825_1 *. t3))
in
(zeta_a, z_a, theta_a)
(* Apply IAU 2006 precession matrix to ICRS (RA, Dec) → mean (RA, Dec) of date.
R = Rz(-z_A) · Ry(theta_A) · Rz(-zeta_A) *)
let precess_to_date ra dec t_cy =
let zeta_a, z_a, theta_a = precession_angles t_cy in
let sz = Float.sin zeta_a and cz = Float.cos zeta_a in
let sa = Float.sin z_a and ca = Float.cos z_a in
let st = Float.sin theta_a and ct = Float.cos theta_a in
(* Rotation matrix elements *)
let r11 = (ca *. ct *. cz) -. (sa *. sz) in
let r12 = ~-.((ca *. ct *. sz) +. (sa *. cz)) in
let r13 = ~-.(ca *. st) in
let r21 = (sa *. ct *. cz) +. (ca *. sz) in
let r22 = ~-.((sa *. ct *. sz) -. (ca *. cz)) in
let r23 = ~-.(sa *. st) in
let r31 = st *. cz in
let r32 = ~-.(st *. sz) in
let r33 = ct in
let n = Nx.numel ra in
let ra_out = Nx.zeros Nx.float64 [| n |] in
let dec_out = Nx.zeros Nx.float64 [| n |] in
for i = 0 to n - 1 do
let r = Nx.item [ i ] ra in
let d = Nx.item [ i ] dec in
let cd = Float.cos d in
let x = cd *. Float.cos r in
let y = cd *. Float.sin r in
let z = Float.sin d in
let x' = (r11 *. x) +. (r12 *. y) +. (r13 *. z) in
let y' = (r21 *. x) +. (r22 *. y) +. (r23 *. z) in
let z' = (r31 *. x) +. (r32 *. y) +. (r33 *. z) in
Nx.set_item [ i ] (Float.atan2 y' x') ra_out;
Nx.set_item [ i ] (Float.asin (Float.max ~-.1.0 (Float.min 1.0 z'))) dec_out
done;
(ra_out, dec_out)
let airmass hz =
let n = Nx.numel hz.alt in
let out = Nx.zeros Nx.float64 [| n |] in
let to_deg = 180.0 /. pi in
for i = 0 to n - 1 do
let alt_deg = Nx.item [ i ] hz.alt *. to_deg in
(* Pickering (2002): X = 1 / sin(h + 244/(165 + 47h^1.1)) where h in deg *)
let arg =
alt_deg
+. (244.0 /. (165.0 +. (47.0 *. Float.pow (Float.abs alt_deg) 1.1)))
in
let x = 1.0 /. Float.sin (arg *. pi /. 180.0) in
Nx.set_item [ i ] (Float.max 1.0 x) out
done;
out
(* Bennett (1982) atmospheric refraction for geometric altitude. R (arcmin) =
cot(h + 7.31/(h + 4.4)) where h in degrees. Returns refraction in radians.
Clamps to 0 below -1°. *)
let refraction_correction alt_rad =
let h = alt_rad *. 180.0 /. pi in
if h < -1.0 then 0.0
else
let arg = (h +. (7.31 /. (h +. 4.4))) *. pi /. 180.0 in
let r_arcmin = 1.0 /. Float.tan arg in
r_arcmin *. pi /. (180.0 *. 60.0)
let refraction hz =
let n = Nx.numel hz.alt in
let out = Nx.zeros Nx.float64 [| n |] in
for i = 0 to n - 1 do
Nx.set_item [ i ] (refraction_correction (Nx.item [ i ] hz.alt)) out
done;
Unit.Angle.of_tensor out
let of_coord ?(refraction = false) ~obstime ~observer c =
let icrs = Coord.icrs c in
let ra_rad = Unit.Angle.to_tensor (Coord.lon icrs) in
let dec_rad = Unit.Angle.to_tensor (Coord.lat icrs) in
(* Convert UTC → UT1 (ignoring DUT1 < 1s) then to TT for precession *)
let jd_utc = Time.to_jd obstime in
let jd_ut1 = jd_utc in
let jd_tt = Time.to_jd (Time.tai_to_tt (Time.utc_to_tai obstime)) in
let t_cy = (jd_tt -. 2_451_545.0) /. 36_525.0 in
(* Precess ICRS to mean RA/Dec of date *)
let ra_date, dec_date = precess_to_date ra_rad dec_rad t_cy in
(* Hour angle: HA = ERA + observer_lon - RA_date *)
let era_val = era jd_ut1 in
let n = Nx.numel ra_rad in
let alt_out = Nx.zeros Nx.float64 [| n |] in
let az_out = Nx.zeros Nx.float64 [| n |] in
let slat = Float.sin observer.lat and clat = Float.cos observer.lat in
for i = 0 to n - 1 do
let ha = era_val +. observer.lon -. Nx.item [ i ] ra_date in
let dec = Nx.item [ i ] dec_date in
let sdec = Float.sin dec and cdec = Float.cos dec in
let sha = Float.sin ha and cha = Float.cos ha in
(* alt = asin(sin(lat)sin(dec) + cos(lat)cos(dec)cos(ha)) *)
let sin_alt = (slat *. sdec) +. (clat *. cdec *. cha) in
let alt = Float.asin (Float.max ~-.1.0 (Float.min 1.0 sin_alt)) in
(* az = atan2(-cos(dec)sin(ha), cos(lat)sin(dec) -
sin(lat)cos(dec)cos(ha)) *)
let num = ~-.(cdec *. sha) in
let den = (clat *. sdec) -. (slat *. cdec *. cha) in
let az = Float.atan2 num den in
let az = if az < 0.0 then az +. (2.0 *. pi) else az in
let alt = if refraction then alt +. refraction_correction alt else alt in
Nx.set_item [ i ] alt alt_out;
Nx.set_item [ i ] az az_out
done;
{ alt = alt_out; az = az_out }
let to_coord ~obstime ~observer t =
let jd_utc = Time.to_jd obstime in
let jd_ut1 = jd_utc in
let jd_tt = Time.to_jd (Time.tai_to_tt (Time.utc_to_tai obstime)) in
let t_cy = (jd_tt -. 2_451_545.0) /. 36_525.0 in
let era_val = era jd_ut1 in
let slat = Float.sin observer.lat and clat = Float.cos observer.lat in
let zeta_a, z_a, theta_a = precession_angles t_cy in
(* Inverse precession matrix = transpose of forward *)
let sz = Float.sin zeta_a and cz = Float.cos zeta_a in
let sa = Float.sin z_a and ca = Float.cos z_a in
let st = Float.sin theta_a and ct = Float.cos theta_a in
let r11 = (ca *. ct *. cz) -. (sa *. sz) in
let r12 = ~-.((ca *. ct *. sz) +. (sa *. cz)) in
let r13 = ~-.(ca *. st) in
let r21 = (sa *. ct *. cz) +. (ca *. sz) in
let r22 = ~-.((sa *. ct *. sz) -. (ca *. cz)) in
let r23 = ~-.(sa *. st) in
let r31 = st *. cz in
let r32 = ~-.(st *. sz) in
let r33 = ct in
(* Transpose for inverse *)
let ri11 = r11 and ri12 = r21 and ri13 = r31 in
let ri21 = r12 and ri22 = r22 and ri23 = r32 in
let ri31 = r13 and ri32 = r23 and ri33 = r33 in
let n = Nx.numel t.alt in
let ra_out = Nx.zeros Nx.float64 [| n |] in
let dec_out = Nx.zeros Nx.float64 [| n |] in
for i = 0 to n - 1 do
let alt = Nx.item [ i ] t.alt in
let az = Nx.item [ i ] t.az in
let salt = Float.sin alt and calt = Float.cos alt in
let saz = Float.sin az and caz = Float.cos az in
(* (Alt, Az) → (HA, Dec) *)
let sin_dec = (slat *. salt) +. (clat *. calt *. caz) in
let dec = Float.asin (Float.max ~-.1.0 (Float.min 1.0 sin_dec)) in
let num = ~-.(calt *. saz) in
let den = (clat *. salt) -. (slat *. calt *. caz) in
let ha = Float.atan2 num den in
(* RA_date = ERA + observer_lon - HA *)
let ra_date = era_val +. observer.lon -. ha in
(* Deprecess: mean of date → ICRS *)
let cd = Float.cos dec in
let x = cd *. Float.cos ra_date in
let y = cd *. Float.sin ra_date in
let z = Float.sin dec in
let x' = (ri11 *. x) +. (ri12 *. y) +. (ri13 *. z) in
let y' = (ri21 *. x) +. (ri22 *. y) +. (ri23 *. z) in
let z' = (ri31 *. x) +. (ri32 *. y) +. (ri33 *. z) in
let ra = Float.atan2 y' x' in
let ra = if ra < 0.0 then ra +. (2.0 *. pi) else ra in
let dec = Float.asin (Float.max ~-.1.0 (Float.min 1.0 z')) in
Nx.set_item [ i ] ra ra_out;
Nx.set_item [ i ] dec dec_out
done;
Coord.of_radec
~ra:(Unit.Angle.of_tensor ra_out)
~dec:(Unit.Angle.of_tensor dec_out)
================================================
FILE: dev/umbra/lib/altaz.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Altitude-azimuth (horizontal) coordinates.
Converts celestial coordinates to local horizon coordinates for a given
observer location and time. Uses IAU 2006 precession (Capitaine et al. 2003)
and the Earth Rotation Angle.
{b Warning.} Nutation and polar motion are omitted. Atmospheric refraction
can be applied via {!refraction} or the [~refraction] parameter of
{!of_coord}. Accuracy is ~1 arcminute for dates within a few centuries of
J2000.0.
{[
let obs = Altaz.make_observer ~lat:(Unit.Angle.deg 28.7624) ~lon:(Unit.Angle.deg (-17.8792)) () in
let t = Time.of_iso "2024-06-21T22:00:00" in
let vega =
Coord.of_radec
~ra:(Unit.Angle.deg 279.2347)
~dec:(Unit.Angle.deg 38.7837)
in
let hz = Altaz.of_coord ~obstime:t ~observer:obs vega in
let alt_deg = Nx.item [] (Unit.Angle.in_deg (Altaz.alt hz))
]} *)
(** {1:observer Observer} *)
type observer
(** The type for a ground-based observer location. *)
val make_observer :
lat:Unit.angle Unit.t ->
lon:Unit.angle Unit.t ->
?height:Unit.length Unit.t ->
unit ->
observer
(** [make_observer ~lat ~lon ?height ()] is an observer at geodetic latitude
[lat], longitude [lon], and elevation [height] above the reference
ellipsoid. [lon] is positive East. [height] defaults to sea level.
[height] is stored for forward compatibility but does not yet affect
coordinate transforms. *)
val observer_height : observer -> Unit.length Unit.t
(** [observer_height obs] is the observer's elevation above the reference
ellipsoid. *)
(** {1:coords Horizontal coordinates} *)
type t
(** The type for altitude-azimuth coordinates. Azimuth is measured from North
through East. *)
val alt : t -> Unit.angle Unit.t
(** [alt t] is the altitude (elevation above the horizon). *)
val az : t -> Unit.angle Unit.t
(** [az t] is the azimuth (North = 0, East = 90 deg). *)
(** {1:derived Derived quantities} *)
val airmass : t -> Nx.float64_t
(** [airmass hz] is the airmass at the altitude of [hz] using the Pickering
(2002) formula. Well-behaved from zenith to horizon. Not differentiable
(operates on float-level altitude values). *)
(** {1:refraction Atmospheric refraction} *)
val refraction : t -> Unit.angle Unit.t
(** [refraction hz] is the atmospheric refraction correction at the geometric
altitude of [hz], using the Bennett (1982) formula. The correction is
positive (objects appear higher than their geometric position). Returns zero
for altitudes below -1°.
Not differentiable (scalar-level trigonometry). *)
(** {1:converting Converting} *)
val of_coord :
?refraction:bool ->
obstime:Time.utc Time.t ->
observer:observer ->
Coord.t ->
t
(** [of_coord ~obstime ~observer c] converts celestial coordinates [c] to
altitude-azimuth for [observer] at [obstime]. Applies IAU 2006 precession to
move from ICRS to the mean equator of date.
When [refraction] is [true], the Bennett (1982) atmospheric refraction
correction is applied to the computed altitude. [refraction] defaults to
[false].
Not differentiable (scalar-level rotation matrices). *)
val to_coord : obstime:Time.utc Time.t -> observer:observer -> t -> Coord.t
(** [to_coord ~obstime ~observer t] converts altitude-azimuth coordinates [t]
back to ICRS celestial coordinates for [observer] at [obstime]. Not
differentiable (scalar-level rotation matrices). *)
================================================
FILE: dev/umbra/lib/const.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Fundamental constants (CODATA 2022) *)
let c = Unit.Velocity.m_s 299_792_458.0
let m_e = Unit.Mass.kg 9.109_383_713_9e-31
let m_p = Unit.Mass.kg 1.672_621_923_69e-27
let m_n = Unit.Mass.kg 1.674_927_498_04e-27
let u = Unit.Mass.kg 1.660_539_066_60e-27
(* Astronomical constants (IAU 2015) *)
let au = Unit.Length.au 1.0
let pc = Unit.Length.pc 1.0
let solar_mass = Unit.Mass.solar_mass 1.0
let solar_radius = Unit.Length.solar_radius 1.0
let solar_luminosity = Unit.Power.solar_luminosity 1.0
let earth_mass = Unit.Mass.earth_mass 1.0
let earth_radius = Unit.Length.earth_radius 1.0
let jupiter_mass = Unit.Mass.jupiter_mass 1.0
let jupiter_radius = Unit.Length.jupiter_radius 1.0
(* Raw SI floats for compound dimensions (CODATA 2022) *)
let h_si = 6.626_070_15e-34
let hbar_si = 1.054_571_817e-34
let g_si = 6.674_30e-11
let k_b_si = 1.380_649e-23
let sigma_sb_si = 5.670_374_419e-8
let n_a = 6.022_140_76e23
let sigma_t_si = 6.652_458_705_1e-29
let b_wien_si = 2.897_771_955e-3
let alpha = 7.297_352_5643e-3
let a_0 = Unit.Length.m 5.291_772_105_44e-11
let gm_sun_si = 1.327_124_4e20
let gm_earth_si = 3.986_004e14
let gm_jup_si = 1.266_865_3e17
let l_bol0 = Unit.Power.w 3.0128e28
================================================
FILE: dev/umbra/lib/const.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Physical and astronomical constants.
Typed constants use {!Unit.t} with the appropriate phantom dimension. Raw SI
floats are provided for compound dimensions that do not map to a single
{!Unit} dimension type.
Fundamental constants follow
{{:https://physics.nist.gov/cuu/Constants/}CODATA 2022}. Astronomical
constants follow IAU 2015. *)
(** {1:fundamental Fundamental constants} *)
val c : Unit.velocity Unit.t
(** [c] is the speed of light in vacuum (299 792 458 m/s, exact). *)
(** {1:particle Particle masses} *)
val m_e : Unit.mass Unit.t
(** [m_e] is the electron mass (9.109 383 7139e-31 kg). *)
val m_p : Unit.mass Unit.t
(** [m_p] is the proton mass (1.672 621 923 69e-27 kg). *)
val m_n : Unit.mass Unit.t
(** [m_n] is the neutron mass (1.674 927 498 04e-27 kg). *)
val u : Unit.mass Unit.t
(** [u] is the atomic mass unit (1.660 539 066 60e-27 kg). *)
(** {1:astro Astronomical constants} *)
val au : Unit.length Unit.t
(** [au] is one astronomical unit. *)
val pc : Unit.length Unit.t
(** [pc] is one parsec. *)
val solar_mass : Unit.mass Unit.t
(** [solar_mass] is one solar mass. *)
val solar_radius : Unit.length Unit.t
(** [solar_radius] is one solar radius. *)
val solar_luminosity : Unit.power Unit.t
(** [solar_luminosity] is one solar luminosity. *)
val earth_mass : Unit.mass Unit.t
(** [earth_mass] is one Earth mass. *)
val earth_radius : Unit.length Unit.t
(** [earth_radius] is one Earth radius. *)
val jupiter_mass : Unit.mass Unit.t
(** [jupiter_mass] is one Jupiter mass. *)
val jupiter_radius : Unit.length Unit.t
(** [jupiter_radius] is one Jupiter radius. *)
(** {1:si Raw SI constants}
Constants with compound dimensions that do not map to a single {!Unit}
dimension type. CODATA 2022 values. *)
val h_si : float
(** [h_si] is the Planck constant (6.626 070 15e-34 J s, exact). *)
val hbar_si : float
(** [hbar_si] is the reduced Planck constant (1.054 571 817e-34 J s). *)
val g_si : float
(** [g_si] is the gravitational constant (6.674 30e-11 m{^ 3} kg{^ -1} s{^ -2}).
*)
val k_b_si : float
(** [k_b_si] is the Boltzmann constant (1.380 649e-23 J K{^ -1}, exact). *)
val sigma_sb_si : float
(** [sigma_sb_si] is the Stefan-Boltzmann constant (5.670 374 419e-8 W m{^ -2}
K{^ -4}). *)
val n_a : float
(** [n_a] is the Avogadro constant (6.022 140 76e23 mol{^ -1}, exact). *)
val sigma_t_si : float
(** [sigma_t_si] is the Thomson scattering cross-section (6.652 458 705 1e-29
m{^ 2}). *)
val b_wien_si : float
(** [b_wien_si] is the Wien displacement law constant (2.897 771 955e-3 m K). *)
val alpha : float
(** [alpha] is the fine-structure constant (7.297 352 5643e-3). *)
val a_0 : Unit.length Unit.t
(** [a_0] is the Bohr radius (5.291 772 105 44e-11 m). *)
val gm_sun_si : float
(** [gm_sun_si] is the solar mass parameter (1.327 124 4e20 m{^ 3} s{^ -2}).
More precise than [g_si * solar_mass] for orbital mechanics. *)
val gm_earth_si : float
(** [gm_earth_si] is the Earth mass parameter (3.986 004e14 m{^ 3} s{^ -2}). *)
val gm_jup_si : float
(** [gm_jup_si] is the Jupiter mass parameter (1.266 865 3e17 m{^ 3} s{^ -2}).
*)
val l_bol0 : Unit.power Unit.t
(** [l_bol0] is the IAU 2015 zero-point bolometric luminosity (3.0128e28 W). *)
================================================
FILE: dev/umbra/lib/coord.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let pi = Float.pi
let deg_to_rad = pi /. 180.0
let two_pi = Nx.scalar Nx.float64 (2.0 *. pi)
type frame = ICRS | Galactic | Ecliptic_j2000 | Supergalactic
(* Internally stores lon/lat in radians *)
type t = { frame : frame; lon : Nx.float64_t; lat : Nx.float64_t }
(* IAU rotation matrices *)
let ra_gp = 192.85948 *. deg_to_rad
let dec_gp = 27.12825 *. deg_to_rad
let l_ncp = 122.93192 *. deg_to_rad
let icrs_to_gal =
let sd = Float.sin dec_gp and cd = Float.cos dec_gp in
let sa = Float.sin ra_gp and ca = Float.cos ra_gp in
let sl = Float.sin l_ncp and cl = Float.cos l_ncp in
[|
[|
(~-.sl *. sa) -. (cl *. ca *. sd);
(sl *. ca) -. (cl *. sa *. sd);
cl *. cd;
|];
[|
(cl *. sa) -. (sl *. ca *. sd);
(~-.cl *. ca) -. (sl *. sa *. sd);
sl *. cd;
|];
[| ca *. cd; sa *. cd; sd |];
|]
let transpose_3x3 m =
[|
[| m.(0).(0); m.(1).(0); m.(2).(0) |];
[| m.(0).(1); m.(1).(1); m.(2).(1) |];
[| m.(0).(2); m.(1).(2); m.(2).(2) |];
|]
let gal_to_icrs = transpose_3x3 icrs_to_gal
(* Fixed J2000.0 mean obliquity: 23.4392911 degrees *)
let obliquity = 23.4392911 *. deg_to_rad
let icrs_to_ecl =
let se = Float.sin obliquity and ce = Float.cos obliquity in
[| [| 1.0; 0.0; 0.0 |]; [| 0.0; ce; se |]; [| 0.0; ~-.se; ce |] |]
let ecl_to_icrs = transpose_3x3 icrs_to_ecl
(* Supergalactic: defined relative to Galactic. SGP at (l=47.37, b=6.32), SGL
origin at l=137.37 *)
let sgl_l0 = 137.37 *. deg_to_rad
let sgp_l = 47.37 *. deg_to_rad
let sgp_b = 6.32 *. deg_to_rad
let gal_to_sgal =
let sb = Float.sin sgp_b and cb = Float.cos sgp_b in
let sl = Float.sin sgp_l and cl = Float.cos sgp_l in
let sl0 = Float.sin sgl_l0 and cl0 = Float.cos sgl_l0 in
let r00 = (~-.sl0 *. sl) -. (cl0 *. cl *. sb) in
let r01 = (sl0 *. cl) -. (cl0 *. sl *. sb) in
let r02 = cl0 *. cb in
let r10 = (cl0 *. sl) -. (sl0 *. cl *. sb) in
let r11 = (~-.cl0 *. cl) -. (sl0 *. sl *. sb) in
let r12 = sl0 *. cb in
let r20 = cl *. cb in
let r21 = sl *. cb in
let r22 = sb in
[| [| r00; r01; r02 |]; [| r10; r11; r12 |]; [| r20; r21; r22 |] |]
let sgal_to_gal = transpose_3x3 gal_to_sgal
let rotate mat lon_rad lat_rad =
let cl = Nx.cos lat_rad and sl = Nx.sin lat_rad in
let ca = Nx.cos lon_rad and sa = Nx.sin lon_rad in
let x = Nx.mul cl ca and y = Nx.mul cl sa in
let x' =
Nx.add
(Nx.add (Nx.mul_s x mat.(0).(0)) (Nx.mul_s y mat.(0).(1)))
(Nx.mul_s sl mat.(0).(2))
in
let y' =
Nx.add
(Nx.add (Nx.mul_s x mat.(1).(0)) (Nx.mul_s y mat.(1).(1)))
(Nx.mul_s sl mat.(1).(2))
in
let z' =
Nx.add
(Nx.add (Nx.mul_s x mat.(2).(0)) (Nx.mul_s y mat.(2).(1)))
(Nx.mul_s sl mat.(2).(2))
in
let z_clamped = Nx.clamp ~min:(-1.0) ~max:1.0 z' in
let lat' = Nx.asin z_clamped in
let lon' = Nx.atan2 y' x' in
let mask = Nx.less_s lon' 0.0 in
let lon' = Nx.where mask (Nx.add lon' two_pi) lon' in
(lon', lat')
let ensure_1d t = if Nx.ndim t = 0 then Nx.reshape [| 1 |] t else t
let make frame ~lon ~lat =
let lon_rad = ensure_1d (Unit.Angle.to_tensor lon) in
let lat_rad = ensure_1d (Unit.Angle.to_tensor lat) in
if Nx.ndim lon_rad <> 1 || Nx.ndim lat_rad <> 1 then
invalid_arg "Coord: lon and lat must be scalar or 1-D tensors";
if Nx.numel lon_rad <> Nx.numel lat_rad then
invalid_arg "Coord: lon and lat must have the same length";
{ frame; lon = lon_rad; lat = lat_rad }
let of_radec ~ra ~dec = make ICRS ~lon:ra ~lat:dec
let of_galactic ~l ~b = make Galactic ~lon:l ~lat:b
let of_ecliptic_j2000 ~lon ~lat = make Ecliptic_j2000 ~lon ~lat
let of_supergalactic ~sgl ~sgb = make Supergalactic ~lon:sgl ~lat:sgb
let frame c = c.frame
let size c = Nx.numel c.lon
let lon c = Unit.Angle.of_tensor c.lon
let lat c = Unit.Angle.of_tensor c.lat
let to_icrs c =
match c.frame with
| ICRS -> c
| Galactic ->
let lon', lat' = rotate gal_to_icrs c.lon c.lat in
{ frame = ICRS; lon = lon'; lat = lat' }
| Ecliptic_j2000 ->
let lon', lat' = rotate ecl_to_icrs c.lon c.lat in
{ frame = ICRS; lon = lon'; lat = lat' }
| Supergalactic ->
let gal_lon, gal_lat = rotate sgal_to_gal c.lon c.lat in
let icrs_lon, icrs_lat = rotate gal_to_icrs gal_lon gal_lat in
{ frame = ICRS; lon = icrs_lon; lat = icrs_lat }
let ra c = lon (to_icrs c)
let dec c = lat (to_icrs c)
let to_frame target c =
if c.frame = target then c
else
let icrs = to_icrs c in
match target with
| ICRS -> icrs
| Galactic ->
let lon', lat' = rotate icrs_to_gal icrs.lon icrs.lat in
{ frame = Galactic; lon = lon'; lat = lat' }
| Ecliptic_j2000 ->
let lon', lat' = rotate icrs_to_ecl icrs.lon icrs.lat in
{ frame = Ecliptic_j2000; lon = lon'; lat = lat' }
| Supergalactic ->
let gal_lon, gal_lat = rotate icrs_to_gal icrs.lon icrs.lat in
let sg_lon, sg_lat = rotate gal_to_sgal gal_lon gal_lat in
{ frame = Supergalactic; lon = sg_lon; lat = sg_lat }
let icrs c = to_frame ICRS c
let galactic c = to_frame Galactic c
let ecliptic_j2000 c = to_frame Ecliptic_j2000 c
let supergalactic c = to_frame Supergalactic c
let trig_of a b =
let a = to_icrs a and b = to_icrs b in
let dlon = Nx.sub b.lon a.lon in
let cos_lat1 = Nx.cos a.lat and sin_lat1 = Nx.sin a.lat in
let cos_lat2 = Nx.cos b.lat and sin_lat2 = Nx.sin b.lat in
let cos_dlon = Nx.cos dlon and sin_dlon = Nx.sin dlon in
(dlon, cos_lat1, sin_lat1, cos_lat2, sin_lat2, cos_dlon, sin_dlon)
let separation a b =
if size a <> size b then
invalid_arg "Coord.separation: arrays must have the same length";
let _, cos_lat1, sin_lat1, cos_lat2, sin_lat2, cos_dlon, sin_dlon =
trig_of a b
in
(* Vincenty formula *)
let a1 = Nx.mul cos_lat2 sin_dlon in
let a2 =
Nx.sub (Nx.mul cos_lat1 sin_lat2)
(Nx.mul (Nx.mul sin_lat1 cos_lat2) cos_dlon)
in
let num = Nx.sqrt (Nx.add (Nx.square a1) (Nx.square a2)) in
let den =
Nx.add (Nx.mul sin_lat1 sin_lat2)
(Nx.mul (Nx.mul cos_lat1 cos_lat2) cos_dlon)
in
let sep = Nx.atan2 num den in
Unit.Angle.of_tensor (Nx.abs sep)
let position_angle a b =
if size a <> size b then
invalid_arg "Coord.position_angle: arrays must have the same length";
let _, cos_lat1, sin_lat1, cos_lat2, sin_lat2, cos_dlon, sin_dlon =
trig_of a b
in
let num = Nx.mul cos_lat2 sin_dlon in
let den =
Nx.sub (Nx.mul cos_lat1 sin_lat2)
(Nx.mul (Nx.mul sin_lat1 cos_lat2) cos_dlon)
in
let pa = Nx.atan2 num den in
let mask = Nx.less_s pa 0.0 in
Unit.Angle.of_tensor (Nx.where mask (Nx.add pa two_pi) pa)
(* --- Offset operations --- *)
let offset_by ~position_angle ~separation c =
let pa = Unit.Angle.to_tensor position_angle in
let sep = Unit.Angle.to_tensor separation in
let cos_sep = Nx.cos sep and sin_sep = Nx.sin sep in
let cos_pa = Nx.cos pa and sin_pa = Nx.sin pa in
let sin_lat = Nx.sin c.lat and cos_lat = Nx.cos c.lat in
(* lat2 = asin(sin(lat1)*cos(sep) + cos(lat1)*sin(sep)*cos(pa)) *)
let sin_lat2 =
Nx.add (Nx.mul sin_lat cos_sep) (Nx.mul (Nx.mul cos_lat sin_sep) cos_pa)
in
let lat2 = Nx.asin (Nx.clamp ~min:(-1.0) ~max:1.0 sin_lat2) in
(* lon2 = lon1 + atan2(sin(pa)*sin(sep), cos(lat1)*cos(sep) -
sin(lat1)*sin(sep)*cos(pa)) *)
let num = Nx.mul sin_pa sin_sep in
let den =
Nx.sub (Nx.mul cos_lat cos_sep) (Nx.mul (Nx.mul sin_lat sin_sep) cos_pa)
in
let dlon = Nx.atan2 num den in
let lon2 = Nx.add c.lon dlon in
let lon2 = Nx.where (Nx.less_s lon2 0.0) (Nx.add lon2 two_pi) lon2 in
let lon2 =
Nx.where (Nx.greater_equal lon2 two_pi) (Nx.sub lon2 two_pi) lon2
in
{ frame = c.frame; lon = lon2; lat = lat2 }
let spherical_offsets_to a b =
if size a <> size b then
invalid_arg "Coord.spherical_offsets_to: arrays must have the same length";
if a.frame <> b.frame then
invalid_arg
"Coord.spherical_offsets_to: coordinates must be in the same frame";
(* Δlon = (lon_b - lon_a) * cos(lat_a), Δlat = lat_b - lat_a *)
let dlon = Nx.mul (Nx.sub b.lon a.lon) (Nx.cos a.lat) in
let dlat = Nx.sub b.lat a.lat in
(Unit.Angle.of_tensor dlon, Unit.Angle.of_tensor dlat)
(* --- Catalog cross-matching --- *)
type coord = t
type result = { indices : Nx.int32_t; separations : Unit.angle Unit.t }
type within_result = {
indices_a : Nx.int32_t;
indices_b : Nx.int32_t;
separations : Unit.angle Unit.t;
}
let to_xyz c =
let icrs = to_icrs c in
let n = size c in
let xs = Array.make n 0.0 in
let ys = Array.make n 0.0 in
let zs = Array.make n 0.0 in
for i = 0 to n - 1 do
let r = Nx.item [ i ] icrs.lon in
let d = Nx.item [ i ] icrs.lat in
let cd = Float.cos d in
xs.(i) <- cd *. Float.cos r;
ys.(i) <- cd *. Float.sin r;
zs.(i) <- Float.sin d
done;
(xs, ys, zs)
let chord_to_rad chord_sq =
let chord = Float.sqrt (Float.max 0.0 chord_sq) in
let half_chord = Float.min 1.0 (chord /. 2.0) in
2.0 *. Float.asin half_chord
module Index = struct
type t = { tree : Kdtree.t }
let of_coord c =
let xs, ys, zs = to_xyz c in
let tree = Kdtree.build xs ys zs in
{ tree }
let nearest idx query =
let qx, qy, qz = to_xyz query in
let n = Array.length qx in
let indices = Nx.zeros Nx.int32 [| n |] in
let seps = Nx.zeros Nx.float64 [| n |] in
for i = 0 to n - 1 do
let j, dist_sq = Kdtree.nearest idx.tree qx.(i) qy.(i) qz.(i) in
Nx.set_item [ i ] (Int32.of_int j) indices;
Nx.set_item [ i ] (chord_to_rad dist_sq) seps
done;
{ indices; separations = Unit.Angle.of_tensor seps }
let within idx query ~max_sep =
let max_sep_rad = Nx.item [] (Unit.Angle.to_tensor max_sep) in
let half_angle = max_sep_rad /. 2.0 in
let chord = 2.0 *. Float.sin half_angle in
let max_dist_sq = chord *. chord in
let qx, qy, qz = to_xyz query in
let na = Array.length qx in
let acc = ref [] and count = ref 0 in
for i = 0 to na - 1 do
let matches = Kdtree.within idx.tree qx.(i) qy.(i) qz.(i) max_dist_sq in
List.iter
(fun (j, dist_sq) ->
acc := (i, j, chord_to_rad dist_sq) :: !acc;
incr count)
matches
done;
let n = !count in
let out_a = Nx.zeros Nx.int32 [| n |] in
let out_b = Nx.zeros Nx.int32 [| n |] in
let out_s = Nx.zeros Nx.float64 [| n |] in
let k = ref (n - 1) in
List.iter
(fun (i, j, sep) ->
let k' = !k in
Nx.set_item [ k' ] (Int32.of_int i) out_a;
Nx.set_item [ k' ] (Int32.of_int j) out_b;
Nx.set_item [ k' ] sep out_s;
decr k)
!acc;
{
indices_a = out_a;
indices_b = out_b;
separations = Unit.Angle.of_tensor out_s;
}
end
let nearest query catalog =
if size catalog = 0 then invalid_arg "Coord.nearest: catalog is empty";
let idx = Index.of_coord catalog in
Index.nearest idx query
let within a b ~max_sep =
let idx = Index.of_coord b in
Index.within idx a ~max_sep
================================================
FILE: dev/umbra/lib/coord.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Celestial coordinates with frame transforms and catalog matching.
Positions are stored as longitude/latitude pairs in 1D {!Unit.angle}
quantities and can be converted between {!ICRS}, {!Galactic},
{!Ecliptic_j2000}, and {!Supergalactic} frames via 3x3 rotation matrices.
{[
let c = Coord.of_radec ~ra:(Unit.Angle.of_deg ra) ~dec:(Unit.Angle.of_deg dec) in
let gal = Coord.galactic c
]} *)
(** {1:types Types} *)
(** The type for celestial reference frames. *)
type frame =
| ICRS (** International Celestial Reference System. *)
| Galactic (** IAU Galactic coordinates. *)
| Ecliptic_j2000 (** Ecliptic coordinates at J2000.0 epoch. *)
| Supergalactic (** Supergalactic coordinates. *)
type t
(** The type for celestial coordinates. A pair of 1D angle quantities
(longitude, latitude), tagged with a {!frame}. *)
(** {1:constructors Constructors}
All constructors require 1D angle quantities of equal length.
Raises [Invalid_argument] if the tensors are not 1D or differ in length. *)
val of_radec : ra:Unit.angle Unit.t -> dec:Unit.angle Unit.t -> t
(** [of_radec ~ra ~dec] is a coordinate in the ICRS frame. [ra] and [dec] must
be scalar or 1-D angle quantities with matching sizes. *)
val of_galactic : l:Unit.angle Unit.t -> b:Unit.angle Unit.t -> t
(** [of_galactic ~l ~b] is a coordinate in the Galactic frame. [l] and [b] must
be scalar or 1-D angle quantities with matching sizes. *)
val of_ecliptic_j2000 : lon:Unit.angle Unit.t -> lat:Unit.angle Unit.t -> t
(** [of_ecliptic_j2000 ~lon ~lat] is a coordinate in the ecliptic frame at the
J2000.0 mean obliquity (23.4392911 degrees). [lon] and [lat] must be scalar
or 1-D angle quantities with matching sizes. *)
val of_supergalactic : sgl:Unit.angle Unit.t -> sgb:Unit.angle Unit.t -> t
(** [of_supergalactic ~sgl ~sgb] is a coordinate in the Supergalactic frame.
[sgl] and [sgb] must be scalar or 1-D angle quantities with matching sizes.
*)
(** {1:accessors Accessors} *)
val frame : t -> frame
(** [frame c] is the reference frame of [c]. *)
val size : t -> int
(** [size c] is the number of positions in [c]. *)
val lon : t -> Unit.angle Unit.t
(** [lon c] is the longitude component of [c]. *)
val lat : t -> Unit.angle Unit.t
(** [lat c] is the latitude component of [c]. *)
val ra : t -> Unit.angle Unit.t
(** [ra c] is the ICRS right ascension of [c]. Converts to ICRS first if [c] is
in another frame. *)
val dec : t -> Unit.angle Unit.t
(** [dec c] is the ICRS declination of [c]. Converts to ICRS first if [c] is in
another frame. *)
(** {1:transforms Frame transforms} *)
val to_frame : frame -> t -> t
(** [to_frame f c] is [c] converted to frame [f]. Returns [c] unchanged if [c]
is already in [f]. All conversions go through ICRS as the pivot frame. Not
differentiable (scalar-level rotation matrices). *)
val icrs : t -> t
(** [icrs c] is [to_frame ICRS c]. *)
val galactic : t -> t
(** [galactic c] is [to_frame Galactic c]. *)
val ecliptic_j2000 : t -> t
(** [ecliptic_j2000 c] is [to_frame Ecliptic_j2000 c]. *)
val supergalactic : t -> t
(** [supergalactic c] is [to_frame Supergalactic c]. *)
(** {1:separation Angular separation} *)
val separation : t -> t -> Unit.angle Unit.t
(** [separation a b] is the angular separation between corresponding positions
of [a] and [b], computed with the Vincenty formula. Both coordinates are
converted to ICRS before computation. Not differentiable (scalar-level
trigonometry).
Raises [Invalid_argument] if [a] and [b] differ in {!size}. *)
val position_angle : t -> t -> Unit.angle Unit.t
(** [position_angle a b] is the position angle from [a] to [b], measured North
through East, in \[0, 2{e pi}). Both coordinates are converted to ICRS
before computation. Not differentiable (scalar-level trigonometry).
Raises [Invalid_argument] if [a] and [b] differ in {!size}. *)
(** {1:offsets Offset operations} *)
val offset_by :
position_angle:Unit.angle Unit.t -> separation:Unit.angle Unit.t -> t -> t
(** [offset_by ~position_angle ~separation c] is the coordinate obtained by
moving each position in [c] along bearing [position_angle] (North through
East) by angular distance [separation]. The result is in the same frame as
[c]. Not differentiable (scalar-level trigonometry). *)
val spherical_offsets_to : t -> t -> Unit.angle Unit.t * Unit.angle Unit.t
(** [spherical_offsets_to a b] is [(dlon, dlat)] where
[dlon = (lon_b - lon_a) * cos(lat_a)] and [dlat = lat_b - lat_a]. Both
coordinates must be in the same frame. Not differentiable (scalar-level
trigonometry).
Raises [Invalid_argument] if [a] and [b] differ in {!size} or {!frame}. *)
(** {1:matching Catalog cross-matching}
Matches positions between catalogs using a 3D kd-tree built from unit-sphere
Cartesian coordinates. All indices in results are 0-based.
{b Warning.} Cross-matching is not differentiable: it produces integer
indices and uses discrete tree search. *)
type coord = t
(** Alias for {!t}, used inside {!Index} to avoid shadowing. *)
type result = {
indices : Nx.int32_t; (** 0-based indices into the catalog. *)
separations : Unit.angle Unit.t; (** Angular distances. *)
}
(** The type for nearest-match results. For each query position, {!indices}
gives the index of the nearest catalog entry and {!separations} gives the
angular distance to it. Both have the same length as the query. *)
type within_result = {
indices_a : Nx.int32_t; (** 0-based indices into the query. *)
indices_b : Nx.int32_t; (** 0-based indices into the catalog. *)
separations : Unit.angle Unit.t; (** Angular distances. *)
}
(** The type for within-radius match results. Each entry represents one matched
pair. The three fields have equal length. *)
(** {2:index Reusable index}
Build a kd-tree once and query it many times. *)
module Index : sig
type t
(** The type for a prebuilt spatial index over a catalog. *)
val of_coord : coord -> t
(** [of_coord c] builds a kd-tree index from the positions in [c]. Coordinates
are converted to ICRS internally. *)
val nearest : t -> coord -> result
(** [nearest idx query] finds, for each position in [query], the nearest
position in the indexed catalog. *)
val within : t -> coord -> max_sep:Unit.angle Unit.t -> within_result
(** [within idx query ~max_sep] finds all pairs where a position in [query] is
within [max_sep] of a position in the indexed catalog. *)
end
val nearest : t -> t -> result
(** [nearest query catalog] finds, for each position in [query], the nearest
position in [catalog].
Raises [Invalid_argument] if [catalog] is empty. *)
val within : t -> t -> max_sep:Unit.angle Unit.t -> within_result
(** [within a b ~max_sep] finds all pairs of positions where the separation is
at most [max_sep]. Builds a kd-tree on [b]. *)
================================================
FILE: dev/umbra/lib/cosmo.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Cosmological distance calculations for ΛCDM, wCDM, and w0waCDM universes.
w0waCDM subsumes all models: - flat ΛCDM: omega_k = 0, w0 = -1, wa = 0 -
non-flat ΛCDM: w0 = -1, wa = 0 - wCDM: wa = 0 - w0waCDM: general case
All computations use Nx tensor ops, making them natively differentiable
through Rune's autodiff. GL quadrature is vectorized as tensor operations. *)
let f64 = Nx.float64
let c_km_s = Nx.scalar f64 299792.458
let _mpc_m = 3.085_677_581_491_367_3e22
type params = {
h0 : Nx.float64_t;
omega_m : Nx.float64_t;
omega_l : Nx.float64_t;
omega_k : Nx.float64_t;
w0 : Nx.float64_t;
wa : Nx.float64_t;
omega_b : Nx.float64_t option;
n_s : Nx.float64_t option;
sigma8 : Nx.float64_t option;
}
let err_missing name =
invalid_arg
("Cosmo: " ^ name ^ " not set (use Cosmo.set or a preset like planck18)")
(* --- Constructors --- *)
let flat_lcdm ~h0 ~omega_m =
if h0 <= 0.0 then invalid_arg "Cosmo.flat_lcdm: h0 must be positive";
if omega_m < 0.0 then
invalid_arg "Cosmo.flat_lcdm: omega_m must be non-negative";
{
h0 = Nx.scalar f64 h0;
omega_m = Nx.scalar f64 omega_m;
omega_l = Nx.scalar f64 (1.0 -. omega_m);
omega_k = Nx.scalar f64 0.0;
w0 = Nx.scalar f64 (-1.0);
wa = Nx.scalar f64 0.0;
omega_b = None;
n_s = None;
sigma8 = None;
}
let lcdm ~h0 ~omega_m ~omega_l =
if h0 <= 0.0 then invalid_arg "Cosmo.lcdm: h0 must be positive";
{
h0 = Nx.scalar f64 h0;
omega_m = Nx.scalar f64 omega_m;
omega_l = Nx.scalar f64 omega_l;
omega_k = Nx.scalar f64 (1.0 -. omega_m -. omega_l);
w0 = Nx.scalar f64 (-1.0);
wa = Nx.scalar f64 0.0;
omega_b = None;
n_s = None;
sigma8 = None;
}
let wcdm ~h0 ~omega_m ?omega_l ~w0 () =
if h0 <= 0.0 then invalid_arg "Cosmo.wcdm: h0 must be positive";
let omega_l = match omega_l with Some v -> v | None -> 1.0 -. omega_m in
{
h0 = Nx.scalar f64 h0;
omega_m = Nx.scalar f64 omega_m;
omega_l = Nx.scalar f64 omega_l;
omega_k = Nx.scalar f64 (1.0 -. omega_m -. omega_l);
w0 = Nx.scalar f64 w0;
wa = Nx.scalar f64 0.0;
omega_b = None;
n_s = None;
sigma8 = None;
}
let w0wacdm ~h0 ~omega_m ?omega_l ~w0 ~wa () =
if h0 <= 0.0 then invalid_arg "Cosmo.w0wacdm: h0 must be positive";
let omega_l = match omega_l with Some v -> v | None -> 1.0 -. omega_m in
{
h0 = Nx.scalar f64 h0;
omega_m = Nx.scalar f64 omega_m;
omega_l = Nx.scalar f64 omega_l;
omega_k = Nx.scalar f64 (1.0 -. omega_m -. omega_l);
w0 = Nx.scalar f64 w0;
wa = Nx.scalar f64 wa;
omega_b = None;
n_s = None;
sigma8 = None;
}
(* Tensor constructors for differentiable construction *)
let create_flat_lcdm ~h0 ~omega_m =
{
h0;
omega_m;
omega_l = Nx.sub (Nx.scalar f64 1.0) omega_m;
omega_k = Nx.scalar f64 0.0;
w0 = Nx.scalar f64 (-1.0);
wa = Nx.scalar f64 0.0;
omega_b = None;
n_s = None;
sigma8 = None;
}
let create_lcdm ~h0 ~omega_m ~omega_l =
{
h0;
omega_m;
omega_l;
omega_k = Nx.sub (Nx.scalar f64 1.0) (Nx.add omega_m omega_l);
w0 = Nx.scalar f64 (-1.0);
wa = Nx.scalar f64 0.0;
omega_b = None;
n_s = None;
sigma8 = None;
}
let create_wcdm ~h0 ~omega_m ?omega_l ~w0 () =
let omega_l =
match omega_l with
| Some v -> v
| None -> Nx.sub (Nx.scalar f64 1.0) omega_m
in
{
h0;
omega_m;
omega_l;
omega_k = Nx.sub (Nx.scalar f64 1.0) (Nx.add omega_m omega_l);
w0;
wa = Nx.scalar f64 0.0;
omega_b = None;
n_s = None;
sigma8 = None;
}
let create_w0wacdm ~h0 ~omega_m ?omega_l ~w0 ~wa () =
let omega_l =
match omega_l with
| Some v -> v
| None -> Nx.sub (Nx.scalar f64 1.0) omega_m
in
{
h0;
omega_m;
omega_l;
omega_k = Nx.sub (Nx.scalar f64 1.0) (Nx.add omega_m omega_l);
w0;
wa;
omega_b = None;
n_s = None;
sigma8 = None;
}
(* Accessors *)
let h0 p = p.h0
let omega_m p = p.omega_m
let omega_l p = p.omega_l
let omega_k p = p.omega_k
let w0 p = p.w0
let wa p = p.wa
let omega_b p =
match p.omega_b with Some v -> v | None -> err_missing "omega_b"
let n_s p = match p.n_s with Some v -> v | None -> err_missing "n_s"
let sigma8 p = match p.sigma8 with Some v -> v | None -> err_missing "sigma8"
let set ?omega_b ?n_s ?sigma8 p =
let omega_b =
match omega_b with Some v -> Some (Nx.scalar f64 v) | None -> p.omega_b
in
let n_s = match n_s with Some v -> Some (Nx.scalar f64 v) | None -> p.n_s in
let sigma8 =
match sigma8 with Some v -> Some (Nx.scalar f64 v) | None -> p.sigma8
in
{ p with omega_b; n_s; sigma8 }
let set_t ?h0 ?omega_m ?omega_l ?omega_b ?n_s ?sigma8 p =
let h0 = match h0 with Some v -> v | None -> p.h0 in
let omega_m = match omega_m with Some v -> v | None -> p.omega_m in
let omega_l = match omega_l with Some v -> v | None -> p.omega_l in
let omega_k = Nx.sub (Nx.scalar f64 1.0) (Nx.add omega_m omega_l) in
let omega_b = match omega_b with Some v -> Some v | None -> p.omega_b in
let n_s = match n_s with Some v -> Some v | None -> p.n_s in
let sigma8 = match sigma8 with Some v -> Some v | None -> p.sigma8 in
{ p with h0; omega_m; omega_l; omega_k; omega_b; n_s; sigma8 }
(* Presets *)
let default = flat_lcdm ~h0:70.0 ~omega_m:0.3
let planck18 =
flat_lcdm ~h0:67.66 ~omega_m:0.3111
|> set ~omega_b:0.0490 ~n_s:0.9665 ~sigma8:0.8102
let planck15 =
flat_lcdm ~h0:67.74 ~omega_m:0.3075
|> set ~omega_b:0.0486 ~n_s:0.9667 ~sigma8:0.8159
let wmap9 =
flat_lcdm ~h0:69.32 ~omega_m:0.2865
|> set ~omega_b:0.0463 ~n_s:0.9608 ~sigma8:0.820
(* --- E(z) computation ---
E(z) = H(z)/H0 = sqrt(Ω_m(1+z)³ + Ω_k(1+z)² + Ω_de(z))
where Ω_de(z) = Ω_Λ * (1+z)^(3(1+w0+wa)) * exp(-3*wa*z/(1+z))
For ΛCDM (w0=-1, wa=0): Ω_de(z) = Ω_Λ (constant) For wCDM (wa=0): Ω_de(z) =
Ω_Λ * (1+z)^(3(1+w0)) *)
let e_of p z =
let one_plus_z = Nx.add_s z 1.0 in
let cubed = Nx.mul one_plus_z (Nx.mul one_plus_z one_plus_z) in
let matter = Nx.mul p.omega_m cubed in
let curvature = Nx.mul p.omega_k (Nx.mul one_plus_z one_plus_z) in
(* Dark energy: Ω_Λ * (1+z)^(3(1+w0+wa)) * exp(-3*wa*z/(1+z)) *)
let w_eff = Nx.add_s (Nx.add p.w0 p.wa) 1.0 in
let de_power = Nx.pow one_plus_z (Nx.mul_s w_eff 3.0) in
let wa_arg = Nx.mul (Nx.mul_s p.wa (-3.0)) (Nx.div z one_plus_z) in
let de = Nx.mul p.omega_l (Nx.mul de_power (Nx.exp wa_arg)) in
Nx.sqrt (Nx.add matter (Nx.add curvature de))
(* 16-point Gauss-Legendre nodes and weights on [-1, 1] as Nx tensors *)
let gl_nodes =
Nx.create f64 [| 16 |]
[|
-0.9894009349916499;
-0.9445750230732326;
-0.8656312023878318;
-0.7554044083550030;
-0.6178762444026438;
-0.4580167776572274;
-0.2816035507792589;
-0.0950125098376374;
0.0950125098376374;
0.2816035507792589;
0.4580167776572274;
0.6178762444026438;
0.7554044083550030;
0.8656312023878318;
0.9445750230732326;
0.9894009349916499;
|]
let gl_weights =
Nx.create f64 [| 16 |]
[|
0.0271524594117541;
0.0622535239386479;
0.0951585116824928;
0.1246289712555339;
0.1495959888165767;
0.1691565193950025;
0.1826034150449236;
0.1894506104550685;
0.1894506104550685;
0.1826034150449236;
0.1691565193950025;
0.1495959888165767;
0.1246289712555339;
0.0951585116824928;
0.0622535239386479;
0.0271524594117541;
|]
(* GL quadrature in scale-factor space.
All cosmological integrals ∫₀ᶻ g(z') dz' are evaluated via the substitution a
= 1/(1+z), which maps [0, z] → [1/(1+z), 1]. This bounded range is
well-resolved by 16-point GL even at z = 1089 (CMB). Direct quadrature over
[0, z] in redshift space under-resolves the integrand at large z. *)
let gl_quad_a p z f =
let a_lo = Nx.recip (Nx.add_s z 1.0) in
let one = Nx.scalar f64 1.0 in
let half = Nx.div_s (Nx.sub one a_lo) 2.0 in
let mid = Nx.div_s (Nx.add one a_lo) 2.0 in
let a = Nx.add (Nx.mul half gl_nodes) mid in
let e_z = e_of p (Nx.sub_s (Nx.recip a) 1.0) in
Nx.mul half (Nx.sum (Nx.mul (f a e_z) gl_weights))
(* ∫₀ᶻ dz'/E(z') = ∫_{a_lo}^1 da/(a² E(a)) *)
let integrate_inv_ez p z =
gl_quad_a p z (fun a e -> Nx.recip (Nx.mul (Nx.mul a a) e))
(* ∫₀ᶻ dz'/((1+z') E(z')) = ∫_{a_lo}^1 da/(a E(a)) *)
let integrate_inv_z1_ez p z = gl_quad_a p z (fun a e -> Nx.recip (Nx.mul a e))
(* --- Derived quantities --- *)
let hubble ?(p = default) z = Nx.mul p.h0 (e_of p z)
let critical_density ?(p = default) z =
let h_z = hubble ~p z in
let h_si = Nx.div_s (Nx.mul_s h_z 1e3) _mpc_m in
Nx.div_s (Nx.mul_s (Nx.mul h_si h_si) 3.0) (8.0 *. Float.pi *. 6.674_30e-11)
(* --- Distances ---
Line-of-sight comoving distance: χ = d_H ∫₀ᶻ dz'/E(z')
Transverse comoving distance (curvature-corrected): - Ω_k > 0 (open): d_M =
d_H/√Ω_k · sinh(√Ω_k · χ/d_H) - Ω_k = 0 (flat): d_M = χ - Ω_k < 0 (closed):
d_M = d_H/√|Ω_k| · sin(√|Ω_k| · χ/d_H) *)
let comoving_distance_mpc p z =
let d_h = Nx.div c_km_s p.h0 in
Nx.mul d_h (integrate_inv_ez p z)
let transverse_comoving_mpc p z =
let d_h = Nx.div c_km_s p.h0 in
let chi = Nx.mul d_h (integrate_inv_ez p z) in
let ok_f = Nx.item [] p.omega_k in
if Float.abs ok_f < 1e-10 then chi (* flat *)
else
let sqrt_ok = Nx.sqrt (Nx.abs p.omega_k) in
let arg = Nx.div (Nx.mul sqrt_ok chi) d_h in
if ok_f > 0.0 then Nx.div (Nx.mul d_h (Nx.sinh arg)) sqrt_ok
else Nx.div (Nx.mul d_h (Nx.sin arg)) sqrt_ok
let comoving_distance ?(p = default) z =
Unit.Length.of_tensor (Nx.mul_s (comoving_distance_mpc p z) _mpc_m)
let luminosity_distance ?(p = default) z =
let dm_mpc = transverse_comoving_mpc p z in
Unit.Length.of_tensor (Nx.mul_s (Nx.mul (Nx.add_s z 1.0) dm_mpc) _mpc_m)
let angular_diameter_distance ?(p = default) z =
let dm_mpc = transverse_comoving_mpc p z in
Unit.Length.of_tensor (Nx.mul_s (Nx.div dm_mpc (Nx.add_s z 1.0)) _mpc_m)
let distance_modulus ?(p = default) z =
let dl_mpc = Nx.mul (Nx.add_s z 1.0) (transverse_comoving_mpc p z) in
(* mu = 5 * log10(dL_Mpc) + 25 = 5/ln10 * ln(dL_Mpc) + 25 *)
let five_over_ln10 = 5.0 /. Float.log 10.0 in
Nx.add_s (Nx.mul_s (Nx.log dl_mpc) five_over_ln10) 25.0
(* --- Angular scale --- *)
let angular_size ?(p = default) ~z phys =
let da = angular_diameter_distance ~p z in
Unit.Angle.of_tensor
(Nx.div (Unit.Length.to_tensor phys) (Unit.Length.to_tensor da))
let physical_size ?(p = default) ~z ang =
let da = angular_diameter_distance ~p z in
Unit.Length.of_tensor
(Nx.mul (Unit.Angle.to_tensor ang) (Unit.Length.to_tensor da))
(* --- Cosmic times --- *)
(* 1/H0 in seconds: (km/s/Mpc)^{-1} = Mpc/km · s *)
let _hubble_time_s p = Nx.mul_s (Nx.recip p.h0) 3.0856776e19
let lookback_time ?(p = default) z =
Unit.Time.of_tensor (Nx.mul (_hubble_time_s p) (integrate_inv_z1_ez p z))
let age ?(p = default) z =
(* age(z) = t_H ∫₀^{1/(1+z)} da/(a E(a)). We reuse gl_quad_a with an upper
limit at z_max=1000 (≈ a_lo → 0) for the total integral, then subtract the
lookback from 0 to z. *)
let t_h_s = _hubble_time_s p in
let total = integrate_inv_z1_ez p (Nx.scalar f64 1000.0) in
let lb = integrate_inv_z1_ez p z in
Unit.Time.of_tensor (Nx.mul t_h_s (Nx.sub total lb))
(* --- z_at_value: inverse lookup via Brent's method ---
Given a monotonic cosmological function f and a target value, find the
redshift z such that f(z) ≈ target. Not differentiable. *)
let z_at_value ?(p = default) ?(zmin = 1e-8) ?(zmax = 1000.0) ?(xtol = 1e-8) f
target =
let target_v = Nx.item [] target in
let eval z = Nx.item [] (f ~p (Nx.scalar f64 z)) -. target_v in
(* Brent's method *)
let a = ref zmin and b = ref zmax in
let fa = ref (eval !a) and fb = ref (eval !b) in
if !fa *. !fb > 0.0 then
invalid_arg "Cosmo.z_at_value: target outside [f(zmin), f(zmax)]";
if Float.abs !fa < Float.abs !fb then begin
let tmp = !a in
a := !b;
b := tmp;
let tmp = !fa in
fa := !fb;
fb := tmp
end;
let c = ref !a and fc = ref !fa in
let d = ref (!b -. !a) in
let mflag = ref true in
let max_iter = 100 in
let i = ref 0 in
while Float.abs !fb > xtol && !i < max_iter do
let s =
if Float.abs (!fa -. !fc) > 1e-30 && Float.abs (!fb -. !fc) > 1e-30 then
(* Inverse quadratic interpolation *)
let s1 = !a *. !fb *. !fc /. ((!fa -. !fb) *. (!fa -. !fc)) in
let s2 = !b *. !fa *. !fc /. ((!fb -. !fa) *. (!fb -. !fc)) in
let s3 = !c *. !fa *. !fb /. ((!fc -. !fa) *. (!fc -. !fb)) in
s1 +. s2 +. s3
else
(* Secant method *)
!b -. (!fb *. (!b -. !a) /. (!fb -. !fa))
in
let cond1 =
let lo = ((3.0 *. !a) +. !b) /. 4.0 in
not (if lo < !b then lo <= s && s <= !b else !b <= s && s <= lo)
in
let cond2 = !mflag && Float.abs (s -. !b) >= Float.abs (!b -. !c) /. 2.0 in
let cond3 =
(not !mflag) && Float.abs (s -. !b) >= Float.abs (!c -. !d) /. 2.0
in
let cond4 = !mflag && Float.abs (!b -. !c) < xtol in
let cond5 = (not !mflag) && Float.abs (!c -. !d) < xtol in
let s =
if cond1 || cond2 || cond3 || cond4 || cond5 then begin
mflag := true;
(!a +. !b) /. 2.0
end
else begin
mflag := false;
s
end
in
let fs = eval s in
d := !c;
c := !b;
fc := !fb;
if !fa *. fs < 0.0 then begin
b := s;
fb := fs
end
else begin
a := s;
fa := fs
end;
if Float.abs !fa < Float.abs !fb then begin
let tmp = !a in
a := !b;
b := tmp;
let tmp = !fa in
fa := !fb;
fb := tmp
end;
incr i
done;
Nx.scalar f64 !b
(* Growth factor and growth rate *)
(* E(a) from scale factor: a = 1/(1+z), so z = 1/a - 1 *)
let e_at_a p a = e_of p (Nx.sub_s (Nx.recip a) 1.0)
(* GL quadrature of f(a') from 0 to a. Transforms [-1,1] to [0,a]. *)
let gl_integrate_a p a f =
let half = Nx.div_s a 2.0 in
let a_prime = Nx.add (Nx.mul half gl_nodes) half in
let e_a = e_at_a p a_prime in
Nx.mul half (Nx.sum (Nx.mul (f a_prime e_a) gl_weights))
(* Growth integral: J(a) = ∫₀ᵃ da' / (a'³ E³(a')) Integrand at a'→0:
~a'^(3/2)/Ω_m^(3/2) → 0, so well-behaved. *)
let growth_integral p a =
gl_integrate_a p a (fun a_prime e_a ->
let a3 = Nx.mul a_prime (Nx.mul a_prime a_prime) in
let e3 = Nx.mul e_a (Nx.mul e_a e_a) in
Nx.recip (Nx.mul a3 e3))
(* Unnormalized growth factor: D(a) ∝ E(a) × J(a) *)
let growth_unnorm p a = Nx.mul (e_at_a p a) (growth_integral p a)
let growth_factor ?(p = default) z =
let a = Nx.recip (Nx.add_s z 1.0) in
let d_a = growth_unnorm p a in
let d_1 = growth_unnorm p (Nx.scalar f64 1.0) in
Nx.div d_a d_1
(* Growth rate: f(a) = dlnD/dlna D(a) = E(a) J(a) / const, so f = dlnE/dlna +
(dJ/dlna) / J = dlnE/dlna + 1 / (a² E³(a) J(a))
dlnE/dlna = a/(2E²) dE²/da dE²/da = -3Ωm a⁻⁴ - 2Ωk a⁻³ + ΩΛ exp(f_de)
(-3(1+w0+wa)/a + 3wa) *)
let growth_rate ?(p = default) z =
let a = Nx.recip (Nx.add_s z 1.0) in
let e_a = e_at_a p a in
let e2 = Nx.mul e_a e_a in
let j_a = growth_integral p a in
(* dE²/da *)
let a2 = Nx.mul a a in
let a3 = Nx.mul a2 a in
let a4 = Nx.mul a3 a in
let dm = Nx.mul_s (Nx.div p.omega_m a4) (-3.0) in
let dk = Nx.mul_s (Nx.div p.omega_k a3) (-2.0) in
(* Dark energy contribution: need f_de(a) and f_de'(a) *)
let f_de =
Nx.add
(Nx.mul (Nx.mul_s (Nx.add_s (Nx.add p.w0 p.wa) 1.0) (-3.0)) (Nx.log a))
(Nx.mul p.wa (Nx.mul_s (Nx.sub_s a 1.0) 3.0))
in
let f_de_prime =
Nx.add
(Nx.div (Nx.mul_s (Nx.add_s (Nx.add p.w0 p.wa) 1.0) (-3.0)) a)
(Nx.mul_s p.wa 3.0)
in
let dde = Nx.mul (Nx.mul p.omega_l (Nx.exp f_de)) f_de_prime in
let de2_da = Nx.add dm (Nx.add dk dde) in
(* dlnE/dlna = a/(2E²) × dE²/da *)
let dln_e = Nx.div (Nx.mul a de2_da) (Nx.mul_s e2 2.0) in
(* 1/(a² E³ J) *)
let e3 = Nx.mul e_a e2 in
let term2 = Nx.recip (Nx.mul a2 (Nx.mul e3 j_a)) in
Nx.add dln_e term2
(* Eisenstein-Hu transfer function (1998) *)
let t_cmb = 2.7255
(* Eisenstein & Hu (1998) transfer function with baryon oscillations. Scalar
cosmological quantities are computed in float arithmetic (the transfer
function is a fitting formula. The wavenumber k may be a tensor of arbitrary
shape; the result has the same shape. Differentiable through cosmological
parameters via Rune. *)
let eisenstein_hu p k =
let s = Nx.scalar f64 in
let om = p.omega_m in
let ob = omega_b p in
let h = Nx.div_s p.h0 100.0 in
let h2 = Nx.mul h h in
let w_m = Nx.mul om h2 in
let w_b = Nx.mul ob h2 in
let fb = Nx.div ob om in
let fc = Nx.sub (s 1.0) fb in
let t27sq = (t_cmb /. 2.7) ** 2.0 in
let t27_4 = t27sq *. t27sq in
(* Eq. 2,3: equality epoch *)
let z_eq = Nx.div_s (Nx.mul_s w_m 2.50e4) t27_4 in
let k_eq = Nx.div (Nx.div_s (Nx.mul_s w_m 7.46e-2) t27sq) h in
(* Eq. 4: drag epoch *)
let b1 =
Nx.mul
(Nx.pow w_m (s (-0.419)))
(Nx.add_s (Nx.mul_s (Nx.pow w_m (s 0.674)) 0.607) 1.0)
|> fun x -> Nx.mul_s x 0.313
in
let b2 = Nx.mul_s (Nx.pow w_m (s 0.223)) 0.238 in
let z_d =
Nx.mul
(Nx.div
(Nx.mul_s (Nx.pow w_m (s 0.251)) 1291.0)
(Nx.add_s (Nx.mul_s (Nx.pow w_m (s 0.828)) 0.659) 1.0))
(Nx.add_s (Nx.mul b1 (Nx.pow w_b b2)) 1.0)
in
(* Eq. 5: baryon/photon momentum ratios *)
let r_d = Nx.mul (Nx.div_s (Nx.mul_s w_b 31.5) t27_4) (Nx.div (s 1e3) z_d) in
let r_eq =
Nx.mul (Nx.div_s (Nx.mul_s w_b 31.5) t27_4) (Nx.div (s 1e3) z_eq)
in
(* Eq. 6: sound horizon *)
let sh_d =
Nx.mul
(Nx.mul
(Nx.div (s 2.0) (Nx.mul_s k_eq 3.0))
(Nx.sqrt (Nx.div (s 6.0) r_eq)))
(Nx.log
(Nx.div
(Nx.add (Nx.sqrt (Nx.add_s r_d 1.0)) (Nx.sqrt (Nx.add r_eq r_d)))
(Nx.add_s (Nx.sqrt r_eq) 1.0)))
in
(* Eq. 7: Silk damping *)
let k_silk =
Nx.div
(Nx.mul
(Nx.mul (Nx.mul_s (Nx.pow w_b (s 0.52)) 1.6) (Nx.pow w_m (s 0.73)))
(Nx.add_s (Nx.pow (Nx.mul_s w_m 10.4) (s (-0.95))) 1.0))
h
in
(* CDM transfer function (Eqs. 11, 12, 17, 18) *)
let a1 =
Nx.mul
(Nx.pow (Nx.mul_s w_m 46.9) (s 0.670))
(Nx.add_s (Nx.pow (Nx.mul_s w_m 32.1) (s (-0.532))) 1.0)
in
let a2 =
Nx.mul
(Nx.pow (Nx.mul_s w_m 12.0) (s 0.424))
(Nx.add_s (Nx.pow (Nx.mul_s w_m 45.0) (s (-0.582))) 1.0)
in
let alpha_c =
Nx.mul
(Nx.pow a1 (Nx.neg fb))
(Nx.pow a2 (Nx.neg (Nx.mul fb (Nx.mul fb fb))))
in
let b1c =
Nx.div (s 0.944) (Nx.add_s (Nx.pow (Nx.mul_s w_m 458.0) (s (-0.708))) 1.0)
in
let b2c = Nx.pow (Nx.mul_s w_m 0.395) (s (-0.0266)) in
let beta_c =
Nx.recip (Nx.add_s (Nx.mul b1c (Nx.sub (Nx.pow fc b2c) (s 1.0))) 1.0)
in
(* T_tilde: Eq. 10, 19. Operates on k tensor. alpha, beta are scalar
tensors. *)
let t_tilde k1 alpha beta =
let q = Nx.div k1 (Nx.mul_s k_eq 13.41) in
let l = Nx.log (Nx.add_s (Nx.mul q (Nx.mul_s beta 1.8)) (Float.exp 1.0)) in
let c =
Nx.add
(Nx.div (s 386.0) (Nx.add_s (Nx.mul_s (Nx.pow q (s 1.08)) 69.9) 1.0))
(Nx.div (s 14.2) alpha)
in
Nx.div l (Nx.add l (Nx.mul c (Nx.mul q q)))
in
let ksh = Nx.mul k sh_d in
(* Eq. 17, 18 *)
let f_ =
let x = Nx.div_s ksh 5.4 in
let x2 = Nx.mul x x in
Nx.recip (Nx.add_s (Nx.mul x2 x2) 1.0)
in
let tc =
Nx.add
(Nx.mul f_ (t_tilde k (s 1.0) beta_c))
(Nx.mul (Nx.sub (s 1.0) f_) (t_tilde k alpha_c beta_c))
in
(* Baryon transfer function (Eqs. 14, 19, 21) *)
let y = Nx.div (Nx.add_s z_eq 1.0) (Nx.add_s z_d 1.0) in
let x_ = Nx.sqrt (Nx.add_s y 1.0) in
let g_eh =
Nx.mul y
(Nx.add (Nx.mul_s x_ (-6.0))
(Nx.mul
(Nx.add_s (Nx.mul_s y 3.0) 2.0)
(Nx.log (Nx.div (Nx.add_s x_ 1.0) (Nx.sub_s x_ 1.0)))))
in
let alpha_b =
Nx.mul_s
(Nx.mul (Nx.mul k_eq sh_d)
(Nx.mul (Nx.pow (Nx.add_s r_d 1.0) (s (-0.75))) g_eh))
2.07
in
let beta_node = Nx.mul_s (Nx.pow w_m (s 0.435)) 8.41 in
let beta_b =
Nx.add (Nx.add_s fb 0.5)
(Nx.mul
(Nx.sub_s (Nx.mul_s fb 2.0) 3.0)
(Nx.neg
(Nx.sqrt
(Nx.add_s (Nx.mul (Nx.mul_s w_m 17.2) (Nx.mul_s w_m 17.2)) 1.0))))
in
(* Eq. 22: tilde_s per-k *)
let tilde_s =
let bns = Nx.div beta_node ksh in
let bns3 = Nx.mul bns (Nx.mul bns bns) in
Nx.div sh_d (Nx.pow (Nx.add_s bns3 1.0) (s (1.0 /. 3.0)))
in
let tb =
let term1 =
Nx.div
(t_tilde k (s 1.0) (s 1.0))
(Nx.add_s
(let x = Nx.div_s ksh 5.2 in
Nx.mul x x)
1.0)
in
let bbks = Nx.div beta_b ksh in
let bbks3 = Nx.mul bbks (Nx.mul bbks bbks) in
let term2 =
Nx.mul
(Nx.div alpha_b (Nx.add_s bbks3 1.0))
(Nx.exp (Nx.neg (Nx.pow (Nx.div k k_silk) (s 1.4))))
in
let sinc_arg = Nx.mul k tilde_s in
Nx.mul (Nx.add term1 term2) (Nx.div (Nx.sin sinc_arg) sinc_arg)
in
(* Total: fb * Tb + fc * Tc *)
Nx.add (Nx.mul tb fb) (Nx.mul tc fc)
(* Matter power spectrum *)
(* Simpson's rule integration on a uniform grid of n+1 points from a to b. n
must be even. f is evaluated at each grid point, returns [n+1] tensor. *)
let simps_integrate f a b n =
let h = (b -. a) /. Float.of_int n in
let xs =
Nx.create f64
[| n + 1 |]
(Array.init (n + 1) (fun i -> a +. (Float.of_int i *. h)))
in
let ys = f xs in
(* Simpson weights: 1, 4, 2, 4, 2, ..., 4, 1 *)
let w =
Array.init (n + 1) (fun i ->
if i = 0 || i = n then 1.0 else if i mod 2 = 1 then 4.0 else 2.0)
in
let weights = Nx.create f64 [| n + 1 |] w in
Nx.mul_s (Nx.sum (Nx.mul ys weights)) (h /. 3.0)
(* σ²(R) = 1/(2π²) ∫ k³ P_unnorm(k) W²(kR) d(ln k) where P_unnorm = k^n_s ×
T²(k) and W is the top-hat window. Integration in ln(k) space: the integrand
is k³ P W² (the dk/k from d(ln k) cancels one power of k, giving k² P W² dk
equivalent). *)
let sigma_sq p r =
let ns = n_s p in
simps_integrate
(fun lnk ->
let k = Nx.exp lnk in
let x = Nx.mul_s k r in
(* Top-hat window: W(x) = 3(sin x - x cos x)/x³ *)
let x2 = Nx.mul x x in
let x3 = Nx.mul x2 x in
let w =
Nx.div (Nx.mul_s (Nx.sub (Nx.sin x) (Nx.mul x (Nx.cos x))) 3.0) x3
in
let t = eisenstein_hu p k in
let pk = Nx.mul (Nx.pow k ns) (Nx.mul t t) in
let k3 = Nx.mul k (Nx.mul k k) in
Nx.mul k3 (Nx.mul (Nx.mul w w) pk))
(Float.log 1e-4) (Float.log 1e4) 512
|> fun integral -> Nx.div_s integral (2.0 *. Float.pi *. Float.pi)
let linear_power ?(p = default) k z =
let s8 = sigma8 p in
let g = growth_factor ~p z in
let t = eisenstein_hu p k in
let ns = n_s p in
let pk_unnorm = Nx.mul (Nx.pow k ns) (Nx.mul t t) in
(* Normalization: A = σ8² / σ²_unnorm(R=8) *)
let s2 = sigma_sq p 8.0 in
let norm = Nx.div (Nx.mul s8 s8) s2 in
Nx.mul norm (Nx.mul pk_unnorm (Nx.mul g g))
(* Halofit (Takahashi et al. 2012) *)
(* Ω_m(a) = Ω_m a⁻³ / E²(a) *)
let omega_m_a p a =
let e2 =
let e = e_at_a p a in
Nx.mul e e
in
let a3 = Nx.mul a (Nx.mul a a) in
Nx.div (Nx.div p.omega_m a3) e2
(* Ω_de(a) = Ω_Λ exp(f_de(a)) / E²(a) *)
let omega_de_a p a =
let e2 =
let e = e_at_a p a in
Nx.mul e e
in
let f_de =
Nx.add
(Nx.mul (Nx.mul_s (Nx.add_s (Nx.add p.w0 p.wa) 1.0) (-3.0)) (Nx.log a))
(Nx.mul p.wa (Nx.mul_s (Nx.sub_s a 1.0) 3.0))
in
Nx.div (Nx.mul p.omega_l (Nx.exp f_de)) e2
(* w(a) = w0 + wa(1-a) *)
let w_of p a = Nx.add p.w0 (Nx.mul p.wa (Nx.sub (Nx.scalar f64 1.0) a))
(* σ²(R, z) using linear P(k) at z=0, scaled by D²(z)/D²(0)=D²(z). For Halofit
we need σ(R) at various R to find k_nl, plus derivatives. *)
let sigma_sq_at_z p r z =
let g = growth_factor ~p z in
Nx.mul (sigma_sq p r) (Nx.mul g g)
(* Find k_nl where σ(1/k_nl, z) = 1, plus n_eff and C at the nonlinear scale. We
compute σ²(R) on a grid, interpolate to find R_nl, then compute spectral
index and curvature from Gaussian-filtered integrals. *)
let halofit_params p z =
let g = growth_factor ~p z in
let g2 = Nx.mul g g in
let ns = n_s p in
let s8 = sigma8 p in
let s2_8 = sigma_sq p 8.0 in
let pknorm = Nx.div (Nx.mul s8 s8) s2_8 in
let n_r = 256 in
let logr =
Nx.create f64 [| n_r |]
(Array.init n_r (fun i ->
Float.log 1e-4
+. Float.of_int i
*. (Float.log 1e1 -. Float.log 1e-4)
/. Float.of_int (n_r - 1)))
in
(* Compute σ²(R) for each R using Gaussian filter exp(-(kR)²) *)
let n_k = 512 in
let lnk_min = Float.log 1e-4 in
let lnk_max = Float.log 1e4 in
let dlnk = (lnk_max -. lnk_min) /. Float.of_int (n_k - 1) in
let lnk =
Nx.create f64 [| n_k |]
(Array.init n_k (fun i -> lnk_min +. (Float.of_int i *. dlnk)))
in
let k = Nx.exp lnk in
let t = eisenstein_hu p k in
let pk_base = Nx.mul pknorm (Nx.mul (Nx.pow k ns) (Nx.mul t t)) in
let pk_at_z = Nx.mul pk_base g2 in
(* k³ P(k) / (2π²) *)
let k3pk =
Nx.div
(Nx.mul (Nx.mul k (Nx.mul k k)) pk_at_z)
(Nx.scalar f64 (2.0 *. Float.pi *. Float.pi))
in
(* Trapezoidal weights [n_k] *)
let trap_w =
Nx.create f64 [| n_k |]
(Array.init n_k (fun j -> if j = 0 || j = n_k - 1 then 0.5 else 1.0))
in
(* Float-level σ²(R) grid for root-finding *)
let sigma2_arr = Array.make n_r 0.0 in
for i = 0 to n_r - 1 do
let r = Float.exp (Nx.item [ i ] logr) in
let kr = Nx.mul_s k r in
let y2 = Nx.mul kr kr in
let gauss = Nx.exp (Nx.neg y2) in
let integrand = Nx.mul k3pk gauss in
sigma2_arr.(i) <-
Nx.item [] (Nx.mul_s (Nx.sum (Nx.mul trap_w integrand)) dlnk)
done;
(* Find R_nl where σ² = 1 by linear interpolation in log space *)
let r_nl = ref (Float.exp (Nx.item [ 0 ] logr)) in
(let found = ref false in
for i = 0 to n_r - 2 do
if (not !found) && sigma2_arr.(i) >= 1.0 && sigma2_arr.(i + 1) <= 1.0 then begin
let ls0 = Float.log sigma2_arr.(i) in
let ls1 = Float.log sigma2_arr.(i + 1) in
let lr0 = Nx.item [ i ] logr in
let lr1 = Nx.item [ i + 1 ] logr in
let frac = (0.0 -. ls0) /. (ls1 -. ls0) in
r_nl := Float.exp (lr0 +. (frac *. (lr1 -. lr0)));
found := true
end
done);
let r_nl_f = !r_nl in
(* Differentiable Newton refinement for R_nl. Compute σ² at the float root,
then one Newton step: R' = R + R*(σ²-1)/dn where dσ²/dR = -dn/R.
Numerically R' ≈ R, but the gradient dR'/dp = -(∂σ²/∂p)/(∂σ²/∂R) is exact
via the implicit function theorem. *)
let kr0 = Nx.mul_s k r_nl_f in
let y2_0 = Nx.mul kr0 kr0 in
let gauss0 = Nx.exp (Nx.neg y2_0) in
let integrand0 = Nx.mul k3pk gauss0 in
let trap_sum f = Nx.mul_s (Nx.sum (Nx.mul trap_w f)) dlnk in
let s2_0 = trap_sum integrand0 in
let dn_0 = trap_sum (Nx.mul_s (Nx.mul integrand0 y2_0) 2.0) in
let r_nl_t =
Nx.add_s (Nx.mul_s (Nx.div (Nx.sub_s s2_0 1.0) dn_0) r_nl_f) r_nl_f
in
let k_nl = Nx.recip r_nl_t in
(* Recompute n_eff and C at the tensor R_nl for full differentiability. *)
let kr = Nx.mul k r_nl_t in
let y2 = Nx.mul kr kr in
let gauss = Nx.exp (Nx.neg y2) in
let integrand = Nx.mul k3pk gauss in
let s2 = trap_sum integrand in
let dn = trap_sum (Nx.mul_s (Nx.mul integrand y2) 2.0) in
let dc =
trap_sum (Nx.mul (Nx.mul_s integrand 4.0) (Nx.sub y2 (Nx.mul y2 y2)))
in
let n_eff = Nx.sub_s dn 3.0 in
let c_curv = Nx.add (Nx.mul dn dn) (Nx.div dc s2) in
(k_nl, n_eff, c_curv)
let nonlinear_power ?(p = default) k z =
let s = Nx.scalar f64 in
let pk_lin = linear_power ~p k z in
let k_nl, n, c = halofit_params p z in
let n2 = Nx.mul n n in
let n3 = Nx.mul n2 n in
let n4 = Nx.mul n3 n in
let a = Nx.recip (Nx.add_s z 1.0) in
let om_m = omega_m_a p a in
let om_de = omega_de_a p a in
let w = w_of p a in
let odew1 = Nx.mul om_de (Nx.add_s w 1.0) in
(* Takahashi et al. 2012 coefficients — all tensor *)
let a_n =
Nx.pow (s 10.0)
(Nx.add
(Nx.add
(Nx.add
(Nx.add
(Nx.add
(Nx.add_s (Nx.mul_s n 2.8553) 1.5222)
(Nx.mul_s n2 2.3706))
(Nx.mul_s n3 0.9903))
(Nx.mul_s n4 0.2250))
(Nx.mul_s c (-0.6038)))
(Nx.mul_s odew1 0.1749))
in
let b_n =
Nx.pow (s 10.0)
(Nx.add
(Nx.add
(Nx.add
(Nx.add_s (Nx.mul_s n 0.5864) (-0.5642))
(Nx.mul_s n2 0.5716))
(Nx.mul_s c (-1.5474)))
(Nx.mul_s odew1 0.2279))
in
let c_n =
Nx.pow (s 10.0)
(Nx.add
(Nx.add (Nx.add_s (Nx.mul_s n 2.0404) 0.3698) (Nx.mul_s n2 0.8161))
(Nx.mul_s c 0.5869))
in
let gamma_n =
Nx.add (Nx.add_s (Nx.mul_s n (-0.0843)) 0.1971) (Nx.mul_s c 0.8460)
in
let alpha_n =
Nx.abs
(Nx.add
(Nx.add (Nx.add_s (Nx.mul_s n 1.3373) 6.0835) (Nx.mul_s n2 (-0.1959)))
(Nx.mul_s c (-5.5274)))
in
let beta_n =
Nx.add
(Nx.add
(Nx.add
(Nx.add
(Nx.add_s (Nx.mul_s n (-0.7354)) 2.0379)
(Nx.mul_s n2 0.3157))
(Nx.mul_s n3 1.2490))
(Nx.mul_s n4 0.3980))
(Nx.mul_s c (-0.1682))
in
let nu_n = Nx.pow (s 10.0) (Nx.add_s (Nx.mul_s n 3.6902) 5.2105) in
let f1 = Nx.pow om_m (s (-0.0307)) in
let f2 = Nx.pow om_m (s (-0.0585)) in
let f3 = Nx.pow om_m (s 0.0743) in
let y = Nx.div k k_nl in
(* Δ²_L = k³ P_lin / (2π²) *)
let d2l =
Nx.div
(Nx.mul (Nx.mul k (Nx.mul k k)) pk_lin)
(s (2.0 *. Float.pi *. Float.pi))
in
(* f(y) = y/4 + y²/8 *)
let fy = Nx.add (Nx.div_s y 4.0) (Nx.div_s (Nx.mul y y) 8.0) in
(* Quasi-linear term: Δ²_Q *)
let d2q =
Nx.mul d2l
(Nx.mul
(Nx.div
(Nx.pow (Nx.add_s d2l 1.0) beta_n)
(Nx.add_s (Nx.mul d2l alpha_n) 1.0))
(Nx.exp (Nx.neg fy)))
in
(* Halo term: Δ²_H *)
let three_f1 = Nx.mul_s f1 3.0 in
let d2h_prime =
Nx.div
(Nx.mul a_n (Nx.pow y three_f1))
(Nx.add_s
(Nx.add
(Nx.mul b_n (Nx.pow y f2))
(Nx.pow (Nx.mul (Nx.mul c_n f3) y) (Nx.sub_s gamma_n 3.0)))
1.0)
in
let d2h = Nx.div d2h_prime (Nx.add_s (Nx.div nu_n (Nx.mul y y)) 1.0) in
let d2nl = Nx.add d2q d2h in
Nx.div (Nx.mul_s d2nl (2.0 *. Float.pi *. Float.pi)) (Nx.mul k (Nx.mul k k))
(* BAO distance measures *)
let dh ?(p = default) z =
Unit.Length.of_tensor (Nx.mul_s (Nx.div c_km_s (hubble ~p z)) _mpc_m)
let dm ?(p = default) z =
Unit.Length.of_tensor (Nx.mul_s (transverse_comoving_mpc p z) _mpc_m)
let dv ?(p = default) z =
let dh_mpc = Nx.div c_km_s (hubble ~p z) in
let dm_mpc = transverse_comoving_mpc p z in
let cube = Nx.mul z (Nx.mul dh_mpc (Nx.mul dm_mpc dm_mpc)) in
Unit.Length.of_tensor (Nx.mul_s (Nx.pow_s cube (1.0 /. 3.0)) _mpc_m)
let sound_horizon ?(p = default) () =
let ob = omega_b p in
let h = Nx.div_s p.h0 100.0 in
let h2 = Nx.mul h h in
let w_m = Nx.mul p.omega_m h2 in
let w_b = Nx.mul ob h2 in
(* Eisenstein & Hu (1998) Eq. 2–6: sound horizon at drag epoch in Mpc/h *)
let t27sq = (t_cmb /. 2.7) ** 2.0 in
let t27_4 = t27sq *. t27sq in
let z_eq = Nx.div_s (Nx.mul_s w_m 2.50e4) t27_4 in
let k_eq = Nx.div (Nx.div_s (Nx.mul_s w_m 7.46e-2) t27sq) h in
let b1_z =
Nx.mul
(Nx.pow w_m (Nx.scalar f64 (-0.419)))
(Nx.add_s (Nx.mul_s (Nx.pow w_m (Nx.scalar f64 0.674)) 0.607) 1.0)
|> fun x -> Nx.mul_s x 0.313
in
let b2_z = Nx.mul_s (Nx.pow w_m (Nx.scalar f64 0.223)) 0.238 in
let z_d =
Nx.mul
(Nx.div
(Nx.mul_s (Nx.pow w_m (Nx.scalar f64 0.251)) 1291.0)
(Nx.add_s (Nx.mul_s (Nx.pow w_m (Nx.scalar f64 0.828)) 0.659) 1.0))
(Nx.add_s (Nx.mul b1_z (Nx.pow w_b b2_z)) 1.0)
in
let r_d =
Nx.mul (Nx.div_s (Nx.mul_s w_b 31.5) t27_4) (Nx.div (Nx.scalar f64 1e3) z_d)
in
let r_eq =
Nx.mul
(Nx.div_s (Nx.mul_s w_b 31.5) t27_4)
(Nx.div (Nx.scalar f64 1e3) z_eq)
in
(* Eq. 6 from Eisenstein & Hu: sound horizon in Mpc/h *)
let sh_d =
Nx.mul
(Nx.mul
(Nx.div (Nx.scalar f64 2.0) (Nx.mul_s k_eq 3.0))
(Nx.sqrt (Nx.div (Nx.scalar f64 6.0) r_eq)))
(Nx.log
(Nx.div
(Nx.add (Nx.sqrt (Nx.add_s r_d 1.0)) (Nx.sqrt (Nx.add r_eq r_d)))
(Nx.add_s (Nx.sqrt r_eq) 1.0)))
in
(* sh_d is in Mpc/h, convert to Mpc then to metres *)
let rs_mpc = Nx.div sh_d h in
Unit.Length.of_tensor (Nx.mul_s rs_mpc _mpc_m)
================================================
FILE: dev/umbra/lib/cosmo.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Cosmology for {e Λ}CDM, wCDM, and w0waCDM universes.
Computes distances, growth factors, and matter power spectra. Supports flat
and non-flat {e Λ}CDM, wCDM, and w0waCDM cosmologies through a single
parameter type. All functions are differentiable through Rune.
{[
let z = Nx.scalar Nx.float64 0.5 in
let dl = Cosmo.luminosity_distance z in
let dl_mpc = Unit.Length.in_mpc dl
]}
Power spectrum functions require [omega_b], [n_s], and [sigma8] to be set
via {!set} or by using a preset like {!planck18}. *)
(** {1:params Parameters} *)
type params
(** The type for cosmological parameters. Subsumes flat {e Λ}CDM, non-flat
{e Λ}CDM, wCDM, and w0waCDM. *)
(** {2:float_constructors Float constructors}
Create parameters from plain floats. *)
val flat_lcdm : h0:float -> omega_m:float -> params
(** [flat_lcdm ~h0 ~omega_m] is flat {e Λ}CDM with {e Ω}{_ L}[ = 1 - omega_m].
Raises [Invalid_argument] if [h0 <= 0] or [omega_m < 0]. *)
val lcdm : h0:float -> omega_m:float -> omega_l:float -> params
(** [lcdm ~h0 ~omega_m ~omega_l] is {e Λ}CDM with curvature
{e Ω}{_ k}[ = 1 - omega_m - omega_l].
Raises [Invalid_argument] if [h0 <= 0]. *)
val wcdm :
h0:float -> omega_m:float -> ?omega_l:float -> w0:float -> unit -> params
(** [wcdm ~h0 ~omega_m ~w0 ()] is wCDM with constant dark energy equation of
state [w0]. [omega_l] defaults to [1 - omega_m] (flat). *)
val w0wacdm :
h0:float ->
omega_m:float ->
?omega_l:float ->
w0:float ->
wa:float ->
unit ->
params
(** [w0wacdm ~h0 ~omega_m ~w0 ~wa ()] is the CPL parameterization
[w(z) = w0 + wa * z/(1+z)]. [omega_l] defaults to [1 - omega_m] (flat). *)
(** {2:tensor_constructors Tensor constructors}
Create parameters from Nx scalar tensors for differentiable construction. *)
val create_flat_lcdm : h0:Nx.float64_t -> omega_m:Nx.float64_t -> params
val create_lcdm :
h0:Nx.float64_t -> omega_m:Nx.float64_t -> omega_l:Nx.float64_t -> params
val create_wcdm :
h0:Nx.float64_t ->
omega_m:Nx.float64_t ->
?omega_l:Nx.float64_t ->
w0:Nx.float64_t ->
unit ->
params
val create_w0wacdm :
h0:Nx.float64_t ->
omega_m:Nx.float64_t ->
?omega_l:Nx.float64_t ->
w0:Nx.float64_t ->
wa:Nx.float64_t ->
unit ->
params
(** {2:accessors Accessors} *)
val h0 : params -> Nx.float64_t
(** [h0 p] is the Hubble constant H{_ 0} in km s{^ -1} Mpc{^ -1}. *)
val omega_m : params -> Nx.float64_t
(** [omega_m p] is the matter density parameter {e Ω}{_ m}. *)
val omega_l : params -> Nx.float64_t
(** [omega_l p] is the dark energy density parameter {e Ω}{_ Λ}. *)
val omega_k : params -> Nx.float64_t
(** [omega_k p] is the curvature density parameter {e Ω}{_ k}[ = 1 - Ω_m - Ω_Λ].
*)
val w0 : params -> Nx.float64_t
(** [w0 p] is the dark energy equation of state parameter w{_ 0}. *)
val wa : params -> Nx.float64_t
(** [wa p] is the CPL time-varying dark energy parameter w{_ a}. *)
val omega_b : params -> Nx.float64_t
(** [omega_b p] is the baryon density parameter {e Ω}{_ b}.
Raises [Invalid_argument] if not set. *)
val n_s : params -> Nx.float64_t
(** [n_s p] is the primordial spectral index n{_ s}.
Raises [Invalid_argument] if not set. *)
val sigma8 : params -> Nx.float64_t
(** [sigma8 p] is the amplitude of matter fluctuations {e σ}{_ 8}.
Raises [Invalid_argument] if not set. *)
(** {2:set Setting power spectrum parameters} *)
val set : ?omega_b:float -> ?n_s:float -> ?sigma8:float -> params -> params
(** [set ~omega_b ~n_s ~sigma8 p] is [p] with the given power spectrum
parameters set. Unspecified parameters retain their previous value. *)
val set_t :
?h0:Nx.float64_t ->
?omega_m:Nx.float64_t ->
?omega_l:Nx.float64_t ->
?omega_b:Nx.float64_t ->
?n_s:Nx.float64_t ->
?sigma8:Nx.float64_t ->
params ->
params
(** [set_t] is like {!set} but takes Nx scalar tensors for differentiable
construction. Recomputes {e Ω}{_ k} when [omega_m] or [omega_l] changes. *)
(** {2:presets Presets} *)
val default : params
(** [default] is flat {e Λ}CDM with [h0 = 70], [omega_m = 0.3]. *)
val planck18 : params
(** [planck18] is Planck 2018 flat {e Λ}CDM: [h0 = 67.66], [omega_m = 0.3111],
[omega_b = 0.0490], [n_s = 0.9665], [sigma8 = 0.8102]. *)
val planck15 : params
(** [planck15] is Planck 2015 flat {e Λ}CDM: [h0 = 67.74], [omega_m = 0.3075],
[omega_b = 0.0486], [n_s = 0.9667], [sigma8 = 0.8159]. *)
val wmap9 : params
(** [wmap9] is WMAP9 flat {e Λ}CDM: [h0 = 69.32], [omega_m = 0.2865],
[omega_b = 0.0463], [n_s = 0.9608], [sigma8 = 0.820]. *)
(** {1:e_z Hubble parameter} *)
val e_of : params -> Nx.float64_t -> Nx.float64_t
(** [e_of p z] is E(z) = H(z)/H{_ 0} at redshift [z]. Fully differentiable
through Rune. *)
val hubble : ?p:params -> Nx.float64_t -> Nx.float64_t
(** [hubble z] is H(z) in km s{^ -1} Mpc{^ -1}. [p] defaults to {!default}. *)
val critical_density : ?p:params -> Nx.float64_t -> Nx.float64_t
(** [critical_density z] is the critical density {e rho}{_ c}(z) in kg m{^ -3}.
[p] defaults to {!default}. *)
(** {1:distances Distances} *)
val comoving_distance : ?p:params -> Nx.float64_t -> Unit.length Unit.t
(** [comoving_distance z] is the line-of-sight comoving distance at redshift
[z]. [p] defaults to {!default}. *)
val luminosity_distance : ?p:params -> Nx.float64_t -> Unit.length Unit.t
(** [luminosity_distance z] is the luminosity distance at redshift [z]. For
non-flat models, applies the curvature correction via the transverse
comoving distance. [p] defaults to {!default}. *)
val angular_diameter_distance : ?p:params -> Nx.float64_t -> Unit.length Unit.t
(** [angular_diameter_distance z] is the angular diameter distance at redshift
[z]. [p] defaults to {!default}. *)
val distance_modulus : ?p:params -> Nx.float64_t -> Nx.float64_t
(** [distance_modulus z] is the distance modulus
{e mu}[ = 5 log10(d_L / Mpc) + 25]. [p] defaults to {!default}. *)
(** {1:angular Angular scale} *)
val angular_size :
?p:params -> z:Nx.float64_t -> Unit.length Unit.t -> Unit.angle Unit.t
(** [angular_size ~z length] is the angular size of [length] at redshift [z]
under the small-angle approximation [{e theta} = l / d_A]. [p] defaults to
{!default}. *)
val physical_size :
?p:params -> z:Nx.float64_t -> Unit.angle Unit.t -> Unit.length Unit.t
(** [physical_size ~z angle] is the physical size subtended by [angle] at
redshift [z] under the small-angle approximation [l = {e theta} * d_A]. [p]
defaults to {!default}. *)
(** {1:times Cosmic times} *)
val lookback_time : ?p:params -> Nx.float64_t -> Unit.time Unit.t
(** [lookback_time z] is the lookback time to redshift [z]. [p] defaults to
{!default}. *)
val age : ?p:params -> Nx.float64_t -> Unit.time Unit.t
(** [age z] is the age of the universe at redshift [z].
Integrates from [z] to [z = 1000]. This approximation is accurate to ~0.1%
for late-time cosmology ([z < 10]) but omits the radiation era and is not
suitable for CMB-epoch calculations. [p] defaults to {!default}. *)
(** {1:inverse Inverse lookup} *)
val z_at_value :
?p:params ->
?zmin:float ->
?zmax:float ->
?xtol:float ->
(p:params -> Nx.float64_t -> Nx.float64_t) ->
Nx.float64_t ->
Nx.float64_t
(** [z_at_value f target] finds the redshift [z] where [f ~p z = target] using
Brent's method. [f] must be a monotonic function of redshift.
For distance functions, unwrap the unit first:
{[
z_at_value
(fun ~p z -> Unit.Length.in_mpc (Cosmo.comoving_distance ~p z))
target
]}
[zmin] defaults to [1e-8]. [zmax] defaults to [1000.0]. [xtol] defaults to
[1e-8].
{b Warning.} Not differentiable (iterative root-finding).
Raises [Invalid_argument] if [target] is outside [[f(zmin), f(zmax)]]. *)
(** {1:bao BAO distance measures} *)
val dh : ?p:params -> Nx.float64_t -> Unit.length Unit.t
(** [dh z] is the Hubble distance D{_ H}(z) = c / H(z). [p] defaults to
{!default}. *)
val dm : ?p:params -> Nx.float64_t -> Unit.length Unit.t
(** [dm z] is the comoving transverse distance D{_ M}(z). Equal to
{!comoving_distance} for flat cosmologies; includes curvature correction
otherwise. [p] defaults to {!default}. *)
val dv : ?p:params -> Nx.float64_t -> Unit.length Unit.t
(** [dv z] is the volume-averaged BAO distance D{_ V}(z) = (z D{_ H}(z)
D{_ M}{^ 2}(z)){^ 1/3}. [p] defaults to {!default}. *)
val sound_horizon : ?p:params -> unit -> Unit.length Unit.t
(** [sound_horizon ()] is the comoving sound horizon at the drag epoch
r{_ s}(z{_ drag}), using the Eisenstein & Hu (1998) fitting formulae for
z{_ drag} and the sound horizon integral.
Raises [Invalid_argument] if [omega_b] is not set in [p]. [p] defaults to
{!default}. *)
(** {1:growth Structure growth} *)
val growth_factor : ?p:params -> Nx.float64_t -> Nx.float64_t
(** [growth_factor z] is the linear growth factor D(z), normalized to D(0) = 1.
Computed via the integral form D(a) {e ∝} E(a) {e ∫}{_ 0}{^ a} da' /
(a'{^ 3} E{^ 3}(a')).
Does not require [omega_b], [n_s], or [sigma8]. [p] defaults to {!default}.
*)
val growth_rate : ?p:params -> Nx.float64_t -> Nx.float64_t
(** [growth_rate z] is the linear growth rate f(z) = d ln D / d ln a, computed
from the exact derivative of the integral-form growth factor.
[p] defaults to {!default}. *)
(** {1:power Matter power spectrum}
All power spectrum functions require [omega_b], [n_s], and [sigma8] to be
set in the parameters. Use {!set} or a preset like {!planck18}.
Wavenumbers [k] are in h/Mpc. Power spectra are in (Mpc/h){^ 3}. *)
val linear_power : ?p:params -> Nx.float64_t -> Nx.float64_t -> Nx.float64_t
(** [linear_power ~p k z] is the linear matter power spectrum P(k, z). Uses the
Eisenstein & Hu (1998) transfer function with baryon oscillations and
{e σ}{_ 8} normalization.
Raises [Invalid_argument] if [omega_b], [n_s], or [sigma8] are not set. *)
val nonlinear_power : ?p:params -> Nx.float64_t -> Nx.float64_t -> Nx.float64_t
(** [nonlinear_power ~p k z] is the nonlinear matter power spectrum via the
Halofit fitting formula (Takahashi et al. 2012).
{b Warning.} The nonlinear scale k{_ nl} is found by float-level
root-finding; gradients do not flow through it. The mapping from k{_ nl} to
P{_ nl}(k) is differentiable.
Raises [Invalid_argument] if [omega_b], [n_s], or [sigma8] are not set. *)
================================================
FILE: dev/umbra/lib/dune
================================================
(library
(name umbra)
(public_name umbra)
(private_modules kdtree filter_data vega_data)
(libraries nx unix))
================================================
FILE: dev/umbra/lib/extinction.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let f64 = Nx.float64
(* Extinction law: wavelength in metres → A_λ/A_V *)
type law = wavelength:Nx.float64_t -> Nx.float64_t
(* Horner evaluation: c0 + y*(c1 + y*(c2 + ...)) *)
let horner y coeffs =
let n = Array.length coeffs in
let acc = ref (Nx.scalar f64 coeffs.(n - 1)) in
for i = n - 2 downto 0 do
acc := Nx.add_s (Nx.mul y !acc) coeffs.(i)
done;
!acc
(* Shared CCM89/O'Donnell94 implementation parameterized by R_V. Only the
optical a/b polynomial coefficients differ between the two laws; IR and UV
regions are identical. Uses Nx.where for differentiable piecewise selection.
Valid for 0.125–3.5 μm (x = 0.3–8.0 μm⁻¹). *)
let ccm89_impl a_opt b_opt ~rv ~wavelength =
(* Convert wavelength (m) to inverse microns *)
let x = Nx.div (Nx.scalar f64 1e-6) wavelength in
(* Infrared: 0.3 ≤ x < 1.1 *)
let a_ir = Nx.mul_s (Nx.pow_s x 1.61) 0.574 in
let b_ir = Nx.mul_s (Nx.pow_s x 1.61) (-0.527) in
(* Optical/NIR: 1.1 ≤ x ≤ 3.3, polynomial in (x - 1.82) *)
let y = Nx.sub_s x 1.82 in
let a_o = horner y a_opt in
let b_o = horner y b_opt in
(* UV: 3.3 < x ≤ 8.0 *)
let fa =
Nx.where (Nx.greater_equal_s x 5.9)
(Nx.add
(Nx.mul_s (Nx.square (Nx.sub_s x 5.9)) (-0.04473))
(Nx.mul_s (Nx.pow_s (Nx.sub_s x 5.9) 3.0) (-0.009779)))
(Nx.scalar f64 0.0)
in
let fb =
Nx.where (Nx.greater_equal_s x 5.9)
(Nx.add
(Nx.mul_s (Nx.square (Nx.sub_s x 5.9)) 0.2130)
(Nx.mul_s (Nx.pow_s (Nx.sub_s x 5.9) 3.0) 0.1207))
(Nx.scalar f64 0.0)
in
(* a(x) = 1.752 - 0.316*x - 0.104/((x-4.67)² + 0.341) + F_a *)
let a_uv_base = Nx.add_s (Nx.mul_s x (-0.316)) 1.752 in
let bump_a =
Nx.div (Nx.scalar f64 (-0.104))
(Nx.add (Nx.square (Nx.sub_s x 4.67)) (Nx.scalar f64 0.341))
in
let a_uv = Nx.add (Nx.add a_uv_base bump_a) fa in
(* b(x) = -3.090 + 1.825*x + 1.206/((x-4.62)² + 0.263) + F_b *)
let b_uv_base = Nx.add_s (Nx.mul_s x 1.825) (-3.090) in
let bump_b =
Nx.div (Nx.scalar f64 1.206)
(Nx.add (Nx.square (Nx.sub_s x 4.62)) (Nx.scalar f64 0.263))
in
let b_uv = Nx.add (Nx.add b_uv_base bump_b) fb in
(* Piecewise selection using Nx.where *)
let ir_mask = Nx.less_s x 1.1 in
let uv_mask = Nx.greater_s x 3.3 in
let a = Nx.where ir_mask a_ir (Nx.where uv_mask a_uv a_o) in
let b = Nx.where ir_mask b_ir (Nx.where uv_mask b_uv b_o) in
(* A_λ/A_V = a(x) + b(x)/R_V *)
Nx.add a (Nx.div b rv)
(* CCM89: Cardelli, Clayton & Mathis 1989, ApJ 345, 245 — optical
coefficients *)
let ccm89_a =
[| 1.0; 0.17699; -0.50447; -0.02427; 0.72085; 0.01979; -0.77530; 0.32999 |]
let ccm89_b =
[| 0.0; 1.41338; 2.28305; 1.07233; -5.38434; -0.62251; 5.30260; -2.09002 |]
let ccm89 ~rv = fun ~wavelength -> ccm89_impl ccm89_a ccm89_b ~rv ~wavelength
(* O'Donnell 1994, ApJ 422, 158 — revised optical coefficients *)
let od94_a = [| 1.0; 0.104; -0.609; 0.701; -1.221; 0.700; -0.048; -0.091 |]
let od94_b = [| 0.0; 1.952; 2.908; -3.989; 7.985; -5.002; -0.478; 1.149 |]
let odonnell94 ~rv = fun ~wavelength -> ccm89_impl od94_a od94_b ~rv ~wavelength
(* Calzetti 2000: Calzetti et al. 2000, ApJ 533, 682. Starburst attenuation law.
Fixed R_V = 4.05. Valid 0.12–2.2 μm. *)
let calzetti00 =
fun ~wavelength ->
let lam_um = Nx.mul_s wavelength 1e6 in
let rv = 4.05 in
(* Blue: 0.12–0.63 μm k'(λ) = 2.659 * (-2.156 + 1.509/λ - 0.198/λ² + 0.011/λ³)
+ R_V *)
let k_blue =
Nx.add_s
(Nx.mul_s
(Nx.add_s
(Nx.add
(Nx.mul_s (Nx.recip lam_um) 1.509)
(Nx.add
(Nx.mul_s (Nx.pow_s lam_um (-2.0)) (-0.198))
(Nx.mul_s (Nx.pow_s lam_um (-3.0)) 0.011)))
(-2.156))
2.659)
rv
in
(* Red: 0.63–2.2 μm k'(λ) = 2.659 * (-1.857 + 1.040/λ) + R_V *)
let k_red =
Nx.add_s
(Nx.mul_s (Nx.add_s (Nx.mul_s (Nx.recip lam_um) 1.040) (-1.857)) 2.659)
rv
in
let blue_mask = Nx.less_s lam_um 0.63 in
let k = Nx.where blue_mask k_blue k_red in
(* A_λ/A_V = k'(λ) / R_V *)
Nx.div_s k rv
(* Fitzpatrick 1999: Fitzpatrick 1999, PASP 111, 63. R_V-dependent extinction
using cubic spline for optical/NIR and Fitzpatrick & Massa parameterization
for UV. Valid 0.1–3.5 μm. *)
(* FM UV parameters (fixed) *)
let f99_x0_sq = 4.596 *. 4.596
let f99_gamma_sq = 0.99 *. 0.99
let f99_c3 = 3.23
let f99_c4 = 0.41
let f99_c5 = 5.9
(* Spline anchor x-values (inverse microns) *)
let f99_xk =
[|
0.;
1e4 /. 26500.;
1e4 /. 12200.;
1e4 /. 6000.;
1e4 /. 5470.;
1e4 /. 4670.;
1e4 /. 4110.;
1e4 /. 2700.;
1e4 /. 2600.;
|]
let f99_hk = Array.init 8 (fun i -> f99_xk.(i + 1) -. f99_xk.(i))
(* Drude profile at a fixed x-value *)
let f99_drude x =
let x2 = x *. x in
let y = x2 -. f99_x0_sq in
x2 /. ((y *. y) +. (x2 *. f99_gamma_sq))
(* Precompute spline basis matrix M (7×9): maps 9 anchor y-values to 7 interior
second derivatives. Natural boundary conditions: m[0] = m[8] = 0.
The tridiagonal system Am = Dy is solved offline; M = A⁻¹D is stored. At
runtime m[j] = Σ M[j][i] y[i] — a weighted sum of Nx scalars, fully
differentiable through Rune. *)
let f99_basis =
let n = 7 in
let h = f99_hk in
(* Right-hand side matrix D (7×9) *)
let d_mat =
Array.init n (fun j ->
Array.init 9 (fun i ->
if i = j then 6.0 /. h.(j)
else if i = j + 1 then ~-.((6.0 /. h.(j + 1)) +. (6.0 /. h.(j)))
else if i = j + 2 then 6.0 /. h.(j + 1)
else 0.0))
in
(* Tridiagonal A: diag, sub, sup *)
let diag = Array.init n (fun j -> 2.0 *. (h.(j) +. h.(j + 1))) in
let sub j = h.(j) in
let sup j = h.(j + 1) in
(* Solve A X_col = D_col for each of 9 columns via Thomas algorithm *)
let m = Array.init n (fun _ -> Array.make 9 0.0) in
for col = 0 to 8 do
let b = Array.init n (fun j -> d_mat.(j).(col)) in
let c = Array.make n 0.0 in
let d = Array.make n 0.0 in
c.(0) <- sup 0 /. diag.(0);
d.(0) <- b.(0) /. diag.(0);
for i = 1 to n - 1 do
let w = diag.(i) -. (sub i *. c.(i - 1)) in
c.(i) <- (if i < n - 1 then sup i /. w else 0.0);
d.(i) <- (b.(i) -. (sub i *. d.(i - 1))) /. w
done;
m.(n - 1).(col) <- d.(n - 1);
for i = n - 2 downto 0 do
m.(i).(col) <- d.(i) -. (c.(i) *. m.(i + 1).(col))
done
done;
m
(* Evaluate a cubic spline piece on [xk, xk1] at tensor x. mk and mk1 are second
derivatives (Nx scalars); yk, yk1 are y-values. *)
let f99_eval_piece hk xk yk yk1 mk mk1 x =
let a = yk in
let c = Nx.mul_s mk 0.5 in
let d = Nx.div_s (Nx.sub mk1 mk) (6.0 *. hk) in
let b =
Nx.sub
(Nx.div_s (Nx.sub yk1 yk) hk)
(Nx.mul_s (Nx.add (Nx.mul_s mk 2.0) mk1) (hk /. 6.0))
in
let t = Nx.sub_s x xk in
Nx.add a (Nx.mul t (Nx.add b (Nx.mul t (Nx.add c (Nx.mul t d)))))
let fitzpatrick99 ~rv =
let rv2 = Nx.mul rv rv in
let rv3 = Nx.mul rv2 rv in
let rv4 = Nx.mul rv2 rv2 in
(* FM UV c1, c2 — computed once, used for anchor y-values and the closure *)
let c2_uv = Nx.add_s (Nx.mul_s (Nx.recip rv) 4.717) (-0.824) in
let c1_uv = Nx.sub (Nx.scalar f64 2.030) (Nx.mul_s c2_uv 3.007) in
let uv_anchor xk =
Nx.add c1_uv (Nx.add_s (Nx.mul_s c2_uv xk) (f99_c3 *. f99_drude xk))
in
(* 9 anchor E(λ-V)/E(B-V) values *)
let y =
[|
Nx.neg rv;
Nx.sub (Nx.mul_s rv (0.26469 /. 3.1)) rv;
Nx.sub (Nx.mul_s rv (0.82925 /. 3.1)) rv;
Nx.sub
(Nx.add
(Nx.add_s (Nx.mul_s rv 1.00270) (-0.422809))
(Nx.mul_s rv2 2.13572e-04))
rv;
Nx.sub
(Nx.add
(Nx.add_s (Nx.mul_s rv 1.00216) (-5.13540e-02))
(Nx.mul_s rv2 (-7.35778e-05)))
rv;
Nx.sub
(Nx.add
(Nx.add_s (Nx.mul_s rv 1.00184) 0.700127)
(Nx.mul_s rv2 (-3.32598e-05)))
rv;
Nx.sub
(Nx.add
(Nx.add
(Nx.add
(Nx.add_s (Nx.mul_s rv 1.01707) 1.19456)
(Nx.mul_s rv2 (-5.46959e-03)))
(Nx.mul_s rv3 7.97809e-04))
(Nx.mul_s rv4 (-4.45636e-05)))
rv;
(* UV anchors from FM parameterization *)
uv_anchor f99_xk.(7);
uv_anchor f99_xk.(8);
|]
in
(* Second derivatives m[0..8]: m[0] = m[8] = 0, m[1..7] from basis matrix *)
let zero = Nx.scalar f64 0.0 in
let m2 = Array.make 9 zero in
for j = 0 to 6 do
let acc = ref zero in
for i = 0 to 8 do
acc := Nx.add !acc (Nx.mul_s y.(i) f99_basis.(j).(i))
done;
m2.(j + 1) <- !acc
done;
(* Precompute spline piece coefficients for intervals 0..6 *)
let h = f99_hk in
let pieces =
Array.init 7 (fun k ->
let hk = h.(k) in
let yk = y.(k) in
let yk1 = y.(k + 1) in
let mk = m2.(k) in
let mk1 = m2.(k + 1) in
(hk, f99_xk.(k), yk, yk1, mk, mk1))
in
fun ~wavelength ->
(* Convert wavelength (m) to inverse microns *)
let x = Nx.div (Nx.scalar f64 1e-6) wavelength in
(* Evaluate spline for each interval *)
let eval k =
let hk, xk, yk, yk1, mk, mk1 = pieces.(k) in
f99_eval_piece hk xk yk yk1 mk mk1 x
in
let s0 = eval 0 in
let s1 = eval 1 in
let s2 = eval 2 in
let s3 = eval 3 in
let s4 = eval 4 in
let s5 = eval 5 in
let s6 = eval 6 in
let opt_nir =
Nx.where
(Nx.less_s x f99_xk.(1))
s0
(Nx.where
(Nx.less_s x f99_xk.(2))
s1
(Nx.where
(Nx.less_s x f99_xk.(3))
s2
(Nx.where
(Nx.less_s x f99_xk.(4))
s3
(Nx.where
(Nx.less_s x f99_xk.(5))
s4
(Nx.where (Nx.less_s x f99_xk.(6)) s5 s6)))))
in
(* UV: FM parameterization for x ≥ 1e4/2700 *)
let x2 = Nx.square x in
let y_bump = Nx.sub x2 (Nx.scalar f64 f99_x0_sq) in
let drude =
Nx.div x2 (Nx.add (Nx.mul y_bump y_bump) (Nx.mul_s x2 f99_gamma_sq))
in
let fuv =
Nx.where
(Nx.greater_equal_s x f99_c5)
(let dx = Nx.sub_s x f99_c5 in
let dx2 = Nx.square dx in
Nx.add (Nx.mul_s dx2 0.5392) (Nx.mul_s (Nx.mul dx2 dx) 0.05644))
(Nx.scalar f64 0.0)
in
let k_uv =
Nx.add c1_uv
(Nx.add (Nx.mul c2_uv x)
(Nx.add (Nx.mul_s drude f99_c3) (Nx.mul_s fuv f99_c4)))
in
(* Select optical/NIR vs UV *)
let e_over_ebv = Nx.where (Nx.less_s x f99_xk.(7)) opt_nir k_uv in
(* A(λ)/A(V) = E(λ-V)/E(B-V) / R_V + 1 *)
Nx.add_s (Nx.div e_over_ebv rv) 1.0
let curve law ~wavelength = law ~wavelength:(Unit.Length.to_tensor wavelength)
let ln10_over_2_5 = Float.log 10.0 *. 0.4
let scale_flux sign law ~av spectrum =
let wave_m = Unit.Length.to_tensor (Spectrum.wavelength spectrum) in
let a_lambda = Nx.mul (law ~wavelength:wave_m) av in
let factor = Nx.exp (Nx.mul_s a_lambda (sign *. ln10_over_2_5)) in
Spectrum.scale factor spectrum
let apply law ~av spectrum = scale_flux (-1.0) law ~av spectrum
let unredden law ~av spectrum = scale_flux 1.0 law ~av spectrum
================================================
FILE: dev/umbra/lib/extinction.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Dust extinction laws.
Extinction laws describe how interstellar dust attenuates and reddens light
as a function of wavelength. A {!law} maps wavelength to the normalised
extinction curve A{_ lambda} / A{_ V}.
{!apply} and {!unredden} are differentiable through Rune with respect to
[av]. The extinction curve evaluation itself (law constructors and {!curve})
is not differentiable (scalar-level polynomial and spline evaluation). *)
(** {1:types Types} *)
type law
(** The type for extinction laws. *)
(** {1:laws Standard laws} *)
val ccm89 : rv:Nx.float64_t -> law
(** [ccm89 ~rv] is the
{{:https://ui.adsabs.harvard.edu/abs/1989ApJ...345..245C}Cardelli, Clayton &
Mathis (1989)} Milky Way extinction law. [rv] is the total-to-selective
extinction ratio R{_ V} (typically 3.1).
Valid for 0.125--3.5 {e mu}m (0.3--8.0 {e mu}m{^ -1}). Values outside this
range are extrapolations. *)
val fitzpatrick99 : rv:Nx.float64_t -> law
(** [fitzpatrick99 ~rv] is the
{{:https://ui.adsabs.harvard.edu/abs/1999PASP..111...63F}Fitzpatrick (1999)}
R{_ V}-dependent Milky Way extinction law. Uses a cubic spline for
optical/NIR and the Fitzpatrick & Massa UV parameterization.
Valid for 0.1--3.5 {e mu}m (0.3--10.0 {e mu}m{^ -1}). *)
val odonnell94 : rv:Nx.float64_t -> law
(** [odonnell94 ~rv] is the
{{:https://ui.adsabs.harvard.edu/abs/1994ApJ...422..158O}O'Donnell (1994)}
Milky Way extinction law. Identical to {!ccm89} except for revised optical
coefficients (1.1--3.3 {e mu}m{^ -1}).
Valid for 0.125--3.5 {e mu}m. *)
val calzetti00 : law
(** [calzetti00] is the
{{:https://ui.adsabs.harvard.edu/abs/2000ApJ...533..682C}Calzetti et al.
(2000)} starburst attenuation law with fixed R{_ V} = 4.05.
Valid for 0.12--2.2 {e mu}m. Values outside this range are extrapolations.
*)
(** {1:evaluation Evaluation} *)
val curve : law -> wavelength:Unit.length Unit.t -> Nx.float64_t
(** [curve law ~wavelength] is A{_ lambda} / A{_ V} at the given wavelengths.
Not differentiable. *)
(** {1:application Application} *)
val apply : law -> av:Nx.float64_t -> 'a Spectrum.t -> 'a Spectrum.t
(** [apply law ~av spectrum] reddens [spectrum] by applying [av] magnitudes of
V-band extinction. The spectral kind is preserved. Differentiable through
Rune with respect to [av] and the spectrum values. *)
val unredden : law -> av:Nx.float64_t -> 'a Spectrum.t -> 'a Spectrum.t
(** [unredden law ~av spectrum] de-reddens [spectrum] by removing [av]
magnitudes of V-band extinction. The spectral kind is preserved.
Differentiable through Rune with respect to [av] and the spectrum values. *)
================================================
FILE: dev/umbra/lib/filter_data.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
[@@@ocamlformat "disable"]
(* Filter transmission curves from the SVO Filter Profile Service.
http://svo2.cab.inta-csic.es/theory/fps/
Wavelengths in Angstroms, throughput dimensionless. *)
(* SDSS *)
let sdss_u_wave =
[|
2980.0; 3005.0; 3030.0; 3055.0; 3080.0; 3105.0;
3130.0; 3155.0; 3180.0; 3205.0; 3230.0; 3255.0;
3280.0; 3305.0; 3330.0; 3355.0; 3380.0; 3405.0;
3430.0; 3455.0; 3480.0; 3505.0; 3530.0; 3555.0;
3580.0; 3605.0; 3630.0; 3655.0; 3680.0; 3705.0;
3730.0; 3755.0; 3780.0; 3805.0; 3830.0; 3855.0;
3880.0; 3905.0; 3930.0; 3955.0; 3980.0; 4005.0;
4030.0; 4055.0; 4080.0; 4105.0; 4130.0
|]
let sdss_u_thru =
[|
0.0; 0.0001; 0.0005; 0.0013; 0.0026; 0.0052;
0.0093; 0.0161; 0.024; 0.0323; 0.0405; 0.0485;
0.0561; 0.0634; 0.07; 0.0756; 0.0803; 0.0848;
0.0883; 0.0917; 0.0959; 0.1001; 0.1029; 0.1044;
0.1053; 0.1063; 0.1075; 0.1085; 0.1084; 0.1064;
0.1024; 0.0966; 0.0887; 0.0787; 0.0672; 0.0549;
0.0413; 0.0268; 0.0145; 0.0075; 0.0042; 0.0022;
0.001; 0.0006; 0.0004; 0.0002; 0.0
|]
(* SDSS *)
let sdss_g_wave =
[|
3630.0; 3655.0; 3680.0; 3705.0; 3730.0; 3755.0;
3780.0; 3805.0; 3830.0; 3855.0; 3880.0; 3905.0;
3930.0; 3955.0; 3980.0; 4005.0; 4030.0; 4055.0;
4080.0; 4105.0; 4130.0; 4155.0; 4180.0; 4205.0;
4230.0; 4255.0; 4280.0; 4305.0; 4330.0; 4355.0;
4380.0; 4405.0; 4430.0; 4455.0; 4480.0; 4505.0;
4530.0; 4555.0; 4580.0; 4605.0; 4630.0; 4655.0;
4680.0; 4705.0; 4730.0; 4755.0; 4780.0; 4805.0;
4830.0; 4855.0; 4880.0; 4905.0; 4930.0; 4955.0;
4980.0; 5005.0; 5030.0; 5055.0; 5080.0; 5105.0;
5130.0; 5155.0; 5180.0; 5205.0; 5230.0; 5255.0;
5280.0; 5305.0; 5330.0; 5355.0; 5380.0; 5405.0;
5430.0; 5455.0; 5480.0; 5505.0; 5530.0; 5555.0;
5580.0; 5605.0; 5630.0; 5655.0; 5680.0; 5705.0;
5730.0; 5755.0; 5780.0; 5805.0; 5830.0
|]
let sdss_g_thru =
[|
0.0; 0.0003; 0.0008; 0.0013; 0.0019; 0.0024;
0.0034; 0.0055; 0.0103; 0.0194; 0.0326; 0.0492;
0.0686; 0.09; 0.1123; 0.1342; 0.1545; 0.1722;
0.1873; 0.2003; 0.2116; 0.2214; 0.2301; 0.2378;
0.2448; 0.2513; 0.2574; 0.2633; 0.2691; 0.2747;
0.2801; 0.2852; 0.2899; 0.294; 0.2979; 0.3016;
0.3055; 0.3097; 0.3141; 0.3184; 0.3224; 0.3257;
0.3284; 0.3307; 0.3327; 0.3346; 0.3364; 0.3383;
0.3403; 0.3425; 0.3448; 0.3472; 0.3495; 0.3519;
0.3541; 0.3562; 0.3581; 0.3597; 0.3609; 0.3613;
0.3609; 0.3595; 0.3581; 0.3558; 0.3452; 0.3194;
0.2807; 0.2339; 0.1839; 0.1352; 0.0911; 0.0548;
0.0295; 0.0166; 0.0112; 0.0077; 0.005; 0.0032;
0.0021; 0.0015; 0.0012; 0.001; 0.0009; 0.0008;
0.0006; 0.0005; 0.0003; 0.0001; 0.0
|]
(* SDSS *)
let sdss_r_wave =
[|
5380.0; 5405.0; 5430.0; 5455.0; 5480.0; 5505.0;
5530.0; 5555.0; 5580.0; 5605.0; 5630.0; 5655.0;
5680.0; 5705.0; 5730.0; 5755.0; 5780.0; 5805.0;
5830.0; 5855.0; 5880.0; 5905.0; 5930.0; 5955.0;
5980.0; 6005.0; 6030.0; 6055.0; 6080.0; 6105.0;
6130.0; 6155.0; 6180.0; 6205.0; 6230.0; 6255.0;
6280.0; 6305.0; 6330.0; 6355.0; 6380.0; 6405.0;
6430.0; 6455.0; 6480.0; 6505.0; 6530.0; 6555.0;
6580.0; 6605.0; 6630.0; 6655.0; 6680.0; 6705.0;
6730.0; 6755.0; 6780.0; 6805.0; 6830.0; 6855.0;
6880.0; 6905.0; 6930.0; 6955.0; 6980.0; 7005.0;
7030.0; 7055.0; 7080.0; 7105.0; 7130.0; 7155.0;
7180.0; 7205.0; 7230.0
|]
let sdss_r_thru =
[|
0.0; 0.0014; 0.0099; 0.0259; 0.0497; 0.0807;
0.1186; 0.1625; 0.2093; 0.2555; 0.2975; 0.3326;
0.3609; 0.3834; 0.401; 0.4147; 0.4253; 0.4333;
0.4395; 0.4446; 0.4489; 0.4527; 0.4563; 0.4599;
0.4634; 0.4665; 0.4689; 0.4703; 0.4711; 0.4717;
0.4727; 0.4744; 0.4767; 0.4792; 0.4819; 0.4844;
0.4867; 0.4887; 0.4902; 0.4909; 0.4912; 0.4912;
0.4912; 0.4914; 0.4915; 0.4912; 0.4901; 0.4878;
0.4852; 0.4818; 0.4697; 0.4421; 0.4009; 0.3499;
0.2924; 0.2318; 0.1715; 0.1152; 0.0687; 0.038;
0.0212; 0.0134; 0.0099; 0.0076; 0.0055; 0.0039;
0.0027; 0.002; 0.0015; 0.0012; 0.001; 0.0007;
0.0004; 0.0002; 0.0
|]
(* SDSS *)
let sdss_i_wave =
[|
6430.0; 6455.0; 6480.0; 6505.0; 6530.0; 6555.0;
6580.0; 6605.0; 6630.0; 6655.0; 6680.0; 6705.0;
6730.0; 6755.0; 6780.0; 6805.0; 6830.0; 6855.0;
6880.0; 6905.0; 6930.0; 6955.0; 6980.0; 7005.0;
7030.0; 7055.0; 7080.0; 7105.0; 7130.0; 7155.0;
7180.0; 7205.0; 7230.0; 7255.0; 7280.0; 7305.0;
7330.0; 7355.0; 7380.0; 7405.0; 7430.0; 7455.0;
7480.0; 7505.0; 7530.0; 7555.0; 7580.0; 7605.0;
7630.0; 7655.0; 7680.0; 7705.0; 7730.0; 7755.0;
7780.0; 7805.0; 7830.0; 7855.0; 7880.0; 7905.0;
7930.0; 7955.0; 7980.0; 8005.0; 8030.0; 8055.0;
8080.0; 8105.0; 8130.0; 8155.0; 8180.0; 8205.0;
8230.0; 8255.0; 8280.0; 8305.0; 8330.0; 8355.0;
8380.0; 8405.0; 8430.0; 8455.0; 8480.0; 8505.0;
8530.0; 8555.0; 8580.0; 8605.0; 8630.0
|]
let sdss_i_thru =
[|
0.0; 0.0001; 0.0003; 0.0004; 0.0004; 0.0003;
0.0003; 0.0004; 0.0009; 0.0019; 0.0034; 0.0056;
0.0103; 0.0194; 0.0344; 0.0561; 0.0839; 0.1164;
0.1528; 0.1948; 0.2408; 0.2857; 0.3233; 0.3503;
0.3759; 0.399; 0.4162; 0.4233; 0.4165; 0.3943;
0.376; 0.3823; 0.3918; 0.3892; 0.3828; 0.382;
0.3884; 0.3872; 0.3821; 0.3787; 0.3759; 0.3727;
0.3681; 0.3618; 0.3565; 0.3554; 0.3478; 0.1473;
0.2096; 0.2648; 0.33; 0.3256; 0.3223; 0.3179;
0.3129; 0.3077; 0.3026; 0.298; 0.2944; 0.2921;
0.2916; 0.2921; 0.2927; 0.2923; 0.2896; 0.284;
0.2758; 0.2642; 0.2427; 0.2091; 0.1689; 0.1276;
0.0901; 0.0603; 0.0378; 0.0218; 0.0117; 0.0068;
0.0048; 0.0033; 0.002; 0.0013; 0.001; 0.0009;
0.0009; 0.0008; 0.0005; 0.0002; 0.0
|]
(* SDSS *)
let sdss_z_wave =
[|
7730.0; 7755.0; 7780.0; 7805.0; 7830.0; 7855.0;
7880.0; 7905.0; 7930.0; 7955.0; 7980.0; 8005.0;
8030.0; 8055.0; 8080.0; 8105.0; 8130.0; 8155.0;
8180.0; 8205.0; 8230.0; 8255.0; 8280.0; 8305.0;
8330.0; 8355.0; 8380.0; 8405.0; 8430.0; 8455.0;
8480.0; 8505.0; 8530.0; 8555.0; 8580.0; 8605.0;
8630.0; 8655.0; 8680.0; 8705.0; 8730.0; 8755.0;
8780.0; 8805.0; 8830.0; 8855.0; 8880.0; 8905.0;
8930.0; 8955.0; 8980.0; 9005.0; 9030.0; 9055.0;
9080.0; 9105.0; 9130.0; 9155.0; 9180.0; 9205.0;
9230.0; 9255.0; 9280.0; 9305.0; 9330.0; 9355.0;
9380.0; 9405.0; 9430.0; 9455.0; 9480.0; 9505.0;
9530.0; 9555.0; 9580.0; 9605.0; 9630.0; 9655.0;
9680.0; 9705.0; 9730.0; 9755.0; 9780.0; 9805.0;
9830.0; 9855.0; 9880.0; 9905.0; 9930.0; 9955.0;
9980.0; 10005.0; 10030.0; 10055.0; 10080.0; 10105.0;
10130.0; 10155.0; 10180.0; 10205.0; 10230.0; 10255.0;
10280.0; 10305.0; 10330.0; 10355.0; 10380.0; 10405.0;
10430.0; 10455.0; 10480.0; 10505.0; 10530.0; 10555.0;
10580.0; 10605.0; 10630.0; 10655.0; 10680.0; 10705.0;
10730.0; 10755.0; 10780.0; 10805.0; 10830.0; 10855.0;
10880.0; 10905.0; 10930.0; 10955.0; 10980.0; 11005.0;
11030.0; 11055.0; 11080.0; 11105.0; 11130.0; 11155.0;
11180.0; 11205.0; 11230.0
|]
let sdss_z_thru =
[|
0.0; 0.0; 0.0001; 0.0001; 0.0001; 0.0002;
0.0002; 0.0003; 0.0005; 0.0007; 0.0011; 0.0017;
0.0027; 0.004; 0.0057; 0.0079; 0.0106; 0.0139;
0.0178; 0.0222; 0.0271; 0.0324; 0.0382; 0.0446;
0.0511; 0.0564; 0.0603; 0.0637; 0.0667; 0.0694;
0.0717; 0.0736; 0.0752; 0.0765; 0.0775; 0.0782;
0.0786; 0.0787; 0.0785; 0.078; 0.0772; 0.0763;
0.0751; 0.0738; 0.0723; 0.0708; 0.0693; 0.0674;
0.0632; 0.0581; 0.0543; 0.0526; 0.0523; 0.0522;
0.0512; 0.0496; 0.0481; 0.0473; 0.0476; 0.0482;
0.0476; 0.0447; 0.0391; 0.0329; 0.0283; 0.0264;
0.0271; 0.0283; 0.0275; 0.0254; 0.0252; 0.0256;
0.0246; 0.0244; 0.0252; 0.0258; 0.0265; 0.0274;
0.0279; 0.0271; 0.0252; 0.0236; 0.0227; 0.0222;
0.0216; 0.0208; 0.0196; 0.0183; 0.0171; 0.016;
0.0149; 0.0138; 0.0128; 0.0118; 0.0108; 0.0099;
0.0091; 0.0083; 0.0075; 0.0068; 0.0061; 0.0055;
0.005; 0.0045; 0.0041; 0.0037; 0.0033; 0.003;
0.0027; 0.0025; 0.0023; 0.0021; 0.0019; 0.0018;
0.0017; 0.0016; 0.0015; 0.0014; 0.0013; 0.0012;
0.0011; 0.001; 0.0009; 0.0008; 0.0008; 0.0007;
0.0006; 0.0006; 0.0006; 0.0005; 0.0005; 0.0004;
0.0004; 0.0003; 0.0003; 0.0002; 0.0002; 0.0001;
0.0001; 0.0; 0.0
|]
(* Johnson-Cousins *)
let johnson_u_wave =
[|
3000.0; 3100.0; 3200.0; 3300.0; 3400.0; 3500.0;
3600.0; 3700.0; 3800.0; 3900.0; 4000.0; 4100.0;
4200.0
|]
let johnson_u_thru =
[|
0.0; 0.1; 0.61; 0.84; 0.93; 0.97;
1.0; 0.97; 0.73; 0.36; 0.05; 0.01;
0.0
|]
(* Johnson-Cousins *)
let johnson_b_wave =
[|
3700.0; 3800.0; 4000.0; 4200.0; 4400.0; 4600.0;
4800.0; 5000.0; 5200.0; 5400.0; 5600.0
|]
let johnson_b_thru =
[|
0.0; 0.11; 0.92; 1.0; 0.94; 0.79;
0.58; 0.36; 0.15; 0.04; 0.0
|]
(* Johnson-Cousins *)
let johnson_v_wave =
[|
4600.0; 4800.0; 5000.0; 5200.0; 5400.0; 5600.0;
5800.0; 6000.0; 6200.0; 6400.0; 6600.0; 6800.0;
7000.0; 7200.0; 7400.0
|]
let johnson_v_thru =
[|
0.0; 0.02; 0.38; 0.91; 0.98; 0.72;
0.62; 0.4; 0.2; 0.08; 0.02; 0.01;
0.01; 0.01; 0.0
|]
(* Johnson-Cousins *)
let cousins_r_wave =
[|
5400.0; 5450.0; 5500.0; 5550.0; 5600.0; 5650.0;
5700.0; 5750.0; 5800.0; 5850.0; 5900.0; 5950.0;
6000.0; 6050.0; 6100.0; 6150.0; 6200.0; 6250.0;
6300.0; 6350.0; 6400.0; 6450.0; 6500.0; 6550.0;
6600.0; 6650.0; 6700.0; 6750.0; 6800.0; 6850.0;
6900.0; 6950.0; 7000.0; 7050.0; 7100.0; 7150.0;
7200.0; 7250.0; 7300.0; 7350.0; 7400.0; 7450.0;
7500.0; 7550.0; 7600.0; 7650.0; 7700.0; 7750.0;
7800.0; 7850.0; 7900.0; 7950.0; 8000.0
|]
let cousins_r_thru =
[|
0.0; 0.002; 0.01; 0.03; 0.07; 0.18;
0.4; 0.77; 0.89; 0.96; 0.99; 0.999;
1.0; 0.997; 0.99; 0.976; 0.96; 0.946;
0.93; 0.912; 0.895; 0.88; 0.86; 0.845;
0.825; 0.806; 0.788; 0.765; 0.742; 0.72;
0.7; 0.676; 0.65; 0.626; 0.6; 0.568;
0.53; 0.48; 0.395; 0.3; 0.215; 0.155;
0.12; 0.1; 0.085; 0.075; 0.06; 0.05;
0.04; 0.029; 0.02; 0.01; 0.0
|]
(* Johnson-Cousins *)
let cousins_i_wave =
[|
7000.0; 7050.0; 7100.0; 7150.0; 7200.0; 7250.0;
7300.0; 7350.0; 7400.0; 7450.0; 7500.0; 7550.0;
7600.0; 7650.0; 7700.0; 7750.0; 7800.0; 7850.0;
7900.0; 7950.0; 8000.0; 8050.0; 8100.0; 8150.0;
8200.0; 8250.0; 8300.0; 8350.0; 8400.0; 8450.0;
8500.0; 8550.0; 8600.0; 8650.0; 8700.0; 8750.0;
8800.0; 8850.0; 8900.0; 8950.0; 9000.0; 9050.0;
9100.0
|]
let cousins_i_thru =
[|
0.0; 0.005; 0.02; 0.05; 0.1; 0.17;
0.33; 0.7; 0.82; 0.9; 0.95; 0.98;
0.99; 0.994; 0.98; 0.95; 0.913; 0.87;
0.83; 0.79; 0.75; 0.71; 0.673; 0.65;
0.63; 0.61; 0.58; 0.55; 0.51; 0.47;
0.405; 0.33; 0.25; 0.18; 0.14; 0.11;
0.08; 0.06; 0.035; 0.02; 0.01; 0.005;
0.0
|]
(* 2MASS *)
let twomass_j_wave =
[|
10620.0; 10660.0; 10700.0; 10750.0; 10780.0; 10820.0;
10840.0; 10870.0; 10890.0; 10930.0; 10960.0; 11020.0;
11050.0; 11070.0; 11090.0; 11120.0; 11160.0; 11170.0;
11200.0; 11230.0; 11280.0; 11290.0; 11320.0; 11340.0;
11380.0; 11400.0; 11430.0; 11470.0; 11540.0; 11590.0;
11640.0; 11670.0; 11700.0; 11730.0; 11750.0; 11790.0;
11820.0; 11860.0; 11880.0; 11920.0; 11950.0; 11990.0;
12020.0; 12090.0; 12160.0; 12210.0; 12270.0; 12310.0;
12360.0; 12400.0; 12440.0; 12470.0; 12530.0; 12550.0;
12580.0; 12600.0; 12650.0; 12700.0; 12750.0; 12790.0;
12860.0; 12920.0; 12970.0; 13020.0; 13050.0; 13070.0;
13100.0; 13130.0; 13160.0; 13190.0; 13230.0; 13260.0;
13300.0; 13330.0; 13340.0; 13360.0; 13390.0; 13430.0;
13460.0; 13490.0; 13530.0; 13550.0; 13600.0; 13630.0;
13700.0; 13730.0; 13770.0; 13830.0; 13880.0; 13920.0;
13950.0; 13960.0; 13970.0; 13980.0; 14000.0; 14010.0;
14020.0; 14040.0; 14060.0; 14070.0; 14100.0; 14120.0;
14160.0; 14210.0; 14260.0; 14420.0; 14500.0
|]
let twomass_j_thru =
[|
0.0; 0.0004; 0.0015; 0.0027; 0.0055; 0.0123;
0.0203; 0.0306; 0.0405; 0.0515; 0.0564; 0.0718;
0.2736; 0.341; 0.3584; 0.3801; 0.3307; 0.2395;
0.2501; 0.2833; 0.2582; 0.2515; 0.5381; 0.2232;
0.5369; 0.1102; 0.5292; 0.2619; 0.3202; 0.1743;
0.607; 0.6179; 0.6763; 0.7279; 0.7465; 0.8304;
0.7903; 0.8096; 0.8369; 0.836; 0.7499; 0.708;
0.6988; 0.7049; 0.7004; 0.7328; 0.7057; 0.8424;
0.9219; 0.9525; 0.9676; 0.9595; 0.9227; 0.893;
0.8529; 0.8023; 0.7501; 0.6781; 0.6524; 0.6388;
0.6424; 0.6486; 0.6824; 0.7529; 0.7759; 0.8118;
0.777; 0.721; 0.9525; 0.8551; 0.8414; 1.0;
0.8947; 0.8549; 0.5379; 0.2799; 0.9065; 0.6893;
0.5533; 0.2432; 0.0144; 0.0002; 0.0401; 0.0045;
0.0003; 0.0372; 0.0005; 0.0; 0.0001; 0.0033;
0.0003; 0.0085; 0.0254; 0.1184; 0.0001; 0.0001;
0.0521; 0.0104; 0.0478; 0.0004; 0.0024; 0.0053;
0.0086; 0.0007; 0.0003; 0.0004; 0.0
|]
(* 2MASS *)
let twomass_h_wave =
[|
12890.0; 13150.0; 13410.0; 13680.0; 13970.0; 14180.0;
14400.0; 14620.0; 14780.0; 14860.0; 14930.0; 15040.0;
15150.0; 15280.0; 15390.0; 15460.0; 15510.0; 15560.0;
15650.0; 15720.0; 15770.0; 15830.0; 15920.0; 15970.0;
16020.0; 16130.0; 16190.0; 16280.0; 16330.0; 16420.0;
16480.0; 16570.0; 16590.0; 16710.0; 16840.0; 17010.0;
17150.0; 17270.0; 17390.0; 17460.0; 17510.0; 17530.0;
17560.0; 17640.0; 17750.0; 17850.0; 17900.0; 17960.0;
18030.0; 18100.0; 18130.0; 18180.0; 18280.0; 18350.0;
18500.0; 18710.0; 18930.0; 19140.0
|]
let twomass_h_thru =
[|
0.0; 0.0; 0.0; 0.0; 0.0; 0.0;
0.0005; 0.028; 0.081; 0.287; 0.871; 0.201;
0.438; 0.686; 0.818; 0.882; 0.912; 0.927;
0.929; 0.873; 0.857; 0.883; 0.918; 0.927;
0.908; 0.926; 0.92; 0.924; 0.924; 0.942;
0.949; 0.981; 0.994; 1.0; 0.956; 0.924;
0.982; 0.992; 0.989; 0.979; 0.968; 0.937;
0.919; 0.842; 0.667; 0.269; 0.452; 0.173;
0.108; 0.071; 0.005; 0.02; 0.0004; 0.0;
0.0001; 0.0; 0.0; 0.0
|]
(* 2MASS *)
let twomass_ks_wave =
[|
19000.0; 19150.0; 19270.0; 19340.0; 19390.0; 19480.0;
19570.0; 19620.0; 19690.0; 19760.0; 19810.0; 19890.0;
19900.0; 19980.0; 20080.0; 20140.0; 20190.0; 20280.0;
20370.0; 20450.0; 20610.0; 20720.0; 20750.0; 20820.0;
20890.0; 20990.0; 21060.0; 21130.0; 21200.0; 21240.0;
21380.0; 21450.0; 21550.0; 21690.0; 21760.0; 21850.0;
21970.0; 22080.0; 22130.0; 22180.0; 22320.0; 22370.0;
22480.0; 22560.0; 22600.0; 22630.0; 22650.0; 22700.0;
22720.0; 22760.0; 22770.0; 22810.0; 22840.0; 22860.0;
22910.0; 22930.0; 22950.0; 22970.0; 22990.0; 23060.0;
23110.0; 23160.0; 23200.0; 23250.0; 23280.0; 23350.0;
23390.0; 23440.0; 23460.0; 23520.0; 23610.0; 23630.0;
23700.0; 23750.0; 23840.0; 23990.0
|]
let twomass_ks_thru =
[|
0.0; 0.0; 0.0; 0.0002; 0.0005; 0.0054;
0.0119; 0.0197; 0.0422; 0.0873; 0.1528; 0.2482;
0.1902; 0.2339; 0.2946; 0.3982; 0.3366; 0.6207;
0.765; 0.7464; 0.6251; 0.7255; 0.6895; 0.7879;
0.8181; 0.8228; 0.8633; 0.8778; 0.8549; 0.8953;
0.9189; 0.9268; 0.9267; 0.9009; 0.9228; 0.8428;
0.9459; 0.9804; 0.9879; 0.9848; 0.9647; 0.9816;
0.9834; 0.9613; 0.9792; 1.0; 0.9632; 0.9812;
0.9681; 0.9109; 0.9821; 0.8896; 0.8918; 0.9424;
0.8404; 0.8042; 0.7077; 0.6576; 0.5607; 0.4437;
0.3482; 0.2302; 0.1626; 0.136; 0.0921; 0.0624;
0.0431; 0.034; 0.031; 0.0118; 0.0068; 0.0007;
0.003; 0.0021; 0.0004; 0.0
|]
(* Gaia DR3 *)
let gaia_g_wave =
[|
3200.0; 3300.0; 3400.0; 3500.0; 3600.0; 3700.0;
3800.0; 3900.0; 4000.0; 4100.0; 4200.0; 4300.0;
4400.0; 4500.0; 4600.0; 4700.0; 4800.0; 4900.0;
5000.0; 5100.0; 5200.0; 5300.0; 5400.0; 5500.0;
5600.0; 5700.0; 5800.0; 5900.0; 6000.0; 6100.0;
6200.0; 6300.0; 6400.0; 6500.0; 6600.0; 6700.0;
6800.0; 6900.0; 7000.0; 7100.0; 7200.0; 7300.0;
7400.0; 7500.0; 7600.0; 7700.0; 7800.0; 7900.0;
8000.0; 8100.0; 8200.0; 8300.0; 8400.0; 8500.0;
8600.0; 8700.0; 8800.0; 8900.0; 9000.0; 9100.0;
9200.0; 9300.0; 9400.0; 9500.0; 9600.0; 9700.0;
9800.0; 9900.0; 10000.0; 10100.0; 10200.0; 10300.0;
10400.0; 10500.0
|]
let gaia_g_thru =
[|
2.37366962e-08; 0.00976875472; 0.0868837415; 0.125910068; 0.121442511; 0.109349045;
0.116293195; 0.204618287; 0.34084777; 0.433235889; 0.492915186; 0.532506055;
0.560121042; 0.58187167; 0.598921356; 0.612743401; 0.624456273; 0.634592054;
0.642876868; 0.651384274; 0.659234285; 0.665180853; 0.672624175; 0.677892686;
0.68337283; 0.688218588; 0.692909244; 0.698360314; 0.701281364; 0.705926392;
0.709945761; 0.712286557; 0.714900215; 0.716852196; 0.718062023; 0.717424017;
0.716404699; 0.713025742; 0.709495858; 0.702344476; 0.694885081; 0.682863231;
0.670880823; 0.654375536; 0.636105955; 0.615501457; 0.592399976; 0.567402553;
0.539583616; 0.510092228; 0.4791254; 0.447393833; 0.414784905; 0.38035191;
0.347263747; 0.313995072; 0.280491684; 0.249470941; 0.218314877; 0.189578109;
0.162072087; 0.137119296; 0.113758622; 0.0931891382; 0.074983285; 0.058819497;
0.0451523186; 0.0338677803; 0.0245381883; 0.0171045299; 0.0113958923; 0.00725157056;
0.00436700622; 0.00241251048
|]
(* Gaia DR3 *)
let gaia_bp_wave =
[|
3250.0; 3300.0; 3350.0; 3400.0; 3450.0; 3500.0;
3550.0; 3600.0; 3650.0; 3700.0; 3750.0; 3800.0;
3850.0; 3900.0; 3950.0; 4000.0; 4050.0; 4100.0;
4150.0; 4200.0; 4250.0; 4300.0; 4350.0; 4400.0;
4450.0; 4500.0; 4550.0; 4600.0; 4650.0; 4700.0;
4750.0; 4800.0; 4850.0; 4900.0; 4950.0; 5000.0;
5050.0; 5100.0; 5150.0; 5200.0; 5250.0; 5300.0;
5350.0; 5400.0; 5450.0; 5500.0; 5550.0; 5600.0;
5650.0; 5700.0; 5750.0; 5800.0; 5850.0; 5900.0;
5950.0; 6000.0; 6050.0; 6100.0; 6150.0; 6200.0;
6250.0; 6300.0; 6350.0; 6400.0; 6450.0; 6500.0;
6550.0; 6600.0; 6650.0; 6700.0; 6750.0; 6800.0;
6850.0; 6900.0; 6950.0; 7000.0; 7050.0; 7100.0;
7150.0; 7200.0; 7250.0; 7300.0; 7350.0; 7400.0;
7450.0; 7500.0
|]
let gaia_bp_thru =
[|
3.87054116e-05; 0.0109458069; 0.0960352312; 0.209777042; 0.24623711; 0.184648618;
0.196988564; 0.235262373; 0.223965129; 0.204351616; 0.178318209; 0.162143918;
0.184059158; 0.254352193; 0.34761739; 0.432175816; 0.492469156; 0.533465853;
0.560569552; 0.578302699; 0.589904162; 0.59902208; 0.607578555; 0.615301491;
0.623213524; 0.626992584; 0.627863884; 0.627028071; 0.627574894; 0.629195435;
0.632206645; 0.634636782; 0.635341726; 0.635285854; 0.634064731; 0.631462795;
0.63078819; 0.630124067; 0.630179832; 0.630007723; 0.627664462; 0.623347947;
0.621221392; 0.620168751; 0.619955482; 0.622688637; 0.622951427; 0.619327372;
0.614279326; 0.608587274; 0.60526859; 0.613287749; 0.63076648; 0.643202692;
0.641217847; 0.627856079; 0.613625406; 0.613169161; 0.625579651; 0.649530382;
0.666534979; 0.666866457; 0.650929127; 0.620667611; 0.578699833; 0.528381533;
0.46236659; 0.341204564; 0.158484372; 0.0351559339; 0.00370522417; 0.000728910264;
0.000524017362; 0.000284946553; 0.000113610422; 2.0255692e-05; 6.08899493e-06; 4.25533546e-06;
2.88792024e-06; 9.89680713e-07; 4.24900727e-07; 1.27016225e-07; 1.16831386e-07; 7.93884518e-09;
1.27555036e-07; 2.17167412e-08
|]
(* Gaia DR3 *)
let gaia_rp_wave =
[|
6100.0; 6150.0; 6200.0; 6250.0; 6300.0; 6350.0;
6400.0; 6450.0; 6500.0; 6550.0; 6600.0; 6650.0;
6700.0; 6750.0; 6800.0; 6850.0; 6900.0; 6950.0;
7000.0; 7050.0; 7100.0; 7150.0; 7200.0; 7250.0;
7300.0; 7350.0; 7400.0; 7450.0; 7500.0; 7550.0;
7600.0; 7650.0; 7700.0; 7750.0; 7800.0; 7850.0;
7900.0; 7950.0; 8000.0; 8050.0; 8100.0; 8150.0;
8200.0; 8250.0; 8300.0; 8350.0; 8400.0; 8450.0;
8500.0; 8550.0; 8600.0; 8650.0; 8700.0; 8750.0;
8800.0; 8850.0; 8900.0; 8950.0; 9000.0; 9050.0;
9100.0; 9150.0; 9200.0; 9250.0; 9300.0; 9350.0;
9400.0; 9450.0; 9500.0; 9550.0; 9600.0; 9650.0;
9700.0; 9750.0; 9800.0; 9850.0; 9900.0; 9950.0;
10000.0; 10050.0; 10100.0; 10150.0; 10200.0; 10250.0;
10300.0; 10350.0; 10400.0; 10450.0; 10500.0; 10550.0;
10600.0; 10650.0; 10700.0; 10750.0; 10800.0
|]
let gaia_rp_thru =
[|
0.0001067; 0.000705; 0.0089591; 0.0894186; 0.3945348; 0.6832151;
0.7284574; 0.6783742; 0.6932457; 0.6991653; 0.7068345; 0.7168661;
0.7258579; 0.7314582; 0.7317729; 0.729553; 0.7311262; 0.7341997;
0.7375911; 0.7377587; 0.7351913; 0.7317705; 0.7322348; 0.7341152;
0.7395558; 0.7439523; 0.7434368; 0.7401882; 0.7383857; 0.7400737;
0.7391916; 0.7378262; 0.7299905; 0.7234387; 0.7148353; 0.7081058;
0.7045418; 0.7029044; 0.703763; 0.7037788; 0.7012269; 0.698329;
0.6904644; 0.6830179; 0.6750185; 0.6668831; 0.6552453; 0.6437497;
0.6278626; 0.6142203; 0.5984866; 0.5817457; 0.5664293; 0.5505743;
0.5320554; 0.5156898; 0.4998404; 0.4817145; 0.4631831; 0.443315;
0.4236545; 0.4041978; 0.3837304; 0.3611222; 0.3409582; 0.320113;
0.2991975; 0.278412; 0.2555403; 0.2372075; 0.2165387; 0.1977315;
0.1787908; 0.1634732; 0.1453902; 0.1318572; 0.1142639; 0.0987333;
0.0815165; 0.0661173; 0.0521649; 0.0400458; 0.030169; 0.0228553;
0.0165918; 0.0122218; 0.0086189; 0.006114; 0.0042268; 0.0028113;
0.001905; 0.0012324; 0.0007693; 0.0004905; 0.0003028
|]
(* LSST/LSST.u — 60 points *)
let rubin_u_wave =
[|
3.200000e+03; 3.215000e+03; 3.230000e+03; 3.245000e+03; 3.260000e+03; 3.275000e+03; 3.290000e+03; 3.305000e+03;
3.320000e+03; 3.335000e+03; 3.350000e+03; 3.365000e+03; 3.380000e+03; 3.395000e+03; 3.410000e+03; 3.425000e+03;
3.440000e+03; 3.455000e+03; 3.470000e+03; 3.485000e+03; 3.500000e+03; 3.515000e+03; 3.530000e+03; 3.545000e+03;
3.560000e+03; 3.575000e+03; 3.590000e+03; 3.605000e+03; 3.620000e+03; 3.635000e+03; 3.650000e+03; 3.665000e+03;
3.680000e+03; 3.695000e+03; 3.710000e+03; 3.725000e+03; 3.740000e+03; 3.755000e+03; 3.770000e+03; 3.785000e+03;
3.800000e+03; 3.815000e+03; 3.830000e+03; 3.845000e+03; 3.860000e+03; 3.875000e+03; 3.890000e+03; 3.905000e+03;
3.920000e+03; 3.935000e+03; 3.950000e+03; 3.965000e+03; 3.980000e+03; 3.995000e+03; 4.010000e+03; 4.025000e+03;
4.040000e+03; 4.055000e+03; 4.070000e+03; 4.085000e+03
|]
let rubin_u_thru =
[|
1.429550e-14; 5.824880e-03; 9.177360e-03; 1.413040e-02; 2.023590e-02; 2.751190e-02; 3.708220e-02; 4.640890e-02;
5.690710e-02; 6.560040e-02; 7.538320e-02; 8.192530e-02; 8.826960e-02; 9.514300e-02; 1.009060e-01; 1.072220e-01;
1.120190e-01; 1.179670e-01; 1.231450e-01; 1.283730e-01; 1.337000e-01; 1.381080e-01; 1.432610e-01; 1.478000e-01;
1.527230e-01; 1.573360e-01; 1.620670e-01; 1.666840e-01; 1.716940e-01; 1.764620e-01; 1.811790e-01; 1.858970e-01;
1.906280e-01; 1.950920e-01; 1.996840e-01; 2.041020e-01; 2.082430e-01; 2.126950e-01; 2.169940e-01; 2.214680e-01;
2.221790e-01; 2.194800e-01; 2.155940e-01; 2.047200e-01; 1.938620e-01; 1.822830e-01; 1.701760e-01; 1.575460e-01;
1.442870e-01; 1.308010e-01; 1.168670e-01; 1.024170e-01; 8.739400e-02; 7.196930e-02; 5.599790e-02; 3.968800e-02;
2.581460e-02; 1.741130e-02; 8.801300e-03; 2.931970e-04
|]
(* LSST/LSST.g — 60 points *)
let rubin_g_wave =
[|
3.864000e+03; 3.894000e+03; 3.925000e+03; 3.955000e+03; 3.986000e+03; 4.016000e+03; 4.047000e+03; 4.078000e+03;
4.108000e+03; 4.139000e+03; 4.169000e+03; 4.200000e+03; 4.231000e+03; 4.261000e+03; 4.292000e+03; 4.322000e+03;
4.353000e+03; 4.384000e+03; 4.414000e+03; 4.445000e+03; 4.475000e+03; 4.506000e+03; 4.537000e+03; 4.567000e+03;
4.598000e+03; 4.628000e+03; 4.659000e+03; 4.690000e+03; 4.720000e+03; 4.751000e+03; 4.781000e+03; 4.812000e+03;
4.842000e+03; 4.873000e+03; 4.904000e+03; 4.934000e+03; 4.965000e+03; 4.995000e+03; 5.026000e+03; 5.057000e+03;
5.087000e+03; 5.118000e+03; 5.148000e+03; 5.179000e+03; 5.210000e+03; 5.240000e+03; 5.271000e+03; 5.301000e+03;
5.332000e+03; 5.363000e+03; 5.393000e+03; 5.424000e+03; 5.454000e+03; 5.485000e+03; 5.516000e+03; 5.546000e+03;
5.577000e+03; 5.607000e+03; 5.638000e+03; 5.669000e+03
|]
let rubin_g_thru =
[|
4.995720e-14; 1.504200e-02; 3.744180e-02; 7.157070e-02; 1.087810e-01; 1.464570e-01; 1.868340e-01; 2.288970e-01;
2.711150e-01; 3.057660e-01; 3.235230e-01; 3.295650e-01; 3.352140e-01; 3.406260e-01; 3.460290e-01; 3.509490e-01;
3.552980e-01; 3.595040e-01; 3.634800e-01; 3.669990e-01; 3.708900e-01; 3.741200e-01; 3.769130e-01; 3.795000e-01;
3.822150e-01; 3.843350e-01; 3.866750e-01; 3.889150e-01; 3.912650e-01; 3.926710e-01; 3.941770e-01; 3.949940e-01;
3.969120e-01; 3.983260e-01; 3.989640e-01; 3.998360e-01; 4.005050e-01; 4.012760e-01; 4.004640e-01; 4.001320e-01;
4.011100e-01; 4.024610e-01; 4.033200e-01; 4.036730e-01; 4.041100e-01; 4.038960e-01; 4.034070e-01; 4.035910e-01;
4.037410e-01; 4.053230e-01; 3.897960e-01; 3.553770e-01; 3.076690e-01; 2.585420e-01; 2.090320e-01; 1.607920e-01;
1.108210e-01; 6.215070e-02; 2.627740e-02; 8.320810e-04
|]
(* LSST/LSST.r — 60 points *)
let rubin_r_wave =
[|
5.370000e+03; 5.398000e+03; 5.427000e+03; 5.455000e+03; 5.484000e+03; 5.513000e+03; 5.541000e+03; 5.570000e+03;
5.599000e+03; 5.627000e+03; 5.656000e+03; 5.684000e+03; 5.713000e+03; 5.742000e+03; 5.770000e+03; 5.799000e+03;
5.828000e+03; 5.856000e+03; 5.885000e+03; 5.913000e+03; 5.942000e+03; 5.971000e+03; 5.999000e+03; 6.028000e+03;
6.057000e+03; 6.085000e+03; 6.114000e+03; 6.142000e+03; 6.171000e+03; 6.200000e+03; 6.228000e+03; 6.257000e+03;
6.286000e+03; 6.314000e+03; 6.343000e+03; 6.371000e+03; 6.400000e+03; 6.429000e+03; 6.457000e+03; 6.486000e+03;
6.515000e+03; 6.543000e+03; 6.572000e+03; 6.600000e+03; 6.629000e+03; 6.658000e+03; 6.686000e+03; 6.715000e+03;
6.744000e+03; 6.772000e+03; 6.801000e+03; 6.829000e+03; 6.858000e+03; 6.887000e+03; 6.915000e+03; 6.944000e+03;
6.973000e+03; 7.001000e+03; 7.030000e+03; 7.059000e+03
|]
let rubin_r_thru =
[|
4.419110e-13; 2.309990e-02; 5.277260e-02; 9.905770e-02; 1.473500e-01; 1.958210e-01; 2.426460e-01; 2.913780e-01;
3.398970e-01; 3.865220e-01; 4.118140e-01; 4.177100e-01; 4.186640e-01; 4.200420e-01; 4.218920e-01; 4.241900e-01;
4.259640e-01; 4.276140e-01; 4.258520e-01; 4.267580e-01; 4.263350e-01; 4.286710e-01; 4.309470e-01; 4.325390e-01;
4.342740e-01; 4.365000e-01; 4.394080e-01; 4.418750e-01; 4.444110e-01; 4.462740e-01; 4.484890e-01; 4.503210e-01;
4.432760e-01; 4.520930e-01; 4.562140e-01; 4.584970e-01; 4.602860e-01; 4.625800e-01; 4.638100e-01; 4.629030e-01;
4.637690e-01; 4.655010e-01; 4.663710e-01; 4.690410e-01; 4.694020e-01; 4.697590e-01; 4.700940e-01; 4.709560e-01;
4.711620e-01; 4.671030e-01; 4.403810e-01; 3.889860e-01; 3.325260e-01; 2.460020e-01; 2.127620e-01; 1.675560e-01;
1.164150e-01; 6.332250e-02; 2.799150e-02; 9.340470e-04
|]
(* LSST/LSST.i — 60 points *)
let rubin_i_wave =
[|
6.760000e+03; 6.786000e+03; 6.813000e+03; 6.839000e+03; 6.866000e+03; 6.892000e+03; 6.919000e+03; 6.946000e+03;
6.972000e+03; 6.999000e+03; 7.025000e+03; 7.052000e+03; 7.079000e+03; 7.105000e+03; 7.132000e+03; 7.158000e+03;
7.185000e+03; 7.212000e+03; 7.238000e+03; 7.265000e+03; 7.291000e+03; 7.318000e+03; 7.345000e+03; 7.371000e+03;
7.398000e+03; 7.424000e+03; 7.451000e+03; 7.478000e+03; 7.504000e+03; 7.531000e+03; 7.557000e+03; 7.584000e+03;
7.610000e+03; 7.637000e+03; 7.664000e+03; 7.690000e+03; 7.717000e+03; 7.743000e+03; 7.770000e+03; 7.797000e+03;
7.823000e+03; 7.850000e+03; 7.876000e+03; 7.903000e+03; 7.930000e+03; 7.956000e+03; 7.983000e+03; 8.009000e+03;
8.036000e+03; 8.063000e+03; 8.089000e+03; 8.116000e+03; 8.142000e+03; 8.169000e+03; 8.196000e+03; 8.222000e+03;
8.249000e+03; 8.275000e+03; 8.302000e+03; 8.329000e+03
|]
let rubin_i_thru =
[|
8.017680e-13; 2.428840e-02; 5.232930e-02; 1.008480e-01; 1.393260e-01; 1.741710e-01; 2.376070e-01; 2.935610e-01;
3.473580e-01; 3.942050e-01; 4.356630e-01; 4.633490e-01; 4.648150e-01; 4.656090e-01; 4.657820e-01; 4.634080e-01;
4.285660e-01; 4.468650e-01; 4.358040e-01; 4.434830e-01; 4.449560e-01; 4.513710e-01; 4.624340e-01; 4.619650e-01;
4.645030e-01; 4.657510e-01; 4.657690e-01; 4.658270e-01; 4.653330e-01; 4.654000e-01; 4.648360e-01; 4.636220e-01;
1.706640e-01; 2.698000e-01; 4.015270e-01; 4.520990e-01; 4.617890e-01; 4.617680e-01; 4.610620e-01; 4.606790e-01;
4.597540e-01; 4.585490e-01; 4.570360e-01; 4.535390e-01; 4.537330e-01; 4.537350e-01; 4.537880e-01; 4.513960e-01;
4.520760e-01; 4.306080e-01; 3.909260e-01; 3.398770e-01; 2.831320e-01; 2.292450e-01; 1.858430e-01; 1.434810e-01;
9.907340e-02; 5.212160e-02; 2.456530e-02; 8.945460e-04
|]
(* LSST/LSST.z — 60 points *)
let rubin_z_wave =
[|
8.030000e+03; 8.052000e+03; 8.075000e+03; 8.098000e+03; 8.121000e+03; 8.144000e+03; 8.167000e+03; 8.190000e+03;
8.213000e+03; 8.236000e+03; 8.259000e+03; 8.282000e+03; 8.305000e+03; 8.328000e+03; 8.351000e+03; 8.374000e+03;
8.397000e+03; 8.420000e+03; 8.443000e+03; 8.466000e+03; 8.489000e+03; 8.512000e+03; 8.535000e+03; 8.558000e+03;
8.581000e+03; 8.604000e+03; 8.627000e+03; 8.650000e+03; 8.673000e+03; 8.696000e+03; 8.718000e+03; 8.741000e+03;
8.764000e+03; 8.787000e+03; 8.810000e+03; 8.833000e+03; 8.856000e+03; 8.879000e+03; 8.902000e+03; 8.925000e+03;
8.948000e+03; 8.971000e+03; 8.994000e+03; 9.017000e+03; 9.040000e+03; 9.063000e+03; 9.086000e+03; 9.109000e+03;
9.132000e+03; 9.155000e+03; 9.178000e+03; 9.201000e+03; 9.224000e+03; 9.247000e+03; 9.270000e+03; 9.293000e+03;
9.316000e+03; 9.339000e+03; 9.362000e+03; 9.385000e+03
|]
let rubin_z_thru =
[|
1.039400e-12; 1.983520e-02; 4.060140e-02; 7.732710e-02; 1.178440e-01; 1.535950e-01; 1.866070e-01; 2.287220e-01;
2.798410e-01; 2.967030e-01; 3.561920e-01; 3.877460e-01; 4.214340e-01; 4.411140e-01; 4.470510e-01; 4.475610e-01;
4.480250e-01; 4.471650e-01; 4.471040e-01; 4.466870e-01; 4.464830e-01; 4.458890e-01; 4.466020e-01; 4.471110e-01;
4.474290e-01; 4.476080e-01; 4.474960e-01; 4.474640e-01; 4.472600e-01; 4.469260e-01; 4.456110e-01; 4.441420e-01;
4.428510e-01; 4.412100e-01; 4.406520e-01; 4.402080e-01; 4.389510e-01; 4.380910e-01; 4.371550e-01; 4.326220e-01;
4.198210e-01; 3.961630e-01; 3.715530e-01; 3.863600e-01; 4.102820e-01; 3.960420e-01; 3.749220e-01; 3.652600e-01;
3.317280e-01; 2.897530e-01; 2.623570e-01; 2.428380e-01; 2.063130e-01; 1.666310e-01; 1.330530e-01; 8.828810e-02;
4.608740e-02; 1.844720e-02; 9.801210e-03; 2.278930e-04
|]
(* LSST/LSST.y — 60 points *)
let rubin_y_wave =
[|
9.084000e+03; 9.116000e+03; 9.148000e+03; 9.180000e+03; 9.213000e+03; 9.245000e+03; 9.277000e+03; 9.310000e+03;
9.342000e+03; 9.374000e+03; 9.406000e+03; 9.439000e+03; 9.471000e+03; 9.503000e+03; 9.536000e+03; 9.568000e+03;
9.600000e+03; 9.632000e+03; 9.665000e+03; 9.697000e+03; 9.729000e+03; 9.762000e+03; 9.794000e+03; 9.826000e+03;
9.858000e+03; 9.891000e+03; 9.923000e+03; 9.955000e+03; 9.988000e+03; 1.002000e+04; 1.005200e+04; 1.008400e+04;
1.011700e+04; 1.014900e+04; 1.018100e+04; 1.021400e+04; 1.024600e+04; 1.027800e+04; 1.031000e+04; 1.034300e+04;
1.037500e+04; 1.040700e+04; 1.044000e+04; 1.047200e+04; 1.050400e+04; 1.053600e+04; 1.056900e+04; 1.060100e+04;
1.063300e+04; 1.066600e+04; 1.069800e+04; 1.073000e+04; 1.076200e+04; 1.079500e+04; 1.082700e+04; 1.085900e+04;
1.089200e+04; 1.092400e+04; 1.095600e+04; 1.098900e+04
|]
let rubin_y_thru =
[|
4.969710e-13; 2.294700e-02; 5.618380e-02; 9.902450e-02; 1.553400e-01; 1.976760e-01; 2.337680e-01; 2.243520e-01;
1.759080e-01; 2.175850e-01; 2.643880e-01; 2.261180e-01; 2.467280e-01; 2.298940e-01; 2.452360e-01; 2.340200e-01;
2.363950e-01; 2.468170e-01; 2.376270e-01; 2.601730e-01; 2.440820e-01; 2.274890e-01; 2.249930e-01; 2.233450e-01;
2.183580e-01; 2.090180e-01; 1.985730e-01; 1.887810e-01; 1.784510e-01; 1.678880e-01; 1.585850e-01; 1.488040e-01;
1.392830e-01; 1.302160e-01; 1.212720e-01; 1.124510e-01; 1.041190e-01; 9.340160e-02; 8.306840e-02; 7.337710e-02;
6.467130e-02; 5.662540e-02; 4.896100e-02; 4.205990e-02; 3.581300e-02; 3.014650e-02; 2.491370e-02; 2.058860e-02;
1.681940e-02; 1.360580e-02; 1.116100e-02; 9.274090e-03; 7.773550e-03; 6.345540e-03; 5.132820e-03; 3.961830e-03;
3.044690e-03; 2.228830e-03; 1.600220e-03; 1.225340e-03
|]
(* Euclid/VIS.vis — 60 points *)
let euclid_vis_wave =
[|
4.369190e+03; 4.459140e+03; 4.549090e+03; 4.639040e+03; 4.738980e+03; 4.828920e+03; 4.918870e+03; 5.018810e+03;
5.108760e+03; 5.198710e+03; 5.298650e+03; 5.388590e+03; 5.478540e+03; 5.578480e+03; 5.668430e+03; 5.758380e+03;
5.858320e+03; 5.948270e+03; 6.038210e+03; 6.138150e+03; 6.228100e+03; 6.318050e+03; 6.417990e+03; 6.507940e+03;
6.597880e+03; 6.697820e+03; 6.787770e+03; 6.877720e+03; 6.977660e+03; 7.067610e+03; 7.157550e+03; 7.247500e+03;
7.347440e+03; 7.437390e+03; 7.527340e+03; 7.627280e+03; 7.717230e+03; 7.807170e+03; 7.907110e+03; 7.997060e+03;
8.087010e+03; 8.186950e+03; 8.276900e+03; 8.366840e+03; 8.466780e+03; 8.556730e+03; 8.646680e+03; 8.746620e+03;
8.836570e+03; 8.926510e+03; 9.026460e+03; 9.116400e+03; 9.206350e+03; 9.306290e+03; 9.396240e+03; 9.486180e+03;
9.586130e+03; 9.676070e+03; 9.766020e+03; 9.865960e+03
|]
let euclid_vis_thru =
[|
5.667901e-04; 1.630730e-03; 4.172531e-03; 1.124922e-03; 2.177489e-03; 3.386911e-03; 3.641207e-03; 1.371951e-02;
9.284691e-03; 1.350449e-02; 2.210145e-02; 5.507258e-02; 7.012943e-01; 7.157499e-01; 7.257763e-01; 7.345625e-01;
7.426705e-01; 7.485474e-01; 7.527962e-01; 7.552509e-01; 7.566956e-01; 7.574142e-01; 7.585293e-01; 7.588235e-01;
7.574724e-01; 7.558434e-01; 7.571670e-01; 7.570556e-01; 7.567355e-01; 7.559193e-01; 7.545203e-01; 7.533978e-01;
7.508411e-01; 7.461720e-01; 7.414610e-01; 7.350350e-01; 7.301607e-01; 7.229322e-01; 7.122474e-01; 7.009171e-01;
6.870897e-01; 6.690291e-01; 6.520249e-01; 6.316748e-01; 6.043242e-01; 5.787996e-01; 5.515213e-01; 5.182027e-01;
4.851838e-01; 4.521600e-01; 4.136761e-01; 3.779477e-01; 3.006216e-01; 1.259014e-02; 1.804853e-03; 2.027216e-03;
1.289551e-03; 6.535986e-04; 7.194299e-04; 4.038814e-04
|]
(* Euclid/NISP.Y — 60 points *)
let euclid_y_wave =
[|
9.330000e+03; 9.380000e+03; 9.430000e+03; 9.480000e+03; 9.540000e+03; 9.590000e+03; 9.640000e+03; 9.700000e+03;
9.750000e+03; 9.800000e+03; 9.850000e+03; 9.910000e+03; 9.960000e+03; 1.001000e+04; 1.007000e+04; 1.012000e+04;
1.017000e+04; 1.022000e+04; 1.028000e+04; 1.033000e+04; 1.038000e+04; 1.044000e+04; 1.049000e+04; 1.054000e+04;
1.059000e+04; 1.065000e+04; 1.070000e+04; 1.075000e+04; 1.081000e+04; 1.086000e+04; 1.091000e+04; 1.096000e+04;
1.102000e+04; 1.107000e+04; 1.112000e+04; 1.118000e+04; 1.123000e+04; 1.128000e+04; 1.133000e+04; 1.139000e+04;
1.144000e+04; 1.149000e+04; 1.155000e+04; 1.160000e+04; 1.165000e+04; 1.170000e+04; 1.176000e+04; 1.181000e+04;
1.186000e+04; 1.192000e+04; 1.197000e+04; 1.202000e+04; 1.207000e+04; 1.213000e+04; 1.218000e+04; 1.223000e+04;
1.229000e+04; 1.234000e+04; 1.239000e+04; 1.245000e+04
|]
let euclid_y_thru =
[|
1.401100e-04; 9.044250e-04; 2.786910e-02; 1.932270e-01; 7.417070e-01; 7.539270e-01; 7.683890e-01; 7.725430e-01;
7.736840e-01; 7.748860e-01; 7.762310e-01; 7.789970e-01; 7.770510e-01; 7.763010e-01; 7.815180e-01; 7.831650e-01;
7.782630e-01; 7.784520e-01; 7.745790e-01; 7.736720e-01; 7.806700e-01; 7.784520e-01; 7.792640e-01; 7.782380e-01;
7.713900e-01; 7.704340e-01; 7.682960e-01; 7.640030e-01; 7.705670e-01; 7.696290e-01; 7.648760e-01; 7.663780e-01;
7.621110e-01; 7.587620e-01; 7.627900e-01; 7.664810e-01; 7.633970e-01; 7.646770e-01; 7.645850e-01; 7.692010e-01;
7.718380e-01; 7.719360e-01; 7.718390e-01; 7.717390e-01; 7.688890e-01; 7.725960e-01; 7.771790e-01; 7.775880e-01;
7.796710e-01; 7.805190e-01; 7.824850e-01; 7.843440e-01; 7.759730e-01; 3.002440e-01; 7.116640e-02; 1.522080e-02;
3.885880e-03; 1.834220e-03; 1.115430e-03; 6.624430e-04
|]
(* Euclid/NISP.J — 60 points *)
let euclid_j_wave =
[|
1.141000e+04; 1.148000e+04; 1.156000e+04; 1.164000e+04; 1.172000e+04; 1.180000e+04; 1.188000e+04; 1.196000e+04;
1.204000e+04; 1.212000e+04; 1.220000e+04; 1.228000e+04; 1.236000e+04; 1.244000e+04; 1.252000e+04; 1.260000e+04;
1.268000e+04; 1.276000e+04; 1.284000e+04; 1.292000e+04; 1.299000e+04; 1.307000e+04; 1.315000e+04; 1.323000e+04;
1.331000e+04; 1.339000e+04; 1.347000e+04; 1.355000e+04; 1.363000e+04; 1.371000e+04; 1.379000e+04; 1.387000e+04;
1.395000e+04; 1.403000e+04; 1.411000e+04; 1.419000e+04; 1.427000e+04; 1.435000e+04; 1.443000e+04; 1.451000e+04;
1.458000e+04; 1.466000e+04; 1.474000e+04; 1.482000e+04; 1.490000e+04; 1.498000e+04; 1.506000e+04; 1.514000e+04;
1.522000e+04; 1.530000e+04; 1.538000e+04; 1.546000e+04; 1.554000e+04; 1.562000e+04; 1.570000e+04; 1.578000e+04;
1.586000e+04; 1.594000e+04; 1.602000e+04; 1.610000e+04
|]
let euclid_j_thru =
[|
1.576900e-04; 4.015150e-04; 3.417080e-03; 1.226570e-01; 7.426110e-01; 7.817110e-01; 7.813510e-01; 7.840630e-01;
7.888400e-01; 7.907170e-01; 7.833700e-01; 7.884310e-01; 7.896350e-01; 7.852690e-01; 7.966270e-01; 7.958310e-01;
7.988340e-01; 7.953290e-01; 7.964360e-01; 7.932720e-01; 7.885410e-01; 7.955600e-01; 7.943190e-01; 7.956280e-01;
8.027170e-01; 8.039270e-01; 8.032210e-01; 7.995800e-01; 8.013920e-01; 8.024890e-01; 7.976110e-01; 7.968730e-01;
7.954540e-01; 7.861820e-01; 7.882250e-01; 7.912090e-01; 7.856070e-01; 7.868450e-01; 7.890830e-01; 7.843430e-01;
7.847770e-01; 7.870230e-01; 7.847020e-01; 7.808650e-01; 7.816630e-01; 7.821060e-01; 7.840170e-01; 7.832910e-01;
7.825870e-01; 7.872660e-01; 7.816920e-01; 7.772310e-01; 7.810290e-01; 6.745540e-01; 2.009210e-01; 2.169140e-02;
3.511730e-03; 9.199010e-04; 3.143400e-04; 1.453270e-04
|]
(* Euclid/NISP.H — 60 points *)
let euclid_h_wave =
[|
1.480000e+04; 1.489000e+04; 1.499000e+04; 1.509000e+04; 1.519000e+04; 1.529000e+04; 1.539000e+04; 1.549000e+04;
1.559000e+04; 1.569000e+04; 1.579000e+04; 1.589000e+04; 1.599000e+04; 1.609000e+04; 1.619000e+04; 1.629000e+04;
1.639000e+04; 1.649000e+04; 1.659000e+04; 1.669000e+04; 1.678000e+04; 1.688000e+04; 1.698000e+04; 1.708000e+04;
1.718000e+04; 1.728000e+04; 1.738000e+04; 1.748000e+04; 1.758000e+04; 1.768000e+04; 1.778000e+04; 1.788000e+04;
1.798000e+04; 1.808000e+04; 1.818000e+04; 1.828000e+04; 1.838000e+04; 1.848000e+04; 1.858000e+04; 1.868000e+04;
1.877000e+04; 1.887000e+04; 1.897000e+04; 1.907000e+04; 1.917000e+04; 1.927000e+04; 1.937000e+04; 1.947000e+04;
1.957000e+04; 1.967000e+04; 1.977000e+04; 1.987000e+04; 1.997000e+04; 2.007000e+04; 2.017000e+04; 2.027000e+04;
2.037000e+04; 2.047000e+04; 2.057000e+04; 2.067000e+04
|]
let euclid_h_thru =
[|
1.433800e-04; 3.416300e-04; 1.371980e-03; 1.165150e-02; 2.166120e-01; 7.653910e-01; 7.770660e-01; 7.766830e-01;
7.792960e-01; 7.733530e-01; 7.817380e-01; 7.820990e-01; 7.830000e-01; 7.815810e-01; 7.808620e-01; 7.824440e-01;
7.788240e-01; 7.785320e-01; 7.810690e-01; 7.777990e-01; 7.773880e-01; 7.825950e-01; 7.841750e-01; 7.836200e-01;
7.841940e-01; 7.853450e-01; 7.824440e-01; 7.815110e-01; 7.833050e-01; 7.838540e-01; 7.843000e-01; 7.839500e-01;
7.850730e-01; 7.855760e-01; 7.873750e-01; 7.901270e-01; 7.888930e-01; 7.901500e-01; 7.918380e-01; 7.927560e-01;
7.897410e-01; 7.867750e-01; 7.847490e-01; 7.831360e-01; 7.785180e-01; 7.761400e-01; 7.759030e-01; 7.723980e-01;
7.665570e-01; 7.651970e-01; 7.639360e-01; 7.611570e-01; 7.567070e-01; 7.462570e-01; 5.895400e-01; 1.360190e-01;
1.785060e-02; 3.091850e-03; 7.171190e-04; 1.507920e-04
|]
================================================
FILE: dev/umbra/lib/filters.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let f64 = Nx.float64
let angstrom_to_m = 1e-10
let make wave_a thru_a =
let n = Array.length wave_a in
let w = Nx.create f64 [| n |] wave_a in
let w = Nx.mul_s w angstrom_to_m in
let t = Nx.create f64 [| n |] thru_a in
Photometry.bandpass ~wavelength:(Unit.Length.of_tensor w) ~throughput:t
(* SDSS *)
let sdss_u = make Filter_data.sdss_u_wave Filter_data.sdss_u_thru
let sdss_g = make Filter_data.sdss_g_wave Filter_data.sdss_g_thru
let sdss_r = make Filter_data.sdss_r_wave Filter_data.sdss_r_thru
let sdss_i = make Filter_data.sdss_i_wave Filter_data.sdss_i_thru
let sdss_z = make Filter_data.sdss_z_wave Filter_data.sdss_z_thru
(* Johnson-Cousins *)
let johnson_u = make Filter_data.johnson_u_wave Filter_data.johnson_u_thru
let johnson_b = make Filter_data.johnson_b_wave Filter_data.johnson_b_thru
let johnson_v = make Filter_data.johnson_v_wave Filter_data.johnson_v_thru
let cousins_r = make Filter_data.cousins_r_wave Filter_data.cousins_r_thru
let cousins_i = make Filter_data.cousins_i_wave Filter_data.cousins_i_thru
(* 2MASS *)
let twomass_j = make Filter_data.twomass_j_wave Filter_data.twomass_j_thru
let twomass_h = make Filter_data.twomass_h_wave Filter_data.twomass_h_thru
let twomass_ks = make Filter_data.twomass_ks_wave Filter_data.twomass_ks_thru
(* Gaia DR3 *)
let gaia_g = make Filter_data.gaia_g_wave Filter_data.gaia_g_thru
let gaia_bp = make Filter_data.gaia_bp_wave Filter_data.gaia_bp_thru
let gaia_rp = make Filter_data.gaia_rp_wave Filter_data.gaia_rp_thru
(* Rubin/LSST *)
let rubin_u = make Filter_data.rubin_u_wave Filter_data.rubin_u_thru
let rubin_g = make Filter_data.rubin_g_wave Filter_data.rubin_g_thru
let rubin_r = make Filter_data.rubin_r_wave Filter_data.rubin_r_thru
let rubin_i = make Filter_data.rubin_i_wave Filter_data.rubin_i_thru
let rubin_z = make Filter_data.rubin_z_wave Filter_data.rubin_z_thru
let rubin_y = make Filter_data.rubin_y_wave Filter_data.rubin_y_thru
(* Euclid *)
let euclid_vis = make Filter_data.euclid_vis_wave Filter_data.euclid_vis_thru
let euclid_y = make Filter_data.euclid_y_wave Filter_data.euclid_y_thru
let euclid_j = make Filter_data.euclid_j_wave Filter_data.euclid_j_thru
let euclid_h = make Filter_data.euclid_h_wave Filter_data.euclid_h_thru
================================================
FILE: dev/umbra/lib/filters.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Standard astronomical filter bandpasses.
Tabulated transmission curves from the
{{:https://svo2.cab.inta-csic.es/theory/fps/} SVO Filter Profile Service}.
Each value is a pre-built {!Photometry.bandpass}.
{[
let mag = Photometry.ab_mag Filters.sdss_r sed
]} *)
(** {1:sdss SDSS ugriz} *)
val sdss_u : Photometry.bandpass
(** [sdss_u] is the SDSS u-band (298--413 nm, 47 points). *)
val sdss_g : Photometry.bandpass
(** [sdss_g] is the SDSS g-band (363--583 nm, 89 points). *)
val sdss_r : Photometry.bandpass
(** [sdss_r] is the SDSS r-band (538--723 nm, 75 points). *)
val sdss_i : Photometry.bandpass
(** [sdss_i] is the SDSS i-band (643--863 nm, 89 points). *)
val sdss_z : Photometry.bandpass
(** [sdss_z] is the SDSS z-band (773--1123 nm, 141 points). *)
(** {1:johnson Johnson-Cousins UBVRI} *)
val johnson_u : Photometry.bandpass
(** [johnson_u] is the Johnson U-band (300--420 nm, 13 points). *)
val johnson_b : Photometry.bandpass
(** [johnson_b] is the Johnson B-band (370--560 nm, 11 points). *)
val johnson_v : Photometry.bandpass
(** [johnson_v] is the Johnson V-band (460--740 nm, 15 points). *)
val cousins_r : Photometry.bandpass
(** [cousins_r] is the Cousins R-band (540--800 nm, 53 points). *)
val cousins_i : Photometry.bandpass
(** [cousins_i] is the Cousins I-band (700--910 nm, 43 points). *)
(** {1:twomass 2MASS JHKs} *)
val twomass_j : Photometry.bandpass
(** [twomass_j] is the 2MASS J-band (1062--1450 nm, 107 points). *)
val twomass_h : Photometry.bandpass
(** [twomass_h] is the 2MASS H-band (1289--1914 nm, 58 points). *)
val twomass_ks : Photometry.bandpass
(** [twomass_ks] is the 2MASS Ks-band (1900--2399 nm, 76 points). *)
(** {1:gaia Gaia DR3} *)
val gaia_g : Photometry.bandpass
(** [gaia_g] is the Gaia DR3 G-band (330--1040 nm, 74 points). *)
val gaia_bp : Photometry.bandpass
(** [gaia_bp] is the Gaia DR3 BP-band (328--748 nm, 86 points). *)
val gaia_rp : Photometry.bandpass
(** [gaia_rp] is the Gaia DR3 RP-band (618--1076 nm, 95 points). *)
(** {1:rubin Rubin/LSST ugrizy} *)
val rubin_u : Photometry.bandpass
(** [rubin_u] is the Rubin/LSST u-band (320--409 nm, 60 points). *)
val rubin_g : Photometry.bandpass
(** [rubin_g] is the Rubin/LSST g-band (386--567 nm, 60 points). *)
val rubin_r : Photometry.bandpass
(** [rubin_r] is the Rubin/LSST r-band (537--706 nm, 60 points). *)
val rubin_i : Photometry.bandpass
(** [rubin_i] is the Rubin/LSST i-band (676--833 nm, 60 points). *)
val rubin_z : Photometry.bandpass
(** [rubin_z] is the Rubin/LSST z-band (803--935 nm, 60 points). *)
val rubin_y : Photometry.bandpass
(** [rubin_y] is the Rubin/LSST y-band (908--1099 nm, 60 points). *)
(** {1:euclid Euclid} *)
val euclid_vis : Photometry.bandpass
(** [euclid_vis] is the Euclid VIS-band (437--987 nm, 60 points). *)
val euclid_y : Photometry.bandpass
(** [euclid_y] is the Euclid NISP Y-band (933--1245 nm, 60 points). *)
val euclid_j : Photometry.bandpass
(** [euclid_j] is the Euclid NISP J-band (1141--1610 nm, 60 points). *)
val euclid_h : Photometry.bandpass
(** [euclid_h] is the Euclid NISP H-band (1480--2067 nm, 60 points). *)
================================================
FILE: dev/umbra/lib/fits/dune
================================================
(library
(name umbra_fits)
(public_name umbra.fits)
(private_modules fits_parser)
(libraries nx nx.io talon unix))
================================================
FILE: dev/umbra/lib/fits/fits_parser.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let err_truncated = "Fits: unexpected end of file"
let err_no_simple = "Fits: missing SIMPLE keyword in primary HDU"
let err_bad_tform msg = "Fits: unsupported TFORM: " ^ msg
let block_size = 2880
type keyword = { key : string; value : string; comment : string }
type header = {
keywords : keyword list;
xtension : string;
bitpix : int;
naxis : int array;
data_bytes : int;
}
type col_desc = {
name : string;
tform : char;
repeat : int;
width : int;
tnull : int64 option;
tscal : float;
tzero : float;
}
let swap16 buf pos =
let b0 = Bytes.get_uint8 buf pos in
let b1 = Bytes.get_uint8 buf (pos + 1) in
Bytes.set_uint8 buf pos b1;
Bytes.set_uint8 buf (pos + 1) b0
let swap32 buf pos =
let b0 = Bytes.get_uint8 buf pos in
let b1 = Bytes.get_uint8 buf (pos + 1) in
let b2 = Bytes.get_uint8 buf (pos + 2) in
let b3 = Bytes.get_uint8 buf (pos + 3) in
Bytes.set_uint8 buf pos b3;
Bytes.set_uint8 buf (pos + 1) b2;
Bytes.set_uint8 buf (pos + 2) b1;
Bytes.set_uint8 buf (pos + 3) b0
let swap64 buf pos =
let b0 = Bytes.get_uint8 buf pos in
let b1 = Bytes.get_uint8 buf (pos + 1) in
let b2 = Bytes.get_uint8 buf (pos + 2) in
let b3 = Bytes.get_uint8 buf (pos + 3) in
let b4 = Bytes.get_uint8 buf (pos + 4) in
let b5 = Bytes.get_uint8 buf (pos + 5) in
let b6 = Bytes.get_uint8 buf (pos + 6) in
let b7 = Bytes.get_uint8 buf (pos + 7) in
Bytes.set_uint8 buf pos b7;
Bytes.set_uint8 buf (pos + 1) b6;
Bytes.set_uint8 buf (pos + 2) b5;
Bytes.set_uint8 buf (pos + 3) b4;
Bytes.set_uint8 buf (pos + 4) b3;
Bytes.set_uint8 buf (pos + 5) b2;
Bytes.set_uint8 buf (pos + 6) b1;
Bytes.set_uint8 buf (pos + 7) b0
let trim_right s =
let len = String.length s in
let i = ref (len - 1) in
while !i >= 0 && s.[!i] = ' ' do
decr i
done;
if !i = len - 1 then s else String.sub s 0 (!i + 1)
let parse_card card =
let key = trim_right (String.sub card 0 8) in
if key = "COMMENT" || key = "HISTORY" then
let content =
if String.length card > 8 then
trim_right (String.sub card 8 (String.length card - 8))
else ""
in
{ key; value = content; comment = "" }
else if String.length card < 10 || card.[8] <> '=' || card.[9] <> ' ' then
{ key; value = ""; comment = "" }
else
let rest = String.sub card 10 (String.length card - 10) in
let rest = String.trim rest in
if String.length rest > 0 && rest.[0] = '\'' then begin
let len = String.length rest in
let i = ref 1 in
let buf = Buffer.create 68 in
while !i < len do
if rest.[!i] = '\'' then
begin if !i + 1 < len && rest.[!i + 1] = '\'' then begin
Buffer.add_char buf '\'';
i := !i + 2
end
else i := len
end
else begin
Buffer.add_char buf rest.[!i];
i := !i + 1
end
done;
{ key; value = trim_right (Buffer.contents buf); comment = "" }
end
else
begin match String.index_opt rest '/' with
| Some i ->
let value = trim_right (String.sub rest 0 i) in
let comment =
String.trim (String.sub rest (i + 1) (String.length rest - i - 1))
in
{ key; value; comment }
| None -> { key; value = trim_right rest; comment = "" }
end
let read_one_header ic =
let keywords = ref [] in
let found_end = ref false in
let card_buf = Bytes.create 80 in
while not !found_end do
let block = Bytes.create block_size in
(match In_channel.really_input ic block 0 block_size with
| None -> failwith err_truncated
| Some () -> ());
for card_i = 0 to 35 do
if not !found_end then begin
Bytes.blit block (card_i * 80) card_buf 0 80;
let card = Bytes.to_string card_buf in
let key = trim_right (String.sub card 0 8) in
if key = "END" then found_end := true
else if key <> "" then keywords := parse_card card :: !keywords
end
done
done;
List.rev !keywords
let find_keyword keywords key =
match List.find_opt (fun kw -> kw.key = key) keywords with
| Some kw -> Some kw.value
| None -> None
let find_keyword_int keywords key =
match find_keyword keywords key with
| Some v -> Some (int_of_string (String.trim v))
| None -> None
let find_keyword_exn keywords key =
match find_keyword keywords key with
| Some v -> v
| None -> failwith ("Fits: missing required keyword " ^ key)
let find_keyword_int_exn keywords key =
int_of_string (String.trim (find_keyword_exn keywords key))
let compute_data_bytes keywords =
let bitpix = find_keyword_int_exn keywords "BITPIX" in
let naxis_n = find_keyword_int_exn keywords "NAXIS" in
if naxis_n = 0 then 0
else begin
let total = ref (abs bitpix / 8) in
for i = 1 to naxis_n do
let key = Printf.sprintf "NAXIS%d" i in
total := !total * find_keyword_int_exn keywords key
done;
let pcount =
match find_keyword_int keywords "PCOUNT" with Some v -> v | None -> 0
in
let gcount =
match find_keyword_int keywords "GCOUNT" with Some v -> v | None -> 1
in
(!total + pcount) * gcount
end
let build_header keywords =
let bitpix = find_keyword_int_exn keywords "BITPIX" in
let naxis_n = find_keyword_int_exn keywords "NAXIS" in
let naxis =
Array.init naxis_n (fun i ->
find_keyword_int_exn keywords (Printf.sprintf "NAXIS%d" (i + 1)))
in
let xtension =
match find_keyword keywords "XTENSION" with Some v -> v | None -> ""
in
let data_bytes = compute_data_bytes keywords in
{ keywords; xtension; bitpix; naxis; data_bytes }
let read_headers ic =
In_channel.seek ic 0L;
let headers = ref [] in
let first = ref true in
let continue = ref true in
while !continue do
let keywords = try Some (read_one_header ic) with Failure _ -> None in
match keywords with
| None -> continue := false
| Some keywords ->
if !first then begin
first := false;
match find_keyword keywords "SIMPLE" with
| Some _ -> ()
| None -> failwith err_no_simple
end;
let hdr = build_header keywords in
headers := hdr :: !headers;
let data_blocks =
if hdr.data_bytes = 0 then 0
else (hdr.data_bytes + block_size - 1) / block_size
in
In_channel.seek ic
(Int64.add (In_channel.pos ic)
(Int64.of_int (data_blocks * block_size)))
done;
List.rev !headers
let seek_to_data ic headers hdu =
if hdu < 0 || hdu >= List.length headers then
failwith
(Printf.sprintf "Fits: HDU %d out of range (file has %d)" hdu
(List.length headers));
In_channel.seek ic 0L;
for i = 0 to hdu do
let h = List.nth headers i in
let found_end = ref false in
while not !found_end do
let block = Bytes.create block_size in
(match In_channel.really_input ic block 0 block_size with
| None -> failwith err_truncated
| Some () -> ());
for card_i = 0 to 35 do
if not !found_end then begin
let key = trim_right (Bytes.sub_string block (card_i * 80) 8) in
if key = "END" then found_end := true
end
done
done;
if i < hdu then begin
let data_blocks =
if h.data_bytes = 0 then 0
else (h.data_bytes + block_size - 1) / block_size
in
In_channel.seek ic
(Int64.add (In_channel.pos ic)
(Int64.of_int (data_blocks * block_size)))
end
done;
let h = List.nth headers hdu in
h.data_bytes
let parse_tform s =
let s = String.trim s in
let len = String.length s in
if len = 0 then failwith (err_bad_tform "empty");
let i = ref 0 in
while !i < len && s.[!i] >= '0' && s.[!i] <= '9' do
incr i
done;
let repeat = if !i = 0 then 1 else int_of_string (String.sub s 0 !i) in
if !i >= len then failwith (err_bad_tform s);
let code = s.[!i] in
let width =
match code with
| 'L' -> 1
| 'B' -> 1
| 'I' -> 2
| 'J' -> 4
| 'K' -> 8
| 'E' -> 4
| 'D' -> 8
| 'A' -> 1
| c -> failwith (err_bad_tform (String.make 1 c))
in
(code, repeat, width)
let parse_bintable_cols hdr =
let keywords = hdr.keywords in
let tfields = find_keyword_int_exn keywords "TFIELDS" in
List.init tfields (fun i ->
let col = i + 1 in
let name =
match find_keyword keywords (Printf.sprintf "TTYPE%d" col) with
| Some v -> v
| None -> Printf.sprintf "col%d" col
in
let tform_s = find_keyword_exn keywords (Printf.sprintf "TFORM%d" col) in
let tform, repeat, width = parse_tform tform_s in
let tnull =
match find_keyword keywords (Printf.sprintf "TNULL%d" col) with
| Some v -> Some (Int64.of_string (String.trim v))
| None -> None
in
let tscal =
match find_keyword keywords (Printf.sprintf "TSCAL%d" col) with
| Some v -> float_of_string (String.trim v)
| None -> 1.0
in
let tzero =
match find_keyword keywords (Printf.sprintf "TZERO%d" col) with
| Some v -> float_of_string (String.trim v)
| None -> 0.0
in
{ name; tform; repeat; width; tnull; tscal; tzero })
================================================
FILE: dev/umbra/lib/fits/fits_parser.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(**/**)
(** Internal FITS parser. *)
type keyword = { key : string; value : string; comment : string }
type header = {
keywords : keyword list;
xtension : string;
bitpix : int;
naxis : int array;
data_bytes : int;
}
type col_desc = {
name : string;
tform : char;
repeat : int;
width : int;
tnull : int64 option;
tscal : float;
tzero : float;
}
val read_headers : In_channel.t -> header list
val seek_to_data : In_channel.t -> header list -> int -> int
val parse_bintable_cols : header -> col_desc list
val find_keyword : keyword list -> string -> string option
val find_keyword_int : keyword list -> string -> int option
val trim_right : string -> string
val block_size : int
val swap16 : bytes -> int -> unit
val swap32 : bytes -> int -> unit
val swap64 : bytes -> int -> unit
(**/**)
================================================
FILE: dev/umbra/lib/fits/umbra_fits.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let err_not_bintable = "Fits.read_table: HDU is not a BINTABLE"
let err_not_image = "Fits.read_image: HDU is not an image"
let err_unsupported_bitpix n = Printf.sprintf "Fits: unsupported BITPIX %d" n
let err_truncated_data = "Fits: unexpected end of file in data"
type header_card = { key : string; value : string; comment : string }
type hdu_type = Primary | Image | Bintable | Ascii_table
type hdu_info = {
index : int;
hdu_type : hdu_type;
dimensions : int array;
num_rows : int option;
num_cols : int option;
}
let hdu_type_of_header i (hdr : Fits_parser.header) =
match hdr.xtension with
| "" -> if i = 0 then Primary else Image
| "BINTABLE" -> Bintable
| "TABLE" -> Ascii_table
| "IMAGE" -> Image
| _ -> Image
let read_input ic buf n =
match In_channel.really_input ic buf 0 n with
| None -> failwith err_truncated_data
| Some () -> ()
let info path =
let ic = In_channel.open_bin path in
Fun.protect
~finally:(fun () -> In_channel.close ic)
(fun () ->
let headers = Fits_parser.read_headers ic in
List.mapi
(fun i (hdr : Fits_parser.header) ->
let ht = hdu_type_of_header i hdr in
let num_rows, num_cols =
match ht with
| Bintable | Ascii_table ->
let nrows =
if Array.length hdr.naxis >= 2 then Some hdr.naxis.(1)
else None
in
let ncols =
Fits_parser.find_keyword_int hdr.keywords "TFIELDS"
in
(nrows, ncols)
| _ -> (None, None)
in
{
index = i;
hdu_type = ht;
dimensions = hdr.naxis;
num_rows;
num_cols;
})
headers)
let header ?(hdu = 0) path =
let ic = In_channel.open_bin path in
Fun.protect
~finally:(fun () -> In_channel.close ic)
(fun () ->
let headers = Fits_parser.read_headers ic in
if hdu < 0 || hdu >= List.length headers then
failwith (Printf.sprintf "Fits.header: HDU %d out of range" hdu);
let h = List.nth headers hdu in
List.map
(fun (kw : Fits_parser.keyword) ->
{ key = kw.key; value = kw.value; comment = kw.comment })
h.keywords)
let read_table ?(hdu = 1) path =
let ic = In_channel.open_bin path in
Fun.protect
~finally:(fun () -> In_channel.close ic)
(fun () ->
let headers = Fits_parser.read_headers ic in
if hdu < 0 || hdu >= List.length headers then
failwith (Printf.sprintf "Fits.read_table: HDU %d out of range" hdu);
let h = List.nth headers hdu in
(match hdu_type_of_header hdu h with
| Bintable -> ()
| _ -> failwith err_not_bintable);
let cols = Fits_parser.parse_bintable_cols h in
let nrows = if Array.length h.naxis >= 2 then h.naxis.(1) else 0 in
let row_bytes = h.naxis.(0) in
let (_ : int) = Fits_parser.seek_to_data ic headers hdu in
let row_buf = Bytes.create row_bytes in
let col_info =
List.map
(fun (cd : Fits_parser.col_desc) ->
let elem_bytes = cd.repeat * cd.width in
(cd, Bytes.create (nrows * elem_bytes), elem_bytes))
cols
in
let col_offsets =
let off = ref 0 in
List.map
(fun (cd : Fits_parser.col_desc) ->
let o = !off in
off := !off + (cd.repeat * cd.width);
o)
cols
in
for row = 0 to nrows - 1 do
read_input ic row_buf row_bytes;
List.iter2
(fun offset (_cd, buf, elem_bytes) ->
Bytes.blit row_buf offset buf (row * elem_bytes) elem_bytes)
col_offsets col_info
done;
let err_vector name repeat =
failwith
(Printf.sprintf "Fits: vector column '%s' (repeat=%d) not supported"
name repeat)
in
let talon_cols =
List.map
(fun (cd, buf, _) ->
let col =
match cd.Fits_parser.tform with
| 'E' ->
if cd.repeat <> 1 then err_vector cd.name cd.repeat;
Talon.Col.float32
(Array.init nrows (fun i ->
let pos = i * 4 in
Fits_parser.swap32 buf pos;
let v =
Int32.float_of_bits (Bytes.get_int32_le buf pos)
in
if cd.tzero = 0.0 && cd.tscal = 1.0 then v
else (v *. cd.tscal) +. cd.tzero))
| 'D' ->
if cd.repeat <> 1 then err_vector cd.name cd.repeat;
Talon.Col.float64
(Array.init nrows (fun i ->
let pos = i * 8 in
Fits_parser.swap64 buf pos;
let v =
Int64.float_of_bits (Bytes.get_int64_le buf pos)
in
if cd.tzero = 0.0 && cd.tscal = 1.0 then v
else (v *. cd.tscal) +. cd.tzero))
| 'J' ->
if cd.repeat <> 1 then err_vector cd.name cd.repeat;
Talon.Col.int32
(Array.init nrows (fun i ->
let pos = i * 4 in
Fits_parser.swap32 buf pos;
let v = Bytes.get_int32_le buf pos in
if cd.tzero = 0.0 && cd.tscal = 1.0 then v
else
Int32.of_float
((Int32.to_float v *. cd.tscal) +. cd.tzero)))
| 'K' ->
if cd.repeat <> 1 then err_vector cd.name cd.repeat;
Talon.Col.int64
(Array.init nrows (fun i ->
let pos = i * 8 in
Fits_parser.swap64 buf pos;
let v = Bytes.get_int64_le buf pos in
if cd.tzero = 0.0 && cd.tscal = 1.0 then v
else
Int64.of_float
((Int64.to_float v *. cd.tscal) +. cd.tzero)))
| 'I' ->
if cd.repeat <> 1 then err_vector cd.name cd.repeat;
Talon.Col.int32
(Array.init nrows (fun i ->
let pos = i * 2 in
Fits_parser.swap16 buf pos;
let v = Bytes.get_int16_le buf pos in
if cd.tzero = 0.0 && cd.tscal = 1.0 then Int32.of_int v
else
Int32.of_float
((Float.of_int v *. cd.tscal) +. cd.tzero)))
| 'B' ->
if cd.repeat <> 1 then err_vector cd.name cd.repeat;
Talon.Col.int32
(Array.init nrows (fun i ->
let v = Bytes.get_uint8 buf i in
if cd.tzero = 0.0 && cd.tscal = 1.0 then Int32.of_int v
else
Int32.of_float
((Float.of_int v *. cd.tscal) +. cd.tzero)))
| 'L' ->
if cd.repeat <> 1 then err_vector cd.name cd.repeat;
Talon.Col.bool
(Array.init nrows (fun i ->
let c = Bytes.get buf i in
c = 'T' || c = '\x01'))
| 'A' ->
Talon.Col.string
(Array.init nrows (fun i ->
Fits_parser.trim_right
(Bytes.sub_string buf (i * cd.repeat) cd.repeat)))
| c -> failwith (Printf.sprintf "Fits: unsupported TFORM '%c'" c)
in
(cd.name, col))
col_info
in
Talon.create talon_cols)
let find_keyword_float keywords key =
match Fits_parser.find_keyword keywords key with
| Some v -> Some (float_of_string (String.trim v))
| None -> None
let read_image ?(hdu = 0) path =
let ic = In_channel.open_bin path in
Fun.protect
~finally:(fun () -> In_channel.close ic)
(fun () ->
let headers = Fits_parser.read_headers ic in
if hdu < 0 || hdu >= List.length headers then
failwith (Printf.sprintf "Fits.read_image: HDU %d out of range" hdu);
let h = List.nth headers hdu in
(match hdu_type_of_header hdu h with
| Primary | Image -> ()
| _ -> failwith err_not_image);
let bscale =
match find_keyword_float h.keywords "BSCALE" with
| Some v -> v
| None -> 1.0
in
let bzero =
match find_keyword_float h.keywords "BZERO" with
| Some v -> v
| None -> 0.0
in
let has_scaling = bscale <> 1.0 || bzero <> 0.0 in
let (_ : int) = Fits_parser.seek_to_data ic headers hdu in
let shape = Array.to_list h.naxis |> List.rev |> Array.of_list in
let total = Array.fold_left ( * ) 1 shape in
let apply_scaling raw =
Nx.add_s (Nx.mul_s (Nx.astype Nx.float64 raw) bscale) bzero
in
match h.bitpix with
| 8 ->
let buf = Bytes.create total in
read_input ic buf total;
let raw =
Nx.create Nx.uint8 shape
(Array.init total (fun i -> Bytes.get_uint8 buf i))
in
if has_scaling then Nx_io.P (apply_scaling raw) else Nx_io.P raw
| 16 ->
let buf = Bytes.create (total * 2) in
read_input ic buf (total * 2);
let raw =
Nx.create Nx.int16 shape
(Array.init total (fun i ->
let pos = i * 2 in
Fits_parser.swap16 buf pos;
Bytes.get_int16_le buf pos))
in
if has_scaling then Nx_io.P (apply_scaling raw) else Nx_io.P raw
| 32 ->
let buf = Bytes.create (total * 4) in
read_input ic buf (total * 4);
let raw =
Nx.create Nx.int32 shape
(Array.init total (fun i ->
let pos = i * 4 in
Fits_parser.swap32 buf pos;
Bytes.get_int32_le buf pos))
in
if has_scaling then Nx_io.P (apply_scaling raw) else Nx_io.P raw
| 64 ->
let buf = Bytes.create (total * 8) in
read_input ic buf (total * 8);
let raw =
Nx.create Nx.int64 shape
(Array.init total (fun i ->
let pos = i * 8 in
Fits_parser.swap64 buf pos;
Bytes.get_int64_le buf pos))
in
if has_scaling then Nx_io.P (apply_scaling raw) else Nx_io.P raw
| -32 ->
let buf = Bytes.create (total * 4) in
read_input ic buf (total * 4);
let raw =
Nx.create Nx.float32 shape
(Array.init total (fun i ->
let pos = i * 4 in
Fits_parser.swap32 buf pos;
Int32.float_of_bits (Bytes.get_int32_le buf pos)))
in
if has_scaling then Nx_io.P (apply_scaling raw) else Nx_io.P raw
| -64 ->
let buf = Bytes.create (total * 8) in
read_input ic buf (total * 8);
let raw =
Nx.create Nx.float64 shape
(Array.init total (fun i ->
let pos = i * 8 in
Fits_parser.swap64 buf pos;
Int64.float_of_bits (Bytes.get_int64_le buf pos)))
in
if has_scaling then Nx_io.P (apply_scaling raw) else Nx_io.P raw
| n -> failwith (err_unsupported_bitpix n))
let pad_to_block oc written =
let rem = written mod Fits_parser.block_size in
if rem > 0 then
output_string oc (String.make (Fits_parser.block_size - rem) '\x00')
let write_card oc key value =
let card = Bytes.make 80 ' ' in
Bytes.blit_string key 0 card 0 (Int.min 8 (String.length key));
Bytes.set card 8 '=';
Bytes.set card 9 ' ';
let v = String.trim value in
Bytes.blit_string v 0 card 10 (Int.min 70 (String.length v));
output_bytes oc card
let write_card_str oc key value =
write_card oc key (Printf.sprintf "'%-8s'" value)
let write_card_int oc key value =
write_card oc key (Printf.sprintf "%20d" value)
let write_end oc cards_written =
let card = Bytes.make 80 ' ' in
Bytes.blit_string "END" 0 card 0 3;
output_bytes oc card;
let total_cards = cards_written + 1 in
let rem = total_cards * 80 mod Fits_parser.block_size in
if rem > 0 then
output_string oc (String.make (Fits_parser.block_size - rem) ' ')
let write_empty_primary oc =
write_card oc "SIMPLE" " T";
write_card_int oc "BITPIX" 8;
write_card_int oc "NAXIS" 0;
write_end oc 3
let write_image_typed (type a b) ?(overwrite = true) path (tensor : (a, b) Nx.t)
=
if (not overwrite) && Sys.file_exists path then
failwith ("Fits.write_image: file exists: " ^ path);
let oc = Out_channel.open_bin path in
Fun.protect
~finally:(fun () -> Out_channel.close oc)
(fun () ->
let shape = Nx.shape tensor in
let ndim = Array.length shape in
let fits_shape = Array.init ndim (fun i -> shape.(ndim - 1 - i)) in
let total = Nx.numel tensor in
let dt = Nx.dtype_to_string (Nx.dtype tensor) in
let bitpix, elem_bytes =
match dt with
| "uint8" -> (8, 1)
| "int16" -> (16, 2)
| "int32" -> (32, 4)
| "int64" -> (64, 8)
| "float32" -> (-32, 4)
| "float64" -> (-64, 8)
| s -> failwith ("Fits.write_image: unsupported dtype " ^ s)
in
write_card oc "SIMPLE" " T";
write_card_int oc "BITPIX" bitpix;
write_card_int oc "NAXIS" ndim;
for i = 0 to ndim - 1 do
write_card_int oc (Printf.sprintf "NAXIS%d" (i + 1)) fits_shape.(i)
done;
write_end oc (3 + ndim);
let flat = Nx.reshape [| total |] tensor in
let arr = Nx.to_array flat in
let data_bytes = total * elem_bytes in
let buf = Bytes.create data_bytes in
(match dt with
| "uint8" ->
Array.iteri
(fun i (v : a) -> Bytes.set_uint8 buf i (Obj.magic v : int))
arr
| "int16" ->
Array.iteri
(fun i (v : a) ->
let pos = i * 2 in
Bytes.set_int16_le buf pos (Obj.magic v : int);
Fits_parser.swap16 buf pos)
arr
| "int32" ->
Array.iteri
(fun i (v : a) ->
let pos = i * 4 in
Bytes.set_int32_le buf pos (Obj.magic v : int32);
Fits_parser.swap32 buf pos)
arr
| "int64" ->
Array.iteri
(fun i (v : a) ->
let pos = i * 8 in
Bytes.set_int64_le buf pos (Obj.magic v : int64);
Fits_parser.swap64 buf pos)
arr
| "float32" ->
Array.iteri
(fun i (v : a) ->
let pos = i * 4 in
Bytes.set_int32_le buf pos
(Int32.bits_of_float (Obj.magic v : float));
Fits_parser.swap32 buf pos)
arr
| "float64" ->
Array.iteri
(fun i (v : a) ->
let pos = i * 8 in
Bytes.set_int64_le buf pos
(Int64.bits_of_float (Obj.magic v : float));
Fits_parser.swap64 buf pos)
arr
| _ -> assert false);
output_bytes oc buf;
pad_to_block oc data_bytes)
let write_image ?overwrite path tensor =
write_image_typed ?overwrite path tensor
let write_table ?(overwrite = true) path df =
if (not overwrite) && Sys.file_exists path then
failwith ("Fits.write_table: file exists: " ^ path);
let oc = Out_channel.open_bin path in
Fun.protect
~finally:(fun () -> Out_channel.close oc)
(fun () ->
write_empty_primary oc;
let col_names = Talon.column_names df in
let nrows = Talon.num_rows df in
let ncols = List.length col_names in
let col_info =
List.map
(fun name ->
let col = Talon.get_column_exn df name in
match Talon.Col.dtype col with
| `Float32 -> (name, col, "1E", 4)
| `Float64 -> (name, col, "1D", 8)
| `Int32 -> (name, col, "1J", 4)
| `Int64 -> (name, col, "1K", 8)
| `String -> (
match Talon.to_string_array df name with
| Some arr ->
let maxlen =
Array.fold_left
(fun acc v ->
match v with
| Some s -> max acc (String.length s)
| None -> acc)
1 arr
in
(name, col, Printf.sprintf "%dA" maxlen, maxlen)
| None -> failwith "Fits.write_table: string column missing")
| `Bool -> (name, col, "1L", 1)
| `Other -> failwith "Fits.write_table: unsupported dtype")
col_names
in
let row_bytes =
List.fold_left (fun acc (_, _, _, eb) -> acc + eb) 0 col_info
in
write_card_str oc "XTENSION" "BINTABLE";
write_card_int oc "BITPIX" 8;
write_card_int oc "NAXIS" 2;
write_card_int oc "NAXIS1" row_bytes;
write_card_int oc "NAXIS2" nrows;
write_card_int oc "PCOUNT" 0;
write_card_int oc "GCOUNT" 1;
write_card_int oc "TFIELDS" ncols;
let cards = ref 8 in
List.iteri
(fun i (name, _col, tform, _eb) ->
let n = i + 1 in
write_card_str oc (Printf.sprintf "TTYPE%d" n) name;
write_card_str oc (Printf.sprintf "TFORM%d" n) tform;
cards := !cards + 2)
col_info;
write_end oc !cards;
let col_arrays =
List.map
(fun (name, col, _tform, _eb) ->
match Talon.Col.dtype col with
| `Float32 -> (
match Talon.to_array Nx.float32 df name with
| Some a -> `F32 a
| None -> assert false)
| `Float64 -> (
match Talon.to_array Nx.float64 df name with
| Some a -> `F64 a
| None -> assert false)
| `Int32 -> (
match Talon.to_array Nx.int32 df name with
| Some a -> `I32 a
| None -> assert false)
| `Int64 -> (
match Talon.to_array Nx.int64 df name with
| Some a -> `I64 a
| None -> assert false)
| `String -> (
match Talon.to_string_array df name with
| Some a -> `Str a
| None -> assert false)
| `Bool -> (
match Talon.to_bool_array df name with
| Some a -> `Bool a
| None -> assert false)
| `Other -> failwith "Fits.write_table: unsupported dtype")
col_info
in
let row_buf = Bytes.create row_bytes in
for row = 0 to nrows - 1 do
let off = ref 0 in
List.iter2
(fun (_, _, _, eb) col_arr ->
(match col_arr with
| `F32 arr ->
Bytes.set_int32_le row_buf !off (Int32.bits_of_float arr.(row));
Fits_parser.swap32 row_buf !off
| `F64 arr ->
Bytes.set_int64_le row_buf !off (Int64.bits_of_float arr.(row));
Fits_parser.swap64 row_buf !off
| `I32 arr ->
Bytes.set_int32_le row_buf !off arr.(row);
Fits_parser.swap32 row_buf !off
| `I64 arr ->
Bytes.set_int64_le row_buf !off arr.(row);
Fits_parser.swap64 row_buf !off
| `Str arr -> (
Bytes.fill row_buf !off eb ' ';
match arr.(row) with
| Some s ->
let len = Int.min eb (String.length s) in
Bytes.blit_string s 0 row_buf !off len
| None -> ())
| `Bool arr ->
let v = match arr.(row) with Some true -> 'T' | _ -> 'F' in
Bytes.set row_buf !off v);
off := !off + eb)
col_info col_arrays;
output_bytes oc row_buf
done;
pad_to_block oc (nrows * row_bytes))
================================================
FILE: dev/umbra/lib/fits/umbra_fits.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** FITS file I/O.
Reads and writes {{:https://fits.gsfc.nasa.gov/fits_standard.html}FITS}
files. Binary tables are loaded into {!Talon.t} dataframes and images into
{!Nx.t} tensors. All data is converted from FITS big-endian on read and
written as big-endian on write. *)
(** {1:inspect Inspection} *)
(** The type for FITS header data unit kinds. *)
type hdu_type =
| Primary (** Primary HDU. *)
| Image (** Image extension. *)
| Bintable (** Binary table extension. *)
| Ascii_table (** ASCII table extension. *)
type hdu_info = {
index : int; (** Zero-based HDU index. *)
hdu_type : hdu_type; (** Kind of HDU. *)
dimensions : int array; (** NAXIS values. *)
num_rows : int option; (** Row count for table HDUs. *)
num_cols : int option; (** Column count for table HDUs. *)
}
(** The type for HDU summary information. *)
type header_card = {
key : string; (** Keyword name (up to 8 characters). *)
value : string; (** Parsed value string. *)
comment : string; (** Inline comment, if any. *)
}
(** The type for FITS header cards. *)
val info : string -> hdu_info list
(** [info path] is the summary information for every HDU in the FITS file at
[path].
Raises [Failure] if [path] cannot be read or is not a valid FITS file. *)
val header : ?hdu:int -> string -> header_card list
(** [header path] is the header cards for HDU [hdu] in the FITS file at [path],
including COMMENT and HISTORY cards.
[hdu] defaults to [0] (primary HDU).
Raises [Failure] if [hdu] is out of range. *)
(** {1:reading Reading} *)
val read_table : ?hdu:int -> string -> Talon.t
(** [read_table path] reads a BINTABLE extension into a dataframe.
[hdu] defaults to [1] (first extension).
Supported TFORM types: [E] (float32), [D] (float64), [J] (int32), [K]
(int64), [I] (int16), [B] (uint8), [L] (logical), [A] (string). Vector
columns (repeat > 1) are not supported except for strings. TSCAL and TZERO
are applied when present.
Raises [Failure] if the HDU is not a BINTABLE, [hdu] is out of range, or a
column has an unsupported TFORM type. *)
val read_image : ?hdu:int -> string -> Nx_io.packed
(** [read_image path] reads an image HDU into a packed {!Nx.t} tensor.
[hdu] defaults to [0] (primary HDU).
Supported BITPIX values: [8], [16], [32], [64], [-32], [-64].
When BSCALE or BZERO header cards are present with non-trivial values
(BSCALE != 1.0 or BZERO != 0.0), the physical values [BZERO + BSCALE * raw]
are computed and the result is returned as float64 regardless of the
original BITPIX. When neither card is present or both have default values,
the raw data type is preserved.
Raises [Failure] if the HDU is not an image, [hdu] is out of range, or
BITPIX is unsupported. *)
(** {1:writing Writing} *)
val write_table : ?overwrite:bool -> string -> Talon.t -> unit
(** [write_table path df] writes [df] as a single BINTABLE extension preceded by
an empty primary HDU.
[overwrite] defaults to [true].
Raises [Failure] if [overwrite] is [false] and [path] already exists. *)
val write_image : ?overwrite:bool -> string -> ('a, 'b) Nx.t -> unit
(** [write_image path tensor] writes [tensor] as a primary image HDU.
[overwrite] defaults to [true].
Supported dtypes: uint8, int16, int32, int64, float32, float64.
Raises [Failure] if [overwrite] is [false] and [path] already exists, or if
the dtype is unsupported. *)
================================================
FILE: dev/umbra/lib/galactocentric.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let pi = Float.pi
let f64 = Nx.float64
let galcen_distance_default = Unit.Length.of_kpc (Nx.scalar f64 8.122)
let z_sun_default = Unit.Length.of_kpc (Nx.scalar f64 0.0208)
type t = { x : Nx.float64_t; y : Nx.float64_t; z : Nx.float64_t }
let x t = Unit.Length.of_kpc t.x
let y t = Unit.Length.of_kpc t.y
let z t = Unit.Length.of_kpc t.z
(* Convert via Galactic coordinates. In Galactic (l,b) the GC is at l=0, b=0, so
heliocentric Galactic Cartesian is: x_h = d cos(b) cos(l) toward GC y_h = d
cos(b) sin(l) toward rotation z_h = d sin(b) toward NGP
Galactocentric = heliocentric shifted by Sun's position: x_gc = x_h -
galcen_distance y_gc = y_h z_gc = z_h + z_sun *)
let of_coord ?(galcen_distance = galcen_distance_default)
?(z_sun = z_sun_default) ~distance c =
let galcen_distance = Nx.item [] (Unit.Length.in_kpc galcen_distance) in
let z_sun = Nx.item [] (Unit.Length.in_kpc z_sun) in
let gal = Coord.galactic c in
let l_rad = Unit.Angle.to_tensor (Coord.lon gal) in
let b_rad = Unit.Angle.to_tensor (Coord.lat gal) in
let d_kpc = Unit.Length.in_kpc distance in
let n = Nx.numel l_rad in
let x_out = Nx.zeros Nx.float64 [| n |] in
let y_out = Nx.zeros Nx.float64 [| n |] in
let z_out = Nx.zeros Nx.float64 [| n |] in
for i = 0 to n - 1 do
let l = Nx.item [ i ] l_rad in
let b = Nx.item [ i ] b_rad in
let d = Nx.item [ i ] d_kpc in
let cb = Float.cos b in
let xh = d *. cb *. Float.cos l in
let yh = d *. cb *. Float.sin l in
let zh = d *. Float.sin b in
Nx.set_item [ i ] (xh -. galcen_distance) x_out;
Nx.set_item [ i ] yh y_out;
Nx.set_item [ i ] (zh +. z_sun) z_out
done;
{ x = x_out; y = y_out; z = z_out }
let to_coord ?(galcen_distance = galcen_distance_default)
?(z_sun = z_sun_default) t =
let galcen_distance = Nx.item [] (Unit.Length.in_kpc galcen_distance) in
let z_sun = Nx.item [] (Unit.Length.in_kpc z_sun) in
let n = Nx.numel t.x in
let l_out = Nx.zeros Nx.float64 [| n |] in
let b_out = Nx.zeros Nx.float64 [| n |] in
let d_out = Nx.zeros Nx.float64 [| n |] in
for i = 0 to n - 1 do
let xg = Nx.item [ i ] t.x in
let yg = Nx.item [ i ] t.y in
let zg = Nx.item [ i ] t.z in
let xh = xg +. galcen_distance in
let yh = yg in
let zh = zg -. z_sun in
let d = Float.sqrt ((xh *. xh) +. (yh *. yh) +. (zh *. zh)) in
let b = Float.asin (Float.max ~-.1.0 (Float.min 1.0 (zh /. d))) in
let l = Float.atan2 yh xh in
let l = if l < 0.0 then l +. (2.0 *. pi) else l in
Nx.set_item [ i ] l l_out;
Nx.set_item [ i ] b b_out;
Nx.set_item [ i ] d d_out
done;
let coord =
Coord.of_galactic
~l:(Unit.Angle.of_tensor l_out)
~b:(Unit.Angle.of_tensor b_out)
in
let distance = Unit.Length.of_kpc d_out in
(coord, distance)
================================================
FILE: dev/umbra/lib/galactocentric.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Galactocentric Cartesian coordinates.
Converts celestial positions with distances to a right-handed Cartesian
frame centered on the Galactic center. The x-axis points from the Sun toward
the Galactic center (l=0, b=0), y in the direction of Galactic rotation, z
toward the North Galactic Pole.
Coordinates go through the Galactic frame (ICRS {e ->} Galactic {e ->}
heliocentric Cartesian {e ->} Galactocentric). The Galactic center position
is defined by the IAU Galactic coordinate system (l=0, b=0).
Default parameters follow
{{:https://ui.adsabs.harvard.edu/abs/2018A%26A...615L..15G}GRAVITY
Collaboration (2018)} for the Galactic center distance.
{[
let star =
Coord.of_radec
~ra:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 266.0 |]))
~dec:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| -29.0 |]))
in
let gc =
Galactocentric.of_coord
~distance:(Unit.Length.of_kpc (Nx.create f64 [| 1 |] [| 8.0 |]))
star
in
let x_kpc = Nx.item [ 0 ] (Unit.Length.in_kpc (Galactocentric.x gc))
]} *)
(** {1:coords Coordinates} *)
type t
(** The type for Galactocentric Cartesian positions. *)
val x : t -> Unit.length Unit.t
(** [x t] is the x coordinate (toward the Galactic center). *)
val y : t -> Unit.length Unit.t
(** [y t] is the y coordinate (direction of Galactic rotation). *)
val z : t -> Unit.length Unit.t
(** [z t] is the z coordinate (toward the North Galactic Pole). *)
(** {1:converting Converting} *)
val of_coord :
?galcen_distance:Unit.length Unit.t ->
?z_sun:Unit.length Unit.t ->
distance:Unit.length Unit.t ->
Coord.t ->
t
(** [of_coord ~distance c] converts celestial coordinates [c] with [distance] to
Galactocentric Cartesian. Not differentiable (scalar-level trigonometry).
[galcen_distance] is the Sun-GC distance (defaults to 8.122 kpc, GRAVITY
Collaboration 2018). [z_sun] is the Sun's height above the Galactic midplane
(defaults to 0.0208 kpc). *)
val to_coord :
?galcen_distance:Unit.length Unit.t ->
?z_sun:Unit.length Unit.t ->
t ->
Coord.t * Unit.length Unit.t
(** [to_coord t] converts Galactocentric Cartesian coordinates [t] back to ICRS
celestial coordinates and a distance. Not differentiable (scalar-level
trigonometry).
[galcen_distance] defaults to 8.122 kpc. [z_sun] defaults to 0.0208 kpc. *)
================================================
FILE: dev/umbra/lib/kdtree.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
type node =
| Leaf
| Node of {
idx : int;
x : float;
y : float;
z : float;
split : int;
left : node;
right : node;
}
type t = { root : node; size : int }
let coord split x y z = match split with 0 -> x | 1 -> y | _ -> z
let build xs ys zs =
let n = Array.length xs in
if n <> Array.length ys || n <> Array.length zs then
invalid_arg "Kdtree.build: arrays must have the same length";
let indices = Array.init n Fun.id in
let rec build_rec start len depth =
if len = 0 then Leaf
else if len = 1 then
let i = indices.(start) in
Node
{
idx = i;
x = xs.(i);
y = ys.(i);
z = zs.(i);
split = depth mod 3;
left = Leaf;
right = Leaf;
}
else begin
let split = depth mod 3 in
let sub = Array.sub indices start len in
Array.sort
(fun a b ->
Float.compare
(coord split xs.(a) ys.(a) zs.(a))
(coord split xs.(b) ys.(b) zs.(b)))
sub;
Array.blit sub 0 indices start len;
let mid = len / 2 in
let mi = indices.(start + mid) in
let left = build_rec start mid (depth + 1) in
let right = build_rec (start + mid + 1) (len - mid - 1) (depth + 1) in
Node
{ idx = mi; x = xs.(mi); y = ys.(mi); z = zs.(mi); split; left; right }
end
in
{ root = build_rec 0 n 0; size = n }
let sq_dist px py pz qx qy qz =
let dx = px -. qx and dy = py -. qy and dz = pz -. qz in
(dx *. dx) +. (dy *. dy) +. (dz *. dz)
let nearest tree qx qy qz =
if tree.size = 0 then invalid_arg "Kdtree.nearest: empty tree";
let best_idx = ref 0 in
let best_dist = ref Float.infinity in
let rec search node =
match node with
| Leaf -> ()
| Node { idx; x; y; z; split; left; right } ->
let d = sq_dist x y z qx qy qz in
if d < !best_dist then begin
best_dist := d;
best_idx := idx
end;
let q_split = coord split qx qy qz in
let p_split = coord split x y z in
let diff = q_split -. p_split in
let near, far = if diff < 0.0 then (left, right) else (right, left) in
search near;
if diff *. diff < !best_dist then search far
in
search tree.root;
(!best_idx, !best_dist)
let within tree qx qy qz max_dist_sq =
let results = ref [] in
let rec search node =
match node with
| Leaf -> ()
| Node { idx; x; y; z; split; left; right } ->
let d = sq_dist x y z qx qy qz in
if d <= max_dist_sq then results := (idx, d) :: !results;
let q_split = coord split qx qy qz in
let p_split = coord split x y z in
let diff = q_split -. p_split in
let near, far = if diff < 0.0 then (left, right) else (right, left) in
search near;
if diff *. diff <= max_dist_sq then search far
in
search tree.root;
!results
================================================
FILE: dev/umbra/lib/kdtree.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** 3D kd-tree for nearest-neighbor queries.
{b Note.} Private module. *)
type t
(** The type for a 3D kd-tree. *)
val build : float array -> float array -> float array -> t
(** [build xs ys zs] is a kd-tree over the points [(xs.(i), ys.(i), zs.(i))].
The three arrays must have equal length.
Raises [Invalid_argument] if the arrays differ in length. *)
val nearest : t -> float -> float -> float -> int * float
(** [nearest tree qx qy qz] is [(i, d2)] where [i] is the index of the nearest
point to [(qx, qy, qz)] and [d2] is the squared Euclidean distance.
Raises [Invalid_argument] if the tree is empty. *)
val within : t -> float -> float -> float -> float -> (int * float) list
(** [within tree qx qy qz max_d2] is the list of [(i, d2)] pairs for all points
within squared Euclidean distance [max_d2] of [(qx, qy, qz)]. *)
================================================
FILE: dev/umbra/lib/photometry.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let f64 = Nx.float64
(* Speed of light for f_lambda to f_nu conversion *)
let _c = 299_792_458.0
(* AB magnitude zero-point: 3631 Jy = 3631e-26 W/m²/Hz *)
let _ab_zp = 3631.0e-26
(* Wavelength stored internally in metres (SI base unit) *)
type bandpass = { wavelength : Nx.float64_t; throughput : Nx.float64_t }
type detector = Energy | Photon
let bandpass ~wavelength ~throughput =
let wavelength = Unit.Length.to_tensor wavelength in
if Nx.ndim wavelength <> 1 then
invalid_arg "Photometry.bandpass: wavelength must be a 1-D tensor";
if Nx.ndim throughput <> 1 then
invalid_arg "Photometry.bandpass: throughput must be a 1-D tensor";
if Nx.numel wavelength <> Nx.numel throughput then
invalid_arg
"Photometry.bandpass: wavelength and throughput must have the same length";
{ wavelength; throughput }
let tophat ~lo ~hi ~n =
let lo_m = Nx.item [] (Unit.Length.to_tensor lo) in
let hi_m = Nx.item [] (Unit.Length.to_tensor hi) in
let wavelength = Nx.linspace f64 lo_m hi_m n in
let throughput = Nx.ones f64 [| n |] in
{ wavelength; throughput }
let wavelength bp = Unit.Length.of_tensor bp.wavelength
let throughput bp = bp.throughput
(* Differentiable trapezoidal integration along the last axis of y. x is always
1-D (the wavelength grid). When y has leading batch dimensions the result
preserves them. All Nx ops — fully differentiable through Rune. *)
let trapz y x =
let m = Nx.numel x in
let x0 = Nx.slice [ R (0, m - 1) ] x in
let x1 = Nx.slice [ R (1, m) ] x in
let dx = Nx.sub x1 x0 in
let y_shape = Nx.shape y in
let ndim = Array.length y_shape in
if ndim <= 1 then begin
let y0 = Nx.slice [ R (0, m - 1) ] y in
let y1 = Nx.slice [ R (1, m) ] y in
let y_avg = Nx.div_s (Nx.add y0 y1) 2.0 in
Nx.sum (Nx.mul y_avg dx)
end
else begin
let y2d = Nx.reshape [| -1; m |] y in
let y0 = Nx.slice [ A; R (0, m - 1) ] y2d in
let y1 = Nx.slice [ A; R (1, m) ] y2d in
let y_avg = Nx.div_s (Nx.add y0 y1) 2.0 in
let result = Nx.sum ~axes:[ 1 ] (Nx.mul y_avg dx) in
let batch_shape = Array.sub y_shape 0 (ndim - 1) in
Nx.reshape batch_shape result
end
let pivot_wavelength bp =
let lam = bp.wavelength in
let t = bp.throughput in
(* lambda_p = sqrt(integral T lambda d lambda / integral T/lambda d lambda) *)
let num = trapz (Nx.mul t lam) lam in
let den = trapz (Nx.div t lam) lam in
Unit.Length.of_tensor (Nx.sqrt (Nx.div num den))
(* Detector weight: 1 for energy-counting, lambda for photon-counting *)
let detector_weight detector lam throughput =
match detector with Energy -> throughput | Photon -> Nx.mul throughput lam
(* ST magnitude zero-point: -2.5 log10(f_lambda / 3.63e-9 erg/s/cm²/Å) In SI:
3.63e-9 erg/s/cm²/Å = 3.63e-9 * 1e-7 * 1e4 * 1e10 W/m²/m = 3.63e-2 W/m²/m *)
let _st_zp = 3.63e-2
let align_spectrum bp spectrum =
let lam_bp = bp.wavelength in
let lam_sp = Unit.Length.to_tensor (Spectrum.wavelength spectrum) in
let same =
Nx.numel lam_bp = Nx.numel lam_sp
&& Nx.item [] (Nx.max (Nx.abs (Nx.sub lam_bp lam_sp))) = 0.0
in
if same then spectrum
else Spectrum.resample ~wavelength:(Unit.Length.of_tensor lam_bp) spectrum
let flux_density ?(detector = Energy) bp spectrum =
let spectrum = align_spectrum bp spectrum in
let lam = bp.wavelength in
let f = Spectrum.values spectrum in
let w = detector_weight detector lam bp.throughput in
Nx.div (trapz (Nx.mul f w) lam) (trapz w lam)
let ab_mag ?(detector = Energy) bp spectrum =
let spectrum = align_spectrum bp spectrum in
let lam = bp.wavelength in
let f_lambda = Spectrum.values spectrum in
let f_nu = Nx.div (Nx.mul f_lambda (Nx.square lam)) (Nx.scalar f64 _c) in
let w = detector_weight detector lam bp.throughput in
let mean_fnu = Nx.div (trapz (Nx.mul f_nu w) lam) (trapz w lam) in
Nx.mul_s
(Nx.log (Nx.div mean_fnu (Nx.scalar f64 _ab_zp)))
(-2.5 /. Float.log 10.0)
let st_mag ?(detector = Energy) bp spectrum =
let spectrum = align_spectrum bp spectrum in
let lam = bp.wavelength in
let f_lambda = Spectrum.values spectrum in
let w = detector_weight detector lam bp.throughput in
let mean_flam = Nx.div (trapz (Nx.mul f_lambda w) lam) (trapz w lam) in
Nx.mul_s
(Nx.log (Nx.div mean_flam (Nx.scalar f64 _st_zp)))
(-2.5 /. Float.log 10.0)
let _vega_spectrum =
let n = Array.length Vega_data.wave in
let w = Nx.create f64 [| n |] Vega_data.wave in
let w = Nx.mul_s w 1e-10 in
let f = Nx.create f64 [| n |] Vega_data.flux in
Spectrum.create ~wavelength:(Unit.Length.of_tensor w) ~values:f
|> Spectrum.as_flux_density
let vega_mag ?(detector = Energy) bp spectrum =
let f_src = flux_density ~detector bp spectrum in
let f_vega = flux_density ~detector bp _vega_spectrum in
Nx.mul_s (Nx.log (Nx.div f_src f_vega)) (-2.5 /. Float.log 10.0)
let color ?detector bp1 bp2 spectrum =
Nx.sub (ab_mag ?detector bp1 spectrum) (ab_mag ?detector bp2 spectrum)
let effective_wavelength ?(detector = Energy) bp spectrum =
let spectrum = align_spectrum bp spectrum in
let lam = bp.wavelength in
let f = Spectrum.values spectrum in
let w = detector_weight detector lam bp.throughput in
let fw = Nx.mul f w in
let num = trapz (Nx.mul fw (Nx.square lam)) lam in
let den = trapz (Nx.mul fw lam) lam in
Unit.Length.of_tensor (Nx.div num den)
================================================
FILE: dev/umbra/lib/photometry.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Synthetic photometry.
Computes broadband fluxes and magnitudes by integrating spectra through
filter bandpasses using trapezoidal quadrature.
{[
let bp = Photometry.tophat
~lo:(Unit.Length.nm 400.0)
~hi:(Unit.Length.nm 500.0) ~n:100 in
let mag = Photometry.ab_mag bp sed
]}
All photometry functions accept batched spectra (values with leading batch
dimensions). When a spectrum has shape [[batch; n_lambda]], the result has
shape [[batch]]. *)
(** {1:types Types} *)
type bandpass
(** The type for filter transmission curves. *)
type detector =
| Energy
| Photon
(** The detector convention.
- {!Energy}: counts incident energy (default). The bandpass-weighted
mean is [ = integral f_nu T d lambda / integral T d lambda].
- {!Photon}: counts photons. Weights both numerator and denominator by
[lambda]:
[ = integral f_nu T lambda d lambda / integral T lambda d
lambda]. *)
(** {1:constructors Constructors} *)
val bandpass :
wavelength:Unit.length Unit.t -> throughput:Nx.float64_t -> bandpass
(** [bandpass ~wavelength ~throughput] is a filter from 1-D arrays. [throughput]
is dimensionless (typically in \[0, 1\]).
Raises [Invalid_argument] if tensors are not 1-D or have different lengths.
*)
val tophat : lo:Unit.length Unit.t -> hi:Unit.length Unit.t -> n:int -> bandpass
(** [tophat ~lo ~hi ~n] is a rectangular bandpass from [lo] to [hi] with [n]
wavelength points and unit throughput. *)
(** {1:accessors Accessors} *)
val wavelength : bandpass -> Unit.length Unit.t
(** [wavelength bp] is the wavelength grid. *)
val throughput : bandpass -> Nx.float64_t
(** [throughput bp] is the throughput curve. *)
val pivot_wavelength : bandpass -> Unit.length Unit.t
(** [pivot_wavelength bp] is the pivot wavelength
{e lambda}{_ p}[ = sqrt(integral T lambda d lambda / integral T/lambda d
lambda)]. *)
(** {1:photometry Synthetic photometry} *)
val flux_density :
?detector:detector ->
bandpass ->
Spectrum.flux_density Spectrum.t ->
Nx.float64_t
(** [flux_density ?detector bp spectrum] is the bandpass-weighted mean flux
density [ = integral f T w d lambda / integral T w d lambda] where [w] is
[1] for {!Energy} and [lambda] for {!Photon}. [detector] defaults to
{!Energy}.
The spectrum is resampled to the bandpass wavelength grid via linear
interpolation if they differ. Differentiable through Rune. *)
val ab_mag :
?detector:detector ->
bandpass ->
Spectrum.flux_density Spectrum.t ->
Nx.float64_t
(** [ab_mag ?detector bp spectrum] is the AB magnitude of [spectrum] through
[bp].
Computes the mean spectral flux density in f{_ nu}:
[ = integral (f_lambda lambda{^2}/c) T w d lambda / integral T w d
lambda], where [w] is [1] for {!Energy} and [lambda] for {!Photon}, then
[m_AB = -2.5 log10( / 3631 Jy)]. [detector] defaults to {!Energy}.
The spectrum is resampled to the bandpass wavelength grid via linear
interpolation if they differ. Differentiable through Rune. *)
val st_mag :
?detector:detector ->
bandpass ->
Spectrum.flux_density Spectrum.t ->
Nx.float64_t
(** [st_mag ?detector bp spectrum] is the ST magnitude of [spectrum] through
[bp].
Computes the bandpass-weighted mean f{_ lambda}, then
[m_ST = -2.5 log10( / 3.63e-9 erg s{^-1} cm{^-2} A{^-1})].
[detector] defaults to {!Energy}.
The spectrum is resampled to the bandpass wavelength grid via linear
interpolation if they differ. Differentiable through Rune. *)
val vega_mag :
?detector:detector ->
bandpass ->
Spectrum.flux_density Spectrum.t ->
Nx.float64_t
(** [vega_mag ?detector bp spectrum] is the Vega magnitude of [spectrum] through
[bp].
Computes [-2.5 log10( / )] where the Vega reference
spectrum is from CALSPEC alpha_lyr_stis_011.fits (Bohlin 2014). [detector]
defaults to {!Energy}.
The spectrum is resampled to the bandpass wavelength grid via linear
interpolation if they differ. Differentiable through Rune. *)
val color :
?detector:detector ->
bandpass ->
bandpass ->
Spectrum.flux_density Spectrum.t ->
Nx.float64_t
(** [color ?detector bp1 bp2 spectrum] is
[ab_mag ?detector bp1 spectrum - ab_mag ?detector bp2 spectrum].
Differentiable through Rune. *)
val effective_wavelength :
?detector:detector ->
bandpass ->
Spectrum.flux_density Spectrum.t ->
Unit.length Unit.t
(** [effective_wavelength ?detector bp spectrum] is the source-dependent
effective wavelength
{e lambda}{_ eff}[ = integral f T w lambda{^2} d lambda / integral f T w
lambda d lambda].
Unlike {!pivot_wavelength}, this depends on the source spectrum. The
spectrum is resampled if grids differ. Differentiable through Rune. *)
================================================
FILE: dev/umbra/lib/spectrum.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let f64 = Nx.float64
(* Physical constants (SI) *)
let _h = 6.626_070_15e-34
let _c = 299_792_458.0
let _k_b = 1.380_649e-23
let _two_hc2 = 2.0 *. _h *. _c *. _c
let _hc_over_k = _h *. _c /. _k_b
(* Spectral kinds — phantom, no runtime representation *)
type flux_density
type radiance
type sampled
(* Wavelength stored internally in metres (SI base unit) *)
type 'a t = { wavelength : Nx.float64_t; values : Nx.float64_t }
let validate_increasing name wl =
let n = Nx.numel wl in
if n > 1 then
for i = 1 to n - 1 do
if Nx.item [ i ] wl <= Nx.item [ i - 1 ] wl then
invalid_arg (name ^ ": wavelength must be strictly increasing")
done
let create ~wavelength ~values =
let wavelength = Unit.Length.to_tensor wavelength in
if Nx.ndim wavelength <> 1 then
invalid_arg "Spectrum.create: wavelength must be a 1-D tensor";
let v_shape = Nx.shape values in
let v_ndim = Array.length v_shape in
if v_ndim = 0 then invalid_arg "Spectrum.create: values must be at least 1-D";
if v_shape.(v_ndim - 1) <> Nx.numel wavelength then
invalid_arg
"Spectrum.create: last dimension of values must match wavelength length";
validate_increasing "Spectrum.create" wavelength;
{ wavelength; values }
let wavelength t = Unit.Length.of_tensor t.wavelength
let values t = t.values
let as_flux_density t = { wavelength = t.wavelength; values = t.values }
let as_sampled t = { wavelength = t.wavelength; values = t.values }
let blackbody ~temperature ~wavelength =
let wavelength = Unit.Length.to_tensor wavelength in
let temp = Unit.Temperature.to_tensor temperature in
let two_hc2 = Nx.scalar f64 _two_hc2 in
let hc_k = Nx.scalar f64 _hc_over_k in
let lam5 = Nx.pow_s wavelength 5.0 in
let exponent = Nx.div hc_k (Nx.mul wavelength temp) in
let values =
Nx.div (Nx.div two_hc2 lam5) (Nx.sub (Nx.exp exponent) (Nx.scalar f64 1.0))
in
{ wavelength; values }
let power_law ~amplitude ~index ~pivot ~wavelength =
let wavelength = Unit.Length.to_tensor wavelength in
let pivot = Unit.Length.to_tensor pivot in
let ratio = Nx.div wavelength pivot in
let values = Nx.mul amplitude (Nx.pow ratio index) in
{ wavelength; values }
let redshift ~z t =
let one_plus_z = Nx.add_s z 1.0 in
{
wavelength = Nx.mul t.wavelength one_plus_z;
values = Nx.div t.values one_plus_z;
}
let scale factor t = { t with values = Nx.mul factor t.values }
let mul a b =
if Nx.numel a.wavelength <> Nx.numel b.wavelength then
invalid_arg "Spectrum.mul: spectra must have the same wavelength grid";
let max_diff =
Nx.item [] (Nx.max (Nx.abs (Nx.sub a.wavelength b.wavelength)))
in
if max_diff > 0.0 then
invalid_arg "Spectrum.mul: spectra must have the same wavelength grid";
{ wavelength = a.wavelength; values = Nx.mul a.values b.values }
let div a b =
if Nx.numel a.wavelength <> Nx.numel b.wavelength then
invalid_arg "Spectrum.div: spectra must have the same wavelength grid";
let max_diff =
Nx.item [] (Nx.max (Nx.abs (Nx.sub a.wavelength b.wavelength)))
in
if max_diff > 0.0 then
invalid_arg "Spectrum.div: spectra must have the same wavelength grid";
{ wavelength = a.wavelength; values = Nx.div a.values b.values }
let add a b =
if Nx.numel a.wavelength <> Nx.numel b.wavelength then
invalid_arg "Spectrum.add: spectra must have the same wavelength grid";
let max_diff =
Nx.item [] (Nx.max (Nx.abs (Nx.sub a.wavelength b.wavelength)))
in
if max_diff > 0.0 then
invalid_arg "Spectrum.add: spectra must have the same wavelength grid";
{ wavelength = a.wavelength; values = Nx.add a.values b.values }
let resample ~wavelength t =
let new_wave = Unit.Length.to_tensor wavelength in
if Nx.ndim new_wave <> 1 then
invalid_arg "Spectrum.resample: wavelength must be a 1-D tensor";
validate_increasing "Spectrum.resample" new_wave;
let old_wave = t.wavelength in
let old_values = t.values in
let n_old = Nx.numel old_wave and n_new = Nx.numel new_wave in
(* Find lower bracket index for each target wavelength (non-differentiable) *)
let lo_arr =
Array.init n_new (fun j ->
let x = Nx.item [ j ] new_wave in
let lo = ref 0 and hi = ref (n_old - 1) in
while !hi - !lo > 1 do
let mid = (!lo + !hi) / 2 in
if Nx.item [ mid ] old_wave <= x then lo := mid else hi := mid
done;
!lo)
in
let hi_arr =
Array.init n_new (fun j -> Int32.of_int (min (lo_arr.(j) + 1) (n_old - 1)))
in
let lo_arr = Array.map Int32.of_int lo_arr in
let lo_t = Nx.create Nx.int32 [| n_new |] lo_arr in
let hi_t = Nx.create Nx.int32 [| n_new |] hi_arr in
(* Gather source wavelengths and values at bracket endpoints. Nx.take uses
B.gather, which Rune differentiates through. *)
let x0 = Nx.take lo_t old_wave in
let x1 = Nx.take hi_t old_wave in
let y0 = Nx.take ~axis:(-1) lo_t old_values in
let y1 = Nx.take ~axis:(-1) hi_t old_values in
(* Linear interpolation — differentiable through Rune *)
let dx = Nx.clamp ~min:1e-30 (Nx.sub x1 x0) in
let frac = Nx.div (Nx.sub new_wave x0) dx in
let values = Nx.add y0 (Nx.mul frac (Nx.sub y1 y0)) in
{ wavelength = new_wave; values }
let gaussian ~amplitude ~center ~stddev ~wavelength =
let wavelength = Unit.Length.to_tensor wavelength in
let center = Unit.Length.to_tensor center in
let stddev = Unit.Length.to_tensor stddev in
let x = Nx.sub wavelength center in
let z = Nx.div x stddev in
let values = Nx.mul amplitude (Nx.exp (Nx.mul_s (Nx.mul z z) (-0.5))) in
{ wavelength; values }
let lorentzian ~amplitude ~center ~fwhm ~wavelength =
let wavelength = Unit.Length.to_tensor wavelength in
let center = Unit.Length.to_tensor center in
let half_gamma = Nx.div_s (Unit.Length.to_tensor fwhm) 2.0 in
let x = Nx.sub wavelength center in
let hg2 = Nx.mul half_gamma half_gamma in
let values = Nx.mul amplitude (Nx.div hg2 (Nx.add (Nx.mul x x) hg2)) in
{ wavelength; values }
let voigt ~amplitude ~center ~sigma ~gamma ~wavelength =
let wavelength = Unit.Length.to_tensor wavelength in
let center = Unit.Length.to_tensor center in
let sigma = Unit.Length.to_tensor sigma in
let gamma = Unit.Length.to_tensor gamma in
(* Pseudo-Voigt mixing via Thompson, Cox & Hastings (1987). *)
let sqrt_2ln2 = Float.sqrt (2.0 *. Float.log 2.0) in
let fg = Nx.mul_s sigma (2.0 *. sqrt_2ln2) in
let fl = Nx.mul_s gamma 2.0 in
let fg2 = Nx.mul fg fg in
let fg3 = Nx.mul fg2 fg in
let fg4 = Nx.mul fg3 fg in
let fg5 = Nx.mul fg4 fg in
let fl2 = Nx.mul fl fl in
let fl3 = Nx.mul fl2 fl in
let fl4 = Nx.mul fl3 fl in
let fl5 = Nx.mul fl4 fl in
let f =
Nx.pow_s
(Nx.add fg5
(Nx.add
(Nx.mul_s (Nx.mul fg4 fl) 2.69269)
(Nx.add
(Nx.mul_s (Nx.mul fg3 fl2) 2.42843)
(Nx.add
(Nx.mul_s (Nx.mul fg2 fl3) 4.47163)
(Nx.add (Nx.mul_s (Nx.mul fg fl4) 0.07842) fl5)))))
0.2
in
let ratio = Nx.div fl f in
let ratio2 = Nx.mul ratio ratio in
let ratio3 = Nx.mul ratio2 ratio in
let eta =
Nx.add (Nx.mul_s ratio 1.36603)
(Nx.add (Nx.mul_s ratio2 (-0.47719)) (Nx.mul_s ratio3 0.11116))
in
(* Gaussian component (unit height at center) *)
let x = Nx.sub wavelength center in
let sig_eff = Nx.div_s f (2.0 *. sqrt_2ln2) in
let z_g = Nx.div x sig_eff in
let gauss = Nx.exp (Nx.mul_s (Nx.mul z_g z_g) (-0.5)) in
(* Lorentzian component (unit height at center) *)
let hf = Nx.div_s f 2.0 in
let hf2 = Nx.mul hf hf in
let lorentz = Nx.div hf2 (Nx.add (Nx.mul x x) hf2) in
let values =
Nx.mul amplitude
(Nx.add (Nx.mul eta lorentz)
(Nx.mul (Nx.sub (Nx.scalar f64 1.0) eta) gauss))
in
{ wavelength; values }
================================================
FILE: dev/umbra/lib/spectrum.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Sampled spectral quantities on a wavelength grid.
A {!'a t} pairs a wavelength grid with spectral values parameterised by a
phantom {e kind} that tracks the physical meaning of the values:
- {!flux_density}: spectral flux density f{_ lambda} (W m{^ -2} m{^ -1}).
- {!radiance}: spectral radiance B{_ lambda} (W m{^ -2} m{^ -1} sr{^ -1}).
- {!sampled}: arbitrary values with no physical assumption.
Operations that depend on the physical interpretation of the values (e.g.,
{!redshift}, {!val-Photometry.ab_mag}) require a specific kind, preventing
accidental misuse at compile time. Use {!as_flux_density} to explicitly
reinterpret values when the physical meaning is known to the caller.
{[
let wave = Unit.Length.of_m (Nx.linspace Nx.float64 1e-7 1e-5 1000) in
let sed =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar Nx.float64 5800.0))
~wavelength:wave
|> Spectrum.as_flux_density
in
let reddened =
Extinction.apply (Extinction.ccm89 ~rv) ~av sed
]}
{2:batch Batched spectra}
Values may have leading batch dimensions: a spectrum with wavelength
[[n_lambda]] and values [[batch; n_lambda]] represents [batch] spectra
sharing a wavelength grid. All operations ({!resample}, {!scale}, {!add},
{!val-Photometry.ab_mag}, {!val-Extinction.apply}, etc.) broadcast over
leading dimensions via Nx:
{[
let values = Nx.stack (List.map Spectrum.values templates) in
let batch =
Spectrum.create ~wavelength ~values |> Spectrum.as_flux_density
in
let mags = Photometry.ab_mag bp batch (* shape [batch] *)
]}
{b Note.} {!redshift} with a per-spectrum [z] does not broadcast — it
changes the wavelength grid, breaking the shared-grid invariant. Use
[List.map] or [Rune.vmap] for per-spectrum redshifts. *)
(** {1:kinds Spectral kinds} *)
type flux_density
(** Phantom type for spectral flux density f{_ lambda} (W m{^ -2} m{^ -1}). *)
type radiance
(** Phantom type for spectral radiance B{_ lambda} (W m{^ -2} m{^ -1} sr{^ -1}).
*)
type sampled
(** Phantom type for arbitrary sampled spectral values. *)
(** {1:types Types} *)
type 'a t
(** The type for spectra parameterised by spectral kind ['a]. *)
(** {1:constructors Constructors} *)
val create : wavelength:Unit.length Unit.t -> values:Nx.float64_t -> sampled t
(** [create ~wavelength ~values] is a tabulated spectrum. [wavelength] must be
1-D. [values] must be at least 1-D with its last dimension matching
[wavelength]; leading dimensions are preserved as batch dimensions.
Raises [Invalid_argument] if [wavelength] is not 1-D, the last dimension of
[values] does not match, or [wavelength] is not strictly increasing. *)
(** {1:accessors Accessors} *)
val wavelength : 'a t -> Unit.length Unit.t
(** [wavelength s] is the wavelength grid. *)
val values : 'a t -> Nx.float64_t
(** [values s] is the spectral values. *)
(** {1:casts Kind casts} *)
val as_flux_density : _ t -> flux_density t
(** [as_flux_density s] reinterprets [s] as spectral flux density. The caller is
responsible for ensuring the values represent f{_ lambda}. Use this when
working with external data or when only relative values matter (e.g.,
fitting colours from a blackbody model). *)
val as_sampled : _ t -> sampled t
(** [as_sampled s] forgets the spectral kind. *)
(** {1:models Parametric models} *)
val blackbody :
temperature:Unit.temperature Unit.t ->
wavelength:Unit.length Unit.t ->
radiance t
(** [blackbody ~temperature ~wavelength] is the Planck spectral radiance
B{_ lambda}(T) in W m{^ -2} m{^ -1} sr{^ -1} at the given wavelengths. This
is a per-steradian quantity; multiply by a solid angle to obtain spectral
irradiance. Differentiable through Rune. *)
val power_law :
amplitude:Nx.float64_t ->
index:Nx.float64_t ->
pivot:Unit.length Unit.t ->
wavelength:Unit.length Unit.t ->
sampled t
(** [power_law ~amplitude ~index ~pivot ~wavelength] is the spectrum
[amplitude * (wavelength / pivot){^index}]. Differentiable through Rune. *)
(** {1:operations Operations} *)
val redshift : z:Nx.float64_t -> flux_density t -> flux_density t
(** [redshift ~z s] shifts [s] to redshift [z]. Wavelengths are multiplied by
[(1+z)] and values are divided by [(1+z)].
Restricted to {!flux_density} spectra because the [(1+z){^ -1}] dimming
factor is specific to spectral flux density. Differentiable through Rune. *)
val scale : Nx.float64_t -> 'a t -> 'a t
(** [scale factor s] is [s] with values multiplied element-wise by [factor].
[factor] may be a scalar or a tensor that broadcasts with the values.
Differentiable through Rune. *)
val mul : 'a t -> sampled t -> 'a t
(** [mul a b] multiplies values element-wise. [a]'s spectral kind is preserved;
[b] is treated as a dimensionless modifier (transmission curve, efficiency
function, etc.). Both must share the same wavelength grid. Differentiable
through Rune.
Raises [Invalid_argument] if wavelength grids have different lengths. *)
val div : 'a t -> sampled t -> 'a t
(** [div a b] divides values element-wise. [a]'s spectral kind is preserved; [b]
is treated as a dimensionless modifier. Both must share the same wavelength
grid. Differentiable through Rune.
Raises [Invalid_argument] if wavelength grids have different lengths. *)
val add : 'a t -> 'a t -> 'a t
(** [add a b] is the element-wise sum of two spectra. Both must share the same
wavelength grid. Differentiable through Rune.
Raises [Invalid_argument] if wavelength grids have different lengths. *)
val resample : wavelength:Unit.length Unit.t -> 'a t -> 'a t
(** [resample ~wavelength s] resamples [s] onto a new wavelength grid using
linear interpolation. Leading batch dimensions are preserved. Differentiable
through Rune with respect to the spectrum values (index computation is not
differentiable, but the interpolation weights and gather operations are).
Raises [Invalid_argument] if [wavelength] is not 1-D or not strictly
increasing. *)
(** {1:lines Line profiles} *)
val gaussian :
amplitude:Nx.float64_t ->
center:Unit.length Unit.t ->
stddev:Unit.length Unit.t ->
wavelength:Unit.length Unit.t ->
sampled t
(** [gaussian ~amplitude ~center ~stddev ~wavelength] is the Gaussian profile
[amplitude * exp(-0.5 * ((lambda - center) / stddev){^2})].
[amplitude], [center], and [stddev] may be scalar tensors; they broadcast
against [wavelength]. Differentiable through Rune. *)
val lorentzian :
amplitude:Nx.float64_t ->
center:Unit.length Unit.t ->
fwhm:Unit.length Unit.t ->
wavelength:Unit.length Unit.t ->
sampled t
(** [lorentzian ~amplitude ~center ~fwhm ~wavelength] is the Lorentzian profile
[amplitude * (gamma/2){^2} / ((lambda - center){^2} + (gamma/2){^2})] where
[gamma = fwhm]. Unit height at [center]. Differentiable through Rune. *)
val voigt :
amplitude:Nx.float64_t ->
center:Unit.length Unit.t ->
sigma:Unit.length Unit.t ->
gamma:Unit.length Unit.t ->
wavelength:Unit.length Unit.t ->
sampled t
(** [voigt ~amplitude ~center ~sigma ~gamma ~wavelength] is the pseudo-Voigt
approximation of the Voigt profile (Thompson, Cox & Hastings 1987). [sigma]
is the Gaussian standard deviation and [gamma] is the Lorentzian half-width
at half-maximum. Accurate to <1% of the exact Faddeeva-based Voigt.
Differentiable through Rune. *)
================================================
FILE: dev/umbra/lib/survey.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let f64 = Nx.float64
let c_km_s = 299792.458
let h0_ref = 100.0
let steradian_to_arcmin2 = 11818102.86004228
let c1_rho_crit = 0.0134
(* Redshift distributions *)
type nz = { eval : Nx.float64_t -> Nx.float64_t; zmax : float }
let simps_float f a b n =
let h = (b -. a) /. Float.of_int n in
let sum = ref (f a +. f b) in
for i = 1 to n - 1 do
let x = a +. (Float.of_int i *. h) in
let w = if i mod 2 = 1 then 4.0 else 2.0 in
sum := !sum +. (w *. f x)
done;
!sum *. h /. 3.0
let smail ?(zmax = 10.0) ~a ~b ~z0 () =
let raw z_f = (z_f ** a) *. Float.exp (-.((z_f /. z0) ** b)) in
let norm = simps_float raw 0.0 zmax 256 in
let eval z =
let z_f = Nx.item [] z in
Nx.scalar f64 (raw z_f /. norm)
in
{ eval; zmax }
let tabulated ~z ~pz () =
let n = (Nx.shape z).(0) in
let zmax = Nx.item [ n - 1 ] z in
let norm = ref 0.0 in
for i = 0 to n - 2 do
let dz = Nx.item [ i + 1 ] z -. Nx.item [ i ] z in
norm := !norm +. (0.5 *. (Nx.item [ i ] pz +. Nx.item [ i + 1 ] pz) *. dz)
done;
let eval zq =
let zq_f = Nx.item [] zq in
if zq_f <= Nx.item [ 0 ] z || zq_f >= zmax then Nx.scalar f64 0.0
else begin
let idx = ref 0 in
for i = 0 to n - 2 do
if Nx.item [ i ] z <= zq_f then idx := i
done;
let i = !idx in
let z0 = Nx.item [ i ] z and z1 = Nx.item [ i + 1 ] z in
let p0 = Nx.item [ i ] pz and p1 = Nx.item [ i + 1 ] pz in
let frac = (zq_f -. z0) /. (z1 -. z0) in
Nx.scalar f64 ((p0 +. (frac *. (p1 -. p0))) /. !norm)
end
in
{ eval; zmax }
let custom_nz ?(zmax = 10.0) eval = { eval; zmax }
let eval_nz nz z = nz.eval z
let nz_zmax nz = nz.zmax
(* Galaxy bias *)
type bias = Cosmo.params -> Nx.float64_t -> Nx.float64_t
let constant_bias b _p _z = Nx.scalar f64 b
let inverse_growth_bias b0 p z =
let d = Cosmo.growth_factor ~p z in
Nx.div (Nx.scalar f64 b0) d
(* Power spectrum backends *)
type power = Cosmo.params -> Nx.float64_t -> Nx.float64_t -> Nx.float64_t
let linear p k z = Cosmo.linear_power ~p k z
let nonlinear p k z = Cosmo.nonlinear_power ~p k z
let baryonic_feedback ?(a_bary = 0.0) ?(log10_k_star = 1.0) ?(sigma = 0.55)
base_power =
fun p k z ->
let pk = base_power p k z in
if a_bary = 0.0 then pk
else
let inv_sigma2 = -1.0 /. (sigma *. sigma) in
let log10_k = Nx.div_s (Nx.log k) (Float.log 10.0) in
let delta = Nx.sub_s log10_k log10_k_star in
let gauss = Nx.exp (Nx.mul_s (Nx.mul delta delta) inv_sigma2) in
Nx.sub pk (Nx.mul_s (Nx.mul gauss pk) a_bary)
(* Tracers *)
type tracer_kind =
| Weak_lensing of { ia_bias : bias option; sigma_e : float; m_bias : float }
| Number_counts of { bias : bias }
| Custom of {
kernel :
p:Cosmo.params -> z:Nx.float64_t -> chi:Nx.float64_t -> Nx.float64_t;
}
type tracer = {
nz : nz option;
n_gal : float;
noise : float;
kind : tracer_kind;
zmax : float;
}
let weak_lensing ?ia_bias ?(sigma_e = 0.26) ?(m_bias = 0.0) ?(n_gal = 1.0) nz =
let noise = sigma_e *. sigma_e /. (n_gal *. steradian_to_arcmin2) in
{
nz = Some nz;
n_gal;
noise;
kind = Weak_lensing { ia_bias; sigma_e; m_bias };
zmax = nz.zmax;
}
let number_counts ~bias ?(n_gal = 1.0) nz =
let noise = 1.0 /. (n_gal *. steradian_to_arcmin2) in
{ nz = Some nz; n_gal; noise; kind = Number_counts { bias }; zmax = nz.zmax }
let tracer ?(noise = 0.0) ?(zmax = 3.0) kernel =
{ nz = None; n_gal = 0.0; noise; kind = Custom { kernel }; zmax }
(* Cls result type *)
type cls = {
ell : Nx.float64_t;
tracers : tracer array;
spectra : Nx.float64_t;
}
(* Cl index ordering: upper triangle *)
let pair_index nt i j =
let a, b = if i <= j then (i, j) else (j, i) in
(a * ((2 * nt) - a - 1) / 2) + b
let cl_pairs nt =
let pairs = ref [] in
for i = 0 to nt - 1 do
for j = i to nt - 1 do
pairs := (i, j) :: !pairs
done
done;
List.rev !pairs
(* Evaluate n(z) for one bin on the z grid. Returns tensor [n_z]. Uses Nx.stack
so gradients flow through custom_nz eval functions. *)
let eval_nz_grid nz z_arr n_z =
Nx.stack (List.init n_z (fun j -> nz.eval (Nx.scalar f64 z_arr.(j))))
(* Reverse cumulative trapezoidal sum of tensor [n] with spacing dz. result[j] =
∫_{x_j}^{x_{n-1}} f(x) dx via trapezoidal rule. *)
let rev_cumtrapz f_vec n dz =
let left = Nx.slice [ R (0, n - 1) ] f_vec in
let right = Nx.slice [ R (1, n) ] f_vec in
let mid = Nx.mul_s (Nx.add left right) (0.5 *. dz) in
let partial = Nx.flip (Nx.cumsum ~axis:0 (Nx.flip mid)) in
Nx.concatenate [ partial; Nx.zeros f64 [| 1 |] ]
(* Angular power spectra *)
let angular_cl ?(p = Cosmo.planck18) ?(power = nonlinear) ~ell tracers =
let tracers_arr = Array.of_list tracers in
let nt = Array.length tracers_arr in
let pairs = cl_pairs nt in
let pairs_arr = Array.of_list pairs in
let zmax =
Array.fold_left (fun acc t -> Float.max acc t.zmax) 0.0 tracers_arr
in
let n_z = 100 in
let dz = zmax /. Float.of_int (n_z - 1) in
let z_arr = Array.init n_z (fun i -> Float.of_int i *. dz) in
z_arr.(0) <- 1e-6;
let z_vec = Nx.create f64 [| n_z |] z_arr in
(* Simpson weights: tensor [n_z] *)
let sw =
Array.init n_z (fun i ->
if i = 0 || i = n_z - 1 then 1.0 else if i mod 2 = 1 then 4.0 else 2.0)
in
let simpson_w = Nx.mul_s (Nx.create f64 [| n_z |] sw) (dz /. 3.0) in
(* Precompute z-dependent quantities as tensors — differentiable through p.
comoving_distance and growth_factor use GL quadrature internally and cannot
accept vector z, so we loop over scalar z values. *)
let h_t = Nx.div (Cosmo.h0 p) (Nx.scalar f64 h0_ref) in
let chi_vec =
Nx.stack
(List.init n_z (fun j ->
let z_t = Nx.scalar f64 z_arr.(j) in
Nx.mul (Unit.Length.in_mpc (Cosmo.comoving_distance ~p z_t)) h_t))
in
let chi_safe = Nx.clamp ~min:1e-10 chi_vec in
let h_vec = Cosmo.hubble ~p z_vec in
let dchi_dz_vec = Nx.div (Nx.mul_s h_t c_km_s) h_vec in
let growth_vec =
Nx.stack
(List.init n_z (fun j -> Cosmo.growth_factor ~p (Nx.scalar f64 z_arr.(j))))
in
let omega_m_t = Cosmo.omega_m p in
(* n(z) values per tracer: tensors [n_z], differentiable through custom_nz *)
let nz_arrs = Array.make nt (Nx.zeros f64 [| n_z |]) in
Array.iteri
(fun idx t ->
match t.nz with
| Some nz -> nz_arrs.(idx) <- eval_nz_grid nz z_arr n_z
| None -> ())
tracers_arr;
(* Kernel base vectors per tracer: tensor [n_z], without ell_factor for WL *)
let kernel_bases = Array.make nt (Nx.zeros f64 [| n_z |]) in
let kernel_has_ell_factor = Array.make nt false in
Array.iteri
(fun idx t ->
match t.kind with
| Weak_lensing { ia_bias; sigma_e = _; m_bias } ->
let nz_tensor = nz_arrs.(idx) in
(* A(z_j) = ∫_{z_j}^{zmax} n(z') dz' *)
let a_vec = rev_cumtrapz nz_tensor n_z dz in
(* B(z_j) = ∫_{z_j}^{zmax} n(z')/χ(z') dz' — tensor, through chi *)
let nz_over_chi = Nx.div nz_tensor chi_safe in
let b_vec = rev_cumtrapz nz_over_chi n_z dz in
(* g = A - chi * B *)
let g_vec = Nx.sub a_vec (Nx.mul chi_vec b_vec) in
(* WL kernel base: (3 H0² Ωm / 2c) × (1+z) × χ × g *)
let prefactor =
Nx.mul_s omega_m_t (3.0 *. h0_ref *. h0_ref /. (2.0 *. c_km_s))
in
let one_plus_z = Nx.add_s z_vec 1.0 in
let k_base =
Nx.mul prefactor (Nx.mul one_plus_z (Nx.mul chi_vec g_vec))
in
(* Add NLA intrinsic alignment if present *)
let k_base =
match ia_bias with
| None -> k_base
| Some ia_b ->
let ia_tensor =
Nx.stack
(List.init n_z (fun j -> ia_b p (Nx.scalar f64 z_arr.(j))))
in
(* K_IA = -(C₁ ρ_crit Ωm / D(z)) × n(z) × b_IA(z) × H(z) *)
let ia_kernel =
Nx.mul
(Nx.mul_s omega_m_t (-.c1_rho_crit))
(Nx.mul
(Nx.div nz_tensor growth_vec)
(Nx.mul ia_tensor h_vec))
in
Nx.add k_base ia_kernel
in
(* Shear multiplicative bias: W_obs = (1+m) W_true *)
let k_base =
if m_bias = 0.0 then k_base else Nx.mul_s k_base (1.0 +. m_bias)
in
kernel_bases.(idx) <- k_base;
kernel_has_ell_factor.(idx) <- true
| Number_counts { bias } ->
let nz_tensor = nz_arrs.(idx) in
let bias_tensor =
Nx.stack (List.init n_z (fun j -> bias p (Nx.scalar f64 z_arr.(j))))
in
(* NC kernel: n(z) × b(z) × H(z) — no ell factor *)
kernel_bases.(idx) <- Nx.mul nz_tensor (Nx.mul bias_tensor h_vec);
kernel_has_ell_factor.(idx) <- false
| Custom { kernel } ->
(* Custom kernel: user provides the full W(z) *)
kernel_bases.(idx) <-
Nx.stack
(List.init n_z (fun j ->
let z_t = Nx.scalar f64 z_arr.(j) in
let chi_t = Nx.get [ j ] chi_safe in
kernel ~p ~z:z_t ~chi:chi_t));
kernel_has_ell_factor.(idx) <- false)
tracers_arr;
(* Common integration weight: dchi/dz / chi² / c² × simpson *)
let integ_weight =
Nx.mul simpson_w
(Nx.div_s
(Nx.div dchi_dz_vec (Nx.mul chi_safe chi_safe))
(c_km_s *. c_km_s))
in
(* Power spectrum grid [n_z, n_ell]: loop over z (scalar), vectorized over k.
Both linear_power and nonlinear_power accept vector k but scalar z. *)
let pk_grid =
Nx.stack
(List.init n_z (fun z_idx ->
let z_t = Nx.scalar f64 z_arr.(z_idx) in
let chi_z = Nx.get [ z_idx ] chi_safe in
let k_vec = Nx.div (Nx.add_s ell 0.5) chi_z in
power p k_vec z_t))
in
(* ell_factor vector [n_ell]: sqrt((ℓ-1)ℓ(ℓ+1)(ℓ+2)) / (ℓ+0.5)² *)
let ell_factor_vec =
let l = ell in
let num =
Nx.mul
(Nx.mul (Nx.sub_s l 1.0) l)
(Nx.mul (Nx.add_s l 1.0) (Nx.add_s l 2.0))
in
let den = Nx.mul (Nx.add_s l 0.5) (Nx.add_s l 0.5) in
Nx.div (Nx.sqrt (Nx.abs num)) den
in
(* Limber integration: functional, no in-place mutation. integ_weight is
[n_z], pk_grid is [n_z, n_ell]. For each pair (i,j): C_ℓ = Σ_z K_i(z)
K_j(z) P(k,z) w(z) kernel_bases are [n_z], broadcast with pk_grid [n_z,
n_ell]. *)
let w_pk = Nx.mul (Nx.reshape [| n_z; 1 |] integ_weight) pk_grid in
let spectra =
Nx.stack
(List.map
(fun (i, j) ->
let ki = Nx.reshape [| n_z; 1 |] kernel_bases.(i) in
let kj = Nx.reshape [| n_z; 1 |] kernel_bases.(j) in
let integrand = Nx.mul (Nx.mul ki kj) w_pk in
let cl_row = Nx.sum ~axes:[ 0 ] integrand in
let ell_power =
(if kernel_has_ell_factor.(i) then 1 else 0)
+ if kernel_has_ell_factor.(j) then 1 else 0
in
if ell_power = 0 then cl_row
else if ell_power = 1 then Nx.mul ell_factor_vec cl_row
else Nx.mul (Nx.mul ell_factor_vec ell_factor_vec) cl_row)
(Array.to_list pairs_arr))
in
{ ell; tracers = tracers_arr; spectra }
(* Cls submodule *)
module Cls = struct
let get cls ~i ~j =
let n = Array.length cls.tracers in
if i < 0 || i >= n || j < 0 || j >= n then
invalid_arg "Survey.Cls.get: index out of range";
Nx.slice [ I (pair_index n i j) ] cls.spectra
let ell cls = cls.ell
let n_tracers cls = Array.length cls.tracers
let to_tensor cls = cls.spectra
let noise cls =
let n_ell = (Nx.shape cls.ell).(0) in
let nt = Array.length cls.tracers in
let pairs = cl_pairs nt in
let n_cls = List.length pairs in
let result = Nx.zeros f64 [| n_cls; n_ell |] in
let pair_idx = ref 0 in
List.iter
(fun (i, j) ->
if i = j then begin
let noise_val = cls.tracers.(i).noise in
for l = 0 to n_ell - 1 do
Nx.set_item [ !pair_idx; l ] noise_val result
done
end;
incr pair_idx)
pairs;
result
let gaussian_covariance ?(f_sky = 0.25) cls =
let ell = cls.ell in
let n_ell = (Nx.shape ell).(0) in
let nt = Array.length cls.tracers in
let pairs = cl_pairs nt in
let n_cls = List.length pairs in
let n = n_cls * n_ell in
let cov = Nx.zeros f64 [| n; n |] in
let cl_noise = noise cls in
let cl_obs = Nx.add cls.spectra cl_noise in
let pairs_arr = Array.of_list pairs in
let find_pair a b = pair_index nt a b in
(* Δℓ via finite differences *)
let dell =
Array.init n_ell (fun l ->
if l = 0 then Nx.item [ 1 ] ell -. Nx.item [ 0 ] ell
else if l = n_ell - 1 then Nx.item [ l ] ell -. Nx.item [ l - 1 ] ell
else 0.5 *. (Nx.item [ l + 1 ] ell -. Nx.item [ l - 1 ] ell))
in
for p1 = 0 to n_cls - 1 do
let i, j = pairs_arr.(p1) in
for p2 = p1 to n_cls - 1 do
let m, nn = pairs_arr.(p2) in
let im = find_pair i m and jn = find_pair j nn in
let in_ = find_pair i nn and jm = find_pair j m in
for l = 0 to n_ell - 1 do
let ell_l = Nx.item [ l ] ell in
let norm = ((2.0 *. ell_l) +. 1.0) *. dell.(l) *. f_sky in
let c_im = Nx.get [ im; l ] cl_obs in
let c_jn = Nx.get [ jn; l ] cl_obs in
let c_in = Nx.get [ in_; l ] cl_obs in
let c_jm = Nx.get [ jm; l ] cl_obs in
let val_ =
Nx.div_s (Nx.add (Nx.mul c_im c_jn) (Nx.mul c_in c_jm)) norm
in
let row = (p1 * n_ell) + l in
let col = (p2 * n_ell) + l in
Nx.set [ row; col ] cov val_;
if p1 <> p2 then Nx.set [ col; row ] cov val_
done
done
done;
cov
end
================================================
FILE: dev/umbra/lib/survey.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Angular power spectra and survey science.
The central type is {!tracer}: one tracer per tomographic bin. {!angular_cl}
cross-correlates a list of tracers and returns a structured {!cls} value
with typed accessors.
{!angular_cl} and {!inverse_growth_bias} are differentiable through Rune.
{!Cls.noise} and {!Cls.gaussian_covariance} are not (they use in-place
mutation); compute them once at a fiducial cosmology.
{[
let nz1 = Survey.smail ~a:2.0 ~b:1.5 ~z0:0.3 () in
let nz2 = Survey.smail ~a:2.0 ~b:1.5 ~z0:0.7 () in
let wl1 = Survey.weak_lensing ~n_gal:26.0 nz1 in
let wl2 = Survey.weak_lensing ~n_gal:26.0 nz2 in
let ell = Nx.logspace Nx.float64 1.0 3.0 50 in
let cls = Survey.angular_cl ~p:Cosmo.planck18 ~ell [ wl1; wl2 ] in
let cl_auto = Survey.Cls.get cls ~i:0 ~j:0 in
let cl_cross = Survey.Cls.get cls ~i:0 ~j:1
]} *)
(** {1:nz Redshift distributions} *)
type nz
(** A normalized redshift probability density n(z) with a maximum redshift. *)
val smail : ?zmax:float -> a:float -> b:float -> z0:float -> unit -> nz
(** [smail ~a ~b ~z0 ()] is n(z) {e ∝} z{^ a} exp(-(z/z0){^ b}). Auto-normalized
via Simpson's rule. [zmax] defaults to [10.0]. *)
val tabulated : z:Nx.float64_t -> pz:Nx.float64_t -> unit -> nz
(** [tabulated ~z ~pz ()] is n(z) linearly interpolated from sampled points.
Auto-normalized. [zmax] is inferred from the last element of [z]. *)
val custom_nz : ?zmax:float -> (Nx.float64_t -> Nx.float64_t) -> nz
(** [custom_nz f] is a redshift distribution with evaluation function [f]. [f z]
maps a scalar tensor [z] to n(z). For differentiable survey optimization,
[f] should use tensor operations so gradients flow through Rune. [zmax]
defaults to [10.0]. *)
val eval_nz : nz -> Nx.float64_t -> Nx.float64_t
(** [eval_nz nz z] evaluates the normalized n(z) at [z]. *)
val nz_zmax : nz -> float
(** [nz_zmax nz] is the maximum redshift of the distribution. *)
(** {1:bias Galaxy bias} *)
type bias = Cosmo.params -> Nx.float64_t -> Nx.float64_t
(** A galaxy bias function. [bias p z] is b(z) under cosmology [p]. *)
val constant_bias : float -> bias
(** [constant_bias b] is a redshift-independent linear bias. Not differentiable
(constant value). *)
val inverse_growth_bias : float -> bias
(** [inverse_growth_bias b0] is [b0 / D(z)], where D is the linear growth
factor. Differentiable through Rune. *)
(** {1:power Power spectrum backends} *)
type power = Cosmo.params -> Nx.float64_t -> Nx.float64_t -> Nx.float64_t
(** [power p k z] is the matter power spectrum P(k, z) in (Mpc/h){^ 3}. [k] is a
1-D tensor of wavenumbers in h/Mpc, [z] is a scalar tensor. *)
val linear : power
(** [linear] is the linear matter power spectrum via Eisenstein & Hu (1998).
Differentiable through Rune. *)
val nonlinear : power
(** [nonlinear] is the nonlinear power spectrum via Halofit (Takahashi et al.
2012). Differentiable through Rune (except the nonlinear scale k{_ nl} which
is found by float-level root-finding). *)
val baryonic_feedback :
?a_bary:float -> ?log10_k_star:float -> ?sigma:float -> power -> power
(** [baryonic_feedback base_power] wraps [base_power] with a Gaussian
suppression in log{_ 10}(k) that models baryonic feedback on the matter
power spectrum:
P{_ bary}(k, z) = P(k, z) {e ×} (1 - a{_ bary} {e ×} exp(-(log{_ 10}(k) -
log{_ 10}(k{_ star})){^ 2} / {e σ}{^ 2})).
[a_bary] is the suppression amplitude (default [0.0] = no effect).
[log10_k_star] is the log{_ 10} of the peak suppression wavenumber in h/Mpc
(default [1.0], i.e. k{_ star} = 10 h/Mpc). [sigma] is the Gaussian width in
log{_ 10}(k) (default [0.55]).
Differentiable through Rune. *)
(** {1:tracers Tracers} *)
type tracer
(** The type for a single tomographic tracer. One tracer = one redshift bin with
its physics (lensing kernel, galaxy bias, etc.) and noise properties.
{!angular_cl} cross-correlates a list of tracers. *)
val weak_lensing :
?ia_bias:bias ->
?sigma_e:float ->
?m_bias:float ->
?n_gal:float ->
nz ->
tracer
(** [weak_lensing nz] is a weak gravitational lensing tracer with redshift
distribution [nz]. [sigma_e] is the intrinsic ellipticity dispersion
(default [0.26]). [n_gal] is the galaxy number density in
galaxies/arcmin{^ 2} (default [1.0]). [ia_bias], if provided, adds NLA
intrinsic alignment.
[m_bias] is the shear multiplicative bias (default [0.0]). The lensing
kernel is scaled by [(1 + m_bias)], so auto-spectra scale as [(1 + m){^ 2}]
and cross-spectra as [(1 + m{_ i})(1 + m{_ j})]. Differentiable through Rune
when used with {!angular_cl}. *)
val number_counts : bias:bias -> ?n_gal:float -> nz -> tracer
(** [number_counts ~bias nz] is a galaxy number counts tracer with redshift
distribution [nz] and galaxy bias model [bias]. [n_gal] is the galaxy number
density in galaxies/arcmin{^ 2} (default [1.0]). *)
val tracer :
?noise:float ->
?zmax:float ->
(p:Cosmo.params -> z:Nx.float64_t -> chi:Nx.float64_t -> Nx.float64_t) ->
tracer
(** [tracer kernel] is a custom tracer with kernel function [kernel].
[kernel ~p ~z ~chi] returns the full projection kernel W(z) at scalar
redshift [z] and comoving distance [chi] (Mpc/h) under cosmology [p].
[noise] is the constant noise power N{_ ℓ} for auto-correlations (default
[0.0]). [zmax] defaults to [3.0]. *)
(** {1:cls Angular power spectra} *)
type cls
(** The type for a set of angular power spectra. Stores all auto- and
cross-correlations for a list of tracers, along with the ell values and
tracer metadata needed for noise and covariance computation. *)
val angular_cl :
?p:Cosmo.params -> ?power:power -> ell:Nx.float64_t -> tracer list -> cls
(** [angular_cl ~ell tracers] computes angular power spectra C{_ ℓ} for all
auto- and cross-correlations via the Limber approximation. Differentiable
through Rune.
[power] defaults to {!nonlinear}. [p] defaults to {!Cosmo.planck18}.
Raises [Invalid_argument] if [omega_b], [n_s], or [sigma8] are not set in
[p]. *)
(** {2:cls_access Structured access} *)
module Cls : sig
val get : cls -> i:int -> j:int -> Nx.float64_t
(** [get cls ~i ~j] is the angular power spectrum C{_ ℓ}{^ ij} between tracers
[i] and [j]. Returns a 1-D tensor of shape [[n_ell]]. [get cls ~i ~j] and
[get cls ~j ~i] return the same spectrum.
Raises [Invalid_argument] if [i] or [j] is out of range. *)
val ell : cls -> Nx.float64_t
(** [ell cls] is the multipole values, shape [[n_ell]]. *)
val n_tracers : cls -> int
(** [n_tracers cls] is the number of tracers. *)
val to_tensor : cls -> Nx.float64_t
(** [to_tensor cls] is all spectra packed as a tensor of shape
[[n_cls; n_ell]] where [n_cls = n * (n + 1) / 2], ordered as (0,0), (0,1),
..., (1,1), .... *)
val noise : cls -> Nx.float64_t
(** [noise cls] is the shot noise power spectra. Weak lensing:
{e σ}{_ e}{^ 2}/n{_ gal}. Number counts: 1/n{_ gal}. Custom: the [noise]
value. Cross-spectra are zero. Shape [[n_cls; n_ell]].
Not differentiable. *)
val gaussian_covariance : ?f_sky:float -> cls -> Nx.float64_t
(** [gaussian_covariance cls] is the Gaussian covariance matrix. [f_sky]
defaults to [0.25]. Returns dense matrix of shape [[n; n]] where
[n = n_cls * n_ell].
Not differentiable. *)
end
================================================
FILE: dev/umbra/lib/time.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Astronomical time with phantom-typed time scales.
Internal representation: Julian Date (float) in the tagged scale. MJD = JD -
2400000.5 Unix epoch (1970-01-01T00:00:00 UTC) = JD 2440587.5 *)
type 'a t = float
type utc
type tai
type tt
type tdb
(* Constructors *)
let unsafe_of_jd jd = jd
let unsafe_of_mjd mjd = mjd +. 2_400_000.5
let of_unix u = (u /. 86_400.0) +. 2_440_587.5
let now () = of_unix (Unix.gettimeofday ())
(* Comparison *)
let compare (a : float) (b : float) = Float.compare a b
let equal (a : float) (b : float) = Float.equal a b
(* Eliminators *)
let to_jd t = t
let to_mjd t = t -. 2_400_000.5
let to_unix t = (t -. 2_440_587.5) *. 86_400.0
(* Duration *)
let diff a b = Unit.Time.day (a -. b)
let add t dt = t +. Nx.item [] (Unit.Time.in_day dt)
(* Leap second table: (JD of midnight UTC when leap second is introduced,
cumulative TAI-UTC). Source: IERS Bulletin C. *)
let leap_seconds =
[|
(2441317.5, 10.0);
(* 1972-01-01 *)
(2441499.5, 11.0);
(* 1972-07-01 *)
(2441683.5, 12.0);
(* 1973-01-01 *)
(2442048.5, 13.0);
(* 1974-01-01 *)
(2442413.5, 14.0);
(* 1975-01-01 *)
(2442778.5, 15.0);
(* 1976-01-01 *)
(2443144.5, 16.0);
(* 1977-01-01 *)
(2443509.5, 17.0);
(* 1978-01-01 *)
(2443874.5, 18.0);
(* 1979-01-01 *)
(2444239.5, 19.0);
(* 1980-01-01 *)
(2444786.5, 20.0);
(* 1981-07-01 *)
(2445151.5, 21.0);
(* 1982-07-01 *)
(2445516.5, 22.0);
(* 1983-07-01 *)
(2446247.5, 23.0);
(* 1985-07-01 *)
(2447161.5, 24.0);
(* 1988-01-01 *)
(2447892.5, 25.0);
(* 1990-01-01 *)
(2448257.5, 26.0);
(* 1991-01-01 *)
(2448804.5, 27.0);
(* 1992-07-01 *)
(2449169.5, 28.0);
(* 1993-07-01 *)
(2449534.5, 29.0);
(* 1994-07-01 *)
(2450083.5, 30.0);
(* 1996-01-01 *)
(2450630.5, 31.0);
(* 1997-07-01 *)
(2451179.5, 32.0);
(* 1999-01-01 *)
(2453736.5, 33.0);
(* 2006-01-01 *)
(2454832.5, 34.0);
(* 2009-01-01 *)
(2456109.5, 35.0);
(* 2012-07-01 *)
(2457204.5, 36.0);
(* 2015-07-01 *)
(2457754.5, 37.0);
(* 2017-01-01 *)
|]
let tai_minus_utc jd_utc =
let n = Array.length leap_seconds in
let rec search i =
if i < 0 then 10.0
else
let jd, dt = leap_seconds.(i) in
if jd_utc >= jd then dt else search (i - 1)
in
search (n - 1)
(* UTC <-> TAI *)
let utc_to_tai utc_jd =
let dt = tai_minus_utc utc_jd in
utc_jd +. (dt /. 86_400.0)
let tai_to_utc tai_jd =
(* Approximate: convert TAI to approximate UTC, look up, refine *)
let approx_utc = tai_jd -. (37.0 /. 86_400.0) in
let dt = tai_minus_utc approx_utc in
tai_jd -. (dt /. 86_400.0)
(* TAI <-> TT: TT = TAI + 32.184s (exact by definition) *)
let tt_offset = 32.184 /. 86_400.0
let tai_to_tt tai_jd = tai_jd +. tt_offset
let tt_to_tai tt_jd = tt_jd -. tt_offset
(* TT <-> TDB: Fairhead & Bretagnon 1990 series (first 10 terms). Accuracy ~1μs
for dates within a few centuries of J2000.0.
T = (JD_TT - 2451545.0) / 36525.0 (Julian centuries from J2000.0 TT) TDB - TT
≈ Σ Aᵢ sin(ωᵢ T + φᵢ) in seconds *)
let fb_terms =
[|
(* amplitude (s), frequency (rad/century), phase (rad) *)
(1.656_674_564e-3, 6_283.075_849_991, 6.240_054_195);
(2.227_2e-5, 5_753.384_884_897, 4.296_977_442);
(1.3886e-5, 12_566.151_699_983, 6.196_904_410);
(3.150e-6, 529.690_965_095, 0.444_401_603);
(1.575e-6, 6_069.776_754_553, 4.021_195_093);
(1.020_5e-5, 213.299_095_438, 5.543_113_262);
(3.978e-6, 77_713.771_467_920, 5.198_467_090);
(4.354e-6, 7_860.419_392_439, 5.988_822_341);
(1.456e-6, 11_506.769_769_794, 2.457_236_222);
(1.126e-6, 3_930.209_696_220, 5.316_024_159);
|]
let tt_to_tdb tt_jd =
let t = (tt_jd -. 2_451_545.0) /. 36_525.0 in
let sum = ref 0.0 in
for i = 0 to Array.length fb_terms - 1 do
let amp, freq, phase = fb_terms.(i) in
sum := !sum +. (amp *. Float.sin ((freq *. t) +. phase))
done;
tt_jd +. (!sum /. 86_400.0)
let tdb_to_tt tdb_jd =
(* Single Newton iteration: TT ≈ TDB, compute correction *)
let tt_approx = tdb_jd in
let tdb_from_approx = tt_to_tdb tt_approx in
let correction = tdb_jd -. tdb_from_approx in
tt_approx +. correction
(* ISO 8601 parsing and formatting for UTC *)
(* Calendar date to JD (valid for dates after 1582-10-15, Gregorian calendar) *)
let cal_to_jd y m d =
let y, m = if m <= 2 then (y - 1, m + 12) else (y, m) in
let a = y / 100 in
let b = 2 - a + (a / 4) in
Float.floor (365.25 *. Float.of_int (y + 4716))
+. Float.floor (30.6001 *. Float.of_int (m + 1))
+. d +. Float.of_int b -. 1524.5
(* JD to calendar date *)
let jd_to_cal jd =
let jd = jd +. 0.5 in
let z = Float.to_int (Float.floor jd) in
let f = jd -. Float.of_int z in
let a =
if z < 2299161 then z
else
let alpha =
Float.to_int (Float.floor ((Float.of_int z -. 1867216.25) /. 36524.25))
in
z + 1 + alpha - (alpha / 4)
in
let b = a + 1524 in
let c = Float.to_int (Float.floor ((Float.of_int b -. 122.1) /. 365.25)) in
let d = Float.to_int (Float.floor (365.25 *. Float.of_int c)) in
let e = Float.to_int (Float.floor (Float.of_int (b - d) /. 30.6001)) in
let day_frac =
Float.of_int (b - d) -. Float.floor (30.6001 *. Float.of_int e) +. f
in
let month = if e < 14 then e - 1 else e - 13 in
let year = if month > 2 then c - 4716 else c - 4715 in
(year, month, day_frac)
let of_iso s =
let s =
let len = String.length s in
if len > 0 && s.[len - 1] = 'Z' then String.sub s 0 (len - 1) else s
in
match
Scanf.sscanf s "%d-%d-%dT%d:%d:%f" (fun y mo d h mi s ->
(y, mo, d, h, mi, s))
with
| y, mo, d, h, mi, sec ->
let day =
Float.of_int d
+. (Float.of_int h /. 24.0)
+. (Float.of_int mi /. 1440.0)
+. (sec /. 86_400.0)
in
cal_to_jd y mo day
| exception _ -> (
match Scanf.sscanf s "%d-%d-%d" (fun y mo d -> (y, mo, d)) with
| y, mo, d -> cal_to_jd y mo (Float.of_int d)
| exception _ -> invalid_arg ("Time.of_iso: cannot parse " ^ s))
let to_iso t =
let y, m, day_frac = jd_to_cal t in
let d = Float.to_int (Float.floor day_frac) in
let frac = day_frac -. Float.of_int d in
let total_sec = frac *. 86_400.0 in
let h = Float.to_int (Float.floor (total_sec /. 3600.0)) in
let rem = total_sec -. (Float.of_int h *. 3600.0) in
let mi = Float.to_int (Float.floor (rem /. 60.0)) in
let sec = rem -. (Float.of_int mi *. 60.0) in
if Float.abs sec < 0.0005 then
Printf.sprintf "%04d-%02d-%02dT%02d:%02d:%02dZ" y m d h mi 0
else Printf.sprintf "%04d-%02d-%02dT%02d:%02d:%06.3fZ" y m d h mi sec
================================================
FILE: dev/umbra/lib/time.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Astronomical time with phantom-typed time scales.
Times are stored internally as Julian Dates (float). Scale conversions are
type-safe: {!utc_to_tai} accepts a [utc t] and returns a [tai t].
{[
let t = Time.of_iso "2024-01-01T00:00:00" in
let tai = Time.utc_to_tai t in
let tt = Time.tai_to_tt tai in
let jd = Time.to_jd tt
]} *)
(** {1:types Types} *)
type 'scale t
(** The type for a Julian Date tagged with time scale ['scale]. *)
type utc
(** Coordinated Universal Time. *)
type tai
(** International Atomic Time. *)
type tt
(** Terrestrial Time. *)
type tdb
(** Barycentric Dynamical Time. *)
(** {1:constructors Constructors} *)
val unsafe_of_jd : float -> 'a t
(** [unsafe_of_jd jd] is a time from the Julian Date [jd]. The caller must
ensure [jd] is in the intended time scale. *)
val unsafe_of_mjd : float -> 'a t
(** [unsafe_of_mjd mjd] is a time from the Modified Julian Date [mjd] (MJD = JD
\- 2400000.5). The caller must ensure [mjd] is in the intended time scale.
*)
val of_iso : string -> utc t
(** [of_iso s] parses an ISO 8601 date-time string as UTC. Accepted formats:
["YYYY-MM-DD"], ["YYYY-MM-DDThh:mm:ss"], and ["YYYY-MM-DDThh:mm:ssZ"].
{b Warning.} Uses the Gregorian calendar; dates before 1582-10-15 produce
incorrect Julian Dates. Leap seconds (e.g. [23:59:60]) cannot be represented
and are parsed as the following second.
Raises [Invalid_argument] if [s] cannot be parsed. *)
val of_unix : float -> utc t
(** [of_unix u] is the UTC time corresponding to the Unix timestamp [u] (seconds
since 1970-01-01T00:00:00 UTC). *)
val now : unit -> utc t
(** [now ()] is the current UTC time from the system clock. *)
(** {1:comparison Comparison} *)
val compare : 'a t -> 'a t -> int
(** [compare a b] orders times by their Julian Date values. *)
val equal : 'a t -> 'a t -> bool
(** [equal a b] is [true] iff [a] and [b] have the same Julian Date value. *)
(** {1:eliminators Eliminators} *)
val to_jd : 'a t -> float
(** [to_jd t] is the Julian Date of [t]. *)
val to_mjd : 'a t -> float
(** [to_mjd t] is the Modified Julian Date of [t] (MJD = JD - 2400000.5). *)
val to_iso : utc t -> string
(** [to_iso t] formats [t] as an ISO 8601 string with trailing [Z]. Output is
["YYYY-MM-DDThh:mm:ssZ"] when the fractional seconds are below 0.5 ms, or
["YYYY-MM-DDThh:mm:ss.sssZ"] otherwise.
{b Warning.} Leap-second labels like [23:59:60] cannot be produced; times
within a leap second round to [00:00:00] of the following day. *)
val to_unix : utc t -> float
(** [to_unix t] is the Unix timestamp of [t] (seconds since 1970-01-01T00:00:00
UTC). *)
(** {1:scales Scale conversions}
UTC/TAI conversions use the IERS leap-second table (Bulletin C), currently
covering 1972-01-01 through 2017-01-01 (TAI-UTC = 37 s). Dates before
1972-01-01 use TAI-UTC = 10 s.
TT = TAI + 32.184 s (exact by definition).
TDB-TT uses the first 10 terms of the Fairhead & Bretagnon (1990) series,
accurate to ~1 us within a few centuries of J2000.0. *)
val utc_to_tai : utc t -> tai t
(** [utc_to_tai t] converts [t] from UTC to TAI. *)
val tai_to_utc : tai t -> utc t
(** [tai_to_utc t] converts [t] from TAI to UTC. *)
val tai_to_tt : tai t -> tt t
(** [tai_to_tt t] converts [t] from TAI to TT. *)
val tt_to_tai : tt t -> tai t
(** [tt_to_tai t] converts [t] from TT to TAI. *)
val tt_to_tdb : tt t -> tdb t
(** [tt_to_tdb t] converts [t] from TT to TDB. *)
val tdb_to_tt : tdb t -> tt t
(** [tdb_to_tt t] converts [t] from TDB to TT. Uses a single Newton iteration;
accurate to ~1 us. *)
(** {1:duration Duration} *)
val diff : 'a t -> 'a t -> Unit.time Unit.t
(** [diff a b] is the duration [a - b] as a {!Unit.time} quantity. *)
val add : 'a t -> Unit.time Unit.t -> 'a t
(** [add t dt] is [t] offset by the duration [dt]. *)
================================================
FILE: dev/umbra/lib/umbra.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
module Unit = Unit
module Const = Const
module Time = Time
module Coord = Coord
module Altaz = Altaz
module Galactocentric = Galactocentric
module Cosmo = Cosmo
module Spectrum = Spectrum
module Extinction = Extinction
module Photometry = Photometry
module Filters = Filters
module Survey = Survey
================================================
FILE: dev/umbra/lib/umbra.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Computational astronomy for OCaml.
Umbra provides dimensionally-typed physical quantities, astronomical
constants, cosmological distance calculations, spectral energy
distributions, dust extinction, synthetic photometry, coordinate transforms,
time scales, and catalog cross-matching.
All computations operate on {!Nx} tensors and are differentiable through
{!Rune} by default.
{[
open Umbra
let z = Nx.scalar Nx.float64 0.5 in
let dl = Cosmo.luminosity_distance z in
let dl_mpc = Unit.Length.in_mpc dl in
let rv = Nx.scalar Nx.float64 3.1 in
let av = Nx.scalar Nx.float64 0.5 in
let wave = Unit.Length.of_m (Nx.linspace Nx.float64 3e-7 1e-6 1000) in
let bp = Photometry.tophat
~lo:(Unit.Length.nm 400.0) ~hi:(Unit.Length.nm 700.0) ~n:1000 in
let sed =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar Nx.float64 5800.0))
~wavelength:wave
|> Extinction.apply (Extinction.ccm89 ~rv) ~av
|> Spectrum.as_flux_density
in
let mag = Photometry.ab_mag bp sed
]} *)
(** {1:units Units and constants} *)
module Unit = Unit
(** Physical quantities with compile-time dimensional safety. *)
module Const = Const
(** Physical and astronomical constants (CODATA 2022, IAU 2015). *)
(** {1:astro Astronomy} *)
module Time = Time
(** Astronomical time with phantom-typed time scales. *)
module Coord = Coord
(** Celestial coordinates with frame transforms and catalog cross-matching. *)
module Altaz = Altaz
(** Altitude-azimuth (horizontal) coordinates. *)
module Galactocentric = Galactocentric
(** Galactocentric Cartesian coordinates. *)
module Cosmo = Cosmo
(** Cosmological distances for {e Λ}CDM, wCDM, and w0waCDM universes. *)
module Spectrum = Spectrum
(** Sampled spectral values on a wavelength grid. *)
module Extinction = Extinction
(** Dust extinction laws. *)
module Photometry = Photometry
(** Synthetic photometry over filter bandpasses. *)
module Filters = Filters
(** Standard astronomical filter bandpasses (SDSS, Johnson-Cousins, 2MASS, Gaia
DR3). *)
(** {1:survey Survey science} *)
module Survey = Survey
(** Angular power spectra, probes, and survey likelihood. *)
================================================
FILE: dev/umbra/lib/unit.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let f64 = Nx.float64
type 'a t = Nx.float64_t
type length
type mass
type time
type angle
type velocity
type power
type temperature
type energy
type frequency
type dimensionless
(* Arithmetic — all Nx ops, fully traced by rune *)
let ( + ) a b = Nx.add a b
let ( - ) a b = Nx.sub a b
let neg x = Nx.neg x
let abs x = Nx.abs x
let scale s x = Nx.mul_s x s
let scale_t s x = Nx.mul s x
let ratio a b = Nx.div a b
let zero = Nx.scalar f64 0.0
let compare a b = Float.compare (Nx.item [] a) (Nx.item [] b)
let equal a b = Float.equal (Nx.item [] a) (Nx.item [] b)
let pp fmt x = Format.fprintf fmt "%g" (Nx.item [] x)
let to_float x = Nx.item [] x
(* Physical constants used in cross-dimension combinators *)
let c_m_s = Nx.scalar f64 299_792_458.0
let h_si = Nx.scalar f64 6.626_070_15e-34
let hc_si = Nx.scalar f64 (6.626_070_15e-34 *. 299_792_458.0)
let one = Nx.scalar f64 1.0
let au_m_t = Nx.scalar f64 1.495_978_707e11
(* Cross-dimension combinators — all Nx ops *)
let length_per_time d t = Nx.div d t
let velocity_times_time v t = Nx.mul v t
let length_per_velocity d v = Nx.div d v
let wavelength_to_frequency lam = Nx.div c_m_s lam
let frequency_to_wavelength nu = Nx.div c_m_s nu
let frequency_to_energy nu = Nx.mul h_si nu
let energy_to_frequency e = Nx.div e h_si
let energy_to_wavelength e = Nx.div hc_si e
(* Parallax: 1 arcsec ↔ 1 parsec. parallax(rad) = 1 AU / distance(m), so
distance(m) = 1 AU / parallax(rad). Uses the scale factors defined below. *)
let parallax_to_distance p = Nx.div au_m_t p
let distance_to_parallax d = Nx.div au_m_t d
(* Spectral density: f_ν = f_λ · λ²/c, where f_λ is per-metre and f_ν is
per-hertz. *)
let flam_to_fnu ~wavelength flam =
Nx.div (Nx.mul flam (Nx.square wavelength)) c_m_s
let fnu_to_flam ~wavelength fnu =
Nx.div (Nx.mul fnu c_m_s) (Nx.square wavelength)
(* Doppler conventions: velocity ↔ observed wavelength given rest wavelength.
All three conventions agree at v << c. *)
(* Optical: z = v/c, λ_obs = λ_rest * (1 + v/c) *)
let doppler_optical ~rest v = Nx.mul rest (Nx.add_s (Nx.div v c_m_s) 1.0)
let doppler_optical_inv ~rest obs =
Nx.mul c_m_s (Nx.sub_s (Nx.div obs rest) 1.0)
(* Radio: v = c*(1 - λ_rest/λ_obs), λ_obs = λ_rest / (1 - v/c) *)
let doppler_radio ~rest v = Nx.div rest (Nx.sub one (Nx.div v c_m_s))
let doppler_radio_inv ~rest obs = Nx.mul c_m_s (Nx.sub one (Nx.div rest obs))
(* Relativistic: λ_obs = λ_rest * sqrt((1+β)/(1-β)), β = v/c *)
let doppler_relativistic ~rest v =
let beta = Nx.div v c_m_s in
Nx.mul rest (Nx.sqrt (Nx.div (Nx.add_s beta 1.0) (Nx.sub one beta)))
let doppler_relativistic_inv ~rest obs =
let r2 = Nx.square (Nx.div obs rest) in
Nx.mul c_m_s (Nx.div (Nx.sub_s r2 1.0) (Nx.add_s r2 1.0))
(* Scale factors to SI base unit *)
let pc_m = 3.085_677_581_491_367_3e16
let au_m = 1.495_978_707e11
let ly_m = 9.460_730_472_580_8e15
let solar_radius_m = 6.957e8
let earth_radius_m = 6.371e6
let jupiter_radius_m = 7.1492e7
let solar_mass_kg = 1.988_4e30
let earth_mass_kg = 5.972_2e24
let jupiter_mass_kg = 1.898_2e27
let solar_luminosity_w = 3.828e26
let julian_year_s = 365.25 *. 86_400.0
let ev_j = 1.602_176_634e-19
module Length = struct
let of_tensor x = x
let to_tensor x = x
let m x = Nx.scalar f64 x
let km x = Nx.scalar f64 (x *. 1e3)
let cm x = Nx.scalar f64 (x *. 1e-2)
let mm x = Nx.scalar f64 (x *. 1e-3)
let um x = Nx.scalar f64 (x *. 1e-6)
let nm x = Nx.scalar f64 (x *. 1e-9)
let angstrom x = Nx.scalar f64 (x *. 1e-10)
let au x = Nx.scalar f64 (x *. au_m)
let pc x = Nx.scalar f64 (x *. pc_m)
let kpc x = Nx.scalar f64 (x *. pc_m *. 1e3)
let mpc x = Nx.scalar f64 (x *. pc_m *. 1e6)
let gpc x = Nx.scalar f64 (x *. pc_m *. 1e9)
let ly x = Nx.scalar f64 (x *. ly_m)
let solar_radius x = Nx.scalar f64 (x *. solar_radius_m)
let earth_radius x = Nx.scalar f64 (x *. earth_radius_m)
let jupiter_radius x = Nx.scalar f64 (x *. jupiter_radius_m)
let of_m x = x
let of_km x = Nx.mul_s x 1e3
let of_cm x = Nx.mul_s x 1e-2
let of_mm x = Nx.mul_s x 1e-3
let of_um x = Nx.mul_s x 1e-6
let of_nm x = Nx.mul_s x 1e-9
let of_angstrom x = Nx.mul_s x 1e-10
let of_au x = Nx.mul_s x au_m
let of_pc x = Nx.mul_s x pc_m
let of_kpc x = Nx.mul_s x (pc_m *. 1e3)
let of_mpc x = Nx.mul_s x (pc_m *. 1e6)
let of_gpc x = Nx.mul_s x (pc_m *. 1e9)
let of_ly x = Nx.mul_s x ly_m
let of_solar_radius x = Nx.mul_s x solar_radius_m
let of_earth_radius x = Nx.mul_s x earth_radius_m
let of_jupiter_radius x = Nx.mul_s x jupiter_radius_m
let in_m x = x
let in_km x = Nx.div_s x 1e3
let in_cm x = Nx.mul_s x (1.0 /. 1e-2)
let in_mm x = Nx.mul_s x (1.0 /. 1e-3)
let in_um x = Nx.mul_s x (1.0 /. 1e-6)
let in_nm x = Nx.mul_s x (1.0 /. 1e-9)
let in_angstrom x = Nx.mul_s x (1.0 /. 1e-10)
let in_au x = Nx.div_s x au_m
let in_pc x = Nx.div_s x pc_m
let in_kpc x = Nx.div_s x (pc_m *. 1e3)
let in_mpc x = Nx.div_s x (pc_m *. 1e6)
let in_gpc x = Nx.div_s x (pc_m *. 1e9)
let in_ly x = Nx.div_s x ly_m
let in_solar_radius x = Nx.div_s x solar_radius_m
let in_earth_radius x = Nx.div_s x earth_radius_m
let in_jupiter_radius x = Nx.div_s x jupiter_radius_m
end
module Mass = struct
let of_tensor x = x
let to_tensor x = x
let kg x = Nx.scalar f64 x
let g x = Nx.scalar f64 (x *. 1e-3)
let mg x = Nx.scalar f64 (x *. 1e-6)
let solar_mass x = Nx.scalar f64 (x *. solar_mass_kg)
let earth_mass x = Nx.scalar f64 (x *. earth_mass_kg)
let jupiter_mass x = Nx.scalar f64 (x *. jupiter_mass_kg)
let of_kg x = x
let of_g x = Nx.mul_s x 1e-3
let of_mg x = Nx.mul_s x 1e-6
let of_solar_mass x = Nx.mul_s x solar_mass_kg
let of_earth_mass x = Nx.mul_s x earth_mass_kg
let of_jupiter_mass x = Nx.mul_s x jupiter_mass_kg
let in_kg x = x
let in_g x = Nx.mul_s x (1.0 /. 1e-3)
let in_mg x = Nx.mul_s x (1.0 /. 1e-6)
let in_solar_mass x = Nx.div_s x solar_mass_kg
let in_earth_mass x = Nx.div_s x earth_mass_kg
let in_jupiter_mass x = Nx.div_s x jupiter_mass_kg
end
module Time = struct
let of_tensor x = x
let to_tensor x = x
let s x = Nx.scalar f64 x
let ms x = Nx.scalar f64 (x *. 1e-3)
let us x = Nx.scalar f64 (x *. 1e-6)
let min x = Nx.scalar f64 (x *. 60.0)
let hr x = Nx.scalar f64 (x *. 3600.0)
let day x = Nx.scalar f64 (x *. 86_400.0)
let yr x = Nx.scalar f64 (x *. julian_year_s)
let myr x = Nx.scalar f64 (x *. julian_year_s *. 1e6)
let gyr x = Nx.scalar f64 (x *. julian_year_s *. 1e9)
let of_s x = x
let of_ms x = Nx.mul_s x 1e-3
let of_us x = Nx.mul_s x 1e-6
let of_min x = Nx.mul_s x 60.0
let of_hr x = Nx.mul_s x 3600.0
let of_day x = Nx.mul_s x 86_400.0
let of_yr x = Nx.mul_s x julian_year_s
let of_myr x = Nx.mul_s x (julian_year_s *. 1e6)
let of_gyr x = Nx.mul_s x (julian_year_s *. 1e9)
let in_s x = x
let in_ms x = Nx.mul_s x (1.0 /. 1e-3)
let in_us x = Nx.mul_s x (1.0 /. 1e-6)
let in_min x = Nx.div_s x 60.0
let in_hr x = Nx.div_s x 3600.0
let in_day x = Nx.div_s x 86_400.0
let in_yr x = Nx.div_s x julian_year_s
let in_myr x = Nx.div_s x (julian_year_s *. 1e6)
let in_gyr x = Nx.div_s x (julian_year_s *. 1e9)
end
module Angle = struct
let deg_rad = Float.pi /. 180.0
let of_tensor x = x
let to_tensor x = x
let rad x = Nx.scalar f64 x
let deg x = Nx.scalar f64 (x *. deg_rad)
let arcmin x = Nx.scalar f64 (x *. deg_rad /. 60.0)
let arcsec x = Nx.scalar f64 (x *. deg_rad /. 3600.0)
let mas x = Nx.scalar f64 (x *. deg_rad /. 3_600_000.0)
let hour_angle x = Nx.scalar f64 (x *. Float.pi /. 12.0)
let of_rad x = x
let of_deg x = Nx.mul_s x deg_rad
let of_arcmin x = Nx.mul_s x (deg_rad /. 60.0)
let of_arcsec x = Nx.mul_s x (deg_rad /. 3600.0)
let of_mas x = Nx.mul_s x (deg_rad /. 3_600_000.0)
let of_hour_angle x = Nx.mul_s x (Float.pi /. 12.0)
let in_rad x = x
let in_deg x = Nx.div_s x deg_rad
let in_arcmin x = Nx.mul_s (Nx.div_s x deg_rad) 60.0
let in_arcsec x = Nx.mul_s (Nx.div_s x deg_rad) 3600.0
let in_mas x = Nx.mul_s (Nx.div_s x deg_rad) 3_600_000.0
let in_hour_angle x = Nx.mul_s x (12.0 /. Float.pi)
let sin x = Nx.sin x
let cos x = Nx.cos x
let tan x = Nx.tan x
let asin x = Nx.asin x
let acos x = Nx.acos x
let atan2 ~y ~x = Nx.atan2 y x
let wrap_360 x =
let d = in_deg x in
let d = Nx.sub d (Nx.mul_s (Nx.floor (Nx.div_s d 360.0)) 360.0) in
of_deg d
let wrap_180 x =
let d = Nx.add_s (in_deg x) 180.0 in
let d = Nx.sub d (Nx.mul_s (Nx.floor (Nx.div_s d 360.0)) 360.0) in
of_deg (Nx.sub_s d 180.0)
end
module Velocity = struct
let of_tensor x = x
let to_tensor x = x
let m_s x = Nx.scalar f64 x
let km_s x = Nx.scalar f64 (x *. 1e3)
let km_hr x = Nx.scalar f64 (x *. (1e3 /. 3600.0))
let of_m_s x = x
let of_km_s x = Nx.mul_s x 1e3
let of_km_hr x = Nx.mul_s x (1e3 /. 3600.0)
let in_m_s x = x
let in_km_s x = Nx.div_s x 1e3
let in_km_hr x = Nx.div_s x (1e3 /. 3600.0)
end
module Power = struct
let of_tensor x = x
let to_tensor x = x
let w x = Nx.scalar f64 x
let kw x = Nx.scalar f64 (x *. 1e3)
let solar_luminosity x = Nx.scalar f64 (x *. solar_luminosity_w)
let erg_s x = Nx.scalar f64 (x *. 1e-7)
let of_w x = x
let of_kw x = Nx.mul_s x 1e3
let of_solar_luminosity x = Nx.mul_s x solar_luminosity_w
let of_erg_s x = Nx.mul_s x 1e-7
let in_w x = x
let in_kw x = Nx.div_s x 1e3
let in_solar_luminosity x = Nx.div_s x solar_luminosity_w
let in_erg_s x = Nx.mul_s x (1.0 /. 1e-7)
end
module Temperature = struct
let of_tensor x = x
let to_tensor x = x
let kelvin x = Nx.scalar f64 x
let of_kelvin x = x
let in_kelvin x = x
end
module Energy = struct
let of_tensor x = x
let to_tensor x = x
let j x = Nx.scalar f64 x
let erg x = Nx.scalar f64 (x *. 1e-7)
let ev x = Nx.scalar f64 (x *. ev_j)
let kev x = Nx.scalar f64 (x *. ev_j *. 1e3)
let mev x = Nx.scalar f64 (x *. ev_j *. 1e6)
let of_j x = x
let of_erg x = Nx.mul_s x 1e-7
let of_ev x = Nx.mul_s x ev_j
let of_kev x = Nx.mul_s x (ev_j *. 1e3)
let of_mev x = Nx.mul_s x (ev_j *. 1e6)
let in_j x = x
let in_erg x = Nx.mul_s x (1.0 /. 1e-7)
let in_ev x = Nx.div_s x ev_j
let in_kev x = Nx.div_s x (ev_j *. 1e3)
let in_mev x = Nx.div_s x (ev_j *. 1e6)
end
module Frequency = struct
let of_tensor x = x
let to_tensor x = x
let hz x = Nx.scalar f64 x
let khz x = Nx.scalar f64 (x *. 1e3)
let mhz x = Nx.scalar f64 (x *. 1e6)
let ghz x = Nx.scalar f64 (x *. 1e9)
let of_hz x = x
let of_khz x = Nx.mul_s x 1e3
let of_mhz x = Nx.mul_s x 1e6
let of_ghz x = Nx.mul_s x 1e9
let in_hz x = x
let in_khz x = Nx.div_s x 1e3
let in_mhz x = Nx.div_s x 1e6
let in_ghz x = Nx.div_s x 1e9
end
module Dimensionless = struct
let of_tensor x = x
let to_tensor x = x
let v x = Nx.scalar f64 x
let to_float x = Nx.item [] x
end
================================================
FILE: dev/umbra/lib/unit.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Physical quantities with compile-time dimensional safety.
A {e quantity} is an {!Nx.float64_t} tensor of arbitrary shape tagged with a
phantom dimension type. Arithmetic requires matching dimensions:
[length t + length t] typechecks, [length t + mass t] does not. Values are
stored in SI base units internally.
Each dimension module provides three families of functions:
- {e Scalar constructors} ([Length.km], [Mass.kg], ...) create a 0-d
quantity from a [float].
- {e Tensor constructors} ([Length.of_km], [Mass.of_kg], ...) wrap an
arbitrary-shape {!Nx.float64_t}.
- {e Extractors} ([Length.in_km], [Mass.in_kg], ...) return the numeric
value in a given unit as an {!Nx.float64_t}.
{[
open Unit
let d = Length.(kpc 10.0 + pc 500.0)
let d_mpc = Length.in_mpc d
]} *)
(** {1:types Types} *)
type 'a t
(** The type for a physical quantity with dimension ['a]. Internally an
{!Nx.float64_t} in SI base units. *)
type length
(** Phantom type for length (SI: metres). *)
type mass
(** Phantom type for mass (SI: kilograms). *)
type time
(** Phantom type for time duration (SI: seconds). *)
type angle
(** Phantom type for angles (SI: radians). *)
type velocity
(** Phantom type for velocity (SI: m/s). *)
type power
(** Phantom type for power / luminosity (SI: watts). *)
type temperature
(** Phantom type for temperature (SI: kelvin). *)
type energy
(** Phantom type for energy (SI: joules). *)
type frequency
(** Phantom type for frequency (SI: hertz). *)
type dimensionless
(** Phantom type for dimensionless quantities. *)
(** {1:arithmetic Arithmetic}
All operations require matching dimensions. *)
val ( + ) : 'a t -> 'a t -> 'a t
(** [a + b] is the element-wise sum of [a] and [b]. *)
val ( - ) : 'a t -> 'a t -> 'a t
(** [a - b] is the element-wise difference of [a] and [b]. *)
val neg : 'a t -> 'a t
(** [neg x] is the element-wise negation of [x]. *)
val abs : 'a t -> 'a t
(** [abs x] is the element-wise absolute value of [x]. *)
val scale : float -> 'a t -> 'a t
(** [scale s x] multiplies every element of [x] by [s]. *)
val scale_t : Nx.float64_t -> 'a t -> 'a t
(** [scale_t s x] multiplies every element of [x] by the tensor [s]. Keeps the
result in the typed world when the scale factor is a fitted parameter. *)
val ratio : 'a t -> 'a t -> dimensionless t
(** [ratio a b] is the element-wise division [a / b], yielding a dimensionless
quantity. *)
val zero : 'a t
(** [zero] is the scalar quantity [0.0]. *)
(** {1:predicates Predicates, comparisons, and converting}
These functions extract scalar values and are intended for 0-d tensors. *)
val compare : 'a t -> 'a t -> int
(** [compare a b] orders [a] and [b] by their scalar SI values. *)
val equal : 'a t -> 'a t -> bool
(** [equal a b] is [true] iff [a] and [b] have the same scalar SI value. *)
val pp : Format.formatter -> 'a t -> unit
(** [pp] formats the scalar SI value of a quantity. *)
val to_float : 'a t -> float
(** [to_float x] is the scalar value of [x] in SI base units. *)
(** {1:cross Cross-dimension combinators}
Functions that relate quantities of different dimensions. *)
val length_per_time : length t -> time t -> velocity t
(** [length_per_time d t] is [d / t] as a velocity. *)
val velocity_times_time : velocity t -> time t -> length t
(** [velocity_times_time v t] is [v * t] as a length. *)
val length_per_velocity : length t -> velocity t -> time t
(** [length_per_velocity d v] is [d / v] as a time. *)
val wavelength_to_frequency : length t -> frequency t
(** [wavelength_to_frequency lam] is [c / lam]. *)
val frequency_to_wavelength : frequency t -> length t
(** [frequency_to_wavelength nu] is [c / nu]. *)
val frequency_to_energy : frequency t -> energy t
(** [frequency_to_energy nu] is [h * nu]. *)
val energy_to_frequency : energy t -> frequency t
(** [energy_to_frequency e] is [e / h]. *)
val energy_to_wavelength : energy t -> length t
(** [energy_to_wavelength e] is [h * c / e]. *)
val parallax_to_distance : angle t -> length t
(** [parallax_to_distance p] is the distance corresponding to parallax [p]. Uses
[d = 1 AU / p]. One arcsecond of parallax gives one parsec. *)
val distance_to_parallax : length t -> angle t
(** [distance_to_parallax d] is the parallax corresponding to distance [d]. Uses
[p = 1 AU / d]. *)
val flam_to_fnu : wavelength:length t -> Nx.float64_t -> Nx.float64_t
(** [flam_to_fnu ~wavelength flam] converts spectral flux density from
per-wavelength (F{_ {e lambda}}, W m{^ -2} m{^ -1}) to per-frequency
(F{_ {e nu}}, W m{^ -2} Hz{^ -1}) at the given wavelengths. Uses
[f_nu = f_lambda * lambda{^ 2} / c]. *)
val fnu_to_flam : wavelength:length t -> Nx.float64_t -> Nx.float64_t
(** [fnu_to_flam ~wavelength fnu] converts spectral flux density from
per-frequency (F{_ {e nu}}) to per-wavelength (F{_ {e lambda}}) at the given
wavelengths. Uses [f_lambda = f_nu * c / lambda{^ 2}]. *)
(** {2:doppler Doppler conventions}
Three conventions for converting between radial velocity and observed
wavelength, given a rest wavelength. All agree at [v << c]. *)
val doppler_optical : rest:length t -> velocity t -> length t
(** [doppler_optical ~rest v] is the observed wavelength under the optical (cz)
convention: [lambda_obs = lambda_rest * (1 + v/c)]. *)
val doppler_optical_inv : rest:length t -> length t -> velocity t
(** [doppler_optical_inv ~rest obs] is the velocity under the optical
convention: [v = c * (lambda_obs/lambda_rest - 1)]. *)
val doppler_radio : rest:length t -> velocity t -> length t
(** [doppler_radio ~rest v] is the observed wavelength under the radio
convention: [lambda_obs = lambda_rest / (1 - v/c)]. *)
val doppler_radio_inv : rest:length t -> length t -> velocity t
(** [doppler_radio_inv ~rest obs] is the velocity under the radio convention:
[v = c * (1 - lambda_rest/lambda_obs)]. *)
val doppler_relativistic : rest:length t -> velocity t -> length t
(** [doppler_relativistic ~rest v] is the observed wavelength under the full
relativistic Doppler formula:
[lambda_obs = lambda_rest * sqrt((1 + v/c) / (1 - v/c))]. *)
val doppler_relativistic_inv : rest:length t -> length t -> velocity t
(** [doppler_relativistic_inv ~rest obs] is the velocity under the relativistic
formula: [v = c * (r{^ 2} - 1) / (r{^ 2} + 1)] where
[r = lambda_obs/lambda_rest]. *)
(** {1:length Length} *)
module Length : sig
val of_tensor : Nx.float64_t -> length t
(** [of_tensor x] wraps [x] as a length. [x] must be in metres. *)
val to_tensor : length t -> Nx.float64_t
(** [to_tensor x] is the underlying tensor in metres. *)
(** {2:scalar Scalar constructors}
Each function creates a 0-d length quantity from a [float] value in the
named unit. *)
val m : float -> length t
(** [m x] is [x] metres. *)
val km : float -> length t
(** [km x] is [x] kilometres. *)
val cm : float -> length t
(** [cm x] is [x] centimetres. *)
val mm : float -> length t
(** [mm x] is [x] millimetres. *)
val um : float -> length t
(** [um x] is [x] micrometres. *)
val nm : float -> length t
(** [nm x] is [x] nanometres. *)
val angstrom : float -> length t
(** [angstrom x] is [x] angstroms. *)
val au : float -> length t
(** [au x] is [x] astronomical units. *)
val pc : float -> length t
(** [pc x] is [x] parsecs. *)
val kpc : float -> length t
(** [kpc x] is [x] kiloparsecs. *)
val mpc : float -> length t
(** [mpc x] is [x] megaparsecs. *)
val gpc : float -> length t
(** [gpc x] is [x] gigaparsecs. *)
val ly : float -> length t
(** [ly x] is [x] light-years. *)
val solar_radius : float -> length t
(** [solar_radius x] is [x] solar radii. *)
val earth_radius : float -> length t
(** [earth_radius x] is [x] Earth equatorial radii. *)
val jupiter_radius : float -> length t
(** [jupiter_radius x] is [x] Jupiter equatorial radii. *)
(** {2:tensor Tensor constructors}
Each function wraps an arbitrary-shape {!Nx.float64_t} (in the named unit)
as a length quantity. *)
val of_m : Nx.float64_t -> length t
val of_km : Nx.float64_t -> length t
val of_cm : Nx.float64_t -> length t
val of_mm : Nx.float64_t -> length t
val of_um : Nx.float64_t -> length t
val of_nm : Nx.float64_t -> length t
val of_angstrom : Nx.float64_t -> length t
val of_au : Nx.float64_t -> length t
val of_pc : Nx.float64_t -> length t
val of_kpc : Nx.float64_t -> length t
val of_mpc : Nx.float64_t -> length t
val of_gpc : Nx.float64_t -> length t
val of_ly : Nx.float64_t -> length t
val of_solar_radius : Nx.float64_t -> length t
val of_earth_radius : Nx.float64_t -> length t
val of_jupiter_radius : Nx.float64_t -> length t
(** {2:extract Extracting}
Each function returns the numeric value in the named unit as an
{!Nx.float64_t}. *)
val in_m : length t -> Nx.float64_t
val in_km : length t -> Nx.float64_t
val in_cm : length t -> Nx.float64_t
val in_mm : length t -> Nx.float64_t
val in_um : length t -> Nx.float64_t
val in_nm : length t -> Nx.float64_t
val in_angstrom : length t -> Nx.float64_t
val in_au : length t -> Nx.float64_t
val in_pc : length t -> Nx.float64_t
val in_kpc : length t -> Nx.float64_t
val in_mpc : length t -> Nx.float64_t
val in_gpc : length t -> Nx.float64_t
val in_ly : length t -> Nx.float64_t
val in_solar_radius : length t -> Nx.float64_t
val in_earth_radius : length t -> Nx.float64_t
val in_jupiter_radius : length t -> Nx.float64_t
end
(** {1:mass Mass} *)
module Mass : sig
val of_tensor : Nx.float64_t -> mass t
(** [of_tensor x] wraps [x] as a mass. [x] must be in kilograms. *)
val to_tensor : mass t -> Nx.float64_t
(** [to_tensor x] is the underlying tensor in kilograms. *)
(** {2:scalar Scalar constructors} *)
val kg : float -> mass t
(** [kg x] is [x] kilograms. *)
val g : float -> mass t
(** [g x] is [x] grams. *)
val mg : float -> mass t
(** [mg x] is [x] milligrams. *)
val solar_mass : float -> mass t
(** [solar_mass x] is [x] solar masses. *)
val earth_mass : float -> mass t
(** [earth_mass x] is [x] Earth masses. *)
val jupiter_mass : float -> mass t
(** [jupiter_mass x] is [x] Jupiter masses. *)
(** {2:tensor Tensor constructors} *)
val of_kg : Nx.float64_t -> mass t
val of_g : Nx.float64_t -> mass t
val of_mg : Nx.float64_t -> mass t
val of_solar_mass : Nx.float64_t -> mass t
val of_earth_mass : Nx.float64_t -> mass t
val of_jupiter_mass : Nx.float64_t -> mass t
(** {2:extract Extracting} *)
val in_kg : mass t -> Nx.float64_t
val in_g : mass t -> Nx.float64_t
val in_mg : mass t -> Nx.float64_t
val in_solar_mass : mass t -> Nx.float64_t
val in_earth_mass : mass t -> Nx.float64_t
val in_jupiter_mass : mass t -> Nx.float64_t
end
(** {1:time Time duration} *)
module Time : sig
val of_tensor : Nx.float64_t -> time t
(** [of_tensor x] wraps [x] as a time duration. [x] must be in seconds. *)
val to_tensor : time t -> Nx.float64_t
(** [to_tensor x] is the underlying tensor in seconds. *)
(** {2:scalar Scalar constructors} *)
val s : float -> time t
(** [s x] is [x] seconds. *)
val ms : float -> time t
(** [ms x] is [x] milliseconds. *)
val us : float -> time t
(** [us x] is [x] microseconds. *)
val min : float -> time t
(** [min x] is [x] minutes. *)
val hr : float -> time t
(** [hr x] is [x] hours. *)
val day : float -> time t
(** [day x] is [x] days (86 400 s). *)
val yr : float -> time t
(** [yr x] is [x] Julian years (365.25 days). *)
val myr : float -> time t
(** [myr x] is [x] megayears. *)
val gyr : float -> time t
(** [gyr x] is [x] gigayears. *)
(** {2:tensor Tensor constructors} *)
val of_s : Nx.float64_t -> time t
val of_ms : Nx.float64_t -> time t
val of_us : Nx.float64_t -> time t
val of_min : Nx.float64_t -> time t
val of_hr : Nx.float64_t -> time t
val of_day : Nx.float64_t -> time t
val of_yr : Nx.float64_t -> time t
val of_myr : Nx.float64_t -> time t
val of_gyr : Nx.float64_t -> time t
(** {2:extract Extracting} *)
val in_s : time t -> Nx.float64_t
val in_ms : time t -> Nx.float64_t
val in_us : time t -> Nx.float64_t
val in_min : time t -> Nx.float64_t
val in_hr : time t -> Nx.float64_t
val in_day : time t -> Nx.float64_t
val in_yr : time t -> Nx.float64_t
val in_myr : time t -> Nx.float64_t
val in_gyr : time t -> Nx.float64_t
end
(** {1:angle Angle} *)
module Angle : sig
val of_tensor : Nx.float64_t -> angle t
(** [of_tensor x] wraps [x] as an angle. [x] must be in radians. *)
val to_tensor : angle t -> Nx.float64_t
(** [to_tensor x] is the underlying tensor in radians. *)
(** {2:scalar Scalar constructors} *)
val rad : float -> angle t
(** [rad x] is [x] radians. *)
val deg : float -> angle t
(** [deg x] is [x] degrees. *)
val arcmin : float -> angle t
(** [arcmin x] is [x] arcminutes. *)
val arcsec : float -> angle t
(** [arcsec x] is [x] arcseconds. *)
val mas : float -> angle t
(** [mas x] is [x] milliarcseconds. *)
val hour_angle : float -> angle t
(** [hour_angle x] is [x] hour angles (1 h = 15 deg). *)
(** {2:tensor Tensor constructors} *)
val of_rad : Nx.float64_t -> angle t
val of_deg : Nx.float64_t -> angle t
val of_arcmin : Nx.float64_t -> angle t
val of_arcsec : Nx.float64_t -> angle t
val of_mas : Nx.float64_t -> angle t
val of_hour_angle : Nx.float64_t -> angle t
(** {2:extract Extracting} *)
val in_rad : angle t -> Nx.float64_t
val in_deg : angle t -> Nx.float64_t
val in_arcmin : angle t -> Nx.float64_t
val in_arcsec : angle t -> Nx.float64_t
val in_mas : angle t -> Nx.float64_t
val in_hour_angle : angle t -> Nx.float64_t
(** {2:trig Trigonometric functions} *)
val sin : angle t -> Nx.float64_t
(** [sin a] is the sine of [a]. *)
val cos : angle t -> Nx.float64_t
(** [cos a] is the cosine of [a]. *)
val tan : angle t -> Nx.float64_t
(** [tan a] is the tangent of [a]. *)
val asin : Nx.float64_t -> angle t
(** [asin x] is the arc sine of [x]. *)
val acos : Nx.float64_t -> angle t
(** [acos x] is the arc cosine of [x]. *)
val atan2 : y:Nx.float64_t -> x:Nx.float64_t -> angle t
(** [atan2 ~y ~x] is the two-argument arc tangent of [y] and [x]. *)
(** {2:wrap Wrapping} *)
val wrap_360 : angle t -> angle t
(** [wrap_360 a] normalizes [a] into \[0, 360) degrees. *)
val wrap_180 : angle t -> angle t
(** [wrap_180 a] normalizes [a] into \[-180, 180) degrees. *)
end
(** {1:velocity Velocity} *)
module Velocity : sig
val of_tensor : Nx.float64_t -> velocity t
(** [of_tensor x] wraps [x] as a velocity. [x] must be in m/s. *)
val to_tensor : velocity t -> Nx.float64_t
(** [to_tensor x] is the underlying tensor in m/s. *)
(** {2:scalar Scalar constructors} *)
val m_s : float -> velocity t
(** [m_s x] is [x] m/s. *)
val km_s : float -> velocity t
(** [km_s x] is [x] km/s. *)
val km_hr : float -> velocity t
(** [km_hr x] is [x] km/h. *)
(** {2:tensor Tensor constructors} *)
val of_m_s : Nx.float64_t -> velocity t
val of_km_s : Nx.float64_t -> velocity t
val of_km_hr : Nx.float64_t -> velocity t
(** {2:extract Extracting} *)
val in_m_s : velocity t -> Nx.float64_t
val in_km_s : velocity t -> Nx.float64_t
val in_km_hr : velocity t -> Nx.float64_t
end
(** {1:power Power / Luminosity} *)
module Power : sig
val of_tensor : Nx.float64_t -> power t
(** [of_tensor x] wraps [x] as a power. [x] must be in watts. *)
val to_tensor : power t -> Nx.float64_t
(** [to_tensor x] is the underlying tensor in watts. *)
(** {2:scalar Scalar constructors} *)
val w : float -> power t
(** [w x] is [x] watts. *)
val kw : float -> power t
(** [kw x] is [x] kilowatts. *)
val solar_luminosity : float -> power t
(** [solar_luminosity x] is [x] solar luminosities. *)
val erg_s : float -> power t
(** [erg_s x] is [x] erg/s. *)
(** {2:tensor Tensor constructors} *)
val of_w : Nx.float64_t -> power t
val of_kw : Nx.float64_t -> power t
val of_solar_luminosity : Nx.float64_t -> power t
val of_erg_s : Nx.float64_t -> power t
(** {2:extract Extracting} *)
val in_w : power t -> Nx.float64_t
val in_kw : power t -> Nx.float64_t
val in_solar_luminosity : power t -> Nx.float64_t
val in_erg_s : power t -> Nx.float64_t
end
(** {1:temperature Temperature} *)
module Temperature : sig
val of_tensor : Nx.float64_t -> temperature t
(** [of_tensor x] wraps [x] as a temperature. [x] must be in kelvin. *)
val to_tensor : temperature t -> Nx.float64_t
(** [to_tensor x] is the underlying tensor in kelvin. *)
(** {2:scalar Scalar constructors} *)
val kelvin : float -> temperature t
(** [kelvin x] is [x] kelvin. *)
(** {2:tensor Tensor constructors} *)
val of_kelvin : Nx.float64_t -> temperature t
(** {2:extract Extracting} *)
val in_kelvin : temperature t -> Nx.float64_t
end
(** {1:energy Energy} *)
module Energy : sig
val of_tensor : Nx.float64_t -> energy t
(** [of_tensor x] wraps [x] as an energy. [x] must be in joules. *)
val to_tensor : energy t -> Nx.float64_t
(** [to_tensor x] is the underlying tensor in joules. *)
(** {2:scalar Scalar constructors} *)
val j : float -> energy t
(** [j x] is [x] joules. *)
val erg : float -> energy t
(** [erg x] is [x] ergs. *)
val ev : float -> energy t
(** [ev x] is [x] electronvolts. *)
val kev : float -> energy t
(** [kev x] is [x] kiloelectronvolts. *)
val mev : float -> energy t
(** [mev x] is [x] megaelectronvolts. *)
(** {2:tensor Tensor constructors} *)
val of_j : Nx.float64_t -> energy t
val of_erg : Nx.float64_t -> energy t
val of_ev : Nx.float64_t -> energy t
val of_kev : Nx.float64_t -> energy t
val of_mev : Nx.float64_t -> energy t
(** {2:extract Extracting} *)
val in_j : energy t -> Nx.float64_t
val in_erg : energy t -> Nx.float64_t
val in_ev : energy t -> Nx.float64_t
val in_kev : energy t -> Nx.float64_t
val in_mev : energy t -> Nx.float64_t
end
(** {1:frequency Frequency} *)
module Frequency : sig
val of_tensor : Nx.float64_t -> frequency t
(** [of_tensor x] wraps [x] as a frequency. [x] must be in hertz. *)
val to_tensor : frequency t -> Nx.float64_t
(** [to_tensor x] is the underlying tensor in hertz. *)
(** {2:scalar Scalar constructors} *)
val hz : float -> frequency t
(** [hz x] is [x] hertz. *)
val khz : float -> frequency t
(** [khz x] is [x] kilohertz. *)
val mhz : float -> frequency t
(** [mhz x] is [x] megahertz. *)
val ghz : float -> frequency t
(** [ghz x] is [x] gigahertz. *)
(** {2:tensor Tensor constructors} *)
val of_hz : Nx.float64_t -> frequency t
val of_khz : Nx.float64_t -> frequency t
val of_mhz : Nx.float64_t -> frequency t
val of_ghz : Nx.float64_t -> frequency t
(** {2:extract Extracting} *)
val in_hz : frequency t -> Nx.float64_t
val in_khz : frequency t -> Nx.float64_t
val in_mhz : frequency t -> Nx.float64_t
val in_ghz : frequency t -> Nx.float64_t
end
(** {1:dimensionless Dimensionless} *)
module Dimensionless : sig
val of_tensor : Nx.float64_t -> dimensionless t
(** [of_tensor x] wraps [x] as a dimensionless quantity. *)
val to_tensor : dimensionless t -> Nx.float64_t
(** [to_tensor x] is the underlying tensor. *)
val v : float -> dimensionless t
(** [v x] is the scalar dimensionless quantity [x]. *)
val to_float : dimensionless t -> float
(** [to_float x] is the scalar value of [x]. Intended for 0-d tensors. *)
end
================================================
FILE: dev/umbra/lib/vega_data.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Vega (alpha Lyrae) reference spectrum from CALSPEC: alpha_lyr_stis_011.fits
(Bohlin 2014). Subsampled to 300 points from 900-250000 Angstroms.
wave: wavelength in Angstroms. flux: spectral flux density f_lambda in
W/m^2/m (SI). *)
[@@@ocamlformat "disable"]
let wave =
[|
900.452; 917.759; 935.398; 953.377; 971.701;
989.386; 1008.402; 1027.784; 1047.539; 1067.673;
1087.104; 1107.998; 1129.294; 1151.0; 1172.199951171875;
1198.199951171875; 1230.0; 1240.699951171875; 1264.300048828125; 1287.9000244140625;
1312.5999755859375; 1337.4000244140625; 1363.4000244140625; 1388.0999755859375; 1415.300048828125;
1442.4000244140625; 1469.5; 1496.699951171875; 1526.199951171875; 1554.5;
1584.0; 1614.5999755859375; 1645.300048828125; 1676.115356; 1707.652832;
1740.569458; 1773.494019; 1807.798096; 1842.109619; 1876.428101;
1912.126099; 1947.830688; 1984.914551; 2023.37793; 2061.84668;
2100.320312; 2140.172119; 2181.401855; 2222.634521; 2265.243652;
2307.85376; 2351.838867; 2395.823242; 2441.179932; 2487.907959;
2536.005615; 2584.09668; 2632.18042; 2683.001953; 2733.811768;
2785.980957; 2838.134277; 2891.641846; 2946.500488; 3002.706055;
3060.254883; 3119.17334; 3176.848877; 3237.280029; 3300.467529;
3363.663574; 3426.86792; 3492.827881; 3558.794678; 3624.767822;
3693.495117; 3764.976562; 3836.461914; 3907.949951; 3982.189453;
4059.178955; 4136.16748; 4213.153809; 4292.885742; 4375.36084;
4457.828613; 4543.035156; 4630.977539; 4718.905273; 4806.816406;
4897.454102; 4990.81543; 5086.894531; 5182.942871; 5281.700684;
5380.419922; 5485.38916; 5587.699219; 5694.90625; 5802.137207;
5914.267578; 6026.421387; 6138.598145; 6255.674805; 6377.652344;
6494.770996; 6621.67041; 6743.70752; 6875.525879; 7002.478027;
7139.210449; 7271.072266; 7412.712402; 7549.477051; 7696.016602;
7842.55957; 7989.103027; 8140.52832; 8296.831055; 8453.124023;
8614.286133; 8775.428711; 8946.311523; 9112.282227; 9287.978516;
9463.631836; 9644.114258; 9824.539062; 10014.648438; 10205.378729443204;
10401.529770846228; 10601.450905715283; 10794.393204668508; 11001.865315614596; 11213.325115099247;
11428.849247658822; 11648.515831526549; 11860.514306414294; 12088.477647568725; 12320.82252599989;
12557.633156314389; 12786.177214329819; 13031.93212871661; 13282.410540749495; 13537.703237775104;
13797.902752017966; 14049.019236198084; 14319.046427392743; 14594.263638185937; 14874.770622496817;
15145.48568465375; 15436.587354440591; 15733.284102879634; 16035.6834692452; 16343.895059966944;
16641.34761761389; 16961.200290683963; 17287.200646834564; 17619.466846796713; 17940.134316746688;
18284.950136177056; 18636.393439411142; 18994.591609034607; 19359.674476061784; 19712.013050644106;
20090.8850151999; 20477.039034852016; 20870.61507330569; 21250.45218142002; 21658.893498090372;
22075.185203484638; 22499.47818484361; 22931.926229505145; 23349.278401987125; 23798.05991182052;
24255.46716331803; 24721.665946277586; 25171.590688206736; 25655.39769412964; 26148.503644305452;
26651.08726772658; 27163.33072884503; 27657.69282371489; 28189.283604893368; 28731.09175155295;
29283.313645250695; 29816.258607862546; 30389.33779785106; 30973.431775616467; 31568.752249243276;
32175.51499603123; 32761.096902670255; 33390.77694148892; 34032.559656653226; 34686.6776658844;
35317.961760150065; 35996.78565914454; 36688.65679708345; 37393.82594651755; 38112.548700060266;
38806.182319759384; 39552.05106970492; 40312.25568451587; 41087.071705009694; 41834.84049088929;
42638.92113772262; 43458.45650302316; 44293.74363220345; 45145.08527967651; 45966.708340478566;
46850.20497002188; 47750.6827218135; 48668.46797921292; 49603.893398540575; 50506.66503950412;
51477.42126605539; 52466.83577548603; 53475.2671869996; 54448.49633089426; 55495.01596936514;
56561.650090809424; 57648.78530287201; 58756.81564377781; 59826.16692337332; 60976.04732038082;
62148.02883123579; 63342.53624731101; 64495.34586512943; 65734.96955693385; 66998.41926097606;
68286.15292170878; 69598.63728556872; 70865.30551837588; 72227.3621205334; 73615.59793944974;
75030.51614901233; 76396.04247255616; 77864.4018237676; 79360.98356865476; 80886.33015165699;
82440.99444348396; 83941.38860077565; 85554.77221228057; 87199.16563569565; 88875.16489081086;
90492.65845739808; 92231.95982648393; 94004.69119201355; 95811.49509052176; 97653.02640827911;
99430.27365358408; 101341.3591939848; 103289.17648427008; 105274.43152183376; 107190.38538724542;
109250.62305986807; 111350.45923995743; 113490.65502485144; 115671.98614086409; 117777.17147181589;
120040.89097212761; 122348.11997351186; 124699.69474434588; 126969.18087410457; 129409.57409375694;
131896.8725467459; 134431.9777685813; 137015.80862256046; 139509.44325303897; 142190.86481564713;
144923.8242628955; 147709.31217146275; 150397.56442345414; 153288.25958223912; 156234.51493917976;
159237.39838144454; 162297.9983210655; 165251.7590085775; 168427.95711765194; 171665.20289426477;
174964.66969630358; 178148.95889525744; 181573.045814387; 185062.94491283095; 188619.92112426818;
192245.26369497634; 195744.05300937625; 199506.32395143431; 203340.90709103053; 207249.19229470106;
211021.04729625938; 215076.94755071797; 219210.80366346712; 223424.11397425184; 227718.40562110202;
231862.792374509; 236319.27844847148; 240861.41978391458; 245490.86270582714; 249958.7012923323
|]
let flux =
[|
1.2837948759614193e-10; 1.1943805588998657e-07; 1.173052282865683e-06; 2.1574317088379757e-06; 1.3420374216366326e-06;
4.100760634173639e-06; 6.2335830079973675e-06; 4.519122285273625e-06; 1.093481569114374e-05; 1.3420373761618976e-05;
1.8457114492775872e-05; 4.486309626372531e-05; 0.0006154832663014531; 0.0017448125872761011; 0.00016285567835438997;
0.00012107902148272842; 0.0002881045511458069; 0.0008428706787526608; 0.00555039057508111; 0.021335331723093987;
0.032872095704078674; 0.04942339286208153; 0.05074312165379524; 0.05717964470386505; 0.062469806522130966;
0.0662260353565216; 0.06470223516225815; 0.0783146470785141; 0.06845968961715698; 0.07111997902393341;
0.0734374150633812; 0.0725170373916626; 0.07288476824760437; 0.06900300085544586; 0.0641390010714531;
0.06533699482679367; 0.06514099985361099; 0.06312000006437302; 0.058240000158548355; 0.060812000185251236;
0.059772998094558716; 0.060645997524261475; 0.05867400020360947; 0.0558370016515255; 0.05407499894499779;
0.05430099740624428; 0.04962899908423424; 0.05051000043749809; 0.047245997935533524; 0.044964998960494995;
0.04547500237822533; 0.04311799630522728; 0.03780499845743179; 0.041788000613451004; 0.04115099832415581;
0.03668300062417984; 0.03907399997115135; 0.03461499884724617; 0.040741000324487686; 0.03929299861192703;
0.03749600052833557; 0.037234000861644745; 0.03612099960446358; 0.03686100244522095; 0.03653800114989281;
0.036056000739336014; 0.03495499864220619; 0.03484500199556351; 0.033263999968767166; 0.03422100096940994;
0.032947998493909836; 0.032896000891923904; 0.03175799921154976; 0.031553998589515686; 0.03090999834239483;
0.0309200007468462; 0.045221999287605286; 0.033542998135089874; 0.07968499511480331; 0.06394299864768982;
0.08404000103473663; 0.07898300141096115; 0.07822800427675247; 0.0719740018248558; 0.0667010024189949;
0.06643600016832352; 0.06298799812793732; 0.05953500047326088; 0.05647499859333038; 0.05225900188088417;
0.047912996262311935; 0.04777200147509575; 0.04543299973011017; 0.04206399992108345; 0.04022299870848656;
0.03820599988102913; 0.03601999953389168; 0.03411700204014778; 0.032175999134778976; 0.03042599931359291;
0.028788000345230103; 0.027058999985456467; 0.02562600001692772; 0.024166999384760857; 0.02274700067937374;
0.021289000287652016; 0.019968999549746513; 0.019201001152396202; 0.01802999898791313; 0.017007999122142792;
0.01603100076317787; 0.015209999866783619; 0.014260999858379364; 0.013499000109732151; 0.012649999931454659;
0.011848000809550285; 0.011309999972581863; 0.010559000074863434; 0.009970699436962605; 0.009328600019216537;
0.0090791005641222; 0.008882399648427963; 0.009110400453209877; 0.008642599917948246; 0.008087899535894394;
0.007705099880695343; 0.00721339974552393; 0.006843299604952335; 0.006151599809527397; 0.006024217698723078;
0.005648050922900438; 0.0052790273912250996; 0.004945714958012104; 0.004534630570560694; 0.0043441662564873695;
0.004064818844199181; 0.0037846784107387066; 0.0035608832258731127; 0.0033259775955229998; 0.0031156735494732857;
0.0029069569427520037; 0.002545868745073676; 0.002553011057898402; 0.0023879422806203365; 0.0022300160489976406;
0.0020839935168623924; 0.001953843282535672; 0.0018244864186272025; 0.0017038591904565692; 0.0015895807882770896;
0.0014951423509046435; 0.0013784831389784813; 0.0012856320245191455; 0.0012396031524986029; 0.0011531008640304208;
0.0011078655952587724; 0.0010332672391086817; 0.0009435904212296009; 0.0008983552106656134; 0.0008412160095758736;
0.0007756645791232586; 0.0007241599960252643; 0.0006796390516683459; 0.0006226585246622562; 0.0005917081143707037;
0.0005520281265489757; 0.0005140146822668612; 0.0004784614429809153; 0.00044632062781602144; 0.00036132606328465044;
0.0003871974186040461; 0.0003606118552852422; 0.0003356928064022213; 0.0003139481705147773; 0.00029244160396046937;
0.0002722047793213278; 0.0002522854192648083; 0.0002364928077440709; 0.00021379583631642163; 0.00019514624727889895;
0.00018959103908855468; 0.00017943295824807137; 0.0001626086450414732; 0.00015602176426909864; 0.00013110271538607776;
0.0001348326331935823; 0.0001258649572264403; 0.0001058662383002229; 0.00010872319398913532; 0.00010110463335877284;
9.396224049851298e-05; 8.72959935804829e-05; 8.126463944790885e-05; 7.560627273051068e-05; 7.031296263448894e-05;
6.562278576893732e-05; 6.097228470025584e-05; 5.6718592531979084e-05; 4.850483310292475e-05; 4.869529584539123e-05;
4.574310514726676e-05; 4.2584575567161664e-05; 3.8981630495982245e-05; 3.6783359973924235e-05; 3.387084507266991e-05;
3.1902720365906134e-05; 2.955366471724119e-05; 2.7506175683811307e-05; 2.5553918021614663e-05; 2.3807999241398647e-05;
2.1720832592109218e-05; 2.0554240109049715e-05; 1.9094017261522822e-05; 1.7729023966239765e-05; 1.6522752048331313e-05;
1.520537625765428e-05; 1.422924833605066e-05; 1.3229311662144028e-05; 1.2340479770500679e-05; 1.143577628681669e-05;
1.063423951563891e-05; 9.872383998299483e-06; 9.134336323768366e-06; 8.507391612511128e-06; 7.931238542369101e-06;
7.359846222243505e-06; 6.820991984568536e-06; 6.3662591855973005e-06; 5.908352250116877e-06; 5.479807896335842e-06;
5.0814210226235446e-06; 4.719538992503658e-06; 4.393369636090938e-06; 4.080691269336967e-06; 3.7862655517528765e-06;
3.291852635811665e-06; 3.273599986641784e-06; 3.0180608519003727e-06; 2.8196607217978453e-06; 2.615705398056889e-06;
2.4268288143503014e-06; 2.2609663119510515e-06; 2.0966911051800707e-06; 1.937177557920222e-06; 1.8046463310383842e-06;
1.6800512412373791e-06; 1.5578368675051024e-06; 1.4118143099040026e-06; 1.3411839745458565e-06; 1.2443647392501589e-06;
1.1586560049181571e-06; 1.0753279866548837e-06; 9.967616279027425e-07; 9.150207347374817e-07; 8.610560371380416e-07;
7.983616114870529e-07; 7.404287885037775e-07; 6.840038508926227e-07; 6.342451115415315e-07; 5.927398092353542e-07;
5.49964795482083e-07; 5.098086148791481e-07; 4.7227135269167775e-07; 4.399718420700083e-07; 4.0822783375915606e-07;
3.7775359373881656e-07; 3.5092992334284645e-07; 3.2537599281567964e-07; 3.019648033841804e-07; 2.7887102760359994e-07;
2.6038017608698283e-07; 2.414131188288593e-07; 2.2458880266640335e-07; 2.0831998881476466e-07; 1.930828688045949e-07;
1.790361494613535e-07; 1.6260864299511013e-07; 1.5435520595019625e-07; 1.4181631513565662e-07; 1.32769287120027e-07;
1.230079931247019e-07; 1.1451648163074424e-07; 1.0618367696224595e-07; 9.83270425081173e-08; 9.118463850654734e-08;
8.451839761391966e-08; 7.86695650845104e-08; 7.293977688505038e-08; 6.760678417094823e-08; 6.264678376055599e-08;
5.8289920445986354e-08; 5.4044157593580167e-08; 5.0052349820361997e-08; 4.447334234214395e-08; 4.3013120176738084e-08;
4.002918529977251e-08; 3.6950016379933004e-08; 3.4370817303397416e-08; 3.1863038429946755e-08; 2.9625088160400992e-08
|]
================================================
FILE: dev/umbra/papers/perlmutter1999/.gitignore
================================================
/data/
================================================
FILE: dev/umbra/papers/perlmutter1999/download_data.sh
================================================
#!/usr/bin/env bash
# Download Pantheon+ Type Ia supernova data (Scolnic et al. 2022, Brout et al. 2022).
#
# The Pantheon+ compilation contains 1701 light curves of 1550 unique SNe Ia
# spanning 0.001 < z < 2.26, extending the original 42 high-z supernovae from
# Perlmutter et al. (1999) that first demonstrated cosmic acceleration.
#
# Source: https://github.com/PantheonPlusSH0ES/DataRelease
# Papers: arXiv:2112.03863 (data), arXiv:2202.04077 (cosmology)
set -euo pipefail
DIR="$(cd "$(dirname "$0")" && pwd)"
DATA_DIR="${DIR}/data"
mkdir -p "${DATA_DIR}"
BASE_URL="https://raw.githubusercontent.com/PantheonPlusSH0ES/DataRelease/main/Pantheon%2B_Data/4_DISTANCES_AND_COVAR"
echo "Downloading Pantheon+ SN Ia distance data..."
curl -fSL "${BASE_URL}/Pantheon%2BSH0ES.dat" -o "${DATA_DIR}/Pantheon+SH0ES.dat"
echo " -> ${DATA_DIR}/Pantheon+SH0ES.dat ($(wc -l < "${DATA_DIR}/Pantheon+SH0ES.dat") lines)"
echo "Downloading paper PDF (Perlmutter et al. 1999, arXiv:astro-ph/9812133)..."
curl -fSL "https://arxiv.org/pdf/astro-ph/9812133" -o "${DATA_DIR}/perlmutter1999.pdf"
echo " -> ${DATA_DIR}/perlmutter1999.pdf"
echo "Done."
================================================
FILE: dev/umbra/papers/perlmutter1999/perlmutter1999.md
================================================
# The Accelerating Universe
Reproducing the key result of Perlmutter et al. (1999), "Measurements of
$\Omega$ and $\Lambda$ from 42 High-Redshift Supernovae" (ApJ 517, 565) --
the Nobel Prize-winning discovery that the expansion of the universe is
accelerating.
We use the modern Pantheon+ dataset (Scolnic et al. 2022, 1701 SNe Ia spanning
$0.001 < z < 2.26$) which extends the original 42 supernovae and confirms the
result with far greater precision.
## Background
Type Ia supernovae (SNe Ia) are "standardizable candles": after correcting for
the correlation between peak luminosity and light-curve width, they have
remarkably uniform absolute magnitudes. This lets us measure their distances
through the **distance modulus**:
$$\mu = m - M = 5 \log_{10}\!\left(\frac{d_L}{\text{Mpc}}\right) + 25$$
where $d_L$ is the luminosity distance, which depends on the cosmological
parameters $\Omega_M$ (matter density) and $\Omega_\Lambda$ (dark energy
density). In 1998--1999, two independent teams (the Supernova Cosmology Project
and the High-z Supernova Search Team) found that distant SNe Ia are **fainter
than expected** in a decelerating universe -- implying that the expansion is
accelerating, driven by a cosmological constant or dark energy.
We reproduce three key results:
1. The **Hubble diagram** ($\mu$ vs $z$) with cosmological model curves
2. **Residuals** relative to an empty universe, showing the acceleration signal
3. **Confidence contours** in the $\Omega_M$--$\Omega_\Lambda$ plane
## Setup
```ocaml
#require "umbra";;
open Nx
open Umbra
let f64 = Nx.float64
let f32 = Nx.float32
```
val f64 : (float, Nx.float64_elt) Nx.dtype = Nx.Float64
val f32 : (float, Nx.float32_elt) Nx.dtype = Nx.Float32
## Loading the Pantheon+ data
The Pantheon+ compilation (Scolnic et al. 2022) provides standardized distance
moduli for 1701 SN Ia light curves from 18 surveys. We load the data file
(downloaded by `download_data.sh`) and extract redshift, distance modulus, and
the diagonal error (for plotting; full cosmological fits require the covariance
matrix).
```ocaml
let df = Talon_csv.read ~sep:' ' "data/Pantheon+SH0ES.dat"
let () =
Printf.printf "Loaded %d light curves, %d columns\n"
(Talon.num_rows df) (List.length (Talon.column_names df))
```
Loaded 1701 light curves, 47 columns
val df : Talon.t =
CID
IDSURVEY
zHD
zHDERR
zCMB
zCMBERR
zHEL
zHELERR
m_b_corr
m_b_corr_err_DIAG
…
2011fe
51
0.00122
0.00084
0.00122
2e-05
0.00082
2e-05
9.74571
1.51621
…
2011fe
56
0.00122
0.00084
0.00122
2e-05
0.00082
2e-05
9.80286
1.51723
…
2012cg
51
0.00256
0.00084
0.00256
2e-05
0.00144
2e-05
11.4703
0.781906
…
2012cg
56
0.00256
0.00084
0.00256
2e-05
0.00144
2e-05
11.4919
0.798612
…
1994DRichmond
50
0.00299
0.00084
0.00299
4e-05
0.00187
4e-05
11.5227
0.880798
…
1981B
50
0.00317
0.00084
0.0035
1e-05
0.00236
1e-05
11.5416
0.613941
…
2013aa
56
0.00331
0.00085
0.00478
0.00015
0.00411
0.00015
11.2074
0.59407
…
2013aa
5
0.00331
0.00085
0.00478
0.00015
0.00411
0.00015
11.2998
0.579622
…
2017cbv
5
0.00331
0.00085
0.00478
0.00015
0.00411
0.00015
11.1483
0.577815
…
2017cbv
18
0.00331
0.00085
0.00478
0.00015
0.00411
0.00015
11.2577
0.577916
…
2001el
50
0.00333
0.00084
0.00357
1e-05
0.00379
1e-05
12.2481
0.590389
…
2011by
51
0.00349
0.00084
0.00369
2e-05
0.00313
2e-05
12.5403
0.55206
…
1998aq
50
0.00349
0.00084
0.00369
1e-05
0.00313
1e-05
12.2437
0.544824
…
1990N
50
0.00359
0.00084
0.00462
2e-05
0.00355
2e-05
12.4439
0.550332
…
2021pit
56
0.00384
0.00084
0.00366
1e-05
0.00388
1e-05
11.7469
0.565861
…
2005df
50
0.00407
0.00084
0.00435
1e-05
0.00435
1e-05
12.1403
0.475638
…
2005df_ANU
50
0.00407
0.00084
0.00435
1e-05
0.00435
1e-05
12.1249
0.478515
…
2013dy
51
0.00432
0.00084
0.00293
0.00012
0.00394
0.00012
12.246
0.513549
…
2013dy
56
0.00432
0.00084
0.00293
0.00012
0.00394
0.00012
12.3081
0.530151
…
2012ht
56
0.00465
0.00084
0.00465
2e-05
0.00352
2e-05
12.6779
0.441191
…
…
…
…
…
…
…
…
…
…
…
…
1701 rows × 47 columns
```ocaml
let col name =
Talon.get_column_exn df name |> Talon.Col.to_tensor f64 |> Option.get
let sn_z = col "zHD"
let sn_mu = col "MU_SH0ES"
let sn_mu_err = col "MU_SH0ES_ERR_DIAG"
let () =
let n = (Nx.shape sn_z).(0) in
Printf.printf "%d SNe Ia, z in [%.4f, %.3f]\n" n
(Nx.item [] (Nx.min sn_z)) (Nx.item [] (Nx.max sn_z))
```
1701 SNe Ia, z in [0.0012, 2.261]
val col : string -> (float, Nx.float64_elt) Nx.t =
val sn_z : (float, Nx.float64_elt) Nx.t = float64 [1701]
[0.00122, 0.00122, ..., 1.91165, 2.26137]
val sn_mu : (float, Nx.float64_elt) Nx.t = float64 [1701]
[28.9987, 29.0559, ..., 45.4233, 46.1828]
val sn_mu_err : (float, Nx.float64_elt) Nx.t = float64 [1701]
[1.51645, 1.51747, ..., 0.358642, 0.281309]
## Cosmological models
We compute the theoretical distance modulus $\mu(z)$ for several cosmologies
to compare with the data. The key insight from Perlmutter et al. is that the
data prefer $\Omega_\Lambda > 0$ (accelerating expansion) over
$\Omega_\Lambda = 0$ (decelerating expansion).
The models we compare:
- **Best-fit $\Lambda$CDM**: $(\Omega_M, \Omega_\Lambda) = (0.3, 0.7)$, $H_0 = 70$
- **Einstein--de Sitter**: $(\Omega_M, \Omega_\Lambda) = (1, 0)$ -- matter-only, decelerating
- **Empty (Milne)**: $(\Omega_M, \Omega_\Lambda) = (0, 0)$ -- coasting, no gravity
- **Open CDM**: $(\Omega_M, \Omega_\Lambda) = (0.3, 0)$ -- matter only, curved
```ocaml
let h0 = 70.0
let p_lcdm = Cosmo.lcdm ~h0 ~omega_m:0.3 ~omega_l:0.7
let p_edsit = Cosmo.lcdm ~h0 ~omega_m:1.0 ~omega_l:0.0
let p_empty = Cosmo.lcdm ~h0 ~omega_m:0.01 ~omega_l:0.0
let p_open = Cosmo.lcdm ~h0 ~omega_m:0.3 ~omega_l:0.0
let n_grid = 200
let z_grid = Nx.logspace f64 (-2.5) 0.4 n_grid (* z ~ 0.003 to 2.5 *)
let mu_of_model p =
Nx.init f64 [| n_grid |] (fun idx ->
let z = Nx.scalar f64 (Nx.item [idx.(0)] z_grid) in
Nx.item [] (Cosmo.distance_modulus ~p z))
let mu_lcdm = mu_of_model p_lcdm
let mu_edsit = mu_of_model p_edsit
let mu_empty = mu_of_model p_empty
let mu_open = mu_of_model p_open
let () = Printf.printf "Theory curves computed for %d redshift points\n" n_grid
```
Theory curves computed for 200 redshift points
val h0 : float = 70.
val p_lcdm : Umbra.Cosmo.params =
val p_edsit : Umbra.Cosmo.params =
val p_empty : Umbra.Cosmo.params =
val p_open : Umbra.Cosmo.params =
val n_grid : int = 200
val z_grid : (float, Nx.float64_elt) Nx.t = float64 [200]
[0.00316228, 0.00327019, ..., 2.429, 2.51189]
val mu_of_model : Umbra.Cosmo.params -> (float, Nx.float64_elt) Nx.t =
val mu_lcdm : (float, Nx.float64_elt) Nx.t = float64 [200]
[30.6639, 30.737, ..., 46.4718, 46.5603]
val mu_edsit : (float, Nx.float64_elt) Nx.t = float64 [200]
[30.6603, 30.7332, ..., 45.6533, 45.7352]
val mu_empty : (float, Nx.float64_elt) Nx.t = float64 [200]
[30.662, 30.735, ..., 46.7919, 46.9042]
val mu_open : (float, Nx.float64_elt) Nx.t = float64 [200]
[30.6615, 30.7345, ..., 46.3264, 46.4235]
## The Hubble diagram
The Hubble diagram plots distance modulus $\mu$ against redshift $z$. Distant
supernovae ($z > 0.3$) are systematically fainter than predicted by
decelerating models (Einstein--de Sitter, Open CDM), showing that the expansion
has been **accelerating**.
```ocaml
let to32 t = Nx.astype f32 t
let _fig =
Hugin.layers [
Hugin.point ~x:(to32 sn_z) ~y:(to32 sn_mu)
~color:(Hugin.Color.with_alpha 0.3 Hugin.Color.blue)
~size:2.0 ~marker:Hugin.Circle () ;
Hugin.line ~x:(to32 z_grid) ~y:(to32 mu_lcdm)
~color:Hugin.Color.vermillion ~line_width:2.5
~label:"ΛCDM (0.3, 0.7)" () ;
Hugin.line ~x:(to32 z_grid) ~y:(to32 mu_edsit)
~color:Hugin.Color.sky_blue ~line_width:2.0
~line_style:`Dashed ~label:"EdS (1, 0)" () ;
Hugin.line ~x:(to32 z_grid) ~y:(to32 mu_empty)
~color:Hugin.Color.green ~line_width:2.0
~line_style:`Dotted ~label:"Empty (0, 0)" () ;
Hugin.line ~x:(to32 z_grid) ~y:(to32 mu_open)
~color:Hugin.Color.orange ~line_width:2.0
~line_style:`Dash_dot ~label:"Open (0.3, 0)" () ;
]
|> Hugin.xscale `Log
|> Hugin.xlim 0.01 2.5
|> Hugin.xlabel "Redshift z"
|> Hugin.ylabel "Distance modulus μ (mag)"
|> Hugin.title "SN Ia Hubble Diagram (Pantheon+, 1701 light curves)"
|> Hugin.legend ~loc:Hugin.Lower_right
|> Hugin.grid_lines true
```
val to32 : ('a, 'b) Nx.t -> (float, Nx.float32_elt) Nx.t =
val _fig : Hugin.t =
## Residuals: the acceleration signal
Residuals $\Delta\mu = \mu_\text{obs} - \mu_\text{empty}(z)$ relative to an
empty (coasting) universe isolate the acceleration signal. Positive residuals
at high redshift mean supernovae are **fainter than expected** -- i.e. farther
away than in a coasting universe. This is the direct evidence for cosmic
acceleration.
We bin the data in redshift to show the trend clearly.
```ocaml
let sn_mu_empty =
let n = (Nx.shape sn_z).(0) in
Nx.init f64 [| n |] (fun idx ->
let z = Nx.scalar f64 (Nx.item [idx.(0)] sn_z) in
Nx.item [] (Cosmo.distance_modulus ~p:p_empty z))
let sn_residual = Nx.sub sn_mu sn_mu_empty
(* Model residuals on the grid *)
let res_lcdm = Nx.sub mu_lcdm mu_empty
let res_edsit = Nx.sub mu_edsit mu_empty
let res_open = Nx.sub mu_open mu_empty
(* Bin the residuals using Talon grouping *)
let n_bins = 25
let log_z_min = Float.log10 0.01
let log_z_max = Float.log10 2.3
let bin_width = (log_z_max -. log_z_min) /. Float.of_int n_bins
let bin_df =
let df = Talon.create [
"z", Talon.Col.of_tensor sn_z;
"res", Talon.Col.of_tensor sn_residual;
] in
let df = Talon.filter_by df Talon.Row.(map (number "z") ~f:(fun z -> z > 0.01)) in
Talon.with_column df "bin" f64 Talon.Row.(
map (number "z") ~f:(fun z ->
let b = int_of_float ((Float.log10 z -. log_z_min) /. bin_width) in
Float.of_int (Int.max 0 (Int.min (n_bins - 1) b))))
let groups =
Talon.group_by bin_df Talon.Row.(map (number "bin") ~f:int_of_float)
|> List.filter (fun (_, g) -> Talon.num_rows g > 2)
|> List.sort (fun (a, _) (b, _) -> Int.compare a b)
let n_groups = List.length groups
let bz = Nx.create f32 [| n_groups |]
(Array.of_list (List.map (fun (_, g) -> Talon.Agg.mean g "z") groups))
let bmu = Nx.create f32 [| n_groups |]
(Array.of_list (List.map (fun (_, g) -> Talon.Agg.mean g "res") groups))
let berr = Nx.create f32 [| n_groups |]
(Array.of_list (List.map (fun (_, g) ->
Talon.Agg.std g "res"
/. Float.sqrt (Float.of_int (Talon.num_rows g - 1))) groups))
```
val sn_mu_empty : (float, Nx.float64_elt) Nx.t = float64 [1701]
[28.5917, 28.5917, ..., 46.007, 46.5544]
val sn_residual : (float, Nx.float64_elt) Nx.t = float64 [1701]
[0.40697, 0.46417, ..., -0.583671, -0.371643]
val res_lcdm : (float, Nx.float64_elt) Nx.t = float64 [200]
[0.00189584, 0.00196019, ..., -0.320084, -0.343972]
val res_edsit : (float, Nx.float64_elt) Nx.t = float64 [200]
[-0.00170018, -0.00175822, ..., -1.13865, -1.16906]
val res_open : (float, Nx.float64_elt) Nx.t = float64 [200]
[-0.000498446, -0.000515476, ..., -0.465507, -0.480768]
val n_bins : int = 25
val log_z_min : float = -2.
val log_z_max : float = 0.361727836017592841
val bin_width : float = 0.094469113440703717
val bin_df : Talon.t =
z
res
bin
0.01016
-0.424629577679
0.
0.01017
-0.287976550182
0.
0.01017
-0.238776550182
0.
0.01026
0.112394684341
0.
0.01026
-0.000105315659297
0.
0.01028
0.0491444208064
0.
0.01042
-0.135579053456
0.
0.01044
-0.17646444435
0.
0.01061
0.0819784530003
0.
0.01061
-0.0055215469997
0.
0.01073
-0.215072175983
0.
0.01079
-0.267045255969
0.
0.01079
-0.154445255969
0.
0.01096
0.201426552503
0.
0.01114
0.328659998161
0.
0.01114
0.231859998161
0.
0.01122
0.414035730767
0.
0.01122
0.0129357307671
0.
0.01122
0.0740357307671
0.
0.01155
0.0993356408497
0.
…
…
…
1590 rows × 3 columns
val groups : (int * Talon.t) list =
[(0,
z
res
bin
0.01016
-0.424629577679
0.
0.01017
-0.287976550182
0.
0.01017
-0.238776550182
0.
0.01026
0.112394684341
0.
0.01026
-0.000105315659297
0.
0.01028
0.0491444208064
0.
0.01042
-0.135579053456
0.
0.01044
-0.17646444435
0.
0.01061
0.0819784530003
0.
0.01061
-0.0055215469997
0.
0.01073
-0.215072175983
0.
0.01079
-0.267045255969
0.
0.01079
-0.154445255969
0.
0.01096
0.201426552503
0.
0.01114
0.328659998161
0.
0.01114
0.231859998161
0.
0.01122
0.414035730767
0.
0.01122
0.0129357307671
0.
0.01122
0.0740357307671
0.
0.01155
0.0993356408497
0.
…
…
…
24 rows × 3 columns
);
(1,
z
res
bin
0.01246
0.178978224041
1.
0.01258
0.0374364117569
1.
0.01258
0.0540364117569
1.
0.01259
-0.084899767747
1.
0.01259
-0.154499767747
1.
0.01279
0.0424614815497
1.
0.01283
-0.0770620111033
1.
0.01283
-0.210762011103
1.
0.01303
-0.0118654606603
1.
0.01303
0.0133345393397
1.
0.01304
-0.173842071152
1.
0.01304
-0.181742071152
1.
0.01312
-0.242409144024
1.
0.01312
-0.290709144024
1.
0.01325
0.184241133384
1.
0.01325
0.200641133384
1.
0.01325
0.140741133384
1.
0.01375
-0.0679294440597
1.
0.01375
-0.0966294440597
1.
0.01375
-0.0929294440597
1.
…
…
…
52 rows × 3 columns
);
(2,
z
res
bin
0.01546
-0.0881971344432
2.
0.01549
0.14766106801
2.
0.0155
0.131148947175
2.
0.0155
0.0842489471753
2.
0.0155
-0.127551052825
2.
0.0155
-0.241751052825
2.
0.01557
-0.346610654758
2.
0.01557
-0.329610654758
2.
0.01562
0.0293736691777
2.
0.01562
-0.0136263308223
2.
0.01565
0.183674953401
2.
0.01576
0.00234770294004
2.
0.01578
-0.16752766025
2.
0.01578
-0.0964276602495
2.
0.01581
-0.177684167024
2.
0.01581
-0.269884167024
2.
0.01587
0.213026247358
2.
0.01588
-0.151252326051
2.
0.0159
-0.0164068904934
2.
0.0159
-0.0504068904934
2.
…
…
…
72 rows × 3 columns
);
(3,
z
res
bin
0.01947
0.148326584074
3.
0.01947
0.182126584074
3.
0.01975
0.0993213367396
3.
0.01975
0.0817213367396
3.
0.01976
0.0277114394619
3.
0.01995
-0.31497156999
3.
0.02001
-0.000956680793522
3.
0.02006
-0.17972935267
3.
0.02019
-0.198795323932
3.
0.02019
-0.00399532393241
3.
0.02023
-0.110835916616
3.
0.02023
-0.209135916616
3.
0.02023
-0.134335916616
3.
0.02023
-0.0621359166155
3.
0.02024
0.287380263147
3.
0.02034
0.0942711314822
3.
0.02034
-0.147328868518
3.
0.02035
-0.152406886052
3.
0.02035
-0.157706886052
3.
0.02035
-0.230206886052
3.
…
…
…
92 rows × 3 columns
);
(4,
z
res
bin
0.02388
-0.0795275874192
4.
0.0239
-0.187266827175
4.
0.0239
-0.0822668271749
4.
0.02391
-0.177885876584
4.
0.02401
0.0154444709488
4.
0.02411
0.00941249193803
4.
0.02411
-0.624387508062
4.
0.02412
-0.218198645962
4.
0.02417
-0.158848742089
4.
0.02417
-0.0169487420891
4.
0.02428
0.0693737074972
4.
0.02429
-0.0347311260491
4.
0.02432
-0.127443419316
4.
0.02432
-0.161543419316
4.
0.02432
0.117156580684
4.
0.02434
-0.106149778375
4.
0.02453
-0.125437393885
4.
0.02453
-0.0588373938845
4.
0.02453
-0.184137393885
4.
0.02457
0.0413818842716
4.
…
…
…
112 rows × 3 columns
);
(5,
z
res
bin
0.02969
-0.275399367382
5.
0.02978
-0.058667628548
5.
0.02978
-0.048567628548
5.
0.0299
-0.133627805804
5.
0.02996
0.287555242197
5.
0.03012
-0.0656808035682
5.
0.03012
-0.101380803568
5.
0.03012
-0.109080803568
5.
0.03023
-0.0747137430286
5.
0.03031
-0.0448378058181
5.
0.03036
-0.122370156394
5.
0.03047
-0.0761406185614
5.
0.03059
-0.260703391037
5.
0.03075
-0.278601806314
5.
0.03076
-0.0231184984297
5.
0.03083
-0.171728924107
5.
0.03086
-0.222272818751
5.
0.03091
-0.149641417105
5.
0.03096
0.00229566781329
5.
0.03108
0.26806774982
5.
…
…
…
99 rows × 3 columns
);
(6,
z
res
bin
0.03697
0.037370558993
6.
0.03702
-0.234817279977
6.
0.03702
-0.111717279977
6.
0.03702
0.0290827200227
6.
0.03702
0.0395827200227
6.
0.03705
-0.0785080806863
6.
0.03707
0.00409884352897
6.
0.03725
-0.047410467157
6.
0.03725
-0.061410467157
6.
0.0373
-0.187276253133
6.
0.0374
-0.0840961258395
6.
0.03753
-0.121668755977
6.
0.03756
0.260864344984
6.
0.03756
0.146064344984
6.
0.03787
-0.0843128672071
6.
0.0379
-0.0299641891915
6.
0.03796
-0.00696275219823
6.
0.03818
-0.214844515711
6.
0.03818
-0.246944515711
6.
0.03828
-0.00033051522886
6.
…
…
…
61 rows × 3 columns
);
(7,
z
res
bin
0.0459
-0.135594053043
7.
0.04625
-0.269958777154
7.
0.04631
0.00296267259202
7.
0.04643
-0.217483496434
7.
0.04656
-0.0402921381081
7.
0.04664
-0.0809044161309
7.
0.04682
-0.0821587037548
7.
0.04682
-0.0750587037548
7.
0.04691
0.193176209883
7.
0.04738
-0.0229677846992
7.
0.0476
0.00344066051645
7.
0.04777
-0.010780094188
7.
0.04777
-0.042680094188
7.
0.04819
-0.127231460348
7.
0.04837
-0.259617069447
7.
0.0486
-0.313360481081
7.
0.04865
-0.00494607201814
7.
0.04934
-0.0775548977512
7.
0.0494
-0.21085715033
7.
0.04944
-0.0989568708902
7.
…
…
…
47 rows × 3 columns
);
(8,
z
res
bin
0.05708
-0.0115206686667
8.
0.05728
0.0590741402491
8.
0.05824
-0.064625197056
8.
0.05824
-0.114325197056
8.
0.0583
-0.0945240955029
8.
0.05886
-0.0901701116884
8.
0.05886
0.0756298883116
8.
0.05974
-0.803417802658
8.
0.06092
-0.115228066137
8.
0.06099
0.0133048886582
8.
0.06099
0.0427048886582
8.
0.06121
-0.0195443596127
8.
0.06137
-0.202880711986
8.
0.06137
-0.194580711986
8.
0.06153
-0.0948022912486
8.
0.06372
0.103960473042
8.
0.06384
-0.232250654126
8.
0.06446
-0.194986434258
8.
0.06533
-0.218508110773
8.
0.06627
0.0260876633256
8.
…
…
…
32 rows × 3 columns
);
(9,
z
res
bin
0.07089
-0.0324754668871
9.
0.0709
0.242007811196
9.
0.07091
-0.116208867472
9.
0.07116
-0.140111813817
9.
0.07158
-0.259128451072
9.
0.07167
-0.0958508194342
9.
0.07193
0.0423148998275
9.
0.07222
0.103375550607
9.
0.07252
0.0846613871215
9.
0.07393
-0.0276218037538
9.
0.0744
-0.16077227092
9.
0.07446
-0.147085210939
9.
0.0752
-0.143129423208
9.
0.0752
-0.0538294232081
9.
0.0756
0.182834610902
9.
0.07575
-0.0897256482245
9.
0.07588
-0.00548430816918
9.
0.07845
-0.143184159753
9.
0.07859
0.112798691001
9.
0.07875
0.292116111885
9.
…
…
…
33 rows × 3 columns
);
(10,
z
res
bin
0.0887
-0.214759965462
10.
0.09039
-0.114490123038
10.
0.09089
-0.0202850975671
10.
0.09205
0.0842789146355
10.
0.09293
0.125110163521
10.
0.0995
0.185107403999
10.
0.10165
-0.00302391994575
10.
0.10221
0.0947708510874
10.
0.10246
-0.0681907017381
10.
0.10294
-0.148432610935
10.
0.10361
-0.0109079016729
10.
0.10374
-0.139664168122
10.
0.10507
0.0529089143211
10.
0.10661
-0.148665966028
10.
0.10707
0.0428133675802
10.
0.10711
0.00906130066566
10.
0.10713
0.359235381084
10.
0.10774
0.064581151469
10.
0.10794
-0.0363509057905
10.
0.10908
-0.0604317214155
10.
20 rows × 3 columns
);
(11,
z
res
bin
0.11001
0.0886813554461
11.
0.11259
0.0870049833038
11.
0.11388
-0.172851036977
11.
0.1165
-0.0717173699591
11.
0.11653
0.0170929257421
11.
0.1176
-0.0613460214423
11.
0.11792
-0.031472958122
11.
0.11818
-0.229620528577
11.
0.11901
-0.0984636003687
11.
0.12014
-0.0872353293955
11.
0.12058
0.101178454287
11.
0.12086
-0.0619431118695
11.
0.12207
-0.125106114742
11.
0.12231
-0.112215348847
11.
0.12278
-0.0476216623284
11.
0.12316
-0.143318307083
11.
0.12357
-0.13125195377
11.
0.12377
0.0495330304242
11.
0.12383
0.0717196359588
11.
0.12393
-0.103934884938
11.
…
…
…
38 rows × 3 columns
);
(12,
z
res
bin
0.13614
0.00634512405369
12.
0.13658
-0.130706237119
12.
0.137
0.0242022119563
12.
0.13713
-0.0258886328085
12.
0.13745
0.35822685965
12.
0.13822
-0.12008128119
12.
0.13826
-0.0561499783706
12.
0.1384
0.0141110181508
12.
0.13851
0.0294747952992
12.
0.13875
-0.0147267381425
12.
0.1388
0.0312404310394
12.
0.13955
-0.0066181867595
12.
0.14082
0.128628511143
12.
0.14104
0.231016921037
12.
0.14123
-0.0148979084376
12.
0.14134
0.0663005741243
12.
0.14325
-0.0414714699277
12.
0.14345
0.423697521
12.
0.14359
0.0195383381869
12.
0.14404
-0.0256092965459
12.
…
…
…
64 rows × 3 columns
);
(13,
z
res
bin
0.16924
0.0389733472685
13.
0.16971
0.189083761363
13.
0.17042
-0.0185879074946
13.
0.17124
0.0383736768909
13.
0.17169
-0.0498724221099
13.
0.17256
-0.0199123832614
13.
0.1727
-0.25631246117
13.
0.17297
0.19872715637
13.
0.17331
-0.145074641558
13.
0.17374
0.00231747862402
13.
0.17378
-0.189222107681
13.
0.17392
0.0581902517402
13.
0.17417
0.0296229858146
13.
0.1742
-0.212480783249
13.
0.17438
-0.0413020372001
13.
0.17443
0.00732580573651
13.
0.17444
-0.062508604123
13.
0.17498
0.188743907993
13.
0.17666
-0.0606712840265
13.
0.17713
-0.0738066505178
13.
…
…
…
117 rows × 3 columns
);
(14,
z
res
bin
0.21037
-0.53515754654
14.
0.21084
-0.031562219736
14.
0.21095
-0.190002164661
14.
0.21114
-0.0429424845705
14.
0.21134
0.14110646419
14.
0.21174
-0.0372897541268
14.
0.212
-0.0942081001471
14.
0.21225
-0.0732110933607
14.
0.2135
-0.00278059242427
14.
0.21365
-0.0178518660398
14.
0.21398
-0.142524867022
14.
0.2144
-0.0166920577869
14.
0.21507
-0.159519939026
14.
0.21521
0.420530657906
14.
0.21578
0.175431938785
14.
0.2165
-0.232502482431
14.
0.21689
0.00510984041876
14.
0.21692
0.0307803130448
14.
0.21742
0.0715943550243
14.
0.21794
-0.0793987415813
14.
…
…
…
142 rows × 3 columns
);
(15,
z
res
bin
0.26141
-0.063389305081
15.
0.26162
-0.0941332792955
15.
0.26172
0.174141517223
15.
0.26173
0.16764901455
15.
0.26175
-0.121535981157
15.
0.26184
0.179131697137
15.
0.262
0.205752656004
15.
0.26303
0.738750935002
15.
0.26323
0.111609861946
15.
0.2636
-0.130892774852
15.
0.26393
0.0251761002073
15.
0.26397
-0.230491074865
15.
0.26408
0.0991994542632
15.
0.26419
-0.0286096346745
15.
0.2646
-0.0244674248267
15.
0.26463
0.131857822634
15.
0.26582
0.0577820584371
15.
0.26583
-0.226609147044
15.
0.2664
-0.266102716605
15.
0.267
-0.0153587433182
15.
…
…
…
150 rows × 3 columns
);
(16,
z
res
bin
0.32548
0.180663773592
16.
0.3256
0.214952098223
16.
0.3258
0.114833307567
16.
0.32581
-0.00134261005185
16.
0.32632
-0.00261164522524
16.
0.32804
-0.0658203680861
16.
0.32842
0.23901384648
16.
0.32848
0.281661625296
16.
0.32851
-0.135564457581
16.
0.32868
0.000754754949696
16.
0.32868
0.0291547549497
16.
0.32871
-0.143471204838
16.
0.32907
-0.017181284091
16.
0.32941
0.334561631049
16.
0.32952
0.161934844394
16.
0.32968
0.0247326862566
16.
0.32995
-0.235994772669
16.
0.33047
0.146304669334
16.
0.33056
0.082430130099
16.
0.33063
0.0834056020207
16.
…
…
…
130 rows × 3 columns
);
(17,
z
res
bin
0.40368
0.0305227159842
17.
0.40463
0.0832671563379
17.
0.40483
-0.255385075024
17.
0.4055
0.183823918891
17.
0.40646
0.0885295187297
17.
0.40895
0.0818394810777
17.
0.4092
0.120588849089
17.
0.40935
0.0852588703231
17.
0.40949
0.0566911608651
17.
0.41004
0.0584848294338
17.
0.41123
0.0295285142735
17.
0.4114
0.199979142478
17.
0.41161
0.178283386575
17.
0.41266
0.148013322853
17.
0.41657
0.0642467660022
17.
0.41857
-0.0600359425703
17.
0.41936
0.00406598556371
17.
0.41939
-0.222016063038
17.
0.4196
0.182509917202
17.
0.41965
0.114106661782
17.
…
…
…
97 rows × 3 columns
);
(18,
z
res
bin
0.50282
0.0115974630251
18.
0.50285
-0.0792578980026
18.
0.50306
-0.00224520003424
18.
0.50316
-0.00676282447449
18.
0.50593
-0.0864656605981
18.
0.50615
-0.0622987115987
18.
0.50725
0.0635424328218
18.
0.50739
0.0689229785269
18.
0.50825
-0.15809275328
18.
0.51016
-0.0926766588196
18.
0.51095
-0.132914123298
18.
0.51169
0.193508853341
18.
0.51387
0.0506093976964
18.
0.51437
0.0172694047828
18.
0.51469
0.11564493179
18.
0.51726
-0.130069979609
18.
0.51883
-0.0362931904153
18.
0.51885
-0.0488939890317
18.
0.51941
0.00968501476334
18.
0.51968
0.0550258324352
18.
…
…
…
96 rows × 3 columns
);
(19,
z
res
bin
0.62525
-0.0676927812368
19.
0.62725
-0.148965950221
19.
0.63077
-0.0857981161031
19.
0.63183
-0.442410797453
19.
0.63225
0.094002948885
19.
0.63399
-0.165486457732
19.
0.63777
-0.248879776508
19.
0.63794
0.0843028521275
19.
0.63824
-0.121262699059
19.
0.63873
-0.00232867361532
19.
0.63934
0.00470128976596
19.
0.64185
0.0517482101598
19.
0.64311
-0.440236071775
19.
0.64371
-0.00444929008427
19.
0.64371
-0.000249290084263
19.
0.6477
-0.00131150843364
19.
0.64852
0.215975033203
19.
0.6487
-0.0909737694835
19.
0.64962
-0.0477982164029
19.
0.66213
0.0994509027894
19.
…
…
…
76 rows × 3 columns
);
(20,
z
res
bin
0.77929
-0.286728045808
20.
0.78807
-0.0531352325948
20.
0.78907
0.0228403967112
20.
0.78928
-0.162099242024
20.
0.79662
-0.484747596675
20.
0.79863
0.039436351798
20.
0.83981
-0.174927172808
20.
0.83981
-0.333727172808
20.
0.85482
0.0320794309533
20.
0.93585
-0.308388644025
20.
10 rows × 3 columns
);
(21,
z
res
bin
0.97423
0.154450661289
21.
1.01242
0.0501694564944
21.
1.01988
0.186919560833
21.
1.02088
-0.121019067405
21.
1.02789
0.42494712992
21.
1.04817
0.254697417559
21.
1.12092
0.0759845634931
21.
7 rows × 3 columns
);
(22,
z
res
bin
1.23225
-0.324587237938
22.
1.23597
0.250701908425
22.
1.29911
0.0602023436219
22.
1.3041
-0.0948606170644
22.
1.30611
0.199592161811
22.
1.31317
-0.0496838731239
22.
1.3291
0.00165723209277
22.
1.34101
-0.145764204634
22.
1.35136
-0.373984554953
22.
1.35608
-0.124370205399
22.
1.39103
-0.216813113651
22.
1.41633
-0.608368865186
22.
12 rows × 3 columns
);
(23,
z
res
bin
1.5429
-0.239796061037
23.
1.54901
-0.0492647551793
23.
1.61505
-0.312860625322
23.
1.69706
-0.341579608953
23.
1.80119
-0.33004923418
23.
5 rows × 3 columns
)]
val n_groups : int = 24
val bz : (float, Nx.float32_elt) Nx.t = float32 [24]
[0.0109479, 0.0140023, ..., 1.32297, 1.64104]
val bmu : (float, Nx.float32_elt) Nx.t = float32 [24]
[0.00578864, -0.0499531, ..., -0.118857, -0.25471]
val berr : (float, Nx.float32_elt) Nx.t = float32 [24]
[0.043205, 0.0225362, ..., 0.070028, 0.0543295]
```ocaml
let _fig =
Hugin.layers [
Hugin.errorbar ~x:bz ~y:bmu ~yerr:(`Symmetric berr)
~color:Hugin.Color.black ~cap_size:4.0 ~line_width:1.5 () ;
Hugin.point ~x:bz ~y:bmu
~color:Hugin.Color.black ~size:5.0 ~marker:Hugin.Circle () ;
Hugin.line ~x:(to32 z_grid) ~y:(to32 res_lcdm)
~color:Hugin.Color.vermillion ~line_width:2.5
~label:"ΛCDM (0.3, 0.7)" () ;
Hugin.line ~x:(to32 z_grid) ~y:(to32 res_edsit)
~color:Hugin.Color.sky_blue ~line_width:2.0
~line_style:`Dashed ~label:"EdS (1, 0)" () ;
Hugin.line ~x:(to32 z_grid) ~y:(to32 res_open)
~color:Hugin.Color.orange ~line_width:2.0
~line_style:`Dash_dot ~label:"Open (0.3, 0)" () ;
Hugin.hline ~y:0.0 ~line_style:`Dotted ~color:Hugin.Color.gray () ;
]
|> Hugin.xscale `Log
|> Hugin.xlim 0.01 2.5
|> Hugin.xlabel "Redshift z"
|> Hugin.ylabel "Δμ (mag, relative to empty universe)"
|> Hugin.title "Hubble Residuals: The Acceleration Signal"
|> Hugin.legend ~loc:Hugin.Upper_left
|> Hugin.grid_lines true
```
val _fig : Hugin.t =
## Confidence contours in the $\Omega_M$--$\Omega_\Lambda$ plane
Following Perlmutter et al. (1999, Fig. 7), we scan a grid of
$(\Omega_M, \Omega_\Lambda)$ values and compute $\chi^2$ at each point. The
confidence contours are drawn at $\Delta\chi^2 = 2.30, 6.17, 11.8$ (68.3%,
95.4%, 99.7% for 2 parameters). We use only the Hubble-flow SNe ($z > 0.01$)
and the diagonal errors (sufficient for this visualization).
```ocaml
(* Filter Hubble-flow SNe using Talon *)
let hf =
Talon.filter_by df Talon.Row.(
map2 (number "zHD") (number "MU_SH0ES_ERR_DIAG")
~f:(fun z err -> z > 0.01 && err > 0.0 && err < 10.0))
let hf_col name =
Talon.get_column_exn hf name |> Talon.Col.to_tensor f64 |> Option.get
let hf_z = hf_col "zHD"
let hf_mu = hf_col "MU_SH0ES"
let hf_w = Nx.recip (Nx.square (hf_col "MU_SH0ES_ERR_DIAG"))
let n_hf = (Nx.shape hf_z).(0)
let () = Printf.printf "Using %d Hubble-flow SNe for chi-squared grid\n" n_hf
(* Chi-squared for a given (omega_m, omega_l) with M marginalized analytically.
chi2 = sum w_i (mu_i - mu_th(z_i) - M)^2
Minimizing over M: M* = sum(w_i * (mu_i - mu_th_i)) / sum(w_i)
chi2_min = sum(w_i * d_i^2) - (sum(w_i * d_i))^2 / sum(w_i) *)
let hf_z_arr = Array.init n_hf (fun i -> Nx.item [i] hf_z)
let hf_mu_arr = Array.init n_hf (fun i -> Nx.item [i] hf_mu)
let hf_w_arr = Array.init n_hf (fun i -> Nx.item [i] hf_w)
let sum_w = Array.fold_left ( +. ) 0.0 hf_w_arr
(* Pure-float distance modulus via 16-point Gauss-Legendre quadrature.
Avoids all tensor allocation in the chi2 hot loop. *)
let gl_n = [| -0.9894009349916499; -0.9445750230732326; -0.8656312023878318;
-0.7554044083550030; -0.6178762444026438; -0.4580167776572274;
-0.2816035507792589; -0.0950125098376374; 0.0950125098376374;
0.2816035507792589; 0.4580167776572274; 0.6178762444026438;
0.7554044083550030; 0.8656312023878318; 0.9445750230732326;
0.9894009349916499 |]
let gl_wt = [| 0.0271524594117541; 0.0622535239386479; 0.0951585116824928;
0.1246289712555339; 0.1495959888165767; 0.1691565193950025;
0.1826034150449236; 0.1894506104550685; 0.1894506104550685;
0.1826034150449236; 0.1691565193950025; 0.1495959888165767;
0.1246289712555339; 0.0951585116824928; 0.0622535239386479;
0.0271524594117541 |]
let dist_mod_f omega_m omega_l z =
let c_over_h0 = 299792.458 /. 70.0 in
let omega_k = 1.0 -. omega_m -. omega_l in
let half_z = z *. 0.5 in
let integral = ref 0.0 in
for k = 0 to 15 do
let zp = half_z *. gl_n.(k) +. half_z in
let opz = 1.0 +. zp in
let ez = Float.sqrt (omega_m *. opz *. opz *. opz
+. omega_k *. opz *. opz +. omega_l) in
integral := !integral +. gl_wt.(k) /. ez
done;
let chi = c_over_h0 *. half_z *. !integral in
let dl = (1.0 +. z) *. chi in
5.0 /. Float.log 10.0 *. Float.log dl +. 25.0
let chi2_at omega_m omega_l =
let sum_wd = ref 0.0 in
let sum_wdd = ref 0.0 in
let ok = ref true in
for i = 0 to n_hf - 1 do
let mu_th_i = dist_mod_f omega_m omega_l hf_z_arr.(i) in
if Float.is_nan mu_th_i then ok := false
else begin
let d = hf_mu_arr.(i) -. mu_th_i in
let w = hf_w_arr.(i) in
sum_wd := !sum_wd +. w *. d;
sum_wdd := !sum_wdd +. w *. d *. d
end
done;
if not !ok then infinity
else !sum_wdd -. (!sum_wd *. !sum_wd /. sum_w)
(* Scan the grid -- axis range matches Perlmutter 1999 Figure 7 *)
let n_om = 100
let n_ol = 100
let om_min = 0.0 and om_max = 3.0
let ol_min = -1.0 and ol_max = 3.0
let () = Printf.printf "Computing chi-squared on %dx%d grid...\n%!" n_om n_ol
let chi2_grid =
Nx.init f64 [| n_ol; n_om |] (fun idx ->
let j = idx.(0) and i = idx.(1) in
let omega_m = om_min +. (Float.of_int i +. 0.5) *. (om_max -. om_min) /. Float.of_int n_om in
let omega_l = ol_min +. (Float.of_int j +. 0.5) *. (ol_max -. ol_min) /. Float.of_int n_ol in
if omega_m < 0.001 then 1e10
else chi2_at omega_m omega_l)
let chi2_min = Nx.item [] (Nx.min chi2_grid)
let delta_chi2 = Nx.sub_s chi2_grid chi2_min
let () =
let flat_idx = Int32.to_int (Nx.item [] (Nx.argmin chi2_grid)) in
let best_i = flat_idx mod n_om in
let best_j = flat_idx / n_om in
let best_om = om_min +. (Float.of_int best_i +. 0.5) *. (om_max -. om_min) /. Float.of_int n_om in
let best_ol = ol_min +. (Float.of_int best_j +. 0.5) *. (ol_max -. ol_min) /. Float.of_int n_ol in
Printf.printf "Best fit: Omega_M = %.2f, Omega_Lambda = %.2f (chi2 = %.1f, dof ~ %d)\n"
best_om best_ol chi2_min (n_hf - 1)
```
Using 1590 Hubble-flow SNe for chi-squared grid
Computing chi-squared on 100x100 grid...
Best fit: Omega_M = 0.23, Omega_Lambda = 0.54 (chi2 = 684.2, dof ~ 1589)
val hf : Talon.t =
CID
IDSURVEY
zHD
zHDERR
zCMB
zCMBERR
zHEL
zHELERR
m_b_corr
m_b_corr_err_DIAG
…
2013E
56
0.01016
0.00085
0.01042
8e-05
0.00936
8e-05
13.5264
0.3475
…
1999ac
57
0.01017
0.00084
0.00979
2e-05
0.00947
2e-05
13.6652
0.364224
…
1999ac
62
0.01017
0.00084
0.00979
2e-05
0.00947
2e-05
13.7144
0.34081
…
2009an
51
0.01026
0.00084
0.00921
1e-05
0.00887
1e-05
14.0848
0.305101
…
2009an
65
0.01026
0.00084
0.00921
1e-05
0.00887
1e-05
13.9723
0.297865
…
2006bh
5
0.01028
0.00086
0.01042
0.00015
0.01077
0.00015
14.0258
0.246478
…
2004S
57
0.01042
0.00084
0.0098
2e-05
0.0093
2e-05
13.8706
0.316076
…
2021hpr
57
0.01044
0.00084
0.00958
2e-05
0.00938
2e-05
13.8339
0.342855
…
2002dp
63
0.01061
0.00084
0.01049
1e-05
0.01169
1e-05
14.1276
0.307827
…
2002dp
57
0.01061
0.00084
0.01049
1e-05
0.01169
1e-05
14.0401
0.273239
…
1997do
62
0.01073
0.00084
0.01048
2e-05
0.01012
2e-05
13.8551
0.363667
…
1997bq
62
0.01079
0.00084
0.00993
2e-05
0.00973
2e-05
13.8153
0.322889
…
2008fv_comb
50
0.01079
0.00084
0.00993
2e-05
0.00973
2e-05
13.9279
0.377003
…
ASASSN-16jf
150
0.01096
0.00084
0.0104
1e-05
0.01144
1e-05
14.3179
0.313698
…
iPTF13ebh
56
0.01114
0.00085
0.01238
5e-05
0.01317
5e-05
14.4807
0.341421
…
iPTF13ebh
5
0.01114
0.00085
0.01238
5e-05
0.01317
5e-05
14.3839
0.293983
…
2010ko
56
0.01122
0.00084
0.01096
2e-05
0.01082
2e-05
14.5817
0.352878
…
2013ex
51
0.01122
0.00084
0.01096
2e-05
0.01082
2e-05
14.1806
0.282135
…
2013ex
56
0.01122
0.00084
0.01096
2e-05
0.01082
2e-05
14.2417
0.32405
…
2009ab
5
0.01155
0.00085
0.01189
8e-05
0.01219
8e-05
14.3303
0.27987
…
…
…
…
…
…
…
…
…
…
…
…
1590 rows × 47 columns
val hf_col : string -> (float, Nx.float64_elt) Nx.t =
val hf_z : (float, Nx.float64_elt) Nx.t = float64 [1590]
[0.01016, 0.01017, ..., 1.91165, 2.26137]
val hf_mu : (float, Nx.float64_elt) Nx.t = float64 [1590]
[32.7794, 32.9182, ..., 45.4233, 46.1828]
val hf_w : (float, Nx.float64_elt) Nx.t = float64 [1590]
[8.23147, 7.49694, ..., 7.77459, 12.6367]
val n_hf : int = 1590
val hf_z_arr : float array =
[|0.01016; 0.01017; 0.01017; 0.01026; 0.01026; 0.01028; 0.01042; 0.01044;
0.01061; 0.01061; 0.01073; 0.01079; 0.01079; 0.01096; 0.01114; 0.01114;
0.01122; 0.01122; 0.01122; 0.01155; 0.01195; 0.01213; 0.0122; 0.01233;
0.01246; 0.01258; 0.01258; 0.01259; 0.01259; 0.01279; 0.01283; 0.01283;
0.01303; 0.01303; 0.01304; 0.01304; 0.01312; 0.01312; 0.01325; 0.01325;
0.01325; 0.01375; 0.01375; 0.01375; 0.01376; 0.01386; 0.01388; 0.01389;
0.01389; 0.01411; 0.01424; 0.01442; 0.01442; 0.01442; 0.01442; 0.01442;
0.01446; 0.0145; 0.01453; 0.0146; 0.01462; 0.01463; 0.01463; 0.01467;
0.01472; 0.01484; 0.01492; 0.01493; 0.01499; 0.01499; 0.01515; 0.01519;
0.01525; 0.01529; 0.01542; 0.01543; 0.01546; 0.01549; 0.0155; 0.0155;
0.0155; 0.0155; 0.01557; 0.01557; 0.01562; 0.01562; 0.01565; 0.01576;
0.01578; 0.01578; 0.01581; 0.01581; 0.01587; 0.01588; 0.0159; 0.0159;
0.0159; 0.01603; 0.01652; 0.01652; 0.01656; 0.01657; 0.01662; 0.01666;
0.01671; 0.01678; 0.01682; 0.01682; 0.01682; 0.0169; 0.0169; 0.01692;
0.01698; 0.01699; 0.01705; 0.01718; 0.01718; 0.0172; 0.0173; 0.0173;
0.01733; 0.01733; 0.01734; 0.01737; 0.01737; 0.01743; 0.01747; 0.01747;
0.01752; 0.01776; 0.01778; 0.01778; 0.01784; 0.01784; 0.0179; 0.01802;
0.01808; 0.01826; 0.01826; 0.01839; 0.01855; 0.01855; 0.01865; 0.01865;
0.01866; 0.01875; 0.01875; 0.01905; 0.01947; 0.01947; 0.01975; 0.01975;
0.01976; 0.01995; 0.02001; 0.02006; 0.02019; 0.02019; 0.02023; 0.02023;
0.02023; 0.02023; 0.02024; 0.02034; 0.02034; 0.02035; 0.02035; 0.02035;
0.02044; 0.02049; 0.02052; 0.02056; 0.02056; 0.02081; 0.02082; 0.02082;
0.0209; 0.02096; 0.02106; 0.02116; 0.02116; 0.02118; 0.02118; 0.02131;
0.02131; 0.02131; 0.02134; 0.02137; 0.02151; 0.02153; 0.0217; 0.02183;
0.02183; 0.02197; 0.02198; 0.02203; 0.02205; 0.02207; 0.02215; 0.02219;
0.02228; 0.02228; 0.02231; 0.02234; 0.02234; 0.02236; 0.02239; 0.02239;
0.0224; 0.0224; 0.02241; 0.02255; 0.02266; 0.0227; 0.02273; 0.02295;
0.02295; 0.02298; 0.02298; 0.02303; 0.02307; 0.02313; 0.02316; 0.02321;
0.02325; 0.02331; 0.02331; 0.02342; 0.02342; 0.02342; 0.02343; 0.02343;
0.02344; 0.02352; 0.02354; 0.02357; 0.02357; 0.02357; 0.02365; 0.02369;
0.02388; 0.0239; 0.0239; 0.02391; 0.02401; 0.02411; 0.02411; 0.02412;
0.02417; 0.02417; 0.02428; 0.02429; 0.02432; 0.02432; 0.02432; 0.02434;
0.02453; 0.02453; 0.02453; 0.02457; 0.02462; 0.02462; 0.02462; 0.02464;
0.02464; 0.02466; 0.02491; 0.02494; 0.02509; 0.0251; 0.0251; 0.0251;
0.0251; 0.02512; 0.02513; 0.02517; 0.02517; 0.02517; 0.02519; 0.02519;
0.02521; 0.02525; 0.02525; 0.02534; 0.02534; 0.02556; 0.02557; 0.02585;
0.02591; 0.02596; 0.02598; 0.02598; 0.02598; 0.02598; 0.02626; 0.02626;
0.02632; 0.02669; 0.02691; ...|]
val hf_mu_arr : float array =
[|32.7794; 32.9182; 32.9674; 33.3378; 33.2253; 33.2788; 33.1236; 33.0869;
33.3806; 33.2931; 33.1081; 33.0683; 33.1809; 33.5709; 33.7337; 33.6369;
33.8347; 33.4336; 33.4947; 33.5833; 33.4974; 33.8783; 33.7023; 33.7403;
33.8286; 33.708; 33.7246; 33.5874; 33.5178; 33.7492; 33.6365; 33.5028;
33.7355; 33.7607; 33.5752; 33.5673; 33.52; 33.4717; 33.9682; 33.9846;
33.9247; 33.797; 33.7683; 33.772; 33.7649; 33.6741; 33.6968; 33.8053;
33.6651; 33.8206; 33.8444; 34.0187; 34.0249; 34.3366; 34.1781; 34.252;
33.9078; 34.1727; 33.7526; 33.8826; 33.9253; 33.7553; 33.7189; 33.8916;
33.7894; 34.0109; 34.1902; 33.8524; 33.974; 33.9261; 33.8877; 34.1793;
33.7706; 34.0334; 34.3426; 34.0744; 34.0331; 34.2732; 34.2581; 34.2112;
33.9994; 33.8852; 33.7902; 33.8072; 34.1732; 34.1302; 34.3317; 34.1657;
33.9986; 34.0697; 33.9926; 33.9004; 34.3916; 34.0287; 34.1663; 34.1323;
34.0329; 34.1062; 34.4137; 33.9999; 34.1608; 33.9726; 34.2793; 34.3408;
33.8985; 34.1679; 34.4649; 34.5341; 33.9673; 33.841; 33.8975; 34.4894;
34.3409; 34.0755; 34.1973; 34.022; 34.2003; 34.6084; 34.1436; 34.1517;
34.4009; 34.5101; 34.361; 34.31; 34.2669; 33.7799; 34.3228; 34.3133;
34.1465; 34.2221; 34.3364; 34.2464; 34.2939; 34.2418; 34.2975; 34.2661;
34.6078; 34.9792; 34.2005; 34.4393; 34.492; 34.576; 34.6171; 34.5716;
34.6066; 34.5466; 34.4866; 34.4286; 34.7747; 34.8085; 34.757; 34.7394;
34.6865; 34.3648; 34.6854; 34.5121; 34.5072; 34.702; 34.5995; 34.5012;
34.576; 34.6482; 34.9988; 34.8165; 34.5749; 34.5709; 34.5656; 34.4931;
34.6552; 35.1263; 34.556; 34.4168; 34.5647; 34.4863; 34.674; 34.6349;
34.8118; 34.6956; 34.6693; 34.375; 34.8194; 35.118; 35.2056; 34.7561;
34.7344; 34.7419; 34.7627; 34.3468; 34.5899; 34.6114; 34.7709; 34.7583;
34.7742; 34.8589; 34.9654; 34.6739; 34.7477; 34.8847; 35.0481; 35.0136;
34.8059; 34.8234; 34.9245; 34.7727; 34.8704; 34.7814; 34.7714; 34.8603;
34.759; 34.7997; 35.0392; 34.7306; 34.8438; 34.8693; 34.6849; 35.2685;
34.8584; 34.9992; 34.9712; 34.985; 34.7666; 35.0225; 34.883; 34.9168;
35.0537; 34.6478; 34.4301; 34.9851; 35.0081; 34.9783; 34.8582; 34.8582;
34.8331; 34.8391; 34.9998; 35.0406; 34.9509; 34.5715; 34.9543; 34.943;
34.9949; 34.889; 34.994; 34.8993; 35.1018; 35.1049; 34.4711; 34.8782;
34.9421; 35.084; 35.1803; 35.0771; 34.9871; 34.953; 35.2317; 35.0102;
35.008; 35.0746; 34.9493; 35.1784; 35.1061; 35.0521; 34.8898; 34.9528;
34.807; 34.8112; 35.0862; 35.1973; 35.2152; 34.8369; 34.8482; 35.079;
35.009; 35.0276; 34.4583; 34.9117; 35.1989; 35.2277; 35.1139; 35.1145;
34.9842; 34.9441; 34.9063; 35.1882; 35.1974; 35.1828; 35.0621; 34.7664;
35.1181; 35.4014; 35.2931; 35.1081; 34.835; 35.3111; 35.2855; 35.3429;
35.0223; 35.308; 35.2688; ...|]
val hf_w_arr : float array =
[|8.23146814613716593; 7.4969352680999215; 8.55574221326686413;
10.6591609759993187; 11.1791254824205311; 16.2654007737080022;
9.93710117748064903; 8.45464229847009463; 10.4726778114779471;
13.2645900552113254; 7.5197722886281948; 9.52504612371758519;
6.99983232092770624; 10.0871977557366836; 8.52534214273321567;
11.4738309842553772; 7.98393120275620927; 12.4487789115859133;
9.45744263313261868; 12.6492602333561486; 11.0006818657807397;
8.3347365445043; 14.0625703127636719; 15.4835677155751181;
14.227591884866; 17.2516774910131865; 15.987591222375503;
14.7860742788850779; 14.1129085551366398; 5.76023962347622742;
12.1084887207242176; 9.1205006900497132; 17.3925484174557248;
12.1218985292262946; 11.683103651051935; 12.3299901204144131;
13.365063188000093; 14.6123881673806135; 10.8425251089593679;
11.1557637660110593; 10.7776256002286637; 12.9527972436885257;
18.4550886236199396; 17.9919475158734201; 11.9903729275141533;
8.38375008416108436; 18.3362669046649529; 11.9893765286826248;
11.3974938041915532; 15.8888386016909919; 14.1875346969754474;
10.5967952266739616; 8.05528722258021723; 7.27645222224196;
6.34775702674259801; 6.33731050419107; 17.2868421478877323;
7.50490609740132886; 21.3740568964570166; 16.0879126286809289;
16.8064771777973228; 6.55759480278037898; 6.48056420082261653;
17.2611398525342; 6.33357902093675751; 13.69489137836125;
7.20015441460373751; 12.9805322684816744; 18.9754652580740206;
18.625928986051818; 11.4921195365949078; 14.0516078939454445;
10.9506480889754076; 8.43554877406980452; 7.19189246515212055;
10.0869414622828693; 19.4983479754138251; 8.63222933841724327;
13.4641128131653112; 15.0148704932793411; 12.1982247035558142;
10.146793599690735; 11.8434308491135152; 11.8582808284653058;
14.3086498722222988; 8.94749089006965; 12.873171291985976;
12.1506484441415417; 20.672995317749784; 11.8593426098531527;
17.4644346112094517; 20.851416708835508; 13.6828374657754281;
19.7840253229750367; 16.9500116927867452; 22.7904126118422568;
13.4657927237540278; 12.6121817743929938; 8.3352659411483927;
8.9251047074459; 21.0194316349305446; 5.79853079888242284;
14.5772613757875487; 25.5016762892044468; 16.6077598589186302;
16.4609255025734029; 7.78826881381891045; 8.56872028090763749;
12.362786901820888; 11.9473092996875536; 14.2871324174975616;
10.3847790563541444; 19.1715506082083955; 11.2747485236470109;
21.741787484924366; 10.4736268288461609; 9.44831673156642182;
20.6961375184077028; 12.5135089429374613; 15.1242634067501207;
20.3458786329970458; 18.7015599950648692; 20.6077296866621609;
14.3963000301706341; 14.520768051703806; 9.14760924209767445;
24.9398587271691525; 25.5540420407535684; 18.6728015422666118;
9.74052699653846; 22.6509220160585194; 21.9531684900774628;
16.4439751191577166; 20.1324567073815182; 9.83237146373930848;
18.2299477227476743; 24.6365123425761; 20.9496429577490275;
12.7035098104889101; 19.4234861464175346; 15.3351298538964578;
11.3947238832405269; 17.509040539551691; 17.3627026557141519;
12.988673539837988; 21.5302476249783794; 9.70348417833172761;
12.6278730549272478; 26.1901234656971589; 26.1558451309293041;
4.59293241178748; 5.32095686553453; 6.05944007446558164;
11.4346774646612896; 23.9645656337853978; 25.1289970704292536;
18.27955215083119; 26.9972194699756898; 28.2005673569440063;
11.4469833641548; 27.7291367056932039; 28.3439800971673748;
28.3793238806152281; 17.1000589321468475; 7.0800382049997026;
28.3814405615375804; 31.6626267020978887; 27.4834283515611411;
21.7158579836266306; 7.08150786072332394; 8.96926295523383743;
21.2058816479978; 16.0635486772255796; 9.15802089756929;
18.7566768373479498; 15.4378524337368379; 11.0946109925032541;
21.335767528888681; 18.7911672303113519; 24.2452796480947299;
25.9182978468908658; 10.3599255061832096; 10.8133084385166747;
30.6648936516269401; 31.8555650038301259; 28.3678374577575063;
8.65199489673906541; 25.2493498917324288; 25.2070269959166922;
6.46396662873246; 26.0672487054426796; 12.4466708843980065;
15.5433263008844627; 16.4646661210180483; 21.2416677860634415;
16.2873329914842024; 25.2277949897877924; 28.8960644281234629;
12.0035869468058891; 21.4019504835921; 30.8025616050924036;
15.0103333911976673; 34.6685262497122935; 25.2508724579438208;
23.9167726307908026; 19.8123910920625654; 10.9170968668811543;
16.9183741772662302; 11.9963539297717094; 15.3461854905195416;
21.7290195082093263; 16.9154518357872767; 20.5587966414740855;
24.1949786225266372; 7.78835575489808818; 10.7586857487961058;
24.6854987844132197; 13.9889266783498822; 17.7608231677881321;
27.4863101958301748; 8.63334537708151; 13.5941912097414512;
13.2260260168708097; 31.3766650710703807; 18.9526720076696868;
14.0314032315368173; 12.807834225708806; 12.5233417493347;
20.3469799530605187; 19.6825161300159905; 14.7618831786326936;
34.3446168784105055; 14.372949787259202; 21.4573053352920269;
21.0965445810056522; 23.4345755202291386; 14.8926270517011119;
14.5962029388981858; 13.3160449309237592; 22.718990622509434;
30.5913289186968882; 6.13605093366510346; 12.9969161832653164;
27.6045671236346415; 26.2947120504520981; 10.3837082466899;
6.29789856614867727; 19.9179299665683978; 23.0504729450445041;
29.6802504840401795; 18.5389364101218526; 26.1866389989169512;
29.4897237317334593; 25.4911194943350203; 18.0761901209589091;
24.5200237809392654; 38.0968922623001305; 22.6861065677917715;
34.2502130892676817; 27.9615875110153667; 18.0223599013189144;
18.6980019826901689; 18.2672101541450651; 21.8778636636524624;
30.0375605276369; 9.91408611351553; 17.2822431119072597;
19.1537670175515231; 13.8683272198277763; 12.0528115008079286;
30.2168128321327174; 23.5658128966585103; 31.9593800666736279;
27.6431867840131886; 9.00369113498522466; 10.6459491096232455;
17.612506723563655; 23.941354096378312; 41.4573139679420777;
32.7184654368981498; 29.1289053383799548; 17.2569811896428149;
13.8128200856228514; 32.0274222366012324; 28.9246665418856;
18.3059966853358169; 27.6492920330522622; 8.41661752676145625;
24.9548113528369981; 24.8777005517676493; 31.0208268317396296;
25.4453634768141832; 8.2424844909615409; 31.304378738480434;
11.9805803972240117; 9.57136742791041328; 7.21284524967817653;
31.0063188509245471; 29.7391957624394863; ...|]
val sum_w : float = 38213.713851628665
val gl_n : float array =
[|-0.989400934991649939; -0.9445750230732326; -0.865631202387831755;
-0.755404408355003; -0.617876244402643771; -0.458016777657227425;
-0.281603550779258915; -0.0950125098376374; 0.0950125098376374;
0.281603550779258915; 0.458016777657227425; 0.617876244402643771;
0.755404408355003; 0.865631202387831755; 0.9445750230732326;
0.989400934991649939|]
val gl_wt : float array =
[|0.0271524594117541; 0.0622535239386479; 0.0951585116824928;
0.124628971255533905; 0.149595988816576708; 0.169156519395002508;
0.182603415044923612; 0.189450610455068502; 0.189450610455068502;
0.182603415044923612; 0.169156519395002508; 0.149595988816576708;
0.124628971255533905; 0.0951585116824928; 0.0622535239386479;
0.0271524594117541|]
val dist_mod_f : float -> float -> float -> float =
val chi2_at : float -> float -> float =
val n_om : int = 100
val n_ol : int = 100
val om_min : float = 0.
val om_max : float = 3.
val ol_min : float = -1.
val ol_max : float = 3.
val chi2_grid : (float, Nx.float64_elt) Nx.t = float64 [100; 100]
[[1522.01, 1543.98, ..., 3820.7, 3844.02],
[1484.88, 1506.67, ..., 3779.71, 3803.06],
...
[inf, inf, ..., 877.663, 867.429],
[inf, inf, ..., 902.83, 890.429]]
val chi2_min : float = 684.197413463096268
val delta_chi2 : (float, Nx.float64_elt) Nx.t = float64 [100; 100]
[[837.808, 859.782, ..., 3136.5, 3159.83],
[800.684, 822.469, ..., 3095.52, 3118.86],
...
[inf, inf, ..., 193.466, 183.232],
[inf, inf, ..., 218.632, 206.232]]
## Contour plot
Reproducing Perlmutter et al. (1999) Figure 7. The contour levels correspond
to 68%, 90%, 95%, and 99% confidence regions for two parameters
($\Delta\chi^2 = 2.30, 4.61, 5.99, 9.21$). The diagonal solid line marks
**flat** universes ($\Omega_M + \Omega_\Lambda = 1$). The nearly horizontal
dashed line separates eternally expanding universes from those that eventually
recollapse. The upper-left gray region has no Big Bang.
```ocaml
(* Confidence levels for 2 parameters:
68% -> delta_chi2 = 2.30, 90% -> 4.61, 95% -> 5.99, 99% -> 9.21 *)
let confidence_levels = [| 2.30; 4.61; 5.99; 9.21 |]
(* "No Big Bang" boundary: upper-left region where the universe has no
initial singularity. Approximate as OmegaL > 4*OmegaM*(cosh(...))^3
for plotting purposes; simplified to a polygon here. *)
let no_bb_x = Nx.create f32 [| 5 |] [| 0.0; 0.0; 1.0; 2.0; 3.0 |]
let no_bb_y1 = Nx.create f32 [| 5 |] [| 3.0; 1.0; 2.2; 2.8; 3.0 |]
let no_bb_y2 = Nx.full f32 [| 5 |] 3.0
let _fig =
Hugin.layers [
(* "No Big Bang" shaded region *)
Hugin.fill_between ~x:no_bb_x ~y1:no_bb_y1 ~y2:no_bb_y2
~color:(Hugin.Color.with_alpha 0.15 Hugin.Color.gray) () ;
(* Filled confidence contours -- blue/teal like Figure 7 *)
Hugin.contour ~data:(to32 delta_chi2)
~x0:om_min ~x1:om_max ~y0:ol_min ~y1:ol_max
~levels:(`Values confidence_levels)
~filled:true
~cmap:(Hugin.Cmap.of_colors [|
Hugin.Color.with_alpha 0.8 (Hugin.Color.hex "#1a5276");
Hugin.Color.with_alpha 0.6 (Hugin.Color.hex "#2e86c1");
Hugin.Color.with_alpha 0.4 (Hugin.Color.hex "#85c1e9");
Hugin.Color.with_alpha 0.2 (Hugin.Color.hex "#d4e6f1");
Hugin.Color.with_alpha 0.0 Hugin.Color.white;
|]) () ;
(* Contour outlines *)
Hugin.contour ~data:(to32 delta_chi2)
~x0:om_min ~x1:om_max ~y0:ol_min ~y1:ol_max
~levels:(`Values confidence_levels)
~color:(Hugin.Color.hex "#2e86c1") ~line_width:1.0 () ;
(* Flat universe line: OmegaM + OmegaL = 1 *)
Hugin.line
~x:(Nx.create f32 [| 2 |] [| 0.0; 3.0 |])
~y:(Nx.create f32 [| 2 |] [| 1.0; -2.0 |])
~color:Hugin.Color.black ~line_width:1.5
~label:"Flat" () ;
(* No-deceleration line: q0 = 0, i.e. OmegaL = OmegaM / 2 *)
Hugin.line
~x:(Nx.create f32 [| 2 |] [| 0.0; 3.0 |])
~y:(Nx.create f32 [| 2 |] [| 0.0; 1.5 |])
~color:Hugin.Color.gray ~line_style:`Dashed ~line_width:1.0
~label:"Accelerating/decelerating" () ;
(* Lambda = 0 line *)
Hugin.hline ~y:0.0 ~color:Hugin.Color.gray ~line_style:`Dotted
~line_width:0.5 () ;
]
|> Hugin.xlim 0.0 3.0
|> Hugin.ylim (-1.0) 3.0
|> Hugin.xlabel "Omega_M"
|> Hugin.ylabel "Omega_Lambda"
|> Hugin.title "Confidence Contours in the Omega_M - Omega_Lambda Plane"
|> Hugin.legend ~loc:Hugin.Upper_right
```
val confidence_levels : float array = [|2.3; 4.61; 5.99; 9.21|]
val no_bb_x : (float, Nx.float32_elt) Nx.t = float32 [5] [0, 0, ..., 2, 3]
val no_bb_y1 : (float, Nx.float32_elt) Nx.t = float32 [5] [3, 1, ..., 2.8, 3]
val no_bb_y2 : (float, Nx.float32_elt) Nx.t = float32 [5] [3, 3, ..., 3, 3]
val _fig : Hugin.t =
## Best-fit flat $\Lambda$CDM
Restricting to flat universes ($\Omega_M + \Omega_\Lambda = 1$), we find the
best-fit $\Omega_M$ by scanning along the flatness constraint and use Umbra's
`Cosmo.flat_lcdm` to compute the corresponding distances.
```ocaml
let n_flat = 200
let flat_om = Nx.linspace f64 0.01 0.99 n_flat
let flat_chi2 = Nx.init f64 [| n_flat |] (fun i ->
let om = Nx.item [i.(0)] flat_om in
chi2_at om (1.0 -. om))
let best_flat_i = Int32.to_int (Nx.item [] (Nx.argmin flat_chi2))
let omega_m_best = Nx.item [best_flat_i] flat_om
let omega_l_best = 1.0 -. omega_m_best
let () =
Printf.printf "\n=== Flat ΛCDM best fit ===\n";
Printf.printf " Omega_M = %.3f\n" omega_m_best;
Printf.printf " Omega_L = %.3f\n" omega_l_best;
Printf.printf " chi2 = %.1f (dof ~ %d)\n"
(Nx.item [best_flat_i] flat_chi2) (n_hf - 1);
Printf.printf "\nPerlmutter et al. (1999) found Omega_M ~ 0.28, Omega_L ~ 0.72\n";
Printf.printf "Planck 2018 finds Omega_M = 0.315, Omega_L = 0.685\n"
```
=== Flat ΛCDM best fit ===
Omega_M = 0.350
Omega_L = 0.650
chi2 = 684.6 (dof ~ 1589)
Perlmutter et al. (1999) found Omega_M ~ 0.28, Omega_L ~ 0.72
Planck 2018 finds Omega_M = 0.315, Omega_L = 0.685
val n_flat : int = 200
val flat_om : (float, Nx.float64_elt) Nx.t = float64 [200]
[0.01, 0.0149246, ..., 0.985075, 0.99]
val flat_chi2 : (float, Nx.float64_elt) Nx.t = float64 [200]
[1265.72, 1240.76, ..., 1313.88, 1321.44]
val best_flat_i : int = 69
val omega_m_best : float = 0.349798994974874378
val omega_l_best : float = 0.650201005025125678
### $\chi^2$ profile along the flat-universe constraint
```ocaml
let chi2_min_flat = Nx.item [best_flat_i] flat_chi2
let delta_flat = Nx.sub_s flat_chi2 chi2_min_flat
let _fig =
Hugin.layers [
Hugin.line ~x:(to32 flat_om) ~y:(to32 delta_flat)
~color:Hugin.Color.vermillion ~line_width:2.5 () ;
Hugin.hline ~y:1.0 ~line_style:`Dashed ~color:Hugin.Color.gray
~label:"Δχ² = 1 (1σ)" () ;
Hugin.hline ~y:4.0 ~line_style:`Dotted ~color:Hugin.Color.gray
~label:"Δχ² = 4 (2σ)" () ;
Hugin.vline ~x:omega_m_best ~line_style:`Dashed
~color:Hugin.Color.sky_blue
~label:(Printf.sprintf "Best fit: Ω_M = %.3f" omega_m_best) () ;
]
|> Hugin.xlim 0.0 1.0
|> Hugin.ylim 0.0 20.0
|> Hugin.xlabel "Ω_M (flat universe)"
|> Hugin.ylabel "Δχ²"
|> Hugin.title "χ² Profile: Flat ΛCDM"
|> Hugin.legend ~loc:Hugin.Upper_right
|> Hugin.grid_lines true
```
val chi2_min_flat : float = 684.600426847840708
val delta_flat : (float, Nx.float64_elt) Nx.t = float64 [200]
[581.117, 556.156, ..., 629.283, 636.844]
val _fig : Hugin.t =
## Cosmological implications
With the best-fit flat $\Lambda$CDM parameters, we compute some fundamental
properties of the universe using Umbra's cosmology module.
```ocaml
let p_best = Cosmo.flat_lcdm ~h0:70.0 ~omega_m:omega_m_best
let () =
let z0 = Nx.scalar f64 0.0 in
let z1 = Nx.scalar f64 1.0 in
let z_star = Nx.scalar f64 1089.0 in
Printf.printf "\n=== Universe properties (H₀ = 70 km/s/Mpc, Ω_M = %.3f) ===\n\n" omega_m_best;
Printf.printf " Age of the universe = %.2f Gyr\n"
(Nx.item [] (Unit.Time.in_gyr (Cosmo.age ~p:p_best z0)));
Printf.printf " Lookback time to z=1 = %.2f Gyr\n"
(Nx.item [] (Unit.Time.in_gyr (Cosmo.lookback_time ~p:p_best z1)));
Printf.printf " Comoving distance to z=1 = %.0f Mpc\n"
(Nx.item [] (Unit.Length.in_mpc (Cosmo.comoving_distance ~p:p_best z1)));
Printf.printf " Luminosity distance to z=1 = %.0f Mpc\n"
(Nx.item [] (Unit.Length.in_mpc (Cosmo.luminosity_distance ~p:p_best z1)));
Printf.printf " Ang. diameter distance to z=1 = %.0f Mpc\n"
(Nx.item [] (Unit.Length.in_mpc (Cosmo.angular_diameter_distance ~p:p_best z1)));
Printf.printf " Comoving distance to CMB = %.0f Mpc\n"
(Nx.item [] (Unit.Length.in_mpc (Cosmo.comoving_distance ~p:p_best z_star)))
```
=== Universe properties (H₀ = 70 km/s/Mpc, Ω_M = 0.350) ===
Age of the universe = 12.89 Gyr
Lookback time to z=1 = 7.52 Gyr
Comoving distance to z=1 = 3212 Mpc
Luminosity distance to z=1 = 6423 Mpc
Ang. diameter distance to z=1 = 1606 Mpc
Comoving distance to CMB = 12758 Mpc
val p_best : Umbra.Cosmo.params =
## Conclusion
We have reproduced the central result of Perlmutter et al. (1999) using the
modern Pantheon+ dataset and Umbra's cosmology module:
1. The **Hubble diagram** shows that distant SNe Ia are fainter than predicted
by decelerating models, confirming cosmic acceleration.
2. **Residuals** relative to an empty universe clearly show the acceleration
signal at $z > 0.2$.
3. **Confidence contours** in the $\Omega_M$--$\Omega_\Lambda$ plane strongly
exclude $\Omega_\Lambda = 0$ and are consistent with a flat universe with
$\Omega_M \approx 0.3$, $\Omega_\Lambda \approx 0.7$.
The analysis required only Umbra's `Cosmo.lcdm`, `Cosmo.flat_lcdm`, and
`Cosmo.distance_modulus` functions -- the entire theoretical framework for
SN Ia cosmology in a few lines of OCaml.
### References
- Perlmutter, S. et al. 1999, ApJ, 517, 565 (arXiv:astro-ph/9812133)
- Riess, A.G. et al. 1998, AJ, 116, 1009 (arXiv:astro-ph/9805201)
- Scolnic, D.M. et al. 2022, ApJ, 938, 113 (arXiv:2112.03863)
- Brout, D. et al. 2022, ApJ, 938, 110 (arXiv:2202.04077)
================================================
FILE: dev/umbra/test/dune
================================================
(test
(name test_umbra)
(libraries umbra umbra.fits nx nx.io talon windtrap))
================================================
FILE: dev/umbra/test/test_umbra.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
open Windtrap
open Umbra
let eps = 1e-6
let f64 = Nx.float64
let v x = Nx.item [] x
(* Unit tests *)
let test_length_conversion () =
let d = Unit.Length.kpc 10.0 in
let mpc = v (Unit.Length.in_mpc d) in
is_true ~msg:"10 kpc = 0.01 Mpc" (Float.abs (mpc -. 0.01) < eps);
let m = v (Unit.Length.in_m d) in
let back = v (Unit.Length.in_kpc (Unit.Length.m m)) in
is_true ~msg:"kpc -> m -> kpc roundtrip" (Float.abs (back -. 10.0) < eps)
let test_length_arithmetic () =
let open Unit in
let d = Length.kpc 10.0 + Length.pc 500.0 in
let kpc = v (Length.in_kpc d) in
is_true ~msg:"10 kpc + 500 pc = 10.5 kpc" (Float.abs (kpc -. 10.5) < eps)
let test_mass_conversion () =
let m = Unit.Mass.solar_mass 1.0 in
let kg = v (Unit.Mass.in_kg m) in
is_true ~msg:"1 Msun ~ 1.988e30 kg"
(Float.abs (kg -. 1.9884e30) /. 1.9884e30 < 1e-4)
let test_velocity_cross_dim () =
let d = Unit.Length.km 100.0 in
let t = Unit.Time.s 10.0 in
let vel = Unit.length_per_time d t in
let km_s = v (Unit.Velocity.in_km_s vel) in
is_true ~msg:"100 km / 10 s = 10 km/s" (Float.abs (km_s -. 10.0) < eps)
let test_angle_trig () =
let a = Unit.Angle.deg 90.0 in
is_true ~msg:"sin(90°) = 1"
(Float.abs (Nx.item [] (Unit.Angle.sin a) -. 1.0) < eps);
is_true ~msg:"cos(90°) = 0" (Float.abs (Nx.item [] (Unit.Angle.cos a)) < eps)
let test_wavelength_frequency () =
let lam = Unit.Length.nm 500.0 in
let nu = Unit.wavelength_to_frequency lam in
let lam2 = Unit.frequency_to_wavelength nu in
let nm2 = v (Unit.Length.in_nm lam2) in
is_true ~msg:"wavelength -> freq -> wavelength roundtrip"
(Float.abs (nm2 -. 500.0) < eps)
let test_phantom_type_safety () =
(* This is a compile-time test: the following should NOT typecheck: let _ =
Unit.(Length.m 1.0 + Mass.kg 1.0) The fact that this module compiles proves
type safety. *)
let _d = Unit.(Length.m 1.0 + Length.km 1.0) in
let _m = Unit.(Mass.kg 1.0 + Mass.g 500.0) in
()
(* Const tests *)
let test_const_c () =
let c_km_s = v (Unit.Velocity.in_km_s Const.c) in
is_true ~msg:"c ~ 299792 km/s" (Float.abs (c_km_s -. 299792.458) < 1.0)
(* Coord tests *)
let deg_eps = 1e-6
let test_coord_roundtrip () =
let ra =
Unit.Angle.of_deg (Nx.create f64 [| 4 |] [| 180.0; 0.0; 90.0; 266.405 |])
in
let dec =
Unit.Angle.of_deg (Nx.create f64 [| 4 |] [| 45.0; -30.0; 0.0; -28.936 |])
in
let c = Coord.of_radec ~ra ~dec in
let gal = Coord.galactic c in
let back = Coord.icrs gal in
let ra' = Unit.Angle.in_deg (Coord.ra back) in
let dec' = Unit.Angle.in_deg (Coord.dec back) in
let ra_orig = Unit.Angle.in_deg ra in
let dec_orig = Unit.Angle.in_deg dec in
for i = 0 to 3 do
is_true
~msg:(Printf.sprintf "RA roundtrip [%d]" i)
(Float.abs (Nx.item [ i ] ra_orig -. Nx.item [ i ] ra') < deg_eps);
is_true
~msg:(Printf.sprintf "Dec roundtrip [%d]" i)
(Float.abs (Nx.item [ i ] dec_orig -. Nx.item [ i ] dec') < deg_eps)
done
let test_separation_poles () =
let c1 =
Coord.of_radec
~ra:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 0.0 |]))
~dec:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 90.0 |]))
in
let c2 =
Coord.of_radec
~ra:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 0.0 |]))
~dec:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| -90.0 |]))
in
let sep = Coord.separation c1 c2 in
is_true ~msg:"Pole separation = 180"
(Float.abs (Nx.item [ 0 ] (Unit.Angle.in_deg sep) -. 180.0) < deg_eps)
(* Cosmo tests *)
let test_cosmo_distances () =
let z = Nx.scalar f64 0.1 in
let dc = v (Unit.Length.in_mpc (Cosmo.comoving_distance z)) in
is_true
~msg:(Printf.sprintf "comoving(0.1) ~ 421 Mpc, got %.1f" dc)
(Float.abs (dc -. 421.0) < 5.0);
let dl = v (Unit.Length.in_mpc (Cosmo.luminosity_distance z)) in
is_true
~msg:(Printf.sprintf "luminosity(0.1) ~ 463 Mpc, got %.1f" dl)
(Float.abs (dl -. 463.0) < 5.0)
let test_cosmo_lookback () =
let z = Nx.scalar f64 1.0 in
let t = v (Unit.Time.in_gyr (Cosmo.lookback_time z)) in
is_true
~msg:(Printf.sprintf "lookback(1.0) ~ 7.7 Gyr, got %.1f" t)
(Float.abs (t -. 7.7) < 0.3)
let test_cosmo_angular_scale () =
let phys = Unit.Length.kpc 1.0 in
let z = Nx.scalar f64 0.022 in
let ang = Cosmo.angular_size ~z phys in
let arcsec = v (Unit.Angle.in_arcsec ang) in
is_true
~msg:(Printf.sprintf "1 kpc at z=0.022 ~ 2.3 arcsec, got %.2f" arcsec)
(Float.abs (arcsec -. 2.3) < 0.2)
(* Cosmo: high-z regression tests. These catch quadrature under-resolution at
large z. *)
let test_cosmo_age_planck18 () =
let p = Cosmo.planck18 in
let t = v (Unit.Time.in_gyr (Cosmo.age ~p (Nx.scalar f64 0.0))) in
is_true
~msg:(Printf.sprintf "age(Planck18, z=0) ~ 13.8 Gyr, got %.1f" t)
(Float.abs (t -. 13.8) < 0.3)
let test_cosmo_age_at_z1 () =
let p = Cosmo.planck18 in
let age_0 = v (Unit.Time.in_gyr (Cosmo.age ~p (Nx.scalar f64 0.0))) in
let age_1 = v (Unit.Time.in_gyr (Cosmo.age ~p (Nx.scalar f64 1.0))) in
let lb_1 =
v (Unit.Time.in_gyr (Cosmo.lookback_time ~p (Nx.scalar f64 1.0)))
in
is_true
~msg:
(Printf.sprintf
"age(z=0) - age(z=1) = lookback(z=1): %.2f - %.2f = %.2f vs %.2f" age_0
age_1 (age_0 -. age_1) lb_1)
(Float.abs (age_0 -. age_1 -. lb_1) < 0.05)
let test_cosmo_comoving_cmb () =
let p = Cosmo.planck18 in
let dc =
v (Unit.Length.in_mpc (Cosmo.comoving_distance ~p (Nx.scalar f64 1089.0)))
in
is_true
~msg:(Printf.sprintf "comoving(z=1089) ~ 14000 Mpc, got %.0f" dc)
(Float.abs (dc -. 14000.0) < 500.0)
let test_cosmo_comoving_high_z () =
let p = Cosmo.planck18 in
let dc_2 =
v (Unit.Length.in_mpc (Cosmo.comoving_distance ~p (Nx.scalar f64 2.0)))
in
let dc_5 =
v (Unit.Length.in_mpc (Cosmo.comoving_distance ~p (Nx.scalar f64 5.0)))
in
let dc_10 =
v (Unit.Length.in_mpc (Cosmo.comoving_distance ~p (Nx.scalar f64 10.0)))
in
is_true ~msg:"comoving distances monotonically increase"
(dc_2 < dc_5 && dc_5 < dc_10);
is_true
~msg:(Printf.sprintf "comoving(z=10) ~ 9700 Mpc, got %.0f" dc_10)
(Float.abs (dc_10 -. 9700.0) < 300.0)
let test_cosmo_lookback_high_z () =
let p = Cosmo.planck18 in
let lb_5 =
v (Unit.Time.in_gyr (Cosmo.lookback_time ~p (Nx.scalar f64 5.0)))
in
is_true
~msg:(Printf.sprintf "lookback(z=5) ~ 12.5 Gyr, got %.1f" lb_5)
(Float.abs (lb_5 -. 12.5) < 0.3)
(* FITS tests *)
let test_fits_image_roundtrip () =
let path = "_test_image.fits" in
Fun.protect
~finally:(fun () -> if Sys.file_exists path then Sys.remove path)
(fun () ->
let data =
Nx.create Nx.float32 [| 2; 3 |] [| 1.0; 2.0; 3.0; 4.0; 5.0; 6.0 |]
in
Umbra_fits.write_image path data;
let packed = Umbra_fits.read_image ~hdu:0 path in
let result = Nx_io.to_typed Nx.float32 packed in
is_true ~msg:"Image shape" (Nx.shape result = [| 2; 3 |]);
for i = 0 to 5 do
let row = i / 3 and col = i mod 3 in
is_true
~msg:(Printf.sprintf "Image value [%d,%d]" row col)
(Float.abs (Nx.item [ row; col ] data -. Nx.item [ row; col ] result)
< 1e-6)
done)
let test_fits_table_roundtrip () =
let path = "_test_table.fits" in
Fun.protect
~finally:(fun () -> if Sys.file_exists path then Sys.remove path)
(fun () ->
let df =
Talon.create
[
("ra", Talon.Col.float64 [| 10.0; 20.0; 30.0 |]);
("dec", Talon.Col.float64 [| -10.0; 0.0; 10.0 |]);
]
in
Umbra_fits.write_table path df;
let df2 = Umbra_fits.read_table ~hdu:1 path in
is_true ~msg:"Table rows" (Talon.num_rows df2 = 3);
match Talon.to_array Nx.float64 df2 "ra" with
| Some arr -> is_true ~msg:"ra[0]" (Float.abs (arr.(0) -. 10.0) < 1e-10)
| None -> fail "ra column missing")
(* Coord cross-matching tests *)
let test_match_nearest_self () =
let ra = Unit.Angle.of_deg (Nx.create f64 [| 3 |] [| 10.0; 20.0; 30.0 |]) in
let dec = Unit.Angle.of_deg (Nx.create f64 [| 3 |] [| -10.0; 0.0; 10.0 |]) in
let c = Coord.of_radec ~ra ~dec in
let { Coord.indices; separations } = Coord.nearest c c in
for i = 0 to 2 do
is_true
~msg:(Printf.sprintf "Self-match index[%d]" i)
(Int32.to_int (Nx.item [ i ] indices) = i);
is_true
~msg:(Printf.sprintf "Self-match separation[%d]" i)
(Nx.item [ i ] (Unit.Angle.in_rad separations) < 1e-10)
done
(* Time tests *)
let test_time_jd_mjd () =
let t = Time.unsafe_of_jd 2451545.0 in
is_true ~msg:"J2000.0 JD" (Float.abs (Time.to_jd t -. 2451545.0) < 1e-10);
is_true ~msg:"J2000.0 MJD" (Float.abs (Time.to_mjd t -. 51544.5) < 1e-10);
let t2 = Time.unsafe_of_mjd 51544.5 in
is_true ~msg:"MJD roundtrip" (Float.abs (Time.to_jd t2 -. 2451545.0) < 1e-10)
let test_time_iso () =
let t = Time.of_iso "2000-01-01T12:00:00" in
is_true ~msg:"J2000.0 from ISO" (Float.abs (Time.to_jd t -. 2451545.0) < 1e-6);
let s = Time.to_iso t in
is_true ~msg:"ISO roundtrip" (s = "2000-01-01T12:00:00Z")
let test_time_utc_tai_tt () =
let utc = Time.unsafe_of_jd 2451545.0 in
let tai = Time.utc_to_tai utc in
let dt_s = (Time.to_jd tai -. Time.to_jd utc) *. 86400.0 in
is_true
~msg:(Printf.sprintf "TAI-UTC at J2000 = 32s, got %.1f" dt_s)
(Float.abs (dt_s -. 32.0) < 0.1);
let tt = Time.tai_to_tt tai in
let dt_tt = (Time.to_jd tt -. Time.to_jd tai) *. 86400.0 in
is_true
~msg:(Printf.sprintf "TT-TAI = 32.184s, got %.6f" dt_tt)
(Float.abs (dt_tt -. 32.184) < 1e-3);
let tai' = Time.tt_to_tai tt in
is_true ~msg:"TT->TAI roundtrip"
(Float.abs (Time.to_jd tai' -. Time.to_jd tai) < 1e-12);
let utc' = Time.tai_to_utc tai in
is_true ~msg:"TAI->UTC roundtrip"
(Float.abs (Time.to_jd utc' -. Time.to_jd utc) < 1e-10)
let test_time_tdb () =
let tt = Time.unsafe_of_jd 2451545.0 in
let tdb = Time.tt_to_tdb tt in
let dt_ms = (Time.to_jd tdb -. Time.to_jd tt) *. 86400.0 *. 1000.0 in
is_true
~msg:(Printf.sprintf "TDB-TT < 2ms, got %.3f ms" dt_ms)
(Float.abs dt_ms < 2.0);
let tt' = Time.tdb_to_tt tdb in
is_true ~msg:"TDB->TT roundtrip"
(Float.abs (Time.to_jd tt' -. Time.to_jd tt) < 1e-10)
let test_time_unix () =
let t = Time.of_unix 0.0 in
is_true ~msg:"Unix epoch JD" (Float.abs (Time.to_jd t -. 2440587.5) < 1e-10);
let u = Time.to_unix t in
is_true ~msg:"Unix roundtrip" (Float.abs u < 1e-6)
let test_time_diff_add () =
let t1 = Time.unsafe_of_jd 2451545.0 in
let t2 = Time.unsafe_of_jd 2451546.0 in
let dt = Time.diff t2 t1 in
is_true ~msg:"diff = 1 day"
(Float.abs (v (Unit.Time.in_day dt) -. 1.0) < 1e-10);
let t3 = Time.add t1 (Unit.Time.day 1.0) in
is_true ~msg:"add 1 day" (Float.abs (Time.to_jd t3 -. 2451546.0) < 1e-10)
(* Cosmo preset tests *)
let test_cosmo_planck18 () =
let z = Nx.scalar f64 0.5 in
let dc =
v (Unit.Length.in_mpc (Cosmo.comoving_distance ~p:Cosmo.planck18 z))
in
is_true
~msg:(Printf.sprintf "Planck18 comoving(0.5) ~ 1960 Mpc, got %.0f" dc)
(Float.abs (dc -. 1960.0) < 30.0)
let test_cosmo_hubble () =
let z = Nx.scalar f64 0.0 in
let h0 = Nx.item [] (Cosmo.hubble z) in
is_true
~msg:(Printf.sprintf "H(0) = H0 = 70, got %.1f" h0)
(Float.abs (h0 -. 70.0) < 1e-6)
(* Coord FK5/Supergalactic tests *)
let test_coord_ecliptic_roundtrip () =
let ra = Unit.Angle.of_deg (Nx.create f64 [| 2 |] [| 180.0; 45.0 |]) in
let dec = Unit.Angle.of_deg (Nx.create f64 [| 2 |] [| 45.0; -30.0 |]) in
let c = Coord.of_radec ~ra ~dec in
let ecl = Coord.ecliptic_j2000 c in
let back = Coord.icrs ecl in
let ra' = Unit.Angle.in_deg (Coord.ra back) in
let dec' = Unit.Angle.in_deg (Coord.dec back) in
let ra_orig = Unit.Angle.in_deg ra in
let dec_orig = Unit.Angle.in_deg dec in
for i = 0 to 1 do
is_true
~msg:(Printf.sprintf "Ecliptic RA roundtrip [%d]" i)
(Float.abs (Nx.item [ i ] ra_orig -. Nx.item [ i ] ra') < deg_eps);
is_true
~msg:(Printf.sprintf "Ecliptic Dec roundtrip [%d]" i)
(Float.abs (Nx.item [ i ] dec_orig -. Nx.item [ i ] dec') < deg_eps)
done
let test_coord_supergalactic_roundtrip () =
let ra = Unit.Angle.of_deg (Nx.create f64 [| 2 |] [| 180.0; 45.0 |]) in
let dec = Unit.Angle.of_deg (Nx.create f64 [| 2 |] [| 45.0; -30.0 |]) in
let c = Coord.of_radec ~ra ~dec in
let sg = Coord.supergalactic c in
let back = Coord.icrs sg in
let ra' = Unit.Angle.in_deg (Coord.ra back) in
let dec' = Unit.Angle.in_deg (Coord.dec back) in
let ra_orig = Unit.Angle.in_deg ra in
let dec_orig = Unit.Angle.in_deg dec in
for i = 0 to 1 do
is_true
~msg:(Printf.sprintf "Supergalactic RA roundtrip [%d]" i)
(Float.abs (Nx.item [ i ] ra_orig -. Nx.item [ i ] ra') < 1e-4);
is_true
~msg:(Printf.sprintf "Supergalactic Dec roundtrip [%d]" i)
(Float.abs (Nx.item [ i ] dec_orig -. Nx.item [ i ] dec') < 1e-4)
done
(* Unit energy-wavelength-frequency tests *)
let test_energy_wavelength_frequency () =
let e = Unit.Energy.ev 2.0 in
let nu = Unit.energy_to_frequency e in
let e2 = Unit.frequency_to_energy nu in
is_true ~msg:"energy->freq->energy roundtrip"
(Float.abs (v (Unit.Energy.in_ev e2) -. 2.0) < 1e-6);
let lam = Unit.energy_to_wavelength e in
let nu2 = Unit.wavelength_to_frequency lam in
let e3 = Unit.frequency_to_energy nu2 in
is_true ~msg:"energy->wavelength->freq->energy roundtrip"
(Float.abs (v (Unit.Energy.in_ev e3) -. 2.0) < 1e-6)
(* Spectrum tests *)
let test_spectrum_blackbody_wien () =
(* Wien's displacement law: λ_max * T = 2.898e-3 m·K *)
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 5778.0) in
let wave = Unit.Length.of_m (Nx.linspace f64 1e-7 3e-6 1000) in
let spec = Spectrum.blackbody ~temperature:temp ~wavelength:wave in
let vals = Spectrum.values spec in
(* Find index of max value *)
let peak_idx = ref 0 in
let peak_val = ref (Nx.item [ 0 ] vals) in
for i = 1 to 999 do
let v = Nx.item [ i ] vals in
if v > !peak_val then begin
peak_val := v;
peak_idx := i
end
done;
let wave_m = Unit.Length.in_m (Spectrum.wavelength spec) in
let peak_lam = Nx.item [ !peak_idx ] wave_m in
let wien = peak_lam *. 5778.0 in
is_true
~msg:(Printf.sprintf "Wien's law: λ_max*T ~ 2.898e-3, got %.4e" wien)
(Float.abs (wien -. 2.898e-3) /. 2.898e-3 < 0.01)
let test_spectrum_redshift () =
let wave = Unit.Length.of_m (Nx.linspace f64 1e-7 1e-6 100) in
let values = Nx.ones f64 [| 100 |] in
let spec =
Spectrum.create ~wavelength:wave ~values |> Spectrum.as_flux_density
in
let z = Nx.scalar f64 1.0 in
let shifted = Spectrum.redshift ~z spec in
(* Wavelengths should double at z=1 *)
let orig_wave = Unit.Length.in_m (Spectrum.wavelength spec) in
let shifted_wave = Unit.Length.in_m (Spectrum.wavelength shifted) in
let ratio = Nx.item [ 50 ] shifted_wave /. Nx.item [ 50 ] orig_wave in
is_true ~msg:"Redshift z=1 doubles wavelength"
(Float.abs (ratio -. 2.0) < 1e-10);
(* Values should halve at z=1 *)
let val_ratio =
Nx.item [ 50 ] (Spectrum.values shifted) /. Nx.item [ 50 ] values
in
is_true ~msg:"Redshift z=1 halves values"
(Float.abs (val_ratio -. 0.5) < 1e-10)
let test_spectrum_scale () =
let wave = Unit.Length.of_m (Nx.linspace f64 1e-7 1e-6 10) in
let values = Nx.ones f64 [| 10 |] in
let spec = Spectrum.create ~wavelength:wave ~values in
let scaled = Spectrum.scale (Nx.scalar f64 3.0) spec in
is_true ~msg:"Scale by 3"
(Float.abs (Nx.item [ 0 ] (Spectrum.values scaled) -. 3.0) < 1e-10)
(* Extinction tests *)
let test_extinction_ccm89_v_band () =
(* At V-band (550nm), A_λ/A_V should be ~1.0 for R_V=3.1 *)
let rv = Nx.scalar f64 3.1 in
let wave_v = Unit.Length.of_m (Nx.create f64 [| 1 |] [| 5.5e-7 |]) in
let alav = Extinction.curve (Extinction.ccm89 ~rv) ~wavelength:wave_v in
let val_v = Nx.item [ 0 ] alav in
is_true
~msg:(Printf.sprintf "CCM89 A_V/A_V ~ 1.0 at 550nm, got %.3f" val_v)
(Float.abs (val_v -. 1.0) < 0.1)
let test_extinction_apply_unredden () =
(* apply then unredden should recover original spectrum *)
let rv = Nx.scalar f64 3.1 in
let law = Extinction.ccm89 ~rv in
let wave = Unit.Length.of_m (Nx.linspace f64 3e-7 1e-6 50) in
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 6000.0) in
let spec = Spectrum.blackbody ~temperature:temp ~wavelength:wave in
let av = Nx.scalar f64 1.0 in
let reddened = Extinction.apply law ~av spec in
let recovered = Extinction.unredden law ~av reddened in
(* Compare values *)
let orig_val = Nx.item [ 25 ] (Spectrum.values spec) in
let rec_val = Nx.item [ 25 ] (Spectrum.values recovered) in
is_true ~msg:"apply + unredden roundtrip"
(Float.abs (rec_val -. orig_val) /. orig_val < 1e-10)
let test_extinction_ccm89_monotonic () =
(* Extinction should increase toward blue wavelengths (for optical) *)
let rv = Nx.scalar f64 3.1 in
let wave =
Unit.Length.of_m (Nx.create f64 [| 3 |] [| 4e-7; 5.5e-7; 8e-7 |])
in
let alav = Extinction.curve (Extinction.ccm89 ~rv) ~wavelength:wave in
let blue = Nx.item [ 0 ] alav in
let green = Nx.item [ 1 ] alav in
let red = Nx.item [ 2 ] alav in
is_true ~msg:"CCM89: A_blue > A_green" (blue > green);
is_true ~msg:"CCM89: A_green > A_red" (green > red)
(* Photometry tests *)
let test_photometry_ab_mag_flat () =
(* A flat f_nu spectrum at 3631 Jy should give m_AB = 0 in any band. f_nu =
3631e-26 W/m²/Hz, so f_lambda = f_nu * c / lambda² *)
let n = 100 in
let bp =
Photometry.tophat ~lo:(Unit.Length.m 4e-7) ~hi:(Unit.Length.m 7e-7) ~n
in
let wave_m = Unit.Length.to_tensor (Photometry.wavelength bp) in
let c = 299_792_458.0 in
let ab_zp = 3631.0e-26 in
(* f_lambda = f_nu * c / lambda^2 *)
let f_lambda =
Nx.div
(Nx.mul_s (Nx.recip (Nx.square wave_m)) (ab_zp *. c))
(Nx.scalar f64 1.0)
in
let spec =
Spectrum.create ~wavelength:(Photometry.wavelength bp) ~values:f_lambda
|> Spectrum.as_flux_density
in
let mag = Nx.item [] (Photometry.ab_mag bp spec) in
is_true
~msg:(Printf.sprintf "Flat f_nu=3631Jy gives m_AB ~ 0, got %.3f" mag)
(Float.abs mag < 0.05)
let test_photometry_color_same_band () =
(* Color between same band should be 0 *)
let bp =
Photometry.tophat ~lo:(Unit.Length.m 4e-7) ~hi:(Unit.Length.m 5.5e-7) ~n:50
in
let spec =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar f64 5000.0))
~wavelength:(Photometry.wavelength bp)
|> Spectrum.as_flux_density
in
let col = Nx.item [] (Photometry.color bp bp spec) in
is_true ~msg:"Same-band color = 0" (Float.abs col < 1e-10)
let test_photometry_blue_star_color () =
(* A hot star should be brighter (lower mag) in blue than red *)
let n = 100 in
let bp_b =
Photometry.tophat ~lo:(Unit.Length.m 4e-7) ~hi:(Unit.Length.m 5e-7) ~n
in
let bp_r =
Photometry.tophat ~lo:(Unit.Length.m 6e-7) ~hi:(Unit.Length.m 7e-7) ~n
in
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 20000.0) in
let spec_b =
Spectrum.blackbody ~temperature:temp
~wavelength:(Photometry.wavelength bp_b)
|> Spectrum.as_flux_density
in
let spec_r =
Spectrum.blackbody ~temperature:temp
~wavelength:(Photometry.wavelength bp_r)
|> Spectrum.as_flux_density
in
let mag_b = Nx.item [] (Photometry.ab_mag bp_b spec_b) in
let mag_r = Nx.item [] (Photometry.ab_mag bp_r spec_r) in
is_true ~msg:"Hot star: blue mag < red mag (brighter in blue)" (mag_b < mag_r)
(* Cosmo: extended models *)
let test_cosmo_flat_lcdm_same_as_default () =
let p = Cosmo.flat_lcdm ~h0:70.0 ~omega_m:0.3 in
let z = Nx.scalar f64 0.5 in
let dc_default = v (Unit.Length.in_mpc (Cosmo.comoving_distance z)) in
let dc_flat = v (Unit.Length.in_mpc (Cosmo.comoving_distance ~p z)) in
is_true ~msg:"flat_lcdm(70,0.3) = default"
(Float.abs (dc_default -. dc_flat) < 1e-6)
let test_cosmo_nonflat_lcdm () =
(* Open universe: omega_m=0.3, omega_l=0.5 → omega_k=0.2. Result should differ
from flat LCDM. *)
let p_flat = Cosmo.flat_lcdm ~h0:70.0 ~omega_m:0.3 in
let p_open = Cosmo.lcdm ~h0:70.0 ~omega_m:0.3 ~omega_l:0.5 in
let z = Nx.scalar f64 1.0 in
let dl_flat =
v (Unit.Length.in_mpc (Cosmo.luminosity_distance ~p:p_flat z))
in
let dl_open =
v (Unit.Length.in_mpc (Cosmo.luminosity_distance ~p:p_open z))
in
is_true
~msg:
(Printf.sprintf "Non-flat LCDM differs from flat: %.0f vs %.0f" dl_open
dl_flat)
(Float.abs (dl_open -. dl_flat) > 10.0)
let test_cosmo_wcdm () =
(* w0 = -1 should be identical to ΛCDM *)
let p_lcdm = Cosmo.flat_lcdm ~h0:70.0 ~omega_m:0.3 in
let p_wcdm = Cosmo.wcdm ~h0:70.0 ~omega_m:0.3 ~w0:(-1.0) () in
let z = Nx.scalar f64 0.5 in
let dc_lcdm = v (Unit.Length.in_mpc (Cosmo.comoving_distance ~p:p_lcdm z)) in
let dc_wcdm = v (Unit.Length.in_mpc (Cosmo.comoving_distance ~p:p_wcdm z)) in
is_true
~msg:(Printf.sprintf "wCDM(w0=-1) = LCDM: %.1f vs %.1f" dc_wcdm dc_lcdm)
(Float.abs (dc_wcdm -. dc_lcdm) < 1.0)
let test_cosmo_w0wacdm () =
(* w0=-1, wa=0 should reduce to ΛCDM *)
let p_lcdm = Cosmo.flat_lcdm ~h0:70.0 ~omega_m:0.3 in
let p_cpl = Cosmo.w0wacdm ~h0:70.0 ~omega_m:0.3 ~w0:(-1.0) ~wa:0.0 () in
let z = Nx.scalar f64 1.0 in
let dl_lcdm =
v (Unit.Length.in_mpc (Cosmo.luminosity_distance ~p:p_lcdm z))
in
let dl_cpl = v (Unit.Length.in_mpc (Cosmo.luminosity_distance ~p:p_cpl z)) in
is_true
~msg:(Printf.sprintf "w0waCDM(-1,0) = LCDM: %.1f vs %.1f" dl_cpl dl_lcdm)
(Float.abs (dl_cpl -. dl_lcdm) < 1.0)
let test_cosmo_e_of () =
(* E(z=0) = 1 for any cosmology *)
let p = Cosmo.planck18 in
let z = Nx.scalar f64 0.0 in
let e0 = v (Cosmo.e_of p z) in
is_true
~msg:(Printf.sprintf "E(z=0) = 1, got %.6f" e0)
(Float.abs (e0 -. 1.0) < 1e-6)
let test_cosmo_z_at_value () =
(* Roundtrip: compute dl at z=0.5, then find z back *)
let p = Cosmo.default in
let z0 = 0.5 in
let dl = Cosmo.luminosity_distance ~p (Nx.scalar f64 z0) in
let z_found =
v
(Cosmo.z_at_value ~p
(fun ~p z -> Unit.Length.to_tensor (Cosmo.luminosity_distance ~p z))
(Unit.Length.to_tensor dl))
in
is_true
~msg:(Printf.sprintf "z_at_value roundtrip: expected 0.5, got %.6f" z_found)
(Float.abs (z_found -. z0) < 1e-6)
(* AltAz tests *)
let test_altaz_zenith () =
(* A star at the observer's zenith should have alt ~ 90° *)
let obs =
Altaz.make_observer ~lat:(Unit.Angle.deg 0.0) ~lon:(Unit.Angle.deg 0.0) ()
in
(* Use the vernal equinox time: RA=0, Dec=0 should be near zenith at sidereal
midnight from lon=0, lat=0. At J2000.0 the ERA is ~280.46°, so RA ~ 280.46°
should be near transit. Instead, test roundtrip. *)
let t = Time.of_iso "2024-01-01T00:00:00" in
let ra = Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 180.0 |]) in
let dec = Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 45.0 |]) in
let c = Coord.of_radec ~ra ~dec in
let hz = Altaz.of_coord ~obstime:t ~observer:obs c in
let back = Altaz.to_coord ~obstime:t ~observer:obs hz in
let ra' = Nx.item [ 0 ] (Unit.Angle.in_deg (Coord.ra back)) in
let dec' = Nx.item [ 0 ] (Unit.Angle.in_deg (Coord.dec back)) in
is_true
~msg:(Printf.sprintf "AltAz RA roundtrip: 180 vs %.4f" ra')
(Float.abs (ra' -. 180.0) < 0.1);
is_true
~msg:(Printf.sprintf "AltAz Dec roundtrip: 45 vs %.4f" dec')
(Float.abs (dec' -. 45.0) < 0.1)
let test_altaz_north_pole () =
(* Polaris (dec ~ 90) should always be near alt = observer lat *)
let obs =
Altaz.make_observer ~lat:(Unit.Angle.deg 45.0) ~lon:(Unit.Angle.deg 0.0) ()
in
let t = Time.of_iso "2024-06-15T12:00:00" in
let ra = Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 37.95 |]) in
let dec = Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 89.264 |]) in
let c = Coord.of_radec ~ra ~dec in
let hz = Altaz.of_coord ~obstime:t ~observer:obs c in
let alt_deg = Nx.item [ 0 ] (Unit.Angle.in_deg (Altaz.alt hz)) in
is_true
~msg:(Printf.sprintf "Polaris alt ~ 45° from lat=45°, got %.1f" alt_deg)
(Float.abs (alt_deg -. 45.0) < 2.0)
(* Galactocentric tests *)
let test_galactocentric_gc_position () =
(* A point at l=0, b=0, d=galcen_distance should map to near (0, 0, z_sun) in
Galactocentric. *)
let c =
Coord.of_galactic
~l:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 0.0 |]))
~b:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 0.0 |]))
in
let gc =
Galactocentric.of_coord
~distance:(Unit.Length.of_kpc (Nx.create f64 [| 1 |] [| 8.122 |]))
c
in
let xv = Nx.item [ 0 ] (Unit.Length.in_kpc (Galactocentric.x gc)) in
let yv = Nx.item [ 0 ] (Unit.Length.in_kpc (Galactocentric.y gc)) in
let zv = Nx.item [ 0 ] (Unit.Length.in_kpc (Galactocentric.z gc)) in
is_true
~msg:(Printf.sprintf "GC x ~ 0 kpc, got %.6f" xv)
(Float.abs xv < 1e-10);
is_true
~msg:(Printf.sprintf "GC y ~ 0 kpc, got %.6f" yv)
(Float.abs yv < 1e-10);
is_true
~msg:(Printf.sprintf "GC z ~ z_sun=0.0208 kpc, got %.4f" zv)
(Float.abs (zv -. 0.0208) < 1e-10)
let test_galactocentric_roundtrip () =
let ra = Unit.Angle.of_deg (Nx.create f64 [| 2 |] [| 180.0; 45.0 |]) in
let dec = Unit.Angle.of_deg (Nx.create f64 [| 2 |] [| 30.0; -15.0 |]) in
let c = Coord.of_radec ~ra ~dec in
let d = Unit.Length.of_kpc (Nx.create f64 [| 2 |] [| 5.0; 12.0 |]) in
let gc = Galactocentric.of_coord ~distance:d c in
let c', d' = Galactocentric.to_coord gc in
let ra' = Unit.Angle.in_deg (Coord.ra c') in
let dec' = Unit.Angle.in_deg (Coord.dec c') in
let d_kpc' = Unit.Length.in_kpc d' in
let ra_orig = Unit.Angle.in_deg ra in
let dec_orig = Unit.Angle.in_deg dec in
let d_orig = Unit.Length.in_kpc d in
for i = 0 to 1 do
is_true
~msg:(Printf.sprintf "Galactocentric RA roundtrip [%d]" i)
(Float.abs (Nx.item [ i ] ra' -. Nx.item [ i ] ra_orig) < 0.01);
is_true
~msg:(Printf.sprintf "Galactocentric Dec roundtrip [%d]" i)
(Float.abs (Nx.item [ i ] dec' -. Nx.item [ i ] dec_orig) < 0.01);
is_true
~msg:(Printf.sprintf "Galactocentric distance roundtrip [%d]" i)
(Float.abs (Nx.item [ i ] d_kpc' -. Nx.item [ i ] d_orig) < 0.01)
done
(* Cosmo: growth and power spectrum *)
let test_cosmo_growth_factor_z0 () =
let g = v (Cosmo.growth_factor ~p:Cosmo.planck18 (Nx.scalar f64 0.0)) in
is_true
~msg:(Printf.sprintf "D(z=0) = 1.0, got %.6f" g)
(Float.abs (g -. 1.0) < 1e-4)
let test_cosmo_growth_factor_z1 () =
let g = v (Cosmo.growth_factor ~p:Cosmo.planck18 (Nx.scalar f64 1.0)) in
is_true
~msg:(Printf.sprintf "D(z=1) ~ 0.61, got %.4f" g)
(Float.abs (g -. 0.61) < 0.02)
let test_cosmo_growth_rate_z0 () =
let f = v (Cosmo.growth_rate ~p:Cosmo.planck18 (Nx.scalar f64 0.0)) in
(* f(z=0) ~ Ω_m^0.55 ~ 0.524 for Planck18 *)
is_true
~msg:(Printf.sprintf "f(z=0) ~ 0.52, got %.4f" f)
(Float.abs (f -. 0.52) < 0.02)
let test_cosmo_growth_monotonic () =
let p = Cosmo.planck18 in
let d0 = v (Cosmo.growth_factor ~p (Nx.scalar f64 0.0)) in
let d1 = v (Cosmo.growth_factor ~p (Nx.scalar f64 0.5)) in
let d2 = v (Cosmo.growth_factor ~p (Nx.scalar f64 1.0)) in
is_true ~msg:"D(0) > D(0.5) > D(1)" (d0 > d1 && d1 > d2)
let test_cosmo_linear_power () =
let p = Cosmo.planck18 in
let k = Nx.scalar f64 0.1 in
let pk = v (Cosmo.linear_power ~p k (Nx.scalar f64 0.0)) in
is_true ~msg:(Printf.sprintf "P_lin(k=0.1, z=0) > 0, got %.1f" pk) (pk > 0.0);
(* P(k, z=1) should be less than P(k, z=0) *)
let pk1 = v (Cosmo.linear_power ~p k (Nx.scalar f64 1.0)) in
is_true ~msg:"P_lin(z=1) < P_lin(z=0)" (pk1 < pk)
let test_cosmo_nonlinear_power () =
let p = Cosmo.planck18 in
let k = Nx.scalar f64 1.0 in
let pk_nl = v (Cosmo.nonlinear_power ~p k (Nx.scalar f64 0.0)) in
let pk_lin = v (Cosmo.linear_power ~p k (Nx.scalar f64 0.0)) in
is_true ~msg:(Printf.sprintf "P_nl(k=1) > 0, got %.1f" pk_nl) (pk_nl > 0.0);
(* At k=1 h/Mpc, nonlinear should exceed linear *)
is_true
~msg:(Printf.sprintf "P_nl(k=1) > P_lin(k=1): %.1f > %.1f" pk_nl pk_lin)
(pk_nl > pk_lin)
let test_cosmo_params_accessors () =
let p = Cosmo.planck18 in
let ob = v (Cosmo.omega_b p) in
let ns = v (Cosmo.n_s p) in
let s8 = v (Cosmo.sigma8 p) in
is_true ~msg:"Planck18 omega_b = 0.049" (Float.abs (ob -. 0.049) < 1e-6);
is_true ~msg:"Planck18 n_s = 0.9665" (Float.abs (ns -. 0.9665) < 1e-6);
is_true ~msg:"Planck18 sigma8 = 0.8102" (Float.abs (s8 -. 0.8102) < 1e-6)
(* Survey tests *)
let test_survey_smail_normalized () =
let nz = Survey.smail ~a:2.0 ~b:1.5 ~z0:0.3 () in
let n = 1000 in
let zmax = Survey.nz_zmax nz in
let dz = zmax /. Float.of_int n in
let sum = ref 0.0 in
for i = 0 to n do
let z = Float.of_int i *. dz in
let nz_val = v (Survey.eval_nz nz (Nx.scalar f64 z)) in
let w = if i = 0 || i = n then 0.5 else 1.0 in
sum := !sum +. (w *. nz_val *. dz)
done;
is_true
~msg:(Printf.sprintf "smail integrates to 1.0, got %.6f" !sum)
(Float.abs (!sum -. 1.0) < 1e-3)
let test_survey_tabulated () =
let z = Nx.create f64 [| 5 |] [| 0.0; 0.25; 0.5; 0.75; 1.0 |] in
let pz = Nx.create f64 [| 5 |] [| 0.0; 1.0; 2.0; 1.0; 0.0 |] in
let nz = Survey.tabulated ~z ~pz () in
let mid = v (Survey.eval_nz nz (Nx.scalar f64 0.5)) in
is_true ~msg:"tabulated mid > 0" (mid > 0.0);
let out = v (Survey.eval_nz nz (Nx.scalar f64 1.5)) in
is_true ~msg:"tabulated outside = 0" (Float.abs out < eps)
let test_survey_cl_shape () =
let p = Cosmo.planck18 in
let nz1 = Survey.smail ~a:2.0 ~b:1.5 ~z0:0.3 () in
let nz2 = Survey.smail ~a:2.0 ~b:1.5 ~z0:0.7 () in
let wl1 = Survey.weak_lensing ~n_gal:26.0 nz1 in
let wl2 = Survey.weak_lensing ~n_gal:26.0 nz2 in
let ell = Nx.create f64 [| 3 |] [| 100.0; 300.0; 1000.0 |] in
let cls = Survey.angular_cl ~p ~power:Survey.linear ~ell [ wl1; wl2 ] in
let shape = Nx.shape (Survey.Cls.to_tensor cls) in
is_true
~msg:(Printf.sprintf "C_l shape = [3; 3], got [%d; %d]" shape.(0) shape.(1))
(shape.(0) = 3 && shape.(1) = 3);
is_true
~msg:(Printf.sprintf "n_tracers = 2, got %d" (Survey.Cls.n_tracers cls))
(Survey.Cls.n_tracers cls = 2)
let test_survey_cl_positive () =
let p = Cosmo.planck18 in
let nz1 = Survey.smail ~a:2.0 ~b:1.5 ~z0:0.5 () in
let wl = Survey.weak_lensing ~n_gal:26.0 nz1 in
let ell = Nx.create f64 [| 3 |] [| 100.0; 500.0; 1000.0 |] in
let cls = Survey.angular_cl ~p ~power:Survey.linear ~ell [ wl ] in
let cl_auto = Survey.Cls.get cls ~i:0 ~j:0 in
for l = 0 to 2 do
let cl_val = Nx.item [ l ] cl_auto in
is_true ~msg:(Printf.sprintf "C_l[%d] = %.2e > 0" l cl_val) (cl_val > 0.0)
done
let test_survey_noise_wl () =
let sigma_e = 0.26 in
let n_gal = 30.0 in
let nz1 = Survey.smail ~a:2.0 ~b:1.5 ~z0:0.3 () in
let wl = Survey.weak_lensing ~sigma_e ~n_gal nz1 in
let ell = Nx.create f64 [| 3 |] [| 100.0; 500.0; 1000.0 |] in
let cls = Survey.angular_cl ~ell [ wl ] in
let nl = Survey.Cls.noise cls in
let n0 = Nx.item [ 0; 0 ] nl in
let n1 = Nx.item [ 0; 1 ] nl in
let n2 = Nx.item [ 0; 2 ] nl in
is_true ~msg:"WL noise > 0" (n0 > 0.0);
is_true
~msg:(Printf.sprintf "WL noise constant in ℓ: %.2e vs %.2e" n0 n1)
(Float.abs (n0 -. n1) < 1e-20);
is_true ~msg:"WL noise constant in ℓ (2)" (Float.abs (n1 -. n2) < 1e-20)
(* Spectrum: mul/div *)
let test_spectrum_mul () =
let wave = Unit.Length.of_m (Nx.linspace f64 3e-7 1e-6 10) in
let values =
Nx.create f64 [| 10 |]
[| 1.0; 2.0; 3.0; 4.0; 5.0; 6.0; 7.0; 8.0; 9.0; 10.0 |]
in
let a =
Spectrum.create ~wavelength:wave ~values |> Spectrum.as_flux_density
in
let trans =
Nx.create f64 [| 10 |]
[| 0.5; 0.5; 0.5; 0.5; 0.5; 0.5; 0.5; 0.5; 0.5; 0.5 |]
in
let b = Spectrum.create ~wavelength:wave ~values:trans in
let result = Spectrum.mul a b in
is_true ~msg:"mul: 2.0 * 0.5 = 1.0"
(Float.abs (Nx.item [ 1 ] (Spectrum.values result) -. 1.0) < eps);
is_true ~msg:"mul: 10.0 * 0.5 = 5.0"
(Float.abs (Nx.item [ 9 ] (Spectrum.values result) -. 5.0) < eps)
let test_spectrum_div () =
let wave = Unit.Length.of_m (Nx.linspace f64 3e-7 1e-6 10) in
let values =
Nx.create f64 [| 10 |]
[| 1.0; 2.0; 3.0; 4.0; 5.0; 6.0; 7.0; 8.0; 9.0; 10.0 |]
in
let a =
Spectrum.create ~wavelength:wave ~values |> Spectrum.as_flux_density
in
let flat =
Nx.create f64 [| 10 |]
[| 2.0; 2.0; 2.0; 2.0; 2.0; 2.0; 2.0; 2.0; 2.0; 2.0 |]
in
let b = Spectrum.create ~wavelength:wave ~values:flat in
let result = Spectrum.div a b in
is_true ~msg:"div: 4.0 / 2.0 = 2.0"
(Float.abs (Nx.item [ 3 ] (Spectrum.values result) -. 2.0) < eps);
is_true ~msg:"div: 10.0 / 2.0 = 5.0"
(Float.abs (Nx.item [ 9 ] (Spectrum.values result) -. 5.0) < eps)
let test_spectrum_mul_div_roundtrip () =
let wave = Unit.Length.of_m (Nx.linspace f64 3e-7 1e-6 50) in
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 6000.0) in
let spec =
Spectrum.blackbody ~temperature:temp ~wavelength:wave
|> Spectrum.as_flux_density
in
let trans_vals =
Nx.create f64 [| 50 |]
(Array.init 50 (fun i ->
0.5 +. (0.3 *. Float.sin (Float.of_int i *. 0.2))))
in
let trans = Spectrum.create ~wavelength:wave ~values:trans_vals in
let mulled = Spectrum.mul spec trans in
let recovered = Spectrum.div mulled trans in
let orig_val = Nx.item [ 25 ] (Spectrum.values spec) in
let rec_val = Nx.item [ 25 ] (Spectrum.values recovered) in
is_true ~msg:"mul then div roundtrip"
(Float.abs (rec_val -. orig_val) /. orig_val < 1e-10)
(* Spectrum: line profiles *)
let test_spectrum_gaussian_peak () =
let wave = Unit.Length.of_m (Nx.linspace f64 6.4e-7 6.7e-7 1000) in
let center = Unit.Length.nm 656.3 in
let stddev = Unit.Length.nm 1.0 in
let amplitude = Nx.scalar f64 1.0 in
let g = Spectrum.gaussian ~amplitude ~center ~stddev ~wavelength:wave in
let vals = Spectrum.values g in
let peak_idx = ref 0 in
let peak_val = ref (Nx.item [ 0 ] vals) in
for i = 1 to 999 do
let vi = Nx.item [ i ] vals in
if vi > !peak_val then begin
peak_val := vi;
peak_idx := i
end
done;
let wave_m = Unit.Length.in_m (Spectrum.wavelength g) in
let peak_lam_nm = Nx.item [ !peak_idx ] wave_m *. 1e9 in
is_true
~msg:(Printf.sprintf "Gaussian peak near 656.3 nm, got %.1f" peak_lam_nm)
(Float.abs (peak_lam_nm -. 656.3) < 0.5);
is_true
~msg:(Printf.sprintf "Gaussian peak amplitude ~ 1.0, got %.4f" !peak_val)
(Float.abs (!peak_val -. 1.0) < 0.01)
let test_spectrum_lorentzian_peak () =
let wave = Unit.Length.of_m (Nx.linspace f64 4.8e-7 5.2e-7 1000) in
let center = Unit.Length.nm 500.0 in
let fwhm = Unit.Length.nm 2.0 in
let amplitude = Nx.scalar f64 3.0 in
let l = Spectrum.lorentzian ~amplitude ~center ~fwhm ~wavelength:wave in
let vals = Spectrum.values l in
let peak_idx = ref 0 in
let peak_val = ref (Nx.item [ 0 ] vals) in
for i = 1 to 999 do
let vi = Nx.item [ i ] vals in
if vi > !peak_val then begin
peak_val := vi;
peak_idx := i
end
done;
let wave_m = Unit.Length.in_m (Spectrum.wavelength l) in
let peak_lam_nm = Nx.item [ !peak_idx ] wave_m *. 1e9 in
is_true
~msg:(Printf.sprintf "Lorentzian peak near 500 nm, got %.1f" peak_lam_nm)
(Float.abs (peak_lam_nm -. 500.0) < 0.5);
is_true
~msg:(Printf.sprintf "Lorentzian peak ~ 3.0, got %.4f" !peak_val)
(Float.abs (!peak_val -. 3.0) < 0.05)
let test_spectrum_voigt_limits () =
let wave = Unit.Length.of_m (Nx.linspace f64 4.8e-7 5.2e-7 1000) in
let center = Unit.Length.nm 500.0 in
let amplitude = Nx.scalar f64 1.0 in
(* Gaussian limit: sigma >> gamma *)
let sigma_big = Unit.Length.nm 2.0 in
let gamma_tiny = Unit.Length.nm 0.001 in
let voigt_g =
Spectrum.voigt ~amplitude ~center ~sigma:sigma_big ~gamma:gamma_tiny
~wavelength:wave
in
let gauss =
Spectrum.gaussian ~amplitude ~center ~stddev:sigma_big ~wavelength:wave
in
let vg_peak = ref 0.0 in
let g_peak = ref 0.0 in
for i = 0 to 999 do
let vv = Nx.item [ i ] (Spectrum.values voigt_g) in
let gv = Nx.item [ i ] (Spectrum.values gauss) in
if vv > !vg_peak then vg_peak := vv;
if gv > !g_peak then g_peak := gv
done;
is_true
~msg:
(Printf.sprintf "Voigt(sigma>>gamma) peak ~ Gaussian peak: %.4f vs %.4f"
!vg_peak !g_peak)
(Float.abs (!vg_peak -. !g_peak) /. !g_peak < 0.05)
let test_spectrum_line_composability () =
let wave = Unit.Length.of_m (Nx.linspace f64 6e-7 7e-7 500) in
let continuum =
Spectrum.power_law ~amplitude:(Nx.scalar f64 1e-15)
~index:(Nx.scalar f64 (-2.0)) ~pivot:(Unit.Length.nm 650.0)
~wavelength:wave
in
let ha =
Spectrum.gaussian ~amplitude:(Nx.scalar f64 1e-15)
~center:(Unit.Length.nm 656.3) ~stddev:(Unit.Length.nm 0.5)
~wavelength:wave
in
let composite = Spectrum.add continuum ha in
let cont_val = Nx.item [ 0 ] (Spectrum.values continuum) in
let comp_val = Nx.item [ 0 ] (Spectrum.values composite) in
is_true ~msg:"Composite spectrum at wing ~ continuum"
(Float.abs (comp_val -. cont_val) /. cont_val < 0.01)
(* Altaz: airmass *)
let test_altaz_airmass_zenith () =
let hz =
Altaz.of_coord
~obstime:(Time.of_iso "2024-06-21T12:00:00")
~observer:
(Altaz.make_observer ~lat:(Unit.Angle.deg 45.0)
~lon:(Unit.Angle.deg 0.0) ())
(Coord.of_radec
~ra:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 0.0 |]))
~dec:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 89.0 |])))
in
let x = Altaz.airmass hz in
let x0 = Nx.item [ 0 ] x in
is_true ~msg:(Printf.sprintf "Airmass >= 1.0, got %.4f" x0) (x0 >= 1.0)
let test_altaz_airmass_low_alt () =
let obs =
Altaz.make_observer ~lat:(Unit.Angle.deg 30.0) ~lon:(Unit.Angle.deg 0.0) ()
in
let t = Time.of_iso "2024-06-21T22:00:00" in
let star_a =
Coord.of_radec
~ra:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 0.0 |]))
~dec:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 80.0 |]))
in
let star_b =
Coord.of_radec
~ra:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 180.0 |]))
~dec:(Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 10.0 |]))
in
let hz_a = Altaz.of_coord ~obstime:t ~observer:obs star_a in
let hz_b = Altaz.of_coord ~obstime:t ~observer:obs star_b in
let x_a = Nx.item [ 0 ] (Altaz.airmass hz_a) in
let x_b = Nx.item [ 0 ] (Altaz.airmass hz_b) in
is_true
~msg:(Printf.sprintf "Both airmasses >= 1: %.2f, %.2f" x_a x_b)
(x_a >= 1.0 && x_b >= 1.0);
is_true
~msg:(Printf.sprintf "Different airmasses: %.2f vs %.2f" x_a x_b)
(Float.abs (x_a -. x_b) > 0.01)
(* Cosmo: BAO distances *)
let test_cosmo_dh () =
let p = Cosmo.planck18 in
let z = Nx.scalar f64 0.0 in
let dh0 = v (Unit.Length.in_mpc (Cosmo.dh ~p z)) in
let h0 = Nx.item [] (Cosmo.h0 p) in
let expected = 299792.458 /. h0 in
is_true
~msg:(Printf.sprintf "D_H(0) = c/H0 ~ %.1f Mpc, got %.1f" expected dh0)
(Float.abs (dh0 -. expected) /. expected < 1e-4)
let test_cosmo_dm_flat () =
let p = Cosmo.planck18 in
let z = Nx.scalar f64 0.5 in
let dm_val = v (Unit.Length.in_mpc (Cosmo.dm ~p z)) in
let dc_val = v (Unit.Length.in_mpc (Cosmo.comoving_distance ~p z)) in
is_true
~msg:(Printf.sprintf "D_M = D_C for flat: %.1f vs %.1f" dm_val dc_val)
(Float.abs (dm_val -. dc_val) /. dc_val < 1e-4)
let test_cosmo_dv () =
let p = Cosmo.planck18 in
let z = Nx.scalar f64 0.5 in
let dv_val = v (Unit.Length.in_mpc (Cosmo.dv ~p z)) in
is_true ~msg:(Printf.sprintf "D_V(0.5) > 0, got %.1f" dv_val) (dv_val > 0.0);
let dh_val = v (Unit.Length.in_mpc (Cosmo.dh ~p z)) in
let dm_val = v (Unit.Length.in_mpc (Cosmo.dm ~p z)) in
let z_f = 0.5 in
let expected = (z_f *. dh_val *. dm_val *. dm_val) ** (1.0 /. 3.0) in
is_true
~msg:
(Printf.sprintf "D_V = (z D_H D_M^2)^{1/3}: %.1f vs %.1f" dv_val expected)
(Float.abs (dv_val -. expected) /. expected < 1e-3)
let test_cosmo_sound_horizon () =
let p = Cosmo.planck18 in
let rs = v (Unit.Length.in_mpc (Cosmo.sound_horizon ~p ())) in
is_true
~msg:(Printf.sprintf "r_s(Planck18) ~ 147 Mpc, got %.1f" rs)
(Float.abs (rs -. 147.0) < 5.0)
(* Filters *)
let test_filters_sdss_pivot () =
let bp = Filters.sdss_r in
let lam_p = v (Unit.Length.in_nm (Photometry.pivot_wavelength bp)) in
is_true
~msg:(Printf.sprintf "SDSS r pivot ~ 620 nm, got %.0f" lam_p)
(Float.abs (lam_p -. 620.0) < 30.0)
let test_filters_johnson_v_pivot () =
let bp = Filters.johnson_v in
let lam_p = v (Unit.Length.in_nm (Photometry.pivot_wavelength bp)) in
is_true
~msg:(Printf.sprintf "Johnson V pivot ~ 551 nm, got %.0f" lam_p)
(Float.abs (lam_p -. 551.0) < 20.0)
let test_filters_twomass_j_pivot () =
let bp = Filters.twomass_j in
let lam_p = v (Unit.Length.in_nm (Photometry.pivot_wavelength bp)) in
is_true
~msg:(Printf.sprintf "2MASS J pivot ~ 1235 nm, got %.0f" lam_p)
(Float.abs (lam_p -. 1235.0) < 30.0)
let test_filters_gaia_ordering () =
let bp_p =
v (Unit.Length.in_nm (Photometry.pivot_wavelength Filters.gaia_bp))
in
let g_p =
v (Unit.Length.in_nm (Photometry.pivot_wavelength Filters.gaia_g))
in
let rp_p =
v (Unit.Length.in_nm (Photometry.pivot_wavelength Filters.gaia_rp))
in
is_true
~msg:(Printf.sprintf "Gaia: BP < G < RP: %.0f < %.0f < %.0f" bp_p g_p rp_p)
(bp_p < g_p && g_p < rp_p)
let test_filters_photometry () =
let bp = Filters.sdss_g in
let wave = Photometry.wavelength bp in
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 5800.0) in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:wave
|> Spectrum.as_flux_density
in
let mag = Nx.item [] (Photometry.ab_mag bp sed) in
is_true
~msg:(Printf.sprintf "BB(5800K) through SDSS g is finite, got %.2f" mag)
(Float.is_finite mag)
(* Photometry: auto-resample *)
let test_photometry_auto_resample () =
let bp = Filters.sdss_g in
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 5800.0) in
let wave_fine = Unit.Length.of_m (Nx.linspace f64 3e-7 1.1e-6 1000) in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:wave_fine
|> Spectrum.as_flux_density
in
let mag = Nx.item [] (Photometry.ab_mag bp sed) in
is_true
~msg:
(Printf.sprintf "Auto-resample: BB(5800K) through SDSS g finite, got %.2f"
mag)
(Float.is_finite mag);
let manual = Spectrum.resample ~wavelength:(Photometry.wavelength bp) sed in
let mag_manual = Nx.item [] (Photometry.ab_mag bp manual) in
is_true
~msg:
(Printf.sprintf "Auto-resample matches manual: %.4f vs %.4f" mag
mag_manual)
(Float.abs (mag -. mag_manual) < 1e-10)
(* Photometry: ST magnitude *)
let test_photometry_st_mag () =
let bp =
Photometry.tophat ~lo:(Unit.Length.nm 400.0) ~hi:(Unit.Length.nm 700.0)
~n:100
in
let wave = Photometry.wavelength bp in
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 5800.0) in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:wave
|> Spectrum.as_flux_density
in
let st = Nx.item [] (Photometry.st_mag bp sed) in
let ab = Nx.item [] (Photometry.ab_mag bp sed) in
is_true ~msg:(Printf.sprintf "ST mag is finite: %.2f" st) (Float.is_finite st);
is_true
~msg:(Printf.sprintf "ST and AB differ: ST=%.2f AB=%.2f" st ab)
(Float.abs (st -. ab) > 0.01)
(* Photometry: Vega magnitude *)
let test_photometry_vega_mag () =
let bp = Filters.johnson_v in
let wave = Photometry.wavelength bp in
let temp = Unit.Temperature.of_kelvin (Nx.scalar f64 9600.0) in
let sed =
Spectrum.blackbody ~temperature:temp ~wavelength:wave
|> Spectrum.as_flux_density
in
let vm = Nx.item [] (Photometry.vega_mag bp sed) in
is_true
~msg:(Printf.sprintf "Vega mag of hot BB through V is finite: %.2f" vm)
(Float.is_finite vm);
let ab = Nx.item [] (Photometry.ab_mag bp sed) in
is_true
~msg:(Printf.sprintf "Vega and AB differ: V=%.2f AB=%.2f" vm ab)
(Float.abs (vm -. ab) > 0.001)
(* Photometry: effective wavelength *)
let test_photometry_effective_wavelength () =
let bp =
Photometry.tophat ~lo:(Unit.Length.nm 400.0) ~hi:(Unit.Length.nm 700.0)
~n:100
in
let wave = Photometry.wavelength bp in
let flat_vals = Nx.ones f64 [| 100 |] in
let flat =
Spectrum.create ~wavelength:wave ~values:flat_vals
|> Spectrum.as_flux_density
in
let lam_eff =
v (Unit.Length.in_nm (Photometry.effective_wavelength bp flat))
in
let lam_pivot = v (Unit.Length.in_nm (Photometry.pivot_wavelength bp)) in
is_true
~msg:
(Printf.sprintf "Flat spectrum: eff_wavelength in range: %.1f nm" lam_eff)
(lam_eff > 500.0 && lam_eff < 600.0);
is_true
~msg:
(Printf.sprintf "eff_wavelength >= pivot for flat/tophat: %.1f vs %.1f"
lam_eff lam_pivot)
(lam_eff >= lam_pivot)
(* Altaz: atmospheric refraction *)
let test_altaz_refraction () =
let obs =
Altaz.make_observer ~lat:(Unit.Angle.deg 45.0) ~lon:(Unit.Angle.deg 0.0) ()
in
let t = Time.of_iso "2024-06-15T12:00:00" in
let ra = Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 37.95 |]) in
let dec = Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 89.264 |]) in
let c = Coord.of_radec ~ra ~dec in
let hz_no = Altaz.of_coord ~refraction:false ~obstime:t ~observer:obs c in
let hz_yes = Altaz.of_coord ~refraction:true ~obstime:t ~observer:obs c in
let alt_no = Nx.item [ 0 ] (Unit.Angle.in_deg (Altaz.alt hz_no)) in
let alt_yes = Nx.item [ 0 ] (Unit.Angle.in_deg (Altaz.alt hz_yes)) in
(* Refraction makes objects appear higher *)
is_true
~msg:
(Printf.sprintf "Refraction raises altitude: %.4f > %.4f" alt_yes alt_no)
(alt_yes > alt_no);
(* At ~45° alt, refraction is ~1 arcmin = 0.017° *)
let diff = alt_yes -. alt_no in
is_true
~msg:(Printf.sprintf "Refraction at ~45° is small (< 0.1°): %.4f" diff)
(diff > 0.0 && diff < 0.1)
let test_altaz_refraction_standalone () =
let obs =
Altaz.make_observer ~lat:(Unit.Angle.deg 45.0) ~lon:(Unit.Angle.deg 0.0) ()
in
let t = Time.of_iso "2024-06-15T12:00:00" in
let ra = Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 37.95 |]) in
let dec = Unit.Angle.of_deg (Nx.create f64 [| 1 |] [| 89.264 |]) in
let c = Coord.of_radec ~ra ~dec in
let hz = Altaz.of_coord ~obstime:t ~observer:obs c in
let r = Altaz.refraction hz in
let r_arcmin = Nx.item [ 0 ] (Unit.Angle.in_deg r) *. 60.0 in
is_true
~msg:(Printf.sprintf "Refraction > 0 arcmin: %.2f" r_arcmin)
(r_arcmin > 0.0);
is_true
~msg:(Printf.sprintf "Refraction < 2 arcmin at high alt: %.2f" r_arcmin)
(r_arcmin < 2.0)
(* Survey: shear multiplicative bias *)
let test_survey_m_bias () =
let nz = Survey.smail ~a:2.0 ~b:1.5 ~z0:0.5 () in
let ell = Nx.logspace f64 1.0 3.0 20 in
let wl_no_bias = Survey.weak_lensing ~n_gal:26.0 nz in
let wl_with_bias = Survey.weak_lensing ~m_bias:0.02 ~n_gal:26.0 nz in
let cls_no = Survey.angular_cl ~ell [ wl_no_bias ] in
let cls_yes = Survey.angular_cl ~ell [ wl_with_bias ] in
let cl_no = Survey.Cls.get cls_no ~i:0 ~j:0 in
let cl_yes = Survey.Cls.get cls_yes ~i:0 ~j:0 in
(* Auto-spectrum scales as (1+m)^2 = 1.0404 *)
let ratio = Nx.item [ 10 ] (Nx.div cl_yes cl_no) in
let expected = 1.02 *. 1.02 in
is_true
~msg:
(Printf.sprintf
"m_bias=0.02 scales auto-Cl by (1+m)^2: ratio=%.4f vs %.4f" ratio
expected)
(Float.abs (ratio -. expected) < 1e-4);
(* m_bias=0.0 gives same result as no bias *)
let wl_zero_bias = Survey.weak_lensing ~m_bias:0.0 ~n_gal:26.0 nz in
let cls_zero = Survey.angular_cl ~ell [ wl_zero_bias ] in
let cl_zero = Survey.Cls.get cls_zero ~i:0 ~j:0 in
let diff = Nx.item [] (Nx.max (Nx.abs (Nx.sub cl_zero cl_no))) in
is_true
~msg:(Printf.sprintf "m_bias=0.0 matches no bias: max_diff=%.2e" diff)
(diff < 1e-30)
(* Spectrum: differentiable resample *)
let test_spectrum_resample_values () =
let wave = Unit.Length.of_m (Nx.linspace f64 1e-7 1e-5 100) in
let sed =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar f64 5800.0))
~wavelength:wave
in
let new_wave = Unit.Length.of_m (Nx.linspace f64 2e-7 9e-6 50) in
let resampled = Spectrum.resample ~wavelength:new_wave sed in
let vals = Spectrum.values resampled in
let n = Nx.numel vals in
is_true ~msg:(Printf.sprintf "Resampled has %d points" n) (n = 50);
let v0 = Nx.item [ 0 ] vals in
is_true ~msg:(Printf.sprintf "Resampled values positive: %.2e" v0) (v0 > 0.0);
let vmax = Nx.item [] (Nx.max vals) in
is_true
~msg:(Printf.sprintf "Resampled max is finite: %.2e" vmax)
(Float.is_finite vmax)
(* Survey: baryonic feedback *)
let test_survey_baryonic_feedback () =
let nz = Survey.smail ~a:2.0 ~b:1.5 ~z0:0.5 () in
let ell = Nx.logspace f64 1.0 3.0 20 in
let wl = Survey.weak_lensing ~n_gal:26.0 nz in
let cls_dm = Survey.angular_cl ~power:Survey.nonlinear ~ell [ wl ] in
let power_bary = Survey.baryonic_feedback ~a_bary:0.2 Survey.nonlinear in
let cls_bary = Survey.angular_cl ~power:power_bary ~ell [ wl ] in
let cl_dm = Survey.Cls.get cls_dm ~i:0 ~j:0 in
let cl_bary = Survey.Cls.get cls_bary ~i:0 ~j:0 in
(* Baryonic feedback suppresses small-scale (high-ell) power *)
let ratio_high = Nx.item [ 19 ] (Nx.div cl_bary cl_dm) in
is_true
~msg:
(Printf.sprintf "Baryonic suppression at high ell: ratio=%.4f < 1"
ratio_high)
(ratio_high < 1.0);
(* a_bary=0 gives same result as no feedback *)
let power_zero = Survey.baryonic_feedback ~a_bary:0.0 Survey.nonlinear in
let cls_zero = Survey.angular_cl ~power:power_zero ~ell [ wl ] in
let cl_zero = Survey.Cls.get cls_zero ~i:0 ~j:0 in
let diff = Nx.item [] (Nx.max (Nx.abs (Nx.sub cl_zero cl_dm))) in
is_true
~msg:(Printf.sprintf "a_bary=0 matches DM-only: max_diff=%.2e" diff)
(diff < 1e-30)
(* Batched spectra *)
let test_batch_create () =
let wave = Unit.Length.of_m (Nx.linspace f64 1e-7 1e-6 100) in
let v1 = Nx.ones f64 [| 100 |] in
let v2 = Nx.full f64 [| 100 |] 2.0 in
let values = Nx.stack [ v1; v2 ] in
let s =
Spectrum.create ~wavelength:wave ~values |> Spectrum.as_flux_density
in
let sh = Nx.shape (Spectrum.values s) in
is_true ~msg:"batch values shape [2; 100]"
(Array.length sh = 2 && sh.(0) = 2 && sh.(1) = 100)
let test_batch_resample () =
let wave = Unit.Length.of_m (Nx.linspace f64 1e-7 1e-6 100) in
let wave2 = Unit.Length.of_m (Nx.linspace f64 2e-7 8e-7 50) in
let bb1 =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar f64 5000.0))
~wavelength:wave
|> Spectrum.as_sampled
in
let bb2 =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar f64 8000.0))
~wavelength:wave
|> Spectrum.as_sampled
in
let values = Nx.stack [ Spectrum.values bb1; Spectrum.values bb2 ] in
let batch = Spectrum.create ~wavelength:wave ~values |> Spectrum.as_sampled in
let resampled = Spectrum.resample ~wavelength:wave2 batch in
let r_shape = Nx.shape (Spectrum.values resampled) in
is_true ~msg:"batch resample shape [2; 50]"
(Array.length r_shape = 2 && r_shape.(0) = 2 && r_shape.(1) = 50);
let r1 = Spectrum.resample ~wavelength:wave2 bb1 in
let r2 = Spectrum.resample ~wavelength:wave2 bb2 in
let expected = Nx.stack [ Spectrum.values r1; Spectrum.values r2 ] in
let diff =
Nx.item [] (Nx.max (Nx.abs (Nx.sub (Spectrum.values resampled) expected)))
in
is_true
~msg:
(Printf.sprintf "batch resample matches individual: max_diff=%.2e" diff)
(diff < 1e-20)
let test_batch_ab_mag () =
let wave = Unit.Length.of_m (Nx.linspace f64 3e-7 8e-7 200) in
let bp =
Photometry.tophat ~lo:(Unit.Length.nm 400.0) ~hi:(Unit.Length.nm 600.0)
~n:100
in
let bb1 =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar f64 5000.0))
~wavelength:wave
|> Spectrum.as_flux_density
in
let bb2 =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar f64 8000.0))
~wavelength:wave
|> Spectrum.as_flux_density
in
let values = Nx.stack [ Spectrum.values bb1; Spectrum.values bb2 ] in
let batch =
Spectrum.create ~wavelength:wave ~values |> Spectrum.as_flux_density
in
let mags_batch = Photometry.ab_mag bp batch in
let mag1 = Photometry.ab_mag bp bb1 in
let mag2 = Photometry.ab_mag bp bb2 in
let expected = Nx.stack [ mag1; mag2 ] in
let diff = Nx.item [] (Nx.max (Nx.abs (Nx.sub mags_batch expected))) in
is_true
~msg:(Printf.sprintf "batch ab_mag matches individual: max_diff=%.2e" diff)
(diff < 1e-10)
let test_batch_extinction () =
let wave = Unit.Length.of_m (Nx.linspace f64 3e-7 8e-7 200) in
let rv = Nx.scalar f64 3.1 in
let av = Nx.scalar f64 0.5 in
let bb1 =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar f64 5000.0))
~wavelength:wave
|> Spectrum.as_flux_density
in
let bb2 =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar f64 8000.0))
~wavelength:wave
|> Spectrum.as_flux_density
in
let values = Nx.stack [ Spectrum.values bb1; Spectrum.values bb2 ] in
let batch =
Spectrum.create ~wavelength:wave ~values |> Spectrum.as_flux_density
in
let reddened = Extinction.apply (Extinction.ccm89 ~rv) ~av batch in
let r1 = Extinction.apply (Extinction.ccm89 ~rv) ~av bb1 in
let r2 = Extinction.apply (Extinction.ccm89 ~rv) ~av bb2 in
let expected = Nx.stack [ Spectrum.values r1; Spectrum.values r2 ] in
let diff =
Nx.item [] (Nx.max (Nx.abs (Nx.sub (Spectrum.values reddened) expected)))
in
is_true
~msg:
(Printf.sprintf "batch extinction matches individual: max_diff=%.2e" diff)
(diff < 1e-25)
let test_batch_scale () =
let wave = Unit.Length.of_m (Nx.linspace f64 1e-7 1e-6 100) in
let v1 = Nx.ones f64 [| 100 |] in
let v2 = Nx.full f64 [| 100 |] 2.0 in
let values = Nx.stack [ v1; v2 ] in
let batch = Spectrum.create ~wavelength:wave ~values in
let scaled = Spectrum.scale (Nx.scalar f64 3.0) batch in
let sv = Spectrum.values scaled in
let expected =
Nx.stack [ Nx.full f64 [| 100 |] 3.0; Nx.full f64 [| 100 |] 6.0 ]
in
let diff = Nx.item [] (Nx.max (Nx.abs (Nx.sub sv expected))) in
is_true ~msg:"batch scalar scale" (diff < 1e-15)
let test_batch_redshift_scalar () =
let wave = Unit.Length.of_m (Nx.linspace f64 1e-7 1e-6 100) in
let bb1 =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar f64 5000.0))
~wavelength:wave
|> Spectrum.as_flux_density
in
let bb2 =
Spectrum.blackbody
~temperature:(Unit.Temperature.of_kelvin (Nx.scalar f64 8000.0))
~wavelength:wave
|> Spectrum.as_flux_density
in
let values = Nx.stack [ Spectrum.values bb1; Spectrum.values bb2 ] in
let batch =
Spectrum.create ~wavelength:wave ~values |> Spectrum.as_flux_density
in
let z = Nx.scalar f64 0.5 in
let shifted = Spectrum.redshift ~z batch in
let s1 = Spectrum.redshift ~z bb1 in
let s2 = Spectrum.redshift ~z bb2 in
let expected = Nx.stack [ Spectrum.values s1; Spectrum.values s2 ] in
let diff =
Nx.item [] (Nx.max (Nx.abs (Nx.sub (Spectrum.values shifted) expected)))
in
is_true
~msg:
(Printf.sprintf "batch redshift matches individual: max_diff=%.2e" diff)
(diff < 1e-20)
let test_batch_create_mismatch () =
let wave = Unit.Length.of_m (Nx.linspace f64 1e-7 1e-6 100) in
let values = Nx.ones f64 [| 3; 50 |] in
let raised =
try
ignore (Spectrum.create ~wavelength:wave ~values);
false
with Invalid_argument _ -> true
in
is_true ~msg:"mismatched last dim raises" raised
let test_batch_roundtrip () =
let wave = Unit.Length.of_m (Nx.linspace f64 1e-7 1e-6 100) in
let v1 = Nx.ones f64 [| 100 |] in
let v2 = Nx.full f64 [| 100 |] 2.0 in
let v3 = Nx.full f64 [| 100 |] 3.0 in
let values = Nx.stack [ v1; v2; v3 ] in
let batch = Spectrum.create ~wavelength:wave ~values in
let extracted = Nx.get [ 1 ] (Spectrum.values batch) in
let diff = Nx.item [] (Nx.max (Nx.abs (Nx.sub extracted v2))) in
is_true ~msg:"extract second spectrum from batch" (diff < 1e-15)
let () =
run "Umbra"
[
group "Unit"
[
test "10 kpc converts to 0.01 Mpc" test_length_conversion;
test "10 kpc + 500 pc = 10.5 kpc" test_length_arithmetic;
test "1 solar mass is ~1.988e30 kg" test_mass_conversion;
test "100 km / 10 s = 10 km/s" test_velocity_cross_dim;
test "sin(90) = 1 and cos(90) = 0" test_angle_trig;
test "wavelength to frequency roundtrips" test_wavelength_frequency;
test "phantom types prevent adding length and mass"
test_phantom_type_safety;
test "2 eV survives energy-wavelength-frequency roundtrip"
test_energy_wavelength_frequency;
];
group "Const" [ test "speed of light is ~299792 km/s" test_const_c ];
group "Time"
[
test "J2000.0 JD and MJD values are correct" test_time_jd_mjd;
test "ISO 8601 parse and format roundtrip" test_time_iso;
test "UTC to TAI offset is 32s at J2000" test_time_utc_tai_tt;
test "TDB-TT difference is less than 2 ms" test_time_tdb;
test "Unix epoch maps to JD 2440587.5" test_time_unix;
test "diff and add with 1-day offset" test_time_diff_add;
];
group "Coord"
[
test "ICRS to Galactic and back preserves RA/Dec" test_coord_roundtrip;
test "ICRS to Ecliptic and back preserves RA/Dec"
test_coord_ecliptic_roundtrip;
test "ICRS to Supergalactic and back preserves RA/Dec"
test_coord_supergalactic_roundtrip;
test "north pole to south pole separation is 180 deg"
test_separation_poles;
test "nearest self-match returns identity indices"
test_match_nearest_self;
];
group "Cosmo"
[
test "H(0) equals H0 = 70 km/s/Mpc" test_cosmo_hubble;
test "E(z=0) = 1 for any cosmology" test_cosmo_e_of;
test "comoving(0.1) ~ 421 Mpc and luminosity(0.1) ~ 463 Mpc"
test_cosmo_distances;
test "lookback time at z=1 is ~7.7 Gyr" test_cosmo_lookback;
test "1 kpc at z=0.022 subtends ~2.3 arcsec" test_cosmo_angular_scale;
test "Planck18 comoving(0.5) ~ 1960 Mpc" test_cosmo_planck18;
test "flat_lcdm(70, 0.3) matches default cosmology"
test_cosmo_flat_lcdm_same_as_default;
test "non-flat LCDM differs from flat" test_cosmo_nonflat_lcdm;
test "wCDM with w0=-1 reduces to LCDM" test_cosmo_wcdm;
test "w0waCDM with w0=-1 wa=0 reduces to LCDM" test_cosmo_w0wacdm;
test "z_at_value roundtrips luminosity distance" test_cosmo_z_at_value;
test "growth factor D(z=0) = 1" test_cosmo_growth_factor_z0;
test "growth factor D(z=1) ~ 0.61" test_cosmo_growth_factor_z1;
test "growth rate f(z=0) ~ 0.52" test_cosmo_growth_rate_z0;
test "growth factor decreases with redshift"
test_cosmo_growth_monotonic;
test "linear power spectrum is positive and decreases with z"
test_cosmo_linear_power;
test "nonlinear power exceeds linear at k=1 h/Mpc"
test_cosmo_nonlinear_power;
test "Planck18 omega_b, n_s, and sigma8 accessors"
test_cosmo_params_accessors;
test "D_H(0) = c/H0" test_cosmo_dh;
test "D_M equals D_C for flat geometry" test_cosmo_dm_flat;
test "D_V = (z * D_H * D_M^2)^(1/3)" test_cosmo_dv;
test "sound horizon r_s ~ 147 Mpc for Planck18"
test_cosmo_sound_horizon;
test "age of universe ~ 13.8 Gyr for Planck18" test_cosmo_age_planck18;
test "age(z=0) - age(z=1) = lookback(z=1)" test_cosmo_age_at_z1;
test "comoving distance to CMB ~ 14000 Mpc" test_cosmo_comoving_cmb;
test "comoving distances increase at z = 2, 5, 10"
test_cosmo_comoving_high_z;
test "lookback time at z=5 ~ 12.5 Gyr" test_cosmo_lookback_high_z;
];
group "Altaz"
[
test "ICRS to AltAz and back preserves RA/Dec" test_altaz_zenith;
test "Polaris altitude ~ observer latitude from lat=45"
test_altaz_north_pole;
test "airmass is >= 1.0 near zenith" test_altaz_airmass_zenith;
test "airmass differs for high vs low altitude stars"
test_altaz_airmass_low_alt;
test "refraction raises apparent altitude" test_altaz_refraction;
test "standalone refraction is between 0 and 2 arcmin at high alt"
test_altaz_refraction_standalone;
];
group "Galactocentric"
[
test "l=0 b=0 at galcen_distance maps to origin"
test_galactocentric_gc_position;
test "Galactocentric to ICRS roundtrips RA/Dec/distance"
test_galactocentric_roundtrip;
];
group "Spectrum"
[
test "scale by 3 multiplies all values" test_spectrum_scale;
test "multiply spectrum by transmission" test_spectrum_mul;
test "divide spectrum by flat transmission" test_spectrum_div;
test "mul then div roundtrips to original"
test_spectrum_mul_div_roundtrip;
test "blackbody peak obeys Wien's displacement law"
test_spectrum_blackbody_wien;
test "redshift z=1 doubles wavelength and halves flux"
test_spectrum_redshift;
test "resample preserves positivity and finiteness"
test_spectrum_resample_values;
test "Gaussian line peaks at 656.3 nm with unit amplitude"
test_spectrum_gaussian_peak;
test "Lorentzian line peaks at 500 nm with amplitude 3"
test_spectrum_lorentzian_peak;
test "Voigt with sigma >> gamma matches Gaussian"
test_spectrum_voigt_limits;
test "power-law continuum plus Gaussian line composes cleanly"
test_spectrum_line_composability;
];
group "Extinction"
[
test "CCM89 A_V/A_V ~ 1.0 at V-band 550 nm"
test_extinction_ccm89_v_band;
test "CCM89 extinction increases toward blue"
test_extinction_ccm89_monotonic;
test "apply then unredden recovers original spectrum"
test_extinction_apply_unredden;
];
group "Photometry"
[
test "flat f_nu = 3631 Jy gives m_AB ~ 0" test_photometry_ab_mag_flat;
test "same-band color is zero" test_photometry_color_same_band;
test "hot star is brighter in blue than red"
test_photometry_blue_star_color;
test "auto-resample matches manual resample"
test_photometry_auto_resample;
test "ST and AB magnitudes differ for a blackbody"
test_photometry_st_mag;
test "Vega and AB magnitudes differ through Johnson V"
test_photometry_vega_mag;
test "effective wavelength is in range for flat tophat spectrum"
test_photometry_effective_wavelength;
];
group "Filters"
[
test "SDSS r pivot wavelength ~ 620 nm" test_filters_sdss_pivot;
test "Johnson V pivot wavelength ~ 551 nm"
test_filters_johnson_v_pivot;
test "2MASS J pivot wavelength ~ 1235 nm" test_filters_twomass_j_pivot;
test "Gaia BP < G < RP pivot ordering" test_filters_gaia_ordering;
test "5800 K blackbody through SDSS g yields finite magnitude"
test_filters_photometry;
];
group "Survey"
[
test "Smail n(z) integrates to 1.0" test_survey_smail_normalized;
test "tabulated n(z) is positive at midpoint and zero outside"
test_survey_tabulated;
test "C_l matrix has correct shape for 2 tracers" test_survey_cl_shape;
test "auto C_l is positive at all ell" test_survey_cl_positive;
test "weak lensing noise is constant in ell" test_survey_noise_wl;
test "shear m_bias=0.02 scales auto C_l by (1+m)^2" test_survey_m_bias;
test "baryonic feedback suppresses high-ell power"
test_survey_baryonic_feedback;
];
group "FITS"
[
test "2x3 float32 image writes and reads back"
test_fits_image_roundtrip;
test "3-row table with ra/dec writes and reads back"
test_fits_table_roundtrip;
];
group "Batch"
[
test "batch of 2 spectra has shape [2; 100]" test_batch_create;
test "mismatched wavelength and values dims raises"
test_batch_create_mismatch;
test "extract second spectrum from batch" test_batch_roundtrip;
test "scalar scale applies to all spectra in batch" test_batch_scale;
test "batch resample matches per-spectrum resample"
test_batch_resample;
test "batch AB magnitudes match per-spectrum magnitudes"
test_batch_ab_mag;
test "batch extinction matches per-spectrum extinction"
test_batch_extinction;
test "batch redshift matches per-spectrum redshift"
test_batch_redshift_scalar;
];
]
================================================
FILE: doc/coming-from-python.md
================================================
# Coming from Python
This page maps Python scientific computing concepts to their Raven equivalents. It assumes you already know OCaml basics.
## Library Mapping
| Python | Raven | Notes |
|--------|-------|-------|
| NumPy | [Nx](/docs/nx/) | N-dimensional arrays, broadcasting, linear algebra, FFT |
| JAX | [Rune](/docs/rune/) | Functional transformations: `grad`, `jvp`, `vmap` |
| PyTorch / Flax | [Kaun](/docs/kaun/) | Layers, optimizers, training loops |
| HuggingFace Tokenizers | [Brot](/docs/brot/) | BPE, WordPiece, Unigram; HF-compatible |
| pandas / Polars | [Talon](/docs/talon/) | Type-safe DataFrames |
| Matplotlib | [Hugin](/docs/hugin/) | 2D/3D plotting with Cairo |
| Gymnasium | [Fehu](/docs/fehu/) | RL environments and training utilities |
| OpenCV | [Sowilo](/docs/sowilo/) | Differentiable image processing |
| Jupyter + IPython | [Quill](/docs/quill/) | Interactive REPL and markdown notebooks |
## Key Differences
### Explicit Types
NumPy casts types silently. Nx does not.
```python
# Python: silently upcasts int + float -> float
a = np.array([1, 2, 3])
b = a + 1.5 # works
```
```ocaml
(* OCaml: types must match *)
let a = Nx.create Nx.Int32 [|3|] [|1l; 2l; 3l|]
(* Nx.add a (Nx.scalar Nx.Float32 1.5) -- type error *)
(* Cast explicitly *)
let a_f = Nx.astype Nx.Float32 a
let b = Nx.add a_f (Nx.scalar Nx.Float32 1.5)
```
### Array Literals
NumPy uses Python lists. Nx uses OCaml arrays with `[| |]` syntax.
```python
x = np.array([[1, 2], [3, 4]])
```
```ocaml
let x = Nx.create Nx.Float32 [|2; 2|] [|1.; 2.; 3.; 4.|]
```
### Slicing
NumPy uses `[]` with `:`. Nx uses the `slice` function with index constructors.
```python
x[0:2, :] # first two rows
x[:, 1] # second column
x[::2] # every other element
```
```ocaml
Nx.slice [R (0, 2); A] x (* first two rows *)
Nx.slice [A; I 1] x (* second column *)
Nx.slice [S (0, -1, 2)] x (* every other element *)
```
### No Separate Tensor Type
In PyTorch, `torch.Tensor` is different from `numpy.ndarray`. In Raven, Rune operates directly on `Nx.t` values. There is no wrapper type.
```python
# PyTorch: convert between types
x_np = np.array([1.0, 2.0])
x_torch = torch.from_numpy(x_np)
x_torch.requires_grad_(True)
```
```ocaml
(* Raven: just use Nx tensors directly *)
let x = Nx.create Nx.Float32 [|2|] [|1.0; 2.0|]
let gradient = Rune.grad (fun x -> Nx.sum (Nx.mul x x)) x
```
### Functional Transformations
JAX users will find Rune familiar. PyTorch users: think of `grad` as a function transformer, not a method on tensors.
```python
# JAX style
grad_fn = jax.grad(loss_fn)
grads = grad_fn(params)
# PyTorch style
loss = loss_fn(params)
loss.backward()
grads = params.grad
```
```ocaml
(* Rune: JAX-style functional transforms *)
let grad_fn = Rune.grad loss_fn
let grads = grad_fn params
(* Or compute value and gradient together *)
let loss, grads = Rune.value_and_grad loss_fn params
```
### Module-Based Layers
Kaun layers are records with `init` and `apply`, not classes with `forward`.
```python
# PyTorch
class Model(nn.Module):
def __init__(self):
self.linear = nn.Linear(784, 10)
def forward(self, x):
return self.linear(x)
model = Model()
```
```ocaml
(* Kaun: compose layer records *)
let model = Kaun.Layer.sequential [
Kaun.Layer.linear ~in_features:784 ~out_features:10 ();
]
let vars = Kaun.Layer.init model ~dtype:Nx.Float32
```
Parameters are plain data (`Ptree.t` — a tree of Nx tensors), not hidden inside objects.
### DataFrames
pandas uses string-based column access. Talon provides type-safe row operations via an applicative.
```python
# pandas
df['bmi'] = df['weight'] / df['height'] ** 2
```
```ocaml
(* Talon: type-safe row computation *)
let df = Talon.with_column df "bmi" Nx.Float64
Talon.Row.(map2 (number "weight") (number "height")
~f:(fun w h -> w /. (h *. h)))
```
## Detailed Comparisons
Each library has a dedicated comparison page with side-by-side code examples:
- [Nx vs NumPy](/docs/nx/numpy-comparison/)
- [Rune vs JAX](/docs/rune/jax-comparison/)
- [Kaun vs PyTorch/Flax](/docs/kaun/pytorch-comparison/)
- [Brot vs HuggingFace Tokenizers](/docs/brot/hf-tokenizers-comparison/)
- [Talon vs pandas](/docs/talon/pandas-comparison/)
- [Hugin vs Matplotlib](/docs/hugin/matplotlib-comparison/)
- [Sowilo vs OpenCV](/docs/sowilo/opencv-comparison/)
- [Fehu vs Gymnasium](/docs/fehu/gymnasium-comparison/)
================================================
FILE: doc/ecosystem-overview.md
================================================
# The Raven Ecosystem
Raven is nine libraries that share one data type: `Nx.t`, the
n-dimensional array. Each library does one thing, and they compose
through tensors.
## How the Libraries Fit Together
```
┌───────────┐
│ Kaun │ neural networks
│ (Flax) │
└─────┬─────┘
│
┌───────────┐ ┌─────┴─────┐ ┌───────────┐
│ Sowilo │ │ Rune │ │ Fehu │
│ (OpenCV) ├──────────┤ (JAX) ├──────────┤(Gymnasium)│
└─────┬─────┘ └─────┬─────┘ └─────┬─────┘
│ │ │
┌─────┴──────────────────────┴──────────────────────┴─────┐
│ Nx │
│ (NumPy) │
└──┬──────────────┬──────────────┬──────────────┬─────────┘
│ │ │ │
┌───┴────┐ ┌────┴───┐ ┌────┴───┐ ┌─────┴────┐
│ Talon │ │ Brot │ │ Hugin │ │ Quill │
│(Polars)│ │(HF Tok)│ │(Mpl) │ │(Jupyter) │
└────────┘ └────────┘ └────────┘ └──────────┘
```
**Nx** is the foundation — every library operates on `Nx.t` tensors.
**Rune** adds functional transformations on top of Nx: `grad`, `jvp`,
`vmap`. Your Nx code becomes differentiable without changes.
**Kaun** builds on Rune to provide layers, optimizers, training loops,
and HuggingFace Hub integration.
**Sowilo**, **Fehu**, **Talon**, **Brot**, **Hugin**, and **Quill** each
use Nx directly for their domain. Sowilo and Fehu operations are
compatible with Rune's `grad` and `vmap` since they are plain Nx
operations under the hood.
## Which Library Do I Need?
| I want to... | Use |
|---|---|
| Work with numerical arrays | [Nx](/docs/nx/) |
| Compute gradients | [Rune](/docs/rune/) |
| Train neural networks | [Kaun](/docs/kaun/) |
| Tokenize text for language models | [Brot](/docs/brot/) |
| Manipulate tabular data | [Talon](/docs/talon/) |
| Process and transform images | [Sowilo](/docs/sowilo/) |
| Build RL environments and agents | [Fehu](/docs/fehu/) |
| Create plots and visualizations | [Hugin](/docs/hugin/) |
| Run code interactively (REPL or notebooks) | [Quill](/docs/quill/) |
---
## Nx: N-Dimensional Arrays
Nx provides the numerical foundation for the entire ecosystem.
NumPy-like operations on n-dimensional arrays with 19 data types
(float16 through complex128), broadcasting, slicing, linear algebra,
FFT, and I/O.
```ocaml
open Nx
let x = linspace Float32 0. 10. 100
let y = sin x
let mean_y = mean y
```
[Nx documentation →](/docs/nx/)
## Rune: Automatic Differentiation
Functional transformations for Nx tensors: reverse-mode AD (grad,
vjp), forward-mode AD (jvp), and vectorising maps (vmap). Operates on
`Nx.t` values directly using OCaml 5 effect handlers — no special
tensor type needed.
```ocaml
open Nx
open Rune
let f x = add (mul x x) (sin x)
let f' = grad f
let f'' = grad f'
```
[Rune documentation →](/docs/rune/)
## Kaun: Neural Networks
Composable layers, optimizers with learning-rate schedules, training
loops, data pipelines, and HuggingFace Hub integration. Model
parameters are `Ptree.t` — trees of Nx tensors you can inspect, map,
and serialize.
```ocaml
open Kaun
let model = Layer.sequential [
Layer.linear ~in_features:784 ~out_features:128 ();
Layer.relu ();
Layer.linear ~in_features:128 ~out_features:10 ();
]
let trainer = Train.make ~model
~optimizer:(Optim.adam ~lr:(Optim.Schedule.constant 0.001) ())
```
[Kaun documentation →](/docs/kaun/)
## Brot: Tokenization
Fast, HuggingFace-compatible tokenization supporting BPE, WordPiece,
Unigram, word-level, and character-level algorithms. Composable
pipeline (normalizer → pre-tokenizer → model → post-processor →
decoder) with training from scratch.
```ocaml
open Brot
let tokenizer = from_file "tokenizer.json" |> Result.get_ok
let encoding = encode tokenizer "Hello, world!"
let ids = Encoding.ids encoding
```
[Brot documentation →](/docs/brot/)
## Talon: DataFrames
Type-safe tabular data with heterogeneous columns, an applicative Row
system for row-wise operations, and vectorized aggregations backed by
Nx.
```ocaml
open Talon
let df = create [
"name", Col.string_list ["Alice"; "Bob"; "Charlie"];
"score", Col.float64_list [85.5; 92.0; 78.5];
]
let () = print df
```
[Talon documentation →](/docs/talon/)
## Sowilo: Computer Vision
Differentiable image processing: geometric transforms (resize, crop,
flip), spatial filters (Gaussian blur, Sobel, Canny), color space
conversions, and morphological operations. All operations are plain Nx
computations, so they compose with `Rune.grad` and `Rune.vmap`.
```ocaml
open Sowilo
let processed =
img
|> to_float
|> resize ~height:224 ~width:224 ~mode:Bilinear
|> normalize ~mean:[|0.485; 0.456; 0.406|] ~std:[|0.229; 0.224; 0.225|]
```
[Sowilo documentation →](/docs/sowilo/)
## Fehu: Reinforcement Learning
RL environments (CartPole, MountainCar, GridWorld), type-safe
observation/action spaces, vectorized environments, trajectory
collection, replay buffers, and generalized advantage estimation.
```ocaml
open Fehu
let env = Fehu_envs.cartpole () in
let obs, _info = Env.reset env in
let obs, reward, terminated, truncated, _info =
Env.step env (Space.sample (Env.action_space env))
```
[Fehu documentation →](/docs/fehu/)
## Hugin: Visualization
Publication-quality 2D and 3D plots using Cairo rendering. Takes Nx
tensors as input. Line plots, scatter, bar charts, contour plots,
image display.
```ocaml
open Hugin
open Nx
let fig = figure () in
let ax = subplot fig in
let _ = Plotting.plot ax ~x ~y ~label:"sin(x)" in
show fig
```
[Hugin documentation →](/docs/hugin/)
## Quill: Interactive Computing
Interactive REPL and markdown notebooks. Launch `quill` for a toplevel
with syntax highlighting, completion, and history, or open a markdown
file for a full notebook experience. Terminal UI, web frontend, and
batch mode with all Raven libraries pre-loaded.
```bash
quill # interactive REPL
quill notebook.md # notebook TUI
quill serve notebook.md # web frontend
quill run notebook.md # batch evaluation
```
[Quill documentation →](/docs/quill/)
## Getting Started
1. **New to Raven?** Start with the [Quickstart](/docs/quickstart/)
2. **Coming from Python?** Read [Coming from Python](/docs/coming-from-python/)
3. **Want a specific library?** Use the table above to find the right docs
================================================
FILE: doc/index.md
================================================
# Documentation
Welcome to Raven's documentation. Raven is an ecosystem of OCaml libraries for numerical computing, machine learning, and data science.
## Start Here
- **[Quickstart](/docs/quickstart/)** — zero to gradient in 5 minutes
- **[Coming from Python](/docs/coming-from-python/)** — map NumPy, PyTorch, pandas concepts to Raven
- **[Ecosystem Overview](/docs/ecosystem-overview/)** — how the libraries relate and which to use
## Libraries
| | Library | Like | What it does |
|-|---------|------|-------------|
| | [**nx**](/docs/nx/) | NumPy | N-dimensional arrays with pluggable backends |
| ᚱ | [**rune**](/docs/rune/) | JAX | Automatic differentiation and functional transformations |
| ᚲ | [**kaun**](/docs/kaun/) | PyTorch / Flax | Neural networks and training |
| ᚨ | [**brot**](/docs/brot/) | HF Tokenizers | Fast tokenization for language models |
| ᛃ | [**talon**](/docs/talon/) | Pandas / Polars | DataFrames with type-safe columns |
| ᛞ | [**hugin**](/docs/hugin/) | Matplotlib | Data visualization and plotting |
| ᛈ | [**quill**](/docs/quill/) | Jupyter + IPython | Interactive REPL and markdown notebooks |
| ᚠ | [**fehu**](/docs/fehu/) | Gymnasium | Reinforcement learning environments |
| ᛋ | [**sowilo**](/docs/sowilo/) | OpenCV | Differentiable computer vision |
## Project
- [Installation](/docs/installation/) — system dependencies, opam setup, building from source
- [Roadmap](/docs/roadmap/) — what works today and what's coming
- [Introduction](/docs/introduction/) — vision and philosophy
- [Support Raven](/docs/support-raven/) — sponsorship and contributing
================================================
FILE: doc/installation.md
================================================
# Installation
## Prerequisites
Raven requires **OCaml 5.2** or later and **opam**.
If you don't have opam installed, follow the [official instructions](https://opam.ocaml.org/doc/Install.html). Then create a switch:
```bash
opam switch create raven 5.2.0
eval $(opam env)
```
## Installing from opam
Install the entire ecosystem:
```bash
opam install raven
```
Or install individual libraries:
```bash
opam install nx # just arrays
opam install rune # arrays + autodiff
opam install kaun # arrays + autodiff + neural networks
opam install brot # tokenization
opam install talon # dataframes
```
## Building from Source
```bash
git clone https://github.com/raven-ml/raven
cd raven
dune pkg lock && dune build
```
To build a specific library:
```bash
dune build packages/nx # just nx
dune build packages/kaun # kaun + its dependencies
```
## System Dependencies
Most Raven libraries have no system dependencies beyond OCaml. The exceptions:
| Library | Requires | macOS | Ubuntu/Debian |
|---------|----------|-------|---------------|
| **hugin** | Cairo, SDL2 | `brew install cairo sdl2` | `apt install libcairo2-dev libsdl2-dev` |
## Using Raven in Your Project
Add libraries to your `dune-project`:
```dune
(lang dune 3.0)
(package
(name my_project)
(depends
ocaml
dune
nx
rune))
```
And your `dune` file:
```dune
(executable
(name main)
(libraries nx rune))
```
## Verify Your Installation
Create a file `main.ml`:
```ocaml
let () =
let open Nx in
let x = linspace Float32 0. 1. 5 in
print_data x
```
Build and run:
```bash
dune exec ./main.exe
```
You should see five evenly-spaced values printed.
## Editor Setup
For the best development experience, use an editor with OCaml LSP support:
- **VS Code**: Install the [OCaml Platform extension](https://marketplace.visualstudio.com/items?itemName=ocamllabs.ocaml-platform)
- **Emacs**: Use [ocaml-eglot](https://github.com/tarides/ocaml-eglot)
- **Vim/Neovim**: Use [ocaml-lsp](https://github.com/ocaml/ocaml-lsp) with your LSP client
## Troubleshooting
**Missing system libraries**: If Hugin fails to build, ensure Cairo and SDL2 development headers are installed.
**Opam switch issues**: Run `eval $(opam env)` after creating or switching opam switches.
**Build failures**: Check your OCaml version with `ocaml --version`. Raven requires 5.2.0 or later.
**Getting help**: Report issues at [github.com/raven-ml/raven/issues](https://github.com/raven-ml/raven/issues).
================================================
FILE: doc/introduction.md
================================================
# Introduction
Raven is a project to bring modern scientific computing to the OCaml programming language. We're building a comprehensive ecosystem, from low-level numerical libraries and automatic differentiation to high-level machine learning frameworks and interactive notebooks.
Our ambition is to make scientific computing in OCaml feel as natural as it does in Python. This means not just matching Python's capabilities, but delivering the same level of ergonomics, performance, and developer experience that has made Python the de facto standard for scientific computing.
If successful, Raven would establish OCaml as a genuine alternative in the scientific computing landscape. It's an ambitious undertaking, but one we believe is both necessary and achievable.
## Why Not Just Use Python?
Today, Python has an effective monopoly on scientific computing. Unlike web development, where we can choose between multiple mature ecosystems, numerical computing offers essentially one realistic option. This lack of choice is unfortunate.
What's more problematic is that Python, while excellent for quick experimentation, doesn't particularly shine for building robust production systems. Its interpreted nature, dynamic typing, and limited multicore support create real challenges when you need to deploy and maintain large-scale applications.
If you've worked in this space, you've likely experienced this firsthand: rapid prototypes that become production nightmares, debugging sessions where type errors only surface at runtime, or performance bottlenecks that force you to drop down to C extensions.
Often, this mismatch forces a wasteful pattern: researchers prototype in Python, then teams reimplement everything for production in other languages. This induces all kinds of second-order effects on organization structures, team dynamics, development velocity, and workload.
The scientific community deserves better options than being forced into one language, and we believe OCaml occupies a unique sweet spot between rapid experimentation and building production-grade systems. It just needs the scientific ecosystem to match its technical strengths. This is the gap that Raven aims to fill.
In the AI era, we believe OCaml has an important role to play. If you're generating 80% of your code with AI assistance, wouldn't you prefer a language that catches errors at compile time rather than runtime? The productivity gains from AI coding are amplified when you have a type system that gives you stronger guarantees about your generated code. Raven is our contribution to putting OCaml in the spotlight for scientific computing in this new era.
## What Does Success Look Like?
Our goal isn't just to build OCaml versions of Python libraries: it's to create a compelling alternative for busy developers who just want the best tool for the job.
Success means two things. First, **OCaml developers shouldn't have to switch to Python for numerical computing**. Whether you're analyzing data, training models, or building computational systems, you should be able to stay in the OCaml ecosystem with the same productivity you'd expect from Python.
Second, **Raven should break into the mainstream scientific computing conversation**. It shouldn't just serve existing OCaml developers: we're building for teams who need to ship reliable systems, not just an OCaml curiosity for language enthusiasts.
We measure success across five key dimensions:
- **Capability parity**: Everything you can do in Python, you should be able to do with Raven
- **Development productivity**: Getting from idea to working prototype is as fast as it would be in Python
- **Developer experience**: Developers get the kind of documentation, tooling, and APIs they dream every project had
- **Production performance**: Match or exceed NumPy/PyTorch performance on the fast path
- **Production readiness**: Teams can ship robust, maintainable Raven-built applications that perform well under real-world conditions
We believe this is achievable through focused execution and strategic choices. We're prioritizing the 80% that matter most, focusing on one blessed workflow per use-case, and building modular components that encourage ecosystem growth, rather than trying to match Python everywhere from day one.
## Why Not Just Use Owl?
Owl deserves credit for the amount of work and love that has been poured into it. It demonstrated that serious numerical computing in OCaml was possible, spanning everything from statistics and signal processing to basic linear algebra and neural networks, and more.
However, Owl can't compete with NumPy or PyTorch on performance, and performance parity isn't optional if we want teams to seriously consider OCaml over Python.
The reality is that we can't realistically match NumPy and PyTorch's performance through traditional optimization. These projects have hundreds of developers working on hand-optimized kernels. With our small team, JIT compilation is our only viable path to competitive performance.
This creates a fundamental constraint. Building for JIT-first changes everything about your design: API choices, memory layouts, operator fusion strategies, even how you structure the development experience. Rather than retrofitting these assumptions onto existing work, we decided a clean slate would be more effective.
There's also the ecosystem question. Despite Owl's technical achievements, it hasn't generated the kind of flourishing community we need. We suspect this is partly due to its lack of modularity: without libraries designed as composable building blocks, it's challenging to build a broader ecosystem around the foundation.
Raven is designed from the ground up to (1) compete with Python's scientific computing stack on performance and (2) build the flourishing ecosystem that OCaml's scientific computing community deserves.
## What We're Building
Raven is a comprehensive ecosystem that spans the entire scientific computing stack. Here's what we're building:
**Foundation**
- **Nx**: N-dimensional arrays with pluggable backends (NumPy equivalent)
- **Brot**: Fast, HuggingFace-compatible tokenization (HF Tokenizers equivalent)
- **Talon**: Type-safe DataFrames (pandas/Polars equivalent)
**Differentiable Computing**
- **Rune**: Automatic differentiation using OCaml's effect system (JAX equivalent)
**Domain Frameworks**
- **Kaun**: Neural networks and training (PyTorch/Flax equivalent)
- **Sowilo**: Differentiable computer vision (OpenCV equivalent)
- **Fehu**: Reinforcement learning environments and algorithms (Gymnasium equivalent)
**Tooling**
- **Hugin**: Publication-quality plotting (Matplotlib equivalent)
- **Quill**: Interactive notebooks as markdown files (Jupyter equivalent)
Nine libraries spanning the full scientific computing stack, all designed to work together seamlessly.
**Key Innovations**
While we aim to feel familiar to Python users, Raven brings genuine innovations to scientific computing:
**Nx** uses pluggable backends inspired by Tinygrad's minimalist approach, giving us flexibility to optimize for different hardware without monolithic complexity.
**Rune** implements automatic differentiation using OCaml's effects system. As far as we know, it is the first project of this scale to use effects for autodiff, building on recent research, and implementing Jax's vision for functional numerical computation with a truly functional foundation.
**Quill** rethinks notebooks. Notebooks are plain markdown files — git-friendly, readable without special tooling, and editable in any text editor. Quill runs them as a TUI in the terminal or as a web frontend in the browser, with all Raven packages pre-loaded and zero setup.
**Deployment** is where Raven's story diverges most from Python. AOT compilation generates all compute kernels at compile time, producing binaries with no BLAS or CUDA runtime dependency. This makes it possible to deploy models as MirageOS unikernels — minimal attack surface, millisecond boot, deterministic behavior — or as static binaries with no Python runtime, no dependency hell.
**Current Focus**
The alpha milestone is complete — we've trained GPT-2 end-to-end on CPU using the full Raven stack. We're now focused on integrating tolk as a JIT transformation in Rune, with the goal of matching PyTorch performance. After that, V1 brings production-ready training and deployment: AOT compilation, inference serving, ONNX import, and MirageOS unikernel deployment. See the [roadmap](/docs/roadmap/) for details.
================================================
FILE: doc/quickstart.md
================================================
# Quickstart
This gets you from zero to computing gradients and training a model in five minutes.
## Setup
```bash
opam install raven
```
Create a `dune-project` and `dune` file:
```dune
; dune-project
(lang dune 3.20)
```
```dune
; dune
(executable
(name main)
(libraries kaun))
```
Installing `kaun` pulls in `nx` and `rune` automatically.
## Step 1: Arrays with Nx
Nx provides n-dimensional arrays. Every value has a data type and a shape.
```ocaml
open Nx
let () =
(* Create arrays *)
let a = create Float32 [|2; 3|] [|1.; 2.; 3.; 4.; 5.; 6.|] in
let b = ones Float32 [|2; 3|] in
(* Element-wise operations *)
let c = add a b in
print_data c;
(* Reductions *)
Printf.printf "sum = %.1f\n" (item [] (sum a));
Printf.printf "mean = %.1f\n" (item [] (mean a));
(* Matrix multiplication *)
let x = rand Float32 [|3; 4|] in
let y = rand Float32 [|4; 2|] in
let z = matmul x y in
Printf.printf "matmul shape: %s\n"
(Array.to_list (shape z) |> List.map string_of_int |> String.concat "x")
```
## Step 2: Gradients with Rune
Rune computes derivatives of Nx functions automatically. Write a function using Nx operations, then use `grad` to differentiate it.
```ocaml
open Nx
open Rune
let () =
(* f(x) = x² + sin(x) *)
let f x = add (mul x x) (sin x) in
(* grad returns the derivative function *)
let f' = grad f in
let x = scalar Float32 2.0 in
Printf.printf "f(2) = %.4f\n" (item [] (f x));
Printf.printf "f'(2) = %.4f\n" (item [] (f' x));
(* Higher-order: second derivative *)
let f'' = grad f' in
Printf.printf "f''(2) = %.4f\n" (item [] (f'' x))
```
## Step 3: Training with Kaun
Kaun provides layers, optimizers, and training loops built on Rune.
```ocaml
open Kaun
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
(* XOR dataset *)
let x = Nx.create Nx.Float32 [|4; 2|]
[|0.; 0.; 0.; 1.; 1.; 0.; 1.; 1.|] in
let y = Nx.create Nx.Float32 [|4; 1|]
[|0.; 1.; 1.; 0.|] in
(* Define model *)
let model = Layer.sequential [
Layer.linear ~in_features:2 ~out_features:8 ();
Layer.tanh ();
Layer.linear ~in_features:8 ~out_features:1 ();
] in
(* Create trainer and initialize *)
let trainer = Train.make ~model
~optimizer:(Optim.adam ~lr:(Optim.Schedule.constant 0.01) ()) in
let st = Train.init trainer ~dtype:Nx.Float32 in
(* Train *)
let st = Train.fit trainer st
~report:(fun ~step ~loss _st ->
if step mod 250 = 0 then
Printf.printf "step %4d loss %.6f\n" step loss)
(Data.repeat 1000 (x, fun pred -> Loss.binary_cross_entropy pred y))
in
(* Predict *)
let pred = Train.predict trainer st x |> Nx.sigmoid in
Printf.printf "\npredictions (expected 0 1 1 0):\n";
for i = 0 to 3 do
Printf.printf " [%.0f, %.0f] -> %.3f\n"
(Nx.item [i; 0] x) (Nx.item [i; 1] x) (Nx.item [i; 0] pred)
done
```
## Next Steps
- **[Nx](/docs/nx/getting-started/)** — full guide to arrays, slicing, broadcasting, linear algebra
- **[Rune](/docs/rune/getting-started/)** — all transformations: grad, jvp, vmap, and more
- **[Kaun](/docs/kaun/getting-started/)** — layers, optimizers, training loops, pretrained models
- **[Ecosystem Overview](/docs/ecosystem-overview/)** — how all 9 libraries fit together
================================================
FILE: doc/roadmap.md
================================================
# Roadmap
## Current Status
Raven is in **alpha**. The core stack (Nx -> Rune -> Kaun) works end-to-end: we have successfully trained GPT-2 on CPU using the full Raven stack.
| Library | Status | What works |
| ---------- | ------ | ------------------------------------------------------------------------- |
| **nx** | Alpha | Full NumPy-like API, linear algebra, FFT, I/O (npy, images) |
| **rune** | Alpha | Reverse and forward-mode AD, vmap, gradient checking |
| **kaun** | Alpha | Layers, optimizers, training loops, HuggingFace Hub, MNIST/GPT-2 examples |
| **brot** | Alpha | All 5 algorithms, full pipeline, HF tokenizer.json compat, training |
| **talon** | Alpha | DataFrames, row operations, aggregations, CSV I/O |
| **hugin** | Alpha | 2D/3D plots, scatter, bar, contour, images |
| **fehu** | Alpha | Environments (CartPole, GridWorld, MountainCar), vectorized envs, GAE |
| **sowilo** | Alpha | Geometric transforms, filters, edge detection, morphological ops |
| **quill** | Alpha | Interactive REPL, notebook TUI and web frontend, batch eval, watch mode |
APIs will change. Bug reports and feedback are welcome.
## Beta: JIT Compilation & Performance
The beta cycle focuses on **JIT compilation with performance close to PyTorch**.
- Integrate tolk (an OCaml port of tinygrad) as a JIT transformation in Rune
- Target CPU, CUDA, Metal, OpenCL, and HIP
- Kernel fusion and optimization
- Benchmark against PyTorch on standard workloads
## V1: Production-Ready Training & Deployment
V1 makes Raven **production-ready**: train models, deploy them as unikernels or static binaries.
**Training**:
- Gradient accumulation, mixed precision, gradient checkpointing
- Flash attention for efficient transformer training
- ONNX import for PyTorch model portability
- Parallel data loading, layer completions
**Deployment**:
- AOT compilation to standalone binaries (CPU and GPU)
- Inference engine with KV cache, continuous batching, and PagedAttention
- Post-training quantization (INT8/INT4)
- MirageOS unikernel deployment -- tolk AOT generates all compute at compile time, no BLAS dependency, enabling deployment as unikernels
================================================
FILE: doc/support-raven.md
================================================
# Support Raven
## Raven in One Minute
Python's monopoly on scientific computing forces an impossible choice: ship everything in Python (endure runtime crashes, the GIL's multicore ceiling, and gigabyte containers), or prototype in Python then rewrite for production (doubling the work and creating siloed teams).
**We think there's a better way.** OCaml lets you prototype as quickly as Python and scale the same code to production. Same expressiveness, strong typing catches bugs before they crash your ML pipeline, while JIT compilation matches NumPy/PyTorch performance. One language from research to production — it just needs a production-grade ML stack.
**Raven brings that stack to OCaml:** Nx (NumPy), Rune (JAX with effects-based autodiff), Kaun (Flax), Brot (tokenization), Hugin (Matplotlib), and Quill (notebooks done right). Train models with automatic differentiation and JIT compilation, then deploy as a MirageOS unikernel or a static binary — no Python, no CUDA dependency hell, no 5 GB Docker images. We built Raven for teams that want both development speed and reliable systems.
_Learn more: [Introduction](/docs/introduction)_
_We're in alpha with the full stack working end-to-end (we've trained GPT-2 on CPU). Next milestone: JIT compilation via tolk with performance close to PyTorch._
## Roadmap & Funding Goals
_See the [full roadmap](/docs/roadmap) for our complete vision and timeline._
### Beta — JIT Compilation & Performance
- Integrate tolk (tinygrad-based compiler) as a JIT transformation in Rune
- Target CPU, CUDA, Metal, OpenCL, and HIP
- Kernel fusion and optimization
- Performance within 2x of PyTorch on standard workloads
### V1 — Production-Ready Training & Deployment
- Production training: gradient accumulation, mixed precision, gradient checkpointing, flash attention
- ONNX import for PyTorch model portability
- AOT compilation to standalone binaries (CPU and GPU)
- Inference engine with KV cache, continuous batching, and PagedAttention
- MirageOS unikernel deployment
- Post-training quantization (INT8/INT4)
We're also open to discussing custom sponsorship packages based on your needs.
## Ways to Support
### For Developers
- **Try it out**: Test Raven with your workflows and [report issues](https://github.com/raven-ml/raven/issues)
- **Contribute code**: See our [contributing guide](https://github.com/raven-ml/raven/blob/main/CONTRIBUTING.md) for areas where we need help
- **Share feedback**: What would make you switch from Python? [Tell us](mailto:thibaut.mattio@gmail.com)
- **Spread the word**: Star the repo, share with your team, write about your experience
### For Companies
- **Use Raven**: Reach out if you're interested in using it—we're keen on prioritizing development based on real-world needs
- **Sponsor development**: Email [thibaut.mattio@gmail.com](mailto:thibaut.mattio@gmail.com) for sponsorship packages
### For Individuals
- **GitHub Sponsors**: [Support the project with monthly contributions](https://github.com/sponsors/tmattio)
- **One-time donations**: Every contribution helps us reach the next milestone
- **Write tutorials**: Help others learn Raven and grow the community
## Current Sponsors
We're grateful for the support of our sponsors:
### Corporate Sponsors
- [**Ahrefs**](https://ahrefs.com) - Building tools to help you grow your search traffic
- [**Tarides**](https://tarides.com) - Secure-by-design infrastructure and tooling for a better digital world
### Individual Sponsors
Thank you to all our individual sponsors for their support!
## Get in Touch
**For sponsorship inquiries**: [thibaut.mattio@gmail.com](mailto:thibaut.mattio@gmail.com)
**For feature request or bug reports**: [GitHub Issues](https://github.com/raven-ml/raven/issues)
---
_Raven is built by [Thibaut Mattio](https://github.com/tmattio) and contributors. We believe OCaml deserves a world-class scientific computing ecosystem, and we're committed to building it._
================================================
FILE: dune-project
================================================
(lang dune 3.21)
(name raven)
(source
(github raven-ml/raven))
(authors "Thibaut Mattio ")
(maintainers "Thibaut Mattio ")
(license ISC)
(documentation "https://raven-ml.dev/docs/")
(bug_reports "https://github.com/raven-ml/raven/issues")
(using mdx 0.4)
(using directory-targets 0.1)
(version 1.0.0~alpha3)
(implicit_transitive_deps false)
(generate_opam_files true)
(opam_file_location inside_opam_directory)
(pin
(url "git+https://github.com/invariant-hq/thumper.git")
(package
(name thumper)))
(package
(name nx)
(dir packages/nx)
(synopsis "N-dimensional arrays for OCaml")
(description
"Nx provides n-dimensional arrays with NumPy-like semantics and OCaml's type safety. 19 data types, broadcasting, slicing, linear algebra, FFT, and I/O. The numerical foundation for the Raven ecosystem.")
(depends
(ocaml
(>= 5.2.0))
dune
(dune-configurator :build)
(conf-pkg-config :build)
; camlzip
(conf-zlib :build)
logs
; tests
(windtrap :with-test)
(mdx :with-test)
(thumper :with-test))
(tags
(numerical-computation tensor-library machine-learning)))
(package
(name brot)
(dir packages/brot)
(synopsis "Tokenization for OCaml")
(description
"Fast, HuggingFace-compatible tokenization for language models. BPE, WordPiece, Unigram, word-level, and character-level algorithms with composable pipelines and training from scratch.")
(depends
(ocaml
(>= 5.2.0))
dune
re
jsont
bytesrw
(uunf
(>= 15.1.0))
uucp
(windtrap :with-test)
(mdx :with-test)
(thumper :with-test))
(tags
(tokenization bpe wordpiece subword-tokenization language-models)))
(package
(name talon)
(dir packages/talon)
(synopsis "Dataframes for OCaml")
(description
"Fast and elegant dataframes with type-safe operations. Heterogeneous columns, applicative row operations, vectorized aggregations, and CSV I/O, built on Nx.")
(depends
(ocaml
(>= 5.2.0))
dune
(nx
(= :version))
(windtrap :with-test)
(mdx :with-test)
(thumper :with-test))
(tags
(dataframe data-manipulation data-science tabular-data)))
(package
(name rune)
(dir packages/rune)
(synopsis "Functional transformations for Nx arrays")
(description
"Automatic differentiation and vectorizing maps for Nx tensors. Reverse-mode AD (grad, vjp), forward-mode AD (jvp), vmap, and gradient checking, built on OCaml 5 effect handlers.")
(depends
(ocaml
(>= 5.2.0))
dune
(dune-configurator :build)
(nx
(= :version))
(tolk
(= :version))
(windtrap :with-test)
(mdx :with-test)
(thumper :with-test))
(tags
(automatic-differentiation machine-learning deep-learning optimization)))
(package
(name tolk)
(dir packages/tolk)
(synopsis "A minimal ML compiler for GPU tensor computation")
(description
"Tolk is a minimal, readable ML compiler for GPU tensor computation in the Raven ecosystem.")
(depends
(ocaml
(>= 5.2))
dune
(windtrap :with-test)
(thumper :with-test))
(tags
(compiler gpu tensor-computation)))
(package
(name norn)
(dir packages/norn)
(synopsis "MCMC sampling for OCaml")
(description
"Markov chain Monte Carlo samplers with automatic gradients via Rune. Hamiltonian Monte Carlo with dual-averaging step-size adaptation.")
(depends
(ocaml
(>= 5.2.0))
dune
(nx
(= :version))
(rune
(= :version))
(windtrap :with-test)
(thumper :with-test))
(tags
(mcmc sampling bayesian machine-learning)))
(package
(name vega)
(dir packages/vega)
(synopsis "Per-parameter gradient-based optimizers for OCaml")
(description
"Typed, per-parameter optimizer primitives: Adam, AdamW, SGD, RMSprop, Adagrad, and learning-rate schedules. Built on Nx with no autodiff dependency.")
(depends
(ocaml
(>= 5.2.0))
dune
(nx
(= :version))
(windtrap :with-test)
(thumper :with-test))
(tags
(optimization machine-learning gradient-descent)))
(package
(name kaun)
(dir packages/kaun)
(synopsis "Neural networks for OCaml")
(description
"Composable layers, parameter trees, optimizers, training loops, data pipelines, and HuggingFace Hub integration. Built on Rune.")
(depends
(ocaml
(>= 5.2.0))
dune
(rune
(= :version))
(vega
(= :version))
(nx
(= :version))
jsont
bytesrw
(windtrap :with-test)
(mdx :with-test)
(thumper :with-test))
(tags
(neural-networks machine-learning deep-learning)))
(package
(name munin)
(dir packages/munin)
(synopsis "Local experiment tracking for Raven")
(description
"Local-first experiment tracking with append-only event logs, versioned artifacts, a terminal dashboard, and a CLI. The core library (munin) provides Session, Run, Store, and Artifact modules. The TUI library (munin.tui) provides a Mosaic-based dashboard.")
(depends
(ocaml
(>= 5.2.0))
dune
jsont
bytesrw
sha
cmdliner
mosaic
(dune-configurator :build)
matrix
(windtrap :with-test))
(tags
(experiment-tracking machine-learning monitoring)))
(package
(name sowilo)
(dir packages/sowilo)
(synopsis "Differentiable computer vision for OCaml")
(description
"Image processing operations expressed as Nx tensor computations. Geometric transforms, spatial filters, edge detection, morphological operations, and color space conversions, all compatible with Rune.grad and Rune.vmap.")
(depends
(ocaml
(>= 5.2.0))
dune
(nx
(= :version))
(windtrap :with-test)
(mdx :with-test)
(thumper :with-test))
(tags
(computer-vision image-processing feature-detection machine-learning)))
(package
(name fehu)
(dir packages/fehu)
(synopsis "Reinforcement learning for OCaml")
(description
"Type-safe RL environments, observation/action spaces, vectorized environments, trajectory collection, replay buffers, and generalized advantage estimation. Built on Nx.")
(depends
(ocaml
(>= 5.2.0))
dune
(nx
(= :version))
(windtrap :with-test)
(mdx :with-test)
(thumper :with-test))
(tags
(reinforcement-learning machine-learning environments)))
(package
(name hugin)
(dir packages/hugin)
(synopsis "Declarative plotting and visualization for OCaml")
(description "Composable, beautiful-by-default plotting built on Nx.")
(depends
(ocaml
(>= 5.2.0))
dune
(dune-configurator :build)
(conf-sdl2 :build)
(conf-cairo :build)
(nx
(= :version))
(windtrap :with-test)
(mdx :with-test))
(tags
(visualization plotting charts data-science graphics)))
(package
(name quill)
(dir packages/quill)
(synopsis "Interactive REPL and markdown notebooks")
(description
"Quill is a REPL and notebook environment for OCaml. Interactive toplevel with syntax highlighting, completion, and history. Markdown notebooks with a terminal UI, web frontend, batch evaluation, and watch mode.")
(depends
(ocaml
(>= 5.2.0))
dune
cmarkit
cmdliner
bytesrw
jsont
mosaic
(windtrap :with-test)
(mdx :with-test))
(tags
(repl toplevel notebooks interactive-computing literate-programming)))
(package
(name raven)
(allow_empty)
(dir packages/raven)
(synopsis "Modern scientific computing for OCaml")
(description
"Raven is an ecosystem of composable libraries for numerical computing in OCaml. Tensors, automatic differentiation, neural networks, dataframes, plotting, tokenization, computer vision, reinforcement learning, and interactive notebooks.")
(depends
(nx
(= :version))
(tolk
(= :version))
(brot
(= :version))
(talon
(= :version))
(rune
(= :version))
(vega
(= :version))
(kaun
(= :version))
(munin
(= :version))
(sowilo
(= :version))
(fehu
(= :version))
(hugin
(= :version))
(quill
(= :version)))
(tags
(machine-learning data-science numerical-computation)))
================================================
FILE: dune-workspace.tsan
================================================
(lang dune 3.21)
(lock_dir
(path dune.lock))
; Pin ocaml-variants to the 5.4 branch which includes the
; __tsan_func_exit signature fix (ocaml/ocaml#14082).
; Remove this pin once OCaml 5.4.2 is released.
(pin
(name ocaml-variants)
(url "git+https://github.com/ocaml/ocaml#5.4")
(package
(name ocaml-variants)
(version 5.4.2+trunk)))
(lock_dir
(path dune-tsan.lock)
(pins ocaml-variants)
(depopts ocaml-option-tsan))
(context default)
(context
(default
(name tsan)
(lock_dir dune-tsan.lock)))
================================================
FILE: opam/brot.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "Tokenization for OCaml"
description:
"Fast, HuggingFace-compatible tokenization for language models. BPE, WordPiece, Unigram, word-level, and character-level algorithms with composable pipelines and training from scratch."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: [
"tokenization" "bpe" "wordpiece" "subword-tokenization" "language-models"
]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"re"
"jsont"
"bytesrw"
"uunf" {>= "15.1.0"}
"uucp"
"windtrap" {with-test}
"mdx" {with-test}
"thumper" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/fehu.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "Reinforcement learning for OCaml"
description:
"Type-safe RL environments, observation/action spaces, vectorized environments, trajectory collection, replay buffers, and generalized advantage estimation. Built on Nx."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: ["reinforcement-learning" "machine-learning" "environments"]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"nx" {= version}
"windtrap" {with-test}
"mdx" {with-test}
"thumper" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/hugin.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "Declarative plotting and visualization for OCaml"
description: "Composable, beautiful-by-default plotting built on Nx."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: ["visualization" "plotting" "charts" "data-science" "graphics"]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"dune-configurator" {build}
"conf-sdl2" {build}
"conf-cairo" {build}
"nx" {= version}
"windtrap" {with-test}
"mdx" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/kaun-board.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "Training dashboard and logging for Raven"
description:
"Lightweight training logger and terminal dashboard for monitoring runs. The core library (kaun-board) provides a Log API for writing JSONL events and a reader for consuming them. The TUI library (kaun-board.tui) provides a Mosaic-based dashboard."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: ["training-dashboard" "monitoring" "logging" "machine-learning"]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"jsont"
"bytesrw"
"cmdliner"
"mosaic"
"dune-configurator" {build}
"matrix"
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/kaun.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "Neural networks for OCaml"
description:
"Composable layers, parameter trees, optimizers, training loops, data pipelines, and HuggingFace Hub integration. Built on Rune."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: ["neural-networks" "machine-learning" "deep-learning"]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"rune" {= version}
"vega" {= version}
"nx" {= version}
"jsont"
"bytesrw"
"windtrap" {with-test}
"mdx" {with-test}
"thumper" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/munin.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "Local experiment tracking for Raven"
description:
"Local-first experiment tracking with append-only event logs, versioned artifacts, a terminal dashboard, and a CLI. The core library (munin) provides Session, Run, Store, and Artifact modules. The TUI library (munin.tui) provides a Mosaic-based dashboard."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: ["experiment-tracking" "machine-learning" "monitoring"]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"jsont"
"bytesrw"
"sha"
"cmdliner"
"mosaic"
"dune-configurator" {build}
"matrix"
"windtrap" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/norn.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "MCMC sampling for OCaml"
description:
"Markov chain Monte Carlo samplers with automatic gradients via Rune. Hamiltonian Monte Carlo with dual-averaging step-size adaptation."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: ["mcmc" "sampling" "bayesian" "machine-learning"]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"nx" {= version}
"rune" {= version}
"windtrap" {with-test}
"thumper" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/nx.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "N-dimensional arrays for OCaml"
description:
"Nx provides n-dimensional arrays with NumPy-like semantics and OCaml's type safety. 19 data types, broadcasting, slicing, linear algebra, FFT, and I/O. The numerical foundation for the Raven ecosystem."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: ["numerical-computation" "tensor-library" "machine-learning"]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"dune-configurator" {build}
"conf-pkg-config" {build}
"conf-zlib" {build}
"logs"
"windtrap" {with-test}
"mdx" {with-test}
"thumper" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
depexts: [
["libc-dev" "openblas-dev" "lapack-dev"] {os-distribution = "alpine"}
["epel-release" "openblas-devel"] {os-distribution = "centos"}
["libopenblas-dev" "liblapacke-dev"] {os-family = "debian"}
["libopenblas-dev" "liblapacke-dev"] {os-family = "ubuntu"}
["openblas-devel"] {os-family = "fedora"}
["libopenblas_openmp-devel"] {os-family = "suse" | os-family = "opensuse"}
["openblas" "lapacke" "cblas"] {os-distribution = "arch"}
["openblas"] {os = "macos" & os-distribution = "homebrew"}
["openblas" "lapacke"] {os = "freebsd"}
["mingw64-x86_64-cblas" "mingw64-x86_64-lapack"] {os = "cygwin"}
]
x-ci-accept-failures: [
"oraclelinux-7"
"oraclelinux-8"
"oraclelinux-9"
]
================================================
FILE: opam/nx.opam.template
================================================
depexts: [
["libc-dev" "openblas-dev" "lapack-dev"] {os-distribution = "alpine"}
["epel-release" "openblas-devel"] {os-distribution = "centos"}
["libopenblas-dev" "liblapacke-dev"] {os-family = "debian"}
["libopenblas-dev" "liblapacke-dev"] {os-family = "ubuntu"}
["openblas-devel"] {os-family = "fedora"}
["libopenblas_openmp-devel"] {os-family = "suse" | os-family = "opensuse"}
["openblas" "lapacke" "cblas"] {os-distribution = "arch"}
["openblas"] {os = "macos" & os-distribution = "homebrew"}
["openblas" "lapacke"] {os = "freebsd"}
["mingw64-x86_64-cblas" "mingw64-x86_64-lapack"] {os = "cygwin"}
]
x-ci-accept-failures: [
"oraclelinux-7"
"oraclelinux-8"
"oraclelinux-9"
]
================================================
FILE: opam/quill.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "Interactive REPL and markdown notebooks"
description:
"Quill is a REPL and notebook environment for OCaml. Interactive toplevel with syntax highlighting, completion, and history. Markdown notebooks with a terminal UI, web frontend, batch evaluation, and watch mode."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: [
"repl"
"toplevel"
"notebooks"
"interactive-computing"
"literate-programming"
]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"cmarkit"
"cmdliner"
"bytesrw"
"jsont"
"mosaic"
"windtrap" {with-test}
"mdx" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/raven.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "Modern scientific computing for OCaml"
description:
"Raven is an ecosystem of composable libraries for numerical computing in OCaml. Tensors, automatic differentiation, neural networks, dataframes, plotting, tokenization, computer vision, reinforcement learning, and interactive notebooks."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: ["machine-learning" "data-science" "numerical-computation"]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"dune" {>= "3.21"}
"nx" {= version}
"tolk" {= version}
"brot" {= version}
"talon" {= version}
"rune" {= version}
"vega" {= version}
"kaun" {= version}
"munin" {= version}
"sowilo" {= version}
"fehu" {= version}
"hugin" {= version}
"quill" {= version}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/rune.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "Functional transformations for Nx arrays"
description:
"Automatic differentiation and vectorizing maps for Nx tensors. Reverse-mode AD (grad, vjp), forward-mode AD (jvp), vmap, and gradient checking, built on OCaml 5 effect handlers."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: [
"automatic-differentiation"
"machine-learning"
"deep-learning"
"optimization"
]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"dune-configurator" {build}
"nx" {= version}
"tolk" {= version}
"windtrap" {with-test}
"mdx" {with-test}
"thumper" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/sowilo.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "Differentiable computer vision for OCaml"
description:
"Image processing operations expressed as Nx tensor computations. Geometric transforms, spatial filters, edge detection, morphological operations, and color space conversions, all compatible with Rune.grad and Rune.vmap."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: [
"computer-vision" "image-processing" "feature-detection" "machine-learning"
]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"nx" {= version}
"windtrap" {with-test}
"mdx" {with-test}
"thumper" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/talon.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "Dataframes for OCaml"
description:
"Fast and elegant dataframes with type-safe operations. Heterogeneous columns, applicative row operations, vectorized aggregations, and CSV I/O, built on Nx."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: ["dataframe" "data-manipulation" "data-science" "tabular-data"]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"nx" {= version}
"windtrap" {with-test}
"mdx" {with-test}
"thumper" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/tolk.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "A minimal ML compiler for GPU tensor computation"
description:
"Tolk is a minimal, readable ML compiler for GPU tensor computation in the Raven ecosystem."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: ["compiler" "gpu" "tensor-computation"]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2"}
"dune" {>= "3.21"}
"windtrap" {with-test}
"thumper" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: opam/vega.opam
================================================
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "1.0.0~alpha3"
synopsis: "Per-parameter gradient-based optimizers for OCaml"
description:
"Typed, per-parameter optimizer primitives: Adam, AdamW, SGD, RMSprop, Adagrad, and learning-rate schedules. Built on Nx with no autodiff dependency."
maintainer: ["Thibaut Mattio "]
authors: ["Thibaut Mattio "]
license: "ISC"
tags: ["optimization" "machine-learning" "gradient-descent"]
homepage: "https://github.com/raven-ml/raven"
doc: "https://raven-ml.dev/docs/"
bug-reports: "https://github.com/raven-ml/raven/issues"
depends: [
"ocaml" {>= "5.2.0"}
"dune" {>= "3.21"}
"nx" {= version}
"windtrap" {with-test}
"thumper" {with-test}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
]
dev-repo: "git+https://github.com/raven-ml/raven.git"
x-maintenance-intent: ["(latest)"]
================================================
FILE: packages/brot/README.md
================================================
# Brot
Fast tokenization library for OCaml.
Brot tokenizes text into token IDs for language models and reverses the
process. It is part of the Raven ecosystem. It loads and saves HuggingFace
`tokenizer.json` files, supports BPE, WordPiece, Unigram, word-level, and
character-level algorithms, and is 1.3-6x faster than HuggingFace
tokenizers on most benchmarks.
## Features
- Tokenization algorithms: BPE, WordPiece, Unigram, word-level, character-level
- HuggingFace compatible: load and save `tokenizer.json` files, load
vocab/merges model files
- Composable pipeline: normalizer, pre-tokenizer, post-processor, decoder
— each stage independently configurable
- Rich encoding output: token IDs, string tokens, byte offsets, attention
masks, type IDs, word IDs, special token masks
- Training: train BPE, WordPiece, Unigram, and word-level tokenizers from
scratch
- Performance: 1.3-6x faster than HuggingFace tokenizers (Rust native) on
most benchmarks — see [bench/](bench/) for details
## Quick Start
```ocaml
open Brot
let () =
(* Load a pretrained HuggingFace tokenizer *)
let tokenizer = from_file "tokenizer.json" |> Result.get_ok in
(* Encode text to token IDs *)
let encoding = encode tokenizer "Hello world!" in
let ids = Encoding.ids encoding in
Printf.printf "Token IDs: ";
Array.iter (fun id -> Printf.printf "%d " id) ids;
print_newline ();
(* Decode back to text *)
let text = decode tokenizer ids in
Printf.printf "Decoded: %s\n" text
```
## Contributing
See the [Raven monorepo README](../README.md) for contribution guidelines.
## License
ISC License. See [LICENSE](../LICENSE) for details.
================================================
FILE: packages/brot/bench/README.md
================================================
# Brot Benchmarks
This directory contains micro-benchmarks for the `brot` library.
The suite mirrors HuggingFace's `tokenizers` so we can compare wall-clock
throughput for realistic workloads and catch regressions.
## Fixtures
Benchmark inputs live in `./data/`:
- `news_1k.txt`, `wiki_64k.txt`, `code_excerpt.txt` — sample corpora used for
encoding workloads.
- `gpt2.json` — OpenAI GPT-2 (BPE, 50K vocab, 50K merges)
- `bert_base.json` — Google BERT-base-uncased (WordPiece, 30K vocab)
- `llama.json` — Meta LLaMA (BPE, 32K vocab, 61K merges, no pre-tokenizer)
Download the tokenizer model files:
```bash
brot/bench/download_data.sh
```
## Running the Benchmarks
### Brot (OCaml)
```bash
dune exec brot/bench/bench_brot.exe -- --gc
```
### tokenizers — Rust native
```bash
cd brot/bench/bench_rust && cargo run --release
```
### tokenizers — Python (Rust FFI)
```bash
uv run --with tokenizers brot/bench/bench_tokenizers.py
```
## Comparison
Wall-clock time per run. Lower is better. Apple M3 Pro, macOS.
| Benchmark | Brot (OCaml) | Rust native | Python (Rust FFI) | Brot vs Rust |
| ------------------------------------ | ------------ | ----------- | ----------------- | ------------ |
| **GPT-2** (BPE, 50K vocab) | | | | |
| Encode/short (1KB) | 46μs | 209μs | 250μs | **4.5x** |
| Encode/long (64KB) | 5.26ms | 10.25ms | 13.27ms | **1.9x** |
| Encode/batch_32 | 1.38ms | 3.05ms | 3.91ms | **2.2x** |
| Decode/long | 1.19ms | 1.50ms | 1.58ms | **1.3x** |
| **BERT-base** (WordPiece, 30K vocab) | | | | |
| Encode/short (1KB) | 137μs | 278μs | 325μs | **2.0x** |
| Encode/long (64KB) | 10.87ms | 13.95ms | 16.64ms | **1.3x** |
| Encode/batch_32 | 2.06ms | 2.31ms | 2.66ms | **1.1x** |
| Decode/long | 1.25ms | 7.63ms | 7.76ms | **6.1x** |
| **LLaMA** (BPE, 32K vocab) | | | | |
| Encode/short (1KB) | 51μs | 207μs | 247μs | **4.1x** |
| Encode/long (64KB) | 20.15ms | 13.41ms | 16.23ms | 1.5x slower |
| Encode/batch_32 | 1.43ms | 1.56ms | 1.51ms | ~par |
| Decode/long | 1.12ms | 5.02ms | 5.03ms | **4.5x** |
Notes:
- The "Rust native" column calls the `tokenizers` crate directly, no Python FFI.
Source: `bench_rust/main.rs`.
- Both brot and HF tokenizers use multi-threading for batch encoding (wall < CPU).
- LLaMA has no pre-tokenizer, so the entire text goes through BPE as a single
sequence — this is where brot's BPE is slower on long inputs.
================================================
FILE: packages/brot/bench/bench_brot.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Benchmark suite for Brot tokenizers using realistic fixtures. *)
open Brot
module Fixtures = struct
let data_dir = Filename.concat (Sys.getcwd ()) "packages/brot/bench/data"
let read_file name =
let path = Filename.concat data_dir name in
let ic = open_in_bin path in
let len = in_channel_length ic in
let content = really_input_string ic len in
close_in ic;
content
let load_tokenizer name =
let path = Filename.concat data_dir name in
match from_file path with
| Ok tok -> tok
| Error msg ->
failwith (Printf.sprintf "Failed to load tokenizer %s: %s" path msg)
let short_text = read_file "news_1k.txt"
let long_text = read_file "wiki_64k.txt"
let batch_32 =
let rec loop acc remaining =
if remaining = 0 then List.rev acc
else loop (short_text :: acc) (remaining - 1)
in
loop [] 32
end
let encode_single tok text = encode tok text
let encode_batch tok texts = encode_batch tok texts
let decode_ids tok ids = decode tok ids
let make_suite ~label ~tokenizer =
let open Fixtures in
let decode_input =
let encoding = encode tokenizer long_text in
Array.copy (Encoding.ids encoding)
in
let benches =
[
Thumper.bench "Encode/single_short" (fun () ->
encode_single tokenizer short_text);
Thumper.bench "Encode/single_long" (fun () ->
encode_single tokenizer long_text);
Thumper.bench "Encode/batch_32" (fun () ->
encode_batch tokenizer batch_32);
Thumper.bench "Decode/long" (fun () -> decode_ids tokenizer decode_input);
]
in
Thumper.group label benches
let all_benchmarks =
let open Fixtures in
let gpt2 =
make_suite ~label:"GPT-2" ~tokenizer:(load_tokenizer "gpt2.json")
in
let bert =
make_suite ~label:"BERT-base" ~tokenizer:(load_tokenizer "bert_base.json")
in
let llama =
make_suite ~label:"LLaMA" ~tokenizer:(load_tokenizer "llama.json")
in
[ gpt2; bert; llama ]
let () = Thumper.run "brot" all_benchmarks
================================================
FILE: packages/brot/bench/bench_rust/.gitignore
================================================
/target
================================================
FILE: packages/brot/bench/bench_rust/Cargo.toml
================================================
[package]
name = "bench_tokenizers_rust"
edition = "2021"
[[bin]]
name = "bench_tokenizers_rust"
path = "main.rs"
[dependencies]
tokenizers = "0.22"
================================================
FILE: packages/brot/bench/bench_rust/main.rs
================================================
use std::fs;
use std::path::Path;
use std::time::{Duration, Instant};
use tokenizers::Tokenizer;
const WARMUP: usize = 4;
const TIME_QUOTA: Duration = Duration::from_millis(300);
const MIN_MEASUREMENTS: usize = 3;
struct BenchResult {
name: String,
wall_per_run: Duration,
runs: usize,
}
fn bench(name: &str, mut f: F) -> BenchResult {
// Warmup
for _ in 0..WARMUP {
f();
}
// Adaptive batching: start with batch_size=1, scale up until each batch
// takes at least 2ms of wall time, then collect measurements for ~0.3s.
let mut batch_size: usize = 1;
let mut measurements: Vec = Vec::new();
let bench_start = Instant::now();
loop {
let start = Instant::now();
for _ in 0..batch_size {
f();
}
let elapsed = start.elapsed();
if elapsed.as_secs_f64() < 0.002 {
// Batch too fast, scale up
batch_size = (batch_size as f64 * 1.3).ceil().max((batch_size + 1) as f64) as usize;
continue;
}
let per_run = elapsed / batch_size as u32;
measurements.push(per_run);
let total_elapsed = bench_start.elapsed();
if measurements.len() >= MIN_MEASUREMENTS && total_elapsed >= TIME_QUOTA {
break;
}
batch_size = (batch_size as f64 * 1.3).ceil().max((batch_size + 1) as f64) as usize;
}
// Compute average
let total: Duration = measurements.iter().sum();
let avg = total / measurements.len() as u32;
BenchResult {
name: name.to_string(),
wall_per_run: avg,
runs: measurements.len(),
}
}
fn format_duration(d: Duration) -> String {
let nanos = d.as_nanos() as f64;
if nanos < 1_000.0 {
format!("{:.2}ns", nanos)
} else if nanos < 1_000_000.0 {
format!("{:.2}μs", nanos / 1_000.0)
} else if nanos < 1_000_000_000.0 {
format!("{:.2}ms", nanos / 1_000_000.0)
} else {
format!("{:.2}s", nanos / 1_000_000_000.0)
}
}
fn run_suite(label: &str, tokenizer: &Tokenizer, short_text: &str, long_text: &str) {
let batch_32: Vec<&str> = vec![short_text; 32];
// Pre-compute decode input
let encoding = tokenizer
.encode(long_text, false)
.expect("encode for decode input");
let decode_ids: Vec = encoding.get_ids().to_vec();
let results = vec![
bench(&format!("{}/Encode/single_short", label), || {
tokenizer.encode(short_text, false).unwrap();
}),
bench(&format!("{}/Encode/single_long", label), || {
tokenizer.encode(long_text, false).unwrap();
}),
bench(&format!("{}/Encode/batch_32", label), || {
tokenizer
.encode_batch(batch_32.clone(), false)
.unwrap();
}),
bench(&format!("{}/Decode/long", label), || {
tokenizer.decode(decode_ids.as_slice(), true).unwrap();
}),
];
for r in &results {
println!(
" {:<35} {:>10} ({} samples)",
r.name,
format_duration(r.wall_per_run),
r.runs
);
}
}
fn main() {
let data_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("../data");
let short_text =
fs::read_to_string(data_dir.join("news_1k.txt")).expect("read news_1k.txt");
let long_text =
fs::read_to_string(data_dir.join("wiki_64k.txt")).expect("read wiki_64k.txt");
println!("Rust-native HuggingFace tokenizers benchmark");
println!("=============================================\n");
let tokenizers = [
("GPT-2", "gpt2.json"),
("BERT-base", "bert_base.json"),
("LLaMA", "llama.json"),
];
for (label, filename) in &tokenizers {
let path = data_dir.join(filename);
let tokenizer =
Tokenizer::from_file(&path).unwrap_or_else(|e| {
panic!("Failed to load {}: {}", path.display(), e)
});
println!("{}:", label);
run_suite(label, &tokenizer, &short_text, &long_text);
println!();
}
}
================================================
FILE: packages/brot/bench/bench_tokenizers.py
================================================
from __future__ import annotations
from pathlib import Path
from typing import Any, Callable, List
from tokenizers import Tokenizer
_ROOT = Path(__file__).resolve().parent
_DATA_DIR = _ROOT / "data"
import sys
_SCRIPTS_DIR = _ROOT
while not (_SCRIPTS_DIR / "dune-project").exists():
_SCRIPTS_DIR = _SCRIPTS_DIR.parent
_SCRIPTS_DIR = _SCRIPTS_DIR / "scripts"
if str(_SCRIPTS_DIR) not in sys.path:
sys.path.insert(0, str(_SCRIPTS_DIR))
import ubench # type: ignore
SHORT_TEXT = (_DATA_DIR / "news_1k.txt").read_text(encoding="utf-8")
LONG_TEXT = (_DATA_DIR / "wiki_64k.txt").read_text(encoding="utf-8")
BATCH_32 = [SHORT_TEXT] * 32
def load_tokenizer(filename: str) -> Tokenizer:
path = _DATA_DIR / filename
return Tokenizer.from_file(str(path))
def make_suite(label: str, tokenizer: Tokenizer) -> Any:
decode_ids = tokenizer.encode(LONG_TEXT).ids
benches: List[Any] = [
ubench.bench("Encode/single_short", lambda: tokenizer.encode(SHORT_TEXT)),
ubench.bench("Encode/single_long", lambda: tokenizer.encode(LONG_TEXT)),
ubench.bench("Encode/batch_32", lambda: tokenizer.encode_batch(BATCH_32)),
ubench.bench("Decode/long", lambda: tokenizer.decode(decode_ids)),
]
return ubench.group(label, benches)
def build_benchmarks() -> List[Any]:
return [
make_suite("GPT-2", load_tokenizer("gpt2.json")),
make_suite("BERT-base", load_tokenizer("bert_base.json")),
make_suite("LLaMA", load_tokenizer("llama.json")),
]
def default_config() -> ubench.Config:
return ubench.Config.default().build()
def main() -> None:
benchmarks = build_benchmarks()
config = default_config()
ubench.run(benchmarks, config=config, output_format="pretty", verbose=False)
if __name__ == "__main__":
main()
================================================
FILE: packages/brot/bench/brot.thumper
================================================
# thumper baseline
# version: 1
# suite_name: brot
# host: 1480401c3b76ed18
# cpu: Apple M1 Max
# ocaml: 5.4.1
# git: 31747323
# dirty: true
# command: /Users/tmattio/Workspace/raven/_build/default/packages/brot/bench/bench_brot.exe --bless --quick
bert-base/decode_long alloc_words 4.445400e+05 4.445400e+05 4.445400e+05 0.000000e+00 9 0
bert-base/decode_long cpu_time 1.424685e-03 1.370913e-03 1.470915e-03 3.509645e-02 9 1
bert-base/decode_long wall_time 1.425388e-03 1.371378e-03 1.475911e-03 3.666840e-02 9 1
bert-base/encode_batch_32 alloc_words 1.089250e+05 1.089250e+05 1.089250e+05 0.000000e+00 31 1
bert-base/encode_batch_32 cpu_time 9.091263e-03 8.686646e-03 9.591658e-03 4.977371e-02 31 2
bert-base/encode_batch_32 wall_time 2.121498e-03 1.993592e-03 2.253395e-03 6.123112e-02 31 1
bert-base/encode_single_long alloc_words 1.350547e+06 1.350547e+06 1.350547e+06 0.000000e+00 9 1
bert-base/encode_single_long cpu_time 9.506189e-03 9.354674e-03 9.692541e-03 1.777085e-02 9 1
bert-base/encode_single_long wall_time 9.509793e-03 9.372449e-03 9.680291e-03 1.618554e-02 9 1
bert-base/encode_single_short alloc_words 2.699600e+04 2.699600e+04 2.699600e+04 0.000000e+00 9 0
bert-base/encode_single_short cpu_time 1.392726e-04 1.345864e-04 1.448241e-04 3.675399e-02 9 0
bert-base/encode_single_short wall_time 1.393180e-04 1.345034e-04 1.440607e-04 3.430033e-02 9 0
gpt-2/decode_long alloc_words 3.417770e+05 3.417770e+05 3.417770e+05 0.000000e+00 7 0
gpt-2/decode_long cpu_time 1.305595e-03 1.261472e-03 1.338931e-03 2.966443e-02 7 0
gpt-2/decode_long wall_time 1.305703e-03 1.262279e-03 1.346462e-03 3.223653e-02 7 0
gpt-2/encode_batch_32 alloc_words 5.518900e+04 5.518900e+04 5.518900e+04 0.000000e+00 6 0
gpt-2/encode_batch_32 cpu_time 3.952923e-03 3.848310e-03 4.113309e-03 3.351934e-02 6 1
gpt-2/encode_batch_32 wall_time 1.386324e-03 1.328412e-03 1.431061e-03 3.702195e-02 6 0
gpt-2/encode_single_long alloc_words 6.731690e+05 6.731690e+05 6.731690e+05 0.000000e+00 19 2
gpt-2/encode_single_long cpu_time 3.852835e-03 3.758677e-03 3.927922e-03 2.196376e-02 19 1
gpt-2/encode_single_long wall_time 3.856248e-03 3.756035e-03 3.930090e-03 2.256793e-02 19 1
gpt-2/encode_single_short alloc_words 1.356200e+04 1.356200e+04 1.356200e+04 0.000000e+00 13 0
gpt-2/encode_single_short cpu_time 5.279107e-05 5.090596e-05 5.553182e-05 4.381283e-02 13 0
gpt-2/encode_single_short wall_time 5.282309e-05 5.106692e-05 5.499492e-05 3.718073e-02 13 0
llama/decode_long alloc_words 6.844460e+05 6.844460e+05 6.844460e+05 0.000000e+00 12 0
llama/decode_long cpu_time 1.149214e-03 1.094901e-03 1.194180e-03 4.319437e-02 12 0
llama/decode_long wall_time 1.149682e-03 1.095042e-03 1.189211e-03 4.095449e-02 12 0
llama/encode_batch_32 alloc_words 9.471700e+04 9.471700e+04 9.471700e+04 0.000000e+00 5 0
llama/encode_batch_32 cpu_time 4.421498e-03 4.320283e-03 4.534447e-03 2.421853e-02 5 0
llama/encode_batch_32 wall_time 1.467702e-03 1.366193e-03 1.593832e-03 7.754944e-02 5 2
llama/encode_single_long alloc_words 1.261210e+06 1.261210e+06 1.261210e+06 0.000000e+00 9 1
llama/encode_single_long cpu_time 1.817278e-02 1.788729e-02 1.860684e-02 1.979736e-02 9 2
llama/encode_single_long wall_time 1.819150e-02 1.794107e-02 1.863329e-02 1.902594e-02 9 2
llama/encode_single_short alloc_words 2.344400e+04 2.344400e+04 2.344400e+04 0.000000e+00 42 0
llama/encode_single_short cpu_time 6.126139e-05 6.069695e-05 6.174501e-05 8.553960e-03 42 6
llama/encode_single_short wall_time 6.130598e-05 6.079485e-05 6.183071e-05 8.448240e-03 42 6
================================================
FILE: packages/brot/bench/data/.gitignore
================================================
gpt2.json
bert_base.json
llama.json
================================================
FILE: packages/brot/bench/data/news_1k.txt
================================================
City officials confirmed on Tuesday that the riverside park will reopen this summer after a two-year renovation.
Crews installed 175 energy-efficient lights, replanted native wildflowers, and added a playground designed by local artists.
The project ran $1.8 million under budget, according to Deputy Mayor Alicia Gómez — a welcome surprise for residents concerned about rising taxes.
"It's not just a facelift; it's a commitment to public space," said Gomez. Cyclists tested the new bike lanes, while children chased bubbles during the ribbon-cutting ceremony.
The park will host weekly night markets featuring Afghan bolani, Jamaican patties, and vegan empanadas, with vendors selected through a community ballot.
Public transit advocates noted that the expanded bus schedule, combined with real-time arrival boards, should alleviate weekend congestion.
Sustainability officers also unveiled a solar-powered irrigation system and a pollinator habitat that includes milkweed, lavender, and rare prairie clover.
Early visitor surveys show 92% satisfaction, with many praising the accessible design, tactile maps, and multilingual audio tours available in English, Spanish, Mandarin, and American Sign Language.
The city plans to share open-source blueprints and a detailed maintenance playbook with other municipalities considering similar upgrades.
================================================
FILE: packages/brot/bench/data/wiki_64k.txt
================================================
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday Life ==
Schoolchildren log phenology observations, while retired tram conductors teach visitor orientation classes in a repurposed depot, complete with time-travel escape room puzzles chronicling the town's evolution.
== Early History ==
The settlement traces its roots to a trading village documented in the 12th-century annals of the Seljuk chronicler al-Biruni. Archaeological digs in 1989 uncovered kiln-fired ceramics, copper ingots, and terraced irrigation canals that reshaped historians' understanding of Central Asian trade routes.
== Linguistics ==
Modern dialect surveys reveal a blend of Chuvash, Khazar, and Oghur loanwords; linguists have mapped palatalized consonants appearing near river valleys, likely a relic of seasonal migration.
== Technological Renaissance ==
By 1893 the town hosted one of the earliest wireless telegraph stations in the region. Engineer Lidiya Petrovna retrofitted surplus naval equipment to send meteorological data to Moscow every sunset. Her notebooks — digitized in 2017 — contain meticulous diagrams of spark-gap transmitters, annotations in French, and the occasional doodle of a cat wearing goggles.
== Cultural Revival ==
Annual festivals now feature Tuvan throat singing workshops, VR reconstructions of vanished monasteries, and fermentation labs explaining the chemistry behind kumis. UNESCO added the town's accordion workshops to its intangible heritage list, citing their adaptive use of recycled polymers for reeds.
== Contemporary Research ==
In 2022 a consortium of botanists, data journalists, and Indigenous seed keepers launched the Steppe Observatory, using open satellite data, LoRaWAN sensors, and community weather diaries to forecast dust storms.
== Notable Figures ==
Historian Salome Okafor popularized the settlement after translating 400 folktales into Yoruba, English, and Esperanto, each annotated with QR codes linking to oral history recordings.
== Gastronomy ==
Local chefs pair fermented camel-milk cheese with candied sea buckthorn, while food trucks experiment with kelp-laden naan tacos, reflecting the town's fishing diaspora.
== Climate Adaptation ==
Flood mitigation now involves mycelium-reinforced levees, willow microforests, and AI-optimized sluice gates governed by a civic algorithm crafted in nightly town halls.
== Digital Archives ==
Volunteer coders maintain a mirrored archive stored on solar-powered Raspberry Pi clusters. The archive syncs monthly via a community-owned satellite uplink leased during lunar downtimes.
== Everyday L
================================================
FILE: packages/brot/bench/download_data.sh
================================================
#!/usr/bin/env bash
set -euo pipefail
DATA_DIR="$(cd "$(dirname "$0")/data" && pwd)"
echo "Downloading real-world tokenizer models to $DATA_DIR..."
curl -sL -o "$DATA_DIR/gpt2.json" \
"https://huggingface.co/openai-community/gpt2/resolve/main/tokenizer.json"
echo " GPT-2 (BPE, 50K vocab)"
curl -sL -o "$DATA_DIR/bert_base.json" \
"https://huggingface.co/google-bert/bert-base-uncased/resolve/main/tokenizer.json"
echo " BERT-base (WordPiece, 30K vocab)"
curl -sL -o "$DATA_DIR/llama.json" \
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.json"
echo " LLaMA (BPE, 32K vocab)"
echo "Done."
================================================
FILE: packages/brot/bench/dune
================================================
(data_only_dirs data)
(executable
(name bench_brot)
(libraries brot thumper unix))
(rule
(alias runtest)
(action
(progn
(run %{exe:bench_brot.exe} -q)
(diff? brot.thumper brot.thumper.corrected))))
================================================
FILE: packages/brot/doc/01-getting-started.md
================================================
# Getting Started
This guide covers the basics: encoding text to token IDs, decoding back
to text, configuring the pipeline, and training tokenizers from scratch.
## Installation
```bash
opam install brot
```
Or build from source:
```bash
git clone https://github.com/raven-ml/raven
cd raven && dune build brot
```
## Encoding and Decoding
A tokenizer converts text to token IDs and back. Build one from a
vocabulary and merge rules, then encode and decode:
```ocaml
open Brot
let tokenizer =
bpe
~vocab:
[ ("h", 0); ("e", 1); ("l", 2); ("o", 3); (" ", 4); ("w", 5);
("r", 6); ("d", 7); ("he", 8); ("ll", 9); ("llo", 10);
("hello", 11); ("wo", 12); ("rl", 13); ("rld", 14); ("world", 15) ]
~merges:
[ ("h", "e"); ("l", "l"); ("ll", "o"); ("he", "llo");
("w", "o"); ("r", "l"); ("rl", "d"); ("wo", "rld") ]
()
(* Encode text to an Encoding *)
let encoding = encode tokenizer "hello world"
let ids = Encoding.ids encoding (* [| 11; 4; 15 |] *)
let tokens = Encoding.tokens encoding (* [| "hello"; " "; "world" |] *)
(* Decode back to text *)
let text = decode tokenizer ids (* "hello world" *)
```
`encode` returns an `Encoding.t`. For just the IDs, use `encode_ids`:
```ocaml
open Brot
let tokenizer =
bpe
~vocab:
[ ("h", 0); ("e", 1); ("l", 2); ("o", 3); (" ", 4); ("w", 5);
("r", 6); ("d", 7); ("he", 8); ("ll", 9); ("llo", 10);
("hello", 11); ("wo", 12); ("rl", 13); ("rld", 14); ("world", 15) ]
~merges:
[ ("h", "e"); ("l", "l"); ("ll", "o"); ("he", "llo");
("w", "o"); ("r", "l"); ("rl", "d"); ("wo", "rld") ]
()
let ids = encode_ids tokenizer "hello world" (* [| 11; 4; 15 |] *)
```
## Encoding Output
An `Encoding.t` carries more than just token IDs. Every field is a
parallel array of the same length:
- `ids` — integer token IDs for model input
- `tokens` — string representation of each token
- `offsets` — `(start, end)` byte positions in the original text
- `type_ids` — segment IDs (0 for first sentence, 1 for second in pair tasks)
- `attention_mask` — 1 for real tokens, 0 for padding
- `special_tokens_mask` — 1 for special tokens (`[CLS]`, `[SEP]`, padding), 0 for content
- `word_ids` — maps each token to its source word index, or `None` for special tokens
```ocaml
open Brot
let tokenizer =
wordpiece
~vocab:
[ ("[UNK]", 0); ("[CLS]", 1); ("[SEP]", 2);
("the", 3); ("cat", 4); ("play", 5); ("##ing", 6) ]
~specials:(List.map special [ "[UNK]"; "[CLS]"; "[SEP]" ])
~post:(Post_processor.bert ~cls:("[CLS]", 1) ~sep:("[SEP]", 2) ())
~decoder:(Decoder.wordpiece ())
~pre:(Pre_tokenizer.whitespace ())
~unk_token:"[UNK]" ()
let enc = encode tokenizer "the cat playing"
(* tokens: [| "[CLS]"; "the"; "cat"; "play"; "##ing"; "[SEP]" |] *)
let ids = Encoding.ids enc
let type_ids = Encoding.type_ids enc
let attention_mask = Encoding.attention_mask enc
let special_tokens_mask = Encoding.special_tokens_mask enc
let offsets = Encoding.offsets enc
let word_ids = Encoding.word_ids enc
```
See [Batch Processing](04-batch-processing/) for a deeper look at encoding
metadata, sentence pairs, padding, and truncation.
## The Pipeline
Tokenization proceeds through up to 5 configurable stages:
1. **Normalizer** — text cleanup (lowercase, accent removal, Unicode normalization)
2. **Pre-tokenizer** — split text into pieces with byte offsets
3. **Algorithm** — apply vocabulary-based encoding (BPE, WordPiece, Unigram, etc.)
4. **Post-processor** — add special tokens and set type IDs
5. **Decoder** — reverse the encoding back to text
Each stage is optional. Here is a complete BERT-style pipeline:
```ocaml
open Brot
let tokenizer =
wordpiece
~normalizer:(Normalizer.bert ~lowercase:true ())
~pre:(Pre_tokenizer.bert ())
~post:(Post_processor.bert ~cls:("[CLS]", 1) ~sep:("[SEP]", 2) ())
~decoder:(Decoder.wordpiece ())
~vocab:
[ ("[UNK]", 0); ("[CLS]", 1); ("[SEP]", 2); ("[PAD]", 3);
("the", 4); ("cat", 5); ("sat", 6); ("on", 7);
("play", 8); ("##ing", 9); ("##ed", 10) ]
~specials:(List.map special [ "[UNK]"; "[CLS]"; "[SEP]"; "[PAD]" ])
~unk_token:"[UNK]" ~pad_token:"[PAD]" ()
(* The normalizer lowercases "The Cat" before tokenization *)
let enc = encode tokenizer "The Cat Sat"
let tokens = Encoding.tokens enc
(* [| "[CLS]"; "the"; "cat"; "sat"; "[SEP]" |] *)
(* Decode, skipping special tokens *)
let text = decode tokenizer ~skip_special_tokens:true (Encoding.ids enc)
(* "the cat sat" *)
```
See [The Tokenization Pipeline](02-pipeline/) for a detailed guide to each
stage.
## Training
Train a tokenizer from a text corpus. Brot supports training BPE,
WordPiece, Unigram, and word-level tokenizers:
```ocaml
open Brot
let tokenizer =
train_bpe ~vocab_size:80 ~show_progress:false
(`Seq (List.to_seq
[ "The quick brown fox jumps over the lazy dog";
"The dog barked loudly at the brown fox";
"Quick brown foxes are jumping over lazy dogs";
"The lazy dog slept while the fox jumped" ]))
let size = vocab_size tokenizer
let enc = encode tokenizer "The quick fox"
```
See [Choosing an Algorithm](05-algorithms/) for guidance on which algorithm
to use and how to configure training.
## Loading Pretrained Tokenizers
Load a HuggingFace `tokenizer.json` file:
```ocaml
open Brot
let tokenizer = from_file "tokenizer.json" |> Result.get_ok
let encoding = encode tokenizer "Hello world!"
```
Load from separate vocabulary and merges files:
```ocaml
open Brot
let tokenizer =
from_model_file ~vocab:"vocab.json" ~merges:"merges.txt"
~pre:(Pre_tokenizer.byte_level ~add_prefix_space:false ())
~decoder:(Decoder.byte_level ())
()
```
See [Pretrained Tokenizers](03-pretrained/) for complete pipeline
configurations for BERT, GPT-2, and SentencePiece-style models.
## Batch Processing
Encode multiple texts at once with padding to uniform length:
```ocaml
open Brot
let tokenizer =
train_bpe ~vocab_size:80 ~show_progress:false
~specials:(List.map special [ "[PAD]" ])
~pad_token:"[PAD]"
(`Seq (List.to_seq
[ "The quick brown fox jumps over the lazy dog";
"The dog barked loudly at the brown fox";
"Quick brown foxes are jumping over lazy dogs" ]))
let encodings =
encode_batch tokenizer
~padding:(padding `Batch_longest)
[ "The quick fox"; "The lazy dog barked" ]
(* All encodings now have the same length *)
let lengths = List.map Encoding.length encodings
```
See [Batch Processing](04-batch-processing/) for padding strategies,
truncation, sentence pairs, and offset alignment.
## Next Steps
- [The Tokenization Pipeline](02-pipeline/) — how the 5 pipeline stages work
- [Pretrained Tokenizers](03-pretrained/) — loading, saving, and building known model pipelines
- [Batch Processing](04-batch-processing/) — padding, truncation, encoding metadata
- [Choosing an Algorithm](05-algorithms/) — BPE vs WordPiece vs Unigram and when to use each
================================================
FILE: packages/brot/doc/02-pipeline.md
================================================
# The Tokenization Pipeline
Brot processes text through up to 5 stages, each optional and independently
configurable:
```
text
│
├─ 1. Normalizer — clean and transform text
├─ 2. Pre-tokenizer — split into pieces with byte offsets
├─ 3. Algorithm — map pieces to token IDs (BPE, WordPiece, …)
├─ 4. Post-processor — add special tokens, set type IDs
└─ 5. Decoder — reverse the encoding back to text
│
▼
Encoding.t (ids, tokens, offsets, masks, …)
```
Each stage is set when constructing the tokenizer. Omit any stage and it
is skipped.
## Normalization
Normalizers transform text before tokenization. They handle lowercasing,
accent removal, Unicode normalization, whitespace cleanup, and
model-specific preprocessing.
Available normalizers:
- **Unicode**: `nfc`, `nfd`, `nfkc`, `nfkd`
- **Text transforms**: `lowercase`, `strip_accents`, `strip`, `replace`, `prepend`
- **Byte-level**: `byte_level` (GPT-2 style byte-to-Unicode mapping)
- **Model-specific**: `bert` (clean text, CJK padding, optional lowercasing and accent stripping)
Compose normalizers with `sequence`:
```ocaml
open Brot
let n =
Normalizer.sequence
[ Normalizer.nfd; Normalizer.strip_accents; Normalizer.lowercase ]
let r1 = Normalizer.apply n "Café Résumé" (* "cafe resume" *)
let r2 = Normalizer.apply n "HELLO" (* "hello" *)
```
The BERT normalizer combines several transforms:
```ocaml
open Brot
let n = Normalizer.bert ~lowercase:true ()
(* Lowercases, cleans control characters, pads CJK *)
let r1 = Normalizer.apply n "Hello World" (* "hello world" *)
let r2 = Normalizer.apply n "Café" (* "cafe" *)
```
## Pre-tokenization
Pre-tokenizers split text into pieces before the algorithm runs. Each
piece carries byte offsets into the original text. The algorithm then
tokenizes each piece independently.
Available pre-tokenizers:
| Pre-tokenizer | Description |
| --------------------- | --------------------------------------------------------------- |
| `whitespace ()` | Split on `\w+\|[^\w\s]+` (word chars grouped, non-word grouped) |
| `whitespace_split ()` | Split on whitespace (simplest) |
| `bert ()` | BERT-style: whitespace + punctuation isolation + CJK separation |
| `byte_level ()` | GPT-2 style byte-level encoding with regex splitting |
| `punctuation ()` | Separate punctuation from alphanumeric content |
| `split ~pattern ()` | Split on a literal string pattern |
| `char_delimiter c` | Split on a single character |
| `digits ()` | Split on digit boundaries |
| `metaspace ()` | Replace whitespace with a visible marker (SentencePiece) |
| `unicode_scripts ()` | Split on Unicode script boundaries |
| `fixed_length n` | Fixed-size character chunks |
Use `pre_tokenize` to inspect how a pre-tokenizer splits text. It returns
a list of `(piece, (start_offset, end_offset))` pairs:
```ocaml
open Brot
let text = "Hello, world! How's it going?"
let whitespace_pieces =
Pre_tokenizer.pre_tokenize (Pre_tokenizer.whitespace ()) text
(* [("Hello", (0,5)); (",", (5,6)); ("world", (7,12)); ("!", (12,13)); ...] *)
let bert_pieces =
Pre_tokenizer.pre_tokenize (Pre_tokenizer.bert ()) text
let punct_pieces =
Pre_tokenizer.pre_tokenize (Pre_tokenizer.punctuation ()) text
```
Compose pre-tokenizers with `sequence`. Each pre-tokenizer in the chain
processes the pieces from the previous one:
```ocaml
open Brot
let pre =
Pre_tokenizer.sequence
[ Pre_tokenizer.whitespace_split (); Pre_tokenizer.digits () ]
let pieces = Pre_tokenizer.pre_tokenize pre "order 42 shipped"
(* [("order", _); ("4", _); ("2", _); ("shipped", _)] *)
```
## Tokenization Algorithms
The algorithm maps pre-tokenized pieces to token IDs using the vocabulary.
Brot supports 5 algorithms:
| Algorithm | How it splits | Notable models |
| --------------- | ------------------------------------------- | ------------------------------ |
| BPE | Iterative merge of most frequent pairs | GPT-2, GPT-3/4, RoBERTa, LLaMA |
| WordPiece | Greedy longest-match with `##` prefix | BERT, DistilBERT, Electra |
| Unigram | Probabilistic segmentation (max likelihood) | T5, ALBERT, mBART, XLNet |
| Word-level | Whole words, no subword splitting | Simple models, prototyping |
| Character-level | Each byte is a token | Byte-level fallback |
See [Choosing an Algorithm](05-algorithms/) for details on each algorithm,
when to use it, and how to configure training.
## Post-processing
Post-processors add special tokens and set type IDs after tokenization.
They handle model-specific requirements like `[CLS]`/`[SEP]` for BERT or
``/`` for RoBERTa.
Available post-processors:
- `bert ~sep ~cls ()` — `[CLS] A [SEP]` or `[CLS] A [SEP] B [SEP]`, type IDs 0/1
- `roberta ~sep ~cls ()` — ` A ` or ` A B `, all type IDs 0
- `byte_level ()` — adjust offsets for byte-level encoding
- `template ~single ()` — custom template with `$A`, `$B`, and literal token placeholders
- `sequence processors` — chain multiple post-processors
```ocaml
open Brot
let tokenizer =
wordpiece
~vocab:
[ ("[UNK]", 0); ("[CLS]", 1); ("[SEP]", 2);
("the", 3); ("cat", 4); ("sat", 5); ("how", 6); ("are", 7); ("you", 8) ]
~specials:(List.map special [ "[UNK]"; "[CLS]"; "[SEP]" ])
~pre:(Pre_tokenizer.whitespace ())
~post:(Post_processor.bert ~cls:("[CLS]", 1) ~sep:("[SEP]", 2) ())
~decoder:(Decoder.wordpiece ())
~unk_token:"[UNK]" ()
(* Single sentence: [CLS] the cat sat [SEP] *)
let single = encode tokenizer "the cat sat"
(* Sentence pair: [CLS] the cat sat [SEP] how are you [SEP] *)
let pair = encode tokenizer ~pair:"how are you" "the cat sat"
(* type_ids: 0 for first sentence + [CLS]/[SEP], 1 for second + [SEP] *)
let type_ids = Encoding.type_ids pair
```
The `template` post-processor gives full control over the format. Use `$A`
and `$B` as sequence placeholders, and literal token names in brackets.
Append `:N` to set type IDs:
```ocaml
open Brot
let tokenizer =
word_level
~vocab:
[ ("[BOS]", 0); ("[EOS]", 1); ("hello", 2); ("world", 3) ]
~specials:(List.map special [ "[BOS]"; "[EOS]" ])
~pre:(Pre_tokenizer.whitespace ())
~post:
(Post_processor.template
~single:"[BOS]:0 $A:0 [EOS]:0"
~pair:"[BOS]:0 $A:0 [EOS]:0 $B:1 [EOS]:1"
~special_tokens:[ ("[BOS]", 0); ("[EOS]", 1) ]
())
~unk_token:"[UNK]" ()
let enc = encode tokenizer "hello world"
let tokens = Encoding.tokens enc (* [| "[BOS]"; "hello"; "world"; "[EOS]" |] *)
let type_ids = Encoding.type_ids enc (* [| 0; 0; 0; 0 |] *)
```
## Decoding
Decoders reverse encoding-specific transformations to produce natural text
from token strings. They operate on token *strings* (looked up from the
vocabulary), not IDs.
Decoders fall into two categories:
- **Per-token** — transform each token independently: `bpe`, `byte_fallback`, `metaspace`
- **Collapsing** — process the entire token list as a whole: `byte_level`, `wordpiece`, `replace`, `strip`, `fuse`
This distinction matters when composing with `sequence`: per-token decoders
pass a list of transformed tokens to the next decoder, while collapsing
decoders produce a single result.
Available decoders:
| Decoder | Type | Description |
| ------------------------- | ---------- | ------------------------------------------------ |
| `bpe ()` | Per-token | Strip end-of-word suffix, insert spaces |
| `byte_fallback ()` | Per-token | Convert `<0x41>` hex tokens to bytes |
| `metaspace ()` | Per-token | Convert metaspace markers to spaces |
| `byte_level ()` | Collapsing | Reverse GPT-2 byte-to-Unicode encoding |
| `wordpiece ()` | Collapsing | Strip `##` prefix, join subwords |
| `replace ~pattern ~by ()` | Collapsing | Replace literal pattern in joined text |
| `strip ()` | Collapsing | Remove leading/trailing characters |
| `fuse ()` | Collapsing | Concatenate all tokens with no delimiter |
| `ctc ()` | Per-token | CTC output decoding (deduplication, pad removal) |
```ocaml
open Brot
(* WordPiece decoder: strips ## prefix and joins subwords *)
let wp = Decoder.wordpiece ()
let text = Decoder.decode wp [ "[CLS]"; "play"; "##ing"; "cat"; "##s"; "[SEP]" ]
(* "[CLS] playing cats [SEP]" *)
(* Sequence of decoders *)
let seq = Decoder.sequence [ Decoder.fuse (); Decoder.replace ~pattern:"_" ~by:" " () ]
let text2 = Decoder.decode seq [ "_Hello"; "_world" ]
(* " Hello world" *)
```
When using `Brot.decode`, the tokenizer looks up token strings from the
vocabulary and then applies the configured decoder automatically.
## Complete Example
Here is a complete BERT-style tokenizer using all 5 pipeline stages:
```ocaml
open Brot
let tokenizer =
wordpiece
(* 1. Normalizer: lowercase and clean text *)
~normalizer:(Normalizer.bert ~lowercase:true ())
(* 2. Pre-tokenizer: BERT-style splitting *)
~pre:(Pre_tokenizer.bert ())
(* 3. Algorithm: WordPiece with ## prefix *)
~vocab:
[ ("[PAD]", 0); ("[UNK]", 1); ("[CLS]", 2); ("[SEP]", 3);
("the", 4); ("cat", 5); ("sat", 6); ("on", 7); ("mat", 8);
("play", 9); ("##ing", 10); ("##ed", 11); ("a", 12) ]
~specials:(List.map special [ "[PAD]"; "[UNK]"; "[CLS]"; "[SEP]" ])
~unk_token:"[UNK]" ~pad_token:"[PAD]"
(* 4. Post-processor: add [CLS] and [SEP] *)
~post:(Post_processor.bert ~cls:("[CLS]", 2) ~sep:("[SEP]", 3) ())
(* 5. Decoder: strip ## and join *)
~decoder:(Decoder.wordpiece ())
()
(* "The Cat" is normalized to "the cat" before tokenization *)
let enc = encode tokenizer "The Cat Played On A Mat"
let tokens = Encoding.tokens enc
(* [| "[CLS]"; "the"; "cat"; "play"; "##ed"; "on"; "a"; "mat"; "[SEP]" |] *)
(* Decode back, skipping special tokens *)
let text = decode tokenizer ~skip_special_tokens:true (Encoding.ids enc)
(* "the cat played on a mat" *)
```
================================================
FILE: packages/brot/doc/03-pretrained.md
================================================
# Pretrained Tokenizers
Most users start by loading an existing tokenizer rather than building one
from scratch. Brot reads and writes HuggingFace `tokenizer.json` files and
separate vocabulary/merges model files.
## Loading from tokenizer.json
HuggingFace models ship a `tokenizer.json` that contains the algorithm,
vocabulary, merge rules, and full pipeline configuration. Load it with
`from_file`:
```ocaml
open Brot
let tokenizer = from_file "path/to/tokenizer.json" |> Result.get_ok
let encoding = encode tokenizer "Hello world!"
let ids = Encoding.ids encoding
```
`from_file` returns `(t, string) result`. Handle errors explicitly when
the file may be missing or malformed:
```ocaml
let tokenizer =
match Brot.from_file "tokenizer.json" with
| Ok t -> t
| Error msg -> failwith msg
```
## Loading from Model Files
Older models ship separate `vocab.json` and `merges.txt` files instead
of a single `tokenizer.json`. Use `from_model_file`:
```ocaml
open Brot
(* BPE: provide both vocab and merges *)
let tokenizer =
from_model_file ~vocab:"vocab.json" ~merges:"merges.txt"
~pre:(Pre_tokenizer.byte_level ~add_prefix_space:false ())
~decoder:(Decoder.byte_level ())
()
(* WordPiece: vocab only, no merges *)
let tokenizer =
from_model_file ~vocab:"vocab.txt"
~pre:(Pre_tokenizer.bert ())
~decoder:(Decoder.wordpiece ())
()
```
When `merges` is provided, a BPE tokenizer is created. Without it,
WordPiece is used. The pipeline stages (normalizer, pre-tokenizer,
post-processor, decoder) must be configured explicitly since model files
do not include them.
## Building Known Pipelines
When you need full control over the pipeline or want to understand what
each stage does, build the tokenizer from scratch with an inline
vocabulary. The following examples show the standard configurations for
well-known models.
### BERT (uncased)
BERT uses WordPiece with `##` continuation prefix, BERT normalization
(lowercase, clean text, CJK padding), BERT pre-tokenization (whitespace +
punctuation), and `[CLS]`/`[SEP]` post-processing:
```ocaml
open Brot
let tokenizer =
wordpiece
~vocab:
[ ("[PAD]", 0); ("[UNK]", 1); ("[CLS]", 2); ("[SEP]", 3);
("the", 4); ("cat", 5); ("sat", 6); ("on", 7); ("mat", 8);
("play", 9); ("##ing", 10); ("##ed", 11); ("a", 12);
("is", 13); ("good", 14) ]
~normalizer:(Normalizer.bert ~lowercase:true ())
~pre:(Pre_tokenizer.bert ())
~post:(Post_processor.bert ~cls:("[CLS]", 2) ~sep:("[SEP]", 3) ())
~decoder:(Decoder.wordpiece ())
~specials:(List.map special [ "[PAD]"; "[UNK]"; "[CLS]"; "[SEP]" ])
~unk_token:"[UNK]" ~pad_token:"[PAD]" ()
let enc = encode tokenizer "The Cat Is Playing"
let tokens = Encoding.tokens enc
(* [| "[CLS]"; "the"; "cat"; "is"; "play"; "##ing"; "[SEP]" |] *)
let decoded = decode tokenizer ~skip_special_tokens:true (Encoding.ids enc)
(* "the cat is playing" *)
```
### GPT-2
GPT-2 uses BPE with byte-level pre-tokenization (no information loss,
handles any Unicode input) and byte-level decoding:
```ocaml
open Brot
let tokenizer =
bpe
~vocab:
[ ("H", 0); ("e", 1); ("l", 2); ("o", 3); ("Ġ", 4); ("w", 5);
("r", 6); ("d", 7); ("He", 8); ("ll", 9); ("llo", 10);
("Hello", 11); ("Ġw", 12); ("or", 13); ("ld", 14);
("orld", 15); ("Ġworld", 16) ]
~merges:
[ ("H", "e"); ("l", "l"); ("ll", "o"); ("He", "llo");
("Ġ", "w"); ("o", "r"); ("l", "d"); ("or", "ld");
("Ġw", "orld") ]
~pre:(Pre_tokenizer.byte_level ~add_prefix_space:false ())
~decoder:(Decoder.byte_level ())
()
let enc = encode tokenizer "Hello world"
let tokens = Encoding.tokens enc (* [| "Hello"; "Ġworld" |] *)
let decoded = decode tokenizer (Encoding.ids enc) (* "Hello world" *)
```
### SentencePiece-style (T5, ALBERT)
SentencePiece models use Unigram with metaspace pre-tokenization (spaces
replaced by a visible marker) and metaspace decoding:
```ocaml
open Brot
let tokenizer =
unigram
~vocab:
[ ("", -1.0); ("\xe2\x96\x81", -2.0);
("\xe2\x96\x81the", -1.5); ("\xe2\x96\x81cat", -1.8);
("\xe2\x96\x81is", -1.6); ("\xe2\x96\x81play", -2.0);
("ing", -2.5); ("\xe2\x96\x81a", -1.4); ("\xe2\x96\x81good", -2.1) ]
~pre:(Pre_tokenizer.metaspace ~replacement:'\xe2' ())
~decoder:(Decoder.metaspace ~replacement:'\xe2' ())
~unk_token:"" ()
let enc = encode tokenizer "the cat is playing"
```
## Saving Tokenizers
Save a tokenizer in HuggingFace format for later use or sharing:
```ocaml
(* Save as tokenizer.json (full pipeline) *)
Brot.save_pretrained tokenizer ~path:"./my_tokenizer"
(* Save just the vocabulary and merges files *)
let files = Brot.save_model_files tokenizer ~folder:"./model" ()
(* Export BPE merges in tiktoken format *)
Brot.export_tiktoken tokenizer
~merges_path:"./tiktoken_merges.txt"
~vocab_path:"./tiktoken_vocab.txt"
```
## Training from Scratch
Train a tokenizer from a text corpus. Configure the full pipeline
alongside the training parameters:
```ocaml
open Brot
let tokenizer =
train_bpe
~vocab_size:120
~min_frequency:1
~show_progress:false
~pre:(Pre_tokenizer.whitespace ())
~specials:(List.map special [ "[PAD]"; "[UNK]" ])
~unk_token:"[UNK]" ~pad_token:"[PAD]"
(`Seq (List.to_seq
[ "The quick brown fox jumps over the lazy dog.";
"Machine learning models need good tokenizers.";
"Subword tokenization handles unknown words gracefully.";
"The fox jumped over the lazy dog again.";
"Tokenizers convert text to numerical representations." ]))
let size = vocab_size tokenizer
let enc = encode tokenizer "The quick fox"
```
See [Choosing an Algorithm](05-algorithms/) for guidance on which algorithm
to train and how to tune parameters like `vocab_size`, `min_frequency`,
and algorithm-specific options.
================================================
FILE: packages/brot/doc/04-batch-processing.md
================================================
# Batch Processing
Real-world usage requires encoding multiple texts into uniform-length
sequences for model input. This guide covers encoding metadata, sentence
pairs, batch encoding, padding, truncation, and offset alignment.
## Encoding Metadata
`Encoding.t` carries parallel arrays that all share the same length. Each
field serves a specific purpose in model input preparation:
| Field | Type | Description |
| --------------------- | ------------------- | ----------------------------------------------- |
| `ids` | `int array` | Token IDs for model input |
| `tokens` | `string array` | String representation of each token |
| `offsets` | `(int * int) array` | `(start, end)` byte positions in source text |
| `type_ids` | `int array` | Segment IDs: 0 for sentence A, 1 for sentence B |
| `attention_mask` | `int array` | 1 for real tokens, 0 for padding |
| `special_tokens_mask` | `int array` | 1 for special tokens, 0 for content |
| `word_ids` | `int option array` | Source word index, or `None` for special tokens |
```ocaml
open Brot
let tokenizer =
wordpiece
~vocab:
[ ("[UNK]", 0); ("[CLS]", 1); ("[SEP]", 2);
("the", 3); ("cat", 4); ("play", 5); ("##ing", 6) ]
~specials:(List.map special [ "[UNK]"; "[CLS]"; "[SEP]" ])
~pre:(Pre_tokenizer.whitespace ())
~post:(Post_processor.bert ~cls:("[CLS]", 1) ~sep:("[SEP]", 2) ())
~decoder:(Decoder.wordpiece ())
~unk_token:"[UNK]" ()
let enc = encode tokenizer "the cat playing"
(* tokens: [| "[CLS]"; "the"; "cat"; "play"; "##ing"; "[SEP]" |] *)
let ids = Encoding.ids enc
let type_ids = Encoding.type_ids enc
let attention_mask = Encoding.attention_mask enc
let special_tokens_mask = Encoding.special_tokens_mask enc
let offsets = Encoding.offsets enc
let word_ids = Encoding.word_ids enc
(* word_ids: [| None; Some 0; Some 1; Some 2; Some 2; None |]
"play" and "##ing" share word index 2 *)
```
## Sentence Pairs
Many NLP tasks (question answering, natural language inference, sentence
similarity) operate on pairs of sentences. Use `encode ~pair` to encode
both sequences together:
```ocaml
open Brot
let tokenizer =
wordpiece
~vocab:
[ ("[UNK]", 0); ("[CLS]", 1); ("[SEP]", 2);
("the", 3); ("cat", 4); ("sat", 5); ("how", 6);
("are", 7); ("you", 8) ]
~specials:(List.map special [ "[UNK]"; "[CLS]"; "[SEP]" ])
~pre:(Pre_tokenizer.whitespace ())
~post:(Post_processor.bert ~cls:("[CLS]", 1) ~sep:("[SEP]", 2) ())
~decoder:(Decoder.wordpiece ())
~unk_token:"[UNK]" ()
let enc = encode tokenizer ~pair:"how are you" "the cat sat"
(* tokens: [| "[CLS]"; "the"; "cat"; "sat"; "[SEP]"; "how"; "are"; "you"; "[SEP]" |] *)
let type_ids = Encoding.type_ids enc
(* [| 0; 0; 0; 0; 0; 1; 1; 1; 1 |] *)
```
Type IDs distinguish the two sentences: 0 for the first sequence
(including `[CLS]` and first `[SEP]`), 1 for the second (including
final `[SEP]`).
## Batch Encoding
Encode multiple texts at once with `encode_batch`, or multiple sentence
pairs with `encode_pairs_batch`:
```ocaml
open Brot
let tokenizer =
wordpiece
~vocab:
[ ("[UNK]", 0); ("[CLS]", 1); ("[SEP]", 2);
("the", 3); ("cat", 4); ("sat", 5);
("how", 6); ("are", 7); ("you", 8); ("good", 9) ]
~specials:(List.map special [ "[UNK]"; "[CLS]"; "[SEP]" ])
~pre:(Pre_tokenizer.whitespace ())
~post:(Post_processor.bert ~cls:("[CLS]", 1) ~sep:("[SEP]", 2) ())
~decoder:(Decoder.wordpiece ())
~unk_token:"[UNK]" ()
(* Batch of single sentences *)
let encodings =
encode_batch tokenizer [ "the cat"; "the cat sat"; "good" ]
let lengths = List.map Encoding.length encodings
(* [4; 5; 3] — each includes [CLS] and [SEP] *)
(* Batch of sentence pairs *)
let pairs =
encode_pairs_batch tokenizer
[ ("the cat sat", "how are you"); ("good", "the cat") ]
```
## Padding
Models require uniform sequence lengths within a batch. Padding extends
shorter sequences with padding tokens. Three strategies are available:
- **`Batch_longest`** — pad to the longest sequence in the batch
- **`Fixed n`** — pad every sequence to exactly `n` tokens
- **`To_multiple n`** — pad to the smallest multiple of `n` that fits
Padding tokens have `attention_mask = 0` and `special_tokens_mask = 1`.
```ocaml
open Brot
let tokenizer =
word_level
~vocab:
[ ("[PAD]", 0); ("[UNK]", 1); ("the", 2); ("cat", 3);
("sat", 4); ("on", 5); ("a", 6); ("mat", 7) ]
~specials:(List.map special [ "[PAD]"; "[UNK]" ])
~pre:(Pre_tokenizer.whitespace ())
~unk_token:"[UNK]" ~pad_token:"[PAD]" ()
let texts = [ "the cat"; "the cat sat on a mat"; "cat" ]
(* Pad to longest in batch — all encodings have length 6 *)
let batch1 =
encode_batch tokenizer ~padding:(padding `Batch_longest) texts
(* Pad to fixed length — all encodings have length 8 *)
let batch2 =
encode_batch tokenizer ~padding:(padding (`Fixed 8)) texts
(* Pad to multiple of 4 — lengths rounded up to nearest multiple *)
let batch3 =
encode_batch tokenizer ~padding:(padding (`To_multiple 4)) texts
```
By default, padding is applied to the right. Use `` ~direction:`Left ``
for left-padding, which is common for autoregressive generation:
```ocaml
open Brot
let tokenizer =
word_level
~vocab:
[ ("[PAD]", 0); ("[UNK]", 1); ("the", 2); ("cat", 3); ("sat", 4) ]
~specials:(List.map special [ "[PAD]"; "[UNK]" ])
~pre:(Pre_tokenizer.whitespace ())
~unk_token:"[UNK]" ~pad_token:"[PAD]" ()
let encodings =
encode_batch tokenizer
~padding:(padding ~direction:`Left (`Fixed 5))
[ "the cat"; "the cat sat" ]
(* tokens: [| "[PAD]"; "[PAD]"; "[PAD]"; "the"; "cat" |]
[| "[PAD]"; "[PAD]"; "the"; "cat"; "sat" |] *)
```
## Truncation
Truncation limits sequences to a maximum length. Excess tokens are
trimmed from the specified direction:
```ocaml
open Brot
let tokenizer =
word_level
~vocab:
[ ("[UNK]", 0); ("the", 1); ("quick", 2); ("brown", 3);
("fox", 4); ("jumps", 5); ("over", 6) ]
~specials:(List.map special [ "[UNK]" ])
~pre:(Pre_tokenizer.whitespace ())
~unk_token:"[UNK]" ()
let text = "the quick brown fox jumps over"
(* Truncate from the right (default) *)
let enc_right = encode tokenizer ~truncation:(truncation 4) text
let tokens_right = Encoding.tokens enc_right
(* [| "the"; "quick"; "brown"; "fox" |] *)
(* Truncate from the left *)
let enc_left =
encode tokenizer ~truncation:(truncation ~direction:`Left 4) text
let tokens_left = Encoding.tokens enc_left
(* [| "brown"; "fox"; "jumps"; "over" |] *)
```
When using a post-processor that adds special tokens, account for the
tokens it adds. Use `Post_processor.added_tokens` to calculate the
budget:
```ocaml
open Brot
let post = Post_processor.bert ~cls:("[CLS]", 1) ~sep:("[SEP]", 2) ()
let added_single = Post_processor.added_tokens post ~is_pair:false (* 2 *)
let added_pair = Post_processor.added_tokens post ~is_pair:true (* 3 *)
```
## Padding and Truncation Together
The common pattern for model input: truncate long sequences and pad short
ones to a uniform length:
```ocaml
open Brot
let tokenizer =
word_level
~vocab:
[ ("[PAD]", 0); ("[UNK]", 1); ("the", 2); ("cat", 3);
("sat", 4); ("on", 5); ("a", 6); ("mat", 7);
("dog", 8); ("ran", 9); ("fast", 10) ]
~specials:(List.map special [ "[PAD]"; "[UNK]" ])
~pre:(Pre_tokenizer.whitespace ())
~unk_token:"[UNK]" ~pad_token:"[PAD]" ()
let encodings =
encode_batch tokenizer
~truncation:(truncation 4)
~padding:(padding (`Fixed 4))
[ "the cat sat on a mat"; "the dog ran"; "cat" ]
(* All encodings have exactly 4 tokens.
Long sequences are truncated, short ones are padded.
attention_mask distinguishes real tokens (1) from padding (0). *)
let masks = List.map Encoding.attention_mask encodings
```
## Offsets and Alignment
`Encoding.offsets` maps each token back to its `(start, end)` byte span
in the original text. This is useful for tasks like named entity
recognition where you need to extract the source text for each token:
```ocaml
open Brot
let tokenizer =
wordpiece
~vocab:
[ ("[UNK]", 0); ("hello", 1); ("world", 2);
("play", 3); ("##ing", 4) ]
~pre:(Pre_tokenizer.whitespace ())
~decoder:(Decoder.wordpiece ())
~unk_token:"[UNK]" ()
let text = "hello playing world"
let enc = encode tokenizer text
let offsets = Encoding.offsets enc
(* offsets.(0) = (0, 5) -> "hello"
offsets.(1) = (6, 13) -> "playing" (start of "play")
offsets.(2) = (6, 13) -> "playing" (extent of "##ing")
offsets.(3) = (14, 19) -> "world" *)
(* Extract source span for a token *)
let start, end_ = offsets.(0)
let source = String.sub text start (end_ - start) (* "hello" *)
```
`Encoding.word_ids` groups subword tokens back to their source word.
Tokens that belong to the same word share the same word index:
```ocaml
open Brot
let tokenizer =
wordpiece
~vocab:
[ ("[UNK]", 0); ("the", 1); ("cat", 2);
("play", 3); ("##ing", 4); ("##s", 5) ]
~pre:(Pre_tokenizer.whitespace ())
~decoder:(Decoder.wordpiece ())
~unk_token:"[UNK]" ()
let enc = encode tokenizer "the cat playing"
let word_ids = Encoding.word_ids enc
(* [| Some 0; Some 1; Some 2; Some 2 |]
"play" and "##ing" share word index 2,
indicating they come from the same source word *)
```
================================================
FILE: packages/brot/doc/05-algorithms.md
================================================
# Choosing a Tokenization Algorithm
Brot supports 5 tokenization algorithms. The three subword algorithms
(BPE, WordPiece, Unigram) handle open vocabulary by splitting rare words
into smaller pieces. Word-level and character-level are simpler
alternatives.
## BPE (Byte Pair Encoding)
BPE starts with individual characters and iteratively merges the most
frequent adjacent pairs. The merge rules, learned during training, define
how text is split. Used by GPT-2, GPT-3/4, RoBERTa, and LLaMA.
Constructor: `Brot.bpe`. Trainer: `Brot.train_bpe`.
Key parameters:
- `vocab_size` — target vocabulary size (default: 30000)
- `min_frequency` — minimum pair frequency for merging (default: 0)
- `dropout` — probability of skipping merges for data augmentation
- `byte_fallback` — use `<0x00>` byte tokens instead of unknown token
- `continuing_subword_prefix` — prefix for non-initial subwords
- `end_of_word_suffix` — suffix marking word boundaries (e.g., ``)
```ocaml
open Brot
let tokenizer =
bpe
~vocab:
[ ("h", 0); ("e", 1); ("l", 2); ("o", 3); (" ", 4); ("w", 5);
("r", 6); ("d", 7); ("he", 8); ("ll", 9); ("llo", 10);
("hello", 11); ("wo", 12); ("rl", 13); ("rld", 14); ("world", 15) ]
~merges:
[ ("h", "e"); ("l", "l"); ("ll", "o"); ("he", "llo");
("w", "o"); ("r", "l"); ("rl", "d"); ("wo", "rld") ]
()
let enc = encode tokenizer "hello world"
let tokens = Encoding.tokens enc (* [| "hello"; " "; "world" |] *)
```
Training BPE:
```ocaml
open Brot
let tokenizer =
train_bpe ~vocab_size:80 ~min_frequency:1 ~show_progress:false
(`Seq (List.to_seq
[ "The quick brown fox jumps over the lazy dog";
"The dog barked at the brown fox";
"Quick brown foxes are rare and beautiful" ]))
let size = vocab_size tokenizer
let enc = encode tokenizer "The brown fox"
```
## WordPiece
WordPiece uses a greedy longest-match-first algorithm. For each word, it
finds the longest prefix in the vocabulary, then continues with the
remainder prefixed by a continuation marker (default: `##`). Used by BERT,
DistilBERT, and Electra.
Constructor: `Brot.wordpiece`. Trainer: `Brot.train_wordpiece`.
Key parameters:
- `vocab_size` — target vocabulary size (default: 30000)
- `continuing_subword_prefix` — prefix for non-initial subwords (default: `##`)
- `max_input_chars_per_word` — words longer than this become unknown (default: 100)
```ocaml
open Brot
let tokenizer =
wordpiece
~vocab:
[ ("[UNK]", 0); ("the", 1); ("cat", 2); ("play", 3);
("##ing", 4); ("##ed", 5); ("##s", 6); ("un", 7);
("##happy", 8); ("##ly", 9) ]
~pre:(Pre_tokenizer.whitespace ())
~decoder:(Decoder.wordpiece ())
~unk_token:"[UNK]" ()
let enc = encode tokenizer "the cat playing unhappily"
let tokens = Encoding.tokens enc
(* [| "the"; "cat"; "play"; "##ing"; "un"; "##happy"; "##ly" |] *)
let decoded = decode tokenizer (Encoding.ids enc)
(* "the cat playing unhappily" *)
```
Training WordPiece:
```ocaml
open Brot
let tokenizer =
train_wordpiece ~vocab_size:80 ~show_progress:false
(`Seq (List.to_seq
[ "The quick brown fox jumps over the lazy dog";
"The dog barked at the brown fox";
"Quick brown foxes are rare and beautiful" ]))
let size = vocab_size tokenizer
let enc = encode tokenizer "The brown fox"
```
## Unigram
Unigram uses probabilistic segmentation: given a vocabulary of subwords
with log-probabilities, it finds the segmentation that maximizes the
total likelihood. Training uses the EM algorithm to iteratively prune the
vocabulary. Used by T5, ALBERT, mBART, and XLNet.
Constructor: `Brot.unigram`. Trainer: `Brot.train_unigram`.
Key parameters:
- `vocab_size` — target vocabulary size (default: 8000)
- `shrinking_factor` — fraction of vocabulary to retain per pruning round (default: 0.75)
- `max_piece_length` — maximum subword length (default: 16)
- `n_sub_iterations` — EM sub-iterations per pruning round (default: 2)
Vocabulary entries are `(token, score)` pairs where scores are negative
log probabilities:
```ocaml
open Brot
let tokenizer =
unigram
~vocab:
[ ("", 0.0); ("the", -1.0); ("cat", -1.5);
("th", -2.0); ("e", -2.5); ("c", -3.0); ("a", -3.0);
("t", -3.0); ("at", -2.0); ("he", -2.0);
("sat", -1.8); ("on", -1.5) ]
~unk_token:"" ()
let enc = encode tokenizer "the cat sat on"
```
Training Unigram:
```ocaml
open Brot
let tokenizer =
train_unigram ~vocab_size:60 ~show_progress:false
(`Seq (List.to_seq
[ "The quick brown fox jumps over the lazy dog";
"The dog barked at the brown fox";
"Quick brown foxes are rare and beautiful" ]))
let size = vocab_size tokenizer
let enc = encode tokenizer "The brown fox"
```
## Word-level
Word-level tokenization maps each word directly to a token ID. No
subword splitting is performed — words not in the vocabulary are replaced
by the unknown token.
Constructor: `Brot.word_level`. Trainer: `Brot.train_wordlevel`.
Best suited for small controlled vocabularies and prototyping. For
production use with open vocabulary, prefer a subword algorithm.
When no pre-tokenizer is specified, `word_level` defaults to
`Pre_tokenizer.whitespace`.
```ocaml
open Brot
let tokenizer =
word_level
~vocab:
[ ("[UNK]", 0); ("the", 1); ("cat", 2); ("sat", 3);
("on", 4); ("a", 5); ("mat", 6) ]
~unk_token:"[UNK]" ()
(* Known words get their IDs, unknown words become [UNK] *)
let enc = encode tokenizer "the cat sat on a rug"
let tokens = Encoding.tokens enc
(* [| "the"; "cat"; "sat"; "on"; "a"; "[UNK]" |] *)
let ids = Encoding.ids enc
(* [| 1; 2; 3; 4; 5; 0 |] *)
```
## Character-level
Character-level tokenization maps each byte to a token with ID equal to
its ordinal value. No vocabulary or training is needed.
Constructor: `Brot.chars`.
Useful as a byte-level fallback or for models that operate directly on
characters:
```ocaml
open Brot
let tokenizer = chars ()
let enc = encode tokenizer "Hi!"
let tokens = Encoding.tokens enc (* [| "H"; "i"; "!" |] *)
let ids = Encoding.ids enc (* [| 72; 105; 33 |] *)
```
## Quick Reference
| Algorithm | Splitting strategy | Typical vocab | Notable models | Constructor | Trainer |
| --------------- | ----------------------------------------- | ------------- | ------------------------- | ------------ | ----------------- |
| BPE | Iterative merge of frequent pairs | 30K-50K | GPT-2, RoBERTa, LLaMA | `bpe` | `train_bpe` |
| WordPiece | Greedy longest-match with `##` prefix | 30K | BERT, DistilBERT, Electra | `wordpiece` | `train_wordpiece` |
| Unigram | Probabilistic max-likelihood segmentation | 8K-32K | T5, ALBERT, mBART, XLNet | `unigram` | `train_unigram` |
| Word-level | Whole words, no splitting | Varies | Simple models | `word_level` | `train_wordlevel` |
| Character-level | Each byte is a token | 256 | Byte-level models | `chars` | — |
================================================
FILE: packages/brot/doc/06-hf-tokenizers-comparison.md
================================================
# Brot vs. HuggingFace Tokenizers -- A Practical Comparison
This guide explains how Brot relates to Python's [HuggingFace Tokenizers](https://github.com/huggingface/tokenizers), focusing on:
* How core concepts map (tokenizer types, pipeline stages, encoding results)
* Where the APIs feel similar vs. deliberately different
* How to translate common HuggingFace patterns into Brot
If you already use HuggingFace Tokenizers, this should be enough to become productive in Brot quickly.
---
## 1. Big-Picture Differences
| Aspect | HuggingFace Tokenizers (Python) | Brot (OCaml) |
| ------------------ | ---------------------------------------------------- | ----------------------------------------------------------------------------- |
| Language | Python bindings over Rust | Native OCaml |
| Core type | `tokenizers.Tokenizer` | `Brot.t` |
| Encoding result | `tokenizers.Encoding` | `Encoding.t` |
| Algorithms | `BPE`, `WordPiece`, `Unigram`, `WordLevel` | `Brot.bpe`, `Brot.wordpiece`, `Brot.unigram`, `Brot.word_level`, `Brot.chars` |
| Pipeline stages | Mutable properties on `Tokenizer` object | Immutable `~normalizer`, `~pre`, `~post`, `~decoder` args |
| Mutability | Tokenizer is mutable (set properties after creation) | Tokenizer is immutable after creation |
| HuggingFace compat | Native format | Full `tokenizer.json` read/write via `from_file`/`save_pretrained` |
| Training | `Trainer` objects passed to `tokenizer.train()` | `Brot.train_bpe`, `Brot.train_wordpiece`, etc. |
| Padding config | `tokenizer.enable_padding()` | `~padding` arg on `encode`/`encode_batch` |
| Truncation config | `tokenizer.enable_truncation()` | `~truncation` arg on `encode`/`encode_batch` |
**Brot semantics to know (read once):**
- Tokenizers are immutable. Pipeline components are set at construction time, not mutated after.
- `from_file` returns `(t, string) result`. Handle errors explicitly.
- Padding and truncation are per-call parameters, not global tokenizer state.
- Special tokens use a record type (`Brot.special`) with explicit control over stripping and normalization.
- `encode` returns `Encoding.t`; use `encode_ids` when you only need the ID array.
---
## 2. Loading Pretrained Tokenizers
### 2.1 From a tokenizer.json file
**HuggingFace**
```python
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file("tokenizer.json")
```
**Brot**
```ocaml
let tokenizer = Brot.from_file "tokenizer.json" |> Result.get_ok
```
Both read the same `tokenizer.json` format. Brot's `from_file` returns a `result` instead of raising an exception.
### 2.2 From vocabulary and merges files
**HuggingFace**
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
tokenizer = Tokenizer(BPE.from_file("vocab.json", "merges.txt"))
```
**Brot**
```ocaml
let tokenizer =
Brot.from_model_file
~vocab:"vocab.json"
~merges:"merges.txt"
()
```
When `~merges` is omitted, Brot infers WordPiece instead of BPE.
### 2.3 Saving
**HuggingFace**
```python
tokenizer.save("tokenizer.json")
```
**Brot**
```ocaml
Brot.save_pretrained tokenizer ~path:"./my_tokenizer"
```
`save_pretrained` creates `path/tokenizer.json` in HuggingFace format. Use `to_json` when you need the JSON value directly.
---
## 3. Encoding Text
### 3.1 Basic encoding
**HuggingFace**
```python
output = tokenizer.encode("Hello world!")
output.ids # [101, 7592, 2088, 999, 102]
output.tokens # ['[CLS]', 'hello', 'world', '!', '[SEP]']
output.offsets # [(0, 0), (0, 5), (6, 11), (11, 12), (0, 0)]
output.type_ids # [0, 0, 0, 0, 0]
output.attention_mask # [1, 1, 1, 1, 1]
```
**Brot**
```ocaml
let enc = Brot.encode tokenizer "Hello world!"
let ids = Encoding.ids enc (* int array *)
let toks = Encoding.tokens enc (* string array *)
let offs = Encoding.offsets enc (* (int * int) array *)
let types = Encoding.type_ids enc (* int array *)
let mask = Encoding.attention_mask enc (* int array *)
```
### 3.2 IDs only
**HuggingFace**
```python
ids = tokenizer.encode("Hello world!").ids
```
**Brot**
```ocaml
let ids = Brot.encode_ids tokenizer "Hello world!"
```
`encode_ids` is a shortcut that avoids constructing the full `Encoding.t` when you only need token IDs.
### 3.3 Without special tokens
**HuggingFace**
```python
output = tokenizer.encode("Hello world!", add_special_tokens=False)
```
**Brot**
```ocaml
let enc = Brot.encode tokenizer ~add_special_tokens:false "Hello world!"
```
---
## 4. Decoding
### 4.1 Basic decoding
**HuggingFace**
```python
text = tokenizer.decode([101, 7592, 2088, 999, 102])
text_clean = tokenizer.decode([101, 7592, 2088, 999, 102], skip_special_tokens=True)
```
**Brot**
```ocaml
let text = Brot.decode tokenizer [| 101; 7592; 2088; 999; 102 |]
let text_clean =
Brot.decode tokenizer ~skip_special_tokens:true
[| 101; 7592; 2088; 999; 102 |]
```
### 4.2 Batch decoding
**HuggingFace**
```python
texts = tokenizer.decode_batch([[101, 7592, 102], [101, 2088, 102]])
```
**Brot**
```ocaml
let texts =
Brot.decode_batch tokenizer
[ [| 101; 7592; 102 |]; [| 101; 2088; 102 |] ]
```
---
## 5. Batch Encoding
**HuggingFace**
```python
outputs = tokenizer.encode_batch(["Hello world!", "How are you?"])
# outputs is a list of Encoding objects
for enc in outputs:
print(enc.ids)
```
**Brot**
```ocaml
let encodings =
Brot.encode_batch tokenizer
[ "Hello world!"; "How are you?" ]
let () =
List.iter
(fun enc ->
let ids = Encoding.ids enc in
Array.iter (Printf.printf "%d ") ids;
print_newline ())
encodings
```
Both return a list of encoding objects, one per input.
---
## 6. Padding and Truncation
### 6.1 Padding
In HuggingFace, padding is global state on the tokenizer. In Brot, it is a per-call parameter.
**HuggingFace**
```python
tokenizer.enable_padding(
direction="right",
pad_id=0,
pad_token="[PAD]",
length=128, # fixed length
)
output = tokenizer.encode("Hello")
# output.attention_mask shows 0s for padding positions
```
**Brot**
```ocaml
let pad = Brot.padding ~pad_id:0 ~pad_token:"[PAD]" (`Fixed 128)
let enc = Brot.encode tokenizer ~padding:pad "Hello"
(* Encoding.attention_mask enc has 0s for padding positions *)
```
Padding strategies:
| HuggingFace | Brot |
| --------------------------------------- | ----------------------------- |
| `length=None` (pad to longest in batch) | `` `Batch_longest `` |
| `length=128` (fixed) | `` `Fixed 128 `` |
| `pad_to_multiple_of=8` | `` `To_multiple 8 `` |
| `direction="left"` | `~direction:`Left` |
| `direction="right"` (default) | `~direction:`Right` (default) |
### 6.2 Truncation
**HuggingFace**
```python
tokenizer.enable_truncation(max_length=512, direction="right")
output = tokenizer.encode("Very long text ...")
```
**Brot**
```ocaml
let trunc = Brot.truncation 512
let enc = Brot.encode tokenizer ~truncation:trunc "Very long text ..."
```
Truncation direction defaults to `` `Right `` in both libraries.
### 6.3 Combined padding and truncation
**HuggingFace**
```python
tokenizer.enable_padding(length=512, pad_token="[PAD]", pad_id=0)
tokenizer.enable_truncation(max_length=512)
outputs = tokenizer.encode_batch(texts)
```
**Brot**
```ocaml
let pad = Brot.padding ~pad_token:"[PAD]" ~pad_id:0 (`Fixed 512)
let trunc = Brot.truncation 512
let encodings =
Brot.encode_batch tokenizer ~padding:pad ~truncation:trunc texts
```
The key difference: Brot passes these as arguments, so different calls can use different settings without mutating the tokenizer.
---
## 7. Sentence Pairs
**HuggingFace**
```python
# Single pair
output = tokenizer.encode("premise", "hypothesis")
output.type_ids # [0, 0, 0, 0, 1, 1, 1] (with BERT post-processor)
# Batch of pairs
outputs = tokenizer.encode_batch([("premise1", "hyp1"), ("premise2", "hyp2")])
```
**Brot**
```ocaml
(* Single pair *)
let enc = Brot.encode tokenizer ~pair:"hypothesis" "premise"
let type_ids = Encoding.type_ids enc (* 0s for first, 1s for second *)
(* Batch of pairs *)
let encodings =
Brot.encode_pairs_batch tokenizer
[ ("premise1", "hyp1"); ("premise2", "hyp2") ]
```
Brot uses the `~pair` optional argument on `encode` for single pairs and a dedicated `encode_pairs_batch` for batches, instead of overloading the same function with tuples.
---
## 8. Special Tokens
### 8.1 Defining special tokens
**HuggingFace**
```python
from tokenizers import AddedToken
tokenizer.add_special_tokens([
AddedToken("[CLS]", single_word=False, lstrip=False, rstrip=False),
AddedToken("[SEP]", single_word=False, lstrip=False, rstrip=False),
AddedToken("[PAD]", single_word=False, lstrip=False, rstrip=False),
])
```
**Brot**
```ocaml
let tokenizer =
Brot.bpe
~specials:[
Brot.special "[CLS]";
Brot.special "[SEP]";
Brot.special "[PAD]";
]
~pad_token:"[PAD]"
~bos_token:"[CLS]"
~eos_token:"[SEP]"
()
```
In HuggingFace, special tokens are added after construction. In Brot, they are part of construction since tokenizers are immutable. The `special` function accepts optional `~single_word`, `~lstrip`, `~rstrip`, and `~normalized` parameters matching `AddedToken`.
### 8.2 Role tokens
**HuggingFace**
```python
tokenizer.pad_token # "[PAD]"
tokenizer.cls_token # "[CLS]"
tokenizer.sep_token # "[SEP]"
tokenizer.unk_token # "[UNK]"
```
**Brot**
```ocaml
let pad = Brot.pad_token tokenizer (* string option *)
let bos = Brot.bos_token tokenizer (* string option *)
let eos = Brot.eos_token tokenizer (* string option *)
let unk = Brot.unk_token tokenizer (* string option *)
```
Brot uses `bos_token`/`eos_token` instead of `cls_token`/`sep_token` since these are model-agnostic roles. They return `option` instead of raising on missing tokens.
### 8.3 Special tokens mask
Both libraries provide a mask distinguishing special tokens from content tokens in the encoding:
**HuggingFace**
```python
output.special_tokens_mask # [1, 0, 0, 0, 1]
```
**Brot**
```ocaml
let mask = Encoding.special_tokens_mask enc (* int array: 1 for special, 0 for content *)
```
---
## 9. Pipeline Components
Both libraries use the same four-stage pipeline: normalizer, pre-tokenizer, post-processor, decoder. The difference is how they are configured.
### 9.1 Normalizer
**HuggingFace**
```python
from tokenizers import normalizers
tokenizer.normalizer = normalizers.Sequence([
normalizers.NFD(),
normalizers.StripAccents(),
normalizers.Lowercase(),
])
```
**Brot**
```ocaml
let norm =
Normalizer.sequence
[ Normalizer.nfd; Normalizer.strip_accents; Normalizer.lowercase ]
let tokenizer = Brot.bpe ~normalizer:norm ()
```
Common normalizers:
| HuggingFace | Brot |
| ----------------------------------- | ------------------------------------------ |
| `normalizers.NFC()` | `Normalizer.nfc` |
| `normalizers.NFD()` | `Normalizer.nfd` |
| `normalizers.NFKC()` | `Normalizer.nfkc` |
| `normalizers.NFKD()` | `Normalizer.nfkd` |
| `normalizers.Lowercase()` | `Normalizer.lowercase` |
| `normalizers.StripAccents()` | `Normalizer.strip_accents` |
| `normalizers.Strip()` | `Normalizer.strip ()` |
| `normalizers.Replace(pattern, rep)` | `Normalizer.replace ~pattern ~replacement` |
| `normalizers.Prepend(s)` | `Normalizer.prepend s` |
| `normalizers.BertNormalizer()` | `Normalizer.bert ()` |
| `normalizers.ByteLevel()` | `Normalizer.byte_level ()` |
| `normalizers.Sequence([...])` | `Normalizer.sequence [...]` |
### 9.2 Pre-tokenizer
**HuggingFace**
```python
from tokenizers import pre_tokenizers
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.WhitespaceSplit(),
pre_tokenizers.Punctuation(),
])
```
**Brot**
```ocaml
let pre =
Pre_tokenizer.sequence
[ Pre_tokenizer.whitespace_split ();
Pre_tokenizer.punctuation () ]
let tokenizer = Brot.bpe ~pre ()
```
Common pre-tokenizers:
| HuggingFace | Brot |
| -------------------------------------- | ----------------------------------- |
| `pre_tokenizers.Whitespace()` | `Pre_tokenizer.whitespace ()` |
| `pre_tokenizers.WhitespaceSplit()` | `Pre_tokenizer.whitespace_split ()` |
| `pre_tokenizers.BertPreTokenizer()` | `Pre_tokenizer.bert ()` |
| `pre_tokenizers.ByteLevel()` | `Pre_tokenizer.byte_level ()` |
| `pre_tokenizers.Punctuation()` | `Pre_tokenizer.punctuation ()` |
| `pre_tokenizers.Digits()` | `Pre_tokenizer.digits ()` |
| `pre_tokenizers.Metaspace()` | `Pre_tokenizer.metaspace ()` |
| `pre_tokenizers.UnicodeScripts()` | `Pre_tokenizer.unicode_scripts ()` |
| `pre_tokenizers.CharDelimiterSplit(c)` | `Pre_tokenizer.char_delimiter c` |
| `pre_tokenizers.Split(pattern, ...)` | `Pre_tokenizer.split ~pattern ()` |
| `pre_tokenizers.Sequence([...])` | `Pre_tokenizer.sequence [...]` |
### 9.3 Post-processor
**HuggingFace**
```python
from tokenizers import processors
tokenizer.post_processor = processors.BertProcessing(
sep=("[SEP]", 102),
cls=("[CLS]", 101),
)
```
**Brot**
```ocaml
let post =
Post_processor.bert
~sep:("[SEP]", 102)
~cls:("[CLS]", 101)
()
let tokenizer = Brot.bpe ~post ()
```
Common post-processors:
| HuggingFace | Brot |
| ------------------------------------------------------------- | ---------------------------------------------------------- |
| `processors.BertProcessing(sep, cls)` | `Post_processor.bert ~sep ~cls ()` |
| `processors.RobertaProcessing(sep, cls)` | `Post_processor.roberta ~sep ~cls ()` |
| `processors.ByteLevel()` | `Post_processor.byte_level ()` |
| `processors.TemplateProcessing(single, pair, special_tokens)` | `Post_processor.template ~single ?pair ~special_tokens ()` |
| `processors.Sequence([...])` | `Post_processor.sequence [...]` |
### 9.4 Decoder
**HuggingFace**
```python
from tokenizers import decoders
tokenizer.decoder = decoders.WordPiece(prefix="##")
```
**Brot**
```ocaml
let dec = Decoder.wordpiece ~prefix:"##" ()
let tokenizer = Brot.wordpiece ~decoder:dec ()
```
Common decoders:
| HuggingFace | Brot |
| ------------------------------- | --------------------------------- |
| `decoders.BPEDecoder(suffix)` | `Decoder.bpe ~suffix ()` |
| `decoders.ByteLevel()` | `Decoder.byte_level ()` |
| `decoders.ByteFallback()` | `Decoder.byte_fallback ()` |
| `decoders.WordPiece(prefix)` | `Decoder.wordpiece ~prefix ()` |
| `decoders.Metaspace()` | `Decoder.metaspace ()` |
| `decoders.CTC()` | `Decoder.ctc ()` |
| `decoders.Replace(pattern, by)` | `Decoder.replace ~pattern ~by ()` |
| `decoders.Strip()` | `Decoder.strip ()` |
| `decoders.Fuse()` | `Decoder.fuse ()` |
| `decoders.Sequence([...])` | `Decoder.sequence [...]` |
### 9.5 Inspecting the pipeline
**HuggingFace**
```python
tokenizer.normalizer
tokenizer.pre_tokenizer
tokenizer.post_processor
tokenizer.decoder
```
**Brot**
```ocaml
let norm = Brot.normalizer tokenizer (* Normalizer.t option *)
let pre = Brot.pre_tokenizer tokenizer (* Pre_tokenizer.t option *)
let post = Brot.post_processor tokenizer (* Post_processor.t option *)
let dec = Brot.decoder tokenizer (* Decoder.t option *)
```
Brot returns `option` for each stage, since any stage can be absent.
---
## 10. Training Tokenizers
### 10.1 BPE training
**HuggingFace**
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(
vocab_size=30000,
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]"],
)
tokenizer.train(["corpus.txt"], trainer)
```
**Brot**
```ocaml
let tokenizer =
Brot.train_bpe
(`Files [ "corpus.txt" ])
~vocab_size:30000
~min_frequency:2
~specials:[
Brot.special "[UNK]";
Brot.special "[CLS]";
Brot.special "[SEP]";
Brot.special "[PAD]";
]
~unk_token:"[UNK]"
~pad_token:"[PAD]"
```
Brot combines the `Tokenizer` + `Trainer` pattern into a single function call. Training data is passed as `` `Files `` (file paths) or `` `Seq `` (string sequence).
### 10.2 WordPiece training
**HuggingFace**
```python
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
trainer = WordPieceTrainer(vocab_size=30000, special_tokens=["[UNK]", "[PAD]"])
tokenizer.train(["corpus.txt"], trainer)
```
**Brot**
```ocaml
let tokenizer =
Brot.train_wordpiece
(`Files [ "corpus.txt" ])
~vocab_size:30000
~unk_token:"[UNK]"
~specials:[ Brot.special "[UNK]"; Brot.special "[PAD]" ]
~pad_token:"[PAD]"
```
### 10.3 Unigram training
**HuggingFace**
```python
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
tokenizer = Tokenizer(Unigram())
trainer = UnigramTrainer(vocab_size=8000, special_tokens=["", ""])
tokenizer.train(["corpus.txt"], trainer)
```
**Brot**
```ocaml
let tokenizer =
Brot.train_unigram
(`Files [ "corpus.txt" ])
~vocab_size:8000
~unk_token:""
~specials:[ Brot.special ""; Brot.special "" ]
~pad_token:""
```
### 10.4 Training from in-memory data
**HuggingFace**
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=1000)
tokenizer.train_from_iterator(
["Hello world", "How are you?", "Hello again"],
trainer,
)
```
**Brot**
```ocaml
let texts = [ "Hello world"; "How are you?"; "Hello again" ]
let tokenizer =
Brot.train_bpe (`Seq (List.to_seq texts)) ~vocab_size:1000
```
### 10.5 Extending an existing tokenizer
**HuggingFace**
```python
# Load, then retrain with more data
tokenizer = Tokenizer.from_file("tokenizer.json")
trainer = BpeTrainer(vocab_size=50000)
tokenizer.train(["more_data.txt"], trainer)
```
**Brot**
```ocaml
let base = Brot.from_file "tokenizer.json" |> Result.get_ok
let tokenizer =
Brot.train_bpe ~init:base (`Files [ "more_data.txt" ]) ~vocab_size:50000
```
The `~init` parameter on training functions lets you extend an existing tokenizer with additional data.
---
## 11. Vocabulary Inspection
**HuggingFace**
```python
tokenizer.get_vocab() # dict: token -> id
tokenizer.get_vocab_size() # int
tokenizer.token_to_id("[CLS]") # int or None
tokenizer.id_to_token(101) # str or None
```
**Brot**
```ocaml
let v = Brot.vocab tokenizer (* (string * int) list *)
let size = Brot.vocab_size tokenizer (* int *)
let id = Brot.token_to_id tokenizer "[CLS]" (* int option *)
let token = Brot.id_to_token tokenizer 101 (* string option *)
```
`vocab` returns an association list instead of a dictionary. `token_to_id` and `id_to_token` return `option` instead of nullable values.
---
## 12. Quick Cheat Sheet
| Task | HuggingFace Tokenizers | Brot |
| ------------------- | ----------------------------------------------------------- | ---------------------------------------------------------------- |
| Load from file | `Tokenizer.from_file("tokenizer.json")` | `Brot.from_file "tokenizer.json"` |
| Save to file | `tokenizer.save("tokenizer.json")` | `Brot.save_pretrained tokenizer ~path:"./out"` |
| Encode text | `tokenizer.encode("Hello")` | `Brot.encode tokenizer "Hello"` |
| Encode IDs only | `tokenizer.encode("Hello").ids` | `Brot.encode_ids tokenizer "Hello"` |
| Encode batch | `tokenizer.encode_batch(["a", "b"])` | `Brot.encode_batch tokenizer ["a"; "b"]` |
| Encode pair | `tokenizer.encode("a", "b")` | `Brot.encode tokenizer ~pair:"b" "a"` |
| Encode pairs batch | `tokenizer.encode_batch([("a","b"), ...])` | `Brot.encode_pairs_batch tokenizer [("a","b"); ...]` |
| Decode | `tokenizer.decode(ids)` | `Brot.decode tokenizer ids` |
| Decode batch | `tokenizer.decode_batch([ids1, ids2])` | `Brot.decode_batch tokenizer [ids1; ids2]` |
| Get token IDs | `output.ids` | `Encoding.ids enc` |
| Get tokens | `output.tokens` | `Encoding.tokens enc` |
| Get attention mask | `output.attention_mask` | `Encoding.attention_mask enc` |
| Get type IDs | `output.type_ids` | `Encoding.type_ids enc` |
| Get offsets | `output.offsets` | `Encoding.offsets enc` |
| Padding | `tokenizer.enable_padding(length=128)` | `Brot.encode tokenizer ~padding:(Brot.padding (`Fixed 128)) ...` |
| Truncation | `tokenizer.enable_truncation(max_length=512)` | `Brot.encode tokenizer ~truncation:(Brot.truncation 512) ...` |
| Vocab size | `tokenizer.get_vocab_size()` | `Brot.vocab_size tokenizer` |
| Token to ID | `tokenizer.token_to_id("[CLS]")` | `Brot.token_to_id tokenizer "[CLS]"` |
| ID to token | `tokenizer.id_to_token(101)` | `Brot.id_to_token tokenizer 101` |
| Train BPE | `tokenizer.train(files, BpeTrainer(...))` | `Brot.train_bpe (`Files files) ~vocab_size:30000` |
| Train WordPiece | `tokenizer.train(files, WordPieceTrainer(...))` | `Brot.train_wordpiece (`Files files) ~vocab_size:30000` |
| Train Unigram | `tokenizer.train(files, UnigramTrainer(...))` | `Brot.train_unigram (`Files files) ~vocab_size:8000` |
| Train from iterator | `tokenizer.train_from_iterator(iter, trainer)` | `Brot.train_bpe (`Seq seq) ~vocab_size:1000` |
| Set normalizer | `tokenizer.normalizer = normalizers.Lowercase()` | `Brot.bpe ~normalizer:Normalizer.lowercase ()` |
| Set pre-tokenizer | `tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()` | `Brot.bpe ~pre:(Pre_tokenizer.byte_level ()) ()` |
| Set post-processor | `tokenizer.post_processor = processors.BertProcessing(...)` | `Brot.bpe ~post:(Post_processor.bert ~sep ~cls ()) ()` |
| Set decoder | `tokenizer.decoder = decoders.WordPiece()` | `Brot.bpe ~decoder:(Decoder.wordpiece ()) ()` |
| Add special tokens | `tokenizer.add_special_tokens([AddedToken(...)])` | Pass `~specials:[Brot.special "..."; ...]` at construction |
================================================
FILE: packages/brot/doc/dune
================================================
(mdx
(files *.md)
(package brot)
(libraries brot))
================================================
FILE: packages/brot/doc/index.md
================================================
# Brot
Brot tokenizes text into token IDs for language models and reverses the
process. It supports BPE, WordPiece, Unigram, word-level, and
character-level algorithms, loads and saves HuggingFace `tokenizer.json`
files, and is 1.3-6x faster than HuggingFace tokenizers on most
benchmarks.
## Features
- **Tokenization algorithms**: BPE, WordPiece, Unigram, word-level, character-level
- **HuggingFace compatible**: load and save `tokenizer.json`, load vocab/merges model files
- **Composable pipeline**: normalizer, pre-tokenizer, post-processor, decoder — each stage independently configurable
- **Rich encoding output**: token IDs, string tokens, byte offsets, attention masks, type IDs, word IDs, special token masks
- **Training**: train BPE, WordPiece, Unigram, and word-level tokenizers from scratch
- **Performance**: 1.3-6x faster than HuggingFace tokenizers (Rust native)
## Quick Start
Build a BPE tokenizer from a vocabulary and merge rules, encode text,
and decode it back:
```ocaml
open Brot
let tokenizer =
bpe
~vocab:
[ ("h", 0); ("e", 1); ("l", 2); ("o", 3); (" ", 4); ("w", 5);
("r", 6); ("d", 7); ("he", 8); ("ll", 9); ("llo", 10);
("hello", 11); ("wo", 12); ("rl", 13); ("rld", 14); ("world", 15) ]
~merges:
[ ("h", "e"); ("l", "l"); ("ll", "o"); ("he", "llo");
("w", "o"); ("r", "l"); ("rl", "d"); ("wo", "rld") ]
()
let encoding = encode tokenizer "hello world"
let ids = Encoding.ids encoding (* [| 11; 4; 15 |] *)
let tokens = Encoding.tokens encoding (* [| "hello"; " "; "world" |] *)
let decoded = decode tokenizer ids (* "hello world" *)
```
Load a pretrained tokenizer from a HuggingFace `tokenizer.json` file:
```ocaml
open Brot
let tokenizer = from_file "tokenizer.json" |> Result.get_ok
let encoding = encode tokenizer "Hello world!"
let ids = Encoding.ids encoding
```
Train a tokenizer from a text corpus:
```ocaml
open Brot
let tokenizer =
train_bpe ~vocab_size:100 ~show_progress:false
(`Seq (List.to_seq
[ "The quick brown fox jumps over the lazy dog";
"The dog barked at the fox";
"Quick brown foxes are rare" ]))
let size = vocab_size tokenizer
let ids = encode_ids tokenizer "The quick fox"
```
## Next Steps
- [Getting Started](01-getting-started/) — encode, decode, pipeline basics, training
- [The Tokenization Pipeline](02-pipeline/) — how the 5 pipeline stages work
- [Pretrained Tokenizers](03-pretrained/) — loading, saving, and building known model pipelines
- [Batch Processing](04-batch-processing/) — padding, truncation, encoding metadata
- [Choosing an Algorithm](05-algorithms/) — BPE vs WordPiece vs Unigram and when to use each
================================================
FILE: packages/brot/examples/01-encode-decode/README.md
================================================
# `01-encode-decode`
Your first tokenizer. This example shows the minimal steps to encode text into
token IDs and decode back.
```bash
dune exec brot/examples/01-encode-decode/main.exe
```
## What You'll Learn
- Creating a BPE tokenizer with `Brot.bpe`
- Encoding text with `Brot.encode`
- Inspecting token strings and IDs with `Encoding.tokens` and `Encoding.ids`
- Decoding token IDs back to text with `Brot.decode`
## Key Functions
| Function | Purpose |
| ----------------- | ------------------------------------------------------ |
| `bpe` | Create a BPE tokenizer from vocabulary and merge rules |
| `encode` | Encode text into an `Encoding.t` |
| `Encoding.ids` | Get the integer token IDs |
| `Encoding.tokens` | Get the string token representations |
| `decode` | Convert token IDs back to text |
## How BPE Works
BPE (Byte Pair Encoding) iteratively merges the most frequent character pairs.
Given the text `"hello"` and merge rules like `("h","e")`, `("l","l")`,
`("he","l")`, `("ll","o")`, `("hel","lo")`, BPE applies merges in priority
order until no more merges apply, producing `"hello"` as a single token.
## Try It
1. Remove some merge rules and run again to see how the text gets split into
smaller subword pieces.
2. Add a new word like `"held"` to the vocabulary and encode `"hello held"`.
## Next Steps
Continue to [02-encoding-fields](../02-encoding-fields/) to learn about all the
metadata in an encoding.
================================================
FILE: packages/brot/examples/01-encode-decode/dune
================================================
(executable
(name main)
(libraries brot))
================================================
FILE: packages/brot/examples/01-encode-decode/main.ml
================================================
(* Encode and decode.
The simplest possible tokenization: convert text to token IDs and back.
Demonstrates creating a BPE tokenizer from an inline vocabulary and merge
rules, encoding text, inspecting tokens and IDs, and decoding. *)
open Brot
let () =
(* Build a small BPE tokenizer. The vocabulary maps token strings to IDs.
Merge rules define which character pairs to combine, in priority order. *)
let vocab =
[
("h", 0);
("e", 1);
("l", 2);
("o", 3);
(" ", 4);
("w", 5);
("r", 6);
("d", 7);
("he", 8);
("ll", 9);
("llo", 10);
("hello", 11);
("wo", 12);
("rl", 13);
("rld", 14);
("world", 15);
]
in
let merges =
[
("h", "e");
("l", "l");
("ll", "o");
("he", "llo");
("w", "o");
("r", "l");
("rl", "d");
("wo", "rld");
]
in
let tokenizer = bpe ~vocab ~merges () in
(* Encode text into an Encoding *)
let text = "hello world" in
let encoding = encode tokenizer text in
let ids = Encoding.ids encoding in
let tokens = Encoding.tokens encoding in
Printf.printf "Text: %S\n" text;
Printf.printf "Tokens: [%s]\n"
(String.concat "; "
(List.map (fun s -> Printf.sprintf "%S" s) (Array.to_list tokens)));
Printf.printf "IDs: [%s]\n"
(String.concat "; " (Array.to_list (Array.map string_of_int ids)));
(* Decode token IDs back to text *)
let decoded = decode tokenizer ids in
Printf.printf "Decoded: %S\n\n" decoded;
Printf.printf "Round-trip matches: %b\n\n" (String.equal text decoded);
(* Try another text -- unknown characters become individual tokens *)
let text2 = "hello" in
let enc2 = encode tokenizer text2 in
Printf.printf "Text: %S\n" text2;
Printf.printf "Tokens: [%s]\n"
(String.concat "; "
(List.map
(fun s -> Printf.sprintf "%S" s)
(Array.to_list (Encoding.tokens enc2))));
Printf.printf "IDs: [%s]\n"
(String.concat "; "
(Array.to_list (Array.map string_of_int (Encoding.ids enc2))))
================================================
FILE: packages/brot/examples/02-encoding-fields/README.md
================================================
# `02-encoding-fields`
Understanding encodings. An `Encoding.t` bundles token IDs with alignment
metadata: byte offsets, word indices, type IDs, attention masks, and
special-token flags.
```bash
dune exec brot/examples/02-encoding-fields/main.exe
```
## What You'll Learn
- All parallel arrays in an `Encoding.t` and how they align
- Byte offsets that map each token back to the original text
- Word indices that group subword tokens by source word
- Attention mask (1 = real token, 0 = padding)
- Special tokens mask (1 = special, 0 = content)
## Key Functions
| Function | Purpose |
| ------------------------------ | ------------------------------------------------- |
| `Encoding.ids` | Token ID array for model input |
| `Encoding.tokens` | String representation of each token |
| `Encoding.offsets` | `(start, end)` byte spans in the original text |
| `Encoding.word_ids` | Source word index per token (`None` for specials) |
| `Encoding.type_ids` | Segment IDs (0 or 1 for sentence pairs) |
| `Encoding.attention_mask` | 1 for real tokens, 0 for padding |
| `Encoding.special_tokens_mask` | 1 for special tokens, 0 for content |
| `Encoding.length` | Number of tokens |
## Offsets
Offsets are byte positions `(start, end)` into the original text. You can
extract the original substring with `String.sub text start (end - start)`.
This is essential for highlighting, named entity recognition, and other tasks
that need to map tokens back to source text.
## Try It
1. Add more words to the vocabulary and encode a longer sentence.
2. Encode a text with unknown words and observe the `[UNK]` token.
## Next Steps
Continue to [03-normalizers](../03-normalizers/) to learn how text is cleaned
before tokenization.
================================================
FILE: packages/brot/examples/02-encoding-fields/dune
================================================
(executable
(name main)
(libraries brot))
================================================
FILE: packages/brot/examples/02-encoding-fields/main.ml
================================================
(* Understanding encodings.
An Encoding bundles token IDs with alignment metadata: byte offsets, word
indices, segment type IDs, attention masks, and special-token flags. All
arrays share the same length. *)
open Brot
let print_encoding enc =
let ids = Encoding.ids enc in
let tokens = Encoding.tokens enc in
let offsets = Encoding.offsets enc in
let word_ids = Encoding.word_ids enc in
let type_ids = Encoding.type_ids enc in
let attn = Encoding.attention_mask enc in
let special = Encoding.special_tokens_mask enc in
Printf.printf "%-6s %-10s %-4s %-12s %-8s %-8s %-6s %-8s\n" "Index" "Token"
"ID" "Offsets" "Word_ID" "Type_ID" "Attn" "Special";
Printf.printf "%s\n" (String.make 66 '-');
for i = 0 to Encoding.length enc - 1 do
let s, e = offsets.(i) in
let word =
match word_ids.(i) with Some w -> string_of_int w | None -> "-"
in
Printf.printf "%-6d %-10s %-4d (%2d, %2d) %-8s %-8d %-6d %-8d\n" i
tokens.(i) ids.(i) s e word type_ids.(i) attn.(i) special.(i)
done
let () =
(* Word-level tokenizer: each word maps to one token *)
let vocab =
[
("[UNK]", 0);
("hello", 1);
("world", 2);
("the", 3);
("is", 4);
("great", 5);
]
in
let tokenizer =
word_level ~vocab ~unk_token:"[UNK]" ~pre:(Pre_tokenizer.whitespace ()) ()
in
let text = "hello world is great" in
Printf.printf "Text: %S\n" text;
Printf.printf "Length: %d tokens\n\n"
(Encoding.length (encode tokenizer text));
print_encoding (encode tokenizer text);
(* Show what happens with unknown words *)
Printf.printf "\n--- Unknown words ---\n\n";
let text2 = "hello universe" in
Printf.printf "Text: %S\n" text2;
Printf.printf "Length: %d tokens\n\n"
(Encoding.length (encode tokenizer text2));
print_encoding (encode tokenizer text2);
(* WordPiece: subword tokens have word_ids linking to the source word *)
Printf.printf "\n--- Subword tokens (WordPiece) ---\n\n";
let wp_vocab =
[
("[UNK]", 0);
("play", 1);
("##ing", 2);
("##ed", 3);
("un", 4);
("##happy", 5);
]
in
let wp = wordpiece ~vocab:wp_vocab ~unk_token:"[UNK]" () in
let text3 = "playing" in
Printf.printf "Text: %S\n" text3;
Printf.printf "Length: %d tokens\n\n" (Encoding.length (encode wp text3));
print_encoding (encode wp text3)
================================================
FILE: packages/brot/examples/03-normalizers/README.md
================================================
# `03-normalizers`
Text normalization before tokenization. Normalizers clean and standardize text
so that surface variations (case, accents, whitespace) don't prevent vocabulary
matches.
```bash
dune exec brot/examples/03-normalizers/main.exe
```
## What You'll Learn
- Unicode normalization: `nfc`, `nfkc`
- Text transforms: `lowercase`, `strip_accents`, `strip`, `replace`, `prepend`
- Model-specific normalization: `bert`
- Composing normalizers with `sequence`
- Applying normalizers directly with `Normalizer.apply`
- How normalization affects tokenization results
## Key Functions
| Function | Purpose |
| -------------------------- | ---------------------------------- |
| `Normalizer.nfc` / `nfkc` | Unicode normalization forms |
| `Normalizer.lowercase` | Unicode case folding |
| `Normalizer.strip_accents` | Remove combining marks |
| `Normalizer.strip` | Strip boundary whitespace |
| `Normalizer.replace` | Regex-based replacement |
| `Normalizer.prepend` | Prepend a string to non-empty text |
| `Normalizer.bert` | BERT-specific normalizer |
| `Normalizer.sequence` | Compose normalizers left-to-right |
| `Normalizer.apply` | Apply a normalizer to a string |
## Why Normalize?
Without normalization, `"Hello"`, `"hello"`, and `"HELLO"` are three different
tokens. Normalization maps them all to `"hello"` so a single vocabulary entry
covers all cases. Similarly, `"caf\u{00E9}"` and `"cafe"` can be unified by
stripping accents.
## Try It
1. Add `Normalizer.nfkd` and see how it differs from `nfd`.
2. Create a normalizer that replaces email addresses with ``.
3. Try the BERT normalizer with Chinese characters.
## Next Steps
Continue to [04-pre-tokenizers](../04-pre-tokenizers/) to learn how text is
split into fragments before vocabulary lookup.
================================================
FILE: packages/brot/examples/03-normalizers/dune
================================================
(executable
(name main)
(libraries brot))
================================================
FILE: packages/brot/examples/03-normalizers/main.ml
================================================
(* Text normalization.
Normalizers transform text before tokenization: lowercasing, accent removal,
Unicode normalization, whitespace cleanup, and model-specific preprocessing.
They are the first stage in the tokenization pipeline. *)
open Brot
let show name norm text =
let result = Normalizer.apply norm text in
Printf.printf " %-20s %S -> %S\n" name text result
let () =
Printf.printf "=== Unicode Normalization ===\n\n";
show "nfc" Normalizer.nfc "caf\xc3\xa9";
show "nfkc" Normalizer.nfkc "\xef\xac\x81";
(* fi ligature -> fi *)
Printf.printf "\n=== Text Transforms ===\n\n";
show "lowercase" Normalizer.lowercase "Hello WORLD";
show "strip_accents" Normalizer.strip_accents
"caf\xc3\xa9 r\xc3\xa9sum\xc3\xa9";
show "strip" (Normalizer.strip ()) " hello ";
show "replace"
(Normalizer.replace ~pattern:"\\d+" ~replacement:"")
"I have 42 apples and 3 oranges";
show "prepend" (Normalizer.prepend ">> ") "hello";
Printf.printf "\n=== Model-specific ===\n\n";
show "bert (default)" (Normalizer.bert ()) "Hello WORLD!";
show "bert (no lower)" (Normalizer.bert ~lowercase:false ()) "Hello WORLD!";
Printf.printf "\n=== Composition ===\n\n";
let composed =
Normalizer.sequence
[ Normalizer.nfd; Normalizer.strip_accents; Normalizer.lowercase ]
in
show "nfd+strip+lower" composed "Caf\xc3\xa9 R\xc3\xa9sum\xc3\xa9";
show "nfd+strip+lower" composed "HELLO";
Printf.printf "\n=== Effect on Tokenization ===\n\n";
let vocab =
[ ("hello", 0); ("world", 1); ("cafe", 2); ("resume", 3); ("", 4) ]
in
let no_norm =
word_level ~vocab ~unk_token:"" ~pre:(Pre_tokenizer.whitespace ()) ()
in
let with_norm =
word_level ~vocab ~unk_token:""
~pre:(Pre_tokenizer.whitespace ())
~normalizer:composed ()
in
let text = "HELLO Caf\xc3\xa9" in
let enc1 = encode no_norm text in
let enc2 = encode with_norm text in
Printf.printf " Text: %S\n" text;
Printf.printf " Without normalizer: [%s]\n"
(String.concat "; "
(List.map
(fun s -> Printf.sprintf "%S" s)
(Array.to_list (Encoding.tokens enc1))));
Printf.printf " With normalizer: [%s]\n"
(String.concat "; "
(List.map
(fun s -> Printf.sprintf "%S" s)
(Array.to_list (Encoding.tokens enc2))))
================================================
FILE: packages/brot/examples/04-pre-tokenizers/README.md
================================================
# `04-pre-tokenizers`
Pre-tokenization: splitting text into fragments before vocabulary lookup. Each
fragment carries byte offsets into the original text.
```bash
dune exec brot/examples/04-pre-tokenizers/main.exe
```
## What You'll Learn
- Common pre-tokenizers: `whitespace`, `whitespace_split`, `bert`
- Punctuation and digit handling
- Delimiter-based splitting: `char_delimiter`, `split`, `fixed_length`
- SentencePiece-style `metaspace`
- Composing pre-tokenizers with `sequence`
- Using `Pre_tokenizer.pre_tokenize` to see fragments and offsets
## Key Functions
| Function | Purpose |
| -------------------------------- | ------------------------------------------ |
| `Pre_tokenizer.whitespace` | Pattern-based: `\w+` and `[^\w\s]+` groups |
| `Pre_tokenizer.whitespace_split` | Simple whitespace splitting |
| `Pre_tokenizer.bert` | BERT-style: whitespace + punctuation + CJK |
| `Pre_tokenizer.punctuation` | Isolate punctuation from words |
| `Pre_tokenizer.digits` | Split on digit boundaries |
| `Pre_tokenizer.char_delimiter` | Split on a single character |
| `Pre_tokenizer.split` | Split on a literal string pattern |
| `Pre_tokenizer.fixed_length` | Fixed-length character chunks |
| `Pre_tokenizer.metaspace` | Replace spaces with visible markers |
| `Pre_tokenizer.sequence` | Chain pre-tokenizers left-to-right |
| `Pre_tokenizer.pre_tokenize` | Apply and get `(fragment, offsets)` list |
## Pre-tokenizer vs Tokenizer
Pre-tokenization happens *before* the vocabulary-based algorithm (BPE,
WordPiece, etc.). It determines the boundaries within which subword splitting
operates. For example, with whitespace pre-tokenization, BPE will never merge
tokens across word boundaries.
## Try It
1. Try `unicode_scripts` on text mixing Latin and CJK characters.
2. Change the punctuation `behavior` to `` `Merged_with_previous `` or
`` `Removed ``.
3. Create a pre-tokenizer that splits on hyphens.
## Next Steps
Continue to [05-algorithms](../05-algorithms/) to see how different
tokenization algorithms split the same text.
================================================
FILE: packages/brot/examples/04-pre-tokenizers/dune
================================================
(executable
(name main)
(libraries brot))
================================================
FILE: packages/brot/examples/04-pre-tokenizers/main.ml
================================================
(* Pre-tokenization.
Pre-tokenizers split text into fragments before vocabulary-based
tokenization. Each fragment carries byte offsets into the original text.
Different strategies produce different splits, affecting how subword
algorithms see the input. *)
open Brot
let show name pre text =
let result = Pre_tokenizer.pre_tokenize pre text in
Printf.printf " %-24s %S\n" name text;
List.iter
(fun (fragment, (s, e)) -> Printf.printf " %S (%d, %d)\n" fragment s e)
result;
print_newline ()
let () =
let text = "Hello, world! It's 2026." in
Printf.printf "=== Common Pre-tokenizers ===\n\n";
Printf.printf "Text: %S\n\n" text;
show "whitespace" (Pre_tokenizer.whitespace ()) text;
show "whitespace_split" (Pre_tokenizer.whitespace_split ()) text;
show "bert" (Pre_tokenizer.bert ()) text;
show "punctuation" (Pre_tokenizer.punctuation ()) text;
show "digits (individual)"
(Pre_tokenizer.digits ~individual_digits:true ())
text;
show "digits (grouped)"
(Pre_tokenizer.digits ~individual_digits:false ())
text;
Printf.printf "=== Delimiter-based ===\n\n";
show "char_delimiter ','" (Pre_tokenizer.char_delimiter ',') "a,b,c";
show "split on '::'" (Pre_tokenizer.split ~pattern:"::" ()) "mod::func::arg";
show "fixed_length 3" (Pre_tokenizer.fixed_length 3) "abcdefgh";
show "metaspace" (Pre_tokenizer.metaspace ()) "Hello world today";
Printf.printf "=== Composition ===\n\n";
let composed =
Pre_tokenizer.sequence
[
Pre_tokenizer.whitespace_split ();
Pre_tokenizer.punctuation ~behavior:`Isolated ();
]
in
show "whitespace + punctuation" composed text
================================================
FILE: packages/brot/examples/05-algorithms/README.md
================================================
# `05-algorithms`
Five tokenization algorithms compared side-by-side. Each algorithm splits text
differently based on its strategy.
```bash
dune exec brot/examples/05-algorithms/main.exe
```
## What You'll Learn
- **BPE** (Byte Pair Encoding): merge-based subwords (GPT-2, RoBERTa)
- **WordPiece**: greedy longest-match with `##` prefix (BERT)
- **Unigram**: probabilistic segmentation (T5, mBART)
- **Word-level**: one token per word, no subword splitting
- **Character-level**: one token per byte, no vocabulary needed
## Key Functions
| Function | Purpose |
| ----------------- | -------------------------------------- |
| `Brot.bpe` | BPE tokenizer from vocab + merge rules |
| `Brot.wordpiece` | WordPiece tokenizer from vocab |
| `Brot.unigram` | Unigram tokenizer from vocab + scores |
| `Brot.word_level` | Word-level tokenizer from vocab |
| `Brot.chars` | Character-level tokenizer (no vocab) |
| `Brot.vocab_size` | Number of vocabulary entries |
## Algorithm Comparison
| Algorithm | Subwords? | Unknown handling | Vocabulary |
| ---------- | ------------------- | ------------------------ | ----------------------- |
| BPE | Yes (merges) | Falls back to characters | `(string * int) list` |
| WordPiece | Yes (`##` prefix) | `[UNK]` token | `(string * int) list` |
| Unigram | Yes (probabilistic) | Lowest-score fallback | `(string * float) list` |
| Word-level | No | `` token | `(string * int) list` |
| Chars | No | N/A (all bytes valid) | None needed |
## Try It
1. Add more merge rules to the BPE tokenizer and see how it affects splitting.
2. Try encoding a word not in the WordPiece vocabulary.
3. Change the Unigram scores and observe how probabilities affect splitting.
## Next Steps
Continue to [06-special-tokens](../06-special-tokens/) to learn about special
tokens and post-processing.
================================================
FILE: packages/brot/examples/05-algorithms/dune
================================================
(executable
(name main)
(libraries brot))
================================================
FILE: packages/brot/examples/05-algorithms/main.ml
================================================
(* Tokenization algorithms.
Five algorithms compared side-by-side: BPE (merge-based), WordPiece (greedy
longest-match), Unigram (probabilistic), word-level (whole words), and
character-level (per-byte). Each splits text differently. *)
open Brot
let show name tokenizer text =
let encoding = encode tokenizer text in
let tokens = Encoding.tokens encoding in
let ids = Encoding.ids encoding in
Printf.printf " %-12s tokens=[%s] ids=[%s]\n" name
(String.concat ", "
(List.map (fun s -> Printf.sprintf "%S" s) (Array.to_list tokens)))
(String.concat ", " (Array.to_list (Array.map string_of_int ids)))
let () =
(* --- BPE: iterative merge-based subwords --- *)
let bpe_tok =
bpe
~vocab:
[
("p", 0);
("l", 1);
("a", 2);
("y", 3);
("i", 4);
("n", 5);
("g", 6);
("pl", 7);
("ay", 8);
("in", 9);
("ng", 10);
("play", 11);
("ing", 12);
("playing", 13);
]
~merges:
[
("p", "l");
("a", "y");
("i", "n");
("n", "g");
("pl", "ay");
("in", "g");
("play", "ing");
]
()
in
(* --- WordPiece: greedy longest-match with ## prefix --- *)
let wp_tok =
wordpiece
~vocab:
[
("[UNK]", 0);
("play", 1);
("##ing", 2);
("##ed", 3);
("run", 4);
("##ning", 5);
("un", 6);
("##known", 7);
]
~unk_token:"[UNK]" ()
in
(* --- Unigram: probabilistic segmentation --- *)
let uni_tok =
unigram
~vocab:
[
("playing", -0.5);
("play", -1.0);
("ing", -1.5);
("p", -3.0);
("l", -3.0);
("a", -3.0);
("y", -3.0);
("i", -3.0);
("n", -3.0);
("g", -3.0);
]
()
in
(* --- Word-level: whole words only --- *)
let wl_tok =
word_level
~vocab:[ ("playing", 0); ("hello", 1); ("", 2) ]
~unk_token:""
~pre:(Pre_tokenizer.whitespace ())
()
in
(* --- Character-level: one byte per token --- *)
let char_tok = chars () in
Printf.printf "=== Encoding %S ===\n\n" "playing";
show "BPE" bpe_tok "playing";
show "WordPiece" wp_tok "playing";
show "Unigram" uni_tok "playing";
show "Word-level" wl_tok "playing";
show "Chars" char_tok "playing";
Printf.printf "\n=== Encoding %S ===\n\n" "running";
show "WordPiece" wp_tok "running";
show "Chars" char_tok "running";
Printf.printf "\n=== Encoding %S (unknown word) ===\n\n" "unknown";
show "WordPiece" wp_tok "unknown";
show "Word-level" wl_tok "unknown";
show "Chars" char_tok "unknown";
Printf.printf "\n=== Vocabulary sizes ===\n\n";
Printf.printf " BPE: %d\n" (vocab_size bpe_tok);
Printf.printf " WordPiece: %d\n" (vocab_size wp_tok);
Printf.printf " Unigram: %d\n" (vocab_size uni_tok);
Printf.printf " Word-level: %d\n" (vocab_size wl_tok);
Printf.printf " Chars: %d (byte range 0-255)\n" (vocab_size char_tok)
================================================
FILE: packages/brot/examples/06-special-tokens/README.md
================================================
# `06-special-tokens`
Special tokens and post-processing. Post-processors insert tokens like `[CLS]`
and `[SEP]` after tokenization, and assign type IDs for sentence-pair tasks.
```bash
dune exec brot/examples/06-special-tokens/main.exe
```
## What You'll Learn
- Defining special tokens with `Brot.special`
- BERT-style post-processing: `[CLS] A [SEP]` and `[CLS] A [SEP] B [SEP]`
- Sentence-pair encoding with `encode ~pair`
- Type IDs: 0 for first sequence, 1 for second
- Template-based post-processing for custom formats
- Skipping special tokens with `~add_special_tokens:false`
## Key Functions
| Function | Purpose |
| ------------------------------ | ------------------------------------------- |
| `Brot.special` | Define a special token configuration |
| `Post_processor.bert` | BERT-style `[CLS] A [SEP] B [SEP]` |
| `Post_processor.template` | Template-based with `$A`, `$B` placeholders |
| `Brot.encode ~pair` | Encode a sentence pair |
| `Encoding.type_ids` | Segment type IDs (0 or 1) |
| `Encoding.special_tokens_mask` | 1 for special tokens, 0 for content |
## BERT Post-processing
For a single sentence: `[CLS] tokens [SEP]`
For a sentence pair: `[CLS] A_tokens [SEP] B_tokens [SEP]`
Type IDs distinguish the two sequences:
- First sequence (including `[CLS]` and first `[SEP]`): type_id = 0
- Second sequence (including final `[SEP]`): type_id = 1
## Try It
1. Try the `roberta` post-processor with `` and `` tokens.
2. Create a custom template with different special tokens.
3. Encode a pair and check that `type_ids` correctly separates the segments.
## Next Steps
Continue to [07-padding-truncation](../07-padding-truncation/) to learn about
preparing batches with uniform sequence lengths.
================================================
FILE: packages/brot/examples/06-special-tokens/dune
================================================
(executable
(name main)
(libraries brot))
================================================
FILE: packages/brot/examples/06-special-tokens/main.ml
================================================
(* Special tokens and post-processing.
Special tokens like [CLS] and [SEP] are inserted by post-processors after
tokenization. They mark sequence boundaries and provide structure for model
input. Sentence-pair encoding assigns different type IDs to each sequence. *)
open Brot
let print_encoding enc =
let ids = Encoding.ids enc in
let tokens = Encoding.tokens enc in
let type_ids = Encoding.type_ids enc in
let special = Encoding.special_tokens_mask enc in
Printf.printf " %-8s %-4s %-8s %-8s\n" "Token" "ID" "Type_ID" "Special";
Printf.printf " %s\n" (String.make 32 '-');
for i = 0 to Encoding.length enc - 1 do
Printf.printf " %-8s %-4d %-8d %-8d\n" tokens.(i) ids.(i) type_ids.(i)
special.(i)
done
let () =
let vocab =
[
("[UNK]", 0);
("[CLS]", 1);
("[SEP]", 2);
("hello", 3);
("world", 4);
("how", 5);
("are", 6);
("you", 7);
]
in
let specials = List.map special [ "[CLS]"; "[SEP]"; "[UNK]" ] in
let post = Post_processor.bert ~cls:("[CLS]", 1) ~sep:("[SEP]", 2) () in
let tokenizer =
word_level ~vocab ~unk_token:"[UNK]" ~specials ~post
~pre:(Pre_tokenizer.whitespace ())
()
in
(* Single sentence: [CLS] A [SEP] *)
Printf.printf "=== Single Sentence ===\n";
Printf.printf "Text: \"hello world\"\n\n";
print_encoding (encode tokenizer "hello world");
(* Sentence pair: [CLS] A [SEP] B [SEP] *)
Printf.printf "\n=== Sentence Pair ===\n";
Printf.printf "A: \"hello world\", B: \"how are you\"\n\n";
print_encoding (encode tokenizer ~pair:"how are you" "hello world");
(* Without special tokens *)
Printf.printf "\n=== Without Special Tokens ===\n";
Printf.printf "Text: \"hello world\" (add_special_tokens=false)\n\n";
print_encoding (encode tokenizer ~add_special_tokens:false "hello world");
(* Template-based post-processor *)
Printf.printf "\n=== Template Post-processor ===\n";
let template_post =
Post_processor.template ~single:"[CLS] $A [SEP]"
~pair:"[CLS] $A [SEP] $B:1 [SEP]:1"
~special_tokens:[ ("[CLS]", 1); ("[SEP]", 2) ]
()
in
let tok2 =
word_level ~vocab ~unk_token:"[UNK]" ~specials ~post:template_post
~pre:(Pre_tokenizer.whitespace ())
()
in
Printf.printf "Template: \"[CLS] $A [SEP] $B:1 [SEP]:1\"\n";
Printf.printf "A: \"hello\", B: \"world\"\n\n";
print_encoding (encode tok2 ~pair:"world" "hello")
================================================
FILE: packages/brot/examples/07-padding-truncation/README.md
================================================
# `07-padding-truncation`
Padding and truncation for batch processing. Models require uniform sequence
lengths. Padding adds filler tokens; truncation trims long sequences.
```bash
dune exec brot/examples/07-padding-truncation/main.exe
```
## What You'll Learn
- Fixed-length padding with `padding (`Fixed n)`
- Batch-longest padding with `padding `Batch_longest`
- Left vs right padding direction
- Truncation with `truncation max_length`
- Combining padding and truncation
- Using `Encoding.attention_mask` to distinguish real tokens from padding
## Key Functions
| Function | Purpose |
| ------------------------- | --------------------------------- |
| `Brot.padding` | Create a padding configuration |
| `Brot.truncation` | Create a truncation configuration |
| `Brot.encode_batch` | Encode multiple texts at once |
| `Encoding.attention_mask` | 1 for real tokens, 0 for padding |
## Padding Strategies
| Strategy | Behavior |
| -------------------- | ------------------------------------------------------ |
| `` `Fixed n `` | Every sequence padded to exactly `n` tokens |
| `` `Batch_longest `` | All sequences padded to match the longest in the batch |
| `` `To_multiple n `` | Pad to smallest multiple of `n` >= sequence length |
## Try It
1. Change the padding direction to `` `Left `` and observe where pad tokens appear.
2. Try `padding (`To_multiple 4)` and see how lengths round up.
3. Truncate from the left with `truncation ~direction:`Left 3`.
## Next Steps
Continue to [08-decoders](../08-decoders/) to learn how tokens are converted
back to text.
================================================
FILE: packages/brot/examples/07-padding-truncation/dune
================================================
(executable
(name main)
(libraries brot))
================================================
FILE: packages/brot/examples/07-padding-truncation/main.ml
================================================
(* Padding and truncation.
Batch processing requires uniform sequence lengths. Padding extends short
sequences with pad tokens; truncation trims long ones. The attention mask
distinguishes real tokens from padding. *)
open Brot
let print_batch label encodings =
Printf.printf "%s\n" label;
List.iteri
(fun i enc ->
let ids = Encoding.ids enc in
let attn = Encoding.attention_mask enc in
Printf.printf " [%d] ids=[%s] attn=[%s]\n" i
(String.concat ", " (Array.to_list (Array.map string_of_int ids)))
(String.concat ", " (Array.to_list (Array.map string_of_int attn))))
encodings;
print_newline ()
let () =
let vocab =
[
("[PAD]", 0);
("", 1);
("hello", 2);
("world", 3);
("how", 4);
("are", 5);
("you", 6);
("doing", 7);
("today", 8);
]
in
let tokenizer =
word_level ~vocab ~unk_token:""
~specials:[ special "[PAD]" ]
~pad_token:"[PAD]"
~pre:(Pre_tokenizer.whitespace ())
()
in
let texts = [ "hello"; "hello world"; "how are you doing today" ] in
Printf.printf "Texts:\n";
List.iteri (fun i t -> Printf.printf " [%d] %S\n" i t) texts;
print_newline ();
(* No padding *)
print_batch "=== No Padding ===" (encode_batch tokenizer texts);
(* Fixed-length padding *)
print_batch "=== Fixed Padding (length=6) ==="
(encode_batch tokenizer ~padding:(padding (`Fixed 6)) texts);
(* Batch-longest padding *)
print_batch "=== Batch Longest Padding ==="
(encode_batch tokenizer ~padding:(padding `Batch_longest) texts);
(* Left padding *)
print_batch "=== Left Padding (length=6) ==="
(encode_batch tokenizer
~padding:(padding ~direction:`Left (`Fixed 6))
texts);
(* Truncation *)
print_batch "=== Truncation (max_length=3) ==="
(encode_batch tokenizer ~truncation:(truncation 3) texts);
(* Padding + Truncation *)
print_batch "=== Padding + Truncation (pad=4, trunc=4) ==="
(encode_batch tokenizer
~padding:(padding (`Fixed 4))
~truncation:(truncation 4) texts)
================================================
FILE: packages/brot/examples/08-decoders/README.md
================================================
# `08-decoders`
Decoders convert token strings back to natural text. Different tokenization
schemes require different decoding strategies to produce clean output.
```bash
dune exec brot/examples/08-decoders/main.exe
```
## What You'll Learn
- Per-token decoders: `wordpiece`, `bpe`, `metaspace`, `byte_fallback`
- Collapsing decoders: `fuse`, `replace`
- Composing decoders with `sequence`
- Integrating a decoder with a tokenizer
- Skipping special tokens during decoding
## Key Functions
| Function | Purpose |
| ----------------------- | ------------------------------------ |
| `Decoder.wordpiece` | Strip `##` prefix, join subwords |
| `Decoder.bpe` | Strip word-end suffix, insert spaces |
| `Decoder.metaspace` | Convert markers back to spaces |
| `Decoder.byte_fallback` | Convert `<0xFF>` back to bytes |
| `Decoder.fuse` | Concatenate all tokens |
| `Decoder.replace` | String replacement |
| `Decoder.sequence` | Chain decoders |
| `Decoder.decode` | Apply decoder to token list |
| `Brot.decode` | Full decode through tokenizer |
## Per-token vs Collapsing
Some decoders transform each token independently (per-token: `bpe`,
`metaspace`, `byte_fallback`), while others combine the entire token list into
a single result (collapsing: `wordpiece`, `fuse`, `replace`). This matters
when composing with `sequence`.
## Try It
1. Try `Decoder.ctc` for speech recognition CTC output.
2. Compose `byte_fallback` with `fuse` and decode byte tokens.
3. Use `Decoder.strip` to remove leading/trailing characters.
## Next Steps
Continue to [09-training](../09-training/) to learn how to train tokenizers
from scratch.
================================================
FILE: packages/brot/examples/08-decoders/dune
================================================
(executable
(name main)
(libraries brot))
================================================
FILE: packages/brot/examples/08-decoders/main.ml
================================================
(* Decoders.
Decoders convert token strings back to natural text by reversing
encoding-specific transformations: prefix/suffix removal, space insertion,
byte-level decoding, and marker replacement. *)
open Brot
let show name decoder tokens =
let result = Decoder.decode decoder tokens in
Printf.printf " %-22s [%s] -> %S\n" name
(String.concat "; " (List.map (fun s -> Printf.sprintf "%S" s) tokens))
result
let () =
Printf.printf "=== Per-token Decoders ===\n\n";
show "wordpiece" (Decoder.wordpiece ()) [ "play"; "##ing"; "un"; "##happy" ];
show "bpe (suffix=)"
(Decoder.bpe ~suffix:"" ())
[ "hel"; "lo"; "wor"; "ld" ];
show "metaspace" (Decoder.metaspace ())
[ "\xe2\x96\x81Hello"; "\xe2\x96\x81world" ];
show "byte_fallback" (Decoder.byte_fallback ()) [ "hello"; "<0x21>" ];
Printf.printf "\n=== Collapsing Decoders ===\n\n";
show "fuse" (Decoder.fuse ()) [ "h"; "e"; "l"; "l"; "o" ];
show "replace ('_' -> ' ')"
(Decoder.replace ~pattern:"_" ~by:" " ())
[ "hello_world" ];
Printf.printf "\n=== Composed Decoder ===\n\n";
let composed =
Decoder.sequence
[ Decoder.wordpiece (); Decoder.replace ~pattern:" " ~by:" " () ]
in
show "wordpiece + replace" composed [ "play"; "##ing"; "is"; "great" ];
Printf.printf "\n=== Integrated with Tokenizer ===\n\n";
let vocab =
[
("[UNK]", 0);
("[CLS]", 1);
("[SEP]", 2);
("play", 3);
("##ing", 4);
("##ed", 5);
("great", 6);
]
in
let tokenizer =
wordpiece ~vocab ~unk_token:"[UNK]"
~specials:[ special "[CLS]"; special "[SEP]" ]
~post:(Post_processor.bert ~cls:("[CLS]", 1) ~sep:("[SEP]", 2) ())
~decoder:(Decoder.wordpiece ()) ()
in
let text = "playing" in
let encoding = encode tokenizer text in
let ids = Encoding.ids encoding in
Printf.printf " Text: %S\n" text;
Printf.printf " Tokens: [%s]\n"
(String.concat "; "
(List.map
(fun s -> Printf.sprintf "%S" s)
(Array.to_list (Encoding.tokens encoding))));
Printf.printf " IDs: [%s]\n"
(String.concat "; " (Array.to_list (Array.map string_of_int ids)));
Printf.printf " Decoded: %S\n" (decode tokenizer ids);
Printf.printf " Decoded (skip specials): %S\n"
(decode tokenizer ~skip_special_tokens:true ids)
================================================
FILE: packages/brot/examples/09-training/README.md
================================================
# `09-training`
Training tokenizers from scratch. Given a text corpus, each algorithm learns a
vocabulary tailored to the data.
```bash
dune exec brot/examples/09-training/main.exe
```
## What You'll Learn
- Training BPE, WordPiece, word-level, and Unigram tokenizers
- Controlling vocabulary size with `~vocab_size`
- Adding special tokens during training
- Inspecting the learned vocabulary
## Key Functions
| Function | Purpose |
| ---------------------- | ------------------------------------------------ |
| `Brot.train_bpe` | Train a BPE tokenizer (learns merge rules) |
| `Brot.train_wordpiece` | Train a WordPiece tokenizer (learns subwords) |
| `Brot.train_wordlevel` | Train a word-level tokenizer (collects words) |
| `Brot.train_unigram` | Train a Unigram tokenizer (learns probabilities) |
| `Brot.vocab_size` | Check learned vocabulary size |
| `Brot.token_to_id` | Look up a token's ID |
## Training Data
Training data is provided as `` `Seq (List.to_seq texts) `` for in-memory text
or `` `Files ["path1"; "path2"] `` for files (one sentence per line).
## Try It
1. Add more sentences to the corpus and see how the vocabulary changes.
2. Train with a smaller `~vocab_size` and observe more subword splitting.
3. Use `~min_frequency:2` to exclude rare words.
## Next Steps
Continue to [10-bert-pipeline](../10-bert-pipeline/) to assemble a complete
BERT-style tokenizer pipeline.
================================================
FILE: packages/brot/examples/09-training/dune
================================================
(executable
(name main)
(libraries brot))
================================================
FILE: packages/brot/examples/09-training/main.ml
================================================
(* Training tokenizers.
Train new tokenizers from a text corpus. Each algorithm learns a different
vocabulary: BPE learns merge rules, WordPiece learns subword prefixes,
word-level collects unique words, and Unigram learns token probabilities. *)
open Brot
let corpus =
[
"the cat sat on the mat";
"the dog sat on the log";
"the cat and the dog are friends";
"cats and dogs play together";
"the cat plays with the dog";
"playing in the park is fun";
"the park has many cats and dogs";
"friends play in the park together";
]
let show_trained name tokenizer test_texts =
Printf.printf "--- %s (vocab_size=%d) ---\n" name (vocab_size tokenizer);
List.iter
(fun text ->
let enc = encode tokenizer text in
Printf.printf " %S -> [%s]\n" text
(String.concat ", "
(List.map
(fun s -> Printf.sprintf "%S" s)
(Array.to_list (Encoding.tokens enc)))))
test_texts;
print_newline ()
let () =
let data = `Seq (List.to_seq corpus) in
let test_texts = [ "the cat plays"; "dogs are friends" ] in
Printf.printf "Training corpus: %d sentences\n\n" (List.length corpus);
(* Train BPE: learns merge rules by iteratively combining frequent pairs *)
let bpe_tok =
train_bpe data ~vocab_size:100 ~show_progress:false
~pre:(Pre_tokenizer.whitespace ())
in
show_trained "BPE" bpe_tok test_texts;
(* Train WordPiece: learns subword prefixes (## for continuation tokens) *)
let wp_tok =
train_wordpiece data ~vocab_size:100 ~show_progress:false
~pre:(Pre_tokenizer.whitespace ())
in
show_trained "WordPiece" wp_tok test_texts;
(* Train word-level: each unique word is a token *)
let wl_tok =
train_wordlevel data ~vocab_size:50 ~show_progress:false
~pre:(Pre_tokenizer.whitespace ())
in
show_trained "Word-level" wl_tok test_texts;
(* Train Unigram: probabilistic subword segmentation *)
let uni_tok = train_unigram data ~vocab_size:100 ~show_progress:false in
show_trained "Unigram" uni_tok test_texts;
(* Training with special tokens *)
Printf.printf "=== Training with Special Tokens ===\n\n";
let wp_with_specials =
train_wordpiece data ~vocab_size:100 ~show_progress:false
~pre:(Pre_tokenizer.whitespace ())
~specials:[ special "[CLS]"; special "[SEP]"; special "[PAD]" ]
~pad_token:"[PAD]"
in
Printf.printf "WordPiece with specials (vocab=%d):\n"
(vocab_size wp_with_specials);
let show_id tok name =
Printf.printf " %s id = %s\n" name
(match token_to_id tok name with
| Some id -> string_of_int id
| None -> "N/A")
in
show_id wp_with_specials "[CLS]";
show_id wp_with_specials "[SEP]";
show_id wp_with_specials "[PAD]";
(* Add a post-processor to insert special tokens during encoding *)
Printf.printf "\n Encoding with post-processor:\n";
let wp_full =
train_wordpiece data ~vocab_size:100 ~show_progress:false
~pre:(Pre_tokenizer.whitespace ())
~post:
(Post_processor.bert
~cls:("[CLS]", Option.get (token_to_id wp_with_specials "[CLS]"))
~sep:("[SEP]", Option.get (token_to_id wp_with_specials "[SEP]"))
())
~specials:[ special "[CLS]"; special "[SEP]"; special "[PAD]" ]
~pad_token:"[PAD]"
in
let enc = encode wp_full "the cat plays" in
Printf.printf " %S -> [%s]\n" "the cat plays"
(String.concat ", "
(List.map
(fun s -> Printf.sprintf "%S" s)
(Array.to_list (Encoding.tokens enc))))
================================================
FILE: packages/brot/examples/10-bert-pipeline/README.md
================================================
# `10-bert-pipeline`
Complete BERT-style tokenizer pipeline. Assembles all stages: normalizer,
pre-tokenizer, WordPiece algorithm, post-processor, decoder, special tokens,
padding, and truncation.
```bash
dune exec brot/examples/10-bert-pipeline/main.exe
```
## What You'll Learn
- Assembling a full tokenization pipeline
- How all stages work together end-to-end
- Single sentence and sentence-pair encoding
- Batch encoding with padding
- Sentence-pair batch encoding with `encode_pairs_batch`
- Decoding with and without special tokens
- Inspecting tokenizer configuration with `Brot.pp`
## Key Functions
| Function | Purpose |
| ---------------------------------- | --------------------------------------------- |
| `Brot.wordpiece` | Full pipeline constructor |
| `Normalizer.bert` | BERT normalizer (lowercase, clean, CJK) |
| `Pre_tokenizer.bert` | BERT pre-tokenizer (whitespace + punctuation) |
| `Post_processor.bert` | Insert `[CLS]` and `[SEP]` tokens |
| `Decoder.wordpiece` | Reverse `##` prefix joining |
| `Brot.encode ~pair` | Encode a sentence pair |
| `Brot.encode_pairs_batch` | Batch-encode sentence pairs |
| `Brot.decode ~skip_special_tokens` | Decode without `[CLS]`/`[SEP]` |
| `Brot.pp` | Pretty-print tokenizer configuration |
## The Full Pipeline
```
Input text
|
v
Normalizer.bert -- lowercase, clean control chars, pad CJK
|
v
Pre_tokenizer.bert -- split on whitespace, isolate punctuation
|
v
WordPiece model -- greedy longest-match subword splitting
|
v
Post_processor.bert -- insert [CLS] and [SEP], set type_ids
|
v
Encoding.t -- ids, tokens, offsets, type_ids, attention_mask
```
## Try It
1. Encode text with accented characters and see the normalizer at work.
2. Change `Post_processor.bert` to `Post_processor.roberta` with `` and
`` tokens for a RoBERTa-style pipeline.
3. Use `save_pretrained` to export the tokenizer and reload it with
`from_file`.
## Further Reading
- [gpt2_tokenizer](../x-gpt2-tokenizer/) -- loading a real GPT-2 tokenizer
from HuggingFace model files
================================================
FILE: packages/brot/examples/10-bert-pipeline/dune
================================================
(executable
(name main)
(libraries brot))
================================================
FILE: packages/brot/examples/10-bert-pipeline/main.ml
================================================
(* BERT-style pipeline.
Assembles all pipeline stages into a complete BERT-style tokenizer:
normalizer, pre-tokenizer, WordPiece algorithm, post-processor, decoder,
special tokens, padding, and truncation. *)
open Brot
let print_encoding label enc =
let tokens = Encoding.tokens enc in
let ids = Encoding.ids enc in
let type_ids = Encoding.type_ids enc in
let attn = Encoding.attention_mask enc in
Printf.printf "%s\n" label;
Printf.printf " tokens: [%s]\n"
(String.concat ", "
(List.map (fun s -> Printf.sprintf "%S" s) (Array.to_list tokens)));
Printf.printf " ids: [%s]\n"
(String.concat ", " (Array.to_list (Array.map string_of_int ids)));
Printf.printf " type_ids: [%s]\n"
(String.concat ", " (Array.to_list (Array.map string_of_int type_ids)));
Printf.printf " attn_mask: [%s]\n"
(String.concat ", " (Array.to_list (Array.map string_of_int attn)));
print_newline ()
let () =
(* Build a BERT-style vocabulary *)
let vocab =
[
("[PAD]", 0);
("[UNK]", 1);
("[CLS]", 2);
("[SEP]", 3);
("the", 4);
("cat", 5);
("sat", 6);
("on", 7);
("mat", 8);
("dog", 9);
("play", 10);
("##ing", 11);
("##ed", 12);
("is", 13);
("a", 14);
("good", 15);
("great", 16);
("un", 17);
("##happy", 18);
("friend", 19);
("##s", 20);
("how", 21);
("are", 22);
("you", 23);
]
in
let specials = List.map special [ "[PAD]"; "[UNK]"; "[CLS]"; "[SEP]" ] in
(* Assemble the full pipeline *)
let tokenizer =
wordpiece ~vocab ~unk_token:"[UNK]"
~normalizer:(Normalizer.bert ~lowercase:true ())
~pre:(Pre_tokenizer.bert ())
~post:(Post_processor.bert ~cls:("[CLS]", 2) ~sep:("[SEP]", 3) ())
~decoder:(Decoder.wordpiece ()) ~specials ~pad_token:"[PAD]" ()
in
(* Inspect the tokenizer *)
Printf.printf "=== Tokenizer Configuration ===\n";
Format.printf "%a@.@." pp tokenizer;
(* Single sentence *)
Printf.printf "=== Single Sentence ===\n\n";
print_encoding "\"The Cat is Playing\""
(encode tokenizer "The Cat is Playing");
(* Sentence pair *)
Printf.printf "=== Sentence Pair ===\n\n";
print_encoding "A: \"the cat sat\", B: \"how are you\""
(encode tokenizer ~pair:"how are you" "the cat sat");
(* Batch with padding *)
Printf.printf "=== Padded Batch ===\n\n";
let batch =
encode_batch tokenizer ~padding:(padding `Batch_longest)
[ "the cat"; "the cat sat on a mat"; "good" ]
in
List.iteri (fun i enc -> print_encoding (Printf.sprintf "[%d]" i) enc) batch;
(* Sentence pairs batch with padding and truncation *)
Printf.printf "=== Sentence Pairs (pad=12, trunc=12) ===\n\n";
let pairs =
encode_pairs_batch tokenizer
~padding:(padding (`Fixed 12))
~truncation:(truncation 12)
[ ("the cat sat", "how are you"); ("good dog", "is a friend") ]
in
List.iteri
(fun i enc -> print_encoding (Printf.sprintf "pair[%d]" i) enc)
pairs;
(* Decoding *)
Printf.printf "=== Decoding ===\n\n";
let enc = encode tokenizer ~pair:"how are you" "the cat sat" in
let ids = Encoding.ids enc in
Printf.printf " Full decode: %S\n" (decode tokenizer ids);
Printf.printf " Skip specials: %S\n"
(decode tokenizer ~skip_special_tokens:true ids)
================================================
FILE: packages/brot/examples/README.md
================================================
# Brot Examples
Learn Brot through progressively complex examples. Start with `01-encode-decode`
and work through the numbered examples in order.
## Examples
| Example | Concept | Key Functions |
|---------|---------|---------------|
| [`01-encode-decode`](./01-encode-decode/) | Text to IDs and back | `bpe`, `encode`, `decode` |
| [`02-encoding-fields`](./02-encoding-fields/) | Encoding metadata | `Encoding.ids`, `.tokens`, `.offsets` |
| [`03-normalizers`](./03-normalizers/) | Text normalization | `Normalizer.lowercase`, `.bert`, `.sequence` |
| [`04-pre-tokenizers`](./04-pre-tokenizers/) | Splitting before vocab | `Pre_tokenizer.whitespace`, `.bert`, `.sequence` |
| [`05-algorithms`](./05-algorithms/) | Algorithm comparison | `bpe`, `wordpiece`, `unigram`, `word_level`, `chars` |
| [`06-special-tokens`](./06-special-tokens/) | Special tokens and post-processing | `Post_processor.bert`, `.template`, `encode ~pair` |
| [`07-padding-truncation`](./07-padding-truncation/) | Batch preparation | `padding`, `truncation`, `encode_batch` |
| [`08-decoders`](./08-decoders/) | Tokens back to text | `Decoder.wordpiece`, `.bpe`, `.fuse`, `.sequence` |
| [`09-training`](./09-training/) | Train from scratch | `train_bpe`, `train_wordpiece`, `train_unigram` |
| [`10-bert-pipeline`](./10-bert-pipeline/) | Full BERT pipeline | All stages assembled end-to-end |
Advanced:
- [**x-gpt2-tokenizer**](./x-gpt2-tokenizer/): Loading a real GPT-2 tokenizer
from HuggingFace model files
## Running Examples
All examples can be run with:
```bash
dune exec brot/examples//main.exe
```
For example:
```bash
dune exec brot/examples/01-encode-decode/main.exe
```
## Quick Reference
### Encode and Decode
```ocaml
open Brot
let tokenizer = bpe ~vocab:[("hello", 0); ...] ~merges:[...] () in
let encoding = encode tokenizer "hello world" in
let ids = Encoding.ids encoding in
let text = decode tokenizer ids
```
### Full Pipeline
```ocaml
let tokenizer =
wordpiece ~vocab
~normalizer:(Normalizer.bert ~lowercase:true ())
~pre:(Pre_tokenizer.bert ())
~post:(Post_processor.bert ~cls:("[CLS]", 2) ~sep:("[SEP]", 3) ())
~decoder:(Decoder.wordpiece ())
~specials:(List.map special [ "[CLS]"; "[SEP]"; "[PAD]" ])
~pad_token:"[PAD]" ()
```
### Train from Text
```ocaml
let tokenizer =
train_bpe (`Seq (List.to_seq texts)) ~vocab_size:1000
```
================================================
FILE: packages/brot/examples/x-gpt2-tokenizer/README.md
================================================
# `x-gpt2-tokenizer`
Loading a real GPT-2 tokenizer from HuggingFace model files. This example
downloads GPT-2's vocabulary and merges, builds the full byte-level BPE
pipeline, and demonstrates encoding, decoding, and subword inspection.
```bash
dune exec brot/examples/x-gpt2-tokenizer/main.exe
```
## What You'll Learn
- Loading a pre-trained tokenizer from vocabulary and merge files
- Building a byte-level BPE pipeline with `from_model_file`
- Encoding text and inspecting tokens, IDs, and offsets
- Decoding token IDs back to text
- Subword splitting on real vocabulary
- Batch encoding multiple texts
## Key Functions
| Function | Purpose |
| -------------------------- | ----------------------------------------------- |
| `Brot.from_model_file` | Load tokenizer from vocab.json and merges.txt |
| `Pre_tokenizer.byte_level` | GPT-2 style byte-level pre-tokenizer |
| `Decoder.byte_level` | Corresponding byte-level decoder |
| `Brot.encode` | Encode text to an `Encoding.t` |
| `Brot.decode` | Decode token IDs back to text |
| `Brot.encode_batch` | Encode multiple texts at once |
| `Encoding.tokens` | Token strings from an encoding |
| `Encoding.ids` | Token IDs from an encoding |
| `Encoding.offsets` | Byte offset pairs mapping tokens to source text |
## Prerequisites
This example downloads GPT-2 model files from HuggingFace on first run
(~1 MB total). Files are cached in `/tmp/brot_gpt2/`.
## Output Walkthrough
```
Vocabulary: 50257 tokens
Text: "Hello world! GPT-2 is amazing."
Tokens: ["Hello"; " world"; "!"; " GPT"; "-"; "2"; " is"; " amazing"; "."]
IDs: [15496; 995; 0; 402; 12; 17; 318; 4998; 13]
Decoded: "Hello world! GPT-2 is amazing."
Round-trip: true
=== Subword Splitting ===
"tokenization" -> 3 tokens: ["token", "ization"]
"transformer" -> 1 tokens: ["transformer"]
...
=== Batch Encoding ===
"The quick brown fox" -> 4 tokens
"jumps over the lazy dog" -> 5 tokens
"Machine learning is fun" -> 4 tokens
=== Token Offsets ===
Text: "Hello, world!"
Hello offsets=(0, 5) source="Hello"
, offsets=(5, 6) source=","
...
```
## Try It
1. Change the input text and see how GPT-2 tokenizes different sentences.
2. Try words with unusual spellings to see subword splitting in action.
3. Compare the token count for English text vs other languages.
## See Also
- [01-encode-decode](../01-encode-decode/) for basic encoding and decoding
- [05-algorithms](../05-algorithms/) for comparing tokenization algorithms
- [08-decoders](../08-decoders/) for decoder options
================================================
FILE: packages/brot/examples/x-gpt2-tokenizer/dune
================================================
(executable
(name main)
(libraries brot nx unix))
================================================
FILE: packages/brot/examples/x-gpt2-tokenizer/main.ml
================================================
(* Loading a real GPT-2 tokenizer.
Downloads GPT-2's vocabulary and merge files from HuggingFace, builds the
full byte-level BPE pipeline, and demonstrates encoding, decoding, and
subword inspection on real-world text. *)
open Brot
let download url dest =
if not (Sys.file_exists dest) then (
Printf.printf "Downloading %s...\n%!" (Filename.basename dest);
let cmd =
Printf.sprintf "curl -L --fail -s -o %s %s" (Filename.quote dest)
(Filename.quote url)
in
match Unix.system cmd with
| Unix.WEXITED 0 -> ()
| _ -> failwith (Printf.sprintf "Failed to download %s" url))
let () =
(* Download GPT-2 model files *)
let cache = "/tmp/brot_gpt2" in
if not (Sys.file_exists cache) then Sys.mkdir cache 0o755;
let vocab_file = Filename.concat cache "vocab.json" in
let merges_file = Filename.concat cache "merges.txt" in
download "https://huggingface.co/gpt2/raw/main/vocab.json" vocab_file;
download "https://huggingface.co/gpt2/raw/main/merges.txt" merges_file;
(* Build the GPT-2 tokenizer: BPE with byte-level pre-tokenizer *)
let tokenizer =
from_model_file ~vocab:vocab_file ~merges:merges_file
~pre:(Pre_tokenizer.byte_level ~add_prefix_space:false ())
~decoder:(Decoder.byte_level ()) ()
in
Printf.printf "\nVocabulary: %d tokens\n\n" (vocab_size tokenizer);
(* Encode text *)
let text = "Hello world! GPT-2 is amazing." in
let enc = encode tokenizer text in
Printf.printf "Text: %S\n" text;
Printf.printf "Tokens: [%s]\n"
(String.concat "; "
(List.map
(fun s -> Printf.sprintf "%S" s)
(Array.to_list (Encoding.tokens enc))));
Printf.printf "IDs: [%s]\n"
(String.concat "; "
(Array.to_list (Array.map string_of_int (Encoding.ids enc))));
(* Decode back *)
let decoded = decode tokenizer (Encoding.ids enc) in
Printf.printf "Decoded: %S\n" decoded;
Printf.printf "Round-trip: %b\n\n" (String.equal text decoded);
(* Subword splitting: see how a long word is broken down *)
Printf.printf "=== Subword Splitting ===\n\n";
List.iter
(fun word ->
let e = encode tokenizer word in
let tokens = Encoding.tokens e in
Printf.printf " %-20s -> %d tokens: [%s]\n" (Printf.sprintf "%S" word)
(Array.length tokens)
(String.concat ", "
(List.map (fun s -> Printf.sprintf "%S" s) (Array.to_list tokens))))
[ "tokenization"; "transformer"; "GPT"; "Hello"; "supercalifragilistic" ];
(* Batch encoding *)
Printf.printf "\n=== Batch Encoding ===\n\n";
let texts =
[
"The quick brown fox";
"jumps over the lazy dog";
"Machine learning is fun";
]
in
let batch = encode_batch tokenizer texts in
List.iter2
(fun text enc ->
Printf.printf " %-30s -> %d tokens\n" (Printf.sprintf "%S" text)
(Encoding.length enc))
texts batch;
(* Offsets: map tokens back to source text *)
Printf.printf "\n=== Token Offsets ===\n\n";
let text2 = "Hello, world!" in
let enc2 = encode tokenizer text2 in
Printf.printf "Text: %S\n" text2;
let tokens = Encoding.tokens enc2 in
let offsets = Encoding.offsets enc2 in
for i = 0 to Encoding.length enc2 - 1 do
let s, e = offsets.(i) in
Printf.printf " %-8s offsets=(%d, %d) source=%S\n" tokens.(i) s e
(String.sub text2 s (e - s))
done
================================================
FILE: packages/brot/lib/bpe.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let list_drop n l =
let rec aux i = function
| _ :: l when i < n -> aux (i + 1) l
| rest -> rest
in
if n <= 0 then l else aux 0 l
type vocab = (string, int) Hashtbl.t
type merges = (string * string) list
(* Open-addressing hash table for merge lookups. Returns int directly (no option
allocation). -1 = not found. *)
module Merge_map = struct
type t = { keys : int array; values : int array; mask : int }
let[@inline] hash key =
let h = key * 0x1B873593 in
h lxor (h lsr 16)
let create entries =
let n = List.length entries in
let cap = ref 16 in
while !cap < n * 4 do
cap := !cap * 2
done;
let mask = !cap - 1 in
let keys = Array.make !cap (-1) in
let values = Array.make !cap 0 in
List.iter
(fun (key, value) ->
let h = ref (hash key land mask) in
while Array.unsafe_get keys !h >= 0 do
h := (!h + 1) land mask
done;
Array.unsafe_set keys !h key;
Array.unsafe_set values !h value)
entries;
{ keys; values; mask }
let[@inline] find t key =
let mask = t.mask in
let keys = t.keys in
let h = ref (hash key land mask) in
let k = ref (Array.unsafe_get keys !h) in
while !k <> key && !k >= 0 do
h := (!h + 1) land mask;
k := Array.unsafe_get keys !h
done;
if !k = key then Array.unsafe_get t.values !h else -1
let fold f t acc =
let keys = t.keys in
let values = t.values in
let len = Array.length keys in
let acc = ref acc in
for i = 0 to len - 1 do
let k = Array.unsafe_get keys i in
if k >= 0 then acc := f k (Array.unsafe_get values i) !acc
done;
!acc
end
let[@inline] merge_key a b = (a lsl 21) lor b
let[@inline] pack_merge rank new_id = (rank lsl 21) lor new_id
let[@inline] merge_rank v = v lsr 21
let[@inline] merge_new_id v = v land 0x1FFFFF
type word = {
sym_c : int array;
sym_prev : int array;
sym_next : int array;
sym_len : int array;
mutable size : int;
}
(* Specialized min-heap for BPE merges using parallel arrays (no tuple allocation).
Ordered by (rank, position) — lower rank first, then lower position. *)
(* Min-heap with packed comparison key: (rank lsl 21) lor pos.
Single int comparison for sift operations, 2 arrays instead of 3. *)
module Merge_queue = struct
type t = {
mutable keys : int array;
mutable new_ids : int array;
mutable size : int;
mutable pop_key : int;
mutable pop_new_id : int;
mutable skip_keys : int array;
mutable skip_new_ids : int array;
mutable skip_size : int;
}
let create cap =
let cap = max 16 cap in
{
keys = Array.make cap 0;
new_ids = Array.make cap 0;
size = 0;
pop_key = 0;
pop_new_id = 0;
skip_keys = [||];
skip_new_ids = [||];
skip_size = 0;
}
let[@inline] pack_key rank pos = (rank lsl 21) lor pos
let sift_up t idx =
let keys = t.keys in
let new_ids = t.new_ids in
let key = Array.unsafe_get keys idx in
let nid = Array.unsafe_get new_ids idx in
let i = ref idx in
let cont = ref (!i > 0) in
while !cont do
let p = (!i - 1) asr 1 in
if key < Array.unsafe_get keys p then (
Array.unsafe_set keys !i (Array.unsafe_get keys p);
Array.unsafe_set new_ids !i (Array.unsafe_get new_ids p);
i := p;
cont := !i > 0)
else cont := false
done;
Array.unsafe_set keys !i key;
Array.unsafe_set new_ids !i nid
let sift_down t idx =
let keys = t.keys in
let new_ids = t.new_ids in
let size = t.size in
let key = Array.unsafe_get keys idx in
let nid = Array.unsafe_get new_ids idx in
let i = ref idx in
let continue_ = ref true in
while !continue_ do
let l = (2 * !i) + 1 in
if l >= size then continue_ := false
else begin
let r = l + 1 in
let smallest =
if r < size && Array.unsafe_get keys r < Array.unsafe_get keys l then
r
else l
in
if Array.unsafe_get keys smallest < key then (
Array.unsafe_set keys !i (Array.unsafe_get keys smallest);
Array.unsafe_set new_ids !i (Array.unsafe_get new_ids smallest);
i := smallest)
else continue_ := false
end
done;
Array.unsafe_set keys !i key;
Array.unsafe_set new_ids !i nid
let push t rank pos new_id =
let s = t.size in
if s = Array.length t.keys then begin
let new_cap = max 16 (s * 2) in
let grow a =
let b = Array.make new_cap 0 in
Array.blit a 0 b 0 s;
b
in
t.keys <- grow t.keys;
t.new_ids <- grow t.new_ids
end;
Array.unsafe_set t.keys s (pack_key rank pos);
Array.unsafe_set t.new_ids s new_id;
t.size <- s + 1;
sift_up t s
let pop t =
if t.size = 0 then false
else begin
t.pop_key <- Array.unsafe_get t.keys 0;
t.pop_new_id <- Array.unsafe_get t.new_ids 0;
t.size <- t.size - 1;
if t.size > 0 then begin
Array.unsafe_set t.keys 0 (Array.unsafe_get t.keys t.size);
Array.unsafe_set t.new_ids 0 (Array.unsafe_get t.new_ids t.size);
sift_down t 0
end;
true
end
end
type token = { id : int; value : string; offsets : int * int }
(* Direct-mapped bounded cache: hash key to slot, newest entry wins. Fixed
memory, no eviction logic, no unbounded growth. *)
type cache = {
cache_keys : string array;
cache_vals : word array;
cache_mask : int;
}
let empty_word =
{ sym_c = [||]; sym_prev = [||]; sym_next = [||]; sym_len = [||]; size = 0 }
let create_cache capacity =
(* Round up to power of 2 *)
let cap = ref 16 in
while !cap < capacity do
cap := !cap * 2
done;
{
cache_keys = Array.make !cap "";
cache_vals = Array.make !cap empty_word;
cache_mask = !cap - 1;
}
let[@inline] cache_find c key =
let h = Hashtbl.hash key land c.cache_mask in
if String.equal (Array.unsafe_get c.cache_keys h) key then
Array.unsafe_get c.cache_vals h
else empty_word
let[@inline] cache_add c key value =
let h = Hashtbl.hash key land c.cache_mask in
Array.unsafe_set c.cache_keys h key;
Array.unsafe_set c.cache_vals h value
type t = {
vocab : vocab;
vocab_r : string array;
merges : Merge_map.t;
cache : cache option;
dropout : float option;
unk_token : string option;
continuing_subword_prefix : string option;
end_of_word_suffix : string option;
fuse_unk : bool;
byte_fallback : bool;
ignore_merges : bool;
ascii_to_id : int array;
byte_fallback_ids : int array;
char_to_id : Merge_map.t;
prefixed_ascii_to_id : int array;
prefixed_char_to_id : Merge_map.t;
unk_id : int;
mutable work_word : word;
mutable work_queue : Merge_queue.t;
work_in_use : bool Atomic.t;
}
let create_word capacity =
let cap = max 16 capacity in
{
sym_c = Array.make cap 0;
sym_prev = Array.make cap 0;
sym_next = Array.make cap 0;
sym_len = Array.make cap 0;
size = 0;
}
let ensure_word_capacity word capacity =
if Array.length word.sym_c >= capacity then begin
word.size <- 0;
word
end
else create_word capacity
let ensure_queue_capacity queue capacity =
let cap = max 16 capacity in
if Array.length queue.Merge_queue.keys >= cap then begin
queue.Merge_queue.size <- 0;
queue
end
else Merge_queue.create cap
let[@inline] add_symbol word c byte_len =
let s = word.size in
let prev = if s > 0 then s - 1 else -1 in
Array.unsafe_set word.sym_c s c;
Array.unsafe_set word.sym_prev s prev;
Array.unsafe_set word.sym_next s (-1);
Array.unsafe_set word.sym_len s byte_len;
if prev >= 0 then Array.unsafe_set word.sym_next prev s;
word.size <- s + 1
let apply_merges model dropout word queue =
let p = match dropout with Some p -> p | None -> 0.0 in
let use_dropout = p > 0.0 in
let merges = model.merges in
let sym_c = word.sym_c in
let sym_prev = word.sym_prev in
let sym_next = word.sym_next in
let sym_len = word.sym_len in
for i = 0 to word.size - 2 do
let key =
merge_key (Array.unsafe_get sym_c i) (Array.unsafe_get sym_c (i + 1))
in
let packed = Merge_map.find merges key in
if packed >= 0 then
Merge_queue.push queue (merge_rank packed) i (merge_new_id packed)
done;
queue.skip_size <- 0;
while Merge_queue.pop queue do
let pkey = queue.pop_key in
let pos = pkey land 0x1FFFFF in
let new_id = queue.pop_new_id in
if Array.unsafe_get sym_len pos > 0 then begin
let next_pos = Array.unsafe_get sym_next pos in
if next_pos >= 0 then begin
let key =
merge_key
(Array.unsafe_get sym_c pos)
(Array.unsafe_get sym_c next_pos)
in
let packed = Merge_map.find merges key in
if packed >= 0 && merge_new_id packed = new_id then
begin if use_dropout && Random.float 1.0 < p then begin
let s = queue.skip_size in
if s = Array.length queue.skip_keys then begin
let new_cap = max 8 (s * 2) in
let grow old =
let a = Array.make new_cap 0 in
if s > 0 then Array.blit old 0 a 0 s;
a
in
queue.skip_keys <- grow queue.skip_keys;
queue.skip_new_ids <- grow queue.skip_new_ids
end;
Array.unsafe_set queue.skip_keys s pkey;
Array.unsafe_set queue.skip_new_ids s new_id;
queue.skip_size <- s + 1
end
else begin
for i = 0 to queue.skip_size - 1 do
Merge_queue.push queue
(Array.unsafe_get queue.skip_keys i lsr 21)
(Array.unsafe_get queue.skip_keys i land 0x1FFFFF)
(Array.unsafe_get queue.skip_new_ids i)
done;
queue.skip_size <- 0;
Array.unsafe_set sym_c pos new_id;
Array.unsafe_set sym_len pos
(Array.unsafe_get sym_len pos + Array.unsafe_get sym_len next_pos);
Array.unsafe_set sym_next pos (Array.unsafe_get sym_next next_pos);
Array.unsafe_set sym_len next_pos 0;
let new_next = Array.unsafe_get sym_next pos in
if new_next >= 0 then Array.unsafe_set sym_prev new_next pos;
let prev = Array.unsafe_get sym_prev pos in
if prev >= 0 then begin
let k =
merge_key
(Array.unsafe_get sym_c prev)
(Array.unsafe_get sym_c pos)
in
let v = Merge_map.find merges k in
if v >= 0 then
Merge_queue.push queue (merge_rank v) prev (merge_new_id v)
end;
let next = Array.unsafe_get sym_next pos in
if next >= 0 then begin
let k =
merge_key
(Array.unsafe_get sym_c pos)
(Array.unsafe_get sym_c next)
in
let v = Merge_map.find merges k in
if v >= 0 then
Merge_queue.push queue (merge_rank v) pos (merge_new_id v)
end
end
end
end
end
done;
(* Compact using linked-list traversal: O(N_final) instead of O(N_original) *)
let j = ref 0 in
let cur = ref 0 in
while !cur >= 0 do
if !j <> !cur then begin
Array.unsafe_set sym_c !j (Array.unsafe_get sym_c !cur);
Array.unsafe_set sym_len !j (Array.unsafe_get sym_len !cur)
end;
incr j;
cur := Array.unsafe_get sym_next !cur
done;
word.size <- !j
let utf8_byte_len_table =
Array.init 256 (fun b ->
if b land 0x80 = 0 then 1
else if b land 0xE0 = 0xC0 then 2
else if b land 0xF0 = 0xE0 then 3
else if b land 0xF8 = 0xF0 then 4
else 1)
let[@inline] utf8_byte_len b = Array.unsafe_get utf8_byte_len_table b
let[@inline] pack_char_key text pos byte_len =
let b0 = Char.code (String.unsafe_get text pos) in
match byte_len with
| 1 -> b0
| 2 -> (b0 lsl 8) lor Char.code (String.unsafe_get text (pos + 1))
| 3 ->
(b0 lsl 16)
lor (Char.code (String.unsafe_get text (pos + 1)) lsl 8)
lor Char.code (String.unsafe_get text (pos + 2))
| _ ->
(b0 lsl 24)
lor (Char.code (String.unsafe_get text (pos + 1)) lsl 16)
lor (Char.code (String.unsafe_get text (pos + 2)) lsl 8)
lor Char.code (String.unsafe_get text (pos + 3))
(* Try emitting byte fallback tokens for [byte_len] bytes starting at [src]
offset [offset]. Returns true if all bytes had fallback IDs. *)
let try_byte_fallback model word flush_unk src offset byte_len =
let all_found = ref true in
for i = 0 to byte_len - 1 do
if
Array.unsafe_get model.byte_fallback_ids
(Char.code (String.unsafe_get src (offset + i)))
< 0
then all_found := false
done;
if !all_found then begin
flush_unk ();
for i = 0 to byte_len - 1 do
add_symbol word
(Array.unsafe_get model.byte_fallback_ids
(Char.code (String.unsafe_get src (offset + i))))
1
done;
true
end
else false
(* No prefix/suffix — avoids all per-character string allocation for ASCII via
pre-computed lookup tables. *)
let init_word_fast model word text text_len =
let pos = ref 0 in
let pending_unk_id = ref (-1) in
let pending_unk_len = ref 0 in
let flush_unk () =
if !pending_unk_id >= 0 then begin
add_symbol word !pending_unk_id !pending_unk_len;
pending_unk_id := -1;
pending_unk_len := 0
end
in
let handle_unk byte_len =
if model.unk_id >= 0 then
begin if model.fuse_unk then
begin if !pending_unk_id >= 0 then
pending_unk_len := !pending_unk_len + byte_len
else begin
pending_unk_id := model.unk_id;
pending_unk_len := byte_len
end
end
else begin
flush_unk ();
add_symbol word model.unk_id byte_len
end
end
in
while !pos < text_len do
let b = Char.code (String.unsafe_get text !pos) in
if b < 128 then begin
let id = Array.unsafe_get model.ascii_to_id b in
if id >= 0 then begin
flush_unk ();
add_symbol word id 1
end
else if model.byte_fallback then begin
let fbid = Array.unsafe_get model.byte_fallback_ids b in
if fbid >= 0 then begin
flush_unk ();
add_symbol word fbid 1
end
else handle_unk 1
end
else handle_unk 1;
incr pos
end
else begin
let byte_len = utf8_byte_len b in
let key = pack_char_key text !pos byte_len in
let id = Merge_map.find model.char_to_id key in
if id >= 0 then begin
flush_unk ();
add_symbol word id byte_len
end
else if model.byte_fallback then
begin if not (try_byte_fallback model word flush_unk text !pos byte_len)
then handle_unk byte_len
end
else handle_unk byte_len;
pos := !pos + byte_len
end
done;
flush_unk ()
(* Models with continuing_subword_prefix or end_of_word_suffix *)
let init_word_slow model word text text_len =
let pos = ref 0 in
let pending_unk_id = ref (-1) in
let pending_unk_len = ref 0 in
let flush_unk () =
if !pending_unk_id >= 0 then begin
add_symbol word !pending_unk_id !pending_unk_len;
pending_unk_id := -1;
pending_unk_len := 0
end
in
let handle_unk byte_len =
if model.unk_id >= 0 then
begin if model.fuse_unk then
begin if !pending_unk_id >= 0 then
pending_unk_len := !pending_unk_len + byte_len
else begin
pending_unk_id := model.unk_id;
pending_unk_len := byte_len
end
end
else begin
flush_unk ();
add_symbol word model.unk_id byte_len
end
end
in
let has_prefix = model.continuing_subword_prefix <> None in
let has_suffix = model.end_of_word_suffix <> None in
while !pos < text_len do
let b = Char.code (String.unsafe_get text !pos) in
let byte_len = utf8_byte_len b in
if b land 0xC0 = 0x80 then pos := !pos + 1
else begin
let start = !pos in
let is_first = start = 0 in
let is_last = !pos + byte_len >= text_len in
pos := !pos + byte_len;
(* Suffix only applies at word boundaries (first-not-last or
last-not-first), never to middle chars and never to single-char
words *)
let needs_string = has_suffix && is_first <> is_last in
if needs_string then begin
(* Slow path: suffix involved, at most 2x per word *)
let char_str = String.sub text start byte_len in
let token_str =
match
( is_first,
is_last,
model.continuing_subword_prefix,
model.end_of_word_suffix )
with
| true, false, _, Some suffix -> char_str ^ suffix
| true, false, _, None -> char_str
| false, true, Some prefix, Some suffix -> prefix ^ char_str ^ suffix
| false, true, Some prefix, None -> prefix ^ char_str
| false, true, None, Some suffix -> char_str ^ suffix
| false, true, None, None -> char_str
| _, _, _, _ -> char_str
in
match Hashtbl.find_opt model.vocab token_str with
| Some id ->
flush_unk ();
add_symbol word id byte_len
| None ->
if model.byte_fallback then
begin if
not (try_byte_fallback model word flush_unk text start byte_len)
then handle_unk byte_len
end
else handle_unk byte_len
end
else begin
(* Fast path: no suffix, use packed-int lookup (zero allocation) *)
let needs_prefix = has_prefix && not is_first in
let id =
if needs_prefix then
if b < 128 then Array.unsafe_get model.prefixed_ascii_to_id b
else
Merge_map.find model.prefixed_char_to_id
(pack_char_key text start byte_len)
else if b < 128 then Array.unsafe_get model.ascii_to_id b
else
Merge_map.find model.char_to_id (pack_char_key text start byte_len)
in
if id >= 0 then begin
flush_unk ();
add_symbol word id byte_len
end
else if model.byte_fallback then
begin if
not (try_byte_fallback model word flush_unk text start byte_len)
then handle_unk byte_len
end
else handle_unk byte_len
end
end
done;
flush_unk ()
let merge_word model text =
let text_len = String.length text in
let owned = Atomic.compare_and_set model.work_in_use false true in
let word, queue =
if owned then begin
let w = ensure_word_capacity model.work_word text_len in
model.work_word <- w;
let q = ensure_queue_capacity model.work_queue text_len in
model.work_queue <- q;
(w, q)
end
else (create_word text_len, Merge_queue.create text_len)
in
if model.continuing_subword_prefix = None && model.end_of_word_suffix = None
then init_word_fast model word text text_len
else init_word_slow model word text text_len;
apply_merges model model.dropout word queue;
if owned then begin
let n = word.size in
let sym_c = Array.make n 0 in
let sym_len = Array.make n 0 in
Array.blit word.sym_c 0 sym_c 0 n;
Array.blit word.sym_len 0 sym_len 0 n;
Atomic.set model.work_in_use false;
{ sym_c; sym_prev = [||]; sym_next = [||]; sym_len; size = n }
end
else word
let word_to_tokens model word =
let offset = ref 0 in
List.init word.size (fun i ->
let id = Array.unsafe_get word.sym_c i in
let vr = model.vocab_r in
let value =
if id >= 0 && id < Array.length vr then Array.unsafe_get vr id
else ""
in
let start = !offset in
let end_ = start + Array.unsafe_get word.sym_len i in
offset := end_;
{ id; value; offsets = (start, end_) })
let word_to_ids word =
Array.init word.size (fun i -> Array.unsafe_get word.sym_c i)
let word_to_encoding model word ~type_id =
let n = word.size in
let ids = Array.make n 0 in
let tokens = Array.make n "" in
let offsets = Array.make n (0, 0) in
let offset = ref 0 in
for i = 0 to n - 1 do
let id = Array.unsafe_get word.sym_c i in
Array.unsafe_set ids i id;
let vr = model.vocab_r in
Array.unsafe_set tokens i
(if id >= 0 && id < Array.length vr then Array.unsafe_get vr id
else "");
let start = !offset in
let end_ = start + Array.unsafe_get word.sym_len i in
Array.unsafe_set offsets i (start, end_);
offset := end_
done;
Encoding.create ~ids ~type_ids:(Array.make n type_id) ~tokens
~words:(Array.make n None) ~offsets ~special_tokens_mask:(Array.make n 0)
~attention_mask:(Array.make n 1) ()
let get_word model text =
if model.ignore_merges then merge_word model text
else
match model.cache with
| Some cache when String.length text < 4096 ->
let cached = cache_find cache text in
if cached.size > 0 then cached
else
let word = merge_word model text in
cache_add cache text word;
word
| _ -> merge_word model text
let tokenize model text =
if String.length text = 0 then []
else
match Hashtbl.find_opt model.vocab text with
| Some id -> [ { id; value = text; offsets = (0, String.length text) } ]
| None -> word_to_tokens model (get_word model text)
let tokenize_ids model text =
if String.length text = 0 then [||]
else
match Hashtbl.find_opt model.vocab text with
| Some id -> [| id |]
| None -> word_to_ids (get_word model text)
let tokenize_encoding model text ~type_id =
if String.length text = 0 then Encoding.empty
else
match Hashtbl.find_opt model.vocab text with
| Some id ->
Encoding.token ~id ~token:text
~offset:(0, String.length text)
~type_id ~special:false
| None -> word_to_encoding model (get_word model text) ~type_id
let token_to_id model token = Hashtbl.find_opt model.vocab token
let id_to_token model id =
if id >= 0 && id < Array.length model.vocab_r then
Some (Array.unsafe_get model.vocab_r id)
else None
let get_vocab model = Hashtbl.fold (fun k v acc -> (k, v) :: acc) model.vocab []
let get_vocab_size model = Hashtbl.length model.vocab
let get_unk_token model = model.unk_token
let get_continuing_subword_prefix model = model.continuing_subword_prefix
let get_end_of_word_suffix model = model.end_of_word_suffix
let get_merges model =
Merge_map.fold
(fun key packed acc ->
let a_id = key lsr 21 in
let b_id = key land 0x1FFFFF in
let rank = merge_rank packed in
let vr = model.vocab_r in
let vr_len = Array.length vr in
if a_id >= 0 && a_id < vr_len && b_id >= 0 && b_id < vr_len then
(rank, (Array.unsafe_get vr a_id, Array.unsafe_get vr b_id)) :: acc
else acc)
model.merges []
|> List.sort (fun (r1, _) (r2, _) -> Int.compare r1 r2)
|> List.map snd
let convert_merges_to_merge_map vocab merges continuing_subword_prefix =
let csp_str =
match continuing_subword_prefix with Some p -> p | None -> ""
in
let csp_len = String.length csp_str in
List.mapi
(fun rank (a, b) ->
match (Hashtbl.find_opt vocab a, Hashtbl.find_opt vocab b) with
| Some a_id, Some b_id -> (
let alen = String.length a in
let blen = String.length b in
let new_token =
if
csp_len > 0 && blen > csp_len
&& String.starts_with ~prefix:csp_str b
then (
let brest = blen - csp_len in
let s = Bytes.create (alen + brest) in
Bytes.blit_string a 0 s 0 alen;
Bytes.blit_string b csp_len s alen brest;
Bytes.unsafe_to_string s)
else
let s = Bytes.create (alen + blen) in
Bytes.blit_string a 0 s 0 alen;
Bytes.blit_string b 0 s alen blen;
Bytes.unsafe_to_string s
in
match Hashtbl.find_opt vocab new_token with
| Some new_id -> Some ((a_id, b_id), pack_merge rank new_id)
| None ->
failwith
(Printf.sprintf "Merge token '%s' not in vocabulary" new_token))
| _ ->
failwith
(Printf.sprintf "Merge tokens ('%s', '%s') not in vocabulary" a b))
merges
|> List.filter_map Fun.id
|> fun entries ->
Merge_map.create
(List.map
(fun ((a_id, b_id), packed) -> (merge_key a_id b_id, packed))
entries)
let create ~vocab ~merges ?(cache_capacity = 10000) ?dropout ?unk_token
?continuing_subword_prefix ?end_of_word_suffix ?(fuse_unk = false)
?(byte_fallback = false) ?(ignore_merges = false) () : t =
let max_id = Hashtbl.fold (fun _ id acc -> max id acc) vocab (-1) in
let vocab_r = Array.make (max_id + 1) "" in
Hashtbl.iter (fun k v -> Array.unsafe_set vocab_r v k) vocab;
let cache =
if cache_capacity = 0 then None else Some (create_cache cache_capacity)
in
let merges =
convert_merges_to_merge_map vocab merges continuing_subword_prefix
in
let ascii_to_id = Array.make 128 (-1) in
for i = 0 to 127 do
let s = String.make 1 (Char.chr i) in
match Hashtbl.find_opt vocab s with
| Some id -> ascii_to_id.(i) <- id
| None -> ()
done;
let byte_fallback_ids = Array.make 256 (-1) in
for i = 0 to 255 do
let hex = Printf.sprintf "<0x%02X>" i in
match Hashtbl.find_opt vocab hex with
| Some id -> byte_fallback_ids.(i) <- id
| None -> ()
done;
(* Build packed-int char lookup table for zero-allocation multi-byte lookup *)
let char_entries = ref [] in
Hashtbl.iter
(fun key id ->
let len = String.length key in
if len >= 1 && len <= 4 then begin
let b0 = Char.code (String.unsafe_get key 0) in
let expected_len = utf8_byte_len b0 in
if expected_len = len then
let packed =
match len with
| 1 -> b0
| 2 -> (b0 lsl 8) lor Char.code (String.unsafe_get key 1)
| 3 ->
(b0 lsl 16)
lor (Char.code (String.unsafe_get key 1) lsl 8)
lor Char.code (String.unsafe_get key 2)
| _ ->
(b0 lsl 24)
lor (Char.code (String.unsafe_get key 1) lsl 16)
lor (Char.code (String.unsafe_get key 2) lsl 8)
lor Char.code (String.unsafe_get key 3)
in
char_entries := (packed, id) :: !char_entries
end)
vocab;
let char_to_id = Merge_map.create !char_entries in
(* Build prefixed char lookup tables for zero-allocation init_word_slow *)
let prefixed_ascii_to_id = Array.make 128 (-1) in
let prefixed_char_entries = ref [] in
(match continuing_subword_prefix with
| Some prefix ->
for i = 0 to 127 do
let s = prefix ^ String.make 1 (Char.chr i) in
match Hashtbl.find_opt vocab s with
| Some id -> prefixed_ascii_to_id.(i) <- id
| None -> ()
done;
Hashtbl.iter
(fun key id ->
let plen = String.length prefix in
let klen = String.length key in
if klen > plen && String.sub key 0 plen = prefix then begin
let rest_len = klen - plen in
if rest_len >= 2 && rest_len <= 4 then begin
let b0 = Char.code (String.unsafe_get key plen) in
let expected = utf8_byte_len b0 in
if expected = rest_len then
let packed = pack_char_key key plen rest_len in
prefixed_char_entries := (packed, id) :: !prefixed_char_entries
end
end)
vocab
| None -> ());
let prefixed_char_to_id = Merge_map.create !prefixed_char_entries in
let unk_id =
match unk_token with
| Some unk -> (
match Hashtbl.find_opt vocab unk with Some id -> id | None -> -1)
| None -> -1
in
{
vocab;
vocab_r;
merges;
cache;
dropout;
unk_token;
continuing_subword_prefix;
end_of_word_suffix;
fuse_unk;
byte_fallback;
ignore_merges;
ascii_to_id;
byte_fallback_ids;
char_to_id;
prefixed_ascii_to_id;
prefixed_char_to_id;
unk_id;
work_word = create_word 16;
work_queue = Merge_queue.create 16;
work_in_use = Atomic.make false;
}
let json_of_string s =
match Jsont_bytesrw.decode_string Jsont.json s with
| Ok v -> v
| Error e -> failwith e
let json_to_string j =
match Jsont_bytesrw.encode_string ~format:Jsont.Minify Jsont.json j with
| Ok s -> s
| Error e -> failwith e
let read_files ~vocab_file ~merges_file =
let vocab_json =
let ic = open_in vocab_file in
let content =
Fun.protect
~finally:(fun () -> close_in ic)
(fun () -> really_input_string ic (in_channel_length ic))
in
json_of_string content
in
let vocab = Hashtbl.create 1024 in
(match vocab_json with
| Jsont.Object (mems, _) ->
List.iter
(fun ((k, _), v) ->
match v with
| Jsont.Number (f, _) -> Hashtbl.add vocab k (int_of_float f)
| _ -> failwith "Invalid vocab format")
mems
| _ -> failwith "Invalid vocab.json format");
let merges =
let ic = open_in merges_file in
Fun.protect
~finally:(fun () -> close_in ic)
(fun () ->
let merges = ref [] in
(try
while true do
let line = input_line ic in
(* Skip empty lines and comment lines that start with #version *)
if
String.length line > 0
&& not (String.starts_with ~prefix:"#version" line)
then
match String.split_on_char ' ' line with
| [ a; b ] -> merges := (a, b) :: !merges
| _ -> failwith (Printf.sprintf "Invalid merge line: %s" line)
done
with End_of_file -> ());
List.rev !merges)
in
(vocab, merges)
let from_files ~vocab_file ~merges_file =
let vocab, merges = read_files ~vocab_file ~merges_file in
create ~vocab ~merges ()
let save model ~path ?name () =
let vocab_file =
match name with
| Some n -> Filename.concat path (Printf.sprintf "%s-vocab.json" n)
| None -> Filename.concat path "vocab.json"
in
let merges_file =
match name with
| Some n -> Filename.concat path (Printf.sprintf "%s-merges.txt" n)
| None -> Filename.concat path "merges.txt"
in
let vocab_items =
Hashtbl.fold (fun k v acc -> (k, v) :: acc) model.vocab []
|> List.sort (fun (_, a) (_, b) -> compare a b)
in
let vocab_json =
Jsont.Json.object'
(List.map
(fun (k, v) -> (Jsont.Json.name k, Jsont.Json.int v))
vocab_items)
in
let oc = open_out vocab_file in
Fun.protect
~finally:(fun () -> close_out oc)
(fun () -> output_string oc (json_to_string vocab_json));
let oc = open_out merges_file in
Fun.protect
~finally:(fun () -> close_out oc)
(fun () ->
output_string oc "#version: 0.2\n";
let merges_list =
Merge_map.fold
(fun key packed acc ->
let a_id = key lsr 21 in
let b_id = key land 0x1FFFFF in
let rank = merge_rank packed in
let vr = model.vocab_r in
let vr_len = Array.length vr in
if a_id >= 0 && a_id < vr_len && b_id >= 0 && b_id < vr_len then
(rank, Array.unsafe_get vr a_id, Array.unsafe_get vr b_id) :: acc
else acc)
model.merges []
|> List.sort (fun (r1, _, _) (r2, _, _) -> compare r1 r2)
in
List.iter (fun (_, a, b) -> Printf.fprintf oc "%s %s\n" a b) merges_list)
let train ~min_frequency ~vocab_size ~show_progress ~special_tokens
~limit_alphabet ~initial_alphabet ~continuing_subword_prefix
~end_of_word_suffix ~max_token_length texts existing =
let _ = (show_progress, existing) in
(* Count words from texts *)
let word_counts = Hashtbl.create 10000 in
List.iter
(fun text ->
let words = String.split_on_char ' ' text in
List.iter
(fun word ->
if String.length word > 0 then
Hashtbl.replace word_counts word
(1 + try Hashtbl.find word_counts word with Not_found -> 0))
words)
texts;
let compute_pair_counts words_copy =
let pair_counts = Hashtbl.create 10000 in
Hashtbl.iter
(fun word count ->
let chars = String.split_on_char ' ' word in
for i = 0 to List.length chars - 2 do
let a = List.nth chars i in
let b = List.nth chars (i + 1) in
let pair = (a, b) in
Hashtbl.replace pair_counts pair
(count + try Hashtbl.find pair_counts pair with Not_found -> 0)
done)
words_copy;
pair_counts
in
(* Build vocabulary *)
let vocab = Hashtbl.create 10000 in
let vocab_size_ref = ref 0 in
List.iter
(fun token ->
if not (Hashtbl.mem vocab token) then (
Hashtbl.add vocab token !vocab_size_ref;
incr vocab_size_ref))
special_tokens;
(* Build alphabet *)
let alphabet = Hashtbl.create 10000 in
Hashtbl.iter
(fun word count ->
let len = String.length word in
let buf = Buffer.create 4 in
let rec loop i =
if i >= len then ()
else
let d = String.get_utf_8_uchar word i in
let n = Uchar.utf_decode_length d in
if Uchar.utf_decode_is_valid d then (
let u = Uchar.utf_decode_uchar d in
Buffer.clear buf;
Buffer.add_utf_8_uchar buf u;
let char_str = Buffer.contents buf in
Hashtbl.replace alphabet char_str
(count + try Hashtbl.find alphabet char_str with Not_found -> 0));
loop (i + n)
in
loop 0)
word_counts;
List.iter
(fun c ->
let char_str = String.make 1 c in
Hashtbl.replace alphabet char_str max_int)
initial_alphabet;
let kept = Hashtbl.fold (fun k v acc -> (k, v) :: acc) alphabet [] in
let kept = List.sort (fun (_, v1) (_, v2) -> compare v1 v2) kept in
let to_remove =
match limit_alphabet with
| Some limit -> max 0 (List.length kept - limit)
| None -> 0
in
let kept = list_drop to_remove kept in
let kept = List.sort (fun (k1, _) (k2, _) -> compare k1 k2) kept in
let csp_str =
match continuing_subword_prefix with Some p -> p | None -> ""
in
let csp_len = String.length csp_str in
List.iter
(fun (c, _) ->
if not (Hashtbl.mem vocab c) then (
Hashtbl.add vocab c !vocab_size_ref;
incr vocab_size_ref);
if csp_len > 0 then (
let clen = String.length c in
let s = Bytes.create (csp_len + clen) in
Bytes.blit_string csp_str 0 s 0 csp_len;
Bytes.blit_string c 0 s csp_len clen;
let prefixed = Bytes.unsafe_to_string s in
if not (Hashtbl.mem vocab prefixed) then (
Hashtbl.add vocab prefixed !vocab_size_ref;
incr vocab_size_ref)))
kept;
(* Learn merges *)
let merges = ref [] in
let words_copy = ref (Hashtbl.create (Hashtbl.length word_counts)) in
Hashtbl.iter
(fun word count ->
let len = String.length word in
let chars = ref [] in
let buf = Buffer.create 8 in
let is_first = ref true in
let rec loop i =
if i >= len then ()
else
let d = String.get_utf_8_uchar word i in
let n = Uchar.utf_decode_length d in
if Uchar.utf_decode_is_valid d then (
let u = Uchar.utf_decode_uchar d in
Buffer.clear buf;
if csp_len > 0 && not !is_first then Buffer.add_string buf csp_str;
Buffer.add_utf_8_uchar buf u;
is_first := false;
chars := Buffer.contents buf :: !chars);
loop (i + n)
in
loop 0;
let separated = String.concat " " (List.rev !chars) in
Hashtbl.add !words_copy separated count)
word_counts;
while !vocab_size_ref < vocab_size do
let pair_counts = compute_pair_counts !words_copy in
let best_pair = ref None in
let best_count = ref (-1) in
let best_pair_tie = ref ("", "") in
Hashtbl.iter
(fun pair count ->
if count > !best_count then (
best_count := count;
best_pair := Some pair;
best_pair_tie := pair)
else if count = !best_count then
if compare pair !best_pair_tie < 0 then best_pair_tie := pair)
pair_counts;
match !best_pair with
| None -> vocab_size_ref := vocab_size
| Some (a, b) ->
if !best_count < min_frequency then vocab_size_ref := vocab_size
else
let blen = String.length b in
let new_token =
if
csp_len > 0 && blen > csp_len
&& String.starts_with ~prefix:csp_str b
then (
let alen = String.length a in
let brest = blen - csp_len in
let s = Bytes.create (alen + brest) in
Bytes.blit_string a 0 s 0 alen;
Bytes.blit_string b csp_len s alen brest;
Bytes.unsafe_to_string s)
else a ^ b
in
let skip =
match max_token_length with
| Some l when String.length new_token > l -> true
| _ -> false
in
if not skip then (
if not (Hashtbl.mem vocab new_token) then (
Hashtbl.add vocab new_token !vocab_size_ref;
incr vocab_size_ref);
merges := (a, b) :: !merges;
let new_words = Hashtbl.create (Hashtbl.length !words_copy) in
let pat = a ^ " " ^ b in
let pat_len = String.length pat in
Hashtbl.iter
(fun word count ->
let wlen = String.length word in
if wlen < pat_len then Hashtbl.add new_words word count
else
let buf = Buffer.create wlen in
let pos = ref 0 in
let changed = ref false in
while !pos <= wlen - pat_len do
let at_boundary =
(!pos = 0
|| Char.equal (String.unsafe_get word (!pos - 1)) ' ')
&& (!pos + pat_len = wlen
|| Char.equal
(String.unsafe_get word (!pos + pat_len))
' ')
in
if at_boundary && String.sub word !pos pat_len = pat then (
Buffer.add_string buf new_token;
pos := !pos + pat_len;
changed := true)
else (
Buffer.add_char buf (String.unsafe_get word !pos);
incr pos)
done;
if !changed then (
Buffer.add_substring buf word !pos (wlen - !pos);
Hashtbl.add new_words (Buffer.contents buf) count)
else Hashtbl.add new_words word count)
!words_copy;
words_copy := new_words)
done;
let trained_model =
create ~vocab ~merges:(List.rev !merges) ?continuing_subword_prefix
?end_of_word_suffix ()
in
(trained_model, special_tokens)
================================================
FILE: packages/brot/lib/bpe.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** BPE (Byte Pair Encoding) tokenization model.
{b Internal module.} Iteratively merges the most frequent adjacent character
pairs to build a subword vocabulary. Used by GPT-2, GPT-3, and RoBERTa.
A word is first split into characters, then merge rules are applied in
priority order (earlier rules have higher priority). Merging continues until
no more rules apply.
Tokenized words are cached in a direct-mapped bounded cache for amortized
performance. *)
type t
(** The type for BPE models. Internally mutable due to the merge cache. *)
type vocab = (string, int) Hashtbl.t
(** The type for vocabularies mapping token strings to IDs. *)
type merges = (string * string) list
(** The type for merge rules in priority order (earlier rules have higher
priority). *)
(** {1:creation Creation} *)
val create :
vocab:vocab ->
merges:merges ->
?cache_capacity:int ->
?dropout:float ->
?unk_token:string ->
?continuing_subword_prefix:string ->
?end_of_word_suffix:string ->
?fuse_unk:bool ->
?byte_fallback:bool ->
?ignore_merges:bool ->
unit ->
t
(** [create ~vocab ~merges ()] is a BPE model.
- [cache_capacity] is the number of slots in the direct-mapped word cache.
Defaults to [10000]. Set to [0] to disable caching. Words longer than 4096
bytes bypass the cache.
- [dropout] is the probability of randomly skipping a merge during
tokenization (BPE-dropout regularization). Defaults to [0.] (no dropout).
- [unk_token] is emitted for characters not in [vocab] (when
{!byte_fallback} is off). No default.
- [continuing_subword_prefix] is prepended to non-initial subwords. No
default.
- [end_of_word_suffix] is appended to the final subword of each word. No
default.
- [fuse_unk], when [true], merges consecutive unknown bytes into a single
[unk_token] instead of emitting one per byte. Defaults to [false].
- [byte_fallback], when [true], falls back to byte-level tokens (e.g.
["<0xFF>"]) for characters not in [vocab] instead of emitting [unk_token].
Defaults to [false].
- [ignore_merges], when [true], skips the merge step entirely and returns
raw character-level tokens. Defaults to [false]. *)
val from_files : vocab_file:string -> merges_file:string -> t
(** [from_files ~vocab_file ~merges_file] loads a BPE model from
HuggingFace-format files.
- [vocab_file] is a JSON object mapping token strings to integer IDs.
- [merges_file] is a text file with one space-separated merge pair per line.
An optional [#version:] header line is skipped. *)
(** {1:tokenization Tokenization} *)
type token = { id : int; value : string; offsets : int * int }
(** The type for tokens. [id] is the vocabulary index, [value] the string
content, and [offsets] the [(start, stop)] byte span in the source text. *)
val tokenize : t -> string -> token list
(** [tokenize t s] is the BPE tokenization of [s]. *)
val tokenize_ids : t -> string -> int array
(** [tokenize_ids t s] is like {!tokenize} but returns only token IDs. *)
val tokenize_encoding : t -> string -> type_id:int -> Encoding.t
(** [tokenize_encoding t s ~type_id] tokenizes [s] and builds an {!Encoding.t}
directly, avoiding intermediate list allocation. *)
(** {1:vocabulary Vocabulary} *)
val token_to_id : t -> string -> int option
(** [token_to_id t tok] is the ID of [tok] in the vocabulary. *)
val id_to_token : t -> int -> string option
(** [id_to_token t id] is the token string for [id]. *)
val get_vocab : t -> (string * int) list
(** [get_vocab t] is the vocabulary as [(token, id)] pairs. *)
val get_vocab_size : t -> int
(** [get_vocab_size t] is the number of tokens in the vocabulary. *)
val get_unk_token : t -> string option
(** [get_unk_token t] is the unknown token, if configured. *)
val get_continuing_subword_prefix : t -> string option
(** [get_continuing_subword_prefix t] is the subword prefix, if configured (e.g.
["##"]). *)
val get_end_of_word_suffix : t -> string option
(** [get_end_of_word_suffix t] is the word-end suffix, if configured (e.g.
[""]). *)
val get_merges : t -> (string * string) list
(** [get_merges t] is the merge rules in priority order. *)
(** {1:serialization Serialization} *)
val save : t -> path:string -> ?name:string -> unit -> unit
(** [save t ~path ()] writes the model to [path] as two files:
- [vocab.json]: a JSON object mapping token strings to IDs.
- [merges.txt]: merge pairs, one per line, with a [#version: 0.2] header. *)
(** {1:training Training} *)
val train :
min_frequency:int ->
vocab_size:int ->
show_progress:bool ->
special_tokens:string list ->
limit_alphabet:int option ->
initial_alphabet:char list ->
continuing_subword_prefix:string option ->
end_of_word_suffix:string option ->
max_token_length:int option ->
string list ->
t option ->
t * string list
(** [train ~min_frequency ~vocab_size ~show_progress ~special_tokens
~limit_alphabet ~initial_alphabet ~continuing_subword_prefix
~end_of_word_suffix ~max_token_length texts init] learns BPE merges from
[texts].
The algorithm counts word frequencies, builds an initial character alphabet,
then iteratively finds and merges the highest-frequency adjacent pair until
[vocab_size] is reached or pair frequency drops below [min_frequency].
- [min_frequency] is the minimum pair frequency to merge.
- [vocab_size] is the target vocabulary size.
- [show_progress] enables progress output on [stderr].
- [special_tokens] are added to the vocabulary first.
- [limit_alphabet] caps the number of distinct initial characters kept.
- [initial_alphabet] seeds the character set.
- [continuing_subword_prefix] is set on the resulting model.
- [end_of_word_suffix] is set on the resulting model.
- [max_token_length] limits the byte length of merged tokens.
- [init], when provided, seeds the vocabulary from an existing model.
Returns [(model, special_tokens)]. *)
================================================
FILE: packages/brot/lib/brot.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
module Normalizer = Normalizer
module Pre_tokenizer = Pre_tokenizer
module Post_processor = Post_processor
module Decoder = Decoder
module Encoding = Encoding
let strf = Printf.sprintf
(* Error messages *)
let err_pair_no_post = "pair sequences require a configured post-processor"
let err_no_pad_token = "padding requested but no pad token configured"
let err_pad_not_in_vocab tok = strf "pad token '%s' not in vocabulary" tok
let err_add_tokens = "only supported for word-level tokenizers"
let err_export_tiktoken = "only supported for BPE models"
let err_infer_type = "unable to infer model type from JSON"
(* Types *)
type direction = [ `Left | `Right ]
type special = {
token : string;
single_word : bool;
lstrip : bool;
rstrip : bool;
normalized : bool;
}
type pad_length = [ `Batch_longest | `Fixed of int | `To_multiple of int ]
type padding = {
length : pad_length;
direction : direction;
pad_id : int option;
pad_type_id : int option;
pad_token : string option;
}
type truncation = { max_length : int; direction : direction }
type data = [ `Files of string list | `Seq of string Seq.t ]
type sequence = { text : string; pair : string option }
type algorithm =
| Alg_bpe of Bpe.t
| Alg_wordpiece of Wordpiece.t
| Alg_wordlevel of Word_level.t
| Alg_unigram of Unigram.t
| Alg_chars of Chars.t
type t = {
algorithm : algorithm;
normalizer : Normalizer.t option;
pre_tokenizer : Pre_tokenizer.t option;
post_processor : Post_processor.t option;
decoder : Decoder.t option;
specials : special list;
special_lookup : (string, unit) Hashtbl.t;
bos_token : string option;
eos_token : string option;
pad_token : string option;
pad_id : int option;
pad_type_id : int;
unk_token : string option;
}
let special ?(single_word = false) ?(lstrip = false) ?(rstrip = false)
?(normalized = false) token =
{ token; single_word; lstrip; rstrip; normalized }
let padding ?(direction = `Right) ?pad_id ?pad_type_id ?pad_token length =
{ length; direction; pad_id; pad_type_id; pad_token }
let truncation ?(direction = `Right) max_length = { max_length; direction }
(* Algorithm dispatch *)
let alg_add_tokens algorithm tokens =
match algorithm with
| Alg_wordlevel model ->
ignore (Word_level.add_tokens model tokens);
algorithm
| Alg_bpe _ | Alg_wordpiece _ | Alg_unigram _ | Alg_chars _ -> algorithm
let alg_token_to_id algorithm token =
match algorithm with
| Alg_bpe m -> Bpe.token_to_id m token
| Alg_wordpiece m -> Wordpiece.token_to_id m token
| Alg_wordlevel m -> Word_level.token_to_id m token
| Alg_unigram m -> Unigram.token_to_id m token
| Alg_chars m -> Chars.token_to_id m token
let alg_id_to_token algorithm id =
match algorithm with
| Alg_bpe m -> Bpe.id_to_token m id
| Alg_wordpiece m -> Wordpiece.id_to_token m id
| Alg_wordlevel m -> Word_level.id_to_token m id
| Alg_unigram m -> Unigram.id_to_token m id
| Alg_chars m -> Chars.id_to_token m id
let alg_vocab algorithm =
match algorithm with
| Alg_bpe m -> Bpe.get_vocab m
| Alg_wordpiece m -> Wordpiece.get_vocab m
| Alg_wordlevel m -> Word_level.get_vocab m
| Alg_unigram m ->
Unigram.get_vocab m |> List.mapi (fun i (token, _) -> (token, i))
| Alg_chars m -> Chars.get_vocab m
let alg_vocab_size algorithm =
match algorithm with
| Alg_bpe m -> Bpe.get_vocab_size m
| Alg_wordpiece m -> Wordpiece.get_vocab_size m
| Alg_wordlevel m -> Word_level.get_vocab_size m
| Alg_unigram m -> Unigram.get_vocab_size m
| Alg_chars m -> Chars.get_vocab_size m
let alg_save algorithm ~folder ?prefix () =
match algorithm with
| Alg_bpe m ->
Bpe.save m ~path:folder ?name:prefix ();
let name base ext =
match prefix with
| Some n -> Filename.concat folder (strf "%s-%s.%s" n base ext)
| None -> Filename.concat folder (strf "%s.%s" base ext)
in
[ name "vocab" "json"; name "merges" "txt" ]
| Alg_wordpiece m -> [ Wordpiece.save m ~path:folder ?name:prefix () ]
| Alg_wordlevel m -> Word_level.save m ~folder ()
| Alg_unigram m -> Unigram.save m ~folder ()
| Alg_chars m -> Chars.save m ~folder ()
let alg_tokenize algorithm text =
match algorithm with
| Alg_bpe m ->
Bpe.tokenize m text
|> List.map (fun (tok : Bpe.token) -> (tok.id, tok.value, tok.offsets))
| Alg_wordpiece m ->
Wordpiece.tokenize m text
|> List.map (fun (tok : Wordpiece.token) ->
(tok.id, tok.value, tok.offsets))
| Alg_wordlevel m -> Word_level.tokenize m text
| Alg_unigram m -> Unigram.tokenize m text
| Alg_chars m -> Chars.tokenize m text
let alg_tokenize_ids algorithm text =
match algorithm with
| Alg_bpe m -> Bpe.tokenize_ids m text
| Alg_wordpiece m -> Wordpiece.tokenize_ids m text
| Alg_wordlevel m -> Word_level.tokenize_ids m text
| Alg_unigram m ->
Unigram.tokenize m text
|> List.map (fun (id, _, _) -> id)
|> Array.of_list
| Alg_chars m ->
Chars.tokenize m text |> List.map (fun (id, _, _) -> id) |> Array.of_list
let alg_name = function
| Alg_bpe _ -> "BPE"
| Alg_wordpiece _ -> "WordPiece"
| Alg_wordlevel _ -> "WordLevel"
| Alg_unigram _ -> "Unigram"
| Alg_chars _ -> "Chars"
let vocab_to_hashtbl vocab =
let tbl = Hashtbl.create (List.length vocab) in
List.iter (fun (token, id) -> Hashtbl.add tbl token id) vocab;
tbl
(* Special tokens *)
let dedup_by key items =
let seen = Hashtbl.create 16 in
let acc = ref [] in
List.iter
(fun item ->
let k = key item in
if not (Hashtbl.mem seen k) then (
Hashtbl.replace seen k ();
acc := item :: !acc))
items;
List.rev !acc
let collect_unique_tokens specials ~bos_token ~eos_token ~pad_token ~unk_token =
let items =
List.map (fun (s : special) -> s.token) specials
@ List.filter_map Fun.id [ bos_token; eos_token; pad_token; unk_token ]
in
dedup_by Fun.id items
let build_special_lookup specials ~bos_token ~eos_token ~pad_token ~unk_token =
let tokens =
collect_unique_tokens specials ~bos_token ~eos_token ~pad_token ~unk_token
in
let table = Hashtbl.create (List.length tokens) in
List.iter (fun t -> Hashtbl.replace table t ()) tokens;
table
(* Construction *)
let create ?normalizer ?pre ?post ?decoder ?(specials = []) ?bos_token
?eos_token ?pad_token ?unk_token algorithm =
let all_tokens =
collect_unique_tokens specials ~bos_token ~eos_token ~pad_token ~unk_token
in
let algorithm = alg_add_tokens algorithm all_tokens in
let special_lookup =
build_special_lookup specials ~bos_token ~eos_token ~pad_token ~unk_token
in
let pad_id = Option.bind pad_token (alg_token_to_id algorithm) in
{
algorithm;
normalizer;
pre_tokenizer = pre;
post_processor = post;
decoder;
specials;
special_lookup;
bos_token;
eos_token;
pad_token;
pad_id;
pad_type_id = 0;
unk_token;
}
(* Accessors *)
let normalizer t = t.normalizer
let pre_tokenizer t = t.pre_tokenizer
let post_processor t = t.post_processor
let decoder t = t.decoder
let specials t = t.specials
let bos_token t = t.bos_token
let eos_token t = t.eos_token
let pad_token t = t.pad_token
let unk_token t = t.unk_token
let vocab t = alg_vocab t.algorithm
let vocab_size t = alg_vocab_size t.algorithm
let token_to_id t token = alg_token_to_id t.algorithm token
let id_to_token t id = alg_id_to_token t.algorithm id
(* Algorithm constructors *)
let bpe ?normalizer ?pre ?post ?decoder ?specials ?bos_token ?eos_token
?pad_token ?unk_token ?vocab ?merges ?cache_capacity ?dropout
?continuing_subword_prefix ?end_of_word_suffix ?fuse_unk ?byte_fallback
?ignore_merges () =
let vocab_tbl =
match vocab with None -> Hashtbl.create 100 | Some v -> vocab_to_hashtbl v
in
let algorithm =
Alg_bpe
(Bpe.create ~vocab:vocab_tbl
~merges:(Option.value merges ~default:[])
?cache_capacity ?dropout ?unk_token ?continuing_subword_prefix
?end_of_word_suffix ?fuse_unk ?byte_fallback ?ignore_merges ())
in
create ?normalizer ?pre ?post ?decoder ?specials ?bos_token ?eos_token
?pad_token ?unk_token algorithm
let wordpiece ?normalizer ?pre ?post ?decoder ?specials ?bos_token ?eos_token
?pad_token ?unk_token ?vocab ?continuing_subword_prefix
?max_input_chars_per_word () =
let vocab_tbl =
match vocab with None -> Hashtbl.create 100 | Some v -> vocab_to_hashtbl v
in
let algorithm =
Alg_wordpiece
(Wordpiece.create ~vocab:vocab_tbl ?unk_token ?continuing_subword_prefix
?max_input_chars_per_word ())
in
create ?normalizer ?pre ?post ?decoder ?specials ?bos_token ?eos_token
?pad_token ?unk_token algorithm
let word_level ?normalizer ?pre ?post ?decoder ?specials ?bos_token ?eos_token
?pad_token ?unk_token ?vocab () =
let pre =
match pre with Some _ -> pre | None -> Some (Pre_tokenizer.whitespace ())
in
let algorithm = Alg_wordlevel (Word_level.create ?vocab ?unk_token ()) in
create ?normalizer ?pre ?post ?decoder ?specials ?bos_token ?eos_token
?pad_token ?unk_token algorithm
let unigram ?normalizer ?pre ?post ?decoder ?specials ?bos_token ?eos_token
?pad_token ?unk_token ?vocab () =
let algorithm =
Alg_unigram (Unigram.create (Option.value vocab ~default:[]))
in
create ?normalizer ?pre ?post ?decoder ?specials ?bos_token ?eos_token
?pad_token ?unk_token algorithm
let chars ?normalizer ?pre ?post ?decoder ?specials ?bos_token ?eos_token
?pad_token ?unk_token () =
let algorithm = Alg_chars (Chars.create ()) in
create ?normalizer ?pre ?post ?decoder ?specials ?bos_token ?eos_token
?pad_token ?unk_token algorithm
let from_model_file ~vocab ?merges ?normalizer ?pre ?post ?decoder ?specials
?bos_token ?eos_token ?pad_token ?unk_token () =
let algorithm =
match merges with
| Some merges_file ->
Alg_bpe (Bpe.from_files ~vocab_file:vocab ~merges_file)
| None -> Alg_wordpiece (Wordpiece.from_file ~vocab_file:vocab)
in
create ?normalizer ?pre ?post ?decoder ?specials ?bos_token ?eos_token
?pad_token ?unk_token algorithm
let add_tokens t tokens =
match t.algorithm with
| Alg_wordlevel model ->
let vocab = Word_level.get_vocab model in
let new_model = Word_level.create ~vocab ?unk_token:t.unk_token () in
ignore (Word_level.add_tokens new_model tokens);
{ t with algorithm = Alg_wordlevel new_model }
| Alg_bpe _ | Alg_wordpiece _ | Alg_unigram _ | Alg_chars _ ->
invalid_arg err_add_tokens
(* Encoding *)
let encode_text t text =
let normalized =
match t.normalizer with Some n -> Normalizer.apply n text | None -> text
in
let pre_tokens =
match t.pre_tokenizer with
| Some pre -> Pre_tokenizer.pre_tokenize pre normalized
| None -> [ (normalized, (0, String.length normalized)) ]
in
match (t.algorithm, pre_tokens) with
| Alg_bpe m, [ (fragment, _) ] -> Bpe.tokenize_encoding m fragment ~type_id:0
| Alg_wordpiece m, _ ->
Wordpiece.tokenize_spans_encoding m pre_tokens ~type_id:0
| _ ->
pre_tokens
|> List.concat_map (fun (fragment, _) ->
alg_tokenize t.algorithm fragment)
|> Encoding.from_tokens ~type_id:0
let post_process t ~add_special primary pair =
match t.post_processor with
| None ->
if Option.is_some pair then invalid_arg err_pair_no_post else primary
| Some processor ->
Post_processor.process processor ?pair primary
~add_special_tokens:add_special
let encode_single t ~add_special_tokens ~truncation seq =
let primary = encode_text t seq.text in
let pair = Option.map (encode_text t) seq.pair in
let processed = post_process t ~add_special:add_special_tokens primary pair in
match truncation with
| None -> processed
| Some { max_length; direction } ->
Encoding.truncate processed ~max_length ~stride:0 ~direction
(* Padding *)
let resolve_pad t (cfg : padding) =
let token =
match cfg.pad_token with Some _ as v -> v | None -> t.pad_token
in
let token =
match token with
| Some token -> token
| None -> invalid_arg err_no_pad_token
in
let id = match cfg.pad_id with Some _ as v -> v | None -> t.pad_id in
let id =
match id with
| Some id -> id
| None -> (
match alg_token_to_id t.algorithm token with
| Some id -> id
| None -> invalid_arg (err_pad_not_in_vocab token))
in
let type_id = Option.value cfg.pad_type_id ~default:t.pad_type_id in
(token, id, type_id)
let round_up_to_multiple n m = if n mod m = 0 then n else (n + m - 1) / m * m
let apply_padding t encodings = function
| None -> encodings
| Some cfg -> (
let pad_token, pad_id, pad_type_id = resolve_pad t cfg in
let direction = cfg.direction in
let pad enc target =
if Encoding.length enc >= target then enc
else
Encoding.pad enc ~target_length:target ~pad_id ~pad_type_id ~pad_token
~direction
in
match cfg.length with
| `Fixed n -> List.map (fun enc -> pad enc n) encodings
| `Batch_longest ->
let max_len =
List.fold_left
(fun acc enc -> max acc (Encoding.length enc))
0 encodings
in
List.map (fun enc -> pad enc max_len) encodings
| `To_multiple m ->
if m <= 0 then encodings
else
List.map
(fun enc ->
pad enc (round_up_to_multiple (Encoding.length enc) m))
encodings)
(* Parallel batch encoding *)
let encode_parallel t sequences ~add_special_tokens ~truncation =
let arr = Array.of_list sequences in
let n = Array.length arr in
let results =
Array.make n (encode_single t ~add_special_tokens ~truncation arr.(0))
in
let num_domains = min n (Domain.recommended_domain_count ()) in
if num_domains <= 1 then
for i = 1 to n - 1 do
results.(i) <- encode_single t ~add_special_tokens ~truncation arr.(i)
done
else begin
let chunk_size = n / num_domains in
let remainder = n mod num_domains in
let domains =
Array.init (num_domains - 1) (fun d ->
let start = ((d + 1) * chunk_size) + min (d + 1) remainder in
let len = chunk_size + if d + 1 < remainder then 1 else 0 in
Domain.spawn (fun () ->
for i = start to start + len - 1 do
results.(i) <-
encode_single t ~add_special_tokens ~truncation arr.(i)
done))
in
let main_len = chunk_size + if 0 < remainder then 1 else 0 in
for i = 1 to main_len - 1 do
results.(i) <- encode_single t ~add_special_tokens ~truncation arr.(i)
done;
Array.iter Domain.join domains
end;
Array.to_list results
let encode_sequences t sequences ~add_special_tokens ~padding ~truncation =
let n = List.length sequences in
let raw =
if n >= 4 then encode_parallel t sequences ~add_special_tokens ~truncation
else List.map (encode_single t ~add_special_tokens ~truncation) sequences
in
apply_padding t raw padding
let encode t ?pair ?(add_special_tokens = true) ?padding ?truncation text =
match
encode_sequences t
[ { text; pair } ]
~add_special_tokens ~padding ~truncation
with
| [ encoding ] -> encoding
| _ -> assert false
let encode_batch t ?(add_special_tokens = true) ?padding ?truncation = function
| [] -> []
| texts ->
let sequences = List.map (fun text -> { text; pair = None }) texts in
encode_sequences t sequences ~add_special_tokens ~padding ~truncation
let encode_pairs_batch t ?(add_special_tokens = true) ?padding ?truncation =
function
| [] -> []
| pairs ->
let sequences =
List.map (fun (text, pair) -> { text; pair = Some pair }) pairs
in
encode_sequences t sequences ~add_special_tokens ~padding ~truncation
let encode_ids t ?pair ?add_special_tokens ?padding ?truncation text =
let use_fast_path =
Option.is_none pair
&& (add_special_tokens = None || add_special_tokens = Some false)
&& Option.is_none padding && Option.is_none truncation
&& Option.is_none t.post_processor
in
if not use_fast_path then
Encoding.ids (encode t ?pair ?add_special_tokens ?padding ?truncation text)
else
let normalized =
match t.normalizer with Some n -> Normalizer.apply n text | None -> text
in
let pre_tokens =
match t.pre_tokenizer with
| Some pre -> Pre_tokenizer.pre_tokenize pre normalized
| None -> [ (normalized, (0, String.length normalized)) ]
in
let id_arrays =
List.map
(fun (fragment, _) -> alg_tokenize_ids t.algorithm fragment)
pre_tokens
in
let total_len =
List.fold_left (fun acc a -> acc + Array.length a) 0 id_arrays
in
let result = Array.make total_len 0 in
let pos = ref 0 in
List.iter
(fun a ->
let len = Array.length a in
Array.blit a 0 result !pos len;
pos := !pos + len)
id_arrays;
result
(* Decoding *)
let decode t ?(skip_special_tokens = false) ids =
let tokens =
Array.to_list ids
|> List.filter_map (fun id ->
match alg_id_to_token t.algorithm id with
| None -> None
| Some token
when skip_special_tokens && Hashtbl.mem t.special_lookup token ->
None
| Some token -> Some token)
in
match t.decoder with
| Some decoder -> Decoder.decode decoder tokens
| None -> (
match t.algorithm with
| Alg_wordlevel _ -> String.concat " " tokens
| _ -> String.concat "" tokens)
let decode_batch t ?(skip_special_tokens = false) id_lists =
List.map (decode t ~skip_special_tokens) id_lists
(* Training *)
let special_tokens_for_training init specials =
let items =
(match specials with
| Some sl -> List.map (fun (s : special) -> s.token) sl
| None -> [])
@
match init with
| Some tok -> List.map (fun (s : special) -> s.token) tok.specials
| None -> []
in
dedup_by Fun.id items
let merge_specials_from_training ~user_specials ~trained_tokens =
let items =
(match user_specials with Some sl -> sl | None -> [])
@ List.map special trained_tokens
in
dedup_by (fun (s : special) -> s.token) items
let data_to_strings = function
| `Files files ->
let lines = ref [] in
List.iter
(fun file ->
let ic = open_in file in
(try
while true do
lines := input_line ic :: !lines
done
with End_of_file -> ());
close_in ic)
files;
List.rev !lines
| `Seq seq -> List.of_seq seq
let initial_alphabet_of strs =
List.map (fun s -> if String.length s > 0 then s.[0] else ' ') strs
let train_bpe ?init ?normalizer ?pre ?post ?decoder ?specials ?bos_token
?eos_token ?pad_token ?unk_token ?(vocab_size = 30000) ?(min_frequency = 0)
?limit_alphabet ?initial_alphabet ?continuing_subword_prefix
?end_of_word_suffix ?(show_progress = true) ?max_token_length data =
let special_tokens = special_tokens_for_training init specials in
let initial_alphabet =
Option.value initial_alphabet ~default:[] |> initial_alphabet_of
in
let limit_alphabet = Some (Option.value limit_alphabet ~default:1000) in
let texts = data_to_strings data in
let existing_bpe =
Option.bind init (fun t ->
match t.algorithm with Alg_bpe m -> Some m | _ -> None)
in
let trained_model, result_specials =
Bpe.train ~min_frequency ~vocab_size ~show_progress ~special_tokens
~limit_alphabet ~initial_alphabet ~continuing_subword_prefix
~end_of_word_suffix ~max_token_length texts existing_bpe
in
let all_specials =
merge_specials_from_training ~user_specials:specials
~trained_tokens:result_specials
in
create ?normalizer ?pre ?post ?decoder ~specials:all_specials ?bos_token
?eos_token ?pad_token ?unk_token (Alg_bpe trained_model)
let train_wordpiece ?init ?normalizer ?pre ?post ?decoder ?specials ?bos_token
?eos_token ?pad_token ?unk_token ?(vocab_size = 30000) ?(min_frequency = 0)
?limit_alphabet ?initial_alphabet ?(continuing_subword_prefix = "##")
?end_of_word_suffix ?(show_progress = true) data =
let special_tokens = special_tokens_for_training init specials in
let initial_alphabet =
Option.value initial_alphabet ~default:[] |> initial_alphabet_of
in
let limit_alphabet = Some (Option.value limit_alphabet ~default:1000) in
let texts = data_to_strings data in
let existing_wp =
Option.bind init (fun t ->
match t.algorithm with Alg_wordpiece m -> Some m | _ -> None)
in
let trained_model, result_specials =
Wordpiece.train ~min_frequency ~vocab_size ~show_progress ~special_tokens
~limit_alphabet ~initial_alphabet ~continuing_subword_prefix
~end_of_word_suffix texts existing_wp
in
let all_specials =
merge_specials_from_training ~user_specials:specials
~trained_tokens:result_specials
in
create ?normalizer ?pre ?post ?decoder ~specials:all_specials ?bos_token
?eos_token ?pad_token ?unk_token (Alg_wordpiece trained_model)
let train_wordlevel ?init ?normalizer ?pre ?post ?decoder ?specials ?bos_token
?eos_token ?pad_token ?unk_token ?(vocab_size = 30000) ?(min_frequency = 0)
?(show_progress = true) data =
let special_tokens = special_tokens_for_training init specials in
let texts = data_to_strings data in
let existing_wl =
Option.bind init (fun t ->
match t.algorithm with Alg_wordlevel m -> Some m | _ -> None)
in
let trained_model, result_specials =
Word_level.train ~vocab_size ~min_frequency ~show_progress ~special_tokens
texts existing_wl
in
let all_specials =
merge_specials_from_training ~user_specials:specials
~trained_tokens:result_specials
in
create ?normalizer ?pre ?post ?decoder ~specials:all_specials ?bos_token
?eos_token ?pad_token ?unk_token (Alg_wordlevel trained_model)
let train_unigram ?init ?normalizer ?pre ?post ?decoder ?specials ?bos_token
?eos_token ?pad_token ?unk_token ?(vocab_size = 8000)
?(show_progress = true) ?(shrinking_factor = 0.75) ?(max_piece_length = 16)
?(n_sub_iterations = 2) data =
let special_tokens = special_tokens_for_training init specials in
let texts = data_to_strings data in
let existing_ug =
Option.bind init (fun t ->
match t.algorithm with Alg_unigram m -> Some m | _ -> None)
in
let trained_model, result_specials =
Unigram.train ~vocab_size ~show_progress ~special_tokens ~shrinking_factor
~unk_token ~max_piece_length ~n_sub_iterations texts existing_ug
in
let all_specials =
merge_specials_from_training ~user_specials:specials
~trained_tokens:result_specials
in
create ?normalizer ?pre ?post ?decoder ~specials:all_specials ?bos_token
?eos_token ?pad_token ?unk_token (Alg_unigram trained_model)
(* JSON serialization *)
let json_obj pairs =
Jsont.Json.object' (List.map (fun (k, v) -> (Jsont.Json.name k, v)) pairs)
let json_mem name = function
| Jsont.Object (mems, _) -> (
match Jsont.Json.find_mem name mems with
| Some (_, v) -> v
| None -> Jsont.Null ((), Jsont.Meta.none))
| _ -> Jsont.Null ((), Jsont.Meta.none)
let json_string_or_null = function Jsont.String (s, _) -> Some s | _ -> None
let json_option_of f = function None -> Jsont.Json.null () | Some v -> f v
let special_of_json json =
let mem name = json_mem name json in
let to_bool = function Jsont.Bool (b, _) -> b | _ -> false in
let to_str = function
| Jsont.String (s, _) -> s
| _ -> failwith "expected string"
in
{
token = to_str (mem "content");
single_word = to_bool (mem "single_word");
lstrip = to_bool (mem "lstrip");
rstrip = to_bool (mem "rstrip");
normalized = to_bool (mem "normalized");
}
let added_token_to_json ~id (s : special) =
json_obj
[
("id", Jsont.Json.int id);
("content", Jsont.Json.string s.token);
("single_word", Jsont.Json.bool s.single_word);
("lstrip", Jsont.Json.bool s.lstrip);
("rstrip", Jsont.Json.bool s.rstrip);
("normalized", Jsont.Json.bool s.normalized);
("special", Jsont.Json.bool true);
]
let vocab_to_json vocab =
json_obj (List.map (fun (token, id) -> (token, Jsont.Json.int id)) vocab)
let alg_to_json = function
| Alg_bpe bpe ->
let vocab_json = vocab_to_json (Bpe.get_vocab bpe) in
let merges_json =
Bpe.get_merges bpe
|> List.map (fun (a, b) ->
Jsont.Json.list [ Jsont.Json.string a; Jsont.Json.string b ])
|> Jsont.Json.list
in
json_obj
[
("type", Jsont.Json.string "BPE");
("dropout", Jsont.Json.null ());
("unk_token", json_option_of Jsont.Json.string (Bpe.get_unk_token bpe));
( "continuing_subword_prefix",
json_option_of Jsont.Json.string
(Bpe.get_continuing_subword_prefix bpe) );
( "end_of_word_suffix",
json_option_of Jsont.Json.string (Bpe.get_end_of_word_suffix bpe) );
("fuse_unk", Jsont.Json.bool false);
("byte_fallback", Jsont.Json.bool false);
("ignore_merges", Jsont.Json.bool false);
("vocab", vocab_json);
("merges", merges_json);
]
| Alg_wordpiece wp ->
json_obj
[
("type", Jsont.Json.string "WordPiece");
("unk_token", Jsont.Json.string (Wordpiece.get_unk_token wp));
( "continuing_subword_prefix",
Jsont.Json.string (Wordpiece.get_continuing_subword_prefix wp) );
("max_input_chars_per_word", Jsont.Json.int 100);
("vocab", vocab_to_json (Wordpiece.get_vocab wp));
]
| Alg_wordlevel wl ->
json_obj
[
("type", Jsont.Json.string "WordLevel");
("unk_token", Jsont.Json.string "[UNK]");
("vocab", vocab_to_json (Word_level.get_vocab wl));
]
| Alg_unigram ug ->
let vocab_json =
Unigram.get_vocab ug
|> List.map (fun (token, score) ->
Jsont.Json.list [ Jsont.Json.string token; Jsont.Json.number score ])
|> Jsont.Json.list
in
json_obj
[
("type", Jsont.Json.string "Unigram");
("unk_id", Jsont.Json.null ());
("vocab", vocab_json);
]
| Alg_chars _ ->
json_obj [ ("type", Jsont.Json.string "Chars"); ("vocab", json_obj []) ]
let to_json (t : t) =
let vocab_list = alg_vocab t.algorithm in
let added_tokens =
t.specials
|> List.filter_map (fun spec ->
List.find_opt (fun (token, _) -> token = spec.token) vocab_list
|> Option.map (fun (_, id) -> added_token_to_json ~id spec))
in
json_obj
[
("version", Jsont.Json.string "1.0");
("truncation", Jsont.Json.null ());
("padding", Jsont.Json.null ());
("added_tokens", Jsont.Json.list added_tokens);
("normalizer", json_option_of Normalizer.to_json t.normalizer);
("pre_tokenizer", json_option_of Pre_tokenizer.to_json t.pre_tokenizer);
("post_processor", json_option_of Post_processor.to_json t.post_processor);
("decoder", json_option_of Decoder.to_json t.decoder);
("model", alg_to_json t.algorithm);
]
(* JSON deserialization helpers *)
let json_to_assoc = function
| Jsont.Object (mems, _) ->
List.map
(fun ((k, _), v) ->
match v with
| Jsont.Number (f, _) -> (k, int_of_float f)
| _ -> failwith ("Expected number for vocab entry: " ^ k))
mems
| _ -> failwith "Expected object for vocab"
let json_to_list = function
| Jsont.Array (l, _) -> l
| _ -> failwith "Expected array"
let json_to_string = function
| Jsont.String (s, _) -> s
| _ -> failwith "Expected string"
let json_to_float = function
| Jsont.Number (f, _) -> f
| _ -> failwith "Expected number"
let json_has_field name j =
match json_mem name j with Jsont.Null _ -> false | _ -> true
let json_result_to_option of_json = function
| Jsont.Null _ -> None
| j -> ( match of_json j with Ok v -> Some v | Error msg -> failwith msg)
let infer_model_type mj =
match json_string_or_null (json_mem "type" mj) with
| Some s -> s
| None ->
if json_has_field "merges" mj then "BPE"
else if json_has_field "unk_id" mj then "Unigram"
else if
json_has_field "continuing_subword_prefix" mj
|| json_has_field "max_input_chars_per_word" mj
then "WordPiece"
else if json_has_field "vocab" mj then "WordLevel"
else failwith err_infer_type
let parse_merge = function
| Jsont.Array ([ a; b ], _) -> (json_to_string a, json_to_string b)
| Jsont.String (s, _) -> (
match String.split_on_char ' ' s with
| [ a; b ] -> (a, b)
| _ -> failwith "Invalid merge string format")
| _ -> failwith "Invalid merge entry"
let alg_of_json mj =
let mem name = json_mem name mj in
let str name = json_string_or_null (mem name) in
match infer_model_type mj with
| "BPE" ->
let vocab_list = json_to_assoc (mem "vocab") in
let merges = json_to_list (mem "merges") |> List.map parse_merge in
Alg_bpe
(Bpe.create
~vocab:(vocab_to_hashtbl vocab_list)
~merges ?unk_token:(str "unk_token")
?continuing_subword_prefix:(str "continuing_subword_prefix")
?end_of_word_suffix:(str "end_of_word_suffix") ())
| "WordPiece" ->
let vocab_list = json_to_assoc (mem "vocab") in
let unk_token = str "unk_token" |> Option.value ~default:"[UNK]" in
let continuing_subword_prefix =
str "continuing_subword_prefix" |> Option.value ~default:"##"
in
let max_input_chars_per_word =
match mem "max_input_chars_per_word" with
| Jsont.Number (f, _) -> int_of_float f
| _ -> 100
in
Alg_wordpiece
(Wordpiece.create
~vocab:(vocab_to_hashtbl vocab_list)
~unk_token ~continuing_subword_prefix ~max_input_chars_per_word ())
| "WordLevel" ->
let vocab_list = json_to_assoc (mem "vocab") in
let unk_token = str "unk_token" |> Option.value ~default:"[UNK]" in
Alg_wordlevel (Word_level.create ~vocab:vocab_list ~unk_token ())
| "Unigram" ->
let vocab =
json_to_list (mem "vocab")
|> List.map (fun arr ->
match json_to_list arr with
| [ token; score ] -> (json_to_string token, json_to_float score)
| _ -> failwith "Invalid unigram vocab format")
in
Alg_unigram (Unigram.create vocab)
| "Chars" -> Alg_chars (Chars.create ())
| s -> failwith (strf "Unsupported model type: %s" s)
let from_json json =
try
let mem name = json_mem name json in
let normalizer =
json_result_to_option Normalizer.of_json (mem "normalizer")
in
let pre =
json_result_to_option Pre_tokenizer.of_json (mem "pre_tokenizer")
in
let post =
json_result_to_option Post_processor.of_json (mem "post_processor")
in
let decoder = json_result_to_option Decoder.of_json (mem "decoder") in
let algorithm = alg_of_json (mem "model") in
let added_tokens =
match mem "added_tokens" with
| Jsont.Array (l, _) -> List.map special_of_json l
| _ -> []
in
Ok (create ?normalizer ?pre ?post ?decoder ~specials:added_tokens algorithm)
with
| Failure msg -> Error msg
| exn -> Error (Printexc.to_string exn)
(* File I/O *)
let write_string_to_file path s =
let oc = open_out path in
Fun.protect ~finally:(fun () -> close_out oc) (fun () -> output_string oc s)
let from_file path =
try
let ic = open_in path in
let s =
Fun.protect
~finally:(fun () -> close_in ic)
(fun () -> really_input_string ic (in_channel_length ic))
in
match Jsont_bytesrw.decode_string Jsont.json s with
| Ok json -> from_json json
| Error e -> Error e
with
| Sys_error msg -> Error ("File error: " ^ msg)
| exn -> Error (Printexc.to_string exn)
let save_pretrained t ~path =
(try Sys.mkdir path 0o755 with Sys_error _ -> ());
let json_str =
match
Jsont_bytesrw.encode_string ~format:Jsont.Minify Jsont.json (to_json t)
with
| Ok s -> s
| Error e -> failwith ("save_pretrained: failed to encode JSON: " ^ e)
in
write_string_to_file (Filename.concat path "tokenizer.json") json_str
let export_tiktoken t ~merges_path ~vocab_path =
match t.algorithm with
| Alg_bpe bpe ->
let vocab =
alg_vocab t.algorithm
|> List.sort (fun (_, id1) (_, id2) -> Int.compare id1 id2)
in
let json_str =
match
Jsont_bytesrw.encode_string ~format:Jsont.Minify Jsont.json
(vocab_to_json vocab)
with
| Ok s -> s
| Error e -> failwith ("export_tiktoken: failed to encode vocab: " ^ e)
in
write_string_to_file vocab_path json_str;
let oc = open_out merges_path in
Fun.protect
~finally:(fun () -> close_out oc)
(fun () ->
output_string oc "#version: 0.2\n";
List.iter
(fun (a, b) -> Printf.fprintf oc "%s %s\n" a b)
(Bpe.get_merges bpe))
| _ -> invalid_arg err_export_tiktoken
let save_model_files t ~folder ?prefix () =
alg_save t.algorithm ~folder ?prefix ()
(* Formatting *)
let pp ppf t =
let yes_no = function Some _ -> "yes" | None -> "no" in
Format.fprintf ppf
"@[<1>@]"
(alg_name t.algorithm)
(alg_vocab_size t.algorithm)
(yes_no t.normalizer) (yes_no t.pre_tokenizer) (yes_no t.post_processor)
(yes_no t.decoder)
================================================
FILE: packages/brot/lib/brot.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Tokenization for OCaml.
Brot tokenizes text into token IDs for language models and reverses the
process. Tokenization proceeds through configurable stages:
+ {e Normalization}: clean and normalize text (lowercase, accent removal,
Unicode normalization). See {!Normalizer}.
+ {e Pre-tokenization}: split text into words or sub-words. See
{!Pre_tokenizer}.
+ {e Tokenization}: apply vocabulary-based encoding (BPE, WordPiece,
Unigram, word-level, or character-level).
+ {e Post-processing}: add special tokens and set type IDs. See
{!Post_processor}.
+ {e Padding/Truncation}: adjust sequence lengths for batching.
Each stage is optional and configurable. Open the module to use it, it
defines only modules in your scope.
{1:quick_start Quick start}
Load a pretrained tokenizer:
{[
let tokenizer = Brot.from_file "tokenizer.json" |> Result.get_ok in
let encoding = Brot.encode tokenizer "Hello world!" in
let _ids = Encoding.ids encoding
]}
Create a BPE tokenizer from scratch:
{[
let tokenizer =
Brot.bpe
~vocab:[("hello", 0); ("world", 1); ("[PAD]", 2)]
~merges:[]
()
in
let encoding = Brot.encode tokenizer "hello world" in
let _text = Brot.decode tokenizer (Encoding.ids encoding)
]}
Train a new tokenizer:
{[
let texts = [ "Hello world"; "How are you?"; "Hello again" ] in
let tokenizer =
Brot.train_bpe (`Seq (List.to_seq texts)) ~vocab_size:1000
in
Brot.save_pretrained tokenizer ~path:"./my_tokenizer"
]}
{!modules:Encoding Normalizer Pre_tokenizer Post_processor Decoder} *)
module Normalizer = Normalizer
(** Text normalization. *)
module Pre_tokenizer = Pre_tokenizer
(** Pre-tokenization. *)
module Post_processor = Post_processor
(** Post-processing. *)
module Decoder = Decoder
(** Token decoding. *)
module Encoding = Encoding
(** Tokenization encodings. *)
(** {1:types Types} *)
type t
(** The type for tokenizers. Immutable after creation. *)
type direction = [ `Left | `Right ]
(** The type for padding and truncation directions. [`Left] operates at the
beginning of the sequence, [`Right] at the end. *)
type special = {
token : string; (** The token text (e.g., [""], [""]). *)
single_word : bool; (** Whether this token must match whole words only. *)
lstrip : bool; (** Whether to strip whitespace on the left. *)
rstrip : bool; (** Whether to strip whitespace on the right. *)
normalized : bool; (** Whether to apply normalization to this token. *)
}
(** The type for special token configurations.
Special tokens are never split during tokenization and can be skipped during
decoding. Token IDs are assigned automatically when added to the vocabulary.
The semantic role (pad, unk, bos, etc.) is contextual, not encoded in the
type. *)
type pad_length = [ `Batch_longest | `Fixed of int | `To_multiple of int ]
(** The type for padding length strategies.
- [`Batch_longest]: pad to the longest sequence in the batch.
- [`Fixed n]: pad every sequence to exactly [n] tokens.
- [`To_multiple n]: pad to the smallest multiple of [n] that is at least the
sequence length. *)
type padding = {
length : pad_length;
direction : direction;
pad_id : int option;
pad_type_id : int option;
pad_token : string option;
}
(** The type for padding configurations.
When [pad_id], [pad_type_id], or [pad_token] are [None], the tokenizer's
configured padding token is used. Raises [Invalid_argument] at padding time
if no padding token is configured and these fields are [None]. *)
type truncation = { max_length : int; direction : direction }
(** The type for truncation configurations. Sequences exceeding [max_length]
tokens are trimmed from the given [direction]. *)
type data = [ `Files of string list | `Seq of string Seq.t ]
(** The type for training data sources.
- [`Files paths]: read training text from files, one line per example.
- [`Seq seq]: use a sequence of strings. *)
val special :
?single_word:bool ->
?lstrip:bool ->
?rstrip:bool ->
?normalized:bool ->
string ->
special
(** [special token] is a special token configuration for [token].
[single_word] defaults to [false]. [lstrip] and [rstrip] default to [false].
[normalized] defaults to [false]. *)
val padding :
?direction:direction ->
?pad_id:int ->
?pad_type_id:int ->
?pad_token:string ->
pad_length ->
padding
(** [padding length] is a padding configuration for the given [length] strategy.
[direction] defaults to [`Right]. Other fields default to [None] (falls back
to the tokenizer's configured padding token). *)
val truncation : ?direction:direction -> int -> truncation
(** [truncation max_length] is a truncation configuration limiting sequences to
[max_length] tokens. [direction] defaults to [`Right]. *)
(** {1:constructors Constructors} *)
val bpe :
?normalizer:Normalizer.t ->
?pre:Pre_tokenizer.t ->
?post:Post_processor.t ->
?decoder:Decoder.t ->
?specials:special list ->
?bos_token:string ->
?eos_token:string ->
?pad_token:string ->
?unk_token:string ->
?vocab:(string * int) list ->
?merges:(string * string) list ->
?cache_capacity:int ->
?dropout:float ->
?continuing_subword_prefix:string ->
?end_of_word_suffix:string ->
?fuse_unk:bool ->
?byte_fallback:bool ->
?ignore_merges:bool ->
unit ->
t
(** [bpe ()] is a BPE (Byte Pair Encoding) tokenizer. Used by GPT-2, GPT-3,
RoBERTa.
- [normalizer]: text normalization. Default: none.
- [pre]: pre-tokenization strategy. Default: none.
- [post]: post-processor for special tokens. Default: none.
- [decoder]: decoding strategy. Default: none.
- [specials]: special tokens to add to vocabulary. Default: [[]].
- [bos_token], [eos_token], [pad_token]: role markers; added to vocabulary
if not already present. Default: none.
- [unk_token]: token for unknown characters. Configures both the role and
the BPE model's unknown handling. Default: none.
- [vocab]: initial vocabulary as [(token, id)] pairs. Default: [[]].
- [merges]: merge rules as [(left, right)] pairs learned during training.
Default: [[]].
- [cache_capacity]: LRU cache size for tokenization results. Default:
[10000].
- [dropout]: probability \[[0]; [1]\] of skipping merges (data
augmentation). Default: none (no dropout).
- [continuing_subword_prefix]: prefix for non-initial subwords (e.g.,
["##"]). Default: none.
- [end_of_word_suffix]: suffix marking word boundaries (e.g., [""]).
Default: none.
- [fuse_unk]: merge consecutive unknown tokens. Default: [false].
- [byte_fallback]: use byte-level fallback (["<0x00>"]) instead of unknown
token. Default: [false].
- [ignore_merges]: skip merge application (character-level output). Default:
[false]. *)
val wordpiece :
?normalizer:Normalizer.t ->
?pre:Pre_tokenizer.t ->
?post:Post_processor.t ->
?decoder:Decoder.t ->
?specials:special list ->
?bos_token:string ->
?eos_token:string ->
?pad_token:string ->
?unk_token:string ->
?vocab:(string * int) list ->
?continuing_subword_prefix:string ->
?max_input_chars_per_word:int ->
unit ->
t
(** [wordpiece ()] is a WordPiece tokenizer. Used by BERT, DistilBERT, Electra.
WordPiece uses a greedy longest-match-first algorithm to split words into
subword pieces prefixed with a continuation marker (e.g., ["running"]
becomes [["run"; "##ning"]]).
- [vocab]: initial vocabulary as [(token, id)] pairs. Default: [[]].
- [unk_token]: token for out-of-vocabulary words. Default: ["[UNK]"].
- [continuing_subword_prefix]: prefix for non-initial subwords. Default:
["##"].
- [max_input_chars_per_word]: words longer than this are replaced with
[unk_token]. Default: [100].
Pipeline parameters ([normalizer], [pre], [post], [decoder], [specials],
[bos_token], [eos_token], [pad_token]) are as in {!bpe}. *)
val word_level :
?normalizer:Normalizer.t ->
?pre:Pre_tokenizer.t ->
?post:Post_processor.t ->
?decoder:Decoder.t ->
?specials:special list ->
?bos_token:string ->
?eos_token:string ->
?pad_token:string ->
?unk_token:string ->
?vocab:(string * int) list ->
unit ->
t
(** [word_level ()] is a word-level tokenizer.
Maps each word directly to a token ID. No subword splitting is performed.
Words not in vocabulary map to [unk_token].
{b Note.} When [pre] is not provided, {!Pre_tokenizer.whitespace} is used by
default.
- [vocab]: initial vocabulary as [(word, id)] pairs. Default: [[]].
- [unk_token]: token for out-of-vocabulary words. Default: [""].
Pipeline parameters ([normalizer], [pre], [post], [decoder], [specials],
[bos_token], [eos_token], [pad_token]) are as in {!bpe}. *)
val unigram :
?normalizer:Normalizer.t ->
?pre:Pre_tokenizer.t ->
?post:Post_processor.t ->
?decoder:Decoder.t ->
?specials:special list ->
?bos_token:string ->
?eos_token:string ->
?pad_token:string ->
?unk_token:string ->
?vocab:(string * float) list ->
unit ->
t
(** [unigram ()] is a Unigram tokenizer. Used by AlBERT, T5, mBART.
Unigram uses probabilistic segmentation to find optimal subword splits based
on token log-probabilities.
- [vocab]: initial vocabulary as [(token, score)] pairs where scores are
negative log probabilities. Default: [[]].
- [unk_token]: token for unknown characters. Default: none.
Pipeline parameters ([normalizer], [pre], [post], [decoder], [specials],
[bos_token], [eos_token], [pad_token]) are as in {!bpe}. *)
val chars :
?normalizer:Normalizer.t ->
?pre:Pre_tokenizer.t ->
?post:Post_processor.t ->
?decoder:Decoder.t ->
?specials:special list ->
?bos_token:string ->
?eos_token:string ->
?pad_token:string ->
?unk_token:string ->
unit ->
t
(** [chars ()] is a character-level tokenizer.
Each byte in the input becomes a separate token with ID equal to its ordinal
value. No vocabulary is required.
Pipeline parameters ([normalizer], [pre], [post], [decoder], [specials],
[bos_token], [eos_token], [pad_token]) are as in {!bpe}. *)
val from_model_file :
vocab:string ->
?merges:string ->
?normalizer:Normalizer.t ->
?pre:Pre_tokenizer.t ->
?post:Post_processor.t ->
?decoder:Decoder.t ->
?specials:special list ->
?bos_token:string ->
?eos_token:string ->
?pad_token:string ->
?unk_token:string ->
unit ->
t
(** [from_model_file ~vocab ()] loads a tokenizer from HuggingFace model files.
The model type is inferred from the arguments: if [merges] is provided, a
BPE tokenizer is created; otherwise WordPiece.
- [vocab]: path to vocabulary file ([vocab.json]). Expected format: JSON
object mapping tokens to IDs ([{"hello": 0, "world": 1}]).
- [merges]: path to merges file ([merges.txt]). One merge per line as
space-separated token pairs. Lines starting with ["#version"] are skipped.
Raises [Sys_error] if a file cannot be read.
Pipeline parameters ([normalizer], [pre], [post], [decoder], [specials],
[bos_token], [eos_token], [pad_token], [unk_token]) are as in {!bpe}. *)
val add_tokens : t -> string list -> t
(** [add_tokens t tokens] is [t] with [tokens] added to the vocabulary. Only
supported for word-level tokenizers.
Raises [Invalid_argument] if the tokenizer does not support dynamic
vocabulary extension. *)
(** {1:accessors Accessors} *)
val normalizer : t -> Normalizer.t option
(** [normalizer t] is [t]'s normalizer, if any. *)
val pre_tokenizer : t -> Pre_tokenizer.t option
(** [pre_tokenizer t] is [t]'s pre-tokenizer, if any. *)
val post_processor : t -> Post_processor.t option
(** [post_processor t] is [t]'s post-processor, if any. *)
val decoder : t -> Decoder.t option
(** [decoder t] is [t]'s decoder, if any. *)
val specials : t -> special list
(** [specials t] is [t]'s special tokens. *)
val bos_token : t -> string option
(** [bos_token t] is [t]'s beginning-of-sequence token, if any. *)
val eos_token : t -> string option
(** [eos_token t] is [t]'s end-of-sequence token, if any. *)
val pad_token : t -> string option
(** [pad_token t] is [t]'s padding token, if any. *)
val unk_token : t -> string option
(** [unk_token t] is [t]'s unknown token, if any. *)
(** {1:vocab Vocabulary} *)
val vocab : t -> (string * int) list
(** [vocab t] is [t]'s vocabulary as [(token, id)] pairs. *)
val vocab_size : t -> int
(** [vocab_size t] is the number of tokens in [t]'s vocabulary. *)
val token_to_id : t -> string -> int option
(** [token_to_id t token] is the ID of [token] in [t], if any. *)
val id_to_token : t -> int -> string option
(** [id_to_token t id] is the token string for [id] in [t], if any. *)
(** {1:encoding Encoding and decoding} *)
val encode :
t ->
?pair:string ->
?add_special_tokens:bool ->
?padding:padding ->
?truncation:truncation ->
string ->
Encoding.t
(** [encode t text] is the encoding of [text] by [t].
- [pair]: a second sentence for sentence-pair tasks. The post-processor
merges both sequences with appropriate type IDs. Default: none.
- [add_special_tokens]: whether to insert special tokens via the
post-processor. Default: [true].
- [padding]: padding configuration. Default: none (no padding).
- [truncation]: truncation configuration. Default: none (no truncation). *)
val encode_batch :
t ->
?add_special_tokens:bool ->
?padding:padding ->
?truncation:truncation ->
string list ->
Encoding.t list
(** [encode_batch t texts] is the encoding of each text in [texts].
Optional parameters are as in {!encode}. For sentence-pair tasks, use
{!encode_pairs_batch}. *)
val encode_pairs_batch :
t ->
?add_special_tokens:bool ->
?padding:padding ->
?truncation:truncation ->
(string * string) list ->
Encoding.t list
(** [encode_pairs_batch t pairs] encodes a batch of sentence pairs. Each element
is [(primary, secondary)].
Optional parameters are as in {!encode}. *)
val encode_ids :
t ->
?pair:string ->
?add_special_tokens:bool ->
?padding:padding ->
?truncation:truncation ->
string ->
int array
(** [encode_ids t text] is [Encoding.ids (encode t text)].
Optional parameters are as in {!encode}. *)
val decode : t -> ?skip_special_tokens:bool -> int array -> string
(** [decode t ids] is the text obtained by decoding [ids] through [t]'s
vocabulary and decoder.
[skip_special_tokens] defaults to [false]. *)
val decode_batch :
t -> ?skip_special_tokens:bool -> int array list -> string list
(** [decode_batch t ids_list] decodes each element of [ids_list].
[skip_special_tokens] defaults to [false]. *)
(** {1:training Training} *)
val train_bpe :
?init:t ->
?normalizer:Normalizer.t ->
?pre:Pre_tokenizer.t ->
?post:Post_processor.t ->
?decoder:Decoder.t ->
?specials:special list ->
?bos_token:string ->
?eos_token:string ->
?pad_token:string ->
?unk_token:string ->
?vocab_size:int ->
?min_frequency:int ->
?limit_alphabet:int ->
?initial_alphabet:string list ->
?continuing_subword_prefix:string ->
?end_of_word_suffix:string ->
?show_progress:bool ->
?max_token_length:int ->
data ->
t
(** [train_bpe data] trains a BPE tokenizer from [data].
Learns merge rules by iteratively merging the most frequent adjacent pairs
until reaching the target vocabulary size.
- [init]: existing tokenizer to extend. Default: create new.
- [vocab_size]: target vocabulary size including special tokens. Default:
[30000].
- [min_frequency]: minimum pair frequency to be merged. Default: [0].
- [limit_alphabet]: maximum number of initial characters to keep. Default:
none (keep all).
- [initial_alphabet]: characters to include regardless of frequency.
Default: [[]].
- [continuing_subword_prefix]: prefix for non-initial subwords. Default:
none.
- [end_of_word_suffix]: suffix marking word boundaries. Default: none.
- [show_progress]: display progress bar. Default: [true].
- [max_token_length]: maximum token length. Default: none.
Pipeline parameters ([normalizer], [pre], [post], [decoder], [specials],
[bos_token], [eos_token], [pad_token], [unk_token]) are as in {!bpe}. *)
val train_wordpiece :
?init:t ->
?normalizer:Normalizer.t ->
?pre:Pre_tokenizer.t ->
?post:Post_processor.t ->
?decoder:Decoder.t ->
?specials:special list ->
?bos_token:string ->
?eos_token:string ->
?pad_token:string ->
?unk_token:string ->
?vocab_size:int ->
?min_frequency:int ->
?limit_alphabet:int ->
?initial_alphabet:string list ->
?continuing_subword_prefix:string ->
?end_of_word_suffix:string ->
?show_progress:bool ->
data ->
t
(** [train_wordpiece data] trains a WordPiece tokenizer from [data].
Learns subword vocabulary by maximizing language model likelihood.
- [init]: existing tokenizer to extend. Default: create new.
- [vocab_size]: target vocabulary size including special tokens. Default:
[30000].
- [min_frequency]: minimum frequency for a subword to be included. Default:
[0].
- [limit_alphabet]: maximum number of initial characters to keep. Default:
none (keep all).
- [initial_alphabet]: characters to include regardless of frequency.
Default: [[]].
- [continuing_subword_prefix]: prefix for non-initial subwords. Default:
["##"].
- [end_of_word_suffix]: suffix marking word boundaries. Default: none.
- [show_progress]: display progress bar. Default: [true].
Pipeline parameters ([normalizer], [pre], [post], [decoder], [specials],
[bos_token], [eos_token], [pad_token], [unk_token]) are as in {!bpe}. *)
val train_wordlevel :
?init:t ->
?normalizer:Normalizer.t ->
?pre:Pre_tokenizer.t ->
?post:Post_processor.t ->
?decoder:Decoder.t ->
?specials:special list ->
?bos_token:string ->
?eos_token:string ->
?pad_token:string ->
?unk_token:string ->
?vocab_size:int ->
?min_frequency:int ->
?show_progress:bool ->
data ->
t
(** [train_wordlevel data] trains a word-level tokenizer from [data].
Builds vocabulary by collecting unique words, optionally filtering by
frequency. No subword splitting.
- [init]: existing tokenizer to extend. Default: create new.
- [vocab_size]: target vocabulary size including special tokens. Default:
[30000].
- [min_frequency]: minimum frequency for a word to be included. Default:
[0].
- [show_progress]: display progress bar. Default: [true].
Pipeline parameters ([normalizer], [pre], [post], [decoder], [specials],
[bos_token], [eos_token], [pad_token], [unk_token]) are as in {!bpe}. *)
val train_unigram :
?init:t ->
?normalizer:Normalizer.t ->
?pre:Pre_tokenizer.t ->
?post:Post_processor.t ->
?decoder:Decoder.t ->
?specials:special list ->
?bos_token:string ->
?eos_token:string ->
?pad_token:string ->
?unk_token:string ->
?vocab_size:int ->
?show_progress:bool ->
?shrinking_factor:float ->
?max_piece_length:int ->
?n_sub_iterations:int ->
data ->
t
(** [train_unigram data] trains a Unigram tokenizer from [data].
Learns probabilistic subword vocabulary using EM algorithm.
- [init]: existing tokenizer to extend. Default: create new.
- [vocab_size]: target vocabulary size including special tokens. Default:
[8000].
- [show_progress]: display progress bar. Default: [true].
- [shrinking_factor]: fraction of vocabulary to retain in each pruning
iteration. Default: [0.75].
- [max_piece_length]: maximum subword length. Default: [16].
- [n_sub_iterations]: number of EM sub-iterations per pruning round.
Default: [2].
Pipeline parameters ([normalizer], [pre], [post], [decoder], [specials],
[bos_token], [eos_token], [pad_token], [unk_token]) are as in {!bpe}. *)
(** {1:model_files Model files} *)
val export_tiktoken : t -> merges_path:string -> vocab_path:string -> unit
(** [export_tiktoken t ~merges_path ~vocab_path] exports [t]'s BPE merges and
vocabulary in tiktoken-compatible format.
{b Warning.} Only BPE tokenizers are supported. Raises [Failure] for other
model types. *)
val save_model_files :
t -> folder:string -> ?prefix:string -> unit -> string list
(** [save_model_files t ~folder ?prefix ()] saves [t]'s underlying model files
(vocabulary and merges) to [folder] and returns the list of created file
paths.
[prefix] defaults to [""]. *)
(** {1:huggingface HuggingFace compatibility} *)
val from_file : string -> (t, string) result
(** [from_file path] is a tokenizer loaded from a HuggingFace [tokenizer.json]
file. Errors if the file cannot be read or has invalid format. *)
val from_json : Jsont.json -> (t, string) result
(** [from_json json] is a tokenizer deserialized from HuggingFace JSON format.
Errors if [json] has a missing or unknown model type, or invalid parameters.
*)
val to_json : t -> Jsont.json
(** [to_json t] is [t] serialized to HuggingFace JSON format. *)
val save_pretrained : t -> path:string -> unit
(** [save_pretrained t ~path] saves [t] to [path] in HuggingFace format. Creates
[path/tokenizer.json].
Raises [Sys_error] if [path] cannot be written. *)
(** {1:fmt Formatting} *)
val pp : Format.formatter -> t -> unit
(** [pp] formats a tokenizer for inspection. Shows algorithm type, vocabulary
size, and configured pipeline stages. *)
================================================
FILE: packages/brot/lib/chars.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
type t = unit
let create () = ()
let tokenize () text =
if String.length text = 0 then []
else
let chars = ref [] in
let offset = ref 0 in
String.iter
(fun c ->
let char_str = String.make 1 c in
let id = Char.code c in
chars := (id, char_str, (!offset, !offset + 1)) :: !chars;
incr offset)
text;
List.rev !chars
let token_to_id () token =
if String.length token = 1 then Some (Char.code token.[0]) else None
let id_to_token () id =
if id >= 0 && id <= 255 then Some (String.make 1 (Char.chr id)) else None
let get_vocab () = []
let get_vocab_size () = 256 (* All ASCII characters *)
let save () ~folder:_ () = []
================================================
FILE: packages/brot/lib/chars.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Character-level tokenization model.
{b Internal module.} Each byte maps to its ordinal value as token ID.
Stateless: no vocabulary storage, no training. *)
type t
(** The type for character-level models. *)
(** {1:creation Creation} *)
val create : unit -> t
(** [create ()] is a character-level tokenizer. *)
(** {1:tokenization Tokenization} *)
val tokenize : t -> string -> (int * string * (int * int)) list
(** [tokenize t s] is the tokenization of [s] as
[(byte_value, char_string, (start, stop))] triples, one per byte. *)
(** {1:vocabulary Vocabulary} *)
val token_to_id : t -> string -> int option
(** [token_to_id t s] is the byte value of [s] when [s] is a single byte. *)
val id_to_token : t -> int -> string option
(** [id_to_token t b] is the single-byte string for byte value [b]. *)
val get_vocab : t -> (string * int) list
(** [get_vocab t] is [[]] (no explicit vocabulary). *)
val get_vocab_size : t -> int
(** [get_vocab_size t] is [1114112] (all Unicode code points). *)
(** {1:serialization Serialization} *)
val save : t -> folder:string -> unit -> string list
(** [save t ~folder ()] is [[]] (no files to write). *)
================================================
FILE: packages/brot/lib/decoder.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
type t =
| BPE of { suffix : string }
| Byte_level
| Byte_fallback
| Word_piece of { prefix : string; cleanup : bool }
| Metaspace of { replacement : char; add_prefix_space : bool }
| CTC of { pad_token : string; word_delimiter_token : string; cleanup : bool }
| Sequence of t list
| Replace of { pattern : string; replacement : string }
| Strip of { left : bool; right : bool; content : char }
| Fuse
(* Errors *)
let strf = Printf.sprintf
let err_replace_missing_pattern = "missing pattern in Replace decoder"
let err_seq_missing_decoders =
"invalid Sequence decoder: missing decoders array"
let err_unknown_type typ = strf "unknown decoder type: %s" typ
let err_expected_object = "invalid decoder JSON: expected object"
(* Decoding *)
let whitespace_re = Re.compile (Re.rep1 (Re.char ' '))
(* Literal string replacement without regex overhead. Returns [s] unchanged when
[pattern] does not occur—no allocation on the fast path. *)
let replace_all ~pattern ~by s =
let plen = String.length pattern in
let slen = String.length s in
if plen = 0 || plen > slen then s
else
let match_at i =
let rec check j =
j >= plen
|| String.unsafe_get s (i + j) = String.unsafe_get pattern j
&& check (j + 1)
in
check 0
in
let rec find_first i =
if i > slen - plen then -1
else if match_at i then i
else find_first (i + 1)
in
let pos = find_first 0 in
if pos < 0 then s
else
let buf = Buffer.create slen in
Buffer.add_substring buf s 0 pos;
Buffer.add_string buf by;
let i = ref (pos + plen) in
while !i <= slen - plen do
if match_at !i then (
Buffer.add_string buf by;
i := !i + plen)
else (
Buffer.add_char buf (String.unsafe_get s !i);
incr i)
done;
if !i < slen then Buffer.add_substring buf s !i (slen - !i);
Buffer.contents buf
let decode_bpe ~suffix tokens =
let suffix_len = String.length suffix in
let strip token =
if suffix_len > 0 && String.ends_with ~suffix token then
String.sub token 0 (String.length token - suffix_len)
else token
in
let rec loop acc = function
| [] -> List.rev acc
| [ token ] -> List.rev (strip token :: acc)
| token :: rest -> loop (" " :: strip token :: acc) rest
in
loop [] tokens
let decode_byte_level tokens =
let buf = Buffer.create 128 in
List.iter
(fun token -> Buffer.add_string buf (Pre_tokenizer.byte_level_decode token))
tokens;
Buffer.contents buf
let decode_byte_fallback tokens =
let flush acc = function
| [] -> acc
| byte_acc ->
let bytes = List.rev byte_acc in
let s = Bytes.create (List.length bytes) in
List.iteri (fun i b -> Bytes.unsafe_set s i (Char.chr b)) bytes;
Bytes.unsafe_to_string s :: acc
in
let is_byte_token token =
String.length token = 6
&& String.starts_with ~prefix:"<0x" token
&& String.ends_with ~suffix:">" token
in
let rec loop acc byte_acc = function
| [] -> List.rev (flush acc byte_acc)
| token :: rest when is_byte_token token -> (
let hex = String.sub token 3 2 in
match int_of_string_opt ("0x" ^ hex) with
| Some b when b >= 0 && b <= 255 -> loop acc (b :: byte_acc) rest
| _ -> loop (token :: flush acc byte_acc) [] rest)
| token :: rest -> loop (token :: flush acc byte_acc) [] rest
in
loop [] [] tokens
let decode_wordpiece ~prefix ~cleanup tokens =
let plen = String.length prefix in
let buf = Buffer.create 128 in
List.iteri
(fun i token ->
if i > 0 && String.starts_with ~prefix token then
Buffer.add_substring buf token plen (String.length token - plen)
else begin
if i > 0 then Buffer.add_char buf ' ';
Buffer.add_string buf token
end)
tokens;
let s = Buffer.contents buf in
if cleanup then String.trim (Re.replace_string whitespace_re ~by:" " s) else s
let decode_metaspace ~replacement ~add_prefix_space tokens =
List.mapi
(fun i token ->
let s = String.map (fun c -> if c = replacement then ' ' else c) token in
if add_prefix_space && i = 0 && String.length s > 0 && s.[0] = ' ' then
String.sub s 1 (String.length s - 1)
else s)
tokens
let decode_ctc ~pad_token ~word_delimiter_token ~cleanup tokens =
let rec dedup acc = function
| [] -> List.rev acc
| [ x ] -> List.rev (x :: acc)
| x :: (y :: _ as rest) ->
if String.equal x y then dedup acc rest else dedup (x :: acc) rest
in
let re =
if cleanup then Some (Re.compile (Re.str word_delimiter_token)) else None
in
dedup [] tokens
|> List.filter_map (fun token ->
if String.equal token pad_token then None
else
let s =
match re with
| Some re -> Re.replace_string re ~by:" " token
| None -> token
in
if String.equal s "" then None else Some s)
let decode_replace ~pattern ~replacement tokens =
[ replace_all ~pattern ~by:replacement (String.concat "" tokens) ]
let strip_token ~left ~right content token =
let len = String.length token in
let start =
if left then
let rec find i =
if i < len && Char.equal token.[i] content then find (i + 1) else i
in
find 0
else 0
in
let stop =
if right then
let rec find i =
if i >= 0 && Char.equal token.[i] content then find (i - 1) else i + 1
in
find (len - 1)
else len
in
if start < stop then String.sub token start (stop - start) else ""
let rec decode_chain decoder tokens =
match decoder with
| BPE { suffix } -> decode_bpe ~suffix tokens
| Byte_level -> [ decode_byte_level tokens ]
| Byte_fallback -> decode_byte_fallback tokens
| Word_piece { prefix; cleanup } ->
[ decode_wordpiece ~prefix ~cleanup tokens ]
| Metaspace { replacement; add_prefix_space } ->
decode_metaspace ~replacement ~add_prefix_space tokens
| CTC { pad_token; word_delimiter_token; cleanup } ->
decode_ctc ~pad_token ~word_delimiter_token ~cleanup tokens
| Replace { pattern; replacement } ->
decode_replace ~pattern ~replacement tokens
| Strip { left; right; content } ->
[ strip_token ~left ~right content (String.concat "" tokens) ]
| Fuse -> [ String.concat "" tokens ]
| Sequence decoders ->
List.fold_left (fun toks dec -> decode_chain dec toks) tokens decoders
let decode decoder tokens = String.concat "" (decode_chain decoder tokens)
(* Constructors *)
let bpe ?(suffix = "") () = BPE { suffix }
let byte_level () = Byte_level
let byte_fallback () = Byte_fallback
let wordpiece ?(prefix = "##") ?(cleanup = true) () =
Word_piece { prefix; cleanup }
let metaspace ?(replacement = '_') ?(add_prefix_space = true) () =
Metaspace { replacement; add_prefix_space }
let ctc ?(pad_token = "") ?(word_delimiter_token = "|") ?(cleanup = true)
() =
CTC { pad_token; word_delimiter_token; cleanup }
let sequence decoders = Sequence decoders
let replace ~pattern ~by () = Replace { pattern; replacement = by }
let strip ?(left = false) ?(right = false) ?(content = ' ') () =
Strip { left; right; content }
let fuse () = Fuse
(* Formatting *)
let rec pp ppf = function
| BPE { suffix } ->
if suffix <> "" then Format.fprintf ppf "bpe ~suffix:%S" suffix
else Format.fprintf ppf "bpe"
| Byte_level -> Format.fprintf ppf "byte_level"
| Byte_fallback -> Format.fprintf ppf "byte_fallback"
| Word_piece { prefix; cleanup } ->
Format.fprintf ppf "wordpiece ~prefix:%S ~cleanup:%b" prefix cleanup
| Metaspace { replacement; add_prefix_space } ->
Format.fprintf ppf "metaspace ~replacement:%C ~add_prefix_space:%b"
replacement add_prefix_space
| CTC { pad_token; word_delimiter_token; cleanup } ->
Format.fprintf ppf
"ctc ~pad_token:%S ~word_delimiter_token:%S ~cleanup:%b" pad_token
word_delimiter_token cleanup
| Replace { pattern; replacement } ->
Format.fprintf ppf "replace ~pattern:%S ~by:%S" pattern replacement
| Strip { left; right; content } ->
Format.fprintf ppf "strip ~left:%b ~right:%b ~content:%C" left right
content
| Fuse -> Format.fprintf ppf "fuse"
| Sequence decoders ->
Format.fprintf ppf "@[sequence [%a]@]"
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf ";@ ")
pp)
decoders
(* Serialization *)
let json_obj pairs =
Jsont.Json.object' (List.map (fun (k, v) -> (Jsont.Json.name k, v)) pairs)
let rec to_json = function
| BPE { suffix } ->
json_obj
[
("type", Jsont.Json.string "BPEDecoder");
("suffix", Jsont.Json.string suffix);
]
| Byte_level -> json_obj [ ("type", Jsont.Json.string "Byte_level") ]
| Byte_fallback -> json_obj [ ("type", Jsont.Json.string "Byte_fallback") ]
| Word_piece { prefix; cleanup } ->
json_obj
[
("type", Jsont.Json.string "Word_piece");
("prefix", Jsont.Json.string prefix);
("cleanup", Jsont.Json.bool cleanup);
]
| Metaspace { replacement; add_prefix_space } ->
json_obj
[
("type", Jsont.Json.string "Metaspace");
("replacement", Jsont.Json.string (String.make 1 replacement));
("add_prefix_space", Jsont.Json.bool add_prefix_space);
]
| CTC { pad_token; word_delimiter_token; cleanup } ->
json_obj
[
("type", Jsont.Json.string "CTC");
("pad_token", Jsont.Json.string pad_token);
("word_delimiter_token", Jsont.Json.string word_delimiter_token);
("cleanup", Jsont.Json.bool cleanup);
]
| Replace { pattern; replacement } ->
json_obj
[
("type", Jsont.Json.string "Replace");
("pattern", Jsont.Json.string pattern);
("content", Jsont.Json.string replacement);
]
| Strip { left; right; content } ->
json_obj
[
("type", Jsont.Json.string "Strip");
("strip_left", Jsont.Json.bool left);
("strip_right", Jsont.Json.bool right);
("content", Jsont.Json.string (String.make 1 content));
]
| Fuse -> json_obj [ ("type", Jsont.Json.string "Fuse") ]
| Sequence decoders ->
json_obj
[
("type", Jsont.Json.string "Sequence");
("decoders", Jsont.Json.list (List.map to_json decoders));
]
let find_field fields name = Option.map snd (Jsont.Json.find_mem name fields)
let string_field fields name ~default =
match find_field fields name with
| Some (Jsont.String (s, _)) -> s
| _ -> default
let bool_field fields name ~default =
match find_field fields name with
| Some (Jsont.Bool (b, _)) -> b
| _ -> default
let char_field fields name ~default =
match find_field fields name with
| Some (Jsont.String (s, _)) when String.length s > 0 -> s.[0]
| _ -> default
let rec of_json = function
| Jsont.Object (fields, _) -> (
let ( let* ) = Result.bind in
match find_field fields "type" with
| Some (Jsont.String ("BPEDecoder", _)) ->
Ok (BPE { suffix = string_field fields "suffix" ~default:"" })
| Some (Jsont.String (("Byte_level" | "ByteLevel"), _)) -> Ok Byte_level
| Some (Jsont.String (("Byte_fallback" | "ByteFallback"), _)) ->
Ok Byte_fallback
| Some (Jsont.String (("Word_piece" | "WordPiece"), _)) ->
Ok
(Word_piece
{
prefix = string_field fields "prefix" ~default:"##";
cleanup = bool_field fields "cleanup" ~default:true;
})
| Some (Jsont.String ("Metaspace", _)) ->
Ok
(Metaspace
{
replacement = char_field fields "replacement" ~default:'_';
add_prefix_space =
bool_field fields "add_prefix_space" ~default:true;
})
| Some (Jsont.String ("CTC", _)) ->
Ok
(CTC
{
pad_token = string_field fields "pad_token" ~default:"";
word_delimiter_token =
string_field fields "word_delimiter_token" ~default:"|";
cleanup = bool_field fields "cleanup" ~default:true;
})
| Some (Jsont.String ("Replace", _)) ->
let* pattern =
match find_field fields "pattern" with
| Some (Jsont.String (s, _)) -> Ok s
| Some (Jsont.Object (pattern_fields, _)) -> (
match Jsont.Json.find_mem "String" pattern_fields with
| Some (_, Jsont.String (p, _)) -> Ok p
| _ -> Error err_replace_missing_pattern)
| _ -> Error err_replace_missing_pattern
in
Ok
(Replace
{
pattern;
replacement = string_field fields "content" ~default:"";
})
| Some (Jsont.String ("Strip", _)) ->
Ok
(Strip
{
left = bool_field fields "strip_left" ~default:false;
right = bool_field fields "strip_right" ~default:false;
content = char_field fields "content" ~default:' ';
})
| Some (Jsont.String ("Fuse", _)) -> Ok Fuse
| Some (Jsont.String ("Sequence", _)) -> (
match find_field fields "decoders" with
| Some (Jsont.Array (decs, _)) ->
let* decoders =
List.fold_left
(fun acc j ->
let* acc = acc in
let* d = of_json j in
Ok (d :: acc))
(Ok []) decs
in
Ok (Sequence (List.rev decoders))
| _ -> Error err_seq_missing_decoders)
| Some (Jsont.String (typ, _)) -> Error (err_unknown_type typ)
| _ -> Error "missing or invalid decoder type field")
| _ -> Error err_expected_object
================================================
FILE: packages/brot/lib/decoder.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Decoding tokens back to text.
Decoders convert token strings back into natural text by reversing
encoding-specific transformations (prefix/suffix removal, byte-level
decoding, whitespace normalization, etc.).
Decoders operate on token {e strings}, not IDs. Convert IDs to strings via
vocabulary first, then apply {!decode}.
Some decoders transform each token independently ({e per-token}: {!bpe},
{!metaspace}, {!replace}, {!strip}, {!byte_fallback}), while others collapse
the entire token list into a single result ({e collapsing}: {!byte_level},
{!wordpiece}, {!fuse}). This distinction matters when composing decoders
with {!sequence}. *)
type t
(** The type for decoders. *)
(** {1:constructors Constructors} *)
val bpe : ?suffix:string -> unit -> t
(** [bpe ~suffix ()] is a per-token decoder for BPE-encoded tokens. Strips
[suffix] from end-of-word tokens and inserts spaces between words. [suffix]
defaults to [""]. *)
val byte_level : unit -> t
(** [byte_level ()] is a collapsing decoder that reverses GPT-2 style
byte-to-Unicode encoding back to original bytes. *)
val byte_fallback : unit -> t
(** [byte_fallback ()] is a per-token decoder for byte fallback tokens. Converts
hex byte tokens (e.g. ["<0x41>"]) back to their byte values, accumulating
consecutive byte tokens into strings. Non-byte tokens pass through
unchanged. *)
val wordpiece : ?prefix:string -> ?cleanup:bool -> unit -> t
(** [wordpiece ~prefix ~cleanup ()] is a collapsing decoder for WordPiece
tokens. Strips continuation [prefix] (default ["##"]) from non-initial
subwords and joins tokens into words. When [cleanup] is [true] (default),
normalizes whitespace in the result. *)
val metaspace : ?replacement:char -> ?add_prefix_space:bool -> unit -> t
(** [metaspace ~replacement ~add_prefix_space ()] is a per-token decoder that
converts metaspace markers back to regular spaces. [replacement] defaults to
['_']. When [add_prefix_space] is [true] (default), the leading replacement
character on the first token is stripped. *)
val ctc :
?pad_token:string ->
?word_delimiter_token:string ->
?cleanup:bool ->
unit ->
t
(** [ctc ~pad_token ~word_delimiter_token ~cleanup ()] is a per-token decoder
for
{{:https://distill.pub/2017/ctc/}CTC (Connectionist Temporal
Classification)} output. Deduplicates consecutive tokens, removes
[pad_token] (default [""]), and when [cleanup] is [true] (default),
replaces [word_delimiter_token] (default ["|"]) with spaces. *)
val sequence : t list -> t
(** [sequence decoders] chains [decoders] left-to-right. Each decoder's output
token list feeds into the next. *)
val replace : pattern:string -> by:string -> unit -> t
(** [replace ~pattern ~by ()] is a collapsing decoder that joins the token list,
replaces all literal occurrences of [pattern] with [by] in the result, and
returns a single-element list. *)
val strip : ?left:bool -> ?right:bool -> ?content:char -> unit -> t
(** [strip ~left ~right ~content ()] is a collapsing decoder that joins the
token list and removes leading (when [left] is [true]) and/or trailing (when
[right] is [true]) occurrences of [content] from the result. [left] and
[right] default to [false]; [content] defaults to [' ']. *)
val fuse : unit -> t
(** [fuse ()] is a collapsing decoder that concatenates all tokens into a single
string with no delimiter. *)
(** {1:ops Operations} *)
val decode : t -> string list -> string
(** [decode decoder tokens] applies [decoder] to [tokens] and returns the
decoded text. *)
(** {1:fmt Formatting} *)
val pp : Format.formatter -> t -> unit
(** [pp ppf decoder] formats [decoder] for debugging. *)
(** {1:serialization Serialization} *)
val to_json : t -> Jsont.json
(** [to_json decoder] serializes [decoder] to HuggingFace JSON format. *)
val of_json : Jsont.json -> (t, string) result
(** [of_json json] is a decoder from HuggingFace JSON format. Errors if [json]
is not an object, has a missing or unknown ["type"] field, or has invalid
parameters. *)
================================================
FILE: packages/brot/lib/dune
================================================
(library
(name brot)
(public_name brot)
(private_modules bpe wordpiece word_level unigram chars)
(libraries re jsont jsont.bytesrw uucp uunf))
================================================
FILE: packages/brot/lib/encoding.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
type t = {
ids : int array;
type_ids : int array;
tokens : string array;
words : int option array;
offsets : (int * int) array;
special_tokens_mask : int array;
attention_mask : int array;
mutable overflowing : t list;
sequence_ranges : (int, int * int) Hashtbl.t;
}
(* Constructors *)
let empty_ranges : (int, int * int) Hashtbl.t = Hashtbl.create 0
let empty =
{
ids = [||];
type_ids = [||];
tokens = [||];
words = [||];
offsets = [||];
special_tokens_mask = [||];
attention_mask = [||];
overflowing = [];
sequence_ranges = empty_ranges;
}
let create ~ids ~type_ids ~tokens ~words ~offsets ~special_tokens_mask
~attention_mask ?(overflowing = []) () =
{
ids;
type_ids;
tokens;
words;
offsets;
special_tokens_mask;
attention_mask;
overflowing;
sequence_ranges = empty_ranges;
}
let token ~id ~token ~offset ~type_id ~special =
{
ids = [| id |];
type_ids = [| type_id |];
tokens = [| token |];
words = [| None |];
offsets = [| offset |];
special_tokens_mask = [| (if special then 1 else 0) |];
attention_mask = [| 1 |];
overflowing = [];
sequence_ranges = empty_ranges;
}
let from_tokens tokens ~type_id =
let n = List.length tokens in
let ids = Array.make n 0 in
let token_strs = Array.make n "" in
let offsets = Array.make n (0, 0) in
List.iteri
(fun i (id, tok, off) ->
ids.(i) <- id;
token_strs.(i) <- tok;
offsets.(i) <- off)
tokens;
{
ids;
tokens = token_strs;
offsets;
words = Array.make n None;
type_ids = Array.make n type_id;
attention_mask = Array.make n 1;
special_tokens_mask = Array.make n 0;
overflowing = [];
sequence_ranges = empty_ranges;
}
let concat a b =
{
ids = Array.append a.ids b.ids;
type_ids = Array.append a.type_ids b.type_ids;
tokens = Array.append a.tokens b.tokens;
words = Array.append a.words b.words;
offsets = Array.append a.offsets b.offsets;
special_tokens_mask =
Array.append a.special_tokens_mask b.special_tokens_mask;
attention_mask = Array.append a.attention_mask b.attention_mask;
overflowing = a.overflowing;
sequence_ranges = a.sequence_ranges;
}
let concat_list encodings =
match encodings with
| [] -> empty
| [ single ] -> single
| first :: _ ->
let total =
List.fold_left (fun acc t -> acc + Array.length t.ids) 0 encodings
in
let ids = Array.make total 0 in
let type_ids = Array.make total 0 in
let tokens = Array.make total "" in
let words = Array.make total None in
let offsets = Array.make total (0, 0) in
let special_tokens_mask = Array.make total 0 in
let attention_mask = Array.make total 0 in
let pos = ref 0 in
List.iter
(fun t ->
let n = Array.length t.ids in
Array.blit t.ids 0 ids !pos n;
Array.blit t.type_ids 0 type_ids !pos n;
Array.blit t.tokens 0 tokens !pos n;
Array.blit t.words 0 words !pos n;
Array.blit t.offsets 0 offsets !pos n;
Array.blit t.special_tokens_mask 0 special_tokens_mask !pos n;
Array.blit t.attention_mask 0 attention_mask !pos n;
pos := !pos + n)
encodings;
{
ids;
type_ids;
tokens;
words;
offsets;
special_tokens_mask;
attention_mask;
overflowing = first.overflowing;
sequence_ranges = first.sequence_ranges;
}
(* Accessors *)
let is_empty t = Array.length t.ids = 0
let length t = Array.length t.ids
let ids t = t.ids
let type_ids t = t.type_ids
let tokens t = t.tokens
let word_ids t = t.words
let offsets t = t.offsets
let special_tokens_mask t = t.special_tokens_mask
let attention_mask t = t.attention_mask
let overflowing t = t.overflowing
(* Truncation *)
let slice t start len =
{
ids = Array.sub t.ids start len;
type_ids = Array.sub t.type_ids start len;
tokens = Array.sub t.tokens start len;
words = Array.sub t.words start len;
offsets = Array.sub t.offsets start len;
special_tokens_mask = Array.sub t.special_tokens_mask start len;
attention_mask = Array.sub t.attention_mask start len;
overflowing = [];
sequence_ranges = empty_ranges;
}
let truncate t ~max_length ~stride ~direction =
let encoding_len = length t in
if max_length >= encoding_len then t
else if max_length = 0 then { empty with overflowing = [ t ] }
else begin
assert (stride < max_length);
let step = max_length - stride in
let ranges =
match direction with
| `Right ->
let rec loop start acc =
if start >= encoding_len then List.rev acc
else
let stop = min (start + max_length) encoding_len in
loop (start + step) ((start, stop) :: acc)
in
loop 0 []
| `Left ->
let rec loop stop acc =
if stop <= 0 then acc
else
let start = max 0 (stop - max_length) in
loop (stop - step) ((start, stop) :: acc)
in
loop encoding_len []
in
match ranges with
| [] -> empty
| (start, stop) :: rest ->
let enc = slice t start (stop - start) in
enc.overflowing <-
List.map (fun (start, stop) -> slice t start (stop - start)) rest;
enc
end
(* Pad *)
let pad_array src n fill direction =
let src_len = Array.length src in
let dst = Array.make (src_len + n) fill in
let off = match direction with `Left -> n | `Right -> 0 in
Array.blit src 0 dst off src_len;
dst
let rec pad t ~target_length ~pad_id ~pad_type_id ~pad_token ~direction =
let overflowing =
List.map
(fun e -> pad e ~target_length ~pad_id ~pad_type_id ~pad_token ~direction)
t.overflowing
in
let current_len = length t in
if current_len >= target_length then { t with overflowing }
else
let n = target_length - current_len in
let pad_a arr fill = pad_array arr n fill direction in
let sequence_ranges =
match direction with
| `Right -> t.sequence_ranges
| `Left ->
if Hashtbl.length t.sequence_ranges = 0 then empty_ranges
else begin
let tbl = Hashtbl.create (Hashtbl.length t.sequence_ranges) in
Hashtbl.iter
(fun seq_id (start, stop) ->
Hashtbl.add tbl seq_id (start + n, stop + n))
t.sequence_ranges;
tbl
end
in
{
ids = pad_a t.ids pad_id;
type_ids = pad_a t.type_ids pad_type_id;
tokens = pad_a t.tokens pad_token;
words = pad_a t.words None;
offsets = pad_a t.offsets (0, 0);
special_tokens_mask = pad_a t.special_tokens_mask 1;
attention_mask = pad_a t.attention_mask 0;
overflowing;
sequence_ranges;
}
================================================
FILE: packages/brot/lib/encoding.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Tokenization encodings.
An encoding bundles token IDs for model input with alignment metadata: byte
offsets, word indices, segment type IDs, attention masks, and special-token
flags.
Encodings are produced by {!Brot.encode} and post-processed with
{!val-truncate} and {!val-pad}. All parallel arrays ({!val-ids},
{!val-type_ids}, {!val-tokens}, {!val-word_ids}, {!val-offsets},
{!val-special_tokens_mask}, {!val-attention_mask}) share the same length,
equal to {!val-length}. *)
type t
(** The type for tokenization encodings. *)
(** {1:construct Construction} *)
val empty : t
(** [empty] is the encoding with no tokens. *)
val create :
ids:int array ->
type_ids:int array ->
tokens:string array ->
words:int option array ->
offsets:(int * int) array ->
special_tokens_mask:int array ->
attention_mask:int array ->
?overflowing:t list ->
unit ->
t
(** [create ~ids ~type_ids ~tokens ~words ~offsets ~special_tokens_mask
~attention_mask ()] is an encoding from the given arrays.
All arrays must have the same length; no validation is performed.
[overflowing] defaults to [[]]. *)
val token :
id:int -> token:string -> offset:int * int -> type_id:int -> special:bool -> t
(** [token ~id ~token ~offset ~type_id ~special] is a single-token encoding.
When [special] is [true], {!val-special_tokens_mask} is [1] and
{!val-word_ids} is [None]; otherwise {!val-special_tokens_mask} is [0].
{!val-attention_mask} is always [1]. *)
val from_tokens : (int * string * (int * int)) list -> type_id:int -> t
(** [from_tokens tokens ~type_id] is an encoding from a list of
[(id, token_string, (start, end_offset))] triples. Every token gets the
given [type_id], {!val-attention_mask} [1], {!val-special_tokens_mask} [0]
and {!val-word_ids} [None]. *)
val concat : t -> t -> t
(** [concat a b] is the encoding with [a]'s tokens followed by [b]'s.
{!val-overflowing} and sequence ranges are taken from [a]. *)
val concat_list : t list -> t
(** [concat_list encs] is the concatenation of [encs] in order.
{!val-overflowing} and sequence ranges are taken from the first element.
Allocates once rather than creating intermediate arrays per pair. *)
(** {1:access Accessors} *)
val ids : t -> int array
(** [ids enc] is the token ID array. *)
val type_ids : t -> int array
(** [type_ids enc] is the segment ID array. Typically [0] for the first sequence
and [1] for the second in sentence-pair tasks. *)
val tokens : t -> string array
(** [tokens enc] is the string representation of each token. *)
val word_ids : t -> int option array
(** [word_ids enc] maps each token to its source word index, or [None] for
special tokens. *)
val offsets : t -> (int * int) array
(** [offsets enc] is the [(start, end_)] byte offset spans into the original
text for each token. *)
val special_tokens_mask : t -> int array
(** [special_tokens_mask enc] is [1] for special tokens ([CLS], [SEP], padding)
and [0] for content tokens. *)
val attention_mask : t -> int array
(** [attention_mask enc] is [1] for real tokens and [0] for padding tokens. *)
val overflowing : t -> t list
(** [overflowing enc] is the list of overflow encodings produced by
{!val-truncate} when the input exceeds [max_length]. Each element is a
sliding window over the excess tokens. *)
val is_empty : t -> bool
(** [is_empty enc] is [true] iff [enc] has no tokens. *)
val length : t -> int
(** [length enc] is the number of tokens in [enc]. *)
(** {1:ops Operations} *)
val truncate :
t -> max_length:int -> stride:int -> direction:[ `Left | `Right ] -> t
(** [truncate enc ~max_length ~stride ~direction] limits [enc] to at most
[max_length] tokens.
Excess tokens are split into sliding windows of size [max_length] with
overlap [stride] and stored in {!val-overflowing}. If
[length enc <= max_length], [enc] is returned unchanged.
[stride] must be strictly less than [max_length]. When [max_length] is [0],
all tokens move to {!val-overflowing} and {!val-empty} is returned. *)
val pad :
t ->
target_length:int ->
pad_id:int ->
pad_type_id:int ->
pad_token:string ->
direction:[ `Left | `Right ] ->
t
(** [pad enc ~target_length ~pad_id ~pad_type_id ~pad_token ~direction] extends
[enc] to exactly [target_length] tokens.
Padding tokens have {!val-attention_mask} [0] and {!val-special_tokens_mask}
[1]. If [length enc >= target_length], [enc] is returned unchanged. Padding
is applied recursively to {!val-overflowing} encodings. When [direction] is
[`Left], {!val-offsets} and sequence ranges are shifted accordingly. *)
================================================
FILE: packages/brot/lib/normalizer.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Errors *)
let err_expected_object = "expected JSON object"
let err_missing_type = "missing type field"
let err_replace_invalid_pattern = "invalid pattern"
let err_replace_missing_pattern = "missing pattern"
let err_replace_missing_content = "missing content"
let err_prepend_missing = "missing prepend field"
let err_sequence_missing = "missing normalizers"
let strf = Printf.sprintf
(* Type *)
type t =
| Bert of {
clean_text : bool;
handle_chinese_chars : bool;
strip_accents : bool option;
lowercase : bool;
}
| Strip of { left : bool; right : bool }
| Strip_accents
| NFC
| NFD
| NFKC
| NFKD
| Lowercase
| Replace of { pattern : string; replacement : string; compiled : Re.re }
| Prepend of string
| Byte_level of { add_prefix_space : bool; use_regex : bool }
| Sequence of t list
(* Unicode text transforms *)
let normalize_utf8 nf text =
let len = String.length text in
if len = 0 then text
else
let rec all_ascii i =
i >= len
|| (Char.code (String.unsafe_get text i) < 0x80 && all_ascii (i + 1))
in
if all_ascii 0 then text else Uunf_string.normalize_utf_8 nf text
let case_fold text =
let len = String.length text in
let rec needs_fold i =
if i >= len then false
else
let byte = Char.code (String.unsafe_get text i) in
if byte >= 0x41 && byte <= 0x5A then true
else if byte >= 128 then true
else needs_fold (i + 1)
in
if not (needs_fold 0) then text
else
let b = Buffer.create len in
let i = ref 0 in
while !i < len do
let byte = Char.code (String.unsafe_get text !i) in
if byte < 128 then (
let c = if byte >= 0x41 && byte <= 0x5A then byte + 32 else byte in
Buffer.add_char b (Char.unsafe_chr c);
incr i)
else
let d = String.get_utf_8_uchar text !i in
let n = Uchar.utf_decode_length d in
(if Uchar.utf_decode_is_valid d then
let u = Uchar.utf_decode_uchar d in
match Uucp.Case.Fold.fold u with
| `Self -> Buffer.add_utf_8_uchar b u
| `Uchars us -> List.iter (fun u -> Buffer.add_utf_8_uchar b u) us);
i := !i + n
done;
Buffer.contents b
let strip_accents_text text =
let len = String.length text in
let rec has_non_ascii i =
if i >= len then false
else if Char.code (String.unsafe_get text i) >= 128 then true
else has_non_ascii (i + 1)
in
if not (has_non_ascii 0) then text
else
let b = Buffer.create len in
let i = ref 0 in
while !i < len do
let byte = Char.code (String.unsafe_get text !i) in
if byte < 128 then (
Buffer.add_char b (Char.unsafe_chr byte);
incr i)
else
let d = String.get_utf_8_uchar text !i in
let n = Uchar.utf_decode_length d in
(if Uchar.utf_decode_is_valid d then
let u = Uchar.utf_decode_uchar d in
match Uucp.Gc.general_category u with
| `Mn | `Mc | `Me -> ()
| _ -> Buffer.add_utf_8_uchar b u);
i := !i + n
done;
Buffer.contents b
(* UTF-8 helpers *)
(* Returns (codepoint lsl 3) lor byte_length — zero allocation. *)
let[@inline] utf8_next s i =
let d = String.get_utf_8_uchar s i in
(Uchar.to_int (Uchar.utf_decode_uchar d) lsl 3) lor Uchar.utf_decode_length d
(* Character classification *)
let[@inline] is_whitespace code =
code = 0x09 || code = 0x0A || code = 0x0D || code = 0x20
|| Uucp.White.is_white_space (Uchar.of_int code)
let[@inline] is_control code =
if code = 0x09 || code = 0x0A || code = 0x0D then false
else
match Uucp.Gc.general_category (Uchar.of_int code) with
| `Cc | `Cf | `Cn | `Co -> true
| _ -> false
let[@inline] is_chinese_char code =
(code >= 0x4E00 && code <= 0x9FFF)
|| (code >= 0x3400 && code <= 0x4DBF)
|| (code >= 0x20000 && code <= 0x2A6DF)
|| (code >= 0x2A700 && code <= 0x2B73F)
|| (code >= 0x2B740 && code <= 0x2B81F)
|| (code >= 0x2B920 && code <= 0x2CEAF)
|| (code >= 0xF900 && code <= 0xFAFF)
|| (code >= 0x2F800 && code <= 0x2FA1F)
(* Operations *)
let clean_text s =
let len = String.length s in
let buf = Buffer.create len in
let i = ref 0 in
while !i < len do
let b0 = Char.code (String.unsafe_get s !i) in
if b0 < 128 then begin
if b0 = 9 || b0 = 10 || b0 = 13 || b0 = 32 then Buffer.add_char buf ' '
else if b0 >= 33 && b0 < 127 then Buffer.add_char buf (Char.unsafe_chr b0);
incr i
end
else begin
let p = utf8_next s !i in
let code = p lsr 3 and clen = p land 7 in
if code <> 0xFFFD && not (is_control code) then
if is_whitespace code then Buffer.add_char buf ' '
else Buffer.add_substring buf s !i clen;
i := !i + clen
end
done;
Buffer.contents buf
let handle_chinese_chars s =
let len = String.length s in
let rec has_non_ascii i =
i < len
&& (Char.code (String.unsafe_get s i) >= 128 || has_non_ascii (i + 1))
in
if not (has_non_ascii 0) then s
else
let buf = Buffer.create (len + (len / 4)) in
let i = ref 0 in
while !i < len do
let b0 = Char.code (String.unsafe_get s !i) in
if b0 < 128 then begin
Buffer.add_char buf (Char.unsafe_chr b0);
incr i
end
else begin
let p = utf8_next s !i in
let code = p lsr 3 and clen = p land 7 in
if is_chinese_char code then (
Buffer.add_char buf ' ';
Buffer.add_substring buf s !i clen;
Buffer.add_char buf ' ')
else Buffer.add_substring buf s !i clen;
i := !i + clen
end
done;
Buffer.contents buf
let do_strip_accents s = strip_accents_text (normalize_utf8 `NFD s)
let do_lowercase s = case_fold s
let strip_whitespace s ~left ~right =
let len = String.length s in
let start =
if left then
let rec loop i =
if i >= len then len
else
let p = utf8_next s i in
let code = p lsr 3 and clen = p land 7 in
if is_whitespace code then loop (i + clen) else i
in
loop 0
else 0
in
let stop =
if right then
let rec loop i last =
if i >= len then last
else
let p = utf8_next s i in
let code = p lsr 3 and clen = p land 7 in
let next = i + clen in
if is_whitespace code then loop next last else loop next next
in
loop start start
else len
in
if start = 0 && stop = len then s else String.sub s start (stop - start)
(* Byte-level encoding *)
let byte_to_unicode =
let is_direct b =
(b >= 33 && b <= 126) || (b >= 161 && b <= 172) || b >= 174
in
let tbl = Array.make 256 0 in
let n = ref 0 in
for b = 0 to 255 do
if is_direct b then tbl.(b) <- b
else (
tbl.(b) <- 256 + !n;
incr n)
done;
tbl
let apply_byte_level s ~add_prefix_space ~use_regex:_ =
let s =
if add_prefix_space && String.length s > 0 then
let code = utf8_next s 0 lsr 3 in
if is_whitespace code then s else " " ^ s
else s
in
let len = String.length s in
let buf = Buffer.create (len * 2) in
for i = 0 to len - 1 do
let b = Char.code (String.unsafe_get s i) in
Buffer.add_utf_8_uchar buf (Uchar.of_int byte_to_unicode.(b))
done;
Buffer.contents buf
(* Constructors *)
let nfc = NFC
let nfd = NFD
let nfkc = NFKC
let nfkd = NFKD
let lowercase = Lowercase
let strip_accents = Strip_accents
let strip ?(left = true) ?(right = true) () = Strip { left; right }
let replace ~pattern ~replacement =
Replace { pattern; replacement; compiled = Re.compile (Re.Pcre.re pattern) }
let prepend s = Prepend s
let byte_level ?(add_prefix_space = false) () =
Byte_level { add_prefix_space; use_regex = false }
let bert ?(clean_text = true) ?(handle_chinese_chars = true)
?(strip_accents = None) ?(lowercase = true) () =
Bert { clean_text; handle_chinese_chars; strip_accents; lowercase }
let sequence ns = Sequence ns
(* Apply *)
let rec apply t s =
match t with
| NFC -> normalize_utf8 `NFC s
| NFD -> normalize_utf8 `NFD s
| NFKC -> normalize_utf8 `NFKC s
| NFKD -> normalize_utf8 `NFKD s
| Lowercase -> do_lowercase s
| Strip_accents -> do_strip_accents s
| Strip { left; right } -> strip_whitespace s ~left ~right
| Replace { compiled; replacement; _ } ->
Re.replace_string compiled ~by:replacement s
| Prepend prefix -> if String.length s = 0 then s else prefix ^ s
| Byte_level { add_prefix_space; use_regex } ->
apply_byte_level s ~add_prefix_space ~use_regex
| Bert
{
clean_text = ct;
handle_chinese_chars = hcc;
strip_accents = sa;
lowercase = lc;
} ->
let s = if ct then clean_text s else s in
let s = if hcc then handle_chinese_chars s else s in
let do_strip = match sa with Some v -> v | None -> lc in
let s = if do_strip then do_strip_accents s else s in
if lc then do_lowercase s else s
| Sequence ns -> List.fold_left (fun s n -> apply n s) s ns
(* Formatting *)
let pp_bool_opt ppf = function
| None -> Format.pp_print_string ppf "None"
| Some b -> Format.fprintf ppf "Some(%b)" b
let rec pp ppf = function
| NFC -> Format.pp_print_string ppf "NFC"
| NFD -> Format.pp_print_string ppf "NFD"
| NFKC -> Format.pp_print_string ppf "NFKC"
| NFKD -> Format.pp_print_string ppf "NFKD"
| Lowercase -> Format.pp_print_string ppf "Lowercase"
| Strip_accents -> Format.pp_print_string ppf "StripAccents"
| Strip { left; right } ->
Format.fprintf ppf "@[<1>Strip(left=%b,@ right=%b)@]" left right
| Replace { pattern; replacement; _ } ->
Format.fprintf ppf "@[<1>Replace(%S,@ %S)@]" pattern replacement
| Prepend s -> Format.fprintf ppf "Prepend(%S)" s
| Byte_level { add_prefix_space; use_regex } ->
Format.fprintf ppf "@[<1>ByteLevel(add_prefix_space=%b,@ use_regex=%b)@]"
add_prefix_space use_regex
| Bert { clean_text; handle_chinese_chars; strip_accents; lowercase } ->
Format.fprintf ppf
"@[<1>Bert(clean_text=%b,@ handle_chinese_chars=%b,@ \
strip_accents=%a,@ lowercase=%b)@]"
clean_text handle_chinese_chars pp_bool_opt strip_accents lowercase
| Sequence ns ->
Format.fprintf ppf "@[<1>Sequence[%a]@]"
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf ",@ ")
pp)
ns
(*---------------------------------------------------------------------------
Serialization
---------------------------------------------------------------------------*)
let json_obj pairs =
Jsont.Json.object' (List.map (fun (k, v) -> (Jsont.Json.name k, v)) pairs)
let typed name = json_obj [ ("type", Jsont.Json.string name) ]
let typed_with name pairs = json_obj (("type", Jsont.Json.string name) :: pairs)
let rec to_json = function
| Bert { clean_text; handle_chinese_chars; strip_accents; lowercase } ->
typed_with "Bert"
[
("clean_text", Jsont.Json.bool clean_text);
("handle_chinese_chars", Jsont.Json.bool handle_chinese_chars);
( "strip_accents",
match strip_accents with
| None -> Jsont.Json.null ()
| Some b -> Jsont.Json.bool b );
("lowercase", Jsont.Json.bool lowercase);
]
| Strip { left; right } ->
typed_with "Strip"
[
("strip_left", Jsont.Json.bool left);
("strip_right", Jsont.Json.bool right);
]
| Strip_accents -> typed "StripAccents"
| NFC -> typed "NFC"
| NFD -> typed "NFD"
| NFKC -> typed "NFKC"
| NFKD -> typed "NFKD"
| Lowercase -> typed "Lowercase"
| Replace { pattern; replacement; _ } ->
typed_with "Replace"
[
("pattern", json_obj [ ("String", Jsont.Json.string pattern) ]);
("content", Jsont.Json.string replacement);
]
| Prepend prefix ->
typed_with "Prepend" [ ("prepend", Jsont.Json.string prefix) ]
| Byte_level { add_prefix_space; use_regex } ->
typed_with "ByteLevel"
[
("add_prefix_space", Jsont.Json.bool add_prefix_space);
("use_regex", Jsont.Json.bool use_regex);
]
| Sequence ns ->
typed_with "Sequence"
[ ("normalizers", Jsont.Json.list (List.map to_json ns)) ]
let rec of_json = function
| Jsont.Object (fields, _) -> (
let find name = Option.map snd (Jsont.Json.find_mem name fields) in
let get_bool name default =
match find name with Some (Jsont.Bool (b, _)) -> b | _ -> default
in
match find "type" with
| Some (Jsont.String (("Bert" | "BertNormalizer"), _)) ->
let strip_accents =
match find "strip_accents" with
| Some (Jsont.Bool (b, _)) -> Some b
| _ -> None
in
Ok
(Bert
{
clean_text = get_bool "clean_text" true;
handle_chinese_chars = get_bool "handle_chinese_chars" true;
strip_accents;
lowercase = get_bool "lowercase" true;
})
| Some (Jsont.String ("Strip", _)) ->
Ok
(Strip
{
left = get_bool "strip_left" false;
right = get_bool "strip_right" true;
})
| Some (Jsont.String ("StripAccents", _)) -> Ok Strip_accents
| Some (Jsont.String ("NFC", _)) -> Ok NFC
| Some (Jsont.String ("NFD", _)) -> Ok NFD
| Some (Jsont.String ("NFKC", _)) -> Ok NFKC
| Some (Jsont.String ("NFKD", _)) -> Ok NFKD
| Some (Jsont.String ("Lowercase", _)) -> Ok Lowercase
| Some (Jsont.String ("Replace", _)) ->
let pattern =
match find "pattern" with
| Some (Jsont.Object (pf, _)) -> (
match Jsont.Json.find_mem "String" pf with
| Some (_, Jsont.String (p, _)) -> Ok p
| _ -> Error err_replace_invalid_pattern)
| _ -> Error err_replace_missing_pattern
in
let replacement =
match find "content" with
| Some (Jsont.String (r, _)) -> Ok r
| _ -> Error err_replace_missing_content
in
Result.bind pattern (fun p ->
Result.map
(fun r -> replace ~pattern:p ~replacement:r)
replacement)
| Some (Jsont.String ("Prepend", _)) -> (
match find "prepend" with
| Some (Jsont.String (p, _)) -> Ok (Prepend p)
| _ -> Error err_prepend_missing)
| Some (Jsont.String ("ByteLevel", _)) ->
Ok
(Byte_level
{
add_prefix_space = get_bool "add_prefix_space" false;
use_regex = get_bool "use_regex" false;
})
| Some (Jsont.String ("Sequence", _)) -> (
match find "normalizers" with
| Some (Jsont.Array (l, _)) ->
let rec build acc = function
| [] -> Ok (Sequence (List.rev acc))
| item :: rest ->
Result.bind (of_json item) (fun n -> build (n :: acc) rest)
in
build [] l
| _ -> Error err_sequence_missing)
| Some (Jsont.String (other, _)) ->
Error (strf "Unknown normalizer type: %s" other)
| _ -> Error err_missing_type)
| _ -> Error err_expected_object
================================================
FILE: packages/brot/lib/normalizer.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Text normalization.
Normalizers transform text before tokenization: lowercasing, accent removal,
Unicode normalization, whitespace cleanup, and model-specific preprocessing.
They are the first stage in the tokenization pipeline, applied before
{!Pre_tokenizer} and vocabulary-based encoding.
Compose normalizers with {!val-sequence}:
{[
let n =
Normalizer.sequence
[ Normalizer.nfd; Normalizer.strip_accents; Normalizer.lowercase ]
in
Normalizer.apply n "Caf\u{00E9}"
(* "cafe" *)
]}
See {!Brot} for the full tokenization pipeline. *)
type t
(** The type for normalizers. *)
(** {1:normalizers Normalizers} *)
(** {2:unicode Unicode normalization} *)
val nfc : t
(** [nfc] is Unicode NFC normalization (canonical composition). *)
val nfd : t
(** [nfd] is Unicode NFD normalization (canonical decomposition). *)
val nfkc : t
(** [nfkc] is Unicode NFKC normalization (compatibility composition). *)
val nfkd : t
(** [nfkd] is Unicode NFKD normalization (compatibility decomposition). *)
(** {2:text Text transforms} *)
val lowercase : t
(** [lowercase] is Unicode case folding to lowercase. *)
val strip_accents : t
(** [strip_accents] removes combining marks after NFD decomposition. Applies
{!val-nfd} before stripping. *)
val strip : ?left:bool -> ?right:bool -> unit -> t
(** [strip ?left ?right ()] is a normalizer that strips Unicode whitespace from
text boundaries. [left] and [right] default to [true]. *)
val replace : pattern:string -> replacement:string -> t
(** [replace ~pattern ~replacement] is a normalizer that replaces all [pattern]
matches with [replacement]. [pattern] is a PCRE regular expression, compiled
once at construction time.
Raises [Re.Pcre.Parse_error] if [pattern] is not valid PCRE. *)
val prepend : string -> t
(** [prepend s] is a normalizer that prepends [s] to non-empty text. Empty text
is returned unchanged. *)
(** {2:byte_level Byte-level encoding} *)
val byte_level : ?add_prefix_space:bool -> unit -> t
(** [byte_level ?add_prefix_space ()] is GPT-2 style byte-level encoding. Each
byte is mapped to a printable Unicode codepoint using the GPT-2
byte-to-unicode table.
- [add_prefix_space] adds a space prefix when the text does not start with
whitespace. Defaults to [false]. *)
(** {2:model Model-specific} *)
val bert :
?clean_text:bool ->
?handle_chinese_chars:bool ->
?strip_accents:bool option ->
?lowercase:bool ->
unit ->
t
(** [bert ()] is a BERT normalizer.
- [clean_text]: remove control characters and normalize whitespace. Default:
[true].
- [handle_chinese_chars]: pad CJK ideographs with spaces. Default: [true].
- [strip_accents]: strip accents after NFD decomposition. When [None],
accents are stripped iff [lowercase] is [true]. Default: [None].
- [lowercase]: lowercase text via Unicode case folding. Default: [true]. *)
(** {2:composition Composition} *)
val sequence : t list -> t
(** [sequence ns] is the composition of normalizers [ns], applied left to right.
*)
(** {1:applying Applying} *)
val apply : t -> string -> string
(** [apply n s] is [s] normalized by [n]. *)
(** {1:formatting Formatting} *)
val pp : Format.formatter -> t -> unit
(** [pp ppf n] formats [n] for inspection. *)
(** {1:serialization Serialization} *)
val to_json : t -> Jsont.json
(** [to_json n] is [n] serialized to HuggingFace-compatible JSON. *)
val of_json : Jsont.json -> (t, string) result
(** [of_json json] is a normalizer deserialized from HuggingFace JSON. Errors if
[json] is not an object, has a missing or unknown ["type"] field, or has
invalid parameters. *)
================================================
FILE: packages/brot/lib/post_processor.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let strf = Printf.sprintf
let err_type_id tok = strf "expected integer type id after ':' in '%s'" tok
let err_piece tok = strf "expected 'id' or 'id:type_id', got '%s'" tok
let err_unknown_special tok = strf "unknown special token '%s'" tok
let err_mismatch tok = strf "ids and tokens differ in length for '%s'" tok
let err_expected what v = strf "expected %s, got %s" what v
let err_seq_id = "sequence id must be \"A\", \"B\", 0 or 1"
let err_type_id_field = "expected number for 'type_id'"
let err_missing_sequence = "template references a sequence not provided"
let err_pair_required = "pair template required when two sequences are provided"
let err_pair_must_ref_both = "pair template must reference both $A and $B"
let err_template_def = "expected string, array or null for template"
let err_unsupported_piece = "expected Sequence or SpecialToken object"
let err_special_missing_id = "missing 'id' in SpecialToken"
let err_special_missing_ids = "missing 'ids' in special token"
let err_special_entry = "expected object for special token entry"
(* Types *)
type sequence_id = Sequence_a | Sequence_b
type template_piece =
| Piece_sequence of { id : sequence_id; type_id : int }
| Piece_special of { key : string; type_id : int }
type template = template_piece list
type special_token = {
key : string;
value_ids : int list;
value_tokens : string list;
}
type token = string * int
type t =
| Bert of { sep : token; cls : token }
| Roberta of {
sep : token;
cls : token;
pad : token;
trim_offsets : bool;
add_prefix_space : bool;
}
| ByteLevel of { trim_offsets : bool }
| Template of {
single : template;
pair : template option;
special_tokens : special_token list;
}
| Sequence of t list
(* Helpers *)
let special_token ~id ~token ~type_id =
Encoding.token ~id ~token ~offset:(0, 0) ~type_id ~special:true
let with_type_id enc type_id =
Encoding.create ~ids:(Encoding.ids enc)
~type_ids:(Array.make (Encoding.length enc) type_id)
~tokens:(Encoding.tokens enc) ~words:(Encoding.word_ids enc)
~offsets:(Encoding.offsets enc)
~special_tokens_mask:(Encoding.special_tokens_mask enc)
~attention_mask:(Encoding.attention_mask enc)
()
let is_ws = function
| ' ' | '\t' | '\n' | '\r' | '\x0b' | '\x0c' -> true
| _ -> false
let build_special_lookup special_tokens =
let tbl = Hashtbl.create (List.length special_tokens + 1) in
List.iter (fun tok -> Hashtbl.replace tbl tok.key tok) special_tokens;
tbl
let string_is_int s =
let len = String.length s in
let rec loop i =
if i >= len then true
else match s.[i] with '0' .. '9' -> loop (i + 1) | _ -> false
in
len > 0 && loop 0
let sequence_id_to_label = function Sequence_a -> "A" | Sequence_b -> "B"
let sequence_id_to_index = function Sequence_a -> 0 | Sequence_b -> 1
(* JSON helpers *)
let json_obj pairs =
Jsont.Json.object' (List.map (fun (k, v) -> (Jsont.Json.name k, v)) pairs)
let json_find name fields =
match Jsont.Json.find_mem name fields with
| Some (_, v) -> Some v
| None -> None
let json_bool_field fields name ~default =
match json_find name fields with
| Some (Jsont.Bool (b, _)) -> b
| _ -> default
let json_str_int_pair fields name ~default =
match json_find name fields with
| Some (Jsont.Array ([ Jsont.String (s, _); Jsont.Number (f, _) ], _)) ->
(s, int_of_float f)
| _ -> default
(* Processors *)
let process_bert ~sep ~cls encodings ~add_special_tokens =
if not add_special_tokens then encodings
else
let cls_str, cls_id = cls in
let sep_str, sep_id = sep in
let cls_tok tid = special_token ~id:cls_id ~token:cls_str ~type_id:tid in
let sep_tok tid = special_token ~id:sep_id ~token:sep_str ~type_id:tid in
match encodings with
| [] -> []
| [ encoding ] ->
[
Encoding.concat_list [ cls_tok 0; with_type_id encoding 0; sep_tok 0 ];
]
| [ enc1; enc2 ] ->
[
Encoding.concat_list
[
cls_tok 0;
with_type_id enc1 0;
sep_tok 0;
with_type_id enc2 1;
sep_tok 1;
];
]
| _ -> encodings
let process_roberta ~sep ~cls ~pad:_ ~trim_offsets:_ ~add_prefix_space:_
encodings ~add_special_tokens =
if not add_special_tokens then encodings
else
let cls_str, cls_id = cls in
let sep_str, sep_id = sep in
let cls_tok = special_token ~id:cls_id ~token:cls_str ~type_id:0 in
let sep_tok = special_token ~id:sep_id ~token:sep_str ~type_id:0 in
match encodings with
| [] -> []
| [ encoding ] ->
[ Encoding.concat_list [ cls_tok; with_type_id encoding 0; sep_tok ] ]
| [ enc1; enc2 ] ->
[
Encoding.concat_list
[
cls_tok;
with_type_id enc1 0;
sep_tok;
sep_tok;
with_type_id enc2 0;
sep_tok;
];
]
| _ -> encodings
let trim_offset enc_tokens idx (start, stop) =
if start >= stop then (start, stop)
else
let token =
if idx < Array.length enc_tokens then enc_tokens.(idx) else ""
in
let decoded = Pre_tokenizer.byte_level_decode token in
let len = String.length decoded in
let rec leading i =
if i >= len then len else if is_ws decoded.[i] then leading (i + 1) else i
in
let rec trailing i =
if i <= 0 then len
else if is_ws decoded.[i - 1] then trailing (i - 1)
else i
in
let lead = leading 0 in
let trail = trailing len in
let trimmed_lead = min (stop - start) lead in
let trimmed_trail = min (stop - start - trimmed_lead) (len - trail) in
let new_start = start + trimmed_lead in
let new_stop = max new_start (stop - trimmed_trail) in
(new_start, new_stop)
let process_byte_level ~trim_offsets encodings ~add_special_tokens:_ =
if not trim_offsets then encodings
else
List.map
(fun encoding ->
let enc_tokens = Encoding.tokens encoding in
let new_offsets =
Array.mapi (trim_offset enc_tokens) (Encoding.offsets encoding)
in
Encoding.create ~ids:(Encoding.ids encoding)
~type_ids:(Encoding.type_ids encoding)
~tokens:enc_tokens
~words:(Encoding.word_ids encoding)
~offsets:new_offsets
~special_tokens_mask:(Encoding.special_tokens_mask encoding)
~attention_mask:(Encoding.attention_mask encoding)
~overflowing:(Encoding.overflowing encoding)
())
encodings
(* Template parsing *)
let split_template_string str =
let len = String.length str in
let rec skip_ws i =
if i >= len then len
else match str.[i] with ' ' | '\t' -> skip_ws (i + 1) | _ -> i
in
let rec find_end i =
if i >= len then len
else match str.[i] with ' ' | '\t' -> i | _ -> find_end (i + 1)
in
let rec loop i acc =
let i = skip_ws i in
if i >= len then List.rev acc
else
let j = find_end i in
loop j (String.sub str i (j - i) :: acc)
in
loop 0 []
let parse_sequence_base base =
let lower = String.lowercase_ascii base in
if lower = "$" || lower = "$a" then Some (Sequence_a, 0)
else if lower = "$b" then Some (Sequence_b, 0)
else if String.length base > 0 && base.[0] = '$' then
let rest = String.sub base 1 (String.length base - 1) in
if string_is_int rest then Some (Sequence_a, int_of_string rest) else None
else None
let parse_template_piece_from_string ~special_lookup token =
let parts = String.split_on_char ':' token in
let base, explicit_type =
match parts with
| [ id; type_part ] when string_is_int type_part ->
(id, Some (int_of_string type_part))
| [ _; _ ] -> invalid_arg (err_type_id token)
| [ id ] -> (id, None)
| _ -> invalid_arg (err_piece token)
in
match parse_sequence_base base with
| Some (seq_id, default_type) ->
let type_id = Option.value ~default:default_type explicit_type in
Piece_sequence { id = seq_id; type_id }
| None ->
if Hashtbl.mem special_lookup base then
let type_id = Option.value ~default:0 explicit_type in
Piece_special { key = base; type_id }
else invalid_arg (err_unknown_special token)
let parse_template_string ~special_lookup str =
List.map
(parse_template_piece_from_string ~special_lookup)
(split_template_string str)
let parse_sequence_id_json fields =
match json_find "id" fields with
| Some (Jsont.String (s, _)) -> (
match String.lowercase_ascii s with
| "a" -> Sequence_a
| "b" -> Sequence_b
| _ -> invalid_arg err_seq_id)
| Some (Jsont.Number (v, _)) -> (
match int_of_float v with
| 0 -> Sequence_a
| 1 -> Sequence_b
| _ -> invalid_arg err_seq_id)
| None -> Sequence_a
| _ -> invalid_arg err_seq_id
let json_type_id fields =
match json_find "type_id" fields with
| Some (Jsont.Number (v, _)) -> int_of_float v
| None -> 0
| _ -> invalid_arg err_type_id_field
let parse_template_piece_from_json ~special_lookup json =
match json with
| Jsont.Object (outer_fields, _) -> (
match json_find "Sequence" outer_fields with
| Some (Jsont.Object (fields, _)) ->
let id = parse_sequence_id_json fields in
let type_id = json_type_id fields in
Piece_sequence { id; type_id }
| _ -> (
match json_find "SpecialToken" outer_fields with
| Some (Jsont.Object (fields, _)) ->
let key =
match json_find "id" fields with
| Some (Jsont.String (s, _)) -> s
| _ -> invalid_arg err_special_missing_id
in
if not (Hashtbl.mem special_lookup key) then
invalid_arg (err_unknown_special key);
let type_id = json_type_id fields in
Piece_special { key; type_id }
| _ -> invalid_arg err_unsupported_piece))
| _ -> invalid_arg err_unsupported_piece
let parse_template_definition ~special_lookup = function
| Jsont.String (s, _) -> parse_template_string ~special_lookup s
| Jsont.Array (l, _) ->
List.map (parse_template_piece_from_json ~special_lookup) l
| Jsont.Null _ -> []
| _ -> invalid_arg err_template_def
(* Template encoding *)
let build_encoding_from_pieces pieces source_encodings special_lookup =
let ids_rev = ref [] in
let type_ids_rev = ref [] in
let tokens_rev = ref [] in
let words_rev = ref [] in
let offsets_rev = ref [] in
let special_mask_rev = ref [] in
let attention_rev = ref [] in
let append ~id ~token ~word ~type_id ~offset ~special ~attention =
ids_rev := id :: !ids_rev;
type_ids_rev := type_id :: !type_ids_rev;
tokens_rev := token :: !tokens_rev;
words_rev := word :: !words_rev;
offsets_rev := offset :: !offsets_rev;
special_mask_rev := special :: !special_mask_rev;
attention_rev := attention :: !attention_rev
in
let append_sequence seq_id type_id =
let index = sequence_id_to_index seq_id in
if index >= Array.length source_encodings then
invalid_arg err_missing_sequence;
let src = source_encodings.(index) in
let src_ids = Encoding.ids src in
let src_tokens = Encoding.tokens src in
let src_words = Encoding.word_ids src in
let src_offsets = Encoding.offsets src in
let src_special = Encoding.special_tokens_mask src in
let src_attention = Encoding.attention_mask src in
let len = Array.length src_ids in
for i = 0 to len - 1 do
let token = if i < Array.length src_tokens then src_tokens.(i) else "" in
let word = if i < Array.length src_words then src_words.(i) else None in
let offset =
if i < Array.length src_offsets then src_offsets.(i) else (0, 0)
in
let special =
if i < Array.length src_special && src_special.(i) <> 0 then 1 else 0
in
let attention =
if i < Array.length src_attention && src_attention.(i) <> 0 then 1
else 0
in
append ~id:src_ids.(i) ~token ~word ~type_id ~offset ~special ~attention
done
in
let append_special key type_id =
match Hashtbl.find_opt special_lookup key with
| None -> invalid_arg (err_unknown_special key)
| Some special ->
let rec loop ids tokens =
match (ids, tokens) with
| id :: rest_ids, token :: rest_tokens ->
append ~id ~token ~word:None ~type_id ~offset:(0, 0) ~special:1
~attention:1;
loop rest_ids rest_tokens
| [], [] -> ()
| _ -> invalid_arg (err_mismatch key)
in
loop special.value_ids special.value_tokens
in
List.iter
(function
| Piece_sequence { id; type_id } -> append_sequence id type_id
| Piece_special { key; type_id } -> append_special key type_id)
pieces;
let to_array r = Array.of_list (List.rev !r) in
Encoding.create ~ids:(to_array ids_rev) ~type_ids:(to_array type_ids_rev)
~tokens:(to_array tokens_rev) ~words:(to_array words_rev)
~offsets:(to_array offsets_rev)
~special_tokens_mask:(to_array special_mask_rev)
~attention_mask:(to_array attention_rev) ()
let process_template ~single ~pair ~special_tokens encodings ~add_special_tokens
=
if not add_special_tokens then encodings
else
let special_lookup = build_special_lookup special_tokens in
let source = Array.of_list encodings in
match Array.length source with
| 0 -> []
| 1 -> [ build_encoding_from_pieces single source special_lookup ]
| 2 ->
let pair =
match pair with Some p -> p | None -> invalid_arg err_pair_required
in
[ build_encoding_from_pieces pair source special_lookup ]
| _ -> encodings
(* Processing *)
let rec process_list processor encodings ~add_special_tokens =
match processor with
| Bert { sep; cls } -> process_bert ~sep ~cls encodings ~add_special_tokens
| Roberta { sep; cls; pad; trim_offsets; add_prefix_space } ->
process_roberta ~sep ~cls ~pad ~trim_offsets ~add_prefix_space encodings
~add_special_tokens
| ByteLevel { trim_offsets } ->
process_byte_level ~trim_offsets encodings ~add_special_tokens
| Template { single; pair; special_tokens } ->
process_template ~single ~pair ~special_tokens encodings
~add_special_tokens
| Sequence processors ->
List.fold_left
(fun encs proc -> process_list proc encs ~add_special_tokens)
encodings processors
let process processor ?pair enc ~add_special_tokens =
let encodings = match pair with None -> [ enc ] | Some p -> [ enc; p ] in
match process_list processor encodings ~add_special_tokens with
| [ r ] -> r
| r :: _ -> r
| [] -> enc
let rec added_tokens processor ~is_pair =
match processor with
| Bert _ -> if is_pair then 3 else 2
| Roberta _ -> if is_pair then 4 else 2
| ByteLevel _ -> 0
| Template { single; pair; special_tokens } ->
let lookup = build_special_lookup special_tokens in
let count_special pieces =
List.fold_left
(fun acc piece ->
match piece with
| Piece_special { key; _ } -> (
match Hashtbl.find_opt lookup key with
| Some tok -> acc + List.length tok.value_ids
| None -> acc)
| _ -> acc)
0 pieces
in
if is_pair then
match pair with
| Some p -> count_special p
| None -> count_special single
else count_special single
| Sequence processors ->
List.fold_left
(fun acc proc -> acc + added_tokens proc ~is_pair)
0 processors
(* Constructors *)
let bert ~sep ~cls () = Bert { sep; cls }
let roberta ~sep ~cls ?(trim_offsets = true) ?(add_prefix_space = true) () =
let pad = ("", 1) in
Roberta { sep; cls; pad; trim_offsets; add_prefix_space }
let byte_level ?(trim_offsets = true) () = ByteLevel { trim_offsets }
let template ~single ?pair ?(special_tokens = []) () =
let specials =
List.map
(fun (token, id) ->
{ key = token; value_ids = [ id ]; value_tokens = [ token ] })
special_tokens
in
let lookup = build_special_lookup specials in
let single = parse_template_string ~special_lookup:lookup single in
let has_sequence pieces seq =
List.exists
(function Piece_sequence { id; _ } when id = seq -> true | _ -> false)
pieces
in
let pair =
match pair with
| None -> None
| Some p ->
let tpl = parse_template_string ~special_lookup:lookup p in
if not (has_sequence tpl Sequence_a && has_sequence tpl Sequence_b) then
invalid_arg err_pair_must_ref_both;
Some tpl
in
Template { single; pair; special_tokens = specials }
let sequence processors = Sequence processors
(* Formatting *)
let rec pp ppf = function
| Bert { sep = sep_s, _; cls = cls_s, _ } ->
Format.fprintf ppf "@[<2>Bert@ ~cls:%S@ ~sep:%S@]" cls_s sep_s
| Roberta { sep = sep_s, _; cls = cls_s, _; _ } ->
Format.fprintf ppf "@[<2>Roberta@ ~cls:%S@ ~sep:%S@]" cls_s sep_s
| ByteLevel { trim_offsets } ->
Format.fprintf ppf "@[<2>ByteLevel@ ~trim_offsets:%b@]" trim_offsets
| Template _ -> Format.fprintf ppf "Template"
| Sequence processors ->
Format.fprintf ppf "@[<2>Sequence[@,%a]@]"
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf ",@ ")
pp)
processors
(* Serialization *)
let token_pair_to_json (s, id) =
Jsont.Json.list [ Jsont.Json.string s; Jsont.Json.int id ]
let template_to_json pieces =
let piece_json tag id type_id =
json_obj
[ (tag, json_obj [ ("id", id); ("type_id", Jsont.Json.int type_id) ]) ]
in
Jsont.Json.list
(List.map
(function
| Piece_sequence { id; type_id } ->
piece_json "Sequence"
(Jsont.Json.string (sequence_id_to_label id))
type_id
| Piece_special { key; type_id } ->
piece_json "SpecialToken" (Jsont.Json.string key) type_id)
pieces)
let rec to_json = function
| Bert { sep; cls } ->
json_obj
[
("type", Jsont.Json.string "BertProcessing");
("sep", token_pair_to_json sep);
("cls", token_pair_to_json cls);
]
| Roberta { sep; cls; pad; trim_offsets; add_prefix_space } ->
json_obj
[
("type", Jsont.Json.string "RobertaProcessing");
("sep", token_pair_to_json sep);
("cls", token_pair_to_json cls);
("pad", token_pair_to_json pad);
("trim_offsets", Jsont.Json.bool trim_offsets);
("add_prefix_space", Jsont.Json.bool add_prefix_space);
]
| ByteLevel { trim_offsets } ->
json_obj
[
("type", Jsont.Json.string "ByteLevel");
("trim_offsets", Jsont.Json.bool trim_offsets);
]
| Template { single; pair; special_tokens } ->
let pair_json =
match pair with
| None -> Jsont.Json.null ()
| Some p -> template_to_json p
in
let special_token_json tok =
let ids = Jsont.Json.list (List.map Jsont.Json.int tok.value_ids) in
let tokens =
Jsont.Json.list (List.map Jsont.Json.string tok.value_tokens)
in
( Jsont.Json.name tok.key,
json_obj
[
("id", Jsont.Json.string tok.key); ("ids", ids); ("tokens", tokens);
] )
in
let special_json =
Jsont.Json.object' (List.map special_token_json special_tokens)
in
json_obj
[
("type", Jsont.Json.string "TemplateProcessing");
("single", template_to_json single);
("pair", pair_json);
("special_tokens", special_json);
]
| Sequence processors ->
json_obj
[
("type", Jsont.Json.string "Sequence");
("processors", Jsont.Json.list (List.map to_json processors));
]
(* Deserialization *)
let parse_special_token_json fields alias =
let key =
match json_find "id" fields with
| Some (Jsont.String (s, _)) -> s
| _ -> alias
in
let value_ids =
match json_find "ids" fields with
| Some (Jsont.Array (lst, _)) ->
List.map
(function
| Jsont.Number (f, _) -> int_of_float f
| v ->
invalid_arg
(err_expected "number" (Format.asprintf "%a" Jsont.pp_json v)))
lst
| _ -> invalid_arg err_special_missing_ids
in
let value_tokens =
match json_find "tokens" fields with
| Some (Jsont.Array (lst, _)) ->
List.map
(function
| Jsont.String (s, _) -> s
| v ->
invalid_arg
(err_expected "string" (Format.asprintf "%a" Jsont.pp_json v)))
lst
| _ -> [ key ]
in
if List.length value_ids <> List.length value_tokens then
invalid_arg (err_mismatch key);
{ key; value_ids; value_tokens }
let parse_special_tokens_json fields =
match json_find "special_tokens" fields with
| Some (Jsont.Object (tokens, _)) ->
List.map
(fun ((alias, _), value) ->
match value with
| Jsont.Object (token_fields, _) ->
parse_special_token_json token_fields alias
| _ -> invalid_arg err_special_entry)
tokens
| Some v ->
invalid_arg
(err_expected "object for 'special_tokens'"
(Format.asprintf "%a" Jsont.pp_json v))
| None -> []
let rec of_json_exn json =
match json with
| Jsont.Object (fields, _) -> (
match json_find "type" fields with
| Some (Jsont.String ("BertProcessing", _)) ->
let sep = json_str_int_pair fields "sep" ~default:("[SEP]", 102) in
let cls = json_str_int_pair fields "cls" ~default:("[CLS]", 101) in
Bert { sep; cls }
| Some (Jsont.String ("RobertaProcessing", _)) ->
let sep = json_str_int_pair fields "sep" ~default:("", 2) in
let cls = json_str_int_pair fields "cls" ~default:("", 0) in
let pad = json_str_int_pair fields "pad" ~default:("", 1) in
let trim_offsets =
json_bool_field fields "trim_offsets" ~default:true
in
let add_prefix_space =
json_bool_field fields "add_prefix_space" ~default:true
in
Roberta { sep; cls; pad; trim_offsets; add_prefix_space }
| Some (Jsont.String ("ByteLevel", _)) ->
let trim_offsets =
json_bool_field fields "trim_offsets" ~default:true
in
ByteLevel { trim_offsets }
| Some (Jsont.String ("TemplateProcessing", _)) ->
let special_tokens = parse_special_tokens_json fields in
let lookup = build_special_lookup special_tokens in
let single =
match json_find "single" fields with
| Some json -> parse_template_definition ~special_lookup:lookup json
| None -> parse_template_string ~special_lookup:lookup "$A"
in
let pair =
match json_find "pair" fields with
| Some (Jsont.Null _) | None -> None
| Some json ->
Some (parse_template_definition ~special_lookup:lookup json)
in
Template { single; pair; special_tokens }
| Some (Jsont.String ("Sequence", _)) -> (
match json_find "processors" fields with
| Some (Jsont.Array (procs, _)) ->
Sequence (List.map of_json_exn procs)
| _ -> failwith "expected array for 'processors'")
| _ -> failwith "unsupported processor type")
| _ -> failwith "expected JSON object"
let of_json json =
try Ok (of_json_exn json) with
| Failure msg -> Error msg
| Invalid_argument msg -> Error msg
================================================
FILE: packages/brot/lib/post_processor.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Post-processing tokenization output with special tokens.
Post-processors add special tokens and type IDs to tokenized sequences after
core tokenization. They handle model-specific requirements like [[CLS]] and
[[SEP]] for BERT, sentence pair formatting, and byte-level offset
adjustments. *)
type t
(** The type for post-processors. *)
type token = string * int
(** A special token as [(text, id)]. *)
(** {1:constructors Constructors} *)
val bert : sep:token -> cls:token -> unit -> t
(** [bert ~sep ~cls ()] is a BERT-style post-processor.
Single: [[CLS] A [SEP]]. Pair: [[CLS] A [SEP] B [SEP]]. Type IDs: [0] for
the first sequence, [1] for the second. *)
val roberta :
sep:token ->
cls:token ->
?trim_offsets:bool ->
?add_prefix_space:bool ->
unit ->
t
(** [roberta ~sep ~cls ()] is a RoBERTa-style post-processor.
Single: [ A ]. Pair: [ A B ]. All type IDs are [0].
[trim_offsets] defaults to [true]. [add_prefix_space] defaults to [true]. *)
val byte_level : ?trim_offsets:bool -> unit -> t
(** [byte_level ()] is a byte-level post-processor that adjusts character
offsets for byte-level encoding.
[trim_offsets] removes leading and trailing whitespace from offsets.
Defaults to [true]. *)
val template :
single:string -> ?pair:string -> ?special_tokens:token list -> unit -> t
(** [template ~single ()] is a template-based post-processor.
Templates use [$A] and [$B] as sequence placeholders and literal special
token names (e.g. [[CLS]]). Type IDs can be specified with a colon suffix:
[$A:0], [[SEP]:1].
[special_tokens] defaults to [[]]. *)
val sequence : t list -> t
(** [sequence processors] chains [processors] left-to-right. *)
(** {1:processing Processing} *)
val process :
t -> ?pair:Encoding.t -> Encoding.t -> add_special_tokens:bool -> Encoding.t
(** [process t enc ~add_special_tokens] adds special tokens and sets type IDs on
[enc].
When [~pair] is provided, both sequences are merged into a single encoding
with appropriate type IDs. When [~add_special_tokens] is [false], special
token insertion is skipped but byte-level offset trimming still applies. *)
val added_tokens : t -> is_pair:bool -> int
(** [added_tokens t ~is_pair] is the number of special tokens [t] adds. Useful
for calculating the truncation budget. *)
(** {1:fmt Formatting} *)
val pp : Format.formatter -> t -> unit
(** [pp] formats a post-processor for inspection. *)
(** {1:serialization Serialization} *)
val of_json : Jsont.json -> (t, string) result
(** [of_json json] is a post-processor from HuggingFace [tokenizer.json] format.
Errors if [json] is not an object, has a missing or unknown ["type"] field,
or has invalid parameters. *)
val to_json : t -> Jsont.json
(** [to_json t] is [t] serialized to HuggingFace [tokenizer.json] format. *)
================================================
FILE: packages/brot/lib/pre_tokenizer.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Types *)
type behavior =
[ `Isolated
| `Removed
| `Merged_with_previous
| `Merged_with_next
| `Contiguous ]
type prepend_scheme = [ `First | `Never | `Always ]
type t =
| Byte_level of {
add_prefix_space : bool;
use_regex : bool;
trim_offsets : bool;
}
| Bert
| Whitespace
| Whitespace_split
| Punctuation of { behavior : behavior }
| Split of { pattern : string; behavior : behavior; invert : bool }
| Char_delimiter of char
| Digits of { individual : bool }
| Metaspace of {
replacement : char;
prepend_scheme : prepend_scheme;
split : bool;
}
| Sequence of t list
| Fixed_length of { length : int }
| Unicode_scripts
(* Errors *)
let strf = Printf.sprintf
let err_unknown_behavior s = strf "unknown punctuation behavior '%s'" s
let err_unknown_scheme s = strf "unknown prepend_scheme '%s'" s
let err_unsupported_type s = strf "unsupported pre-tokenizer type '%s'" s
let err_expected_char name = strf "expected single character for '%s'" name
let err_missing_type = "missing 'type' field"
let err_expected_object = "expected JSON object"
let err_missing_behavior = "missing 'behavior' field"
let err_split_missing = "requires 'pattern' and 'behavior'"
let err_char_delim_missing = "requires 'delimiter'"
let err_metaspace_missing = "requires 'replacement' and 'prepend_scheme'"
let err_sequence_missing = "requires 'pretokenizers' list"
let err_fixed_length = "requires positive length"
(* Character classification *)
(* ASCII property table: packed flags for O(1) classification. bit 0:
whitespace, bit 1: alphabetic, bit 2: numeric, bit 3: punctuation *)
let ascii_props =
let t = Array.make 128 0 in
for i = 9 to 13 do
t.(i) <- t.(i) lor 1
done;
t.(32) <- t.(32) lor 1;
for i = 65 to 90 do
t.(i) <- t.(i) lor 2
done;
for i = 97 to 122 do
t.(i) <- t.(i) lor 2
done;
for i = 48 to 57 do
t.(i) <- t.(i) lor 4
done;
List.iter
(fun i -> t.(i) <- t.(i) lor 8)
[
33;
34;
35;
37;
38;
39;
40;
41;
42;
44;
45;
46;
47;
58;
59;
63;
64;
91;
92;
93;
95;
123;
125;
];
t
let[@inline] is_whitespace code =
if code < 128 then Array.unsafe_get ascii_props code land 1 <> 0
else Uucp.White.is_white_space (Uchar.of_int code)
let[@inline] is_alphabetic code =
if code < 128 then Array.unsafe_get ascii_props code land 2 <> 0
else Uucp.Alpha.is_alphabetic (Uchar.of_int code)
let[@inline] is_numeric code =
if code < 128 then Array.unsafe_get ascii_props code land 4 <> 0
else
match Uucp.Gc.general_category (Uchar.of_int code) with
| `Nd | `Nl | `No -> true
| _ -> false
let[@inline] is_punctuation code =
if code < 128 then Array.unsafe_get ascii_props code land 8 <> 0
else
match Uucp.Gc.general_category (Uchar.of_int code) with
| `Pc | `Pd | `Pe | `Pf | `Pi | `Po | `Ps -> true
| _ -> false
(* Returns (codepoint lsl 3) lor byte_length — zero allocation. *)
let[@inline] utf8_next s i =
let c = Char.code (String.unsafe_get s i) in
if c < 0x80 then (c lsl 3) lor 1
else if c < 0xE0 then
(((c land 0x1F) lsl 6)
lor (Char.code (String.unsafe_get s (i + 1)) land 0x3F))
lsl 3
lor 2
else if c < 0xF0 then
(((c land 0x0F) lsl 12)
lor ((Char.code (String.unsafe_get s (i + 1)) land 0x3F) lsl 6)
lor (Char.code (String.unsafe_get s (i + 2)) land 0x3F))
lsl 3
lor 3
else
(((c land 0x07) lsl 18)
lor ((Char.code (String.unsafe_get s (i + 1)) land 0x3F) lsl 12)
lor ((Char.code (String.unsafe_get s (i + 2)) land 0x3F) lsl 6)
lor (Char.code (String.unsafe_get s (i + 3)) land 0x3F))
lsl 3
lor 4
(* Pre-computed byte ↔ unicode mappings for byte-level encode/decode *)
let byte_to_unicode, unicode_to_byte =
let is_direct = Array.make 256 false in
for i = 33 to 126 do
is_direct.(i) <- true
done;
for i = 161 to 172 do
is_direct.(i) <- true
done;
for i = 174 to 255 do
is_direct.(i) <- true
done;
let byte_to_unicode = Array.make 256 0 in
let next_code = ref 0 in
let max_code = ref 0 in
for b = 0 to 255 do
let code =
if is_direct.(b) then b
else
let code = 256 + !next_code in
incr next_code;
code
in
byte_to_unicode.(b) <- code;
if code > !max_code then max_code := code
done;
let unicode_to_byte = Array.make (!max_code + 1) (-1) in
for b = 0 to 255 do
let code = byte_to_unicode.(b) in
if code < Array.length unicode_to_byte then unicode_to_byte.(code) <- b
done;
(byte_to_unicode, unicode_to_byte)
let byte_level_encode text =
let len = String.length text in
(* Worst case: every byte remaps to a 2-byte UTF-8 sequence *)
let result = Bytes.create (len * 2) in
let j = ref 0 in
for i = 0 to len - 1 do
let u =
Array.unsafe_get byte_to_unicode (Char.code (String.unsafe_get text i))
in
if u < 128 then begin
Bytes.unsafe_set result !j (Char.unsafe_chr u);
incr j
end
else begin
Bytes.unsafe_set result !j (Char.unsafe_chr (0xC0 lor (u lsr 6)));
Bytes.unsafe_set result (!j + 1)
(Char.unsafe_chr (0x80 lor (u land 0x3F)));
j := !j + 2
end
done;
Bytes.sub_string result 0 !j
let byte_level_encode_range text ~start ~len =
let result = Bytes.create (len * 2) in
let j = ref 0 in
for i = start to start + len - 1 do
let u =
Array.unsafe_get byte_to_unicode (Char.code (String.unsafe_get text i))
in
if u < 128 then begin
Bytes.unsafe_set result !j (Char.unsafe_chr u);
incr j
end
else begin
Bytes.unsafe_set result !j (Char.unsafe_chr (0xC0 lor (u lsr 6)));
Bytes.unsafe_set result (!j + 1)
(Char.unsafe_chr (0x80 lor (u land 0x3F)));
j := !j + 2
end
done;
Bytes.sub_string result 0 !j
let byte_level_decode text =
let len = String.length text in
let result = Buffer.create len in
let i = ref 0 in
while !i < len do
let b0 = Char.code (String.unsafe_get text !i) in
if b0 < 128 then begin
(* ASCII: direct lookup, no utf8_next needed *)
let byte = Array.unsafe_get unicode_to_byte b0 in
Buffer.add_char result
(if byte >= 0 then Char.chr byte else Char.unsafe_chr b0);
incr i
end
else begin
let p = utf8_next text !i in
let code = p lsr 3 and clen = p land 7 in
let byte =
if code < Array.length unicode_to_byte then unicode_to_byte.(code)
else -1
in
if byte >= 0 then Buffer.add_char result (Char.chr byte)
else
for j = !i to !i + clen - 1 do
Buffer.add_char result (String.unsafe_get text j)
done;
i := !i + clen
end
done;
Buffer.contents result
let[@inline] is_other code =
(not (is_whitespace code))
&& (not (is_alphabetic code))
&& not (is_numeric code)
let split_gpt2_pattern text =
let len = String.length text in
if len = 0 then []
else
let spans = ref [] in
let pos = ref 0 in
(* Try: optional leading space + run of chars matching a class.
[ascii_mask]: bitmask into ascii_props for the ASCII fast path. [invert]:
when true, match chars where (props land mask) = 0. [classify]: predicate
for non-ASCII codepoints (slow path only). *)
let try_space_run ~ascii_mask ~invert ~classify () =
let start = !pos in
let b0 = Char.code (String.unsafe_get text !pos) in
let has_space =
if b0 < 128 then Array.unsafe_get ascii_props b0 land 1 <> 0
else is_whitespace b0
in
let run_start = if has_space then start + 1 else start in
if run_start < len then
let b = Char.code (String.unsafe_get text run_start) in
let ok, clen =
if b < 128 then
let v = Array.unsafe_get ascii_props b land ascii_mask in
((if invert then v = 0 else v <> 0), 1)
else
let p = utf8_next text run_start in
let code = p lsr 3 and cl = p land 7 in
(classify code, cl)
in
if ok then (
let j = ref (run_start + clen) in
let continue = ref true in
while !j < len && !continue do
let b = Char.code (String.unsafe_get text !j) in
if b < 128 then
let v = Array.unsafe_get ascii_props b land ascii_mask in
if if invert then v = 0 else v <> 0 then j := !j + 1
else continue := false
else
let p = utf8_next text !j in
if classify (p lsr 3) then j := !j + (p land 7)
else continue := false
done;
spans := (start, !j - start) :: !spans;
pos := !j;
true)
else false
else false
in
let[@inline] next_is_alnum next_pos =
if next_pos >= len then false
else
let nb = Char.code (String.unsafe_get text next_pos) in
if nb < 128 then Array.unsafe_get ascii_props nb land 6 <> 0
else
let nc = utf8_next text next_pos lsr 3 in
is_alphabetic nc || is_numeric nc
in
let rec loop () =
if !pos >= len then ()
else begin
(* 1. Contractions: 's 't 'm 'd 're 've 'll *)
let matched_contraction =
text.[!pos] = '\''
&&
let remaining = len - !pos in
remaining >= 2
&&
let c1 = String.unsafe_get text (!pos + 1) in
if c1 = 's' || c1 = 't' || c1 = 'm' || c1 = 'd' then (
spans := (!pos, 2) :: !spans;
pos := !pos + 2;
true)
else
remaining >= 3
&&
let c2 = String.unsafe_get text (!pos + 2) in
if
(c1 = 'r' && c2 = 'e')
|| (c1 = 'v' && c2 = 'e')
|| (c1 = 'l' && c2 = 'l')
then (
spans := (!pos, 3) :: !spans;
pos := !pos + 3;
true)
else false
in
if matched_contraction then ()
else if
try_space_run ~ascii_mask:2 ~invert:false ~classify:is_alphabetic ()
then ()
else if
try_space_run ~ascii_mask:4 ~invert:false ~classify:is_numeric ()
then ()
else if try_space_run ~ascii_mask:7 ~invert:true ~classify:is_other ()
then ()
(* 5 & 6. Whitespace run *)
else begin
let b0 = Char.code (String.unsafe_get text !pos) in
let is_ws, clen =
if b0 < 128 then (Array.unsafe_get ascii_props b0 land 1 <> 0, 1)
else
let p = utf8_next text !pos in
let code = p lsr 3 and cl = p land 7 in
(is_whitespace code, cl)
in
if is_ws then begin
let j = ref (!pos + clen) in
let continue = ref true in
while !j < len && !continue do
let b = Char.code (String.unsafe_get text !j) in
if b < 128 then
if Array.unsafe_get ascii_props b land 1 <> 0 then
if next_is_alnum (!j + 1) && b = 0x20 then continue := false
else j := !j + 1
else continue := false
else
let p = utf8_next text !j in
let code = p lsr 3 and cl = p land 7 in
if is_whitespace code then
if next_is_alnum (!j + cl) && code = 0x20 then
continue := false
else j := !j + cl
else continue := false
done;
spans := (!pos, !j - !pos) :: !spans;
pos := !j
end
else begin
(* Fallback: single character *)
spans := (!pos, clen) :: !spans;
pos := !pos + clen
end
end;
loop ()
end
in
loop ();
List.rev !spans
(* Pre-tokenize implementations *)
let pre_tokenize_whitespace_split text =
let pieces = ref [] in
let start = ref (-1) in
let i = ref 0 in
let len = String.length text in
let flush () =
if !start >= 0 then begin
pieces := (String.sub text !start (!i - !start), (!start, !i)) :: !pieces;
start := -1
end
in
while !i < len do
let b = Char.code (String.unsafe_get text !i) in
if b < 128 then
if Array.unsafe_get ascii_props b land 1 <> 0 then (
flush ();
i := !i + 1)
else (
if !start < 0 then start := !i;
i := !i + 1)
else
let p = utf8_next text !i in
let code = p lsr 3 and l = p land 7 in
if is_whitespace code then (
flush ();
i := !i + l)
else (
if !start < 0 then start := !i;
i := !i + l)
done;
flush ();
List.rev !pieces
let pre_tokenize_whitespace text =
let pieces = ref [] in
let start = ref (-1) in
let i = ref 0 in
let len = String.length text in
let in_word = ref false in
let in_punct = ref false in
let flush () =
if !start >= 0 then begin
pieces := (String.sub text !start (!i - !start), (!start, !i)) :: !pieces;
start := -1
end
in
while !i < len do
let b = Char.code (String.unsafe_get text !i) in
if b < 128 then
let p = Array.unsafe_get ascii_props b in
if p land 6 <> 0 || b = 95 then (
if !in_punct then flush ();
if !start < 0 then start := !i;
in_word := true;
in_punct := false;
i := !i + 1)
else if p land 1 <> 0 then (
flush ();
in_word := false;
in_punct := false;
i := !i + 1)
else (
if !in_word then flush ();
if !start < 0 then start := !i;
in_word := false;
in_punct := true;
i := !i + 1)
else
let p = utf8_next text !i in
let code = p lsr 3 and l = p land 7 in
if is_alphabetic code || is_numeric code then (
if !in_punct then flush ();
if !start < 0 then start := !i;
in_word := true;
in_punct := false;
i := !i + l)
else if is_whitespace code then (
flush ();
in_word := false;
in_punct := false;
i := !i + l)
else (
if !in_word then flush ();
if !start < 0 then start := !i;
in_word := false;
in_punct := true;
i := !i + l)
done;
flush ();
List.rev !pieces
let pre_tokenize_byte_level ~add_prefix_space ~use_regex ~trim_offsets:_ text =
let orig_len = String.length text in
let text, prefix_added =
if
add_prefix_space && orig_len > 0
&& not (is_whitespace (Char.code text.[0]))
then (" " ^ text, true)
else (text, false)
in
if use_regex then
let spans = split_gpt2_pattern text in
List.map
(fun (start, plen) ->
let o_start =
if prefix_added then if start = 0 then 0 else start - 1 else start
in
let o_end =
min orig_len (if prefix_added then start + plen - 1 else start + plen)
in
(byte_level_encode_range text ~start ~len:plen, (max 0 o_start, o_end)))
spans
else [ (byte_level_encode text, (0, orig_len)) ]
let pre_tokenize_bert text =
let pieces = ref [] in
let start = ref (-1) in
let i = ref 0 in
let len = String.length text in
let flush () =
if !start >= 0 then begin
pieces := (String.sub text !start (!i - !start), (!start, !i)) :: !pieces;
start := -1
end
in
while !i < len do
let b = Char.code (String.unsafe_get text !i) in
if b < 128 then
let p = Array.unsafe_get ascii_props b in
if p land 1 <> 0 then (
flush ();
i := !i + 1)
else if p land 8 <> 0 then (
flush ();
pieces := (String.sub text !i 1, (!i, !i + 1)) :: !pieces;
i := !i + 1)
else (
if !start < 0 then start := !i;
i := !i + 1)
else
let p = utf8_next text !i in
let code = p lsr 3 and l = p land 7 in
if is_whitespace code then (
flush ();
i := !i + l)
else if is_punctuation code then (
flush ();
pieces := (String.sub text !i l, (!i, !i + l)) :: !pieces;
i := !i + l)
else (
if !start < 0 then start := !i;
i := !i + l)
done;
flush ();
List.rev !pieces
let pre_tokenize_punctuation ~behavior text =
let pieces = ref [] in
let start = ref (-1) in
let i = ref 0 in
let len = String.length text in
let last_was_punc = ref false in
let flush () =
if !start >= 0 then begin
pieces := (String.sub text !start (!i - !start), (!start, !i)) :: !pieces;
start := -1
end
in
let handle_char is_p l =
if is_p then (
(match behavior with
| `Isolated ->
flush ();
pieces := (String.sub text !i l, (!i, !i + l)) :: !pieces
| `Removed -> flush ()
| `Merged_with_previous -> if !start < 0 then start := !i
| `Merged_with_next ->
flush ();
start := !i
| `Contiguous ->
if not (!start >= 0 && !last_was_punc) then begin
flush ();
start := !i
end);
last_was_punc := true;
i := !i + l)
else (
if behavior = `Contiguous && !start >= 0 && !last_was_punc then flush ();
if !start < 0 then start := !i;
i := !i + l;
last_was_punc := false)
in
while !i < len do
let b = Char.code (String.unsafe_get text !i) in
if b < 128 then handle_char (Array.unsafe_get ascii_props b land 8 <> 0) 1
else
let p = utf8_next text !i in
let code = p lsr 3 and l = p land 7 in
handle_char (is_punctuation code) l
done;
flush ();
List.rev !pieces
let pre_tokenize_split ~pattern ~behavior ~invert text =
let plen = String.length pattern in
if plen = 0 then [ (text, (0, String.length text)) ]
else
let pieces = ref [] in
let current = Buffer.create 16 in
let current_start = ref 0 in
let i = ref 0 in
let flush_current () =
if Buffer.length current > 0 then (
pieces :=
( Buffer.contents current,
(!current_start, !current_start + Buffer.length current) )
:: !pieces;
Buffer.clear current)
in
while !i < String.length text do
let is_match =
!i + plen <= String.length text && String.sub text !i plen = pattern
in
let is_delim = if invert then not is_match else is_match in
let delim_len = if is_delim then if invert then 1 else plen else 1 in
if is_delim then (
(match behavior with
| `Removed -> flush_current ()
| `Isolated ->
flush_current ();
let delim_str = String.sub text !i delim_len in
pieces := (delim_str, (!i, !i + delim_len)) :: !pieces
| `Merged_with_previous ->
Buffer.add_string current (String.sub text !i delim_len);
flush_current ()
| `Merged_with_next ->
flush_current ();
current_start := !i;
Buffer.add_string current (String.sub text !i delim_len)
| `Contiguous ->
if Buffer.length current > 0 && is_delim then
Buffer.add_string current (String.sub text !i delim_len)
else (
flush_current ();
Buffer.add_string current (String.sub text !i delim_len)));
i := !i + delim_len)
else (
if Buffer.length current = 0 then current_start := !i;
Buffer.add_string current (String.sub text !i 1);
i := !i + 1)
done;
flush_current ();
List.rev !pieces
let pre_tokenize_digits ~individual text =
let pieces = ref [] in
let start = ref (-1) in
let i = ref 0 in
let len = String.length text in
let in_digits = ref false in
let flush () =
if !start >= 0 then begin
pieces := (String.sub text !start (!i - !start), (!start, !i)) :: !pieces;
start := -1
end
in
let handle_char is_d l =
if individual && is_d then (
flush ();
pieces := (String.sub text !i l, (!i, !i + l)) :: !pieces;
i := !i + l)
else (
if is_d <> !in_digits then (
flush ();
in_digits := is_d);
if !start < 0 then start := !i;
i := !i + l)
in
while !i < len do
let b = Char.code (String.unsafe_get text !i) in
if b < 128 then handle_char (Array.unsafe_get ascii_props b land 4 <> 0) 1
else
let p = utf8_next text !i in
let code = p lsr 3 and l = p land 7 in
handle_char (is_numeric code) l
done;
flush ();
List.rev !pieces
let pre_tokenize_metaspace ~replacement ~prepend_scheme ~split text =
let repl = String.make 1 replacement in
let text =
match prepend_scheme with
| (`Always | `First) when String.length text > 0 && text.[0] <> ' ' ->
" " ^ text
| _ -> text
in
let len = String.length text in
let buf = Buffer.create len in
let i = ref 0 in
while !i < len do
if text.[!i] = ' ' then (
Buffer.add_string buf repl;
incr i)
else
let l = utf8_next text !i land 7 in
Buffer.add_substring buf text !i l;
i := !i + l
done;
let transformed = Buffer.contents buf in
if split then (
let tlen = String.length transformed in
let rlen = String.length repl in
let splits = ref [] in
let start = ref 0 in
let pos = ref 0 in
while !pos < tlen do
if !pos + rlen <= tlen && String.sub transformed !pos rlen = repl then (
if !pos > !start then
splits :=
(String.sub transformed !start (!pos - !start), (!start, !pos))
:: !splits;
start := !pos;
pos := !pos + rlen)
else incr pos
done;
if !pos > !start then
splits :=
(String.sub transformed !start (!pos - !start), (!start, !pos))
:: !splits;
List.rev !splits)
else [ (transformed, (0, len)) ]
let pre_tokenize_fixed_length ~length text =
if length <= 0 || String.length text = 0 then []
else
let pieces = ref [] in
let len = String.length text in
let i = ref 0 in
while !i < len do
let start = !i in
let count = ref 0 in
while !i < len && !count < length do
let l = utf8_next text !i land 7 in
i := !i + l;
incr count
done;
pieces := (String.sub text start (!i - start), (start, !i)) :: !pieces
done;
List.rev !pieces
type script = [ `Any | Uucp.Script.t ]
let fixed_script code : script =
if code = 0x30FC then (`Hani :> script)
else if is_whitespace code then `Any
else
match Uucp.Script.script (Uchar.of_int code) with
| `Hira | `Kana -> (`Hani :> script)
| s -> (s :> script)
let pre_tokenize_unicode_scripts text =
let pieces = ref [] in
let start = ref (-1) in
let len = String.length text in
let i = ref 0 in
let last_script = ref None in
let flush () =
if !start >= 0 then begin
pieces := (String.sub text !start (!i - !start), (!start, !i)) :: !pieces;
start := -1
end
in
let emit (script : script) l =
if
script <> `Any && !last_script <> Some `Any && !last_script <> Some script
then flush ();
if !start < 0 then start := !i;
i := !i + l;
if script <> `Any then last_script := Some script
in
while !i < len do
let b = Char.code (String.unsafe_get text !i) in
if b < 128 then
let p = Array.unsafe_get ascii_props b in
let script : script =
if p land 1 <> 0 then `Any else if p land 2 <> 0 then `Latn else `Zyyy
in
emit script 1
else
let p = utf8_next text !i in
let code = p lsr 3 and l = p land 7 in
emit (fixed_script code) l
done;
flush ();
List.rev !pieces
(* Constructors *)
let whitespace () = Whitespace
let whitespace_split () = Whitespace_split
let bert () = Bert
let byte_level ?(add_prefix_space = true) ?(use_regex = true)
?(trim_offsets = true) () =
Byte_level { add_prefix_space; use_regex; trim_offsets }
let punctuation ?(behavior = `Isolated) () = Punctuation { behavior }
let split ~pattern ?(behavior = `Removed) ?(invert = false) () =
Split { pattern; behavior; invert }
let char_delimiter c = Char_delimiter c
let digits ?(individual_digits = false) () =
Digits { individual = individual_digits }
let metaspace ?(replacement = '_') ?(prepend_scheme = `Always) ?(split = true)
() =
Metaspace { replacement; prepend_scheme; split }
let unicode_scripts () = Unicode_scripts
let fixed_length n = Fixed_length { length = n }
let sequence ts = Sequence ts
(* Dispatch *)
let rec pre_tokenize t text =
match t with
| Whitespace -> pre_tokenize_whitespace text
| Whitespace_split -> pre_tokenize_whitespace_split text
| Bert -> pre_tokenize_bert text
| Byte_level { add_prefix_space; use_regex; trim_offsets } ->
pre_tokenize_byte_level ~add_prefix_space ~use_regex ~trim_offsets text
| Punctuation { behavior } -> pre_tokenize_punctuation ~behavior text
| Split { pattern; behavior; invert } ->
pre_tokenize_split ~pattern ~behavior ~invert text
| Char_delimiter c ->
pre_tokenize_split ~pattern:(String.make 1 c) ~behavior:`Removed
~invert:false text
| Digits { individual } -> pre_tokenize_digits ~individual text
| Metaspace { replacement; prepend_scheme; split } ->
pre_tokenize_metaspace ~replacement ~prepend_scheme ~split text
| Unicode_scripts -> pre_tokenize_unicode_scripts text
| Fixed_length { length } -> pre_tokenize_fixed_length ~length text
| Sequence ts -> pre_tokenize_sequence ts text
and pre_tokenize_sequence ts text =
let initial = [ (text, (0, String.length text)) ] in
List.fold_left
(fun pieces t ->
List.concat_map
(fun (s, (o_start, _)) ->
let sub_pieces = pre_tokenize t s in
List.map
(fun (p, (p_start, p_end)) ->
(p, (o_start + p_start, o_start + p_end)))
sub_pieces)
pieces)
initial ts
(* Serialization *)
let json_obj pairs =
Jsont.Json.object' (List.map (fun (k, v) -> (Jsont.Json.name k, v)) pairs)
let behavior_to_string = function
| `Isolated -> "Isolated"
| `Removed -> "Removed"
| `Merged_with_previous -> "MergedWithPrevious"
| `Merged_with_next -> "MergedWithNext"
| `Contiguous -> "Contiguous"
let behavior_of_string = function
| "Isolated" -> Ok `Isolated
| "Removed" -> Ok `Removed
| "MergedWithPrevious" -> Ok `Merged_with_previous
| "MergedWithNext" -> Ok `Merged_with_next
| "Contiguous" -> Ok `Contiguous
| other -> Error (err_unknown_behavior other)
let scheme_to_string = function
| `First -> "First"
| `Never -> "Never"
| `Always -> "Always"
let scheme_of_string = function
| "First" -> Ok `First
| "Never" -> Ok `Never
| "Always" -> Ok `Always
| other -> Error (err_unknown_scheme other)
(* Formatting *)
let rec pp ppf = function
| Byte_level { add_prefix_space; use_regex; trim_offsets } ->
Format.fprintf ppf
"@[<1>ByteLevel(add_prefix_space=%b,@ use_regex=%b,@ trim_offsets=%b)@]"
add_prefix_space use_regex trim_offsets
| Bert -> Format.pp_print_string ppf "Bert"
| Whitespace -> Format.pp_print_string ppf "Whitespace"
| Whitespace_split -> Format.pp_print_string ppf "WhitespaceSplit"
| Punctuation { behavior } ->
Format.fprintf ppf "@[<1>Punctuation(%s)@]" (behavior_to_string behavior)
| Split { pattern; behavior; invert } ->
Format.fprintf ppf "@[<1>Split(%S,@ %s,@ invert=%b)@]" pattern
(behavior_to_string behavior)
invert
| Char_delimiter c -> Format.fprintf ppf "CharDelimiter(%C)" c
| Digits { individual } ->
Format.fprintf ppf "Digits(individual=%b)" individual
| Metaspace { replacement; prepend_scheme; split } ->
Format.fprintf ppf "@[<1>Metaspace(%C,@ %s,@ split=%b)@]" replacement
(scheme_to_string prepend_scheme)
split
| Sequence ts ->
Format.fprintf ppf "@[<1>Sequence[%a]@]"
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf ",@ ")
pp)
ts
| Fixed_length { length } -> Format.fprintf ppf "FixedLength(%d)" length
| Unicode_scripts -> Format.pp_print_string ppf "UnicodeScripts"
let rec to_json = function
| Byte_level { add_prefix_space; use_regex; trim_offsets } ->
json_obj
[
("type", Jsont.Json.string "ByteLevel");
("add_prefix_space", Jsont.Json.bool add_prefix_space);
("use_regex", Jsont.Json.bool use_regex);
("trim_offsets", Jsont.Json.bool trim_offsets);
]
| Bert -> json_obj [ ("type", Jsont.Json.string "BertPreTokenizer") ]
| Whitespace -> json_obj [ ("type", Jsont.Json.string "Whitespace") ]
| Whitespace_split ->
json_obj [ ("type", Jsont.Json.string "WhitespaceSplit") ]
| Punctuation { behavior } ->
json_obj
[
("type", Jsont.Json.string "Punctuation");
("behavior", Jsont.Json.string (behavior_to_string behavior));
]
| Split { pattern; behavior; invert } ->
json_obj
[
("type", Jsont.Json.string "Split");
("pattern", Jsont.Json.string pattern);
("behavior", Jsont.Json.string (behavior_to_string behavior));
("invert", Jsont.Json.bool invert);
]
| Char_delimiter delimiter ->
json_obj
[
("type", Jsont.Json.string "CharDelimiterSplit");
("delimiter", Jsont.Json.string (String.make 1 delimiter));
]
| Digits { individual } ->
json_obj
[
("type", Jsont.Json.string "Digits");
("individual_digits", Jsont.Json.bool individual);
]
| Metaspace { replacement; prepend_scheme; split } ->
json_obj
[
("type", Jsont.Json.string "Metaspace");
("replacement", Jsont.Json.string (String.make 1 replacement));
("prepend_scheme", Jsont.Json.string (scheme_to_string prepend_scheme));
("split", Jsont.Json.bool split);
]
| Sequence ts ->
json_obj
[
("type", Jsont.Json.string "Sequence");
("pretokenizers", Jsont.Json.list (List.map to_json ts));
]
| Fixed_length { length } ->
json_obj
[
("type", Jsont.Json.string "FixedLength");
("length", Jsont.Json.int length);
]
| Unicode_scripts -> json_obj [ ("type", Jsont.Json.string "UnicodeScripts") ]
let find_field name fields = Option.map snd (Jsont.Json.find_mem name fields)
let bool_field name default fields =
match find_field name fields with
| Some (Jsont.Bool (b, _)) -> b
| Some (Jsont.Number (f, _)) -> int_of_float f <> 0
| Some (Jsont.String (s, _)) -> (
match String.lowercase_ascii s with
| "true" | "1" -> true
| "false" | "0" -> false
| _ -> default)
| _ -> default
let int_field name default fields =
match find_field name fields with
| Some (Jsont.Number (f, _)) -> int_of_float f
| Some (Jsont.String (s, _)) -> (
match int_of_string_opt s with Some v -> v | None -> default)
| _ -> default
let char_of_field name = function
| Jsont.String (s, _) when String.length s = 1 -> Ok s.[0]
| _ -> Error (err_expected_char name)
let rec of_json = function
| Jsont.Object (fields, _) -> (
match find_field "type" fields with
| Some (Jsont.String ("ByteLevel", _)) ->
let add_prefix_space = bool_field "add_prefix_space" true fields in
let use_regex = bool_field "use_regex" true fields in
let trim_offsets = bool_field "trim_offsets" true fields in
Ok (Byte_level { add_prefix_space; use_regex; trim_offsets })
| Some (Jsont.String ("BertPreTokenizer", _)) -> Ok Bert
| Some (Jsont.String ("Whitespace", _)) -> Ok Whitespace
| Some (Jsont.String ("WhitespaceSplit", _)) -> Ok Whitespace_split
| Some (Jsont.String ("Punctuation", _)) -> (
match find_field "behavior" fields with
| Some (Jsont.String (s, _)) ->
Result.map
(fun b -> Punctuation { behavior = b })
(behavior_of_string s)
| _ -> Error err_missing_behavior)
| Some (Jsont.String ("Split", _)) -> (
match (find_field "pattern" fields, find_field "behavior" fields) with
| ( Some (Jsont.String (pattern, _)),
Some (Jsont.String (behavior_str, _)) ) ->
Result.map
(fun behavior ->
let invert = bool_field "invert" false fields in
Split { pattern; behavior; invert })
(behavior_of_string behavior_str)
| _ -> Error err_split_missing)
| Some (Jsont.String ("CharDelimiterSplit", _)) -> (
match find_field "delimiter" fields with
| Some v ->
Result.map
(fun c -> Char_delimiter c)
(char_of_field "delimiter" v)
| None -> Error err_char_delim_missing)
| Some (Jsont.String ("Digits", _)) ->
let individual = bool_field "individual_digits" false fields in
Ok (Digits { individual })
| Some (Jsont.String ("Metaspace", _)) -> (
match
(find_field "replacement" fields, find_field "prepend_scheme" fields)
with
| Some (Jsont.String (repl, _)), Some (Jsont.String (scheme, _))
when String.length repl = 1 ->
Result.map
(fun prepend_scheme ->
let split = bool_field "split" true fields in
Metaspace { replacement = repl.[0]; prepend_scheme; split })
(scheme_of_string scheme)
| _ -> Error err_metaspace_missing)
| Some (Jsont.String ("Sequence", _)) -> (
match find_field "pretokenizers" fields with
| Some (Jsont.Array (elements, _)) ->
let rec build acc = function
| [] -> Ok (Sequence (List.rev acc))
| item :: rest -> (
match of_json item with
| Ok t -> build (t :: acc) rest
| Error _ as e -> e)
in
build [] elements
| _ -> Error err_sequence_missing)
| Some (Jsont.String ("FixedLength", _)) ->
let length = int_field "length" 0 fields in
if length <= 0 then Error err_fixed_length
else Ok (Fixed_length { length })
| Some (Jsont.String ("UnicodeScripts", _)) -> Ok Unicode_scripts
| Some (Jsont.String (other, _)) -> Error (err_unsupported_type other)
| _ -> Error err_missing_type)
| _ -> Error err_expected_object
================================================
FILE: packages/brot/lib/pre_tokenizer.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Pre-tokenization.
Pre-tokenizers split raw text into pieces before vocabulary-based
tokenization (BPE, WordPiece, etc.) is applied. Each piece carries byte
offsets into the original text.
See {!Brot} for the full tokenization pipeline. *)
type t
(** The type for pre-tokenizers. *)
(** {1:constructors Constructors} *)
val whitespace : unit -> t
(** [whitespace ()] splits on whitespace using pattern [\w+|[^\w\s]+].
Groups word characters (letters, digits, underscore) together and groups
non-word, non-space characters together. Whitespace is used as delimiter but
not included in output. *)
val whitespace_split : unit -> t
(** [whitespace_split ()] splits on any whitespace characters.
Removes whitespace from output. Simplest and fastest pre-tokenizer. *)
val bert : unit -> t
(** [bert ()] applies BERT-style pre-tokenization.
Splits on whitespace, isolates punctuation, and separates CJK characters
individually. *)
val byte_level :
?add_prefix_space:bool -> ?use_regex:bool -> ?trim_offsets:bool -> unit -> t
(** [byte_level ()] is a byte-level pre-tokenizer. Used by GPT-2, GPT-3,
RoBERTa.
Converts text to byte representation and applies GPT-2's regex pattern for
splitting.
- [add_prefix_space]: add space at beginning if text does not start with
whitespace. Default: [true].
- [use_regex]: use GPT-2's regex pattern for splitting. Default: [true].
- [trim_offsets]: adjust offsets for byte-level encoding. Default: [true].
*)
type behavior =
[ `Isolated (** Keep delimiter as separate piece *)
| `Removed (** Remove delimiter *)
| `Merged_with_previous (** Merge delimiter with previous piece *)
| `Merged_with_next (** Merge delimiter with next piece *)
| `Contiguous (** Group consecutive delimiters together *) ]
(** Delimiter handling behavior for splitting operations. *)
val punctuation : ?behavior:behavior -> unit -> t
(** [punctuation ()] separates punctuation from alphanumeric content.
[behavior] defaults to [`Isolated]. *)
val split : pattern:string -> ?behavior:behavior -> ?invert:bool -> unit -> t
(** [split ~pattern ()] splits on a literal string [pattern].
[behavior] defaults to [`Removed]. When [invert] is [true], splits on
everything {e except} the pattern; defaults to [false]. *)
val char_delimiter : char -> t
(** [char_delimiter c] splits on character [c], removing it from output.
Equivalent to [split ~pattern:(String.make 1 c) ~behavior:`Removed ()]. *)
val digits : ?individual_digits:bool -> unit -> t
(** [digits ()] splits on digit boundaries.
When [individual_digits] is [true], each digit is a separate piece; when
[false] (default), consecutive digits are grouped. *)
type prepend_scheme =
[ `First (** Only prepend to first piece *)
| `Never (** Never prepend *)
| `Always (** Always prepend if not starting with space *) ]
(** Controls when metaspace prepends the replacement character. *)
val metaspace :
?replacement:char ->
?prepend_scheme:prepend_scheme ->
?split:bool ->
unit ->
t
(** [metaspace ()] replaces whitespace with a visible marker. Used by
SentencePiece models.
- [replacement]: character to replace spaces with. Default: ['_'].
- [prepend_scheme]: when to prepend the replacement character. Default:
[`Always].
- [split]: whether to split on the replacement character. Default: [true].
*)
val unicode_scripts : unit -> t
(** [unicode_scripts ()] splits on Unicode script boundaries.
Separates text when the writing system changes (e.g., Latin to Cyrillic,
Latin to Han). *)
val fixed_length : int -> t
(** [fixed_length n] splits into fixed-length character chunks.
The last chunk may be shorter than [n]. *)
val sequence : t list -> t
(** [sequence ts] chains multiple pre-tokenizers left-to-right.
Each pre-tokenizer processes the pieces from the previous one. Offsets are
composed correctly through the chain. *)
(** {1 Operations} *)
val pre_tokenize : t -> string -> (string * (int * int)) list
(** [pre_tokenize t text] splits [text] into pieces with character offsets.
Returns a list of [(piece, (start, end_))] where [start] and [end_] are byte
positions in the original [text]. Offsets are non-overlapping and in
ascending order. *)
(** {1 Formatting} *)
val pp : Format.formatter -> t -> unit
(** [pp ppf t] formats [t] for inspection. *)
(** {1:byte_level_decode Byte-level decoding} *)
val byte_level_decode : string -> string
(** [byte_level_decode s] reverses byte-level encoding by converting the special
Unicode codepoints back to original byte values. *)
(** {1 Serialization} *)
val to_json : t -> Jsont.json
(** [to_json t] serializes [t] to HuggingFace JSON format. *)
val of_json : Jsont.json -> (t, string) result
(** [of_json json] is a pre-tokenizer from HuggingFace JSON format. Errors if
[json] is not an object, has a missing or unknown ["type"] field, or has
invalid parameters. *)
================================================
FILE: packages/brot/lib/unigram.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Compact trie for longest-prefix matching *)
type trie = {
trie_ids : int array;
child_starts : int array;
edge_bytes : bytes;
edge_targets : int array;
}
let build_trie token_to_ids =
if Hashtbl.length token_to_ids = 0 then
{
trie_ids = [||];
child_starts = [| 0 |];
edge_bytes = Bytes.empty;
edge_targets = [||];
}
else
let cap = ref 256 in
let ids = ref (Array.make !cap (-1)) in
let ch = ref (Array.init !cap (fun _ -> Hashtbl.create 0)) in
let n = ref 1 in
!ch.(0) <- Hashtbl.create 64;
let grow () =
let new_cap = !cap * 2 in
let new_ids = Array.make new_cap (-1) in
Array.blit !ids 0 new_ids 0 !n;
ids := new_ids;
let new_ch =
Array.init new_cap (fun i ->
if i < !n then !ch.(i) else Hashtbl.create 0)
in
ch := new_ch;
cap := new_cap
in
Hashtbl.iter
(fun key id ->
let cur = ref 0 in
for i = 0 to String.length key - 1 do
let byte = Char.code (String.unsafe_get key i) in
let child =
match Hashtbl.find_opt !ch.(!cur) byte with
| Some c -> c
| None ->
if !n >= !cap then grow ();
let c = !n in
incr n;
!ch.(c) <- Hashtbl.create 4;
Hashtbl.add !ch.(!cur) byte c;
c
in
cur := child
done;
!ids.(!cur) <- id)
token_to_ids;
let node_count = !n in
let trie_ids = Array.init node_count (fun i -> !ids.(i)) in
let child_starts = Array.make (node_count + 1) 0 in
let total = ref 0 in
for i = 0 to node_count - 1 do
child_starts.(i) <- !total;
total := !total + Hashtbl.length !ch.(i)
done;
child_starts.(node_count) <- !total;
let edge_bytes = Bytes.create !total in
let edge_targets = Array.make !total 0 in
let pos = ref 0 in
for i = 0 to node_count - 1 do
Hashtbl.iter
(fun byte child ->
Bytes.unsafe_set edge_bytes !pos (Char.unsafe_chr byte);
edge_targets.(!pos) <- child;
incr pos)
!ch.(i)
done;
for i = 0 to node_count - 1 do
let start = child_starts.(i) in
let stop = child_starts.(i + 1) in
for j = start + 1 to stop - 1 do
let kb = Bytes.unsafe_get edge_bytes j in
let kt = edge_targets.(j) in
let k = ref (j - 1) in
while !k >= start && Bytes.unsafe_get edge_bytes !k > kb do
Bytes.unsafe_set edge_bytes (!k + 1) (Bytes.unsafe_get edge_bytes !k);
edge_targets.(!k + 1) <- edge_targets.(!k);
decr k
done;
Bytes.unsafe_set edge_bytes (!k + 1) kb;
edge_targets.(!k + 1) <- kt
done
done;
{ trie_ids; child_starts; edge_bytes; edge_targets }
let[@inline] trie_step trie node byte =
let lo = ref (Array.unsafe_get trie.child_starts node) in
let hi = ref (Array.unsafe_get trie.child_starts (node + 1) - 1) in
let result = ref (-1) in
while !lo <= !hi do
let mid = !lo + ((!hi - !lo) asr 1) in
let mid_byte = Char.code (Bytes.unsafe_get trie.edge_bytes mid) in
if mid_byte = byte then (
result := Array.unsafe_get trie.edge_targets mid;
lo := !hi + 1)
else if mid_byte < byte then lo := mid + 1
else hi := mid - 1
done;
!result
let trie_longest_match trie text ~start =
if Array.length trie.trie_ids = 0 then None
else
let text_len = String.length text in
let last_id = ref (-1) in
let last_end = ref start in
let current = ref 0 in
let stopped = ref false in
let j = ref start in
while !j < text_len && not !stopped do
let child =
trie_step trie !current (Char.code (String.unsafe_get text !j))
in
if child < 0 then stopped := true
else (
current := child;
incr j;
let tid = Array.unsafe_get trie.trie_ids child in
if tid >= 0 then (
last_id := tid;
last_end := !j))
done;
if !last_id >= 0 then Some (!last_id, !last_end) else None
(* Model type *)
type t = {
vocab : (string * float) array;
token_to_ids : (string, int) Hashtbl.t;
trie : trie;
}
let create vocab_list =
let vocab = Array.of_list vocab_list in
let token_to_ids = Hashtbl.create (Array.length vocab) in
Array.iteri
(fun idx (token, _) -> Hashtbl.replace token_to_ids token idx)
vocab;
let trie = build_trie token_to_ids in
{ vocab; token_to_ids; trie }
let token_to_id model token = Hashtbl.find_opt model.token_to_ids token
let id_to_token model id =
if id >= 0 && id < Array.length model.vocab then
let token, _ = model.vocab.(id) in
Some token
else None
let get_vocab model = Array.to_list model.vocab
let get_vocab_size model = Array.length model.vocab
let tokenize model text =
let len = String.length text in
let rec consume pos acc =
if pos >= len then List.rev acc
else if
text.[pos] = ' '
|| text.[pos] = '\n'
|| text.[pos] = '\t'
|| text.[pos] = '\r'
then consume (pos + 1) acc
else
match trie_longest_match model.trie text ~start:pos with
| Some (id, end_pos) ->
let s = String.sub text pos (end_pos - pos) in
consume end_pos ((id, s, (pos, end_pos)) :: acc)
| None ->
let s = String.sub text pos 1 in
let id = match token_to_id model s with Some id -> id | None -> 0 in
consume (pos + 1) ((id, s, (pos, pos + 1)) :: acc)
in
consume 0 []
let json_obj pairs =
Jsont.Json.object' (List.map (fun (k, v) -> (Jsont.Json.name k, v)) pairs)
let json_to_string j =
match Jsont_bytesrw.encode_string ~format:Jsont.Minify Jsont.json j with
| Ok s -> s
| Error e -> failwith e
let save model ~folder () =
let json_vocab =
Array.to_list model.vocab
|> List.mapi (fun id (token, prob) ->
json_obj
[
("id", Jsont.Json.int id);
("token", Jsont.Json.string token);
("prob", Jsont.Json.number prob);
])
in
let json =
json_obj
[
("type", Jsont.Json.string "Unigram");
("vocab", Jsont.Json.list json_vocab);
]
in
let path = Filename.concat folder "unigram.json" in
let oc = open_out path in
Fun.protect
~finally:(fun () -> close_out oc)
(fun () -> output_string oc (json_to_string json));
[ "unigram.json" ]
let train ~vocab_size ~show_progress ~special_tokens ~shrinking_factor
~unk_token ~max_piece_length ~n_sub_iterations texts existing =
let _ =
( show_progress,
shrinking_factor,
unk_token,
max_piece_length,
n_sub_iterations,
existing )
in
let counts = Hashtbl.create 10000 in
List.iter
(fun line ->
let words = Re.split (Re.compile (Re.rep1 (Re.set " \t\n\r"))) line in
List.iter
(fun word ->
if word <> "" then
Hashtbl.replace counts word
(1 + Option.value ~default:0 (Hashtbl.find_opt counts word)))
words)
texts;
let total =
Hashtbl.fold (fun _ count acc -> acc + count) counts 0 |> float_of_int
in
let sorted =
Hashtbl.fold (fun token count acc -> (token, count) :: acc) counts []
|> List.sort (fun (_, c1) (_, c2) -> compare c2 c1)
in
let take_first n lst =
let rec aux i = function
| [] -> []
| _ when i = 0 -> []
| x :: xs -> x :: aux (i - 1) xs
in
aux n lst
in
let selected = take_first vocab_size sorted in
let vocab_with_probs =
special_tokens
|> List.map (fun token -> (token, 1.0 /. float_of_int (vocab_size + 1)))
|> fun specials ->
specials
@ List.map
(fun (token, count) ->
let prob = if total = 0. then 0. else float_of_int count /. total in
(token, prob))
selected
in
let model = create vocab_with_probs in
(model, special_tokens)
================================================
FILE: packages/brot/lib/unigram.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Unigram language model tokenization.
{b Internal module.} Probabilistic subword tokenization using token
log-probabilities. Used by SentencePiece, AlBERT, T5, and mBART.
Tokenization uses greedy longest-prefix matching via a compact trie with
sorted edges and binary-search dispatch. At each byte position the longest
vocabulary match is consumed. Unknown single characters default to ID [0].
*)
type t
(** The type for unigram models. *)
(** {1:creation Creation} *)
val create : (string * float) list -> t
(** [create vocab] is a unigram model from [(token, log_probability)] pairs. The
trie is built at creation time. *)
(** {1:tokenization Tokenization} *)
val tokenize : t -> string -> (int * string * (int * int)) list
(** [tokenize t s] is the tokenization of [s] as [(id, token, (start, stop))]
triples. Offsets are byte positions in [s]. *)
(** {1:vocabulary Vocabulary} *)
val token_to_id : t -> string -> int option
(** [token_to_id t tok] is the ID of [tok] in the vocabulary. *)
val id_to_token : t -> int -> string option
(** [id_to_token t id] is the token string for [id]. *)
val get_vocab : t -> (string * float) list
(** [get_vocab t] is the vocabulary as [(token, score)] pairs. *)
val get_vocab_size : t -> int
(** [get_vocab_size t] is the number of tokens in the vocabulary. *)
(** {1:serialization Serialization} *)
val save : t -> folder:string -> unit -> string list
(** [save t ~folder ()] writes [unigram.json] to [folder]. The file contains
each token with its ID and log-probability in JSON format. Returns the list
of created filenames. *)
(** {1:training Training} *)
val train :
vocab_size:int ->
show_progress:bool ->
special_tokens:string list ->
shrinking_factor:float ->
unk_token:string option ->
max_piece_length:int ->
n_sub_iterations:int ->
string list ->
t option ->
t * string list
(** [train ~vocab_size ~show_progress ~special_tokens ~shrinking_factor
~unk_token ~max_piece_length ~n_sub_iterations texts init] learns a unigram
model from [texts].
- [vocab_size] is the target vocabulary size.
- [show_progress] enables progress output on [stderr].
- [special_tokens] are added to the vocabulary first.
- [shrinking_factor] controls vocabulary pruning rate.
- [unk_token] is the unknown token, if any.
- [max_piece_length] limits the byte length of vocabulary pieces.
- [n_sub_iterations] is the number of EM sub-iterations.
- [init], when provided, seeds the vocabulary from an existing model.
Returns [(model, special_tokens)]. *)
================================================
FILE: packages/brot/lib/word_level.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
type t = {
vocab : (string, int) Hashtbl.t;
vocab_r : (int, string) Hashtbl.t;
unk_token : string;
}
let create ?(vocab = []) ?(unk_token = "") () =
let size = max 1 (List.length vocab) in
let vocab_tbl = Hashtbl.create size in
let vocab_r_tbl = Hashtbl.create size in
List.iter
(fun (token, id) ->
Hashtbl.replace vocab_tbl token id;
Hashtbl.replace vocab_r_tbl id token)
vocab;
{ vocab = vocab_tbl; vocab_r = vocab_r_tbl; unk_token }
let add_token vocab vocab_r token id =
Hashtbl.replace vocab token id;
Hashtbl.replace vocab_r id token
let tokenize model text =
if String.length text = 0 then []
else
(* Match HuggingFace tokenizers semantics exactly: 1. Try to find token in
vocab 2. Fall back to UNK token if available 3. Return empty list if
neither exists (error case) *)
match Hashtbl.find_opt model.vocab text with
| Some id -> [ (id, text, (0, String.length text)) ]
| None -> (
match Hashtbl.find_opt model.vocab model.unk_token with
| Some unk_id -> [ (unk_id, model.unk_token, (0, String.length text)) ]
| None -> [] (* Token not found and no UNK token - return empty *))
let tokenize_ids model text =
if String.length text = 0 then [||]
else
match Hashtbl.find_opt model.vocab text with
| Some id -> [| id |]
| None -> (
match Hashtbl.find_opt model.vocab model.unk_token with
| Some unk_id -> [| unk_id |]
| None -> [||])
let token_to_id model token = Hashtbl.find_opt model.vocab token
let id_to_token model id = Hashtbl.find_opt model.vocab_r id
let get_vocab model =
Hashtbl.fold (fun token id acc -> (token, id) :: acc) model.vocab []
let get_vocab_size model = Hashtbl.length model.vocab
let add_tokens model tokens =
let start_id = Hashtbl.length model.vocab in
let count = ref 0 in
List.iteri
(fun i token ->
if not (Hashtbl.mem model.vocab token) then (
add_token model.vocab model.vocab_r token (start_id + i);
incr count))
tokens;
!count
let json_obj pairs =
Jsont.Json.object' (List.map (fun (k, v) -> (Jsont.Json.name k, v)) pairs)
let json_to_string j =
match Jsont_bytesrw.encode_string ~format:Jsont.Minify Jsont.json j with
| Ok s -> s
| Error e -> failwith e
let save model ~folder () =
let vocab_items =
get_vocab model
|> List.sort (fun (_, id1) (_, id2) -> compare id1 id2)
|> List.map (fun (token, id) ->
json_obj
[ ("token", Jsont.Json.string token); ("id", Jsont.Json.int id) ])
in
let json =
json_obj
[
("type", Jsont.Json.string "WordLevel");
("unk_token", Jsont.Json.string model.unk_token);
("vocab", Jsont.Json.list vocab_items);
]
in
let path = Filename.concat folder "wordlevel.json" in
let oc = open_out path in
Fun.protect
~finally:(fun () -> close_out oc)
(fun () -> output_string oc (json_to_string json));
[ "wordlevel.json" ]
let train ~vocab_size ~min_frequency ~show_progress ~special_tokens texts
existing =
let _ = show_progress in
let counts = Hashtbl.create 10000 in
List.iter
(fun line ->
let words = Re.split (Re.compile (Re.rep1 (Re.set " \t\n\r"))) line in
List.iter
(fun word ->
if word <> "" then
Hashtbl.replace counts word
(1 + Option.value ~default:0 (Hashtbl.find_opt counts word)))
words)
texts;
let items =
Hashtbl.fold
(fun word count acc ->
if count >= min_frequency then (word, count) :: acc else acc)
counts []
|> List.sort (fun (_, c1) (_, c2) -> compare c2 c1)
in
let vocab_items = ref [] in
let idx = ref 0 in
List.iter
(fun token ->
if !idx < vocab_size then (
vocab_items := (fst token, !idx) :: !vocab_items;
incr idx))
items;
let vocab_items = List.rev !vocab_items in
let specials = List.mapi (fun i token -> (token, i)) special_tokens in
let vocab = specials @ vocab_items in
let model =
match existing with
| Some model ->
model.vocab |> Hashtbl.clear;
model.vocab_r |> Hashtbl.clear;
List.iter
(fun (token, id) -> add_token model.vocab model.vocab_r token id)
vocab;
model
| None -> create ~vocab ()
in
(model, special_tokens)
================================================
FILE: packages/brot/lib/word_level.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Word-level tokenization model.
{b Internal module.} Direct vocabulary lookup with no subword splitting.
Each input word is mapped to a single token ID via exact string match. Words
not in the vocabulary are replaced by [unk_token]. *)
type t
(** The type for word-level models. *)
(** {1:creation Creation} *)
val create : ?vocab:(string * int) list -> ?unk_token:string -> unit -> t
(** [create ?vocab ?unk_token ()] is a word-level model.
- [vocab] is the initial vocabulary as [(token, id)] pairs. Defaults to
[[]].
- [unk_token] is the token emitted for unknown words. Defaults to ["[UNK]"].
*)
(** {1:tokenization Tokenization} *)
val tokenize : t -> string -> (int * string * (int * int)) list
(** [tokenize t s] is [[(id, token, (start, stop))]] for [s]. If [s] is not in
the vocabulary, [unk_token] is used. If [unk_token] itself is not in the
vocabulary, the empty list is returned. *)
val tokenize_ids : t -> string -> int array
(** [tokenize_ids t s] is like {!tokenize} but returns only token IDs. *)
(** {1:vocabulary Vocabulary} *)
val token_to_id : t -> string -> int option
(** [token_to_id t tok] is the ID of [tok] in the vocabulary. *)
val id_to_token : t -> int -> string option
(** [id_to_token t id] is the token string for [id]. *)
val get_vocab : t -> (string * int) list
(** [get_vocab t] is the vocabulary as [(token, id)] pairs. *)
val get_vocab_size : t -> int
(** [get_vocab_size t] is the number of tokens in the vocabulary. *)
val add_tokens : t -> string list -> int
(** [add_tokens t toks] adds [toks] to the vocabulary, assigning consecutive IDs
starting after the current maximum. Returns the number of new tokens
actually added (duplicates are skipped). Mutates [t]. *)
(** {1:serialization Serialization} *)
val save : t -> folder:string -> unit -> string list
(** [save t ~folder ()] writes [wordlevel.json] to [folder]. The file contains
the vocabulary and [unk_token] in JSON format. Returns the list of created
filenames. *)
(** {1:training Training} *)
val train :
vocab_size:int ->
min_frequency:int ->
show_progress:bool ->
special_tokens:string list ->
string list ->
t option ->
t * string list
(** [train ~vocab_size ~min_frequency ~show_progress ~special_tokens texts init]
learns a vocabulary from [texts] by counting word frequencies.
- [vocab_size] is the target vocabulary size.
- [min_frequency] is the minimum word frequency to include.
- [show_progress] enables progress output on [stderr].
- [special_tokens] are added to the vocabulary first.
- [init], when provided, seeds the vocabulary from an existing model.
Returns [(model, special_tokens)]. *)
================================================
FILE: packages/brot/lib/wordpiece.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
type token = { id : int; value : string; offsets : int * int }
(* Compact trie for zero-allocation longest-prefix matching *)
type trie = {
trie_ids : int array;
child_starts : int array;
edge_bytes : bytes;
edge_targets : int array;
(* Flat 256-element arrays for dense nodes (>16 children) — O(1) lookup *)
flat_nodes : int array array;
}
let build_trie vocab =
if Hashtbl.length vocab = 0 then
{
trie_ids = [||];
child_starts = [| 0 |];
edge_bytes = Bytes.empty;
edge_targets = [||];
flat_nodes = [||];
}
else
let cap = ref 256 in
let ids = ref (Array.make !cap (-1)) in
let ch = ref (Array.init !cap (fun _ -> Hashtbl.create 0)) in
let n = ref 1 in
!ch.(0) <- Hashtbl.create 64;
let grow () =
let new_cap = !cap * 2 in
let new_ids = Array.make new_cap (-1) in
Array.blit !ids 0 new_ids 0 !n;
ids := new_ids;
let new_ch =
Array.init new_cap (fun i ->
if i < !n then !ch.(i) else Hashtbl.create 0)
in
ch := new_ch;
cap := new_cap
in
Hashtbl.iter
(fun key id ->
let cur = ref 0 in
for i = 0 to String.length key - 1 do
let byte = Char.code (String.unsafe_get key i) in
let child =
match Hashtbl.find_opt !ch.(!cur) byte with
| Some c -> c
| None ->
if !n >= !cap then grow ();
let c = !n in
incr n;
!ch.(c) <- Hashtbl.create 4;
Hashtbl.add !ch.(!cur) byte c;
c
in
cur := child
done;
!ids.(!cur) <- id)
vocab;
let node_count = !n in
let trie_ids = Array.init node_count (fun i -> !ids.(i)) in
let child_starts = Array.make (node_count + 1) 0 in
let total = ref 0 in
for i = 0 to node_count - 1 do
child_starts.(i) <- !total;
total := !total + Hashtbl.length !ch.(i)
done;
child_starts.(node_count) <- !total;
let edge_bytes = Bytes.create !total in
let edge_targets = Array.make !total 0 in
let pos = ref 0 in
for i = 0 to node_count - 1 do
Hashtbl.iter
(fun byte child ->
Bytes.unsafe_set edge_bytes !pos (Char.unsafe_chr byte);
edge_targets.(!pos) <- child;
incr pos)
!ch.(i)
done;
(* Sort each node's children by byte value for binary search *)
for i = 0 to node_count - 1 do
let start = child_starts.(i) in
let stop = child_starts.(i + 1) in
for j = start + 1 to stop - 1 do
let kb = Bytes.unsafe_get edge_bytes j in
let kt = edge_targets.(j) in
let k = ref (j - 1) in
while !k >= start && Bytes.unsafe_get edge_bytes !k > kb do
Bytes.unsafe_set edge_bytes (!k + 1) (Bytes.unsafe_get edge_bytes !k);
edge_targets.(!k + 1) <- edge_targets.(!k);
decr k
done;
Bytes.unsafe_set edge_bytes (!k + 1) kb;
edge_targets.(!k + 1) <- kt
done
done;
(* Build flat 256-element arrays for dense nodes (>16 children) *)
let flat_nodes = Array.make node_count [||] in
for i = 0 to node_count - 1 do
let start = child_starts.(i) in
let count = child_starts.(i + 1) - start in
if count > 16 then begin
let flat = Array.make 256 (-1) in
for j = start to start + count - 1 do
let b = Char.code (Bytes.unsafe_get edge_bytes j) in
flat.(b) <- Array.unsafe_get edge_targets j
done;
flat_nodes.(i) <- flat
end
done;
{ trie_ids; child_starts; edge_bytes; edge_targets; flat_nodes }
let[@inline] trie_step trie node byte =
let flat = Array.unsafe_get trie.flat_nodes node in
if Array.length flat > 0 then Array.unsafe_get flat byte
else
let lo = ref (Array.unsafe_get trie.child_starts node) in
let hi = ref (Array.unsafe_get trie.child_starts (node + 1) - 1) in
let result = ref (-1) in
while !lo <= !hi do
let mid = !lo + ((!hi - !lo) asr 1) in
let mid_byte = Char.code (Bytes.unsafe_get trie.edge_bytes mid) in
if mid_byte = byte then (
result := Array.unsafe_get trie.edge_targets mid;
lo := !hi + 1)
else if mid_byte < byte then lo := mid + 1
else hi := mid - 1
done;
!result
let trie_longest_match trie sequence ~start ~prefix ~prefix_len =
if Array.length trie.trie_ids = 0 then None
else
let seq_len = String.length sequence in
let last_id = ref (-1) in
let last_end = ref start in
let current = ref 0 in
let stopped = ref false in
let i = ref 0 in
while !i < prefix_len && not !stopped do
let child =
trie_step trie !current (Char.code (String.unsafe_get prefix !i))
in
if child < 0 then stopped := true
else (
current := child;
incr i)
done;
(if not !stopped then
let j = ref start in
while !j < seq_len && not !stopped do
let child =
trie_step trie !current (Char.code (String.unsafe_get sequence !j))
in
if child < 0 then stopped := true
else (
current := child;
incr j;
let tid = Array.unsafe_get trie.trie_ids child in
if tid >= 0 then (
last_id := tid;
last_end := !j))
done);
if !last_id >= 0 then Some (!last_id, !last_end) else None
(* Model type *)
type t = {
vocab : (string, int) Hashtbl.t;
vocab_r : string array;
trie : trie;
unk_token : string;
continuing_subword_prefix : string;
max_input_chars_per_word : int;
}
let create ~vocab ?(unk_token = "[UNK]") ?(continuing_subword_prefix = "##")
?(max_input_chars_per_word = 100) () =
let max_id = Hashtbl.fold (fun _ id acc -> max id acc) vocab (-1) in
let vocab_r = Array.make (max_id + 1) "" in
Hashtbl.iter (fun k v -> Array.unsafe_set vocab_r v k) vocab;
if Hashtbl.length vocab > 0 && not (Hashtbl.mem vocab unk_token) then
invalid_arg "Wordpiece.create: unk_token not in vocab";
let trie = build_trie vocab in
{
vocab;
vocab_r;
trie;
unk_token;
continuing_subword_prefix;
max_input_chars_per_word;
}
let read_file ~vocab_file =
let vocab = Hashtbl.create 10000 in
let ic = open_in vocab_file in
Fun.protect
~finally:(fun () -> close_in ic)
(fun () ->
let index = ref 0 in
(try
while true do
let line = input_line ic in
let token = String.trim line in
if token <> "" then (
Hashtbl.add vocab token !index;
incr index)
done
with End_of_file -> ());
vocab)
let from_file ~vocab_file =
let vocab = read_file ~vocab_file in
create ~vocab ()
let count_chars s =
let len = String.length s in
let n = ref 0 in
for i = 0 to len - 1 do
if Char.code (String.unsafe_get s i) land 0xC0 <> 0x80 then incr n
done;
!n
let tokenize model sequence =
if Hashtbl.length model.vocab = 0 then []
else
let seq_len = String.length sequence in
if count_chars sequence > model.max_input_chars_per_word then
let id = Hashtbl.find model.vocab model.unk_token in
[ { id; value = model.unk_token; offsets = (0, seq_len) } ]
else
let prefix = model.continuing_subword_prefix in
let prefix_len = String.length prefix in
let rec greedy start acc =
if start >= seq_len then List.rev acc
else
let p = if start > 0 then prefix else "" in
let pl = if start > 0 then prefix_len else 0 in
match
trie_longest_match model.trie sequence ~start ~prefix:p
~prefix_len:pl
with
| Some (id, end_byte) ->
let value = Array.unsafe_get model.vocab_r id in
greedy end_byte ({ id; value; offsets = (start, end_byte) } :: acc)
| None ->
let id = Hashtbl.find model.vocab model.unk_token in
[ { id; value = model.unk_token; offsets = (0, seq_len) } ]
in
greedy 0 []
let tokenize_ids model sequence =
if Hashtbl.length model.vocab = 0 then [||]
else
let seq_len = String.length sequence in
if count_chars sequence > model.max_input_chars_per_word then
let id = Hashtbl.find model.vocab model.unk_token in
[| id |]
else
let prefix = model.continuing_subword_prefix in
let prefix_len = String.length prefix in
let ids = ref [] in
let n = ref 0 in
let rec greedy start =
if start >= seq_len then ()
else
let p = if start > 0 then prefix else "" in
let pl = if start > 0 then prefix_len else 0 in
match
trie_longest_match model.trie sequence ~start ~prefix:p
~prefix_len:pl
with
| Some (id, end_byte) ->
ids := id :: !ids;
incr n;
greedy end_byte
| None ->
let unk_id = Hashtbl.find model.vocab model.unk_token in
ids := [ unk_id ];
n := 1
in
greedy 0;
let result = Array.make !n 0 in
List.iteri (fun i id -> result.(!n - 1 - i) <- id) !ids;
result
let tokenize_spans_encoding model pre_tokens ~type_id =
if Hashtbl.length model.vocab = 0 then Encoding.empty
else
let trie = model.trie in
let prefix = model.continuing_subword_prefix in
let prefix_len = String.length prefix in
let unk_id = Hashtbl.find model.vocab model.unk_token in
let max_chars = model.max_input_chars_per_word in
let vocab_r = model.vocab_r in
let unk_token_str = model.unk_token in
(* Single pass: convert pre_tokens to array for direct access (no closure),
tokenize all fragments and fill growable output arrays directly. *)
let pre_arr = Array.of_list pre_tokens in
let n_pre = Array.length pre_arr in
let cap = ref (max 16 (n_pre * 2)) in
let ids = ref (Array.make !cap 0) in
let token_strs = ref (Array.make !cap "") in
let offsets_arr = ref (Array.make !cap (0, 0)) in
let n = ref 0 in
let grow () =
let new_cap = !cap * 2 in
let new_ids = Array.make new_cap 0 in
Array.blit !ids 0 new_ids 0 !n;
ids := new_ids;
let new_strs = Array.make new_cap "" in
Array.blit !token_strs 0 new_strs 0 !n;
token_strs := new_strs;
let new_off = Array.make new_cap (0, 0) in
Array.blit !offsets_arr 0 new_off 0 !n;
offsets_arr := new_off;
cap := new_cap
in
(* Hoisted mutable state for trie matching — allocated once *)
let current = ref 0 in
let stopped = ref false in
let last_id = ref (-1) in
let last_end = ref 0 in
let pos = ref 0 in
let is_unk = ref false in
let char_count = ref 0 in
let i_ref = ref 0 in
let j_ref = ref 0 in
for frag_idx = 0 to n_pre - 1 do
let fragment, _ = Array.unsafe_get pre_arr frag_idx in
let seq_len = String.length fragment in
char_count := 0;
for k = 0 to seq_len - 1 do
if Char.code (String.unsafe_get fragment k) land 0xC0 <> 0x80 then
incr char_count
done;
if !char_count > max_chars then begin
if !n >= !cap then grow ();
Array.unsafe_set !ids !n unk_id;
Array.unsafe_set !token_strs !n unk_token_str;
Array.unsafe_set !offsets_arr !n (0, seq_len);
incr n
end
else begin
pos := 0;
is_unk := false;
let start_n = !n in
while !pos < seq_len && not !is_unk do
let match_start = !pos in
current := 0;
stopped := false;
last_id := -1;
last_end := !pos;
if !pos > 0 then begin
i_ref := 0;
while !i_ref < prefix_len && not !stopped do
let child =
trie_step trie !current
(Char.code (String.unsafe_get prefix !i_ref))
in
if child < 0 then stopped := true
else begin
current := child;
incr i_ref
end
done
end;
if not !stopped then begin
j_ref := !pos;
while !j_ref < seq_len && not !stopped do
let child =
trie_step trie !current
(Char.code (String.unsafe_get fragment !j_ref))
in
if child < 0 then stopped := true
else begin
current := child;
incr j_ref;
let tid = Array.unsafe_get trie.trie_ids child in
if tid >= 0 then begin
last_id := tid;
last_end := !j_ref
end
end
done
end;
if !last_id >= 0 then begin
if !n >= !cap then grow ();
Array.unsafe_set !ids !n !last_id;
Array.unsafe_set !token_strs !n (Array.unsafe_get vocab_r !last_id);
Array.unsafe_set !offsets_arr !n (match_start, !last_end);
incr n;
pos := !last_end
end
else is_unk := true
done;
if !is_unk then begin
n := start_n;
if !n >= !cap then grow ();
Array.unsafe_set !ids !n unk_id;
Array.unsafe_set !token_strs !n unk_token_str;
Array.unsafe_set !offsets_arr !n (0, seq_len);
n := start_n + 1
end
end
done;
let total = !n in
if total = 0 then Encoding.empty
else
let final_ids = if total = !cap then !ids else Array.sub !ids 0 total in
let final_strs =
if total = !cap then !token_strs else Array.sub !token_strs 0 total
in
let final_off =
if total = !cap then !offsets_arr else Array.sub !offsets_arr 0 total
in
Encoding.create ~ids:final_ids ~type_ids:(Array.make total type_id)
~tokens:final_strs ~words:(Array.make total None) ~offsets:final_off
~special_tokens_mask:(Array.make total 0)
~attention_mask:(Array.make total 1) ()
let token_to_id model token = Hashtbl.find_opt model.vocab token
let id_to_token model id =
if id >= 0 && id < Array.length model.vocab_r then
Some (Array.unsafe_get model.vocab_r id)
else None
let get_vocab model = Hashtbl.fold (fun k v acc -> (k, v) :: acc) model.vocab []
let get_vocab_size model = Hashtbl.length model.vocab
let get_unk_token model = model.unk_token
let get_continuing_subword_prefix model = model.continuing_subword_prefix
let save model ~path ?name () =
let vocab_file =
match name with
| Some n -> Filename.concat path (n ^ "-vocab.txt")
| None -> Filename.concat path "vocab.txt"
in
let vocab_list =
Hashtbl.fold (fun k v acc -> (v, k) :: acc) model.vocab []
|> List.sort compare
|> List.map (fun (_, k) -> k)
in
let oc = open_out vocab_file in
Fun.protect
~finally:(fun () -> close_out oc)
(fun () ->
List.iter
(fun token ->
output_string oc token;
output_char oc '\n')
vocab_list);
vocab_file
let from_bpe bpe =
let vocab = Hashtbl.create (Bpe.get_vocab_size bpe) in
List.iter (fun (k, id) -> Hashtbl.add vocab k id) (Bpe.get_vocab bpe);
let unk_token =
match Bpe.get_unk_token bpe with Some u -> u | None -> "[UNK]"
in
if not (Hashtbl.mem vocab unk_token) then begin
let max_id = Hashtbl.fold (fun _ id acc -> max id acc) vocab (-1) in
Hashtbl.add vocab unk_token (max_id + 1)
end;
let continuing_subword_prefix =
match Bpe.get_continuing_subword_prefix bpe with
| Some p -> p
| None -> "##"
in
create ~vocab ~unk_token ~continuing_subword_prefix ()
(* Trainer *)
let train ~min_frequency ~vocab_size ~show_progress ~special_tokens
~limit_alphabet ~initial_alphabet ~continuing_subword_prefix
~end_of_word_suffix texts existing =
let _ = existing in
(* WordPiece training uses BPE algorithm internally *)
let bpe_trained, result_tokens =
Bpe.train ~min_frequency ~vocab_size ~show_progress ~special_tokens
~limit_alphabet ~initial_alphabet
~continuing_subword_prefix:(Some continuing_subword_prefix)
~end_of_word_suffix ~max_token_length:None texts None
in
let wordpiece_model = from_bpe bpe_trained in
(wordpiece_model, result_tokens)
================================================
FILE: packages/brot/lib/wordpiece.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** WordPiece tokenization model.
{b Internal module.} Greedy longest-match-first subword decomposition
against a fixed vocabulary. Used by BERT, DistilBERT, and Electra.
A word is decomposed left-to-right: at each position the longest vocabulary
match is consumed. Continuation pieces are prefixed with
{!get_continuing_subword_prefix} (typically ["##"]). If no subword is found
at any position the {e entire} word falls back to {!get_unk_token}.
Vocabulary lookup uses a hybrid trie: dense nodes (more than 16 children)
use a 256-element flat array for O(1) byte dispatch, sparse nodes use binary
search on sorted edges. *)
type t
(** The type for WordPiece models. *)
(** {1:creation Creation} *)
val create :
vocab:(string, int) Hashtbl.t ->
?unk_token:string ->
?continuing_subword_prefix:string ->
?max_input_chars_per_word:int ->
unit ->
t
(** [create ~vocab ()] is a WordPiece model backed by [vocab].
- [unk_token] is the token emitted for words that cannot be decomposed.
Defaults to ["[UNK]"].
- [continuing_subword_prefix] is prepended to non-initial subwords. Defaults
to ["##"].
- [max_input_chars_per_word] is the UTF-8 character count above which a word
is replaced by [unk_token] without attempting decomposition. Defaults to
[100].
Raises [Invalid_argument] if [unk_token] is not in [vocab]. *)
val from_file : vocab_file:string -> t
(** [from_file ~vocab_file] loads a model from a BERT-style [vocab.txt] file
(one token per line, ID equals line number). Uses BERT defaults:
[unk_token = "[UNK]"], [continuing_subword_prefix = "##"],
[max_input_chars_per_word = 100]. *)
(** {1:tokenization Tokenization} *)
type token = { id : int; value : string; offsets : int * int }
(** The type for tokens. [id] is the vocabulary index, [value] the string
content, and [offsets] the [(start, stop)] byte span in the source text. *)
val tokenize : t -> string -> token list
(** [tokenize t s] is the WordPiece decomposition of [s].
If [s] exceeds {!create}'s [max_input_chars_per_word] (in UTF-8 characters),
a single [unk_token] token spanning the whole input is returned. If
decomposition fails at any position, the result is likewise a single
[unk_token]. *)
val tokenize_ids : t -> string -> int array
(** [tokenize_ids t s] is like {!tokenize} but returns only token IDs. *)
val tokenize_spans_encoding :
t -> (string * (int * int)) list -> type_id:int -> Encoding.t
(** [tokenize_spans_encoding t spans ~type_id] tokenizes all [spans] and builds
an {!Encoding.t} directly. Each element of [spans] is
[(fragment, (start, stop))] where offsets are byte positions in the original
text.
This is a single-pass variant that avoids intermediate list and record
allocation: mutable refs are hoisted, growable arrays are filled in place,
and trie matching is inlined. *)
(** {1:vocabulary Vocabulary} *)
val token_to_id : t -> string -> int option
(** [token_to_id t tok] is the ID of [tok] in the vocabulary. *)
val id_to_token : t -> int -> string option
(** [id_to_token t id] is the token string for [id]. *)
val get_vocab : t -> (string * int) list
(** [get_vocab t] is the vocabulary as [(token, id)] pairs. *)
val get_vocab_size : t -> int
(** [get_vocab_size t] is the number of tokens in the vocabulary. *)
val get_unk_token : t -> string
(** [get_unk_token t] is the unknown token string. *)
val get_continuing_subword_prefix : t -> string
(** [get_continuing_subword_prefix t] is the subword continuation prefix (e.g.
["##"]). *)
(** {1:serialization Serialization} *)
val save : t -> path:string -> ?name:string -> unit -> string
(** [save t ~path ()] writes the vocabulary as a plain-text [vocab.txt] file
(one token per line) to [path]. If [name] is given the file is named
[{name}-vocab.txt]. Returns the filepath written. *)
(** {1:training Training} *)
val train :
min_frequency:int ->
vocab_size:int ->
show_progress:bool ->
special_tokens:string list ->
limit_alphabet:int option ->
initial_alphabet:char list ->
continuing_subword_prefix:string ->
end_of_word_suffix:string option ->
string list ->
t option ->
t * string list
(** [train ~min_frequency ~vocab_size ~show_progress ~special_tokens
~limit_alphabet ~initial_alphabet ~continuing_subword_prefix
~end_of_word_suffix texts init] learns a WordPiece vocabulary from [texts]
using BPE merge training internally.
- [min_frequency] is the minimum pair frequency to merge.
- [vocab_size] is the target vocabulary size.
- [show_progress] enables progress output on [stderr].
- [special_tokens] are added to the vocabulary first.
- [limit_alphabet] caps the number of distinct initial characters kept.
- [initial_alphabet] seeds the character set.
- [continuing_subword_prefix] is set on the resulting model.
- [end_of_word_suffix] appended to final subwords if given.
- [init], when provided, seeds the vocabulary from an existing model.
Returns [(model, special_tokens)]. *)
================================================
FILE: packages/brot/test/dune
================================================
(data_only_dirs fixtures scripts)
(tests
(names
test_tokenization
test_vocab
test_encoding
test_unicode
test_bpe
test_wordpiece
test_hf_tokenizers
test_processors
test_pretokenizers)
(package brot)
(libraries brot windtrap unix jsont))
================================================
FILE: packages/brot/test/fixtures/.gitignore
================================================
hf/
================================================
FILE: packages/brot/test/scripts/download_hf_tokenizers.py
================================================
#!/usr/bin/env python3
"""
Download selected HuggingFace tokenizer JSON files into brot/test/fixtures/hf.
Run this script whenever you need to refresh the fixtures:
python3 brot/test/scripts/download_hf_tokenizers.py
The files are ignored by git, so each developer/machine maintains its own cache.
"""
from __future__ import annotations
import hashlib
import json
import sys
import urllib.request
from pathlib import Path
from typing import Iterable, Tuple
FIXTURES: Iterable[Tuple[str, str]] = (
(
"bert-base-uncased",
"https://huggingface.co/bert-base-uncased/resolve/main/tokenizer.json?download=1",
),
(
"gpt2",
"https://huggingface.co/gpt2/resolve/main/tokenizer.json?download=1",
),
(
"roberta-base",
"https://huggingface.co/roberta-base/resolve/main/tokenizer.json?download=1",
),
)
def download(url: str, dest: Path) -> None:
dest.parent.mkdir(parents=True, exist_ok=True)
tmp_path = dest.with_suffix(".tmp")
print(f"→ downloading {url}…")
with urllib.request.urlopen(url) as response, open(tmp_path, "wb") as out:
while True:
chunk = response.read(1024 * 64)
if not chunk:
break
out.write(chunk)
tmp_path.replace(dest)
def sha256(path: Path) -> str:
h = hashlib.sha256()
with path.open("rb") as fh:
for chunk in iter(lambda: fh.read(1024 * 64), b""):
h.update(chunk)
return h.hexdigest()
def summarize(path: Path) -> None:
try:
with path.open("r", encoding="utf-8") as fh:
metadata = json.load(fh)
model_type = metadata.get("model", {}).get("type", "")
size = path.stat().st_size
digest = sha256(path)[:12]
print(f" saved {path} ({size} bytes, model={model_type}, sha256={digest})")
except Exception as exc: # pylint: disable=broad-except
print(f" warning: failed to inspect {path}: {exc}")
def main() -> int:
test_root = Path(__file__).resolve().parents[1]
fixtures_dir = test_root / "fixtures" / "hf"
fixtures_dir.mkdir(parents=True, exist_ok=True)
for model, url in FIXTURES:
target = fixtures_dir / model / "tokenizer.json"
if target.exists():
print(f"✓ {model} already present at {target}")
continue
print(f"Downloading {model} tokenizer…")
try:
download(url, target)
summarize(target)
except Exception as exc: # pylint: disable=broad-except
print(f" failed to download {model}: {exc}", file=sys.stderr)
if target.exists():
target.unlink()
return 1
print("All fixtures downloaded.")
return 0
if __name__ == "__main__":
sys.exit(main())
================================================
FILE: packages/brot/test/test_bpe.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
open Windtrap
open Brot
let test_bpe_basic () =
(* Create a simple vocabulary and merges *)
let vocab =
[
("h", 0);
("e", 1);
("l", 2);
("o", 3);
("ll", 4);
("he", 5);
("llo", 6);
("hello", 7);
]
in
let merges =
[
("l", "l");
(* rank 0: Merge 'l' + 'l' -> 'll' *)
("ll", "o");
(* rank 1: Merge 'll' + 'o' -> 'llo' *)
("he", "llo");
(* rank 2: Merge 'he' + 'llo' -> 'hello' *)
]
in
let tokenizer = bpe ~vocab ~merges ~unk_token:"" () in
let encoding = encode tokenizer "hello" in
let tokens = Encoding.tokens encoding |> Array.to_list in
Printf.printf "Tokenized 'hello': ";
List.iter (Printf.printf "%s ") tokens;
Printf.printf "\n";
equal ~msg:"vocabulary size" int 8 (vocab_size tokenizer)
let test_bpe_builder () =
let vocab = [ ("a", 0); ("b", 1); ("ab", 2) ] in
let merges = [ ("a", "b") ] in
let tokenizer = bpe ~vocab ~merges ~cache_capacity:50 () in
let encoding = encode tokenizer "ab" in
let tokens = Encoding.tokens encoding in
equal ~msg:"single token for 'ab'" int 1 (Array.length tokens)
let test_bpe_save_load () =
let vocab = [ ("t", 0); ("e", 1); ("s", 2); ("test", 3) ] in
let merges = [] in
(* No merges for simplicity *)
let tokenizer = bpe ~vocab ~merges () in
(* Save the model *)
let temp_dir = Filename.temp_dir "bpe_test" "" in
let files = save_model_files tokenizer ~folder:temp_dir () in
(* Load the model *)
let vocab_file = List.find (fun f -> Filename.check_suffix f ".json") files in
let merges_file = List.find (fun f -> Filename.check_suffix f ".txt") files in
let loaded_tokenizer =
from_model_file ~vocab:vocab_file ~merges:merges_file ()
in
(* Test that loaded tokenizer works the same *)
let original_tokens = encode tokenizer "test" |> Encoding.tokens in
let loaded_tokens = encode loaded_tokenizer "test" |> Encoding.tokens in
equal ~msg:"same number of tokens" int
(Array.length original_tokens)
(Array.length loaded_tokens);
(* Clean up *)
List.iter Sys.remove files;
Unix.rmdir temp_dir
let test_tokenizer_integration () =
(* Create a BPE tokenizer using the high-level API *)
let vocab =
[
("h", 0); ("e", 1); ("l", 2); ("o", 3); ("he", 4); ("llo", 5); ("hello", 6);
]
in
let merges = [ ("h", "e"); ("he", "llo") ] in
let tokenizer = bpe ~vocab ~merges () in
(* Test encoding *)
let tokens = encode tokenizer "hello" |> Encoding.tokens |> Array.to_list in
Printf.printf "bpe result: ";
List.iter (Printf.printf "%s ") tokens;
Printf.printf "\n";
equal ~msg:"tokenizer produces output" bool true (List.length tokens > 0)
let () =
run "BPE tests"
[
group "basic"
[
test "basic tokenization" test_bpe_basic;
test "builder pattern" test_bpe_builder;
test "save and load" test_bpe_save_load;
test "tokenizer integration" test_tokenizer_integration;
];
]
================================================
FILE: packages/brot/test/test_encoding.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
open Windtrap
open Brot
let make_word_tokenizer ?(specials = []) () =
word_level ~pre:(Pre_tokenizer.whitespace ()) ~specials ()
let test_encode_simple () =
let tokenizer = add_tokens (make_word_tokenizer ()) [ "hello"; "world" ] in
let ids = encode tokenizer "hello world hello" |> Encoding.ids in
equal ~msg:"encoded length" int 3 (Array.length ids);
equal ~msg:"repeated token same id" bool true (ids.(0) = ids.(2))
let test_encode_with_vocab () =
let tokenizer = add_tokens (make_word_tokenizer ()) [ "hello"; "world" ] in
let ids = encode tokenizer "hello world" |> Encoding.ids |> Array.to_list in
equal ~msg:"encoded with vocab" (list int) [ 0; 1 ] ids
let test_encode_unknown_tokens () =
let tokenizer =
add_tokens
(make_word_tokenizer ~specials:[ special "" ] ())
[ "hello" ]
in
let ids =
encode tokenizer "hello unknown world" |> Encoding.ids |> Array.to_list
in
equal ~msg:"encoded something" bool true (List.length ids > 0)
let test_encode_empty () =
let tokenizer = make_word_tokenizer () in
let ids = encode tokenizer "" |> Encoding.ids |> Array.to_list in
equal ~msg:"encode empty" (list int) [] ids
let test_encode_batch_simple () =
let tokenizer =
add_tokens (make_word_tokenizer ()) [ "hello"; "world"; "hi"; "there" ]
in
let encodings = encode_batch tokenizer [ "hello world"; "hi there" ] in
equal ~msg:"batch size" int 2 (List.length encodings);
let first = List.hd encodings in
equal ~msg:"first encoding has ids" bool true
(Array.length (Encoding.ids first) > 0)
let test_encode_batch_with_padding () =
let tokenizer =
add_tokens
(make_word_tokenizer ~specials:[ special "" ] ())
[ "hello"; "world"; "hi"; "there" ]
in
let padding =
{
length = `Fixed 5;
direction = `Right;
pad_id = None;
pad_type_id = None;
pad_token = Some "";
}
in
let encodings = encode_batch tokenizer ~padding [ "hello"; "hi there" ] in
let first = Encoding.ids (List.nth encodings 0) in
let second = Encoding.ids (List.nth encodings 1) in
equal ~msg:"first padded length" int 5 (Array.length first);
equal ~msg:"second padded length" int 5 (Array.length second)
let test_encode_batch_empty () =
let tokenizer = make_word_tokenizer () in
let encodings = encode_batch tokenizer [] in
equal ~msg:"empty batch" int 0 (List.length encodings)
let test_decode_simple () =
let tokenizer = add_tokens (make_word_tokenizer ()) [ "hello"; "world" ] in
let decoded = decode tokenizer [| 0; 1 |] in
equal ~msg:"decoded text" string "hello world" decoded
let test_decode_with_special () =
let tokenizer =
add_tokens
(make_word_tokenizer ~specials:[ special ""; special "" ] ())
[ "hello" ]
in
(* =0, =1, hello=2 *)
let decoded = decode tokenizer [| 0; 2; 1 |] in
equal ~msg:"decoded with special" string " hello " decoded
let test_decode_skip_special () =
let tokenizer =
add_tokens
(make_word_tokenizer ~specials:[ special ""; special "" ] ())
[ "hello" ]
in
let decoded = decode ~skip_special_tokens:true tokenizer [| 0; 2; 1 |] in
equal ~msg:"decoded without special" string "hello" decoded
let test_decode_batch () =
let tokenizer =
add_tokens (make_word_tokenizer ()) [ "hello"; "world"; "hi"; "there" ]
in
let decoded = decode_batch tokenizer [ [| 0; 1 |]; [| 2; 3 |] ] in
equal ~msg:"decoded count" int 2 (List.length decoded);
equal ~msg:"first decoded" string "hello world" (List.nth decoded 0);
equal ~msg:"second decoded" string "hi there" (List.nth decoded 1)
let test_chars_model () =
let tokenizer = chars () in
let ids = encode tokenizer "abc" |> Encoding.ids |> Array.to_list in
equal ~msg:"char ids" (list int) [ 97; 98; 99 ] ids
let suite =
[
test "encode simple" test_encode_simple;
test "encode with vocab" test_encode_with_vocab;
test "encode unknown tokens" test_encode_unknown_tokens;
test "encode empty" test_encode_empty;
test "batch simple" test_encode_batch_simple;
test "batch with padding" test_encode_batch_with_padding;
test "batch empty request" test_encode_batch_empty;
test "decode simple" test_decode_simple;
test "decode with special" test_decode_with_special;
test "decode skip special" test_decode_skip_special;
test "decode batch" test_decode_batch;
test "chars model" test_chars_model;
]
let () = run "Encoding tests" [ group "encoding" suite ]
================================================
FILE: packages/brot/test/test_hf_tokenizers.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
open Brot
open Windtrap
let candidate_roots () =
match Sys.getenv_opt "DUNE_SOURCEROOT" with
| Some root -> [ root; Sys.getcwd () ]
| None -> [ Sys.getcwd () ]
let locate_fixture model =
let relative =
Filename.concat "brot/test/fixtures/hf"
(Filename.concat model "tokenizer.json")
in
let rec search = function
| [] -> None
| root :: rest ->
let path = Filename.concat root relative in
if Sys.file_exists path then Some path else search rest
in
search (candidate_roots ())
let with_hf_tokenizer model f =
match locate_fixture model with
| None -> skip ()
| Some path -> (
match from_file path with
| Ok tok -> f tok
| Error msg -> failf "Failed to load tokenizer %s: %s" model msg)
let test_bert_base_uncased () =
with_hf_tokenizer "bert-base-uncased" (fun tok ->
let encoding = encode tok "Hello world!" in
let tokens = Encoding.tokens encoding |> Array.to_list in
equal ~msg:"token sequence" (list string)
[ "[CLS]"; "hello"; "world"; "!"; "[SEP]" ]
tokens;
let type_ids = Encoding.type_ids encoding |> Array.to_list in
equal ~msg:"type ids" (list int) [ 0; 0; 0; 0; 0 ] type_ids;
equal ~msg:"has [MASK]" bool true
(Option.is_some (token_to_id tok "[MASK]")))
let test_gpt2_small () =
with_hf_tokenizer "gpt2" (fun tok ->
let encoding = encode tok "Hello world" in
let ids = Encoding.ids encoding |> Array.to_list in
equal ~msg:"ids" (list int) [ 15496; 995 ] ids;
let roundtrip =
decode tok (Array.of_list ids) ~skip_special_tokens:true
in
equal ~msg:"decode" string "Hello world" roundtrip)
let test_roberta_base () =
with_hf_tokenizer "roberta-base" (fun tok ->
let encoding = encode tok "A quick test" in
let tokens = Encoding.tokens encoding |> Array.to_list in
equal ~msg:"tokens" (list string)
[ ""; "A"; "Ġquick"; "Ġtest"; "" ]
tokens;
let attention = Encoding.attention_mask encoding |> Array.to_list in
equal ~msg:"attention mask" (list int) [ 1; 1; 1; 1; 1 ] attention)
let () =
run "HF tokenizers"
[
group "bert-base-uncased" [ test "encode" test_bert_base_uncased ];
group "gpt2" [ test "encode" test_gpt2_small ];
group "roberta-base" [ test "encode" test_roberta_base ];
]
================================================
FILE: packages/brot/test/test_pretokenizers.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
open Windtrap
module Pre = Brot.Pre_tokenizer
let check_tokenization name input expected =
equal ~msg:name (list (pair string (pair int int))) expected input
let check_strings name input expected =
equal ~msg:name (list string) expected (List.map fst input)
let test_byte_level_basic () =
let tokenizer = Pre.byte_level ~add_prefix_space:false ~use_regex:true () in
(* Test basic tokenization *)
let test_case text expected_pieces expected_offsets =
let result = Pre.pre_tokenize tokenizer text in
let offsets = List.map snd result in
check_strings
(Printf.sprintf "ByteLevel pieces for %S" text)
result expected_pieces;
equal
~msg:(Printf.sprintf "ByteLevel offsets for %S" text)
(list (pair int int))
expected_offsets offsets
in
(* Basic words *)
test_case "Hello" [ "Hello" ] [ (0, 5) ];
test_case "hello" [ "hello" ] [ (0, 5) ];
test_case "HELLO" [ "HELLO" ] [ (0, 5) ];
(* Words with spaces - space becomes Ġ (0xC4 0xA0) *)
test_case "Hello world" [ "Hello"; "\196\160world" ] [ (0, 5); (5, 11) ];
test_case "Hello world"
[ "Hello"; "\196\160"; "\196\160world" ]
[ (0, 5); (5, 6); (6, 12) ];
(* Leading/trailing spaces *)
test_case " hello" [ "\196\160hello" ] [ (0, 6) ];
test_case "hello " [ "hello"; "\196\160" ] [ (0, 5); (5, 6) ];
(* Note: Python produces ['Ġ', 'Ġhello', 'ĠĠ'] for " hello " *)
test_case " hello "
[ "\196\160"; "\196\160hello"; "\196\160\196\160" ]
[ (0, 1); (1, 7); (7, 9) ];
(* Contractions - should be kept as separate pieces *)
test_case "'s" [ "'s" ] [ (0, 2) ];
test_case "'t" [ "'t" ] [ (0, 2) ];
test_case "'re" [ "'re" ] [ (0, 3) ];
test_case "'ve" [ "'ve" ] [ (0, 3) ];
test_case "'m" [ "'m" ] [ (0, 2) ];
test_case "'ll" [ "'ll" ] [ (0, 3) ];
test_case "'d" [ "'d" ] [ (0, 2) ];
(* Words with contractions *)
test_case "don't" [ "don"; "'t" ] [ (0, 3); (3, 5) ];
test_case "it's" [ "it"; "'s" ] [ (0, 2); (2, 4) ];
test_case "we're" [ "we"; "'re" ] [ (0, 2); (2, 5) ];
test_case "I'll" [ "I"; "'ll" ] [ (0, 1); (1, 4) ];
test_case "OpenAI's" [ "OpenAI"; "'s" ] [ (0, 6); (6, 8) ]
let test_byte_level_prefix_space () =
(* Test with add_prefix_space=true *)
let tokenizer = Pre.byte_level ~add_prefix_space:true ~use_regex:true () in
let test_case text expected_pieces =
let result = Pre.pre_tokenize tokenizer text in
check_strings
(Printf.sprintf "ByteLevel with prefix for %S" text)
result expected_pieces
in
(* Should add space prefix when text doesn't start with space *)
test_case "hello" [ "\196\160hello" ];
test_case "Hello world" [ "\196\160Hello"; "\196\160world" ];
(* Should NOT add extra space when text already starts with space *)
test_case " hello" [ "\196\160hello" ];
test_case " hello" [ "\196\160"; "\196\160hello" ]
let test_byte_level_special_chars () =
let tokenizer = Pre.byte_level ~add_prefix_space:false ~use_regex:true () in
let test_case text desc =
let result = Pre.pre_tokenize tokenizer text in
let pieces = List.map fst result in
(* Just verify it doesn't crash and produces something *)
equal
~msg:(Printf.sprintf "ByteLevel handles %s" desc)
bool true
(List.length pieces > 0)
in
(* Punctuation *)
test_case "." "period";
test_case "!" "exclamation";
test_case "?" "question";
test_case "," "comma";
test_case ";" "semicolon";
test_case ":" "colon";
(* Special characters *)
test_case "@" "at sign";
test_case "#" "hash";
test_case "$" "dollar";
test_case "%" "percent";
test_case "^" "caret";
test_case "&" "ampersand";
test_case "*" "asterisk";
(* Brackets and quotes *)
test_case "()" "parentheses";
test_case "[]" "brackets";
test_case "{}" "braces";
test_case "\"\"" "quotes";
test_case "''" "single quotes";
(* Numbers *)
test_case "123" "numbers";
test_case "3.14" "decimal";
test_case "1,000" "number with comma";
(* Mixed *)
test_case "Hello, world!" "punctuated sentence";
test_case "@user #hashtag" "social media";
test_case "test@example.com" "email";
test_case "https://example.com" "URL";
test_case "function()" "function call";
test_case "a+b=c" "math expression"
let test_byte_level_unicode () =
let tokenizer = Pre.byte_level ~add_prefix_space:false ~use_regex:true () in
let test_case text desc =
let result = Pre.pre_tokenize tokenizer text in
let pieces = List.map fst result in
(* Byte-level encoding should handle any Unicode by encoding bytes *)
equal
~msg:(Printf.sprintf "ByteLevel handles %s" desc)
bool true
(List.length pieces > 0);
(* Check that we can reconstruct something (even if not identical due to
encoding) *)
let concatenated = String.concat "" pieces in
equal
~msg:(Printf.sprintf "ByteLevel produces non-empty output for %s" desc)
bool true
(String.length concatenated > 0)
in
(* Common accented characters *)
test_case "café" "accented e";
test_case "naïve" "diaeresis";
test_case "résumé" "French accents";
(* Other languages *)
test_case "你好" "Chinese";
test_case "こんにちは" "Japanese";
test_case "안녕하세요" "Korean";
test_case "Привет" "Russian";
test_case "مرحبا" "Arabic";
(* Emojis *)
test_case "😀" "emoji";
test_case "👍🏻" "emoji with skin tone";
test_case "Hello 👋 World" "text with emoji"
let test_byte_level_edge_cases () =
let tokenizer = Pre.byte_level ~add_prefix_space:false ~use_regex:true () in
(* Empty string *)
let result = Pre.pre_tokenize tokenizer "" in
equal ~msg:"Empty string" (list string) [] (List.map fst result);
(* Single character *)
let result = Pre.pre_tokenize tokenizer "a" in
check_strings "Single char" result [ "a" ];
(* Only spaces - Python produces ['ĠĠĠ'] all together *)
let result = Pre.pre_tokenize tokenizer " " in
check_strings "Only spaces" result [ "\196\160\196\160\196\160" ];
(* Only punctuation - Python keeps '...' together *)
let result = Pre.pre_tokenize tokenizer "..." in
check_strings "Only punctuation" result [ "..." ];
(* Very long word *)
let long_word = String.make 100 'a' in
let result = Pre.pre_tokenize tokenizer long_word in
equal ~msg:"Long word produces single token" int 1 (List.length result);
(* Mixed whitespace *)
let result = Pre.pre_tokenize tokenizer "hello\tworld\nfoo\rbar" in
equal ~msg:"Handles tabs and newlines" bool true (List.length result > 0)
let test_bert_pretokenizer () =
let test_case text expected =
let result = Pre.pre_tokenize (Pre.bert ()) text in
check_tokenization
(Printf.sprintf "BERT tokenization of %S" text)
result expected
in
(* Basic tokenization *)
test_case "Hello world" [ ("Hello", (0, 5)); ("world", (6, 11)) ];
test_case "Hello, world!"
[ ("Hello", (0, 5)); (",", (5, 6)); ("world", (7, 12)); ("!", (12, 13)) ];
(* Punctuation handling *)
test_case "test." [ ("test", (0, 4)); (".", (4, 5)) ];
test_case "a-b" [ ("a", (0, 1)); ("-", (1, 2)); ("b", (2, 3)) ];
test_case "it's" [ ("it", (0, 2)); ("'", (2, 3)); ("s", (3, 4)) ];
(* Multiple spaces *)
test_case "hello world" [ ("hello", (0, 5)); ("world", (7, 12)) ];
(* Unicode *)
test_case "café" [ ("café", (0, 5)) ];
(* Note: e is 2 bytes in UTF-8 *)
(* Empty and whitespace *)
test_case "" [];
test_case " " []
let test_whitespace_pretokenizer () =
let test_case text expected =
let result = Pre.pre_tokenize (Pre.whitespace ()) text in
check_tokenization
(Printf.sprintf "Whitespace tokenization of %S" text)
result expected
in
(* Pattern is \w+|[^\w\s]+ *)
test_case "Hello world" [ ("Hello", (0, 5)); ("world", (6, 11)) ];
test_case "Hello, world!"
[ ("Hello", (0, 5)); (",", (5, 6)); ("world", (7, 12)); ("!", (12, 13)) ];
test_case "test_var" [ ("test_var", (0, 8)) ];
(* underscore is part of \w *)
test_case "123abc" [ ("123abc", (0, 6)) ];
(* numbers are part of \w *)
test_case "a+b=c"
[
("a", (0, 1)); ("+", (1, 2)); ("b", (2, 3)); ("=", (3, 4)); ("c", (4, 5));
]
let test_whitespace_split () =
let test_case text expected =
let result = Pre.pre_tokenize (Pre.whitespace_split ()) text in
check_tokenization
(Printf.sprintf "WhitespaceSplit of %S" text)
result expected
in
(* Simple split on whitespace *)
test_case "Hello world" [ ("Hello", (0, 5)); ("world", (6, 11)) ];
test_case " Hello world " [ ("Hello", (2, 7)); ("world", (9, 14)) ];
test_case "one\ttwo\nthree"
[ ("one", (0, 3)); ("two", (4, 7)); ("three", (8, 13)) ];
test_case "" [];
test_case " " []
let test_punctuation_pretokenizer () =
(* Test different behaviors *)
let test_isolated text expected =
let tokenizer = Pre.punctuation ~behavior:`Isolated () in
let result = Pre.pre_tokenize tokenizer text in
check_tokenization
(Printf.sprintf "Punctuation Isolated %S" text)
result expected
in
let test_removed text expected =
let tokenizer = Pre.punctuation ~behavior:`Removed () in
let result = Pre.pre_tokenize tokenizer text in
check_tokenization
(Printf.sprintf "Punctuation Removed %S" text)
result expected
in
(* Isolated behavior *)
test_isolated "Hello, world!"
[ ("Hello", (0, 5)); (",", (5, 6)); (" world", (6, 12)); ("!", (12, 13)) ];
(* Removed behavior *)
test_removed "Hello, world!" [ ("Hello", (0, 5)); (" world", (6, 12)) ];
(* Multiple punctuation *)
test_isolated "test...end"
[
("test", (0, 4));
(".", (4, 5));
(".", (5, 6));
(".", (6, 7));
("end", (7, 10));
];
(* Unicode punctuation *)
test_isolated "test—end" [ ("test", (0, 4)); ("—", (4, 7)); ("end", (7, 10)) ]
(* em dash is 3 bytes *)
let test_digits_pretokenizer () =
let test_individual text expected =
let tokenizer = Pre.digits ~individual_digits:true () in
let result = Pre.pre_tokenize tokenizer text in
check_tokenization
(Printf.sprintf "Digits individual %S" text)
result expected
in
let test_grouped text expected =
let tokenizer = Pre.digits ~individual_digits:false () in
let result = Pre.pre_tokenize tokenizer text in
check_tokenization (Printf.sprintf "Digits grouped %S" text) result expected
in
(* Individual digits *)
test_individual "123" [ ("1", (0, 1)); ("2", (1, 2)); ("3", (2, 3)) ];
test_individual "a1b2"
[ ("a", (0, 1)); ("1", (1, 2)); ("b", (2, 3)); ("2", (3, 4)) ];
(* Grouped digits *)
test_grouped "123" [ ("123", (0, 3)) ];
test_grouped "a123b456"
[ ("a", (0, 1)); ("123", (1, 4)); ("b", (4, 5)); ("456", (5, 8)) ];
test_grouped "3.14" [ ("3", (0, 1)); (".", (1, 2)); ("14", (2, 4)) ]
let test_split_pretokenizer () =
let test_case pattern behavior text expected =
let tokenizer = Pre.split ~pattern ~behavior () in
let result = Pre.pre_tokenize tokenizer text in
check_tokenization
(Printf.sprintf "Split pattern=%S behavior=%s text=%S" pattern
(match behavior with
| `Isolated -> "Isolated"
| `Removed -> "Removed"
| `Merged_with_previous -> "MergedPrev"
| `Merged_with_next -> "MergedNext"
| `Contiguous -> "Contiguous")
text)
result expected
in
(* Test different behaviors *)
test_case "," `Isolated "a,b,c"
[
("a", (0, 1)); (",", (1, 2)); ("b", (2, 3)); (",", (3, 4)); ("c", (4, 5));
];
test_case "," `Removed "a,b,c" [ ("a", (0, 1)); ("b", (2, 3)); ("c", (4, 5)) ];
test_case "," `Merged_with_previous "a,b,c"
[ ("a,", (0, 2)); ("b,", (2, 4)); ("c", (4, 5)) ];
test_case "," `Merged_with_next "a,b,c"
[ ("a", (0, 1)); (",b", (1, 3)); (",c", (3, 5)) ];
(* Test with longer pattern *)
test_case "::" `Isolated "a::b::c"
[
("a", (0, 1)); ("::", (1, 3)); ("b", (3, 4)); ("::", (4, 6)); ("c", (6, 7));
]
let test_char_delimiter_split () =
let test_case delim text expected =
let result = Pre.pre_tokenize (Pre.char_delimiter delim) text in
check_tokenization
(Printf.sprintf "CharDelimiterSplit delim='%c' text=%S" delim text)
result expected
in
test_case ',' "a,b,c" [ ("a", (0, 1)); ("b", (2, 3)); ("c", (4, 5)) ];
test_case ' ' "hello world" [ ("hello", (0, 5)); ("world", (6, 11)) ];
test_case '|' "one|two|three"
[ ("one", (0, 3)); ("two", (4, 7)); ("three", (8, 13)) ];
test_case ',' "" [];
test_case ',' "," []
let test_sequence_pretokenizer () =
(* Combine whitespace split then punctuation isolation *)
let tokenizers =
[ Pre.whitespace_split (); Pre.punctuation ~behavior:`Isolated () ]
in
let tokenizer = Pre.sequence tokenizers in
let test_case text expected =
let result = Pre.pre_tokenize tokenizer text in
check_tokenization (Printf.sprintf "Sequence %S" text) result expected
in
(* First splits on whitespace, then isolates punctuation in each piece *)
test_case "Hello, world!"
[ ("Hello", (0, 5)); (",", (5, 6)); ("world", (7, 12)); ("!", (12, 13)) ];
(* Multiple words and punctuation *)
test_case "test. another, example!"
[
("test", (0, 4));
(".", (4, 5));
("another", (6, 13));
(",", (13, 14));
("example", (15, 22));
("!", (22, 23));
]
let test_fixed_length () =
let test_case length text expected =
let result = Pre.pre_tokenize (Pre.fixed_length length) text in
check_tokenization
(Printf.sprintf "FixedLength %d %S" length text)
result expected
in
test_case 3 "abcdefghi" [ ("abc", (0, 3)); ("def", (3, 6)); ("ghi", (6, 9)) ];
test_case 2 "abcde" [ ("ab", (0, 2)); ("cd", (2, 4)); ("e", (4, 5)) ];
test_case 5 "hello" [ ("hello", (0, 5)) ];
test_case 0 "test" [];
test_case 3 "" [];
(* With UTF-8 - counts characters not bytes *)
test_case 2 "café" [ ("ca", (0, 2)); ("fé", (2, 5)) ]
(* e is 2 bytes *)
let test_unicode_scripts () =
let test_case text desc =
let tokenizer = Pre.unicode_scripts () in
let result = Pre.pre_tokenize tokenizer text in
(* Just verify it runs without crashing and produces something reasonable *)
equal
~msg:(Printf.sprintf "UnicodeScripts %s" desc)
bool true
(List.length result >= 0)
in
test_case "Hello world" "Latin text";
test_case "Hello世界" "Mixed Latin and Chinese";
test_case "Привет мир" "Cyrillic";
test_case "مرحبا بالعالم" "Arabic";
test_case "こんにちは世界" "Japanese";
test_case "" "Empty string"
let test_metaspace_basic () =
let test_case text expected =
let result =
Pre.pre_tokenize
(Pre.metaspace ~replacement:'_' ~prepend_scheme:`Always ~split:true ())
text
in
check_strings (Printf.sprintf "Metaspace %S" text) result expected
in
test_case "Hello world" [ "_Hello"; "_world" ];
test_case " starts with space" [ "_starts"; "_with"; "_space" ];
test_case "" []
let () =
run "Pre-tokenizers Test Suite"
[
group "byte_level"
[
test "ByteLevel basic" test_byte_level_basic;
test "ByteLevel prefix space" test_byte_level_prefix_space;
test "ByteLevel special chars" test_byte_level_special_chars;
test "ByteLevel unicode" test_byte_level_unicode;
test "ByteLevel edge cases" test_byte_level_edge_cases;
];
group "bert" [ test "BERT tokenization" test_bert_pretokenizer ];
group "whitespace"
[
test "Whitespace tokenization" test_whitespace_pretokenizer;
test "WhitespaceSplit" test_whitespace_split;
];
group "punctuation"
[ test "Punctuation behaviors" test_punctuation_pretokenizer ];
group "digits" [ test "Digits tokenization" test_digits_pretokenizer ];
group "split"
[
test "Split with patterns" test_split_pretokenizer;
test "CharDelimiterSplit" test_char_delimiter_split;
];
group "sequence"
[ test "Sequence of tokenizers" test_sequence_pretokenizer ];
group "fixed_length" [ test "FixedLength chunks" test_fixed_length ];
group "unicode_scripts" [ test "UnicodeScripts" test_unicode_scripts ];
group "metaspace" [ test "Metaspace basic" test_metaspace_basic ];
]
================================================
FILE: packages/brot/test/test_processors.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
open Windtrap
open Brot
let make_encoding ~ids ~tokens ~type_id =
let len = Array.length ids in
Encoding.create ~ids:(Array.copy ids) ~type_ids:(Array.make len type_id)
~tokens:(Array.copy tokens) ~words:(Array.make len None)
~offsets:(Array.make len (0, 0))
~special_tokens_mask:(Array.make len 0) ~attention_mask:(Array.make len 1)
()
let json_obj pairs =
Jsont.Json.object' (List.map (fun (k, v) -> (Jsont.Json.name k, v)) pairs)
let test_template_multi_special () =
let processor =
Result.get_ok
(Post_processor.of_json
(json_obj
[
("type", Jsont.Json.string "TemplateProcessing");
( "single",
Jsont.Json.list
[
json_obj
[
( "SpecialToken",
json_obj
[
("id", Jsont.Json.string "");
("type_id", Jsont.Json.int 2);
] );
];
json_obj
[
( "Sequence",
json_obj
[
("id", Jsont.Json.string "A");
("type_id", Jsont.Json.int 0);
] );
];
] );
("pair", Jsont.Json.null ());
( "special_tokens",
json_obj
[
( "",
json_obj
[
("id", Jsont.Json.string "");
( "ids",
Jsont.Json.list
[ Jsont.Json.int 100; Jsont.Json.int 101 ] );
( "tokens",
Jsont.Json.list
[
Jsont.Json.string "";
Jsont.Json.string "";
] );
] );
] );
]))
in
let base = make_encoding ~ids:[| 10 |] ~tokens:[| "hello" |] ~type_id:0 in
let encoding =
Post_processor.process processor base ~add_special_tokens:true
in
equal ~msg:"ids" (array int) [| 100; 101; 10 |] (Encoding.ids encoding);
equal ~msg:"tokens" (array string)
[| ""; ""; "hello" |]
(Encoding.tokens encoding);
equal ~msg:"type ids" (array int) [| 2; 2; 0 |] (Encoding.type_ids encoding);
equal ~msg:"special mask" (array int) [| 1; 1; 0 |]
(Encoding.special_tokens_mask encoding);
equal ~msg:"attention mask" (array int) [| 1; 1; 1 |]
(Encoding.attention_mask encoding);
equal ~msg:"added tokens single" int 2
(Post_processor.added_tokens processor ~is_pair:false)
let test_template_pair_type_ids () =
let processor =
Post_processor.template ~single:"$A [SEP]"
~pair:"[CLS]:0 $A:0 [SEP]:0 $B:3 [SEP]:3"
~special_tokens:[ ("[CLS]", 101); ("[SEP]", 102) ]
()
in
let seq_a =
make_encoding ~ids:[| 10; 11 |] ~tokens:[| "hello"; "world" |] ~type_id:0
in
let seq_b = make_encoding ~ids:[| 20 |] ~tokens:[| "pair" |] ~type_id:1 in
let encoding =
Post_processor.process processor ~pair:seq_b seq_a ~add_special_tokens:true
in
equal ~msg:"pair ids" (array int)
[| 101; 10; 11; 102; 20; 102 |]
(Encoding.ids encoding);
equal ~msg:"pair tokens" (array string)
[| "[CLS]"; "hello"; "world"; "[SEP]"; "pair"; "[SEP]" |]
(Encoding.tokens encoding);
equal ~msg:"pair type ids" (array int) [| 0; 0; 0; 0; 3; 3 |]
(Encoding.type_ids encoding);
equal ~msg:"pair special mask" (array int) [| 1; 0; 0; 1; 0; 1 |]
(Encoding.special_tokens_mask encoding);
equal ~msg:"added tokens pair" int 3
(Post_processor.added_tokens processor ~is_pair:true);
let no_special =
Post_processor.process processor ~pair:seq_b seq_a ~add_special_tokens:false
in
equal ~msg:"no-special ids" (array int) (Encoding.ids seq_a)
(Encoding.ids no_special)
let () =
run "Processors"
[
group "template"
[
test "multi-id special expansion" test_template_multi_special;
test "pair template semantics" test_template_pair_type_ids;
];
]
================================================
FILE: packages/brot/test/test_tokenization.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Tokenization tests for brot *)
open Windtrap
open Brot
(* Helper function to tokenize text *)
let tokenize_text text =
(* Pre-tokenize to get all unique tokens *)
let pre_tokens =
Pre_tokenizer.pre_tokenize (Pre_tokenizer.whitespace ()) text
in
let unique_tokens =
List.fold_left
(fun acc (tok, _) -> if List.mem tok acc then acc else tok :: acc)
[] pre_tokens
|> List.rev
in
(* Build vocabulary with all tokens from the text plus extras *)
let all_tokens =
unique_tokens
@
(* Add numbered words for long text test *)
List.init 1000 (fun i -> Printf.sprintf "word%d" i)
in
let vocab = List.mapi (fun i token -> (token, i)) all_tokens in
(* Create WordLevel tokenizer with the vocabulary *)
let tokenizer =
word_level ~vocab ~unk_token:"" ~pre:(Pre_tokenizer.whitespace ()) ()
in
encode tokenizer text |> Encoding.tokens |> Array.to_list
(* Basic Tokenization Tests *)
let test_tokenize_words_simple () =
let tokens = tokenize_text "Hello world!" in
equal ~msg:"simple words" (list string) [ "Hello"; "world"; "!" ] tokens
let test_tokenize_words_punctuation () =
let tokens = tokenize_text "don't stop, it's fun!" in
equal ~msg:"words with punctuation" (list string)
[ "don"; "'"; "t"; "stop"; ","; "it"; "'"; "s"; "fun"; "!" ]
tokens
let test_tokenize_words_numbers () =
let tokens = tokenize_text "I have 42 apples and 3.14 pies" in
equal ~msg:"words with numbers" (list string)
[ "I"; "have"; "42"; "apples"; "and"; "3"; "."; "14"; "pies" ]
tokens
let test_tokenize_words_empty () =
let tokens = tokenize_text "" in
equal ~msg:"empty string" (list string) [] tokens
let test_tokenize_words_whitespace_only () =
let tokens = tokenize_text " \t\n " in
equal ~msg:"whitespace only" (list string) [] tokens
let test_tokenize_words_special_chars () =
let tokens = tokenize_text "hello@world.com #ml $100 C++" in
equal ~msg:"special characters" (list string)
[ "hello"; "@"; "world"; "."; "com"; "#"; "ml"; "$"; "100"; "C"; "++" ]
tokens
(* Character Tokenization Tests *)
let tokenize_chars text =
let chars = ref [] in
String.iter (fun c -> chars := String.make 1 c :: !chars) text;
List.rev !chars
let test_tokenize_chars_ascii () =
let tokens = tokenize_chars "Hi!" in
equal ~msg:"ASCII chars" (list string) [ "H"; "i"; "!" ] tokens
let test_tokenize_chars_unicode () =
let tokens = tokenize_chars "Hello 👋 世界" in
(* Note: UTF-8 encoding means multi-byte chars may appear differently *)
equal ~msg:"has tokens" bool true (List.length tokens > 0)
let test_tokenize_chars_empty () =
let tokens = tokenize_chars "" in
equal ~msg:"empty string chars" (list string) [] tokens
(* Pre-tokenizer Pattern Tests *)
let test_tokenize_regex_words () =
(* Use the helper that sets up vocabulary properly *)
let tokens = tokenize_text "hello-world test_123" in
equal ~msg:"regex words" (list string)
[ "hello"; "-"; "world"; "test_123" ]
tokens
let test_tokenize_regex_custom () =
(* Test with punctuation pre-tokenizer *)
let text = "don't stop!" in
let pre_tokens =
Pre_tokenizer.pre_tokenize (Pre_tokenizer.punctuation ()) text
in
let vocab = List.mapi (fun i (tok, _) -> (tok, i)) pre_tokens in
let tokenizer =
word_level ~vocab ~unk_token:"" ~pre:(Pre_tokenizer.punctuation ()) ()
in
let tokens = encode tokenizer text |> Encoding.tokens |> Array.to_list in
equal ~msg:"has tokens" bool true (List.length tokens > 0)
let test_tokenize_regex_no_match () =
let tokenizer = word_level () in
let tokens =
encode tokenizer "no numbers here" |> Encoding.tokens |> Array.to_list
in
equal ~msg:"regex no match" (list string) [] tokens
(* Unigram Model Tests *)
(* Round-trip lookups *)
let test_unigram_roundtrip () =
let tokens = [ "hello"; "world"; "test" ] in
let vocab = List.map (fun token -> (token, 0.0)) tokens in
let tokenizer = unigram ~vocab () in
List.iteri
(fun expected_id token ->
equal
~msg:(Printf.sprintf "token_to_id '%s'" token)
(option int) (Some expected_id)
(token_to_id tokenizer token);
equal
~msg:(Printf.sprintf "id_to_token %d" expected_id)
(option string) (Some token)
(id_to_token tokenizer expected_id))
tokens
(* token_to_id - out of vocab *)
let test_unigram_token_to_id_oov () =
let tokenizer = unigram ~vocab:[ ("hello", 0.0); ("world", 0.0) ] () in
equal ~msg:"token_to_id out-of-vocab" (option int) None
(token_to_id tokenizer "missing")
(* id_to_token - out of bounds *)
let test_unigram_id_to_token_oob () =
let tokenizer = unigram ~vocab:[ ("hello", 0.0); ("world", 0.0) ] () in
equal ~msg:"id_to_token negative" (option string) None
(id_to_token tokenizer (-1));
equal ~msg:"id_to_token out of bounds" (option string) None
(id_to_token tokenizer 10)
(* Test empty vocabulary *)
let test_unigram_empty_vocab () =
let tokenizer = unigram ~vocab:[] () in
equal ~msg:"empty vocab token_to_id" (option int) None
(token_to_id tokenizer "test");
equal ~msg:"empty vocab id_to_token" (option string) None
(id_to_token tokenizer 0)
(* Test special characters and unicode *)
let test_unigram_special_tokens () =
let tokenizer =
unigram
~vocab:
[
("", 0.0);
("", 0.0);
("", 0.0);
("▁hello", 0.0);
("世界", 0.0);
]
()
in
equal ~msg:"special " (option int) (Some 0)
(token_to_id tokenizer "");
equal ~msg:"special " (option int) (Some 1) (token_to_id tokenizer "");
equal ~msg:"sentencepiece token" (option int) (Some 3)
(token_to_id tokenizer "▁hello");
equal ~msg:"unicode token" (option int) (Some 4) (token_to_id tokenizer "世界");
equal ~msg:"id to unicode" (option string) (Some "世界")
(id_to_token tokenizer 4)
let test_unigram_encode_sequence () =
let tokenizer = unigram ~vocab:[ ("hello", 0.0); ("world", 0.0) ] () in
let encoding = encode tokenizer "hello world" in
let tokens = Encoding.tokens encoding |> Array.to_list in
equal ~msg:"unigram encode tokens" (list string) [ "hello"; "world" ] tokens
let test_pad_token_set_at_construction () =
let vocab = [ ("hello", 0); ("world", 1); ("", 2); ("[PAD]", 3) ] in
let tokenizer =
word_level ~vocab ~unk_token:""
~pre:(Pre_tokenizer.whitespace ())
~specials:[ special "[PAD]" ]
~pad_token:"[PAD]" ()
in
equal ~msg:"pad token set" (option string) (Some "[PAD]")
(pad_token tokenizer);
let pad_id =
match token_to_id tokenizer "[PAD]" with
| Some id -> id
| None -> failwith "missing pad id"
in
let encoding =
encode tokenizer "hello"
~padding:
{
length = `Fixed 3;
direction = `Right;
pad_id = None;
pad_type_id = None;
pad_token = None;
}
in
let ids = Encoding.ids encoding |> Array.to_list in
let pad_ids = List.tl ids in
equal ~msg:"pad id matches configured token" (list int) [ pad_id; pad_id ]
pad_ids
(* Edge Cases *)
let test_tokenize_long_text () =
let text =
String.concat " " (List.init 1000 (fun i -> Printf.sprintf "word%d" i))
in
let tokens = tokenize_text text in
equal ~msg:"long text token count" int 1000 (List.length tokens)
let test_tokenize_repeated_punctuation () =
let tokens = tokenize_text "wow!!! really???" in
equal ~msg:"repeated punctuation" (list string)
[ "wow"; "!!!"; "really"; "???" ]
tokens
let test_tokenize_mixed_whitespace () =
let tokens = tokenize_text "hello\tworld\nthere\r\nfriend" in
equal ~msg:"mixed whitespace" (list string)
[ "hello"; "world"; "there"; "friend" ]
tokens
(* Test Suite *)
let tokenization_tests =
[
(* Words tokenization *)
test "tokenize words simple" test_tokenize_words_simple;
test "tokenize words punctuation" test_tokenize_words_punctuation;
test "tokenize words numbers" test_tokenize_words_numbers;
test "tokenize words empty" test_tokenize_words_empty;
test "tokenize words whitespace only" test_tokenize_words_whitespace_only;
test "tokenize words special chars" test_tokenize_words_special_chars;
(* Character tokenization *)
test "tokenize chars ASCII" test_tokenize_chars_ascii;
test "tokenize chars unicode" test_tokenize_chars_unicode;
test "tokenize chars empty" test_tokenize_chars_empty;
(* Regex tokenization *)
test "tokenize regex words" test_tokenize_regex_words;
test "tokenize regex custom" test_tokenize_regex_custom;
test "tokenize regex no match" test_tokenize_regex_no_match;
(* Edge cases *)
test "tokenize long text" test_tokenize_long_text;
test "tokenize repeated punctuation" test_tokenize_repeated_punctuation;
test "tokenize mixed whitespace" test_tokenize_mixed_whitespace;
(* Unigram model tests *)
test "unigram roundtrip" test_unigram_roundtrip;
test "unigram token_to_id out-of-vocab" test_unigram_token_to_id_oov;
test "unigram id_to_token out-of-bounds" test_unigram_id_to_token_oob;
test "unigram empty vocab" test_unigram_empty_vocab;
test "unigram special tokens" test_unigram_special_tokens;
test "unigram encode sequence" test_unigram_encode_sequence;
test "pad token reassignment updates id" test_pad_token_set_at_construction;
]
let () = run "brot tokenization" [ group "tokenization" tokenization_tests ]
================================================
FILE: packages/brot/test/test_unicode.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Unicode processing tests for brot *)
open Windtrap
open Brot
(* Normalization via public API *)
let test_lowercase_normalization () =
let text = "HELLO WORLD" in
let normalizer = Normalizer.lowercase in
let result = Normalizer.apply normalizer text in
equal ~msg:"lowercase" string "hello world" result
let test_strip_accents_normalization () =
let text = "caf\xC3\xA9 na\xC3\xAFve r\xC3\xA9sum\xC3\xA9" in
let normalizer =
Normalizer.sequence [ Normalizer.nfd; Normalizer.strip_accents ]
in
let result = Normalizer.apply normalizer text in
equal ~msg:"strip accents" string "cafe naive resume" result
let test_normalization_sequence () =
let text = " HELLO World " in
let normalizer =
Normalizer.sequence
[
Normalizer.lowercase;
Normalizer.strip ();
Normalizer.replace ~pattern:"\\s+" ~replacement:" ";
]
in
let result = Normalizer.apply normalizer text in
equal ~msg:"sequence" string "hello world" result
(* Integration with Tokenizer *)
let test_tokenize_with_normalization () =
let text = "HELLO WORLD!" in
let normalizer =
Normalizer.sequence
[
Normalizer.lowercase;
Normalizer.replace ~pattern:"\\s+" ~replacement:" ";
]
in
let tokenizer =
word_level ~normalizer ~pre:(Pre_tokenizer.whitespace ()) ()
in
let tokenizer = add_tokens tokenizer [ "hello"; "world"; "!" ] in
let tokens = encode tokenizer text |> Encoding.tokens |> Array.to_list in
equal ~msg:"normalized tokenization" (list string) [ "hello"; "world"; "!" ]
tokens
let test_tokenize_unicode_words () =
let text = "café résumé naïve" in
let tokenizer = word_level ~pre:(Pre_tokenizer.whitespace ()) () in
let tokenizer = add_tokens tokenizer [ "café"; "résumé"; "naïve" ] in
let tokens = encode tokenizer text |> Encoding.tokens |> Array.to_list in
equal ~msg:"tokenized unicode" bool true (List.length tokens > 0)
let test_malformed_unicode () =
let text = "Hello" ^ String.make 1 '\xFF' ^ String.make 1 '\xFE' ^ "World" in
let tokenizer = chars () in
let tokens = encode tokenizer text |> Encoding.tokens |> Array.to_list in
equal ~msg:"handled malformed" bool true (List.length tokens > 0)
(* Test Suite *)
let unicode_tests =
[
(* Normalization *)
test "lowercase normalization" test_lowercase_normalization;
test "strip accents normalization" test_strip_accents_normalization;
test "normalization sequence" test_normalization_sequence;
(* Integration *)
test "tokenize with normalization" test_tokenize_with_normalization;
test "tokenize unicode words" test_tokenize_unicode_words;
(* Error handling *)
test "malformed unicode" test_malformed_unicode;
]
let () = run "brot unicode" [ group "unicode" unicode_tests ]
================================================
FILE: packages/brot/test/test_vocab.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
open Windtrap
open Brot
let test_vocab_create_empty () =
let tokenizer = word_level () in
let vocab = vocab tokenizer in
equal ~msg:"empty vocab size" int 0 (List.length vocab)
let test_vocab_with_tokenizer () =
let tokenizer = word_level () in
let vocab = vocab tokenizer in
equal ~msg:"initial vocab size" int 0 (List.length vocab)
let test_vocab_add_tokens () =
let tokenizer =
add_tokens
(word_level ~specials:[ special ""; special "" ] ())
[ "hello"; "world" ]
in
let vocab_size = vocab_size tokenizer in
equal ~msg:"vocab size increased" bool true (vocab_size >= 2)
let test_vocab_encode_decode () =
let tokenizer =
add_tokens
(word_level ~pre:(Pre_tokenizer.whitespace ()) ())
[ "hello"; "world" ]
in
let ids = encode tokenizer "hello world" |> Encoding.ids in
equal ~msg:"encoded ids" bool true (Array.length ids > 0);
let decoded = decode tokenizer ids in
equal ~msg:"decoded text" string "hello world" decoded
let test_vocab_batch_encode () =
let tokenizer = add_tokens (Brot.word_level ()) [ "hello"; "world" ] in
let encodings = encode_batch tokenizer [ "hello"; "world" ] in
equal ~msg:"batch size" int 2 (List.length encodings)
let test_vocab_special_tokens () =
let tokenizer =
add_tokens
(word_level ~specials:[ special "[CLS]"; special "[SEP]" ] ())
[ "test" ]
in
let tokens =
encode ~add_special_tokens:true tokenizer "test" |> Encoding.tokens
in
equal ~msg:"tokens emitted" bool true (Array.length tokens > 0)
let test_vocab_save_load () =
let tokenizer =
add_tokens (Brot.word_level ()) [ "hello"; "world"; "test" ]
in
let json = to_json tokenizer in
match from_json json with
| Error msg -> failf "failed to round-trip tokenizer: %s" msg
| Ok reloaded ->
let original_vocab = vocab tokenizer in
let loaded_vocab = vocab reloaded in
equal ~msg:"vocab size matches" int
(List.length original_vocab)
(List.length loaded_vocab);
List.iter
(fun (token, _) ->
equal
~msg:(Printf.sprintf "token %s preserved" token)
bool true
(Option.is_some (token_to_id reloaded token)))
original_vocab
let suite =
[
test "create empty" test_vocab_create_empty;
test "with tokenizer" test_vocab_with_tokenizer;
test "add tokens" test_vocab_add_tokens;
test "encode decode" test_vocab_encode_decode;
test "batch encode" test_vocab_batch_encode;
test "special tokens" test_vocab_special_tokens;
test "save load" test_vocab_save_load;
]
let () = run "Vocabulary tests" [ group "vocab" suite ]
================================================
FILE: packages/brot/test/test_wordpiece.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
open Windtrap
open Brot
let test_wordpiece_basic () =
(* Create a simple vocabulary *)
let vocab =
[
("[UNK]", 0);
("hello", 1);
("world", 2);
("##llo", 3);
("##rld", 4);
("he", 5);
("wo", 6);
]
in
let tokenizer =
wordpiece ~vocab ~unk_token:"[UNK]" ~continuing_subword_prefix:"##" ()
in
(* Test tokenizing a known word *)
let encoding = encode tokenizer "hello" in
let tokens = Encoding.tokens encoding in
equal ~msg:"single token for 'hello'" int 1 (Array.length tokens);
equal ~msg:"token value" string "hello" tokens.(0);
Printf.printf "Tokenized 'hello': ";
Array.iter (Printf.printf "%s ") tokens;
Printf.printf "\n";
equal ~msg:"vocabulary size" int 7 (vocab_size tokenizer)
let test_wordpiece_subwords () =
(* Create vocabulary with subword pieces *)
let vocab =
[
("[UNK]", 0);
("un", 1);
("##able", 2);
("##happy", 3);
("play", 4);
("##ing", 5);
("##ed", 6);
]
in
let tokenizer = wordpiece ~vocab ~unk_token:"[UNK]" () in
(* Test word that can be split into subwords *)
let encoding = encode tokenizer "playing" in
let tokens = Encoding.tokens encoding in
Printf.printf "Tokenized 'playing': ";
Array.iter (Printf.printf "%s ") tokens;
Printf.printf "\n";
equal ~msg:"should split into subwords" int 2 (Array.length tokens);
equal ~msg:"first token" string "play" tokens.(0);
equal ~msg:"second token" string "##ing" tokens.(1)
let test_wordpiece_unknown () =
(* Create minimal vocabulary *)
let vocab = [ ("[UNK]", 0); ("hello", 1) ] in
let tokenizer = wordpiece ~vocab ~unk_token:"[UNK]" () in
(* Test unknown word *)
let encoding = encode tokenizer "goodbye" in
let tokens = Encoding.tokens encoding in
equal ~msg:"unknown word becomes single token" int 1 (Array.length tokens);
equal ~msg:"unknown token" string "[UNK]" tokens.(0)
let test_wordpiece_max_chars () =
(* Create vocabulary *)
let vocab = [ ("[UNK]", 0); ("test", 1) ] in
let tokenizer =
wordpiece ~vocab ~unk_token:"[UNK]" ~max_input_chars_per_word:5 ()
in
(* Test word exceeding max chars *)
let long_word = String.make 10 'a' in
let encoding = encode tokenizer long_word in
let tokens = Encoding.tokens encoding in
equal ~msg:"long word becomes unknown" int 1 (Array.length tokens);
equal ~msg:"unknown token" string "[UNK]" tokens.(0)
let test_wordpiece_save_load () =
(* Create vocabulary *)
let vocab =
[
("[PAD]", 0);
("[UNK]", 1);
("[CLS]", 2);
("[SEP]", 3);
("hello", 4);
("world", 5);
]
in
let tokenizer = wordpiece ~vocab ~unk_token:"[UNK]" () in
(* Save the model *)
let temp_dir = Filename.temp_dir "wordpiece_test" "" in
let files = save_model_files tokenizer ~folder:temp_dir () in
(* Load the model *)
let vocab_file = List.find (fun f -> Filename.check_suffix f ".txt") files in
let loaded_tokenizer = from_model_file ~vocab:vocab_file () in
(* Test that loaded tokenizer works the same *)
let original_tokens = encode tokenizer "hello" |> Encoding.tokens in
let loaded_tokens = encode loaded_tokenizer "hello" |> Encoding.tokens in
equal ~msg:"same number of tokens" int
(Array.length original_tokens)
(Array.length loaded_tokens);
(* Clean up *)
List.iter Sys.remove files;
Unix.rmdir temp_dir
let test_tokenizer_integration () =
(* Create a WordPiece tokenizer using the high-level API *)
let vocab =
[
("[PAD]", 0);
("[UNK]", 1);
("[CLS]", 2);
("[SEP]", 3);
("hello", 4);
("world", 5);
("##ing", 6);
]
in
let tokenizer = wordpiece ~vocab ~unk_token:"[UNK]" () in
(* Test encoding *)
let tokens = encode tokenizer "hello" |> Encoding.tokens |> Array.to_list in
Printf.printf "wordpiece result: ";
List.iter (Printf.printf "%s ") tokens;
Printf.printf "\n";
equal ~msg:"tokenizer produces output" bool true (List.length tokens > 0)
let test_wordpiece_greedy_matching () =
(* Test the greedy longest-match-first algorithm *)
let vocab =
[
("[UNK]", 0);
("un", 1);
("able", 2);
("unable", 3);
(* Longer match should be preferred *)
("##able", 4);
]
in
let tokenizer = wordpiece ~vocab ~unk_token:"[UNK]" () in
(* Should match "unable" as a single token, not "un" + "##able" *)
let encoding = encode tokenizer "unable" in
let tokens = Encoding.tokens encoding in
equal ~msg:"greedy match finds longest token" int 1 (Array.length tokens);
equal ~msg:"matched full word" string "unable" tokens.(0)
let () =
run "WordPiece tests"
[
group "basic"
[
test "basic tokenization" test_wordpiece_basic;
test "subword tokenization" test_wordpiece_subwords;
test "unknown tokens" test_wordpiece_unknown;
test "max input chars" test_wordpiece_max_chars;
test "save and load" test_wordpiece_save_load;
test "tokenizer integration" test_tokenizer_integration;
test "greedy matching" test_wordpiece_greedy_matching;
];
]
================================================
FILE: packages/dune
================================================
(dirs :standard \ nx-oxcaml)
================================================
FILE: packages/fehu/README.md
================================================
# Fehu
Reinforcement learning environment toolkit for OCaml, built on [Rune](../rune/)
Fehu provides type-safe environments, composable wrappers, trajectory
collection, replay buffers, GAE computation, policy evaluation, and
vectorized environments. It follows the Gymnasium interface pattern:
environments expose `reset` and `step` with typed observation and action
spaces.
## Quick Start
Create an environment, run a random policy, and evaluate:
```ocaml
open Fehu
let () =
let rng = Rune.Rng.key 42 in
let env = Fehu_envs.Cartpole.make ~rng () in
(* Run one episode *)
let _obs, _info = Env.reset env () in
let done_ = ref false in
let total_reward = ref 0.0 in
while not !done_ do
let act, _ = Space.sample (Env.action_space env)
~rng:(Env.take_rng env) in
let s = Env.step env act in
total_reward := !total_reward +. s.reward;
done_ := s.terminated || s.truncated
done;
Printf.printf "Episode reward: %.0f\n" !total_reward;
(* Evaluate over 10 episodes *)
let stats = Eval.run env
~policy:(fun _obs ->
let act, _ = Space.sample (Env.action_space env)
~rng:(Env.take_rng env) in act)
~n_episodes:10 ()
in
Printf.printf "Mean reward: %.1f (std: %.1f)\n"
stats.mean_reward stats.std_reward
```
## Features
- **Environments**: typed `('obs, 'act, 'render) Env.t` with lifecycle enforcement (reset before step, auto-guard on terminal states)
- **Spaces**: Discrete, Box, Multi_binary, Multi_discrete, Tuple, Dict, Sequence, Text with sampling, validation, and serialization
- **Wrappers**: `map_observation`, `map_action`, `map_reward`, `clip_action`, `clip_observation`, `time_limit`, and custom wrappers via `Env.wrap`
- **Trajectory collection**: `Collect.rollout` and `Collect.episodes` in structure-of-arrays form with automatic episode resets
- **Replay buffers**: fixed-capacity circular buffer with uniform random sampling (`Buffer.sample`, `Buffer.sample_arrays`)
- **GAE**: generalized advantage estimation with proper terminated/truncated handling (`Gae.compute`, `Gae.returns`)
- **Evaluation**: `Eval.run` computes mean/std reward and episode length over multiple episodes
- **Vectorized environments**: `Vec_env.create` runs multiple environments with batched step and auto-reset
- **Rendering**: `Render.image` and `Render.rollout` for frame capture, `Env.on_render` for recording
- **Built-in environments**: CartPole-v1, MountainCar-v0, GridWorld, RandomWalk
## Libraries
| Library | opam package | Description |
|---------|-------------|-------------|
| `fehu` | `fehu` | Core: environments, spaces, wrappers, collection, buffers, GAE, evaluation |
| `fehu-envs` | `fehu.envs` | Built-in environments (CartPole, MountainCar, GridWorld, RandomWalk) |
## Built-in Environments
| Environment | Observation | Actions | Reward | Termination |
|-------------|------------|---------|--------|-------------|
| CartPole | Box [4] (x, v, θ, ω) | Discrete 2 | +1.0 per step | Pole > ±12° or cart > ±2.4, truncated at 500 |
| MountainCar | Box [2] (position, velocity) | Discrete 3 | −1.0 per step | Position ≥ 0.5 with v ≥ 0, truncated at 200 |
| GridWorld | Multi_discrete [5; 5] | Discrete 4 | +10 at goal, −1 otherwise | Reach (4,4), truncated at 200 |
| RandomWalk | Box [1] | Discrete 2 | −|position| | None, truncated at 200 |
## Contributing
See the [Raven monorepo README](../README.md) for guidelines.
## License
ISC License. See [LICENSE](../LICENSE) for details.
================================================
FILE: packages/fehu/bench/bench_fehu.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
module Gae = Fehu.Gae
module Buffer = Fehu.Buffer
let gamma = 0.99
let lambda = 0.95
let make_arrays n =
let rewards = Array.init n (fun i -> Float.of_int (i mod 5)) in
let values = Array.init n (fun i -> Float.of_int (i mod 10) *. 0.1) in
let terminated = Array.init n (fun i -> i mod 50 = 49) in
let truncated = Array.init n (fun i -> i mod 100 = 99) in
let next_values =
Array.init n (fun i -> Float.of_int ((i + 1) mod 10) *. 0.1)
in
(rewards, values, terminated, truncated, next_values)
let gae_benchmarks () =
let sizes = [ ("256", 256); ("1024", 1024); ("4096", 4096) ] in
let benches = ref [] in
List.iter
(fun (label, n) ->
let rewards, values, terminated, truncated, next_values = make_arrays n in
benches :=
Thumper.bench (Printf.sprintf "compute n=%s" label) (fun () ->
Gae.compute ~rewards ~values ~terminated ~truncated ~next_values
~gamma ~lambda)
:: !benches)
sizes;
List.rev !benches
let gae_from_values_benchmarks () =
let sizes = [ ("256", 256); ("1024", 1024); ("4096", 4096) ] in
let benches = ref [] in
List.iter
(fun (label, n) ->
let rewards, values, terminated, truncated, _ = make_arrays n in
benches :=
Thumper.bench (Printf.sprintf "compute_from_values n=%s" label)
(fun () ->
Gae.compute_from_values ~rewards ~values ~terminated ~truncated
~last_value:0.0 ~gamma ~lambda)
:: !benches)
sizes;
List.rev !benches
let returns_benchmarks () =
let sizes = [ ("256", 256); ("1024", 1024); ("4096", 4096) ] in
let benches = ref [] in
List.iter
(fun (label, n) ->
let rewards, _, terminated, truncated, _ = make_arrays n in
benches :=
Thumper.bench (Printf.sprintf "returns n=%s" label) (fun () ->
Gae.returns ~rewards ~terminated ~truncated ~gamma)
:: !benches)
sizes;
List.rev !benches
let normalize_benchmarks () =
let sizes = [ ("256", 256); ("1024", 1024); ("4096", 4096) ] in
let benches = ref [] in
List.iter
(fun (label, n) ->
let arr = Array.init n (fun i -> Float.of_int i *. 0.01) in
benches :=
Thumper.bench (Printf.sprintf "normalize n=%s" label) (fun () ->
Gae.normalize arr)
:: !benches)
sizes;
List.rev !benches
let fill_buffer capacity =
let buf : (float, float) Buffer.t = Buffer.create ~capacity in
for i = 0 to capacity - 1 do
Buffer.add buf
{
Buffer.observation = Float.of_int i;
action = Float.of_int (i mod 4);
reward = Float.of_int (i mod 10) *. 0.1;
next_observation = Float.of_int (i + 1);
terminated = i mod 50 = 49;
truncated = i mod 100 = 99;
}
done;
buf
let buffer_add_benchmarks () =
let capacity = 10000 in
let buf = fill_buffer capacity in
let tr =
{
Buffer.observation = 0.0;
action = 0.0;
reward = 1.0;
next_observation = 1.0;
terminated = false;
truncated = false;
}
in
[ Thumper.bench "add (full buffer, cap=10000)" (fun () -> Buffer.add buf tr) ]
let buffer_create_benchmarks () =
let sizes = [ ("100", 100); ("1000", 1000); ("10000", 10000) ] in
List.map
(fun (label, n) ->
Thumper.bench (Printf.sprintf "create+fill cap=%s" label) (fun () ->
fill_buffer n))
sizes
let build_benchmarks () =
[
Thumper.group "GAE" (gae_benchmarks ());
Thumper.group "GAE from values" (gae_from_values_benchmarks ());
Thumper.group "Returns" (returns_benchmarks ());
Thumper.group "Normalize" (normalize_benchmarks ());
Thumper.group "Buffer add" (buffer_add_benchmarks ());
Thumper.group "Buffer create" (buffer_create_benchmarks ());
]
let () =
let benchmarks = build_benchmarks () in
Thumper.run "fehu" benchmarks
================================================
FILE: packages/fehu/bench/dune
================================================
(executable
(name bench_fehu)
(libraries nx fehu thumper))
(rule
(alias runtest)
(action
(progn
(run %{exe:bench_fehu.exe} -q)
(diff? fehu.thumper fehu.thumper.corrected))))
================================================
FILE: packages/fehu/bench/fehu.thumper
================================================
# thumper baseline
# version: 1
# suite_name: fehu
# host: 1480401c3b76ed18
# cpu: Apple M1 Max
# ocaml: 5.4.1
# git: 31747323
# dirty: true
# command: /Users/tmattio/Workspace/raven/_build/default/packages/fehu/bench/bench_fehu.exe --bless --quick
buffer_add/add__full_buffer__cap_10000_ alloc_words 0.000000e+00 0.000000e+00 0.000000e+00 inf 5 0
buffer_add/add__full_buffer__cap_10000_ cpu_time 1.226649e-08 1.223096e-08 1.229882e-08 2.766164e-03 5 0
buffer_add/add__full_buffer__cap_10000_ wall_time 1.227909e-08 1.224626e-08 1.230893e-08 2.552074e-03 5 0
buffer_create/create_fill_cap_100 alloc_words 2.116000e+03 2.116000e+03 2.116000e+03 0.000000e+00 5 0
buffer_create/create_fill_cap_100 cpu_time 1.430951e-06 1.419749e-06 1.441342e-06 7.544736e-03 5 0
buffer_create/create_fill_cap_100 wall_time 1.433376e-06 1.418645e-06 1.445267e-06 9.286558e-03 5 1
buffer_create/create_fill_cap_1000 alloc_words 2.101600e+04 2.101600e+04 2.101600e+04 0.000000e+00 5 0
buffer_create/create_fill_cap_1000 cpu_time 1.815813e-05 1.805756e-05 1.829634e-05 6.575071e-03 5 1
buffer_create/create_fill_cap_1000 wall_time 1.817186e-05 1.806923e-05 1.829684e-05 6.262546e-03 5 1
buffer_create/create_fill_cap_10000 alloc_words 2.100160e+05 2.100160e+05 2.100160e+05 0.000000e+00 5 0
buffer_create/create_fill_cap_10000 cpu_time 1.670111e-04 1.660089e-04 1.677894e-04 5.330490e-03 5 0
buffer_create/create_fill_cap_10000 wall_time 1.673292e-04 1.663936e-04 1.679940e-04 4.782197e-03 5 1
gae/compute_n_1024 alloc_words 2.053000e+03 2.053000e+03 2.053000e+03 0.000000e+00 5 0
gae/compute_n_1024 cpu_time 4.432903e-06 4.388681e-06 4.472044e-06 9.402727e-03 5 0
gae/compute_n_1024 wall_time 4.434699e-06 4.389286e-06 4.487822e-06 1.110972e-02 5 0
gae/compute_n_256 alloc_words 5.170000e+02 5.170000e+02 5.170000e+02 0.000000e+00 5 0
gae/compute_n_256 cpu_time 8.899175e-07 8.873495e-07 8.928650e-07 3.098850e-03 5 1
gae/compute_n_256 wall_time 8.917142e-07 8.886773e-07 8.955820e-07 3.871563e-03 5 1
gae/compute_n_4096 alloc_words 8.197000e+03 8.197000e+03 8.197000e+03 0.000000e+00 5 0
gae/compute_n_4096 cpu_time 1.953902e-05 1.938466e-05 1.977415e-05 9.967171e-03 5 0
gae/compute_n_4096 wall_time 1.954775e-05 1.938586e-05 1.974381e-05 9.155682e-03 5 0
gae_from_values/compute_from_values_n_1024 alloc_words 5.130000e+03 5.130000e+03 5.130000e+03 0.000000e+00 5 0
gae_from_values/compute_from_values_n_1024 cpu_time 7.572262e-06 7.451256e-06 7.687885e-06 1.562475e-02 5 0
gae_from_values/compute_from_values_n_1024 wall_time 7.577999e-06 7.456862e-06 7.698184e-06 1.592253e-02 5 0
gae_from_values/compute_from_values_n_256 alloc_words 1.290000e+03 1.290000e+03 1.290000e+03 0.000000e+00 5 0
gae_from_values/compute_from_values_n_256 cpu_time 1.573065e-06 1.567452e-06 1.579053e-06 3.687355e-03 5 0
gae_from_values/compute_from_values_n_256 wall_time 1.573980e-06 1.567490e-06 1.582763e-06 4.851611e-03 5 0
gae_from_values/compute_from_values_n_4096 alloc_words 2.049000e+04 2.049000e+04 2.049000e+04 0.000000e+00 5 0
gae_from_values/compute_from_values_n_4096 cpu_time 3.301526e-05 3.259093e-05 3.336488e-05 1.172117e-02 5 0
gae_from_values/compute_from_values_n_4096 wall_time 3.310795e-05 3.264492e-05 3.349818e-05 1.288613e-02 5 0
normalize/normalize_n_1024 alloc_words 3.083000e+03 3.083000e+03 3.083000e+03 0.000000e+00 5 0
normalize/normalize_n_1024 cpu_time 8.499536e-06 8.421020e-06 8.595588e-06 1.026924e-02 5 2
normalize/normalize_n_1024 wall_time 8.507517e-06 8.425891e-06 8.596021e-06 9.998763e-03 5 2
normalize/normalize_n_256 alloc_words 7.790000e+02 7.790000e+02 7.790000e+02 0.000000e+00 5 0
normalize/normalize_n_256 cpu_time 1.998406e-06 1.990551e-06 2.005793e-06 3.813604e-03 5 0
normalize/normalize_n_256 wall_time 1.999451e-06 1.990430e-06 2.006996e-06 4.142583e-03 5 0
normalize/normalize_n_4096 alloc_words 1.229900e+04 1.229900e+04 1.229900e+04 0.000000e+00 5 0
normalize/normalize_n_4096 cpu_time 3.403318e-05 3.376822e-05 3.423778e-05 6.898544e-03 5 0
normalize/normalize_n_4096 wall_time 3.405293e-05 3.381521e-05 3.430319e-05 7.164939e-03 5 0
returns/returns_n_1024 alloc_words 1.025000e+03 1.025000e+03 1.025000e+03 0.000000e+00 5 0
returns/returns_n_1024 cpu_time 2.979521e-06 2.946996e-06 3.001828e-06 9.201606e-03 5 0
returns/returns_n_1024 wall_time 2.994462e-06 2.963207e-06 3.022511e-06 9.902417e-03 5 0
returns/returns_n_256 alloc_words 2.570000e+02 2.570000e+02 2.570000e+02 0.000000e+00 5 0
returns/returns_n_256 cpu_time 6.113050e-07 6.027021e-07 6.163104e-07 1.113047e-02 5 1
returns/returns_n_256 wall_time 6.123541e-07 6.072009e-07 6.175059e-07 8.414242e-03 5 1
returns/returns_n_4096 alloc_words 4.097000e+03 4.097000e+03 4.097000e+03 0.000000e+00 5 0
returns/returns_n_4096 cpu_time 1.194814e-05 1.181334e-05 1.212174e-05 1.290600e-02 5 0
returns/returns_n_4096 wall_time 1.196891e-05 1.180767e-05 1.211994e-05 1.304521e-02 5 0
================================================
FILE: packages/fehu/doc/01-getting-started.md
================================================
# Getting Started
This guide covers the basics: creating environments, running the step loop,
understanding spaces, and using the built-in environments.
## Installation
```bash
opam install fehu
```
Or build from source:
```bash
git clone https://github.com/raven-ml/raven
cd raven && dune build fehu
```
## Creating an Environment
Environments are created via factory functions in `Fehu_envs`. Randomness is
provided by the implicit RNG scope from `Nx.Rng.run`:
```ocaml
open Fehu
let () = Nx.Rng.run ~seed:42 @@ fun () ->
let env = Fehu_envs.Cartpole.make () in
ignore env
```
The seed controls all randomness in the scope. Use the same seed to get
the same episode sequence.
## The Step Loop
An environment follows a strict lifecycle: `reset` must be called before the
first `step`, and again after any terminal step (terminated or truncated).
```ocaml
open Fehu
let () = Nx.Rng.run ~seed:42 @@ fun () ->
let env = Fehu_envs.Cartpole.make () in
(* Reset returns the initial observation and info *)
let _obs, _info = Env.reset env () in
(* Step returns observation, reward, terminated, truncated, info *)
let s = Env.step env (Space.Discrete.of_int 0) in
Printf.printf "reward: %.1f, terminated: %b, truncated: %b\n"
s.reward s.terminated s.truncated
```
A complete episode loop:
```ocaml
open Fehu
let run_episode env =
let _obs, _info = Env.reset env () in
let done_ = ref false in
let total_reward = ref 0.0 in
while not !done_ do
let act = Space.sample (Env.action_space env) in
let s = Env.step env act in
total_reward := !total_reward +. s.reward;
done_ := s.terminated || s.truncated
done;
!total_reward
let () = Nx.Rng.run ~seed:42 @@ fun () ->
let env = Fehu_envs.Cartpole.make () in
let _reward = run_episode env in ()
```
## Spaces
Spaces define the valid observations and actions for an environment. They
provide sampling, validation, and serialization.
### Discrete
Integer choices. Used for environments with a finite number of actions (e.g.,
left/right).
```ocaml
open Fehu
let space = Space.Discrete.create 4 (* actions 0, 1, 2, 3 *)
let _n = Space.Discrete.n space (* 4 *)
(* Sample a random action (requires an Nx.Rng scope) *)
let _act = Nx.Rng.run ~seed:0 @@ fun () ->
Space.sample space
(* Convert between int and discrete element *)
let act = Space.Discrete.of_int 2
let _i = Space.Discrete.to_int act (* 2 *)
(* Check membership *)
let _valid = Space.contains space act (* true *)
```
### Box
Continuous vectors with per-dimension bounds. Used for continuous observations
(e.g., position, velocity) and continuous actions.
```ocaml
open Fehu
let space = Space.Box.create
~low:[| -1.0; -2.0 |]
~high:[| 1.0; 2.0 |]
let _low, _high = Space.Box.bounds space
let _obs = Nx.Rng.run ~seed:0 @@ fun () -> Space.sample space
```
### Other Space Types
- **Multi_binary**: binary vectors of fixed length (multi-label scenarios)
- **Multi_discrete**: multiple discrete axes with independent cardinalities
- **Tuple**: fixed-length heterogeneous sequences
- **Dict**: named fields with different space types
- **Sequence**: variable-length homogeneous sequences
- **Text**: character strings from a fixed alphabet
All spaces support `contains`, `sample`, `pack`/`unpack` (to/from the
universal `Value.t` type), and `boundary_values`.
## Available Environments
### CartPole
Classic cart-pole balancing. Push a cart left or right to keep a pole upright.
Reward is +1.0 per step. Terminates when the pole exceeds +/-12 degrees or the
cart leaves +/-2.4. Truncates at 500 steps.
- **Observation**: Box [4] -- x, x_dot, theta, theta_dot
- **Actions**: Discrete 2 -- 0 = push left, 1 = push right
```ocaml
let _env = Nx.Rng.run ~seed:42 @@ fun () -> Fehu_envs.Cartpole.make ()
```
### MountainCar
A car in a valley must build momentum to climb a hill. Reward is -1.0 per
step. Terminates when position >= 0.5 with non-negative velocity. Truncates at
200 steps.
- **Observation**: Box [2] -- position, velocity
- **Actions**: Discrete 3 -- 0 = push left, 1 = coast, 2 = push right
```ocaml
let _env = Nx.Rng.run ~seed:42 @@ fun () -> Fehu_envs.Mountain_car.make ()
```
### GridWorld
5x5 grid navigation with an obstacle. Agent starts at (0,0), goal at (4,4),
obstacle at (2,2). Reward is +10.0 at goal, -1.0 otherwise. Truncates at 200
steps.
- **Observation**: Multi_discrete [5; 5] -- (row, col)
- **Actions**: Discrete 4 -- 0 = up, 1 = down, 2 = left, 3 = right
```ocaml
let _env = Nx.Rng.run ~seed:42 @@ fun () -> Fehu_envs.Grid_world.make ()
```
### RandomWalk
One-dimensional random walk on [-10, 10]. Reward is -|position|. Terminates at
boundaries or after 200 steps.
- **Observation**: Box [1] in [-10.0, 10.0]
- **Actions**: Discrete 2 -- 0 = left, 1 = right
```ocaml
let _env = Nx.Rng.run ~seed:42 @@ fun () -> Fehu_envs.Random_walk.make ()
```
## Render Modes
Environments can optionally render their state. Pass `~render_mode` when
creating the environment:
```ocaml
open Fehu
let () = Nx.Rng.run ~seed:42 @@ fun () ->
let env = Fehu_envs.Cartpole.make
~render_mode:`Ansi () in
let _obs, _info = Env.reset env () in
let _s = Env.step env (Space.Discrete.of_int 0) in
(* Render after reset or step *)
match Env.render env with
| Some text -> print_endline text
| None -> ()
```
Supported render modes vary by environment: `Ansi` for text output,
`Rgb_array` for pixel frames, `Human` for interactive display.
## Next Steps
- [Environments and Wrappers](../02-environments/) -- custom environments, wrappers, rendering, vectorized environments
- [Collection and Evaluation](../03-collection-and-evaluation/) -- trajectory collection, replay buffers, GAE, evaluation
================================================
FILE: packages/fehu/doc/02-environments.md
================================================
# Environments and Wrappers
This guide covers creating custom environments, composing wrappers, rendering,
and running vectorized environments.
## The Env.t Type
An environment `('obs, 'act, 'render) Env.t` is parameterized by its
observation type, action type, and render type. The type system ensures that
policies, wrappers, and collection utilities all agree on these types.
The lifecycle is strict:
1. Call `Env.reset` to get the initial observation
2. Call `Env.step` with an action to advance one timestep
3. When `terminated` or `truncated` is true, call `Env.reset` again
4. Call `Env.close` when done (optional, releases resources)
Calling `step` before `reset`, or after a terminal step without resetting,
raises `Invalid_argument`.
## Creating Custom Environments
Use `Env.create` to build an environment from `reset` and `step` functions.
Both receive the environment handle as their first argument, which provides
access to spaces and lifecycle state. Random keys are drawn from the implicit
RNG scope (see below).
```ocaml
open Fehu
(* A simple counting environment: agent must choose action 1 *)
let make_counter () =
let count = ref 0 in
Env.create
~id:"Counter-v0"
~observation_space:(Space.Discrete.create 100)
~action_space:(Space.Discrete.create 2)
~reset:(fun _env ?options:_ () ->
count := 0;
Space.Discrete.of_int 0, Info.empty)
~step:(fun _env action ->
let a = Space.Discrete.to_int action in
if a = 1 then incr count else count := 0;
let obs = Space.Discrete.of_int !count in
let terminated = !count >= 10 in
Env.step_result ~observation:obs
~reward:(if a = 1 then 1.0 else -1.0)
~terminated ())
()
```
### RNG Management
Environments draw random keys from the implicit RNG scope established by
`Nx.Rng.run`. Any call to `Space.sample` or other random operations inside
`reset` and `step` callbacks will use this scope automatically:
```ocaml
let make_noisy_env () =
Env.create
~observation_space:(Space.Box.create
~low:[| 0.0 |] ~high:[| 1.0 |])
~action_space:(Space.Discrete.create 2)
~reset:(fun env ?options:_ () ->
let obs = Space.sample (Env.observation_space env) in
obs, Info.empty)
~step:(fun env _action ->
let obs = Space.sample (Env.observation_space env) in
Env.step_result ~observation:obs
~reward:1.0 ())
()
```
## Wrappers
Wrappers transform an environment's observations, actions, or rewards without
modifying the inner environment. They compose: wrap a wrapper to stack
transformations.
### map_observation
Transform observations from reset and step:
```ocaml
open Fehu
(* Normalize observations to [0, 1] *)
let env = Env.map_observation
~observation_space:(Space.Box.create
~low:[| 0.0; 0.0; 0.0; 0.0 |]
~high:[| 1.0; 1.0; 1.0; 1.0 |])
~f:(fun obs info ->
(* obs is a float32 tensor, transform it *)
let normalized = normalize_fn obs in
normalized, info)
env
```
The function `f` receives both the observation and the info dictionary,
returning both. This allows wrappers to pass metadata through info.
### map_action
Transform actions before they reach the inner environment:
```ocaml
(* Remap discrete actions *)
let env = Env.map_action
~action_space:(Space.Discrete.create 3)
~f:(fun act ->
(* Map from 3-action to 2-action space *)
let i = Space.Discrete.to_int act in
Space.Discrete.of_int (if i >= 2 then 1 else i))
env
```
### map_reward
Transform rewards after each step:
```ocaml
(* Scale rewards *)
let env = Env.map_reward
~f:(fun ~reward ~info -> reward *. 0.01, info)
env
```
### clip_action
Clamp continuous actions to the action space bounds. The wrapper relaxes the
action space to accept any float values, then clips before forwarding:
```ocaml
(* Works with Box action spaces *)
let env = Env.clip_action env
```
### clip_observation
Clamp observations to specified bounds:
```ocaml
let env = Env.clip_observation
~low:[| -1.0; -1.0 |]
~high:[| 1.0; 1.0 |]
env
```
### time_limit
Enforce a maximum episode length. When the limit is reached, the step's
`truncated` flag is set to true:
```ocaml
let env = Env.time_limit ~max_episode_steps:200 env
```
### Custom Wrappers with Env.wrap
For transformations that need full control over reset and step, use `Env.wrap`.
The wrapper shares the inner environment's lifecycle (RNG, closed flag, reset
flag):
```ocaml
open Fehu
(* A wrapper that tracks episode reward *)
let with_episode_reward env =
let episode_reward = ref 0.0 in
Env.wrap
~observation_space:(Env.observation_space env)
~action_space:(Env.action_space env)
~reset:(fun inner ?options () ->
episode_reward := 0.0;
Env.reset inner ?options ())
~step:(fun inner action ->
let s = Env.step inner action in
episode_reward := !episode_reward +. s.reward;
let info =
if s.terminated || s.truncated then
Info.set "episode_reward"
(Info.float !episode_reward) s.info
else s.info
in
{ s with info })
env
```
## Rendering
Environments support optional rendering via render modes. Pass
`~render_mode` at creation time:
```ocaml
let env = Fehu_envs.Grid_world.make
~render_mode:`Ansi ()
let _obs, _info = Env.reset env ()
match Env.render env with
| Some (Text s) -> print_endline s
| _ -> ()
```
### Render Rollout
`Render.rollout` runs a policy and feeds rendered frames to a sink function:
```ocaml
open Fehu
(* Collect rendered frames from a policy rollout *)
let frames = ref [] in
Render.rollout env
~policy:(fun _obs -> Space.sample (Env.action_space env))
~steps:100
~sink:(fun img -> frames := img :: !frames)
()
```
### Recording with on_render
`Render.on_render` wraps an environment so that every frame after reset and
step is passed to a sink:
```ocaml
let env = Render.on_render
~sink:(fun img -> save_frame img)
env
```
## Vectorized Environments
`Vec_env` runs multiple environment instances with batched inputs and outputs.
Terminated or truncated episodes are automatically reset.
```ocaml
open Fehu
let () = Nx.Rng.run ~seed:42 @@ fun () ->
(* Create 4 parallel environments *)
let envs = List.init 4 (fun _ -> Fehu_envs.Cartpole.make ()) in
let vec = Vec_env.create envs in
let n = Vec_env.num_envs vec in (* 4 *)
(* Reset all environments *)
let _observations, _infos = Vec_env.reset vec () in
(* Step all environments with an array of actions *)
let actions = Array.init n (fun _ -> Space.Discrete.of_int 0) in
let _s = Vec_env.step vec actions in
(* _s.observations, _s.rewards, _s.terminated, _s.truncated *)
(* Clean up *)
Vec_env.close vec
```
All environments must have structurally identical observation and action spaces
(checked via `Space.equal_spec`). On terminal steps, the original terminal
observation is stored in the step info under `"final_observation"` as a packed
`Value.t`, and the terminal info under `"final_info"`.
## Next Steps
- [Getting Started](../01-getting-started/) -- installation, environments, spaces, step loop
- [Collection and Evaluation](../03-collection-and-evaluation/) -- trajectory collection, replay buffers, GAE, evaluation
================================================
FILE: packages/fehu/doc/03-collection-and-evaluation.md
================================================
# Collection, Buffers, and Evaluation
This guide covers trajectory collection, replay buffers, generalized advantage
estimation, and policy evaluation.
## Trajectory Collection
`Collect` gathers agent-environment interactions into structure-of-arrays form
for batch processing.
### Rollout
`Collect.rollout` collects a fixed number of transitions. It resets the
environment at the start and automatically on episode boundaries:
```ocaml
open Fehu
let () = Nx.Rng.run ~seed:42 @@ fun () ->
let env = Fehu_envs.Cartpole.make () in
(* The policy receives an observation and returns
(action, log_prob option, value_estimate option) *)
let policy _obs =
let act = Space.sample (Env.action_space env) in
(act, None, None)
in
let _trajectory = Collect.rollout env ~policy ~n_steps:1024 in ()
```
The returned trajectory contains parallel arrays:
```ocaml
let n = Collect.length trajectory (* 1024 *)
let obs = trajectory.observations (* 'obs array *)
let acts = trajectory.actions (* 'act array *)
let rews = trajectory.rewards (* float array *)
let next_obs = trajectory.next_observations (* 'obs array *)
let terms = trajectory.terminated (* bool array *)
let truncs = trajectory.truncated (* bool array *)
let infos = trajectory.infos (* Info.t array *)
let log_ps = trajectory.log_probs (* float array option *)
let vals = trajectory.values (* float array option *)
```
When the policy returns `Some log_prob` or `Some value`, those are collected
into `log_probs` and `values`. When any return is `None`, the corresponding
field is `None` for the entire trajectory.
### Policy Signature
The policy function has the signature:
```
'obs -> 'act * float option * float option
```
The three components are:
1. **action**: the action to take
2. **log_prob** (optional): the log-probability of the action under the current policy, used for importance sampling in PPO
3. **value** (optional): the estimated value of the current state, used for GAE computation
For a simple random policy, return `None` for both:
```ocaml
let random_policy _obs =
let act = Space.sample (Env.action_space env) in
(act, None, None)
```
For a neural network policy with value head:
```ocaml
let nn_policy obs =
let logits, value = forward_pass model obs in
let act = sample_from_logits logits in
let log_prob = log_prob_of logits act in
(act, Some log_prob, Some value)
```
### Episodes
`Collect.episodes` collects complete episodes, one trajectory per episode:
```ocaml
let episodes = Collect.episodes env
~policy ~n_episodes:10
~max_steps:500 ()
(* episodes is a ('obs, 'act) Collect.t list *)
let total_rewards = List.map (fun traj ->
Array.fold_left (+.) 0.0 traj.rewards) episodes
```
Each episode runs until termination, truncation, or `max_steps` (default
1000).
### Concatenating Trajectories
`Collect.concat` merges multiple trajectories into one:
```ocaml
let combined = Collect.concat [traj1; traj2; traj3]
```
Optional fields (`log_probs`, `values`) are kept only if present in all inputs.
## Replay Buffers
`Buffer` provides a fixed-capacity circular buffer for off-policy experience
storage. It stores individual transitions and supports uniform random sampling.
### Creating and Filling
```ocaml
open Fehu
let buf = Buffer.create ~capacity:10_000
(* Add transitions one at a time *)
Buffer.add buf {
observation = obs;
action = act;
reward = 1.0;
next_observation = next_obs;
terminated = false;
truncated = false;
}
let n = Buffer.size buf (* number of stored transitions *)
let full = Buffer.is_full buf (* true when at capacity *)
let cap = Buffer.capacity buf (* 10000 *)
```
When the buffer is full, new transitions overwrite the oldest ones.
### Sampling
Draw a batch of transitions uniformly at random (with replacement):
```ocaml
let batch = Nx.Rng.run ~seed:0 @@ fun () ->
Buffer.sample buf ~batch_size:64
(* batch is a transition array *)
let _obs_0 = batch.(0).observation
let _rew_0 = batch.(0).reward
```
For structure-of-arrays form (more convenient for training):
```ocaml
let (observations, actions, rewards,
next_observations, terminated, truncated) =
Nx.Rng.run ~seed:0 @@ fun () ->
Buffer.sample_arrays buf ~batch_size:64
```
### Clearing
```ocaml
Buffer.clear buf (* removes all transitions, keeps storage allocated *)
```
## Generalized Advantage Estimation
`Gae` computes advantages and returns for policy gradient methods. It correctly
handles the distinction between terminated and truncated episodes:
- **Terminated**: the episode ended naturally (e.g., pole fell). Bootstrap
value is zero.
- **Truncated**: the episode was cut short (e.g., time limit). Bootstrap value
comes from `next_values`.
### Computing Advantages
```ocaml
open Fehu
(* From a trajectory with value estimates *)
let advantages, returns = Gae.compute
~rewards:trajectory.rewards
~values:(Option.get trajectory.values)
~terminated:trajectory.terminated
~truncated:trajectory.truncated
~next_values (* V(s_{t+1}) for each step *)
~gamma:0.99 (* discount factor *)
~lambda:0.95 (* GAE smoothing parameter *)
```
When you have values from a value network and the last value estimate,
`compute_from_values` builds `next_values` for you:
```ocaml
let advantages, returns = Gae.compute_from_values
~rewards:trajectory.rewards
~values:(Option.get trajectory.values)
~terminated:trajectory.terminated
~truncated:trajectory.truncated
~last_value:0.0 (* V(s_T) for the final state *)
~gamma:0.99
~lambda:0.95
```
### Monte Carlo Returns
For simpler algorithms that do not need advantages:
```ocaml
let rets = Gae.returns
~rewards:trajectory.rewards
~terminated:trajectory.terminated
~truncated:trajectory.truncated
~gamma:0.99
```
### Normalizing Advantages
Normalize to zero mean and unit variance for training stability:
```ocaml
let normalized = Gae.normalize advantages
(* or with custom epsilon *)
let normalized = Gae.normalize ~eps:1e-6 advantages
```
## Policy Evaluation
`Eval.run` runs a deterministic or stochastic policy over multiple episodes
and reports summary statistics:
```ocaml
open Fehu
let () = Nx.Rng.run ~seed:42 @@ fun () ->
let env = Fehu_envs.Cartpole.make () in
(* Evaluate a random policy *)
let stats = Eval.run env
~policy:(fun _obs -> Space.sample (Env.action_space env))
~n_episodes:100
~max_steps:500
()
in
Printf.printf
"Episodes: %d, Mean reward: %.1f +/- %.1f, Mean length: %.0f\n"
stats.n_episodes
stats.mean_reward
stats.std_reward
stats.mean_length
```
The evaluation policy has a simpler signature than the collection policy: it
only returns an action, not log-probs or value estimates:
```
'obs -> 'act
```
`Eval.run` resets the environment between episodes. Default `n_episodes` is 10
and default `max_steps` is 1000.
## Putting It Together
A typical PPO-style training iteration using these utilities:
```ocaml
open Fehu
(* 1. Collect rollout *)
let trajectory = Collect.rollout env
~policy:(fun obs ->
let act, log_prob, value = nn_policy obs in
(act, Some log_prob, Some value))
~n_steps:2048
(* 2. Compute advantages *)
let last_value = estimate_value model last_obs in
let advantages, returns = Gae.compute_from_values
~rewards:trajectory.rewards
~values:(Option.get trajectory.values)
~terminated:trajectory.terminated
~truncated:trajectory.truncated
~last_value
~gamma:0.99 ~lambda:0.95
let advantages = Gae.normalize advantages
(* 3. Update policy using trajectory data + advantages *)
(* ... your PPO update here ... *)
(* 4. Evaluate *)
let stats = Eval.run env
~policy:(fun obs -> greedy_action model obs)
~n_episodes:10 ()
```
## Next Steps
- [Getting Started](../01-getting-started/) -- installation, environments, spaces, step loop
- [Environments and Wrappers](../02-environments/) -- custom environments, wrappers, rendering, vectorized environments
================================================
FILE: packages/fehu/doc/04-gymnasium-comparison.md
================================================
# Fehu vs. Gymnasium -- A Practical Comparison
This guide explains how Fehu's reinforcement learning API relates to Python's [Gymnasium](https://gymnasium.farama.org/) (and [Stable Baselines3](https://stable-baselines3.readthedocs.io/) for collection/buffer/GAE), focusing on:
* How core concepts map (Env, Space, step loop, wrappers)
* Where the APIs feel similar vs. deliberately different
* How to translate common Gymnasium patterns into Fehu
If you already use Gymnasium, this should be enough to become productive in Fehu quickly.
---
## 1. Big-Picture Differences
| Aspect | Gymnasium (Python) | Fehu (OCaml) |
| --------------------- | -------------------------------------------------- | -------------------------------------------------------------------- |
| Language | Dynamic, interpreted | Statically typed, compiled |
| Environment type | `gymnasium.Env` | `('obs, 'act, 'render) Env.t` |
| Observation/action | Untyped (`np.ndarray`, `int`, etc.) | Parametric: `'obs` and `'act` tracked in the type |
| Spaces | `gymnasium.spaces.*` | `'a Space.t` with typed modules (`Space.Discrete`, `Space.Box`, ...) |
| Step result | Tuple `(obs, reward, terminated, truncated, info)` | Record `Env.step` with named fields |
| Wrappers | Subclassing `gymnasium.Wrapper` | `Env.wrap` or composable combinators (`map_observation`, etc.) |
| Vectorized envs | `gymnasium.vector.SyncVectorEnv` | `Vec_env.create` |
| Trajectory collection | External (Stable Baselines3, TorchRL) | Built-in: `Collect.rollout`, `Collect.episodes` |
| Replay buffers | External (Stable Baselines3, TorchRL) | Built-in: `Buffer.create`, `Buffer.add`, `Buffer.sample` |
| GAE | External (Stable Baselines3) | Built-in: `Gae.compute`, `Gae.returns`, `Gae.normalize` |
| Policy evaluation | Manual loop or SB3 `evaluate_policy` | Built-in: `Eval.run` |
| RNG | `np.random` / seed passed to `env.reset(seed=...)` | Implicit scope via `Nx.Rng.run ~seed` |
| Rendering | String mode `"human"`, `"rgb_array"` | Polymorphic variants `` `Human ``, `` `Rgb_array ``, etc. |
| Mutability | Environments are mutable objects | Environments are immutable handles; state is internal |
**Fehu semantics to know (read once):**
- `Env.reset` must be called before `Env.step`. After a terminal step, another `reset` is required.
- Spaces validate observations and actions automatically -- `Env.step` raises if an action is outside the action space.
- RNG is scoped: wrap your code in `Nx.Rng.run ~seed:42 (fun () -> ...)` instead of passing seeds to individual calls.
- Trajectory collection, replay buffers, GAE, and evaluation are built into Fehu, not external libraries.
---
## 2. Spaces
### 2.1 Discrete
**Gymnasium**
```python
import gymnasium as gym
space = gym.spaces.Discrete(5) # {0, 1, 2, 3, 4}
space = gym.spaces.Discrete(5, start=1) # {1, 2, 3, 4, 5}
sample = space.sample()
assert space.contains(sample)
```
**Fehu**
```ocaml
open Fehu
let space = Space.Discrete.create 5 (* {0, 1, 2, 3, 4} *)
let space = Space.Discrete.create ~start:1 5 (* {1, 2, 3, 4, 5} *)
let sample = Space.sample space
let valid = Space.contains space sample
let n = Space.Discrete.n space (* 5 *)
let start = Space.Discrete.start space (* 1 *)
(* Convert between discrete elements and ints *)
let action = Space.Discrete.of_int 3
let value = Space.Discrete.to_int action
```
Discrete elements are `(int32, Nx.int32_elt) Nx.t` scalars, not bare OCaml ints.
### 2.2 Box (continuous)
**Gymnasium**
```python
import numpy as np
space = gym.spaces.Box(
low=np.array([-1.0, -2.0]),
high=np.array([1.0, 2.0]),
dtype=np.float32,
)
sample = space.sample()
```
**Fehu**
```ocaml
let space =
Space.Box.create
~low:[| -1.0; -2.0 |]
~high:[| 1.0; 2.0 |]
let sample = Space.sample space
let (low, high) = Space.Box.bounds space
```
Box elements are `(float, Nx.float32_elt) Nx.t` tensors. Infinite bounds are allowed; sampling falls back to uniform draws in `[-1e6, 1e6]` clamped to bounds.
### 2.3 Multi_binary
**Gymnasium**
```python
space = gym.spaces.MultiBinary(4) # {0,1}^4
```
**Fehu**
```ocaml
let space = Space.Multi_binary.create 4
```
Elements are `(int32, Nx.int32_elt) Nx.t` vectors with values 0 or 1.
### 2.4 Multi_discrete
**Gymnasium**
```python
space = gym.spaces.MultiDiscrete([3, 5, 2]) # 3 axes: {0..2}, {0..4}, {0..1}
```
**Fehu**
```ocaml
let space = Space.Multi_discrete.create [| 3; 5; 2 |]
```
### 2.5 Composite spaces
**Gymnasium**
```python
space = gym.spaces.Tuple((
gym.spaces.Discrete(3),
gym.spaces.Box(low=0.0, high=1.0, shape=(2,)),
))
space = gym.spaces.Dict({
"position": gym.spaces.Box(low=-10.0, high=10.0, shape=(3,)),
"velocity": gym.spaces.Box(low=-1.0, high=1.0, shape=(3,)),
})
```
**Fehu**
```ocaml
let space =
Space.Tuple.create [
Space.Pack (Space.Discrete.create 3);
Space.Pack (Space.Box.create ~low:[| 0.0; 0.0 |] ~high:[| 1.0; 1.0 |]);
]
let space =
Space.Dict.create [
("position", Space.Pack (Space.Box.create ~low:[| -10.; -10.; -10. |] ~high:[| 10.; 10.; 10. |]));
("velocity", Space.Pack (Space.Box.create ~low:[| -1.; -1.; -1. |] ~high:[| 1.; 1.; 1. |]));
]
```
Composite space elements use `Value.t` for heterogeneous data: `Tuple.element = Value.t list`, `Dict.element = (string * Value.t) list`.
### 2.6 Sequence and Text
**Gymnasium**
```python
space = gym.spaces.Sequence(gym.spaces.Discrete(5), seed=42)
space = gym.spaces.Text(max_length=32, charset="abcdef")
```
**Fehu**
```ocaml
let space = Space.Sequence.create ~max_length:10 (Space.Discrete.create 5)
let space = Space.Text.create ~charset:"abcdef" ~max_length:32 ()
```
### 2.7 Common operations
All space types share the same interface:
```ocaml
let sample = Space.sample space (* random element *)
let valid = Space.contains space sample (* membership test *)
let spec = Space.spec space (* structural description *)
let shape = Space.shape space (* dimensionality, if defined *)
(* Serialization via Value.t *)
let packed = Space.pack space sample
let unpacked = Space.unpack space packed (* (element, string) result *)
(* Edge cases for testing *)
let edges = Space.boundary_values space
```
---
## 3. Creating Environments
### 3.1 From a registry
**Gymnasium**
```python
env = gym.make("CartPole-v1", render_mode="human")
```
**Fehu** does not have a global registry. Environments are constructed directly:
```ocaml
let env =
Env.create
~id:"CartPole-v1"
~observation_space:(Space.Box.create
~low:[| -4.8; Float.neg_infinity; -0.418; Float.neg_infinity |]
~high:[| 4.8; Float.infinity; 0.418; Float.infinity |])
~action_space:(Space.Discrete.create 2)
~render_mode:`Human
~render_modes:["human"; "rgb_array"]
~reset:(fun _env ?options:_ () ->
let obs = (* initial state *) in
(obs, Info.empty))
~step:(fun _env action ->
let obs = (* next state *) in
Env.step_result ~observation:obs ~reward:1.0 ())
()
```
`Env.create` takes the observation space, action space, and two callbacks: `reset` and `step`. Optional `render` and `close` callbacks handle visualization and cleanup.
### 3.2 Step result construction
**Gymnasium** returns a flat tuple from `env.step()`:
```python
obs, reward, terminated, truncated, info = env.step(action)
```
**Fehu** uses a record with named fields, and provides a convenience constructor with defaults:
```ocaml
(* Inside a step callback *)
Env.step_result
~observation:obs
~reward:1.0
~terminated:false
~truncated:false
~info:Info.empty
()
(* Defaults: reward=0., terminated=false, truncated=false, info=Info.empty *)
Env.step_result ~observation:obs ()
```
---
## 4. Step Loop
### 4.1 Basic episode
**Gymnasium**
```python
env = gym.make("CartPole-v1")
obs, info = env.reset(seed=42)
total_reward = 0.0
while True:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
if terminated or truncated:
break
env.close()
```
**Fehu**
```ocaml
let () =
Nx.Rng.run ~seed:42 (fun () ->
let env = (* create environment *) in
let (obs, _info) = Env.reset env () in
let obs = ref obs in
let total_reward = ref 0.0 in
let done_ = ref false in
while not !done_ do
let action = Space.sample (Env.action_space env) in
let step = Env.step env action in
obs := step.observation;
total_reward := !total_reward +. step.reward;
done_ := step.terminated || step.truncated
done;
Env.close env)
```
Key differences:
- RNG is scoped with `Nx.Rng.run ~seed:42` rather than passed to `reset`.
- Step results are accessed by field name (`step.observation`, `step.reward`).
- `Env.step` raises `Invalid_argument` if called without a prior `reset` or after a terminal step without resetting.
### 4.2 Multiple episodes
**Gymnasium**
```python
for episode in range(10):
obs, info = env.reset()
done = False
while not done:
action = policy(obs)
obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
```
**Fehu** -- manual loop or use `Collect.episodes`:
```ocaml
(* Manual *)
let () =
Nx.Rng.run ~seed:0 (fun () ->
let env = (* create environment *) in
for _ep = 0 to 9 do
let (obs, _info) = Env.reset env () in
let obs = ref obs in
let done_ = ref false in
while not !done_ do
let action = policy !obs in
let step = Env.step env action in
obs := step.observation;
done_ := step.terminated || step.truncated
done
done;
Env.close env)
(* Or use Collect.episodes directly *)
let trajs =
Nx.Rng.run ~seed:0 (fun () ->
let env = (* create environment *) in
Collect.episodes env
~policy:(fun obs -> (policy obs, None, None))
~n_episodes:10 ())
```
---
## 5. Wrappers
### 5.1 Gymnasium approach: subclassing
**Gymnasium**
```python
class NormalizeObservation(gym.Wrapper):
def __init__(self, env, mean, std):
super().__init__(env)
self.mean = mean
self.std = std
def observation(self, obs):
return (obs - self.mean) / self.std
env = NormalizeObservation(env, mean=0.0, std=1.0)
```
### 5.2 Fehu approach: composable functions
**Fehu** provides `Env.wrap` for full control and specialized combinators for common patterns.
**`map_observation`** -- transform observations:
```ocaml
let normalized_env =
Env.map_observation
~observation_space:obs_space
~f:(fun obs _info ->
let normalized = (* normalize obs *) in
(normalized, Info.empty))
env
```
**`map_action`** -- transform actions before passing to the inner env:
```ocaml
let remapped_env =
Env.map_action
~action_space:new_action_space
~f:(fun new_action -> (* convert to inner action *))
env
```
**`map_reward`** -- transform rewards:
```ocaml
let scaled_env =
Env.map_reward
~f:(fun ~reward ~info -> (reward *. 0.1, info))
env
```
**`clip_action`** -- clamp continuous actions to bounds:
```ocaml
(* Gymnasium *)
(* from gymnasium.wrappers import ClipAction *)
(* env = ClipAction(env) *)
(* Fehu *)
let clipped_env = Env.clip_action env
```
**`clip_observation`** -- clamp observations:
```ocaml
let clipped_env =
Env.clip_observation
~low:[| -5.0; -5.0 |]
~high:[| 5.0; 5.0 |]
env
```
**`time_limit`** -- enforce maximum episode length:
```ocaml
(* Gymnasium *)
(* from gymnasium.wrappers import TimeLimit *)
(* env = TimeLimit(env, max_episode_steps=200) *)
(* Fehu *)
let limited_env = Env.time_limit ~max_episode_steps:200 env
```
### 5.3 Full custom wrapper with `Env.wrap`
When the combinators are not enough, use `Env.wrap` directly:
```ocaml
let custom_env =
Env.wrap
~observation_space:new_obs_space
~action_space:new_act_space
~reset:(fun inner ?options () ->
let (obs, info) = Env.reset inner ?options () in
(transform_obs obs, info))
~step:(fun inner action ->
let step = Env.step inner (transform_action action) in
{ step with observation = transform_obs step.observation })
env
```
`Env.wrap` receives the inner environment as the first argument to `reset`, `step`, `render`, and `close`. Guards (closed check, needs-reset check, space validation) are enforced automatically.
### 5.4 Composing wrappers
Wrappers compose by chaining:
```ocaml
let env =
base_env
|> Env.time_limit ~max_episode_steps:500
|> Env.clip_action
|> Env.map_reward ~f:(fun ~reward ~info -> (reward *. 0.01, info))
```
---
## 6. Vectorized Environments
### 6.1 Synchronous vectorization
**Gymnasium**
```python
envs = gym.vector.SyncVectorEnv([
lambda: gym.make("CartPole-v1") for _ in range(4)
])
obs, infos = envs.reset()
actions = envs.action_space.sample() # batch of 4 actions
obs, rewards, terminated, truncated, infos = envs.step(actions)
envs.close()
```
**Fehu**
```ocaml
let venv =
Vec_env.create [env1; env2; env3; env4]
let n = Vec_env.num_envs venv (* 4 *)
let (observations, infos) = Vec_env.reset venv ()
let actions = Array.init n (fun _ -> Space.sample (Vec_env.action_space venv)) in
let step = Vec_env.step venv actions
(* step.observations : 'obs array -- one per env *)
(* step.rewards : float array -- one per env *)
(* step.terminated : bool array -- one per env *)
(* step.truncated : bool array -- one per env *)
(* step.infos : Info.t array -- one per env *)
Vec_env.close venv
```
Key differences:
- `Vec_env.create` takes a list of already-constructed environments. All must have structurally identical spaces.
- Terminated or truncated environments are automatically reset. The terminal observation is stored in the step's info under `"final_observation"` (as a packed `Value.t`), and the terminal info under `"final_info"`.
- The step result is a record with named arrays, not a tuple.
---
## 7. Trajectory Collection
### 7.1 Fixed-step rollout
**Gymnasium + Stable Baselines3**
```python
from stable_baselines3.common.buffers import RolloutBuffer
# Manual loop or SB3 internals
obs, _ = env.reset()
for step in range(2048):
action, log_prob, value = policy(obs)
obs, reward, terminated, truncated, info = env.step(action)
buffer.add(obs, action, reward, ...)
if terminated or truncated:
obs, _ = env.reset()
```
**Fehu** -- built-in:
```ocaml
let trajectory =
Collect.rollout env
~policy:(fun obs ->
let action = (* select action *) in
let log_prob = (* optional log probability *) in
let value = (* optional value estimate *) in
(action, Some log_prob, Some value))
~n_steps:2048
```
`Collect.rollout` handles resets on episode boundaries automatically and returns a `Collect.t` record:
```ocaml
(* Collect.t fields: *)
trajectory.observations (* 'obs array *)
trajectory.actions (* 'act array *)
trajectory.rewards (* float array *)
trajectory.next_observations (* 'obs array *)
trajectory.terminated (* bool array *)
trajectory.truncated (* bool array *)
trajectory.infos (* Info.t array *)
trajectory.log_probs (* float array option *)
trajectory.values (* float array option *)
let n = Collect.length trajectory
```
### 7.2 Complete episodes
**Gymnasium + manual**
```python
episodes = []
for _ in range(10):
obs, _ = env.reset()
episode = []
done = False
while not done:
action = policy(obs)
next_obs, reward, terminated, truncated, info = env.step(action)
episode.append((obs, action, reward, next_obs, terminated, truncated))
obs = next_obs
done = terminated or truncated
episodes.append(episode)
```
**Fehu** -- built-in:
```ocaml
let episode_list =
Collect.episodes env
~policy:(fun obs -> (policy obs, None, None))
~n_episodes:10
~max_steps:1000
()
(* episode_list : ('obs, 'act) Collect.t list *)
```
Each element is one episode as a `Collect.t`. Concatenate them with `Collect.concat`:
```ocaml
let all_transitions = Collect.concat episode_list
```
---
## 8. Replay Buffers
### 8.1 Standard replay buffer
**Stable Baselines3**
```python
from stable_baselines3.common.buffers import ReplayBuffer
buffer = ReplayBuffer(buffer_size=100_000, observation_space=..., action_space=...)
buffer.add(obs, next_obs, action, reward, done, infos)
batch = buffer.sample(batch_size=256)
```
**Fehu** -- built-in:
```ocaml
let buf = Buffer.create ~capacity:100_000
let () =
Buffer.add buf {
Buffer.observation = obs;
action;
reward = 1.0;
next_observation = next_obs;
terminated = false;
truncated = false;
}
(* Uniform random sampling *)
let batch = Buffer.sample buf ~batch_size:256
(* batch : ('obs, 'act) Buffer.transition array *)
(* Structure-of-arrays form for training loops *)
let (observations, actions, rewards, next_observations, terminated, truncated) =
Buffer.sample_arrays buf ~batch_size:256
```
### 8.2 Buffer queries
```ocaml
let n = Buffer.size buf (* current number of stored transitions *)
let cap = Buffer.capacity buf (* maximum capacity *)
let full = Buffer.is_full buf (* true when size = capacity *)
let () = Buffer.clear buf (* remove all transitions, keep storage *)
```
---
## 9. GAE and Returns
### 9.1 Generalized Advantage Estimation
**Stable Baselines3** (internal)
```python
# SB3 computes GAE internally in on-policy algorithms
# or manually:
import numpy as np
def compute_gae(rewards, values, dones, next_values, gamma=0.99, lam=0.95):
advantages = np.zeros_like(rewards)
last_gae = 0
for t in reversed(range(len(rewards))):
delta = rewards[t] + gamma * next_values[t] * (1 - dones[t]) - values[t]
advantages[t] = last_gae = delta + gamma * lam * (1 - dones[t]) * last_gae
returns = advantages + values
return advantages, returns
```
**Fehu** -- built-in, with correct terminated/truncated handling:
```ocaml
let (advantages, returns) =
Gae.compute
~rewards:trajectory.rewards
~values:(Option.get trajectory.values)
~terminated:trajectory.terminated
~truncated:trajectory.truncated
~next_values (* float array: V(s_{t+1}) for each t *)
~gamma:0.99
~lambda:0.95
```
When you have values from a rollout and a final bootstrap value:
```ocaml
let (advantages, returns) =
Gae.compute_from_values
~rewards:trajectory.rewards
~values:(Option.get trajectory.values)
~terminated:trajectory.terminated
~truncated:trajectory.truncated
~last_value:0.0
~gamma:0.99
~lambda:0.95
```
`compute_from_values` builds `next_values` from `values` and `last_value` automatically: `next_values.(t) = values.(t+1)` for `t < n-1`, and `next_values.(n-1) = last_value`.
### 9.2 Monte Carlo returns
**Manual Python**
```python
def discounted_returns(rewards, dones, gamma=0.99):
returns = np.zeros_like(rewards)
running = 0.0
for t in reversed(range(len(rewards))):
running = rewards[t] + gamma * running * (1 - dones[t])
returns[t] = running
return returns
```
**Fehu**
```ocaml
let mc_returns =
Gae.returns
~rewards:trajectory.rewards
~terminated:trajectory.terminated
~truncated:trajectory.truncated
~gamma:0.99
```
### 9.3 Normalization
```ocaml
let normalized_advantages = Gae.normalize advantages
let normalized_custom = Gae.normalize ~eps:1e-5 advantages
```
---
## 10. Policy Evaluation
**Gymnasium + Stable Baselines3**
```python
from stable_baselines3.common.evaluation import evaluate_policy
mean_reward, std_reward = evaluate_policy(
model, env, n_eval_episodes=10, deterministic=True
)
```
**Fehu** -- built-in:
```ocaml
let stats =
Eval.run env
~policy:(fun obs -> (* deterministic action *))
~n_episodes:10
~max_steps:1000
()
(* stats.mean_reward : float *)
(* stats.std_reward : float *)
(* stats.mean_length : float *)
(* stats.n_episodes : int *)
```
`Eval.run` resets the environment between episodes and collects total reward and episode length across all episodes.
---
## 11. Rendering
### 11.1 Render modes
**Gymnasium**
```python
env = gym.make("CartPole-v1", render_mode="human")
env.reset()
env.step(action)
frame = env.render() # None for "human", np.ndarray for "rgb_array"
```
**Fehu**
```ocaml
let env =
Env.create
~render_mode:`Human
~render_modes:["human"; "rgb_array"]
~render:(fun () -> (* return 'render option *))
(* ... *)
()
let frame = Env.render env (* 'render option *)
```
Render modes are polymorphic variants: `` `Human ``, `` `Rgb_array ``, `` `Ansi ``, `` `Svg ``, `` `Custom of string ``.
### 11.2 Frame type
For `Rgb_array` environments, Fehu uses `Render.image`:
```ocaml
(* Render.image fields: *)
(* width : int *)
(* height : int *)
(* pixel_format : Render.Pixel.format (Rgb|Rgba|Gray) *)
(* data : uint8 bigarray *)
```
### 11.3 Recording rendered rollouts
**Gymnasium**
```python
from gymnasium.wrappers import RecordVideo
env = RecordVideo(env, video_folder="./videos")
```
**Fehu** -- use `Render.rollout` or `Render.on_render`:
```ocaml
(* Run a policy and feed frames to a sink *)
Render.rollout env
~policy:(fun obs -> (* action *))
~steps:500
~sink:(fun frame -> (* save or display frame *))
()
(* Or wrap the env to capture every rendered frame *)
let recording_env =
Render.on_render
~sink:(fun frame -> (* process frame *))
env
```
---
## 12. Info Dictionaries
**Gymnasium** uses plain Python dicts for info:
```python
obs, info = env.reset()
print(info.get("elapsed_steps", 0))
```
**Fehu** uses typed `Info.t` dictionaries with `Value.t` values:
```ocaml
let info = Info.of_list [
("elapsed_steps", Info.int 42);
("success", Info.bool true);
]
let steps = Info.find "elapsed_steps" info (* Value.t option *)
let steps = Info.find_exn "elapsed_steps" info (* Value.t, raises on missing *)
let info' = Info.set "custom_key" (Info.float 3.14) info
let info' = Info.merge info1 info2 (* info2 wins on conflicts *)
let is_empty = Info.is_empty info
```
---
## 13. Quick Cheat Sheet
| Task | Gymnasium / SB3 | Fehu |
| -------------------- | ------------------------------------------------- | --------------------------------------------------------------------------------- |
| Create env | `gym.make("CartPole-v1")` | `Env.create ~observation_space ~action_space ~reset ~step ()` |
| Reset | `obs, info = env.reset(seed=42)` | `let (obs, info) = Env.reset env ()` |
| Step | `obs, r, term, trunc, info = env.step(a)` | `let s = Env.step env a` (record fields) |
| Close | `env.close()` | `Env.close env` |
| Discrete space | `gym.spaces.Discrete(5)` | `Space.Discrete.create 5` |
| Box space | `gym.spaces.Box(low, high)` | `Space.Box.create ~low ~high` |
| Sample from space | `space.sample()` | `Space.sample space` |
| Contains check | `space.contains(x)` | `Space.contains space x` |
| Observation wrapper | `class W(gym.ObservationWrapper)` | `Env.map_observation ~observation_space ~f env` |
| Action wrapper | `class W(gym.ActionWrapper)` | `Env.map_action ~action_space ~f env` |
| Reward wrapper | `class W(gym.RewardWrapper)` | `Env.map_reward ~f env` |
| Clip actions | `ClipAction(env)` | `Env.clip_action env` |
| Time limit | `TimeLimit(env, max_episode_steps=N)` | `Env.time_limit ~max_episode_steps:N env` |
| Vectorize | `gym.vector.SyncVectorEnv([...])` | `Vec_env.create [env1; env2; ...]` |
| Rollout N steps | Manual loop / SB3 internal | `Collect.rollout env ~policy ~n_steps` |
| Collect N episodes | Manual loop | `Collect.episodes env ~policy ~n_episodes ()` |
| Replay buffer | `ReplayBuffer(buffer_size=N, ...)` | `Buffer.create ~capacity:N` |
| Add to buffer | `buffer.add(obs, next_obs, ...)` | `Buffer.add buf transition` |
| Sample from buffer | `buffer.sample(batch_size=B)` | `Buffer.sample buf ~batch_size:B` |
| GAE | SB3 internal / manual | `Gae.compute ~rewards ~values ~terminated ~truncated ~next_values ~gamma ~lambda` |
| Discounted returns | Manual loop | `Gae.returns ~rewards ~terminated ~truncated ~gamma` |
| Normalize advantages | `(adv - mean) / std` | `Gae.normalize advantages` |
| Evaluate policy | `evaluate_policy(model, env, n_eval_episodes=10)` | `Eval.run env ~policy ~n_episodes:10 ()` |
| Render | `env.render()` | `Env.render env` |
| Record frames | `RecordVideo(env, ...)` | `Render.on_render ~sink env` |
| Seed RNG | `env.reset(seed=42)` | `Nx.Rng.run ~seed:42 (fun () -> ...)` |
================================================
FILE: packages/fehu/doc/dune
================================================
(mdx
(files *.md)
(package fehu)
(libraries fehu fehu.envs nx))
================================================
FILE: packages/fehu/doc/index.md
================================================
# Fehu
Fehu is a reinforcement learning environment toolkit for OCaml. It provides
type-safe environments, composable wrappers, trajectory collection, replay
buffers, GAE computation, policy evaluation, and vectorized environments.
Fehu follows the Gymnasium interface pattern: environments expose `reset` and
`step` with typed observation and action spaces. Wrappers compose freely.
Collection and evaluation utilities handle the plumbing between environments
and training loops.
## Features
- **Type-safe environments**: observation and action spaces are encoded in the type system
- **Rich space types**: Discrete, Box, Multi_binary, Multi_discrete, Tuple, Dict, Sequence, Text
- **Composable wrappers**: map_observation, map_action, map_reward, clip_action, clip_observation, time_limit
- **Trajectory collection**: rollout and episode collection in structure-of-arrays form
- **Replay buffers**: fixed-capacity circular buffer with uniform random sampling
- **GAE**: generalized advantage estimation with proper terminated/truncated handling
- **Policy evaluation**: run a policy over episodes and get mean/std reward statistics
- **Vectorized environments**: run multiple environments with batched step and auto-reset
- **Built-in environments**: CartPole, MountainCar, GridWorld, RandomWalk
## Quick Start
Create an environment, run a random agent, and evaluate:
```ocaml
open Fehu
let () = Nx.Rng.run ~seed:42 @@ fun () ->
let env = Fehu_envs.Cartpole.make () in
(* Run one episode *)
let _obs, _info = Env.reset env () in
let done_ = ref false in
let total_reward = ref 0.0 in
while not !done_ do
let act = Space.sample (Env.action_space env) in
let s = Env.step env act in
total_reward := !total_reward +. s.reward;
done_ := s.terminated || s.truncated
done;
(* Evaluate over 10 episodes *)
let _stats = Eval.run env
~policy:(fun _obs -> Space.sample (Env.action_space env))
~n_episodes:10 ()
in ()
```
## Next Steps
- [Getting Started](01-getting-started/) -- installation, environments, spaces, step loop
- [Environments and Wrappers](02-environments/) -- custom environments, wrappers, rendering, vectorized environments
- [Collection and Evaluation](03-collection-and-evaluation/) -- trajectory collection, replay buffers, GAE, evaluation
================================================
FILE: packages/fehu/examples/01-random-agent/dune
================================================
(executable
(name main)
(libraries nx rune fehu fehu.envs))
================================================
FILE: packages/fehu/examples/01-random-agent/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* A random agent on CartPole-v1.
Demonstrates the Env lifecycle: create, reset, step, render, close. Then uses
Eval.run for batch evaluation. *)
open Fehu
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
Printf.printf "Random Agent on CartPole-v1\n";
Printf.printf "===========================\n\n";
let env = Fehu_envs.Cartpole.make ~render_mode:`Ansi () in
(* -- Manual episode loop ------------------------------------------------ *)
Printf.printf "Running 5 episodes with random actions...\n\n";
for episode = 1 to 5 do
let obs = ref (fst (Env.reset env ())) in
let total_reward = ref 0.0 in
let steps = ref 0 in
let done_ = ref false in
while not !done_ do
(* Show the first step of episode 1 *)
(if episode = 1 && !steps = 0 then
match Env.render env with
| Some text -> Printf.printf "%s\n" text
| None -> ());
let action = Space.sample (Env.action_space env) in
let s = Env.step env action in
total_reward := !total_reward +. s.reward;
incr steps;
obs := s.observation;
done_ := s.terminated || s.truncated
done;
Printf.printf " Episode %d: reward = %5.1f length = %3d\n" episode
!total_reward !steps
done;
(* -- Batch evaluation with Eval.run ------------------------------------ *)
Printf.printf "\nEvaluating over 100 episodes...\n\n";
let random_policy _obs = Space.sample (Env.action_space env) in
let stats = Eval.run env ~policy:random_policy ~n_episodes:100 () in
Printf.printf " mean reward: %6.2f +/- %.2f\n" stats.mean_reward
stats.std_reward;
Printf.printf " mean length: %6.1f\n" stats.mean_length;
Env.close env;
Printf.printf "\nDone.\n"
================================================
FILE: packages/fehu/examples/02-q-learning/dune
================================================
(executable
(name main)
(libraries nx rune fehu fehu.envs))
================================================
FILE: packages/fehu/examples/02-q-learning/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Tabular Q-learning on CartPole-v1.
Discretizes the continuous 4D observation into bins, learns a Q-table with
epsilon-greedy exploration and temporal difference updates. Uses Eval.run for
periodic evaluation. *)
open Fehu
(* Hyperparameters *)
let n_bins = 12
let n_actions = 2
let alpha = 0.1
let gamma = 0.99
let epsilon_start = 1.0
let epsilon_end = 0.01
let epsilon_decay = 2000.0
let n_episodes = 10_000
let eval_interval = 500
(* Sparkline *)
let sparkline values =
let blocks =
[|
"\xe2\x96\x81";
"\xe2\x96\x82";
"\xe2\x96\x83";
"\xe2\x96\x84";
"\xe2\x96\x85";
"\xe2\x96\x86";
"\xe2\x96\x87";
"\xe2\x96\x88";
|]
in
let lo = Array.fold_left Float.min Float.infinity values in
let hi = Array.fold_left Float.max Float.neg_infinity values in
let range = hi -. lo in
if range < 1e-9 then
String.concat "" (Array.to_list (Array.map (fun _ -> blocks.(4)) values))
else
String.concat ""
(Array.to_list
(Array.map
(fun v ->
let idx = Float.to_int ((v -. lo) /. range *. 7.0) in
blocks.(max 0 (min 7 idx)))
values))
(* Q-table *)
let n_states = n_bins * n_bins * n_bins * n_bins
let q = Array.make (n_states * n_actions) 0.0
let q_get s a = q.((s * n_actions) + a)
let q_set s a v = q.((s * n_actions) + a) <- v
(* Discretize: clip each of the 4 obs dimensions into bins. CartPole obs: [x,
x_dot, theta, theta_dot] We use generous clip ranges that cover typical
CartPole trajectories. *)
let clip_ranges = [| (-2.4, 2.4); (-3.0, 3.0); (-0.21, 0.21); (-3.0, 3.0) |]
let discretize obs =
let arr = (Nx.to_array obs : float array) in
let bin i =
let lo, hi = clip_ranges.(i) in
let v = Float.max lo (Float.min hi arr.(i)) in
let normalized = (v -. lo) /. (hi -. lo) in
Float.to_int (normalized *. Float.of_int (n_bins - 1))
|> max 0
|> min (n_bins - 1)
in
let b0 = bin 0 in
let b1 = bin 1 in
let b2 = bin 2 in
let b3 = bin 3 in
(b0 * n_bins * n_bins * n_bins) + (b1 * n_bins * n_bins) + (b2 * n_bins) + b3
let best_action s = if q_get s 0 >= q_get s 1 then 0 else 1
(* Training *)
let () =
Printf.printf "Q-Learning on CartPole-v1\n";
Printf.printf "==========================\n\n";
Printf.printf "States: %d bins/dim (%d total), Actions: left/right\n" n_bins
n_states;
Printf.printf "alpha = %.2f, gamma = %.2f, episodes = %d\n\n" alpha gamma
n_episodes;
Nx.Rng.run ~seed:42 @@ fun () ->
let sample_uniform () =
let t = Nx.rand Nx.float32 [| 1 |] in
(Nx.to_array t : float array).(0)
in
let sample_random_action () =
let t = Nx.randint Nx.int32 ~high:n_actions [| 1 |] 0 in
Int32.to_int (Nx.to_array t : Int32.t array).(0)
in
let env = Fehu_envs.Cartpole.make () in
let n_evals = n_episodes / eval_interval in
let reward_history = Array.make n_evals 0.0 in
let eval_idx = ref 0 in
Printf.printf "Training...\n\n";
for episode = 1 to n_episodes do
let epsilon =
epsilon_end
+. (epsilon_start -. epsilon_end)
*. exp (-.Float.of_int episode /. epsilon_decay)
in
let obs, _info = Env.reset env () in
let state = ref (discretize obs) in
let done_ = ref false in
while not !done_ do
let a =
if sample_uniform () < epsilon then sample_random_action ()
else best_action !state
in
let s = Env.step env (Space.Discrete.of_int a) in
let next_state = discretize s.observation in
let done_flag = s.terminated || s.truncated in
let bootstrap =
if done_flag then 0.0
else Float.max (q_get next_state 0) (q_get next_state 1)
in
let target = s.reward +. (gamma *. bootstrap) in
let old_q = q_get !state a in
q_set !state a (old_q +. (alpha *. (target -. old_q)));
state := next_state;
done_ := done_flag
done;
if episode mod eval_interval = 0 then begin
let greedy_policy obs =
Space.Discrete.of_int (best_action (discretize obs))
in
let stats = Eval.run env ~policy:greedy_policy ~n_episodes:20 () in
Printf.printf
" episode %5d eps = %.2f eval: reward = %5.1f +/- %4.1f\n%!" episode
epsilon stats.mean_reward stats.std_reward;
reward_history.(!eval_idx) <- stats.mean_reward;
incr eval_idx
end
done;
Printf.printf "\n reward: %s\n" (sparkline reward_history);
(* Final evaluation *)
Printf.printf "\nFinal evaluation (100 episodes):\n";
let greedy_policy obs =
Space.Discrete.of_int (best_action (discretize obs))
in
let stats = Eval.run env ~policy:greedy_policy ~n_episodes:100 () in
Printf.printf " mean reward: %5.1f +/- %.1f\n" stats.mean_reward
stats.std_reward;
Printf.printf " mean length: %5.1f\n" stats.mean_length;
if stats.mean_reward >= 195.0 then
Printf.printf "\nSolved! (mean reward >= 195)\n"
else Printf.printf "\nNot solved yet (mean reward < 195).\n";
Env.close env
================================================
FILE: packages/fehu/examples/03-reinforce/dune
================================================
(executable
(name main)
(libraries nx rune kaun vega fehu fehu.envs))
================================================
FILE: packages/fehu/examples/03-reinforce/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* REINFORCE on CartPole-v1.
Policy gradient with a small neural network. Collects rollouts, computes
discounted returns, and updates the policy by maximizing the expected return
weighted by log-probabilities. *)
open Fehu
open Kaun
(* Hyperparameters *)
let gamma = 0.99
let lr = 1e-3
let n_steps = 2048
let n_updates = 250
let eval_interval = 10
let eval_episodes = 20
(* Sparkline *)
let sparkline values =
let blocks =
[|
"\xe2\x96\x81";
"\xe2\x96\x82";
"\xe2\x96\x83";
"\xe2\x96\x84";
"\xe2\x96\x85";
"\xe2\x96\x86";
"\xe2\x96\x87";
"\xe2\x96\x88";
|]
in
let lo = Array.fold_left Float.min Float.infinity values in
let hi = Array.fold_left Float.max Float.neg_infinity values in
let range = hi -. lo in
if range < 1e-9 then
String.concat "" (Array.to_list (Array.map (fun _ -> blocks.(4)) values))
else
String.concat ""
(Array.to_list
(Array.map
(fun v ->
let idx = Float.to_int ((v -. lo) /. range *. 7.0) in
blocks.(max 0 (min 7 idx)))
values))
(* Network *)
let network =
Layer.sequential
[
Layer.linear ~in_features:4 ~out_features:64 ();
Layer.relu ();
Layer.linear ~in_features:64 ~out_features:2 ();
]
(* Forward pass: obs [batch; 4] -> logits [batch; 2] *)
let forward params net_state obs =
let vars = Layer.make_vars ~params ~state:net_state ~dtype:Nx.float32 in
fst (Layer.apply network vars ~training:false obs)
(* Main *)
let () =
Printf.printf "REINFORCE on CartPole-v1\n";
Printf.printf "=========================\n\n";
Printf.printf "Network: Linear(4 -> 64) -> ReLU -> Linear(64 -> 2)\n";
Printf.printf "Rollout: %d steps/update, gamma = %.2f, lr = %.4f\n\n" n_steps
gamma lr;
Nx.Rng.run ~seed:42 @@ fun () ->
let env = Fehu_envs.Cartpole.make () in
(* Initialize network *)
let vars = Layer.init network ~dtype:Nx.float32 in
let params = ref (Layer.params vars) in
let net_state = Layer.state vars in
Printf.printf "Parameters: %d\n\n" (Ptree.count_parameters !params);
(* Optimizer *)
let algo = Vega.adam (Vega.Schedule.constant lr) in
let opt_state = ref (Optim.init algo !params) in
let policy obs =
let obs_batch = Nx.reshape [| 1; 4 |] obs in
let logits = Rune.no_grad (fun () -> forward !params net_state obs_batch) in
let action_idx = Nx.categorical logits in
let action = Nx.reshape [||] action_idx in
let log_probs = Nx.log_softmax logits in
let action_1 = Nx.reshape [| 1; 1 |] action_idx in
let log_prob = Nx.take_along_axis ~axis:1 action_1 log_probs in
let lp = Nx.item [ 0; 0 ] log_prob in
(action, Some lp, None)
in
(* Greedy policy for evaluation *)
let greedy_policy obs =
let obs_batch = Nx.reshape [| 1; 4 |] obs in
let logits = Rune.no_grad (fun () -> forward !params net_state obs_batch) in
let action_idx =
Nx.argmax logits ~axis:(-1) ~keepdims:false |> Nx.cast Nx.int32
in
Nx.reshape [||] action_idx
in
(* Training loop *)
Printf.printf "Training...\n\n";
let n_evals = n_updates / eval_interval in
let reward_history = Array.make n_evals 0.0 in
let eval_idx = ref 0 in
for update = 1 to n_updates do
(* Collect rollout *)
let traj = Collect.rollout env ~policy ~n_steps in
let n = Collect.length traj in
(* Compute discounted returns and normalize *)
let returns =
Gae.returns ~rewards:traj.rewards ~terminated:traj.terminated
~truncated:traj.truncated ~gamma
in
let returns = Gae.normalize returns in
(* Stack observations and actions into batch tensors *)
let obs_batch = Nx.stack (Array.to_list traj.observations) in
let actions_batch =
Nx.stack
(Array.to_list (Array.map (fun a -> Nx.reshape [| 1 |] a) traj.actions))
in
let returns_t = Nx.create Nx.float32 [| n |] returns in
(* Policy gradient loss *)
let loss_fn p =
let logits = forward p net_state obs_batch in
let log_probs = Nx.log_softmax logits in
let action_log_probs =
Nx.take_along_axis ~axis:1 actions_batch log_probs
in
let action_log_probs = Nx.reshape [| n |] action_log_probs in
let weighted = Nx.mul action_log_probs returns_t in
Nx.neg (Nx.mean weighted)
in
let loss, grads = Grad.value_and_grad loss_fn !params in
let new_params, new_opt_state = Optim.update !opt_state !params grads in
params := new_params;
opt_state := new_opt_state;
(* Evaluate periodically *)
if update mod eval_interval = 0 then begin
let stats =
Eval.run env ~policy:greedy_policy ~n_episodes:eval_episodes ()
in
Printf.printf
" update %3d loss = %6.3f eval: reward = %5.1f +/- %4.1f\n%!" update
(Nx.item [] loss) stats.mean_reward stats.std_reward;
reward_history.(!eval_idx) <- stats.mean_reward;
incr eval_idx
end
done;
Printf.printf "\n reward: %s\n" (sparkline reward_history);
(* Final evaluation *)
Printf.printf "\nFinal evaluation (%d episodes):\n" 50;
let stats = Eval.run env ~policy:greedy_policy ~n_episodes:50 () in
Printf.printf " mean reward: %5.1f +/- %.1f\n" stats.mean_reward
stats.std_reward;
Printf.printf " mean length: %5.1f\n" stats.mean_length;
if stats.mean_reward >= 475.0 then
Printf.printf "\nSolved! (mean reward >= 475)\n"
else Printf.printf "\nNot solved yet (mean reward < 475).\n";
Env.close env
================================================
FILE: packages/fehu/examples/04-dqn/dune
================================================
(executable
(name main)
(libraries nx rune kaun vega fehu fehu.envs))
================================================
FILE: packages/fehu/examples/04-dqn/main.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* DQN on CartPole-v1.
Deep Q-Network with experience replay and a target network. Epsilon-greedy
exploration decays linearly. The target network is hard-copied every
target_update_interval steps. *)
open Fehu
open Kaun
(* Hyperparameters *)
let buffer_capacity = 50_000
let batch_size = 64
let gamma = 0.99
let lr = 5e-4
let epsilon_start = 1.0
let epsilon_end = 0.05
let epsilon_decay_steps = 10_000
let target_update_interval = 250
let learning_starts = 1000
let n_steps = 50_000
let eval_interval = 2000
let eval_episodes = 20
(* Sparkline *)
let sparkline values =
let blocks =
[|
"\xe2\x96\x81";
"\xe2\x96\x82";
"\xe2\x96\x83";
"\xe2\x96\x84";
"\xe2\x96\x85";
"\xe2\x96\x86";
"\xe2\x96\x87";
"\xe2\x96\x88";
|]
in
let lo = Array.fold_left Float.min Float.infinity values in
let hi = Array.fold_left Float.max Float.neg_infinity values in
let range = hi -. lo in
if range < 1e-9 then
String.concat "" (Array.to_list (Array.map (fun _ -> blocks.(4)) values))
else
String.concat ""
(Array.to_list
(Array.map
(fun v ->
let idx = Float.to_int ((v -. lo) /. range *. 7.0) in
blocks.(max 0 (min 7 idx)))
values))
(* Network *)
let q_network =
Layer.sequential
[
Layer.linear ~in_features:4 ~out_features:128 ();
Layer.relu ();
Layer.linear ~in_features:128 ~out_features:128 ();
Layer.relu ();
Layer.linear ~in_features:128 ~out_features:2 ();
]
(* Forward pass: obs [batch; 4] -> q_values [batch; 2] *)
let forward params net_state obs =
let vars = Layer.make_vars ~params ~state:net_state ~dtype:Nx.float32 in
fst (Layer.apply q_network vars ~training:false obs)
(* Epsilon schedule: linear decay *)
let epsilon step =
let t =
Float.min 1.0 (Float.of_int step /. Float.of_int epsilon_decay_steps)
in
epsilon_start +. (t *. (epsilon_end -. epsilon_start))
(* Copy parameters for the target network *)
let copy_params params = Ptree.map { run = (fun t -> Nx.copy t) } params
(* Main *)
let () =
Printf.printf "DQN on CartPole-v1\n";
Printf.printf "===================\n\n";
Printf.printf
"Network: Linear(4 -> 128) -> ReLU -> Linear(128 -> 128) -> ReLU -> \
Linear(128 -> 2)\n";
Printf.printf "Buffer: %d, batch: %d, gamma = %.2f, lr = %.4f\n"
buffer_capacity batch_size gamma lr;
Printf.printf
"Epsilon: %.2f -> %.2f over %d steps, target update every %d steps\n\n"
epsilon_start epsilon_end epsilon_decay_steps target_update_interval;
Nx.Rng.run ~seed:42 @@ fun () ->
let env = Fehu_envs.Cartpole.make () in
(* Initialize network *)
let vars = Layer.init q_network ~dtype:Nx.float32 in
let params = ref (Layer.params vars) in
let net_state = Layer.state vars in
let target_params = ref (copy_params !params) in
Printf.printf "Parameters: %d\n\n" (Ptree.count_parameters !params);
(* Optimizer *)
let algo = Vega.adam (Vega.Schedule.constant lr) in
let opt_state = ref (Optim.init algo !params) in
(* Replay buffer *)
let buffer = Buffer.create ~capacity:buffer_capacity in
let sample_uniform () =
let t = Nx.rand Nx.float32 [| 1 |] in
(Nx.to_array t : float array).(0)
in
(* Epsilon-greedy action selection *)
let select_action obs eps =
if sample_uniform () < eps then Space.sample (Env.action_space env)
else begin
let obs_batch = Nx.reshape [| 1; 4 |] obs in
let q_values =
Rune.no_grad (fun () -> forward !params net_state obs_batch)
in
let action_idx =
Nx.argmax q_values ~axis:(-1) ~keepdims:false |> Nx.cast Nx.int32
in
Nx.reshape [||] action_idx
end
in
(* Greedy policy for evaluation *)
let greedy_policy obs =
let obs_batch = Nx.reshape [| 1; 4 |] obs in
let q_values =
Rune.no_grad (fun () -> forward !params net_state obs_batch)
in
let action_idx =
Nx.argmax q_values ~axis:(-1) ~keepdims:false |> Nx.cast Nx.int32
in
Nx.reshape [||] action_idx
in
(* Training step *)
let train_step () =
let obs_arr, act_arr, rew_arr, next_obs_arr, term_arr, trunc_arr =
Buffer.sample_arrays buffer ~batch_size
in
let n = Array.length obs_arr in
(* Stack into batch tensors *)
let obs_batch = Nx.stack (Array.to_list obs_arr) in
let next_obs_batch = Nx.stack (Array.to_list next_obs_arr) in
let actions_batch =
Nx.stack
(Array.to_list (Array.map (fun a -> Nx.reshape [| 1 |] a) act_arr))
in
let rewards_t = Nx.create Nx.float32 [| n |] rew_arr in
(* Done mask: 1.0 if not done, 0.0 if done *)
let done_mask =
Array.init n (fun i -> if term_arr.(i) || trunc_arr.(i) then 0.0 else 1.0)
in
let done_mask_t = Nx.create Nx.float32 [| n |] done_mask in
(* Compute TD target with target network (no gradient) *)
let td_target =
Rune.no_grad (fun () ->
let target_q = forward !target_params net_state next_obs_batch in
let max_q = Nx.max target_q ~axes:[ 1 ] ~keepdims:false in
Nx.add rewards_t
(Nx.mul (Nx.scalar Nx.float32 gamma) (Nx.mul max_q done_mask_t)))
in
let td_target = Rune.detach td_target in
(* Loss: MSE between predicted Q and TD target *)
let loss_fn p =
let q_values = forward p net_state obs_batch in
let q_selected = Nx.take_along_axis ~axis:1 actions_batch q_values in
let q_selected = Nx.reshape [| n |] q_selected in
let diff = Nx.sub q_selected td_target in
Nx.mean (Nx.mul diff diff)
in
let loss, grads = Grad.value_and_grad loss_fn !params in
let new_params, new_opt_state = Optim.update !opt_state !params grads in
params := new_params;
opt_state := new_opt_state;
Nx.item [] loss
in
(* Main training loop *)
Printf.printf "Filling buffer (%d steps)...\n\n" learning_starts;
let obs = ref (fst (Env.reset env ())) in
let last_loss = ref 0.0 in
let n_evals = n_steps / eval_interval in
let reward_history = Array.make n_evals 0.0 in
let eval_idx = ref 0 in
Printf.printf "Training...\n\n";
for step = 1 to n_steps do
let eps = epsilon step in
let action = select_action !obs eps in
let s = Env.step env action in
Buffer.add buffer
{
observation = !obs;
action;
reward = s.reward;
next_observation = s.observation;
terminated = s.terminated;
truncated = s.truncated;
};
if s.terminated || s.truncated then obs := fst (Env.reset env ())
else obs := s.observation;
(* Train *)
if step >= learning_starts then begin
last_loss := train_step ();
(* Update target network *)
if step mod target_update_interval = 0 then
target_params := copy_params !params
end;
(* Evaluate periodically *)
if step mod eval_interval = 0 then begin
let stats =
Eval.run env ~policy:greedy_policy ~n_episodes:eval_episodes ()
in
Printf.printf
" step %5d epsilon = %.2f loss = %6.4f eval: reward = %5.1f +/- \
%4.1f\n\
%!"
step eps !last_loss stats.mean_reward stats.std_reward;
reward_history.(!eval_idx) <- stats.mean_reward;
incr eval_idx;
(* Eval.run leaves the env in a done state; reset for training *)
obs := fst (Env.reset env ())
end
done;
Printf.printf "\n reward: %s\n" (sparkline reward_history);
(* Final evaluation *)
Printf.printf "\nFinal evaluation (%d episodes):\n" 50;
let stats = Eval.run env ~policy:greedy_policy ~n_episodes:50 () in
Printf.printf " mean reward: %5.1f +/- %.1f\n" stats.mean_reward
stats.std_reward;
Printf.printf " mean length: %5.1f\n" stats.mean_length;
if stats.mean_reward >= 475.0 then
Printf.printf "\nSolved! (mean reward >= 475)\n"
else Printf.printf "\nNot solved yet (mean reward < 475).\n";
Env.close env
================================================
FILE: packages/fehu/lib/buffer.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let err_capacity = "Buffer.create: capacity must be positive"
let err_empty = "Buffer.sample: buffer is empty"
let err_batch_size = "Buffer.sample: batch_size must be positive"
type ('obs, 'act) transition = {
observation : 'obs;
action : 'act;
reward : float;
next_observation : 'obs;
terminated : bool;
truncated : bool;
}
type ('obs, 'act) t = {
capacity : int;
mutable size : int;
mutable pos : int;
mutable observations : 'obs array;
mutable actions : 'act array;
rewards : float array;
mutable next_observations : 'obs array;
terminateds : bool array;
truncateds : bool array;
}
(* Constructor *)
let create ~capacity =
if capacity <= 0 then invalid_arg err_capacity;
{
capacity;
size = 0;
pos = 0;
observations = [||];
actions = [||];
rewards = Array.make capacity 0.0;
next_observations = [||];
terminateds = Array.make capacity false;
truncateds = Array.make capacity false;
}
(* Mutating *)
let ensure_init buf (tr : _ transition) =
if Array.length buf.observations = 0 then begin
buf.observations <- Array.make buf.capacity tr.observation;
buf.actions <- Array.make buf.capacity tr.action;
buf.next_observations <- Array.make buf.capacity tr.next_observation
end
let add buf tr =
ensure_init buf tr;
buf.observations.(buf.pos) <- tr.observation;
buf.actions.(buf.pos) <- tr.action;
buf.rewards.(buf.pos) <- tr.reward;
buf.next_observations.(buf.pos) <- tr.next_observation;
buf.terminateds.(buf.pos) <- tr.terminated;
buf.truncateds.(buf.pos) <- tr.truncated;
buf.pos <- (buf.pos + 1) mod buf.capacity;
if buf.size < buf.capacity then buf.size <- buf.size + 1
let clear buf =
buf.size <- 0;
buf.pos <- 0
(* Sampling *)
let sample_indices buf ~batch_size =
if buf.size = 0 then invalid_arg err_empty;
if batch_size <= 0 then invalid_arg err_batch_size;
let n = min batch_size buf.size in
let raw = Nx.randint Nx.int32 ~high:buf.size [| n |] 0 in
let idx : Int32.t array = Nx.to_array raw in
(idx, n)
let sample buf ~batch_size =
let idx, n = sample_indices buf ~batch_size in
Array.init n (fun i ->
let j = Int32.to_int idx.(i) in
{
observation = buf.observations.(j);
action = buf.actions.(j);
reward = buf.rewards.(j);
next_observation = buf.next_observations.(j);
terminated = buf.terminateds.(j);
truncated = buf.truncateds.(j);
})
let sample_arrays buf ~batch_size =
let idx, n = sample_indices buf ~batch_size in
let get arr i = arr.(Int32.to_int idx.(i)) in
let observations = Array.init n (get buf.observations) in
let actions = Array.init n (get buf.actions) in
let rewards = Array.init n (get buf.rewards) in
let next_observations = Array.init n (get buf.next_observations) in
let terminated = Array.init n (get buf.terminateds) in
let truncated = Array.init n (get buf.truncateds) in
(observations, actions, rewards, next_observations, terminated, truncated)
(* Queries *)
let size buf = buf.size
let is_full buf = buf.size = buf.capacity
let capacity buf = buf.capacity
================================================
FILE: packages/fehu/lib/buffer.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Replay buffer for off-policy experience storage.
A fixed-capacity circular buffer that stores transitions and supports
uniform random sampling. Observation and action arrays are lazily
initialized on the first {!add}. *)
(** {1:types Types} *)
type ('obs, 'act) transition = {
observation : 'obs; (** State before the action. *)
action : 'act; (** Action taken. *)
reward : float; (** Scalar reward received. *)
next_observation : 'obs; (** State after the action. *)
terminated : bool; (** Natural episode ending. *)
truncated : bool; (** Forced episode ending. *)
}
(** The type for transitions. *)
type ('obs, 'act) t
(** A replay buffer of transitions. *)
(** {1:constructors Constructors} *)
val create : capacity:int -> ('obs, 'act) t
(** [create ~capacity] is an empty buffer that holds at most [capacity]
transitions.
Raises [Invalid_argument] if [capacity <= 0]. *)
(** {1:mutating Mutating} *)
val add : ('obs, 'act) t -> ('obs, 'act) transition -> unit
(** [add buf tr] appends [tr], overwriting the oldest transition when at
capacity. *)
val clear : ('obs, 'act) t -> unit
(** [clear buf] removes all transitions, keeping storage allocated. *)
(** {1:sampling Sampling} *)
val sample : ('obs, 'act) t -> batch_size:int -> ('obs, 'act) transition array
(** [sample buf ~batch_size] draws [batch_size] transitions uniformly at random
(with replacement).
Random keys are drawn from the implicit RNG scope.
If [batch_size] exceeds {!size}, samples [min batch_size size] transitions.
Raises [Invalid_argument] if [buf] is empty or [batch_size <= 0]. *)
val sample_arrays :
('obs, 'act) t ->
batch_size:int ->
'obs array * 'act array * float array * 'obs array * bool array * bool array
(** [sample_arrays buf ~batch_size] is like {!sample} but returns
structure-of-arrays
[(observations, actions, rewards, next_observations, terminated, truncated)]
for direct use in training loops. *)
(** {1:queries Queries} *)
val size : ('obs, 'act) t -> int
(** [size buf] is the number of stored transitions. *)
val is_full : ('obs, 'act) t -> bool
(** [is_full buf] is [true] iff [size buf = capacity]. *)
val capacity : ('obs, 'act) t -> int
(** [capacity buf] is the maximum number of transitions [buf] can hold. *)
================================================
FILE: packages/fehu/lib/collect.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let err_concat_empty = "Collect.concat: empty list"
type ('obs, 'act) t = {
observations : 'obs array;
actions : 'act array;
rewards : float array;
next_observations : 'obs array;
terminated : bool array;
truncated : bool array;
infos : Info.t array;
log_probs : float array option;
values : float array option;
}
let length t = Array.length t.observations
(* Concatenation *)
let concat_opt_field ts get =
if List.for_all (fun t -> Option.is_some (get t)) ts then
Some (Array.concat (List.map (fun t -> Option.get (get t)) ts))
else None
let concat = function
| [] -> invalid_arg err_concat_empty
| [ t ] -> t
| ts ->
{
observations = Array.concat (List.map (fun t -> t.observations) ts);
actions = Array.concat (List.map (fun t -> t.actions) ts);
rewards = Array.concat (List.map (fun t -> t.rewards) ts);
next_observations =
Array.concat (List.map (fun t -> t.next_observations) ts);
terminated = Array.concat (List.map (fun t -> t.terminated) ts);
truncated = Array.concat (List.map (fun t -> t.truncated) ts);
infos = Array.concat (List.map (fun t -> t.infos) ts);
log_probs = concat_opt_field ts (fun t -> t.log_probs);
values = concat_opt_field ts (fun t -> t.values);
}
(* Accumulator for building trajectories *)
type ('obs, 'act) acc = {
mutable obs : 'obs list;
mutable acts : 'act list;
mutable rews : float list;
mutable next_obs : 'obs list;
mutable terms : bool list;
mutable truncs : bool list;
mutable infos_acc : Info.t list;
mutable lps : float list;
mutable vals : float list;
mutable count : int;
}
let create_acc () =
{
obs = [];
acts = [];
rews = [];
next_obs = [];
terms = [];
truncs = [];
infos_acc = [];
lps = [];
vals = [];
count = 0;
}
let acc_step acc ~current_obs ~action ~lp_opt ~v_opt (s : _ Env.step) =
acc.obs <- current_obs :: acc.obs;
acc.acts <- action :: acc.acts;
acc.rews <- s.reward :: acc.rews;
acc.next_obs <- s.observation :: acc.next_obs;
acc.terms <- s.terminated :: acc.terms;
acc.truncs <- s.truncated :: acc.truncs;
acc.infos_acc <- s.info :: acc.infos_acc;
(match lp_opt with Some lp -> acc.lps <- lp :: acc.lps | None -> ());
(match v_opt with Some v -> acc.vals <- v :: acc.vals | None -> ());
acc.count <- acc.count + 1
let acc_to_trajectory acc =
let n = acc.count in
let log_probs =
if List.length acc.lps = n then Some (Array.of_list (List.rev acc.lps))
else None
in
let values =
if List.length acc.vals = n then Some (Array.of_list (List.rev acc.vals))
else None
in
{
observations = Array.of_list (List.rev acc.obs);
actions = Array.of_list (List.rev acc.acts);
rewards = Array.of_list (List.rev acc.rews);
next_observations = Array.of_list (List.rev acc.next_obs);
terminated = Array.of_list (List.rev acc.terms);
truncated = Array.of_list (List.rev acc.truncs);
infos = Array.of_list (List.rev acc.infos_acc);
log_probs;
values;
}
(* Collecting *)
let rollout env ~policy ~n_steps =
let acc = create_acc () in
let obs, _info = Env.reset env () in
let current_obs = ref obs in
while acc.count < n_steps do
let action, lp_opt, v_opt = policy !current_obs in
let s = Env.step env action in
acc_step acc ~current_obs:!current_obs ~action ~lp_opt ~v_opt s;
current_obs := s.observation;
if s.terminated || s.truncated then begin
let obs, _info = Env.reset env () in
current_obs := obs
end
done;
acc_to_trajectory acc
let episodes env ~policy ~n_episodes ?(max_steps = 1000) () =
let eps = ref [] in
for _ = 1 to n_episodes do
let acc = create_acc () in
let obs, _info = Env.reset env () in
let current_obs = ref obs in
let done_flag = ref false in
while acc.count < max_steps && not !done_flag do
let action, lp_opt, v_opt = policy !current_obs in
let s = Env.step env action in
acc_step acc ~current_obs:!current_obs ~action ~lp_opt ~v_opt s;
current_obs := s.observation;
done_flag := s.terminated || s.truncated
done;
eps := acc_to_trajectory acc :: !eps
done;
List.rev !eps
================================================
FILE: packages/fehu/lib/collect.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Trajectory collection from environments.
Collects sequential agent-environment interactions into structure-of-arrays
form for batch processing. Handles automatic resets on episode boundaries
and records both the current and next observation at each timestep. *)
(** {1:types Types} *)
type ('obs, 'act) t = {
observations : 'obs array; (** States before each action. *)
actions : 'act array; (** Actions taken. *)
rewards : float array; (** Scalar rewards received. *)
next_observations : 'obs array; (** States after each action. *)
terminated : bool array; (** Natural episode endings. *)
truncated : bool array; (** Forced episode endings. *)
infos : Info.t array; (** Per-step metadata. *)
log_probs : float array option; (** Policy log-probabilities. *)
values : float array option; (** Value estimates. *)
}
(** The type for trajectories. All arrays have the same length. Optional fields
are [None] when the policy does not provide them. *)
(** {1:accessors Accessors} *)
val length : ('obs, 'act) t -> int
(** [length traj] is the number of transitions in [traj]. *)
(** {1:combining Combining} *)
val concat : ('obs, 'act) t list -> ('obs, 'act) t
(** [concat trajs] concatenates [trajs] into a single trajectory. Optional
fields are kept only if present in all inputs.
Raises [Invalid_argument] if [trajs] is empty. *)
(** {1:collecting Collecting} *)
val rollout :
('obs, 'act, 'render) Env.t ->
policy:('obs -> 'act * float option * float option) ->
n_steps:int ->
('obs, 'act) t
(** [rollout env ~policy ~n_steps] collects [n_steps] transitions.
Resets [env] at the start and automatically on episode boundaries
(terminated or truncated). The [policy] receives the current observation and
returns [(action, log_prob_opt, value_opt)]. *)
val episodes :
('obs, 'act, 'render) Env.t ->
policy:('obs -> 'act * float option * float option) ->
n_episodes:int ->
?max_steps:int ->
unit ->
('obs, 'act) t list
(** [episodes env ~policy ~n_episodes ()] collects complete episodes, one
trajectory per episode. Each episode runs until termination, truncation, or
[max_steps] (default [1000]). *)
================================================
FILE: packages/fehu/lib/dune
================================================
(library
(name fehu)
(public_name fehu)
(libraries nx))
================================================
FILE: packages/fehu/lib/env.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let strf = Printf.sprintf
(* Error messages *)
let err_closed op = strf "Env: operation '%s' on a closed environment" op
let err_needs_reset op =
strf "Env: operation '%s' requires calling reset first" op
let err_render_mode mode modes =
strf "Env.create: render mode '%s' not in render_modes [%s]" mode
(String.concat "; " modes)
let err_obs_reset value =
strf "Env.reset: observation outside observation_space (value=%s)" value
let err_obs_step value =
strf "Env.step: observation outside observation_space (value=%s)" value
let err_action value =
strf "Env.step: action outside action_space (value=%s)" value
(* Step result *)
type 'obs step = {
observation : 'obs;
reward : float;
terminated : bool;
truncated : bool;
info : Info.t;
}
let step_result ~observation ?(reward = 0.) ?(terminated = false)
?(truncated = false) ?(info = Info.empty) () =
{ observation; reward; terminated; truncated; info }
(* Render mode *)
type render_mode = [ `Human | `Rgb_array | `Ansi | `Svg | `Custom of string ]
let render_mode_to_string = function
| `Human -> "human"
| `Rgb_array -> "rgb_array"
| `Ansi -> "ansi"
| `Svg -> "svg"
| `Custom name -> name
(* Shared mutable state *)
type shared = { mutable closed : bool; mutable needs_reset : bool }
(* Environment *)
type ('obs, 'act, 'render) t = {
id : string option;
observation_space : 'obs Space.t;
action_space : 'act Space.t;
render_mode : render_mode option;
render_modes : string list;
shared : shared;
reset_fn : ?options:Info.t -> unit -> 'obs * Info.t;
step_fn : 'act -> 'obs step;
render_fn : unit -> 'render option;
close_fn : unit -> unit;
}
(* Lifecycle guards *)
let ensure_open shared op = if shared.closed then invalid_arg (err_closed op)
let ensure_reset shared op =
if shared.needs_reset then invalid_arg (err_needs_reset op)
(* Constructor *)
let create ?id ~observation_space ~action_space ?render_mode
?(render_modes = []) ~reset ~step ?render ?close () =
(match render_mode with
| None -> ()
| Some mode ->
let mode_s = render_mode_to_string mode in
if not (List.mem mode_s render_modes) then
invalid_arg (err_render_mode mode_s render_modes));
let shared = { closed = false; needs_reset = true } in
let render_fn = Option.value render ~default:(fun () -> None) in
let close_fn = Option.value close ~default:(fun () -> ()) in
let rec env =
{
id;
observation_space;
action_space;
render_mode;
render_modes;
shared;
reset_fn = (fun ?options () -> reset env ?options ());
step_fn = (fun action -> step env action);
render_fn;
close_fn;
}
in
env
(* Wrap *)
let wrap ?id ~observation_space ~action_space ?render_mode ~reset ~step ?render
?close inner =
let render_mode =
match render_mode with Some _ -> render_mode | None -> inner.render_mode
in
let render_fn =
match render with
| Some f -> fun () -> f inner
| None -> fun () -> inner.render_fn ()
in
let close_fn =
match close with
| Some f -> fun () -> f inner
| None -> fun () -> inner.close_fn ()
in
{
id;
observation_space;
action_space;
render_mode;
render_modes = inner.render_modes;
shared = inner.shared;
reset_fn = (fun ?options () -> reset inner ?options ());
step_fn = (fun action -> step inner action);
render_fn;
close_fn;
}
(* Accessors *)
let id env = env.id
let observation_space env = env.observation_space
let action_space env = env.action_space
let render_mode env = env.render_mode
(* Human render helper *)
let maybe_human_render env =
match env.render_mode with
| Some `Human -> ignore (env.render_fn ())
| _ -> ()
(* Lifecycle — all guards live here *)
let closed env = env.shared.closed
let reset env ?options () =
ensure_open env.shared "reset";
let observation, info = env.reset_fn ?options () in
if not (Space.contains env.observation_space observation) then
invalid_arg
(err_obs_reset
(Space.pack env.observation_space observation |> Value.to_string));
env.shared.needs_reset <- false;
maybe_human_render env;
(observation, info)
let step env action =
ensure_open env.shared "step";
ensure_reset env.shared "step";
if not (Space.contains env.action_space action) then
invalid_arg
(err_action (Space.pack env.action_space action |> Value.to_string));
let result = env.step_fn action in
if not (Space.contains env.observation_space result.observation) then
invalid_arg
(err_obs_step
(Space.pack env.observation_space result.observation |> Value.to_string));
if result.terminated || result.truncated then env.shared.needs_reset <- true;
maybe_human_render env;
result
let render env =
ensure_open env.shared "render";
env.render_fn ()
let close env =
if not env.shared.closed then begin
env.close_fn ();
env.shared.closed <- true;
env.shared.needs_reset <- true
end
(* Wrapper helpers *)
let err_clip_bounds = "Env.clip_action: mismatched low/high bounds"
let err_clip_obs_bounds = "Env.clip_observation: mismatched low/high bounds"
let err_time_limit = "Env.time_limit: max_episode_steps must be positive"
let derive_id env suffix =
match env.id with None -> None | Some id -> Some (id ^ suffix)
let clamp_tensor ~low ~high tensor =
let data = Nx.to_array tensor in
let clipped = Array.copy data in
let upper = Array.length clipped - 1 in
for idx = 0 to upper do
let lo = low.(idx) in
let hi = high.(idx) in
let v = clipped.(idx) in
if v < lo then clipped.(idx) <- lo else if v > hi then clipped.(idx) <- hi
done;
Nx.create Nx.float32 (Nx.shape tensor) clipped
(* Wrappers *)
let map_observation ~observation_space ~f env =
wrap
?id:(derive_id env "/ObservationWrapper")
~observation_space ~action_space:env.action_space
~reset:(fun inner ?options () ->
let obs, info = reset inner ?options () in
f obs info)
~step:(fun inner action ->
let s = step inner action in
let obs, info = f s.observation s.info in
{ s with observation = obs; info })
env
let map_action ~action_space ~f env =
wrap
?id:(derive_id env "/ActionWrapper")
~observation_space:env.observation_space ~action_space
~reset:(fun inner ?options () -> reset inner ?options ())
~step:(fun inner action ->
let s = step inner (f action) in
{
observation = s.observation;
reward = s.reward;
terminated = s.terminated;
truncated = s.truncated;
info = s.info;
})
env
let map_reward ~f env =
wrap
?id:(derive_id env "/RewardWrapper")
~observation_space:env.observation_space ~action_space:env.action_space
~reset:(fun inner ?options () -> reset inner ?options ())
~step:(fun inner action ->
let s = step inner action in
let reward, info = f ~reward:s.reward ~info:s.info in
{ s with reward; info })
env
(* Clipping *)
let clip_action env =
let low, high = Space.Box.bounds env.action_space in
let element_count = Array.length low in
if Array.length high <> element_count then invalid_arg err_clip_bounds;
let relaxed_low =
Array.init element_count (fun i ->
if Float.equal low.(i) high.(i) then low.(i) else Float.neg_infinity)
in
let relaxed_high =
Array.init element_count (fun i ->
if Float.equal low.(i) high.(i) then high.(i) else Float.infinity)
in
let relaxed_space = Space.Box.create ~low:relaxed_low ~high:relaxed_high in
map_action ~action_space:relaxed_space
~f:(fun action -> clamp_tensor ~low ~high action)
env
let clip_observation ~low ~high env =
let inner_low, inner_high = Space.Box.bounds env.observation_space in
let n = Array.length low in
if Array.length high <> n then invalid_arg err_clip_obs_bounds;
if Array.length inner_low <> n then invalid_arg err_clip_obs_bounds;
let clamp_low = Array.init n (fun i -> Float.max low.(i) inner_low.(i)) in
let clamp_high = Array.init n (fun i -> Float.min high.(i) inner_high.(i)) in
let observation_space = Space.Box.create ~low:clamp_low ~high:clamp_high in
map_observation ~observation_space
~f:(fun obs info ->
(clamp_tensor ~low:clamp_low ~high:clamp_high obs, info))
env
(* Limits *)
let time_limit ~max_episode_steps env =
if max_episode_steps <= 0 then invalid_arg err_time_limit;
let steps = ref 0 in
let add_info info elapsed =
info
|> Info.set "time_limit.truncated" (Info.bool true)
|> Info.set "time_limit.elapsed_steps" (Info.int elapsed)
in
wrap
?id:(derive_id env "/TimeLimit")
~observation_space:env.observation_space ~action_space:env.action_space
~reset:(fun inner ?options () ->
steps := 0;
reset inner ?options ())
~step:(fun inner action ->
incr steps;
let s = step inner action in
if s.terminated || s.truncated then begin
steps := 0;
s
end
else if !steps >= max_episode_steps then begin
let info = add_info s.info !steps in
steps := 0;
{ s with truncated = true; info }
end
else s)
env
================================================
FILE: packages/fehu/lib/env.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Reinforcement learning environments.
An environment defines an interactive loop: the agent observes, acts, and
receives a reward. The environment enforces a lifecycle: {!reset} must be
called before {!step}, and a terminated or truncated episode requires
another {!reset}. *)
(** {1:step Step results} *)
type 'obs step = {
observation : 'obs; (** The observation after the action. *)
reward : float; (** Scalar reward for the transition. *)
terminated : bool; (** [true] when the episode ends naturally. *)
truncated : bool; (** [true] when the episode is cut short. *)
info : Info.t; (** Auxiliary metadata. *)
}
(** The type for step results. *)
val step_result :
observation:'obs ->
?reward:float ->
?terminated:bool ->
?truncated:bool ->
?info:Info.t ->
unit ->
'obs step
(** [step_result ~observation ()] constructs a step result. [reward] defaults to
[0.], [terminated] and [truncated] default to [false], [info] defaults to
{!Info.empty}. *)
(** {1:render Render modes} *)
type render_mode = [ `Human | `Rgb_array | `Ansi | `Svg | `Custom of string ]
(** Rendering modes supported by environments. *)
val render_mode_to_string : render_mode -> string
(** [render_mode_to_string m] is the string representation of [m]. *)
(** {1:env Environments} *)
type ('obs, 'act, 'render) t
(** Environment handle. Use {!create} or {!wrap} to construct. *)
val create :
?id:string ->
observation_space:'obs Space.t ->
action_space:'act Space.t ->
?render_mode:render_mode ->
?render_modes:string list ->
reset:(('obs, 'act, 'render) t -> ?options:Info.t -> unit -> 'obs * Info.t) ->
step:(('obs, 'act, 'render) t -> 'act -> 'obs step) ->
?render:(unit -> 'render option) ->
?close:(unit -> unit) ->
unit ->
('obs, 'act, 'render) t
(** [create ~observation_space ~action_space ~reset ~step ()] makes a new
environment.
[reset] and [step] receive the environment handle as first argument. Random
keys for stochastic behavior are drawn from the implicit RNG scope.
[render_modes] lists the supported render mode strings. When [render_mode]
is provided, it must appear in [render_modes].
Raises [Invalid_argument] if [render_mode] is not in [render_modes]. *)
val wrap :
?id:string ->
observation_space:'obs2 Space.t ->
action_space:'act2 Space.t ->
?render_mode:render_mode ->
reset:(('obs1, 'act1, 'render) t -> ?options:Info.t -> unit -> 'obs2 * Info.t) ->
step:(('obs1, 'act1, 'render) t -> 'act2 -> 'obs2 step) ->
?render:(('obs1, 'act1, 'render) t -> 'render option) ->
?close:(('obs1, 'act1, 'render) t -> unit) ->
('obs1, 'act1, 'render) t ->
('obs2, 'act2, 'render) t
(** [wrap ~observation_space ~action_space ~reset ~step inner] builds a new
environment that wraps [inner]. The wrapper shares [inner]'s lifecycle state
(RNG, closed flag, reset flag). All guards (closed, needs-reset, space
validation) are enforced by {!reset}/{!step}, so wrappers get them
automatically.
The render type is preserved from [inner]. [render_mode] defaults to
[inner]'s. *)
(** {1:accessors Accessors} *)
val id : ('obs, 'act, 'render) t -> string option
(** [id env] is the environment's identifier, if any. *)
val observation_space : ('obs, 'act, 'render) t -> 'obs Space.t
(** [observation_space env] is the space of valid observations. *)
val action_space : ('obs, 'act, 'render) t -> 'act Space.t
(** [action_space env] is the space of valid actions. *)
val render_mode : ('obs, 'act, 'render) t -> render_mode option
(** [render_mode env] is the render mode chosen at construction, if any. *)
(** {1:lifecycle Lifecycle} *)
val closed : ('obs, 'act, 'render) t -> bool
(** [closed env] is [true] iff the environment has been closed. *)
val reset : ('obs, 'act, 'render) t -> ?options:Info.t -> unit -> 'obs * Info.t
(** [reset env ()] resets the environment to an initial state.
Raises [Invalid_argument] if [env] is closed, or if the reset function
produces an observation outside {!observation_space}. *)
val step : ('obs, 'act, 'render) t -> 'act -> 'obs step
(** [step env action] advances the environment by one timestep.
Raises [Invalid_argument] if [env] is closed, if no {!reset} has been called
since the last terminal step, if [action] is outside {!action_space}, or if
the step function produces an observation outside {!observation_space}. *)
val render : ('obs, 'act, 'render) t -> 'render option
(** [render env] produces a visualization of the current state.
Raises [Invalid_argument] if [env] is closed. *)
val close : ('obs, 'act, 'render) t -> unit
(** [close env] releases resources held by the environment. Subsequent calls are
no-ops. *)
(** {1:wrappers Wrappers} *)
val map_observation :
observation_space:'obs2 Space.t ->
f:('obs1 -> Info.t -> 'obs2 * Info.t) ->
('obs1, 'act, 'render) t ->
('obs2, 'act, 'render) t
(** [map_observation ~observation_space ~f env] transforms observations. Every
observation from {!reset} and {!step} is passed through [f] together with
the info dictionary. *)
val map_action :
action_space:'act2 Space.t ->
f:('act2 -> 'act1) ->
('obs, 'act1, 'render) t ->
('obs, 'act2, 'render) t
(** [map_action ~action_space ~f env] transforms actions before passing them to
the inner environment. *)
val map_reward :
f:(reward:float -> info:Info.t -> float * Info.t) ->
('obs, 'act, 'render) t ->
('obs, 'act, 'render) t
(** [map_reward ~f env] transforms rewards after each step. *)
(** {1:clip Clipping} *)
val clip_action :
('obs, Space.Box.element, 'render) t -> ('obs, Space.Box.element, 'render) t
(** [clip_action env] clamps continuous actions to the bounds of the inner
environment's {!Space.Box} action space. The wrapper exposes a relaxed space
that accepts any float values, then clips before forwarding. *)
val clip_observation :
low:float array ->
high:float array ->
(Space.Box.element, 'act, 'render) t ->
(Space.Box.element, 'act, 'render) t
(** [clip_observation ~low ~high env] clamps observations to \[[low]; [high]\].
The wrapper's observation space is the intersection of the provided bounds
and the inner space's bounds.
Raises [Invalid_argument] if [low] and [high] differ in length or do not
match the inner space's dimensionality. *)
(** {1:limits Limits} *)
val time_limit :
max_episode_steps:int -> ('obs, 'act, 'render) t -> ('obs, 'act, 'render) t
(** [time_limit ~max_episode_steps env] enforces a maximum episode length. When
the limit is reached the step's [truncated] flag is set to [true]. The
counter resets on {!reset}.
Raises [Invalid_argument] if [max_episode_steps <= 0]. *)
================================================
FILE: packages/fehu/lib/envs/cartpole.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
open Fehu
type obs = (float, Nx.float32_elt) Nx.t
type act = (int32, Nx.int32_elt) Nx.t
type render = string
(* Physics constants matching Gymnasium CartPole-v1 *)
let gravity = 9.8
let masscart = 1.0
let masspole = 0.1
let total_mass = masscart +. masspole
let half_pole_length = 0.5
let polemass_length = masspole *. half_pole_length
let force_mag = 10.0
let tau = 0.02
(* Termination thresholds *)
let theta_threshold = 12. *. Float.pi /. 180.
let x_threshold = 2.4
let max_steps = 500
(* Float32-representable large bound for "unbounded" dimensions *)
let f32_max = 3.4028235e38
let observation_space =
Space.Box.create
~low:[| -4.8; -.f32_max; -.theta_threshold *. 2.; -.f32_max |]
~high:[| 4.8; f32_max; theta_threshold *. 2.; f32_max |]
let action_space = Space.Discrete.create 2
let make_obs x x_dot theta theta_dot =
Nx.create Nx.float32 [| 4 |] [| x; x_dot; theta; theta_dot |]
let make ?render_mode () =
let x = ref 0.0 in
let x_dot = ref 0.0 in
let theta = ref 0.0 in
let theta_dot = ref 0.0 in
let steps = ref 0 in
let reset _env ?options:_ () =
let random_state () =
let r = Nx.rand Nx.float32 [| 1 |] in
let v = (Nx.to_array r).(0) in
(v -. 0.5) *. 0.1
in
x := random_state ();
x_dot := random_state ();
theta := random_state ();
theta_dot := random_state ();
steps := 0;
(make_obs !x !x_dot !theta !theta_dot, Info.empty)
in
let step _env action =
let force =
if Space.Discrete.to_int action = 1 then force_mag else -.force_mag
in
let costheta = cos !theta in
let sintheta = sin !theta in
let temp =
(force +. (polemass_length *. !theta_dot *. !theta_dot *. sintheta))
/. total_mass
in
let thetaacc =
((gravity *. sintheta) -. (costheta *. temp))
/. (half_pole_length
*. ((4.0 /. 3.0) -. (masspole *. costheta *. costheta /. total_mass)))
in
let xacc =
temp -. (polemass_length *. thetaacc *. costheta /. total_mass)
in
x := !x +. (tau *. !x_dot);
x_dot := !x_dot +. (tau *. xacc);
theta := !theta +. (tau *. !theta_dot);
theta_dot := !theta_dot +. (tau *. thetaacc);
incr steps;
let terminated =
!x < -.x_threshold || !x > x_threshold || !theta < -.theta_threshold
|| !theta > theta_threshold
in
let truncated = (not terminated) && !steps >= max_steps in
let reward = if terminated then 0.0 else 1.0 in
let info = Info.set "steps" (Info.int !steps) Info.empty in
Env.step_result
~observation:(make_obs !x !x_dot !theta !theta_dot)
~reward ~terminated ~truncated ~info ()
in
let render () =
Some
(Printf.sprintf
"CartPole: x=%.3f, x_dot=%.3f, theta=%.3f\xc2\xb0, theta_dot=%.3f, \
steps=%d"
!x !x_dot
(!theta *. 180. /. Float.pi)
!theta_dot !steps)
in
Env.create ?render_mode ~render_modes:[ "ansi" ] ~id:"CartPole-v1"
~observation_space ~action_space ~reset ~step ~render ()
================================================
FILE: packages/fehu/lib/envs/dune
================================================
(library
(name fehu_envs)
(public_name fehu.envs)
(libraries fehu nx))
================================================
FILE: packages/fehu/lib/envs/fehu_envs.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
module Random_walk = Random_walk
module Cartpole = Cartpole
module Grid_world = Grid_world
module Mountain_car = Mountain_car
================================================
FILE: packages/fehu/lib/envs/fehu_envs.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Built-in environments for testing and learning.
Four environments covering the standard RL benchmarks: a simple 1D walk, the
classic cart-pole, a grid navigation problem, and the sparse-reward mountain
car. All follow the {!Fehu.Env} interface. *)
(** {1:envs Environments} *)
module Random_walk : sig
(** One-dimensional random walk.
The agent moves left or right on a line bounded by \[[-10]; [10]\]. Reward
is [- |position|]. Episodes terminate when the agent reaches a boundary or
after 200 steps.
{b Observation}: {!Fehu.Space.Box} of shape [[1]] in \[[-10.0]; [10.0]\].
{b Actions}: {!Fehu.Space.Discrete} 2 -- 0 = left, 1 = right.
{b Render modes}: [ansi]. *)
type obs = (float, Nx.float32_elt) Nx.t
type act = (int32, Nx.int32_elt) Nx.t
type render = string
val make :
?render_mode:Fehu.Env.render_mode -> unit -> (obs, act, render) Fehu.Env.t
(** [make ()] is a random walk environment. *)
end
module Cartpole : sig
(** Classic cart-pole balancing (CartPole-v1).
A pole is attached to a cart on a frictionless track. The agent pushes the
cart left or right to keep the pole upright. Reward is [+1.0] per step
while the pole stays up. The episode terminates when the pole exceeds
+/-12 degrees or the cart leaves +/-2.4, and truncates at 500 steps.
{b Observation}: {!Fehu.Space.Box} of shape [[4]] -- [x], [x_dot],
[theta], [theta_dot].
{b Actions}: {!Fehu.Space.Discrete} 2 -- 0 = push left, 1 = push right.
{b Render modes}: [ansi]. *)
type obs = (float, Nx.float32_elt) Nx.t
type act = (int32, Nx.int32_elt) Nx.t
type render = string
val make :
?render_mode:Fehu.Env.render_mode -> unit -> (obs, act, render) Fehu.Env.t
(** [make ()] is a cart-pole environment. *)
end
module Grid_world : sig
(** 5x5 grid navigation with obstacle.
The agent starts at [(0, 0)] and must reach the goal at [(4, 4)]. An
obstacle at [(2, 2)] blocks movement. Reward is [+10.0] on reaching the
goal, [-1.0] otherwise. Truncates at 200 steps.
{b Observation}: {!Fehu.Space.Multi_discrete} [[5; 5]] -- [(row, col)].
{b Actions}: {!Fehu.Space.Discrete} 4 -- 0 = up, 1 = down, 2 = left, 3 =
right.
{b Render modes}: [ansi], [rgb_array]. *)
type obs = (int32, Nx.int32_elt) Nx.t
type act = (int32, Nx.int32_elt) Nx.t
type render = Text of string | Image of Fehu.Render.image
val make :
?render_mode:Fehu.Env.render_mode -> unit -> (obs, act, render) Fehu.Env.t
(** [make ()] is a grid world environment. *)
end
module Mountain_car : sig
(** Mountain car with sparse reward (MountainCar-v0).
A car sits in a valley between two hills. The engine is too weak to climb
the right hill directly; the agent must build momentum by rocking back and
forth. Reward is [-1.0] per step. The episode terminates when the car
reaches position >= 0.5 with non-negative velocity, and truncates at 200
steps.
{b Observation}: {!Fehu.Space.Box} of shape [[2]] -- [position],
[velocity].
{b Actions}: {!Fehu.Space.Discrete} 3 -- 0 = push left, 1 = coast, 2 =
push right.
{b Render modes}: [ansi]. *)
type obs = (float, Nx.float32_elt) Nx.t
type act = (int32, Nx.int32_elt) Nx.t
type render = string
val make :
?render_mode:Fehu.Env.render_mode -> unit -> (obs, act, render) Fehu.Env.t
(** [make ()] is a mountain car environment. *)
end
================================================
FILE: packages/fehu/lib/envs/grid_world.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
open Fehu
type obs = (int32, Nx.int32_elt) Nx.t
type act = (int32, Nx.int32_elt) Nx.t
type render = Text of string | Image of Render.image
let grid_size = 5
let max_steps = 200
let observation_space = Space.Multi_discrete.create [| grid_size; grid_size |]
let action_space = Space.Discrete.create 4
let is_goal row col = row = grid_size - 1 && col = grid_size - 1
let is_obstacle row col = row = 2 && col = 2
let is_valid row col =
row >= 0 && row < grid_size && col >= 0 && col < grid_size
&& not (is_obstacle row col)
let make_obs row col =
Nx.create Nx.int32 [| 2 |] [| Int32.of_int row; Int32.of_int col |]
(* ANSI rendering *)
let render_text row col =
let buffer = Bytes.make (grid_size * grid_size) '.' in
Bytes.set buffer ((row * grid_size) + col) 'A';
Bytes.set buffer (((grid_size - 1) * grid_size) + (grid_size - 1)) 'G';
Bytes.set buffer ((2 * grid_size) + 2) '#';
let rows =
List.init grid_size (fun r ->
Bytes.sub_string buffer (r * grid_size) grid_size)
in
Format.asprintf "Position: (%d, %d)@.%a" row col
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@.")
Format.pp_print_string)
rows
(* RGB rendering *)
let cell_size = 32
let frame_width = grid_size * cell_size
let frame_height = grid_size * cell_size
let fill_rect data ~x0 ~y0 ~w ~h ~r ~g ~b =
for dy = 0 to h - 1 do
let row_offset = (y0 + dy) * frame_width * 3 in
for dx = 0 to w - 1 do
let base = row_offset + ((x0 + dx) * 3) in
Bigarray.Array1.unsafe_set data base r;
Bigarray.Array1.unsafe_set data (base + 1) g;
Bigarray.Array1.unsafe_set data (base + 2) b
done
done
let render_image row col =
let len = frame_width * frame_height * 3 in
let data =
Bigarray.Array1.create Bigarray.int8_unsigned Bigarray.c_layout len
in
fill_rect data ~x0:0 ~y0:0 ~w:frame_width ~h:frame_height ~r:30 ~g:33 ~b:36;
for gr = 0 to grid_size - 1 do
for gc = 0 to grid_size - 1 do
let x0 = gc * cell_size in
let y0 = gr * cell_size in
fill_rect data ~x0 ~y0 ~w:cell_size ~h:cell_size ~r:44 ~g:48 ~b:52;
fill_rect data ~x0:(x0 + 1) ~y0:(y0 + 1) ~w:(cell_size - 2)
~h:(cell_size - 2) ~r:54 ~g:60 ~b:65
done
done;
let draw_cell cr cc ~r ~g ~b =
fill_rect data
~x0:((cc * cell_size) + 2)
~y0:((cr * cell_size) + 2)
~w:(cell_size - 4) ~h:(cell_size - 4) ~r ~g ~b
in
draw_cell row col ~r:78 ~g:162 ~b:196;
draw_cell (grid_size - 1) (grid_size - 1) ~r:76 ~g:175 ~b:80;
draw_cell 2 2 ~r:200 ~g:80 ~b:80;
Render.image ~width:frame_width ~height:frame_height data
let make ?render_mode () =
let row = ref 0 in
let col = ref 0 in
let steps = ref 0 in
let reset _env ?options:_ () =
row := 0;
col := 0;
steps := 0;
(make_obs 0 0, Info.empty)
in
let step _env action =
let r, c = (!row, !col) in
let nr, nc =
match Space.Discrete.to_int action with
| 0 -> (r - 1, c)
| 1 -> (r + 1, c)
| 2 -> (r, c - 1)
| 3 -> (r, c + 1)
| _ -> (r, c)
in
let nr, nc = if is_valid nr nc then (nr, nc) else (r, c) in
row := nr;
col := nc;
incr steps;
let terminated = is_goal nr nc in
let truncated = (not terminated) && !steps >= max_steps in
let reward = if terminated then 10.0 else -1.0 in
let info = Info.set "steps" (Info.int !steps) Info.empty in
Env.step_result ~observation:(make_obs nr nc) ~reward ~terminated ~truncated
~info ()
in
let render_mode_val = render_mode in
let render () =
match render_mode_val with
| Some `Rgb_array -> Some (Image (render_image !row !col))
| _ -> Some (Text (render_text !row !col))
in
Env.create ?render_mode ~render_modes:[ "ansi"; "rgb_array" ]
~id:"GridWorld-v0" ~observation_space ~action_space ~reset ~step ~render ()
================================================
FILE: packages/fehu/lib/envs/mountain_car.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
open Fehu
type obs = (float, Nx.float32_elt) Nx.t
type act = (int32, Nx.int32_elt) Nx.t
type render = string
(* Physics constants matching Gymnasium MountainCar-v0 *)
let min_position = -1.2
let max_position = 0.6
let max_speed = 0.07
let goal_position = 0.5
let goal_velocity = 0.0
let force = 0.001
let gravity = 0.0025
let max_steps = 200
let observation_space =
Space.Box.create
~low:[| min_position; -.max_speed |]
~high:[| max_position; max_speed |]
let action_space = Space.Discrete.create 3
let make_obs position velocity =
Nx.create Nx.float32 [| 2 |] [| position; velocity |]
let make ?render_mode () =
let position = ref 0.0 in
let velocity = ref 0.0 in
let steps = ref 0 in
let reset _env ?options:_ () =
let r = Nx.rand Nx.float32 [| 1 |] in
let v = (Nx.to_array r).(0) in
position := -0.6 +. (v *. 0.2);
velocity := 0.0;
steps := 0;
(make_obs !position !velocity, Info.empty)
in
let step _env action =
let force_direction = float_of_int (Space.Discrete.to_int action - 1) in
let vel =
!velocity +. (force_direction *. force)
-. (gravity *. cos (3.0 *. !position))
in
let vel = Float.max (-.max_speed) (Float.min vel max_speed) in
let pos = !position +. vel in
let pos = Float.max min_position (Float.min pos max_position) in
let vel = if pos = min_position && vel < 0.0 then 0.0 else vel in
position := pos;
velocity := vel;
incr steps;
let terminated = pos >= goal_position && vel >= goal_velocity in
let truncated = (not terminated) && !steps >= max_steps in
let reward = -1.0 in
let info = Info.set "steps" (Info.int !steps) Info.empty in
Env.step_result ~observation:(make_obs pos vel) ~reward ~terminated
~truncated ~info ()
in
let render () =
let normalized_pos =
(!position -. min_position) /. (max_position -. min_position)
in
let car_pos = int_of_float (normalized_pos *. 40.0) in
let goal_pos =
int_of_float
((goal_position -. min_position)
/. (max_position -. min_position)
*. 40.0)
in
let track = Bytes.make 41 '-' in
Bytes.set track goal_pos 'G';
Bytes.set track (max 0 (min 40 car_pos)) 'C';
Some
(Printf.sprintf "MountainCar: [%s] pos=%.3f, vel=%.3f, steps=%d"
(Bytes.to_string track) !position !velocity !steps)
in
Env.create ?render_mode ~render_modes:[ "ansi" ] ~id:"MountainCar-v0"
~observation_space ~action_space ~reset ~step ~render ()
================================================
FILE: packages/fehu/lib/envs/random_walk.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
open Fehu
type obs = (float, Nx.float32_elt) Nx.t
type act = (int32, Nx.int32_elt) Nx.t
type render = string
let step_size = 1.0
let max_position = 10.0
let max_steps = 200
let observation_space =
Space.Box.create ~low:[| -.max_position |] ~high:[| max_position |]
let action_space = Space.Discrete.create 2
let make_obs position = Nx.create Nx.float32 [| 1 |] [| position |]
let render_ansi position =
let offset = int_of_float (position +. max_position) in
let offset = max 0 (min 20 offset) in
let buffer = Bytes.make 21 '.' in
Bytes.set buffer offset 'o';
Printf.sprintf "Position: %+.2f\n|%s|" position (Bytes.to_string buffer)
let make ?render_mode () =
let position = ref 0.0 in
let steps = ref 0 in
let reset _env ?options:_ () =
position := 0.0;
steps := 0;
(make_obs 0.0, Info.empty)
in
let step _env action =
let direction =
if Space.Discrete.to_int action = 0 then -.step_size else step_size
in
let updated = !position +. direction in
let clamped = Float.min max_position (Float.max (-.max_position) updated) in
position := clamped;
incr steps;
let terminated = Float.abs clamped >= max_position in
let truncated = (not terminated) && !steps >= max_steps in
let reward = -.Float.abs clamped in
let info = Info.set "steps" (Info.int !steps) Info.empty in
Env.step_result ~observation:(make_obs clamped) ~reward ~terminated
~truncated ~info ()
in
let render () = Some (render_ansi !position) in
Env.create ?render_mode ~render_modes:[ "ansi" ] ~id:"RandomWalk-v0"
~observation_space ~action_space ~reset ~step ~render ()
================================================
FILE: packages/fehu/lib/eval.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
type stats = {
mean_reward : float;
std_reward : float;
mean_length : float;
n_episodes : int;
}
let run env ~policy ?(n_episodes = 10) ?(max_steps = 1000) () =
let ep_rewards = Array.make n_episodes 0.0 in
let ep_lengths = Array.make n_episodes 0.0 in
for ep = 0 to n_episodes - 1 do
let obs, _info = Env.reset env () in
let current_obs = ref obs in
let total_reward = ref 0.0 in
let steps = ref 0 in
let done_flag = ref false in
while !steps < max_steps && not !done_flag do
let action = policy !current_obs in
let s = Env.step env action in
total_reward := !total_reward +. s.reward;
steps := !steps + 1;
current_obs := s.observation;
done_flag := s.terminated || s.truncated
done;
ep_rewards.(ep) <- !total_reward;
ep_lengths.(ep) <- Float.of_int !steps
done;
let n = Float.of_int n_episodes in
let mean_reward = ref 0.0 in
let mean_length = ref 0.0 in
for i = 0 to n_episodes - 1 do
mean_reward := !mean_reward +. ep_rewards.(i);
mean_length := !mean_length +. ep_lengths.(i)
done;
mean_reward := !mean_reward /. n;
mean_length := !mean_length /. n;
let var_sum = ref 0.0 in
for i = 0 to n_episodes - 1 do
let d = ep_rewards.(i) -. !mean_reward in
var_sum := !var_sum +. (d *. d)
done;
let std_reward = sqrt (!var_sum /. n) in
{
mean_reward = !mean_reward;
std_reward;
mean_length = !mean_length;
n_episodes;
}
================================================
FILE: packages/fehu/lib/eval.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Policy evaluation.
Runs a deterministic or stochastic policy over multiple episodes and reports
summary statistics. *)
(** {1:types Types} *)
type stats = {
mean_reward : float; (** Mean total reward across episodes. *)
std_reward : float; (** Standard deviation of total rewards. *)
mean_length : float; (** Mean episode length in steps. *)
n_episodes : int; (** Number of episodes evaluated. *)
}
(** The type for evaluation statistics. *)
(** {1:running Running} *)
val run :
('obs, 'act, 'render) Env.t ->
policy:('obs -> 'act) ->
?n_episodes:int ->
?max_steps:int ->
unit ->
stats
(** [run env ~policy ()] evaluates [policy] over [n_episodes] (default [10])
episodes of at most [max_steps] (default [1000]) steps each. The environment
is reset between episodes. *)
================================================
FILE: packages/fehu/lib/fehu.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
module Value = Value
module Info = Info
module Space = Space
module Env = Env
module Vec_env = Vec_env
module Collect = Collect
module Buffer = Buffer
module Gae = Gae
module Eval = Eval
module Render = Render
================================================
FILE: packages/fehu/lib/fehu.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** {0 Fehu} Reinforcement learning environments and utilities.
{1 Core}
{!modules: Value Info Space Env}
{1 Collection and training}
{!modules: Collect Buffer Gae Eval}
{1 Composition}
{!modules: Vec_env Render} *)
module Value = Value
module Info = Info
module Space = Space
module Env = Env
module Vec_env = Vec_env
module Collect = Collect
module Buffer = Buffer
module Gae = Gae
module Eval = Eval
module Render = Render
================================================
FILE: packages/fehu/lib/gae.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let err_lengths = "Gae: all arrays must have the same length"
let err_returns_lengths =
"Gae.returns: rewards, terminated, and truncated must have the same length"
let err_cfv_lengths =
"Gae.compute_from_values: all arrays must have the same length"
let compute ~rewards ~values ~terminated ~truncated ~next_values ~gamma ~lambda
=
let n = Array.length rewards in
if
n <> Array.length values
|| n <> Array.length terminated
|| n <> Array.length truncated
|| n <> Array.length next_values
then invalid_arg err_lengths;
let advantages = Array.make n 0.0 in
let returns = Array.make n 0.0 in
let last_gae = ref 0.0 in
for t = n - 1 downto 0 do
let next_val, continuation =
if terminated.(t) then (0.0, 0.0)
else if truncated.(t) then (next_values.(t), 0.0)
else begin
let v = if t = n - 1 then next_values.(t) else values.(t + 1) in
(v, 1.0)
end
in
let delta = rewards.(t) +. (gamma *. next_val) -. values.(t) in
last_gae := delta +. (gamma *. lambda *. continuation *. !last_gae);
advantages.(t) <- !last_gae;
returns.(t) <- !last_gae +. values.(t)
done;
(advantages, returns)
let compute_from_values ~rewards ~values ~terminated ~truncated ~last_value
~gamma ~lambda =
let n = Array.length rewards in
if
n <> Array.length values
|| n <> Array.length terminated
|| n <> Array.length truncated
then invalid_arg err_cfv_lengths;
let next_values =
Array.init n (fun t -> if t = n - 1 then last_value else values.(t + 1))
in
compute ~rewards ~values ~terminated ~truncated ~next_values ~gamma ~lambda
let returns ~rewards ~terminated ~truncated ~gamma =
let n = Array.length rewards in
if n <> Array.length terminated || n <> Array.length truncated then
invalid_arg err_returns_lengths;
let ret = Array.make n 0.0 in
let acc = ref 0.0 in
for t = n - 1 downto 0 do
let cont = if terminated.(t) || truncated.(t) then 0.0 else 1.0 in
acc := rewards.(t) +. (gamma *. cont *. !acc);
ret.(t) <- !acc
done;
ret
let normalize ?(eps = 1e-8) arr =
let n = Array.length arr in
if n = 0 then arr
else begin
let mean = ref 0.0 in
let m2 = ref 0.0 in
for i = 0 to n - 1 do
let k = Float.of_int (i + 1) in
let x = arr.(i) in
let delta = x -. !mean in
mean := !mean +. (delta /. k);
let delta2 = x -. !mean in
m2 := !m2 +. (delta *. delta2)
done;
let std = sqrt (!m2 /. Float.of_int n) +. eps in
let mu = !mean in
Array.init n (fun i -> (arr.(i) -. mu) /. std)
end
================================================
FILE: packages/fehu/lib/gae.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Generalized Advantage Estimation.
Correctly handles the distinction between terminated and truncated episodes.
On termination, the bootstrap value is zero. On truncation, the bootstrap
value comes from [next_values]. *)
(** {1:gae GAE} *)
val compute :
rewards:float array ->
values:float array ->
terminated:bool array ->
truncated:bool array ->
next_values:float array ->
gamma:float ->
lambda:float ->
float array * float array
(** [compute ~rewards ~values ~terminated ~truncated ~next_values
~gamma ~lambda] is [(advantages, returns)].
[next_values.(t)] is V(s_{{t+1}}). When [terminated.(t)] is
[true], the bootstrap value is zero and the GAE trace resets.
When [truncated.(t)] is [true], the bootstrap value is
[next_values.(t)] and the trace resets for the new episode.
Otherwise, continuation uses the next step's value.
Raises [Invalid_argument] if array lengths differ. *)
val compute_from_values :
rewards:float array ->
values:float array ->
terminated:bool array ->
truncated:bool array ->
last_value:float ->
gamma:float ->
lambda:float ->
float array * float array
(** [compute_from_values ~rewards ~values ~terminated ~truncated ~last_value
~gamma ~lambda] is [(advantages, returns)].
Convenience wrapper around {!compute} that builds [next_values] from
[values] and [last_value]: [next_values.(t) = values.(t+1)] for [t < n-1],
and [next_values.(n-1) = last_value].
Raises [Invalid_argument] if array lengths differ. *)
(** {1:returns Monte Carlo returns} *)
val returns :
rewards:float array ->
terminated:bool array ->
truncated:bool array ->
gamma:float ->
float array
(** [returns ~rewards ~terminated ~truncated ~gamma] computes discounted
cumulative returns. The accumulation resets at terminal or truncated states.
*)
(** {1:normalize Normalization} *)
val normalize : ?eps:float -> float array -> float array
(** [normalize arr] is a copy of [arr] with zero mean and unit variance. [eps]
(default [1e-8]) prevents division by zero. *)
================================================
FILE: packages/fehu/lib/info.ml
================================================
module String_map = Map.Make (String)
type t = Value.t String_map.t
let empty = String_map.empty
let is_empty = String_map.is_empty
let set key value info = String_map.add key value info
let find key info = String_map.find_opt key info
let find_exn key info =
match String_map.find_opt key info with
| Some v -> v
| None -> invalid_arg (Printf.sprintf "Info.find_exn: key %S not present" key)
let remove key info = String_map.remove key info
let merge a b = String_map.union (fun _key _left right -> Some right) a b
let to_list info = String_map.bindings info
let of_list kvs =
List.fold_left (fun acc (k, v) -> String_map.add k v acc) String_map.empty kvs
let to_value info = Value.Dict (String_map.bindings info)
let pp ppf t =
let bindings = String_map.bindings t in
Format.fprintf ppf "{";
List.iteri
(fun i (k, v) ->
if i > 0 then Format.fprintf ppf "; ";
Format.fprintf ppf "%s: %a" k Value.pp v)
bindings;
Format.fprintf ppf "}"
(* Convenience constructors *)
let null = Value.Null
let bool b = Value.Bool b
let int i = Value.Int i
let float f = Value.Float f
let string s = Value.String s
let int_array arr = Value.Int_array (Array.copy arr)
let float_array arr = Value.Float_array (Array.copy arr)
let bool_array arr = Value.Bool_array (Array.copy arr)
================================================
FILE: packages/fehu/lib/info.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Step metadata dictionaries.
Info dictionaries carry auxiliary data returned by {!Env.reset} and
{!Env.step}. Keys are strings and values are {!Value.t}. *)
(** {1:types Types} *)
type t
(** The type for info dictionaries. *)
(** {1:constructors Constructors} *)
val empty : t
(** [empty] is the empty dictionary. *)
val of_list : (string * Value.t) list -> t
(** [of_list kvs] is a dictionary from the given key-value pairs. *)
(** {1:predicates Predicates} *)
val is_empty : t -> bool
(** [is_empty t] is [true] iff [t] has no bindings. *)
(** {1:ops Operations} *)
val set : string -> Value.t -> t -> t
(** [set k v t] is [t] with [k] bound to [v]. *)
val find : string -> t -> Value.t option
(** [find k t] is the value bound to [k] in [t], if any. *)
val find_exn : string -> t -> Value.t
(** [find_exn k t] is the value bound to [k] in [t].
Raises [Invalid_argument] if [k] is not present. *)
val remove : string -> t -> t
(** [remove k t] is [t] without the binding for [k]. *)
val merge : t -> t -> t
(** [merge a b] is the union of [a] and [b]. When both have a binding for the
same key, the value from [b] wins. *)
(** {1:converting Converting} *)
val to_list : t -> (string * Value.t) list
(** [to_list t] is the bindings of [t] in key order. *)
val to_value : t -> Value.t
(** [to_value t] is [t] as a {!Value.Dict}. *)
(** {1:fmt Formatting} *)
val pp : Format.formatter -> t -> unit
(** [pp] formats an info dictionary for debugging. *)
(** {1:convenience Convenience value constructors} *)
val null : Value.t
(** [null] is {!Value.Null}. *)
val bool : bool -> Value.t
(** [bool b] is [Value.Bool b]. *)
val int : int -> Value.t
(** [int i] is [Value.Int i]. *)
val float : float -> Value.t
(** [float f] is [Value.Float f]. *)
val string : string -> Value.t
(** [string s] is [Value.String s]. *)
val int_array : int array -> Value.t
(** [int_array arr] is [Value.Int_array (Array.copy arr)]. *)
val float_array : float array -> Value.t
(** [float_array arr] is [Value.Float_array (Array.copy arr)]. *)
val bool_array : bool array -> Value.t
(** [bool_array arr] is [Value.Bool_array (Array.copy arr)]. *)
================================================
FILE: packages/fehu/lib/render.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
module Pixel = struct
type format = Rgb | Rgba | Gray
let channels = function Rgb -> 3 | Rgba -> 4 | Gray -> 1
end
type image = {
width : int;
height : int;
pixel_format : Pixel.format;
data : (int, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t;
}
let err_data_length ~expected ~got =
Printf.sprintf
"Render.image: data length %d does not match width * height * channels = %d"
got expected
let image ~width ~height ?(pixel_format = Pixel.Rgb) data =
let expected = width * height * Pixel.channels pixel_format in
let got = Bigarray.Array1.dim data in
if got <> expected then invalid_arg (err_data_length ~expected ~got);
{ width; height; pixel_format; data }
let rollout env ~policy ~steps ~sink () =
let obs, _info = Env.reset env () in
let current_obs = ref obs in
for _ = 1 to steps do
let action = policy !current_obs in
let step = Env.step env action in
(match Env.render env with Some frame -> sink frame | None -> ());
current_obs := step.Env.observation;
if step.Env.terminated || step.Env.truncated then begin
let obs, _info = Env.reset env () in
current_obs := obs
end
done
let derive_id env suffix =
match Env.id env with None -> None | Some id -> Some (id ^ suffix)
let on_render ~sink env =
let maybe_record inner =
match Env.render inner with Some frame -> sink frame | None -> ()
in
Env.wrap
?id:(derive_id env "/OnRender")
~observation_space:(Env.observation_space env)
~action_space:(Env.action_space env)
~reset:(fun inner ?options () ->
let result = Env.reset inner ?options () in
maybe_record inner;
result)
~step:(fun inner action ->
let s = Env.step inner action in
maybe_record inner;
s)
~render:(fun inner -> Env.render inner)
env
================================================
FILE: packages/fehu/lib/render.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Visualization primitives.
{!image} is the standard frame type for rgb-rendered environments.
{!rollout} runs a policy and feeds rendered frames to a user-provided sink.
*)
(** {1:pixel Pixel formats} *)
module Pixel : sig
(** The type for pixel formats. *)
type format =
| Rgb (** 3 channels. *)
| Rgba (** 4 channels. *)
| Gray (** 1 channel. *)
val channels : format -> int
(** [channels fmt] is the number of channels for [fmt]. *)
end
(** {1:image Images} *)
type image = {
width : int; (** Width in pixels. *)
height : int; (** Height in pixels. *)
pixel_format : Pixel.format; (** Pixel layout. *)
data : (int, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t;
(** Raw pixel data of length [width * height * channels]. *)
}
(** The type for rendered frames. *)
val image :
width:int ->
height:int ->
?pixel_format:Pixel.format ->
(int, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t ->
image
(** [image ~width ~height data] constructs a frame. [pixel_format] defaults to
[Rgb].
Raises [Invalid_argument] if [Bigarray.Array1.dim data] does not equal
[width * height * channels]. *)
(** {1:rollout Rollout} *)
val rollout :
('obs, 'act, image) Env.t ->
policy:('obs -> 'act) ->
steps:int ->
sink:(image -> unit) ->
unit ->
unit
(** [rollout env ~policy ~steps ~sink ()] runs [policy] in [env] for up to
[steps] steps. Each rendered frame is passed to [sink]. The environment is
reset at the start and on episode boundaries. *)
(** {1:recording Recording} *)
val on_render :
sink:(image -> unit) -> ('obs, 'act, image) Env.t -> ('obs, 'act, image) Env.t
(** [on_render ~sink env] wraps [env] so that every rendered frame after
{!Env.reset} and {!Env.step} is passed to [sink]. The wrapper is
transparent: observations, actions, rewards, and termination signals pass
through unchanged. *)
================================================
FILE: packages/fehu/lib/space.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Error messages *)
let err_discrete_n = "Space.Discrete.create: n must be strictly positive"
let err_discrete_not = "Space.Discrete: not a discrete space"
let err_box_empty = "Space.Box.create: low cannot be empty"
let err_box_shape = "Space.Box.create: low and high must have identical lengths"
let err_box_not = "Space.Box: not a box space"
let err_mb_n = "Space.Multi_binary.create: n must be strictly positive"
let err_md_empty = "Space.Multi_discrete.create: nvec must not be empty"
let err_seq_min = "Space.Sequence.create: min_length must be non-negative"
let err_seq_max = "Space.Sequence.create: max_length must be >= min_length"
let err_text_max = "Space.Text.create: max_length must be positive"
let err_text_charset = "Space.Text.create: charset must not be empty"
let strf = Printf.sprintf
let errorf fmt = Format.kasprintf (fun msg -> Error msg) fmt
(* Spec *)
type spec =
| Discrete of { start : int; n : int }
| Box of { low : float array; high : float array }
| Multi_binary of { n : int }
| Multi_discrete of { nvec : int array }
| Tuple of spec list
| Dict of (string * spec) list
| Sequence of { min_length : int; max_length : int option; base : spec }
| Text of { charset : string; max_length : int }
let rec equal_spec a b =
match (a, b) with
| Discrete a, Discrete b -> a.start = b.start && a.n = b.n
| Box a, Box b -> a.low = b.low && a.high = b.high
| Multi_binary a, Multi_binary b -> a.n = b.n
| Multi_discrete a, Multi_discrete b -> a.nvec = b.nvec
| Tuple a, Tuple b ->
List.length a = List.length b && List.for_all2 equal_spec a b
| Dict a, Dict b ->
List.length a = List.length b
&& List.for_all2
(fun (ka, sa) (kb, sb) -> String.equal ka kb && equal_spec sa sb)
a b
| Sequence a, Sequence b ->
a.min_length = b.min_length
&& a.max_length = b.max_length
&& equal_spec a.base b.base
| Text a, Text b ->
String.equal a.charset b.charset && a.max_length = b.max_length
| ( ( Discrete _ | Box _ | Multi_binary _ | Multi_discrete _ | Tuple _
| Dict _ | Sequence _ | Text _ ),
_ ) ->
false
(* Space type *)
type 'a t = {
spec : spec;
shape : int array option;
contains : 'a -> bool;
sample : unit -> 'a;
pack : 'a -> Value.t;
unpack : Value.t -> ('a, string) result;
boundaries : Value.t list;
box_bounds : (float array * float array) option;
discrete_info : (int * int) option;
}
type packed = Pack : 'a t -> packed
let spec s = s.spec
let shape s = s.shape
let contains s v = s.contains v
let sample s = s.sample ()
let pack s v = s.pack v
let unpack s v = s.unpack v
let boundary_values s = s.boundaries
(* Discrete *)
module Discrete = struct
type element = (int32, Nx.int32_elt) Nx.t
let to_int tensor =
let reshaped = Nx.reshape [| 1 |] tensor in
let arr : Int32.t array = Nx.to_array reshaped in
Int32.to_int arr.(0)
let of_int v = Nx.scalar Nx.int32 (Int32.of_int v)
let create ?(start = 0) n =
if n <= 0 then invalid_arg err_discrete_n;
let hi = start + n in
let contains tensor =
let reshaped = Nx.reshape [| 1 |] tensor in
let arr : Int32.t array = Nx.to_array reshaped in
Array.length arr = 1
&&
let v = Int32.to_int arr.(0) in
v >= start && v < hi
in
let sample () =
let tensor = Nx.randint Nx.int32 ~high:hi [| 1 |] start in
let arr : Int32.t array = Nx.to_array tensor in
Nx.scalar Nx.int32 arr.(0)
in
let pack tensor =
let arr : Int32.t array = Nx.to_array (Nx.reshape [| 1 |] tensor) in
Value.Int (Int32.to_int arr.(0))
in
let unpack = function
| Value.Int v when v >= start && v < hi ->
Ok (Nx.scalar Nx.int32 (Int32.of_int v))
| Value.Int v -> errorf "Discrete value %d outside [%d, %d)" v start hi
| other -> errorf "Discrete expects Int, got %s" (Value.to_string other)
in
let boundaries =
if n = 1 then [ Value.Int start ]
else [ Value.Int start; Value.Int (hi - 1) ]
in
{
spec = Discrete { start; n };
shape = None;
contains;
sample;
pack;
unpack;
boundaries;
box_bounds = None;
discrete_info = Some (start, n);
}
let n s =
match s.discrete_info with
| Some (_, n) -> n
| None -> invalid_arg err_discrete_not
let start s =
match s.discrete_info with
| Some (start, _) -> start
| None -> invalid_arg err_discrete_not
end
(* Box *)
module Box = struct
type element = (float, Nx.float32_elt) Nx.t
let create ~low ~high =
let arity = Array.length low in
if arity = 0 then invalid_arg err_box_empty;
if arity <> Array.length high then invalid_arg err_box_shape;
Array.iteri
(fun i lo ->
if lo > high.(i) then
invalid_arg
(strf "Space.Box.create: low[%d]=%g > high[%d]=%g" i lo i high.(i)))
low;
let low = Array.copy low in
let high = Array.copy high in
let contains tensor =
let sh = Nx.shape tensor in
Array.length sh = 1
&& sh.(0) = arity
&&
let values = Nx.to_array tensor in
let rec loop i =
if i = arity then true
else
let v = values.(i) in
v >= low.(i) && v <= high.(i) && loop (i + 1)
in
loop 0
in
let sample () =
let uniform = Nx.rand Nx.float32 [| arity |] in
let draws = Nx.to_array uniform in
let values =
Array.init arity (fun i ->
let lo = low.(i) in
let hi = high.(i) in
if Float.equal lo hi then lo
else
let range = hi -. lo in
if Float.is_finite range then lo +. (draws.(i) *. range)
else
let v = -1e6 +. (draws.(i) *. 2e6) in
Float.max lo (Float.min hi v))
in
Nx.create Nx.float32 [| arity |] values
in
let pack tensor = Value.Float_array (Array.copy (Nx.to_array tensor)) in
let unpack = function
| Value.Float_array arr when Array.length arr = arity ->
let tensor = Nx.create Nx.float32 [| arity |] arr in
if contains tensor then Ok tensor
else
errorf "Box value outside bounds: %s"
(Value.to_string (Value.Float_array arr))
| Value.Float_array arr ->
errorf "Box expects vector of size %d, got size %d" arity
(Array.length arr)
| other ->
errorf "Box expects Float_array, got %s" (Value.to_string other)
in
let identical =
let same = ref true in
let i = ref 0 in
while !same && !i < arity do
if not (Float.equal low.(!i) high.(!i)) then same := false;
incr i
done;
!same
in
let boundaries =
let lo_v = Value.Float_array (Array.copy low) in
let hi_v = Value.Float_array (Array.copy high) in
if identical then [ lo_v ] else [ lo_v; hi_v ]
in
let box_bounds = Some (Array.copy low, Array.copy high) in
{
spec = Box { low = Array.copy low; high = Array.copy high };
shape = Some [| arity |];
contains;
sample;
pack;
unpack;
boundaries;
box_bounds;
discrete_info = None;
}
let bounds s =
match s.box_bounds with
| Some (low, high) -> (Array.copy low, Array.copy high)
| None -> invalid_arg err_box_not
end
(* Multi_binary *)
module Multi_binary = struct
type element = (int32, Nx.int32_elt) Nx.t
let create n =
if n <= 0 then invalid_arg err_mb_n;
let contains tensor =
let sh = Nx.shape tensor in
Array.length sh = 1
&& sh.(0) = n
&&
let arr : Int32.t array = Nx.to_array tensor in
Array.for_all (fun v -> v = Int32.zero || v = Int32.one) arr
in
let sample () = Nx.randint Nx.int32 ~high:2 [| n |] 0 in
let pack tensor =
let arr : Int32.t array = Nx.to_array tensor in
Value.Bool_array
(Array.init n (fun i -> not (Int32.equal arr.(i) Int32.zero)))
in
let unpack = function
| Value.Bool_array arr when Array.length arr = n ->
let data =
Array.map (fun b -> if b then Int32.one else Int32.zero) arr
in
Ok (Nx.create Nx.int32 [| n |] data)
| Value.Bool_array arr ->
errorf "Multi_binary expects vector of size %d, got size %d" n
(Array.length arr)
| other ->
errorf "Multi_binary expects Bool_array, got %s"
(Value.to_string other)
in
let boundaries =
[
Value.Bool_array (Array.make n false);
Value.Bool_array (Array.make n true);
]
in
{
spec = Multi_binary { n };
shape = Some [| n |];
contains;
sample;
pack;
unpack;
boundaries;
box_bounds = None;
discrete_info = None;
}
end
(* Multi_discrete *)
module Multi_discrete = struct
type element = (int32, Nx.int32_elt) Nx.t
let create nvec =
let arity = Array.length nvec in
if arity = 0 then invalid_arg err_md_empty;
let nvec = Array.copy nvec in
Array.iteri
(fun i bound ->
if bound <= 0 then
invalid_arg
(strf "Space.Multi_discrete.create: nvec[%d] must be > 0" i))
nvec;
let contains tensor =
let sh = Nx.shape tensor in
Array.length sh = 1
&& sh.(0) = arity
&&
let arr : Int32.t array = Nx.to_array tensor in
let rec loop i =
if i = arity then true
else
let v = Int32.to_int arr.(i) in
v >= 0 && v < nvec.(i) && loop (i + 1)
in
loop 0
in
let sample () =
let data =
Array.init arity (fun i ->
let tensor = Nx.randint Nx.int32 ~high:nvec.(i) [| 1 |] 0 in
let arr = Nx.to_array tensor in
arr.(0))
in
Nx.create Nx.int32 [| arity |] data
in
let pack tensor =
let arr : Int32.t array = Nx.to_array tensor in
Value.Int_array (Array.map Int32.to_int arr)
in
let unpack = function
| Value.Int_array arr when Array.length arr = arity ->
let data = Array.map Int32.of_int arr in
let tensor = Nx.create Nx.int32 [| arity |] data in
if contains tensor then Ok tensor
else
errorf "Multi_discrete value outside bounds: %s"
(Value.to_string (Value.Int_array arr))
| Value.Int_array arr ->
errorf "Multi_discrete expects vector of size %d, got size %d" arity
(Array.length arr)
| other ->
errorf "Multi_discrete expects Int_array, got %s"
(Value.to_string other)
in
let boundaries =
[
Value.Int_array (Array.make arity 0);
Value.Int_array (Array.init arity (fun i -> nvec.(i) - 1));
]
in
{
spec = Multi_discrete { nvec = Array.copy nvec };
shape = Some [| arity |];
contains;
sample;
pack;
unpack;
boundaries;
box_bounds = None;
discrete_info = None;
}
end
(* Tuple *)
module Tuple = struct
type element = Value.t list
let create spaces =
let spaces = Array.of_list spaces in
let len = Array.length spaces in
let contains values =
let rec loop i = function
| [] -> i = len
| v :: rest -> (
if i >= len then false
else
let (Pack s) = spaces.(i) in
match s.unpack v with
| Ok _ -> loop (i + 1) rest
| Error _ -> false)
in
loop 0 values
in
let sample () =
let values =
Array.to_list
(Array.init len (fun i ->
let (Pack s) = spaces.(i) in
let v = s.sample () in
s.pack v))
in
values
in
let pack values = Value.List values in
let unpack = function
| Value.List values ->
if List.length values <> len then
errorf "Tuple expects %d elements, got %d" len (List.length values)
else
let rec loop i = function
| [] -> Ok values
| v :: rest -> (
let (Pack s) = spaces.(i) in
match s.unpack v with
| Ok _ -> loop (i + 1) rest
| Error msg -> errorf "Tuple element %d: %s" i msg)
in
loop 0 values
| other -> errorf "Tuple expects List, got %s" (Value.to_string other)
in
let sub_specs = Array.to_list (Array.map (fun (Pack s) -> s.spec) spaces) in
{
spec = Tuple sub_specs;
shape = None;
contains;
sample;
pack;
unpack;
boundaries = [];
box_bounds = None;
discrete_info = None;
}
end
(* Dict *)
module Dict = struct
type element = (string * Value.t) list
module String_map = Map.Make (String)
let create entries =
let map =
List.fold_left
(fun acc (key, space) ->
if String_map.mem key acc then
invalid_arg (strf "Space.Dict.create: duplicate key '%s'" key);
String_map.add key space acc)
String_map.empty entries
in
let contains values =
let rec loop remaining m =
match remaining with
| [] -> String_map.is_empty m
| (key, value) :: rest -> (
match String_map.find_opt key m with
| None -> false
| Some (Pack s) -> (
match s.unpack value with
| Ok _ -> loop rest (String_map.remove key m)
| Error _ -> false))
in
loop values map
in
let sample () =
if String_map.is_empty map then []
else
let acc =
String_map.fold
(fun key (Pack s) acc ->
let v = s.sample () in
(key, s.pack v) :: acc)
map []
in
List.rev acc
in
let pack values = Value.Dict values in
let unpack = function
| Value.Dict values ->
if contains values then Ok values
else errorf "Dict contains unexpected keys or values"
| other -> errorf "Dict expects Dict, got %s" (Value.to_string other)
in
let sub_specs =
List.rev
(String_map.fold (fun key (Pack s) acc -> (key, s.spec) :: acc) map [])
in
{
spec = Dict sub_specs;
shape = None;
contains;
sample;
pack;
unpack;
boundaries = [];
box_bounds = None;
discrete_info = None;
}
end
(* Sequence *)
module Sequence = struct
type 'a element = 'a list
let create ?(min_length = 0) ?max_length base =
if min_length < 0 then invalid_arg err_seq_min;
let max_length =
match max_length with
| None -> None
| Some m when m < min_length -> invalid_arg err_seq_max
| Some _ as m -> m
in
let contains values =
let len = List.length values in
len >= min_length
&& (match max_length with None -> true | Some m -> len <= m)
&& List.for_all (fun v -> base.contains v) values
in
let sample () =
let length =
match max_length with
| None -> min_length
| Some max_len ->
if max_len = min_length then min_length
else
let tensor =
Nx.randint Nx.int32 ~high:(max_len + 1) [| 1 |] min_length
in
let arr = Nx.to_array tensor in
Int32.to_int arr.(0)
in
if length = 0 then []
else
let rec build i acc =
if i = length then List.rev acc
else
let v = base.sample () in
build (i + 1) (v :: acc)
in
build 0 []
in
let pack values = Value.List (List.map (fun v -> base.pack v) values) in
let unpack = function
| Value.List values ->
let len = List.length values in
let exceeds =
match max_length with None -> false | Some m -> len > m
in
if len < min_length || exceeds then
match max_length with
| None ->
errorf "Sequence length %d shorter than minimum %d" len
min_length
| Some m ->
errorf "Sequence length %d outside [%d, %d]" len min_length m
else
let rec loop acc = function
| [] -> Ok (List.rev acc)
| v :: rest -> (
match base.unpack v with
| Ok x -> loop (x :: acc) rest
| Error _ as err -> err)
in
loop [] values
| other -> errorf "Sequence expects List, got %s" (Value.to_string other)
in
{
spec = Sequence { min_length; max_length; base = base.spec };
shape = None;
contains;
sample;
pack;
unpack;
boundaries = [];
box_bounds = None;
discrete_info = None;
}
end
(* Text *)
module Text = struct
type element = string
let default_charset =
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 "
let create ?(charset = default_charset) ?(max_length = 64) () =
if max_length <= 0 then invalid_arg err_text_max;
let charset_len = String.length charset in
if charset_len = 0 then invalid_arg err_text_charset;
let contains value =
let len = String.length value in
len <= max_length
&&
let rec loop i =
if i = len then true
else String.contains charset value.[i] && loop (i + 1)
in
loop 0
in
let sample () =
let length =
if max_length = 1 then 1
else
let tensor = Nx.randint Nx.int32 ~high:(max_length + 1) [| 1 |] 1 in
let arr = Nx.to_array tensor in
Int32.to_int arr.(0)
in
if length = 0 then ""
else
let idxs = Nx.randint Nx.int32 ~high:charset_len [| length |] 0 in
let arr = Nx.to_array idxs in
Bytes.init length (fun i -> charset.[Int32.to_int arr.(i)])
|> Bytes.to_string
in
let pack value = Value.String value in
let unpack = function
| Value.String s when contains s -> Ok s
| Value.String s -> errorf "Text value '%s' violates constraints" s
| other -> errorf "Text expects String, got %s" (Value.to_string other)
in
let example = if charset_len = 0 then "" else String.make 1 charset.[0] in
let boundaries = [ Value.String ""; Value.String example ] in
{
spec = Text { charset; max_length };
shape = None;
contains;
sample;
pack;
unpack;
boundaries;
box_bounds = None;
discrete_info = None;
}
end
================================================
FILE: packages/fehu/lib/space.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Observation and action spaces.
Spaces define valid observations and actions for reinforcement learning
environments. They specify shapes, constraints, and provide methods to
validate, sample, and serialize values.
Each space type corresponds to a common RL scenario: discrete choices,
continuous vectors, binary indicators, composite structures, and
variable-length sequences. *)
(** {1:spec Structural description} *)
(** Structural description of a space. Two spaces are compatible when their
specs are equal. *)
type spec =
| Discrete of { start : int; n : int }
(** Integer choices in \[[start]; [start + n - 1]\]. *)
| Box of { low : float array; high : float array }
(** Continuous vector bounded per dimension. *)
| Multi_binary of { n : int } (** Binary vector of length [n]. *)
| Multi_discrete of { nvec : int array }
(** Multiple discrete axes with per-axis cardinalities. *)
| Tuple of spec list (** Fixed-length heterogeneous sequence. *)
| Dict of (string * spec) list (** Named fields with different types. *)
| Sequence of { min_length : int; max_length : int option; base : spec }
(** Variable-length homogeneous sequence. *)
| Text of { charset : string; max_length : int }
(** Character strings from a fixed alphabet. *)
val equal_spec : spec -> spec -> bool
(** [equal_spec a b] is [true] iff [a] and [b] describe structurally identical
spaces. *)
(** {1:spaces Spaces} *)
type 'a t
(** The type for spaces over values of type ['a]. A space is self-contained: all
bounds, constraints, and serialization logic are stored in the value itself.
*)
type packed =
| Pack : 'a t -> packed
(** Type-erased space for heterogeneous collections. *)
(** {1:ops Operations} *)
val spec : 'a t -> spec
(** [spec s] is the structural description of [s]. *)
val shape : 'a t -> int array option
(** [shape s] is the dimensionality of [s], if defined. [None] for scalar or
variable-length spaces. *)
val contains : 'a t -> 'a -> bool
(** [contains s v] is [true] iff [v] is valid in [s]. *)
val sample : 'a t -> 'a
(** [sample s] is a uniformly sampled value from [s].
Random keys are drawn from the implicit RNG scope. *)
val pack : 'a t -> 'a -> Value.t
(** [pack s v] is [v] converted to the universal {!Value.t} representation. *)
val unpack : 'a t -> Value.t -> ('a, string) result
(** [unpack s v] is [Ok x] if [v] can be converted to a valid element of [s], or
[Error msg] otherwise. *)
val boundary_values : 'a t -> Value.t list
(** [boundary_values s] is a list of representative edge-case values for [s].
Includes lower/upper bounds or canonical sentinels when known. The empty
list when no boundary values apply. *)
(** {1:space_types Space types} *)
module Discrete : sig
type element = (int32, Nx.int32_elt) Nx.t
(** Discrete action represented as a scalar int32 tensor. *)
val create : ?start:int -> int -> element t
(** [create ?start n] is a discrete space with [n] choices in the range
\[[start]; [start + n - 1]\]. [start] defaults to [0].
Raises [Invalid_argument] if [n <= 0]. *)
val n : element t -> int
(** [n s] is the number of choices in [s].
Raises [Invalid_argument] if [s] is not a discrete space. *)
val start : element t -> int
(** [start s] is the starting value of [s].
Raises [Invalid_argument] if [s] is not a discrete space. *)
val to_int : element -> int
(** [to_int e] is the integer value of the discrete element [e]. *)
val of_int : int -> element
(** [of_int v] is a discrete element with value [v]. *)
end
module Box : sig
type element = (float, Nx.float32_elt) Nx.t
(** Continuous vector represented as a float32 tensor. *)
val create : low:float array -> high:float array -> element t
(** [create ~low ~high] is a continuous space where element [i] satisfies
[low.(i) <= x.(i) <= high.(i)]. Both arrays must have the same positive
length.
When the range of a dimension is not finite (e.g. bounds set to
[Float.max_float]), sampling falls back to a uniform draw in \[[-1e6];
[1e6]\] clamped to bounds.
Raises [Invalid_argument] if [low] is empty, if [low] and [high] differ in
length, or if any [low.(i) > high.(i)]. *)
val bounds : element t -> float array * float array
(** [bounds s] is [(low, high)] copies of the bound vectors.
Raises [Invalid_argument] if [s] is not a box space. *)
end
module Multi_binary : sig
type element = (int32, Nx.int32_elt) Nx.t
(** Binary vector for multi-label scenarios. *)
val create : int -> element t
(** [create n] is a binary vector space of length [n]. Valid values are int32
tensors with [n] elements, each 0 or 1.
Raises [Invalid_argument] if [n <= 0]. *)
end
module Multi_discrete : sig
type element = (int32, Nx.int32_elt) Nx.t
(** Multiple discrete choices with independent cardinalities. *)
val create : int array -> element t
(** [create nvec] is a multi-discrete space where element [i] is in \[[0];
[nvec.(i) - 1]\].
Raises [Invalid_argument] if [nvec] is empty or any [nvec.(i) <= 0]. *)
end
module Tuple : sig
type element = Value.t list
(** Fixed-length heterogeneous sequence in {!Value.t} form. *)
val create : packed list -> element t
(** [create spaces] is a tuple space. Valid values are lists where element [i]
belongs to [spaces.(i)]. {!unpack} validates each element against its
subspace. *)
end
module Dict : sig
type element = (string * Value.t) list
(** Named fields with different space types. *)
val create : (string * packed) list -> element t
(** [create fields] is a dictionary space with named fields. Valid values are
association lists matching the keys and subspaces of [fields].
Raises [Invalid_argument] if [fields] contains duplicate keys. *)
end
module Sequence : sig
type 'a element = 'a list
(** Variable-length homogeneous sequence. *)
val create : ?min_length:int -> ?max_length:int -> 'a t -> 'a element t
(** [create ?min_length ?max_length s] is a sequence space over [s].
[min_length] defaults to [0]. When [max_length] is provided, sampling
draws a uniform length in \[[min_length]; [max_length]\]; otherwise the
sampler returns sequences of length [min_length].
Raises [Invalid_argument] if [min_length < 0] or
[max_length < min_length]. *)
end
module Text : sig
type element = string
(** String space for textual observations or actions. *)
val create : ?charset:string -> ?max_length:int -> unit -> element t
(** [create ?charset ?max_length ()] is a text space. [charset] defaults to
alphanumeric plus space. [max_length] defaults to [64]. Valid strings
contain only characters from [charset] and have length at most
[max_length].
Raises [Invalid_argument] if [max_length <= 0] or [charset] is empty. *)
end
================================================
FILE: packages/fehu/lib/value.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
type t =
| Null
| Bool of bool
| Int of int
| Float of float
| String of string
| Int_array of int array
| Float_array of float array
| Bool_array of bool array
| List of t list
| Dict of (string * t) list
(* Equality *)
let rec equal a b =
match (a, b) with
| Null, Null -> true
| Bool a, Bool b -> Bool.equal a b
| Int a, Int b -> Int.equal a b
| Float a, Float b -> Float.equal a b
| String a, String b -> String.equal a b
| Int_array a, Int_array b -> a = b
| Float_array a, Float_array b -> a = b
| Bool_array a, Bool_array b -> a = b
| List a, List b -> equal_list a b
| Dict a, Dict b -> equal_dict a b
| ( ( Null | Bool _ | Int _ | Float _ | String _ | Int_array _ | Float_array _
| Bool_array _ | List _ | Dict _ ),
_ ) ->
false
and equal_list a b =
match (a, b) with
| [], [] -> true
| x :: xs, y :: ys -> equal x y && equal_list xs ys
| _ -> false
and equal_dict a b =
match (a, b) with
| [], [] -> true
| (ka, va) :: rest_a, (kb, vb) :: rest_b ->
String.equal ka kb && equal va vb && equal_dict rest_a rest_b
| _ -> false
(* Formatting *)
let pp_array pp_elt ppf a =
Format.fprintf ppf "[|";
for i = 0 to Array.length a - 1 do
if i > 0 then Format.fprintf ppf "; ";
pp_elt ppf a.(i)
done;
Format.fprintf ppf "|]"
let rec pp ppf = function
| Null -> Format.fprintf ppf "null"
| Bool b -> Format.fprintf ppf "%b" b
| Int i -> Format.fprintf ppf "%d" i
| Float f -> Format.fprintf ppf "%g" f
| String s -> Format.fprintf ppf "%S" s
| Int_array a -> pp_array (fun ppf v -> Format.fprintf ppf "%d" v) ppf a
| Float_array a -> pp_array (fun ppf v -> Format.fprintf ppf "%g" v) ppf a
| Bool_array a -> pp_array (fun ppf v -> Format.fprintf ppf "%b" v) ppf a
| List items ->
Format.fprintf ppf "[";
List.iteri
(fun i v ->
if i > 0 then Format.fprintf ppf "; ";
pp ppf v)
items;
Format.fprintf ppf "]"
| Dict fields ->
Format.fprintf ppf "{";
List.iteri
(fun i (k, v) ->
if i > 0 then Format.fprintf ppf "; ";
Format.fprintf ppf "%s: %a" k pp v)
fields;
Format.fprintf ppf "}"
let to_string v = Format.asprintf "%a" pp v
================================================
FILE: packages/fehu/lib/value.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Universal value type.
Values represent heterogeneous data flowing through spaces and info
dictionaries. Each variant wraps one kind of scalar, array, or composite
datum. *)
(** {1:types Types} *)
(** The type for universal values. *)
type t =
| Null (** No value. *)
| Bool of bool (** A boolean. *)
| Int of int (** An integer. *)
| Float of float (** A float. *)
| String of string (** A string. *)
| Int_array of int array (** An integer array. *)
| Float_array of float array (** A float array. *)
| Bool_array of bool array (** A boolean array. *)
| List of t list (** A heterogeneous list. *)
| Dict of (string * t) list (** A string-keyed association list. *)
(** {1:predicates Predicates} *)
val equal : t -> t -> bool
(** [equal a b] is [true] iff [a] and [b] are structurally equal. *)
(** {1:fmt Formatting} *)
val pp : Format.formatter -> t -> unit
(** [pp] formats a value for debugging. *)
val to_string : t -> string
(** [to_string v] is [v] formatted as a string via {!pp}. *)
================================================
FILE: packages/fehu/lib/vec_env.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
let strf = Printf.sprintf
(* Error messages *)
let err_empty = "Vec_env.create: env list must not be empty"
let err_action_len n m = strf "Vec_env.step: expected %d actions, got %d" n m
let err_space kind =
strf "Vec_env.create: all environments must have the same %s space" kind
(* Types *)
type 'obs step = {
observations : 'obs array;
rewards : float array;
terminated : bool array;
truncated : bool array;
infos : Info.t array;
}
type ('obs, 'act, 'render) t = {
envs : ('obs, 'act, 'render) Env.t array;
observation_space : 'obs Space.t;
action_space : 'act Space.t;
}
(* Space compatibility *)
let ensure_compatible envs =
let first = envs.(0) in
let obs_spec = Space.spec (Env.observation_space first) in
let act_spec = Space.spec (Env.action_space first) in
for i = 1 to Array.length envs - 1 do
let env = envs.(i) in
if not (Space.equal_spec obs_spec (Space.spec (Env.observation_space env)))
then invalid_arg (err_space "observation");
if not (Space.equal_spec act_spec (Space.spec (Env.action_space env))) then
invalid_arg (err_space "action")
done
(* Constructor *)
let create envs =
match envs with
| [] -> invalid_arg err_empty
| first :: _ ->
let envs = Array.of_list envs in
ensure_compatible envs;
{
envs;
observation_space = Env.observation_space first;
action_space = Env.action_space first;
}
(* Accessors *)
let num_envs t = Array.length t.envs
let observation_space t = t.observation_space
let action_space t = t.action_space
(* Reset *)
let reset t () =
let n = Array.length t.envs in
let results = Array.init n (fun i -> Env.reset t.envs.(i) ()) in
let observations = Array.map fst results in
let infos = Array.map snd results in
(observations, infos)
(* Step *)
let step t actions =
let n = Array.length t.envs in
if Array.length actions <> n then
invalid_arg (err_action_len n (Array.length actions));
let results = Array.init n (fun i -> Env.step t.envs.(i) actions.(i)) in
let observations = Array.make n results.(0).observation in
let rewards = Array.make n 0. in
let terminated = Array.make n false in
let truncated = Array.make n false in
let infos = Array.make n Info.empty in
for i = 0 to n - 1 do
let result = results.(i) in
rewards.(i) <- result.reward;
terminated.(i) <- result.terminated;
truncated.(i) <- result.truncated;
if result.terminated || result.truncated then begin
let final_obs = Space.pack t.observation_space result.observation in
let info = Info.set "final_observation" final_obs result.info in
let info = Info.set "final_info" (Info.to_value result.info) info in
let obs, reset_info = Env.reset t.envs.(i) () in
observations.(i) <- obs;
infos.(i) <- Info.merge info reset_info
end
else begin
observations.(i) <- result.observation;
infos.(i) <- result.info
end
done;
{ observations; rewards; terminated; truncated; infos }
(* Close *)
let close t = Array.iter Env.close t.envs
================================================
FILE: packages/fehu/lib/vec_env.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Vectorized environments.
Runs multiple environment instances and batches their outputs. All
environments must have compatible observation and action spaces. Terminated
or truncated episodes are automatically reset. *)
(** {1:types Types} *)
type ('obs, 'act, 'render) t
(** The type for vectorized environments. *)
type 'obs step = {
observations : 'obs array; (** One observation per environment. *)
rewards : float array; (** One reward per environment. *)
terminated : bool array; (** Per-environment termination flags. *)
truncated : bool array; (** Per-environment truncation flags. *)
infos : Info.t array; (** Per-environment info dictionaries. *)
}
(** The type for batched step results. All arrays have length {!num_envs}. *)
(** {1:constructors Constructors} *)
val create : ('obs, 'act, 'render) Env.t list -> ('obs, 'act, 'render) t
(** [create envs] creates a vectorized environment.
All environments must have structurally identical spaces (checked via
{!Space.spec} and {!Space.equal_spec}). Raises [Invalid_argument] if [envs]
is empty or spaces differ. *)
(** {1:accessors Accessors} *)
val num_envs : ('obs, 'act, 'render) t -> int
(** [num_envs t] is the number of environments. *)
val observation_space : ('obs, 'act, 'render) t -> 'obs Space.t
(** [observation_space t] is the shared observation space. *)
val action_space : ('obs, 'act, 'render) t -> 'act Space.t
(** [action_space t] is the shared action space. *)
(** {1:lifecycle Lifecycle} *)
val reset : ('obs, 'act, 'render) t -> unit -> 'obs array * Info.t array
(** [reset t ()] resets all environments. *)
val step : ('obs, 'act, 'render) t -> 'act array -> 'obs step
(** [step t actions] steps all environments with the given actions.
[actions] must have length [num_envs t]. Terminated or truncated
environments are automatically reset. The terminal observation is stored in
the step's info under the key ["final_observation"] as a packed {!Value.t}.
The terminal info is stored under ["final_info"].
Raises [Invalid_argument] if [Array.length actions <> num_envs t]. *)
val close : ('obs, 'act, 'render) t -> unit
(** [close t] closes all environments. *)
================================================
FILE: packages/fehu/test/dune
================================================
(tests
(names
test_value
test_info
test_space
test_env
test_env_wrappers
test_collect
test_buffer
test_gae
test_eval
test_vec_env
test_render
test_envs)
(package fehu)
(libraries fehu fehu.envs nx windtrap))
================================================
FILE: packages/fehu/test/test_buffer.ml
================================================
open Fehu
open Windtrap
let make_transition obs act rew next_obs term trunc =
Buffer.
{
observation = obs;
action = act;
reward = rew;
next_observation = next_obs;
terminated = term;
truncated = trunc;
}
(* Creation *)
let test_create_empty () =
let buf = Buffer.create ~capacity:10 in
equal ~msg:"size = 0" int 0 (Buffer.size buf);
is_false ~msg:"not full" (Buffer.is_full buf)
let test_capacity () =
let buf = Buffer.create ~capacity:10 in
equal ~msg:"capacity = 10" int 10 (Buffer.capacity buf)
let test_create_zero_capacity () =
raises_invalid_arg "Buffer.create: capacity must be positive" (fun () ->
Buffer.create ~capacity:0)
let test_create_negative_capacity () =
raises_invalid_arg "Buffer.create: capacity must be positive" (fun () ->
Buffer.create ~capacity:(-1))
(* Add/Size *)
let test_add_increments_size () =
let buf = Buffer.create ~capacity:10 in
Buffer.add buf (make_transition 1 0 1.0 2 false false);
equal ~msg:"size = 1" int 1 (Buffer.size buf);
Buffer.add buf (make_transition 2 1 2.0 3 false false);
equal ~msg:"size = 2" int 2 (Buffer.size buf)
let test_size_capped_at_capacity () =
let buf = Buffer.create ~capacity:3 in
for i = 1 to 5 do
Buffer.add buf (make_transition i 0 1.0 (i + 1) false false)
done;
equal ~msg:"size capped at 3" int 3 (Buffer.size buf)
let test_is_full () =
let buf = Buffer.create ~capacity:2 in
Buffer.add buf (make_transition 1 0 1.0 2 false false);
is_false ~msg:"not yet full" (Buffer.is_full buf);
Buffer.add buf (make_transition 2 1 2.0 3 false false);
is_true ~msg:"full" (Buffer.is_full buf)
(* Sample *)
let test_sample_batch_size () =
let buf = Buffer.create ~capacity:10 in
for i = 1 to 5 do
Buffer.add buf (make_transition i 0 1.0 (i + 1) false false)
done;
let batch = Buffer.sample buf ~batch_size:3 in
equal ~msg:"batch length" int 3 (Array.length batch)
let test_sample_empty_raises () =
let buf = Buffer.create ~capacity:10 in
raises_invalid_arg "Buffer.sample: buffer is empty" (fun () ->
Buffer.sample buf ~batch_size:1)
let test_sample_zero_batch_raises () =
let buf = Buffer.create ~capacity:10 in
Buffer.add buf (make_transition 1 0 1.0 2 false false);
raises_invalid_arg "Buffer.sample: batch_size must be positive" (fun () ->
Buffer.sample buf ~batch_size:0)
let test_sample_arrays_lengths () =
let buf = Buffer.create ~capacity:10 in
for i = 1 to 5 do
Buffer.add buf (make_transition i 0 1.0 (i + 1) false false)
done;
let obs, acts, rews, next_obs, terms, truncs =
Buffer.sample_arrays buf ~batch_size:3
in
equal ~msg:"obs length" int 3 (Array.length obs);
equal ~msg:"acts length" int 3 (Array.length acts);
equal ~msg:"rews length" int 3 (Array.length rews);
equal ~msg:"next_obs length" int 3 (Array.length next_obs);
equal ~msg:"terms length" int 3 (Array.length terms);
equal ~msg:"truncs length" int 3 (Array.length truncs)
let test_sample_arrays_empty_raises () =
let buf = Buffer.create ~capacity:10 in
raises_invalid_arg "Buffer.sample: buffer is empty" (fun () ->
Buffer.sample_arrays buf ~batch_size:1)
(* Clear *)
let test_clear_resets () =
let buf = Buffer.create ~capacity:10 in
Buffer.add buf (make_transition 1 0 1.0 2 false false);
Buffer.add buf (make_transition 2 1 2.0 3 false false);
Buffer.clear buf;
equal ~msg:"size = 0 after clear" int 0 (Buffer.size buf)
let test_add_after_clear () =
let buf = Buffer.create ~capacity:10 in
Buffer.add buf (make_transition 1 0 1.0 2 false false);
Buffer.clear buf;
Buffer.add buf (make_transition 3 1 3.0 4 false false);
equal ~msg:"size = 1 after re-add" int 1 (Buffer.size buf)
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
run "Fehu.Buffer"
[
group "creation"
[
test "empty" test_create_empty;
test "capacity" test_capacity;
test "zero capacity raises" test_create_zero_capacity;
test "negative capacity raises" test_create_negative_capacity;
];
group "add/size"
[
test "add increments size" test_add_increments_size;
test "size capped at capacity" test_size_capped_at_capacity;
test "is_full" test_is_full;
];
group "sample"
[
test "batch size" test_sample_batch_size;
test "empty raises" test_sample_empty_raises;
test "zero batch raises" test_sample_zero_batch_raises;
test "sample_arrays lengths" test_sample_arrays_lengths;
test "sample_arrays empty raises" test_sample_arrays_empty_raises;
];
group "clear"
[
test "resets size" test_clear_resets;
test "add after clear" test_add_after_clear;
];
]
================================================
FILE: packages/fehu/test/test_collect.ml
================================================
open Fehu
open Windtrap
let make_test_env ?(max_steps = 100) () =
let obs_space = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let act_space = Space.Discrete.create 2 in
let state = ref 5.0 in
let steps = ref 0 in
let reset _env ?options:_ () =
state := 5.0;
steps := 0;
(Nx.create Nx.float32 [| 1 |] [| !state |], Info.empty)
in
let step _env action =
let a : Int32.t array = Nx.to_array (Nx.reshape [| 1 |] action) in
state := !state +. if Int32.to_int a.(0) = 0 then -1.0 else 1.0;
incr steps;
let terminated = !state <= 0.0 || !state >= 10.0 in
let truncated = (not terminated) && !steps >= max_steps in
Env.step_result
~observation:(Nx.create Nx.float32 [| 1 |] [| !state |])
~reward:1.0 ~terminated ~truncated ()
in
Env.create ~id:"Test-v0" ~observation_space:obs_space ~action_space:act_space
~reset ~step ()
(* Rollout *)
let test_rollout_length () =
let env = make_test_env () in
let policy _obs = (Nx.create Nx.int32 [| 1 |] [| 1l |], None, None) in
let traj = Collect.rollout env ~policy ~n_steps:5 in
equal ~msg:"length = 5" int 5 (Collect.length traj)
let test_rollout_arrays_length () =
let env = make_test_env () in
let policy _obs = (Nx.create Nx.int32 [| 1 |] [| 1l |], None, None) in
let traj = Collect.rollout env ~policy ~n_steps:5 in
equal ~msg:"observations" int 5 (Array.length traj.observations);
equal ~msg:"actions" int 5 (Array.length traj.actions);
equal ~msg:"rewards" int 5 (Array.length traj.rewards);
equal ~msg:"next_observations" int 5 (Array.length traj.next_observations);
equal ~msg:"terminated" int 5 (Array.length traj.terminated);
equal ~msg:"truncated" int 5 (Array.length traj.truncated);
equal ~msg:"infos" int 5 (Array.length traj.infos)
let test_rollout_next_obs_populated () =
let env = make_test_env () in
let policy _obs = (Nx.create Nx.int32 [| 1 |] [| 1l |], None, None) in
let traj = Collect.rollout env ~policy ~n_steps:3 in
for i = 0 to 2 do
let arr : float array =
Nx.to_array (Nx.reshape [| 1 |] traj.next_observations.(i))
in
is_true ~msg:"next_obs is finite" (Float.is_finite arr.(0))
done
let test_rollout_no_log_probs () =
let env = make_test_env () in
let policy _obs = (Nx.create Nx.int32 [| 1 |] [| 1l |], None, None) in
let traj = Collect.rollout env ~policy ~n_steps:3 in
is_none ~msg:"log_probs" traj.log_probs;
is_none ~msg:"values" traj.values
let test_rollout_with_log_probs () =
let env = make_test_env () in
let policy _obs =
(Nx.create Nx.int32 [| 1 |] [| 1l |], Some (-0.5), Some 1.0)
in
let traj = Collect.rollout env ~policy ~n_steps:4 in
is_some ~msg:"log_probs present" traj.log_probs;
is_some ~msg:"values present" traj.values;
equal ~msg:"log_probs length" int 4 (Array.length (Option.get traj.log_probs));
equal ~msg:"values length" int 4 (Array.length (Option.get traj.values))
(* Episodes *)
let test_episodes_count () =
let env = make_test_env ~max_steps:10 () in
let policy _obs = (Nx.create Nx.int32 [| 1 |] [| 1l |], None, None) in
let eps = Collect.episodes env ~policy ~n_episodes:2 ~max_steps:10 () in
equal ~msg:"2 episodes" int 2 (List.length eps)
let test_episodes_positive_length () =
let env = make_test_env ~max_steps:10 () in
let policy _obs = (Nx.create Nx.int32 [| 1 |] [| 1l |], None, None) in
let eps = Collect.episodes env ~policy ~n_episodes:2 ~max_steps:10 () in
List.iter
(fun ep ->
is_true ~msg:"episode has positive length" (Collect.length ep > 0))
eps
(* Concat *)
let test_concat_two () =
let env = make_test_env () in
let policy _obs = (Nx.create Nx.int32 [| 1 |] [| 1l |], None, None) in
let t1 = Collect.rollout env ~policy ~n_steps:3 in
let t2 = Collect.rollout env ~policy ~n_steps:4 in
let t = Collect.concat [ t1; t2 ] in
equal ~msg:"total length" int 7 (Collect.length t)
let test_concat_empty_raises () =
raises_invalid_arg "Collect.concat: empty list" (fun () -> Collect.concat [])
let test_concat_singleton () =
let env = make_test_env () in
let policy _obs = (Nx.create Nx.int32 [| 1 |] [| 1l |], None, None) in
let t1 = Collect.rollout env ~policy ~n_steps:5 in
let t = Collect.concat [ t1 ] in
equal ~msg:"same length" int 5 (Collect.length t)
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
run "Fehu.Collect"
[
group "rollout"
[
test "length" test_rollout_length;
test "arrays length" test_rollout_arrays_length;
test "next_observations populated" test_rollout_next_obs_populated;
test "no log_probs/values" test_rollout_no_log_probs;
test "with log_probs/values" test_rollout_with_log_probs;
];
group "episodes"
[
test "count" test_episodes_count;
test "positive length" test_episodes_positive_length;
];
group "concat"
[
test "two trajectories" test_concat_two;
test "empty raises" test_concat_empty_raises;
test "singleton" test_concat_singleton;
];
]
================================================
FILE: packages/fehu/test/test_env.ml
================================================
open Fehu
open Windtrap
let make_test_env ?(max_steps = 100) () =
let obs_space = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let act_space = Space.Discrete.create 2 in
let state = ref 5.0 in
let steps = ref 0 in
let reset _env ?options:_ () =
state := 5.0;
steps := 0;
(Nx.create Nx.float32 [| 1 |] [| !state |], Info.empty)
in
let step _env action =
let a : Int32.t array = Nx.to_array (Nx.reshape [| 1 |] action) in
state := !state +. if Int32.to_int a.(0) = 0 then -1.0 else 1.0;
incr steps;
let terminated = !state <= 0.0 || !state >= 10.0 in
let truncated = (not terminated) && !steps >= max_steps in
Env.step_result
~observation:(Nx.create Nx.float32 [| 1 |] [| !state |])
~reward:1.0 ~terminated ~truncated ()
in
Env.create ~id:"Test-v0" ~observation_space:obs_space ~action_space:act_space
~reset ~step ()
let action_left = Nx.create Nx.int32 [| 1 |] [| 0l |]
let action_right = Nx.create Nx.int32 [| 1 |] [| 1l |]
let read_obs obs =
let arr : float array = Nx.to_array (Nx.reshape [| 1 |] obs) in
arr.(0)
(* Creation *)
let test_id () =
let env = make_test_env () in
equal ~msg:"id is Some Test-v0" (option string) (Some "Test-v0") (Env.id env)
let test_observation_space () =
let env = make_test_env () in
let low, high = Space.Box.bounds (Env.observation_space env) in
equal ~msg:"obs low" (array (float 0.0)) [| 0.0 |] low;
equal ~msg:"obs high" (array (float 0.0)) [| 10.0 |] high
let test_action_space () =
let env = make_test_env () in
equal ~msg:"act n" int 2 (Space.Discrete.n (Env.action_space env))
let test_render_mode_default () =
let env = make_test_env () in
is_none ~msg:"render_mode default is None" (Env.render_mode env)
let test_render_mode_invalid () =
raises_invalid_arg ~msg:"render_mode not in render_modes"
"Env.create: render mode 'human' not in render_modes []" (fun () ->
let obs_space = Space.Box.create ~low:[| 0.0 |] ~high:[| 1.0 |] in
let act_space = Space.Discrete.create 2 in
Env.create ~observation_space:obs_space ~action_space:act_space
~render_mode:`Human ~render_modes:[]
~reset:(fun _env ?options:_ () -> assert false)
~step:(fun _env _ -> assert false)
())
(* Lifecycle *)
let test_reset_obs () =
let env = make_test_env () in
let obs, _info = Env.reset env () in
equal ~msg:"reset obs shape" (array int) [| 1 |] (Nx.shape obs);
equal ~msg:"reset obs value" (float 0.0) 5.0 (read_obs obs)
let test_step_after_reset () =
let env = make_test_env () in
let _obs, _info = Env.reset env () in
let step = Env.step env action_right in
equal ~msg:"reward" (float 0.0) 1.0 step.reward;
is_false ~msg:"not terminated" step.terminated;
is_false ~msg:"not truncated" step.truncated
let test_step_before_reset () =
let env = make_test_env () in
raises_invalid_arg ~msg:"step before reset"
"Env: operation 'step' requires calling reset first" (fun () ->
Env.step env action_left)
let test_step_after_terminal () =
let env = make_test_env () in
let _obs, _info = Env.reset env () in
(* Move left 5 times: 5 -> 4 -> 3 -> 2 -> 1 -> 0, terminated *)
for _ = 1 to 4 do
ignore (Env.step env action_left)
done;
let step = Env.step env action_left in
is_true ~msg:"terminated" step.terminated;
raises_invalid_arg ~msg:"step after terminal"
"Env: operation 'step' requires calling reset first" (fun () ->
Env.step env action_left)
let test_reset_after_terminal () =
let env = make_test_env () in
let _obs, _info = Env.reset env () in
for _ = 1 to 5 do
ignore (Env.step env action_left)
done;
let obs, _info = Env.reset env () in
equal ~msg:"reset clears terminal" (float 0.0) 5.0 (read_obs obs)
let test_close () =
let env = make_test_env () in
Env.close env;
is_true ~msg:"closed" (Env.closed env)
let test_step_on_closed () =
let env = make_test_env () in
let _obs, _info = Env.reset env () in
Env.close env;
raises_invalid_arg ~msg:"step on closed"
"Env: operation 'step' on a closed environment" (fun () ->
Env.step env action_left)
let test_reset_on_closed () =
let env = make_test_env () in
Env.close env;
raises_invalid_arg ~msg:"reset on closed"
"Env: operation 'reset' on a closed environment" (fun () ->
Env.reset env ())
let test_render_on_closed () =
let env = make_test_env () in
Env.close env;
raises_invalid_arg ~msg:"render on closed"
"Env: operation 'render' on a closed environment" (fun () -> Env.render env)
let test_close_idempotent () =
let env = make_test_env () in
Env.close env;
Env.close env;
is_true ~msg:"still closed" (Env.closed env)
(* step_result *)
let test_step_result_defaults () =
let obs = Nx.create Nx.float32 [| 1 |] [| 0.0 |] in
let s = Env.step_result ~observation:obs () in
equal ~msg:"default reward" (float 0.0) 0.0 s.reward;
is_false ~msg:"default terminated" s.terminated;
is_false ~msg:"default truncated" s.truncated;
is_true ~msg:"default info empty" (Info.is_empty s.info)
let test_step_result_custom () =
let obs = Nx.create Nx.float32 [| 1 |] [| 0.0 |] in
let info = Info.set "k" (Info.int 1) Info.empty in
let s =
Env.step_result ~observation:obs ~reward:5.0 ~terminated:true
~truncated:false ~info ()
in
equal ~msg:"custom reward" (float 0.0) 5.0 s.reward;
is_true ~msg:"custom terminated" s.terminated;
is_false ~msg:"custom truncated" s.truncated;
is_some ~msg:"custom info has key" (Info.find "k" s.info)
(* time_limit lifecycle enforcement *)
let test_time_limit_needs_reset () =
let env = make_test_env () in
let wrapped = Env.time_limit ~max_episode_steps:3 env in
let _obs, _info = Env.reset wrapped () in
for _ = 1 to 2 do
ignore (Env.step wrapped action_right)
done;
let s3 = Env.step wrapped action_right in
is_true ~msg:"step 3 truncated" s3.truncated;
raises_invalid_arg ~msg:"step after time_limit truncation requires reset"
"Env: operation 'step' requires calling reset first" (fun () ->
Env.step wrapped action_right)
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
run "Fehu.Env"
[
group "creation"
[
test "id" test_id;
test "observation_space" test_observation_space;
test "action_space" test_action_space;
test "render_mode default" test_render_mode_default;
test "render_mode invalid" test_render_mode_invalid;
];
group "lifecycle"
[
test "reset returns valid obs" test_reset_obs;
test "step after reset" test_step_after_reset;
test "step before reset" test_step_before_reset;
test "step after terminal" test_step_after_terminal;
test "reset after terminal" test_reset_after_terminal;
test "close" test_close;
test "step on closed" test_step_on_closed;
test "reset on closed" test_reset_on_closed;
test "render on closed" test_render_on_closed;
test "close idempotent" test_close_idempotent;
test "time_limit needs reset after truncation"
test_time_limit_needs_reset;
];
group "step_result"
[
test "defaults" test_step_result_defaults;
test "custom values" test_step_result_custom;
];
]
================================================
FILE: packages/fehu/test/test_env_wrappers.ml
================================================
open Fehu
open Windtrap
let make_test_env ?(max_steps = 100) () =
let obs_space = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let act_space = Space.Discrete.create 2 in
let state = ref 5.0 in
let steps = ref 0 in
let reset _env ?options:_ () =
state := 5.0;
steps := 0;
(Nx.create Nx.float32 [| 1 |] [| !state |], Info.empty)
in
let step _env action =
let a : Int32.t array = Nx.to_array (Nx.reshape [| 1 |] action) in
state := !state +. if Int32.to_int a.(0) = 0 then -1.0 else 1.0;
incr steps;
let terminated = !state <= 0.0 || !state >= 10.0 in
let truncated = (not terminated) && !steps >= max_steps in
Env.step_result
~observation:(Nx.create Nx.float32 [| 1 |] [| !state |])
~reward:1.0 ~terminated ~truncated ()
in
Env.create ~id:"Test-v0" ~observation_space:obs_space ~action_space:act_space
~reset ~step ()
let action_left = Nx.create Nx.int32 [| 1 |] [| 0l |]
let action_right = Nx.create Nx.int32 [| 1 |] [| 1l |]
let read_obs obs =
let arr : float array = Nx.to_array (Nx.reshape [| 1 |] obs) in
arr.(0)
let value = testable ~pp:Value.pp ~equal:Value.equal ()
(* State sharing *)
let test_close_wrapper_closes_inner () =
let env = make_test_env () in
let wrapped =
Env.map_observation
~observation_space:(Env.observation_space env)
~f:(fun obs info -> (obs, info))
env
in
Env.close wrapped;
is_true ~msg:"inner closed" (Env.closed env)
let test_close_inner_closes_wrapper () =
let env = make_test_env () in
let wrapped =
Env.map_observation
~observation_space:(Env.observation_space env)
~f:(fun obs info -> (obs, info))
env
in
Env.close env;
is_true ~msg:"wrapper closed" (Env.closed wrapped)
let test_reset_wrapper_clears_inner () =
let env = make_test_env () in
let wrapped =
Env.map_observation
~observation_space:(Env.observation_space env)
~f:(fun obs info -> (obs, info))
env
in
let _obs, _info = Env.reset wrapped () in
let step = Env.step env action_left in
equal ~msg:"inner step works" (float 0.0) 1.0 step.reward
(* map_observation *)
let test_map_observation_reset () =
let env = make_test_env () in
let double_space = Space.Box.create ~low:[| 0.0 |] ~high:[| 20.0 |] in
let wrapped =
Env.map_observation ~observation_space:double_space
~f:(fun obs info ->
let v = read_obs obs in
(Nx.create Nx.float32 [| 1 |] [| v *. 2.0 |], info))
env
in
let obs, _info = Env.reset wrapped () in
equal ~msg:"doubled reset obs" (float 0.0) 10.0 (read_obs obs)
let test_map_observation_step () =
let env = make_test_env () in
let double_space = Space.Box.create ~low:[| 0.0 |] ~high:[| 20.0 |] in
let wrapped =
Env.map_observation ~observation_space:double_space
~f:(fun obs info ->
let v = read_obs obs in
(Nx.create Nx.float32 [| 1 |] [| v *. 2.0 |], info))
env
in
let _obs, _info = Env.reset wrapped () in
let step = Env.step wrapped action_right in
(* Inner: 5 + 1 = 6, doubled: 12 *)
equal ~msg:"doubled step obs" (float 0.0) 12.0 (read_obs step.observation)
let test_map_observation_id () =
let env = make_test_env () in
let wrapped =
Env.map_observation
~observation_space:(Env.observation_space env)
~f:(fun obs info -> (obs, info))
env
in
equal ~msg:"id suffix" (option string) (Some "Test-v0/ObservationWrapper")
(Env.id wrapped)
(* map_action *)
let test_map_action_flip () =
let env = make_test_env () in
let wrapped =
Env.map_action ~action_space:(Env.action_space env)
~f:(fun action ->
let a : Int32.t array = Nx.to_array (Nx.reshape [| 1 |] action) in
let flipped = if Int32.to_int a.(0) = 0 then 1l else 0l in
Nx.create Nx.int32 [| 1 |] [| flipped |])
env
in
let _obs, _info = Env.reset wrapped () in
(* Send left (0) to wrapper; inner sees right (1): 5 -> 6 *)
let step = Env.step wrapped action_left in
equal ~msg:"flipped: left becomes right" (float 0.0) 6.0
(read_obs step.observation)
let test_map_action_id () =
let env = make_test_env () in
let wrapped =
Env.map_action ~action_space:(Env.action_space env) ~f:Fun.id env
in
equal ~msg:"id suffix" (option string) (Some "Test-v0/ActionWrapper")
(Env.id wrapped)
(* map_reward *)
let test_map_reward () =
let env = make_test_env () in
let wrapped =
Env.map_reward ~f:(fun ~reward ~info -> (reward *. 2.0, info)) env
in
let _obs, _info = Env.reset wrapped () in
let step = Env.step wrapped action_right in
equal ~msg:"doubled reward" (float 0.0) 2.0 step.reward
let test_map_reward_id () =
let env = make_test_env () in
let wrapped = Env.map_reward ~f:(fun ~reward ~info -> (reward, info)) env in
equal ~msg:"id suffix" (option string) (Some "Test-v0/RewardWrapper")
(Env.id wrapped)
(* clip_action *)
let make_box_action_env () =
let obs_space = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let act_space = Space.Box.create ~low:[| 0.0 |] ~high:[| 1.0 |] in
let last_action = ref 0.0 in
let reset _env ?options:_ () =
last_action := 0.0;
(Nx.create Nx.float32 [| 1 |] [| 5.0 |], Info.empty)
in
let step _env action =
let a : float array = Nx.to_array (Nx.reshape [| 1 |] action) in
last_action := a.(0);
Env.step_result
~observation:(Nx.create Nx.float32 [| 1 |] [| a.(0) |])
~reward:1.0 ()
in
let env =
Env.create ~id:"BoxAct-v0" ~observation_space:obs_space
~action_space:act_space ~reset ~step ()
in
(env, last_action)
let test_clip_action () =
let env, last_action = make_box_action_env () in
let wrapped = Env.clip_action env in
let _obs, _info = Env.reset wrapped () in
let _step = Env.step wrapped (Nx.create Nx.float32 [| 1 |] [| 2.0 |]) in
equal ~msg:"clamped to upper" (float 0.0) 1.0 !last_action;
let _step = Env.step wrapped (Nx.create Nx.float32 [| 1 |] [| -0.5 |]) in
equal ~msg:"clamped to lower" (float 0.0) 0.0 !last_action
(* clip_observation *)
let make_box_obs_env () =
let obs_space = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let act_space = Space.Discrete.create 2 in
let obs_val = ref 5.0 in
let reset _env ?options:_ () =
obs_val := 5.0;
(Nx.create Nx.float32 [| 1 |] [| !obs_val |], Info.empty)
in
let step _env action =
let a : Int32.t array = Nx.to_array (Nx.reshape [| 1 |] action) in
obs_val := !obs_val +. if Int32.to_int a.(0) = 0 then -3.0 else 3.0;
Env.step_result
~observation:(Nx.create Nx.float32 [| 1 |] [| !obs_val |])
~reward:1.0 ()
in
Env.create ~id:"BoxObs-v0" ~observation_space:obs_space
~action_space:act_space ~reset ~step ()
let test_clip_observation () =
let env = make_box_obs_env () in
let wrapped = Env.clip_observation ~low:[| 2.0 |] ~high:[| 8.0 |] env in
let _obs, _info = Env.reset wrapped () in
(* Step right: inner obs = 8.0, clipped to 8.0 *)
let s1 = Env.step wrapped action_right in
let arr1 : float array = Nx.to_array (Nx.reshape [| 1 |] s1.observation) in
equal ~msg:"clipped to upper" (float 0.0) 8.0 arr1.(0);
let _obs, _info = Env.reset wrapped () in
(* Step left: inner obs = 2.0, within bounds *)
let s2 = Env.step wrapped action_left in
let arr2 : float array = Nx.to_array (Nx.reshape [| 1 |] s2.observation) in
equal ~msg:"within bounds" (float 0.0) 2.0 arr2.(0)
let test_clip_observation_space () =
let env = make_box_obs_env () in
let wrapped = Env.clip_observation ~low:[| 2.0 |] ~high:[| 8.0 |] env in
let low, high = Space.Box.bounds (Env.observation_space wrapped) in
equal ~msg:"clipped low" (array (float 0.0)) [| 2.0 |] low;
equal ~msg:"clipped high" (array (float 0.0)) [| 8.0 |] high
(* time_limit *)
let test_time_limit_truncation () =
let env = make_test_env () in
let wrapped = Env.time_limit ~max_episode_steps:3 env in
let _obs, _info = Env.reset wrapped () in
let s1 = Env.step wrapped action_right in
is_false ~msg:"step 1 not truncated" s1.truncated;
let s2 = Env.step wrapped action_right in
is_false ~msg:"step 2 not truncated" s2.truncated;
let s3 = Env.step wrapped action_right in
is_true ~msg:"step 3 truncated" s3.truncated
let test_time_limit_info () =
let env = make_test_env () in
let wrapped = Env.time_limit ~max_episode_steps:2 env in
let _obs, _info = Env.reset wrapped () in
let _s1 = Env.step wrapped action_right in
let s2 = Env.step wrapped action_right in
is_some ~msg:"time_limit.truncated present"
(Info.find "time_limit.truncated" s2.info);
is_some ~msg:"time_limit.elapsed_steps present"
(Info.find "time_limit.elapsed_steps" s2.info)
let test_time_limit_info_values () =
let env = make_test_env () in
let wrapped = Env.time_limit ~max_episode_steps:2 env in
let _obs, _info = Env.reset wrapped () in
let _s1 = Env.step wrapped action_right in
let s2 = Env.step wrapped action_right in
let tl_truncated = Info.find_exn "time_limit.truncated" s2.info in
equal ~msg:"truncated is Bool true" value (Value.Bool true) tl_truncated;
let tl_steps = Info.find_exn "time_limit.elapsed_steps" s2.info in
equal ~msg:"elapsed_steps is Int 2" value (Value.Int 2) tl_steps
let test_time_limit_counter_resets () =
let env = make_test_env () in
let wrapped = Env.time_limit ~max_episode_steps:3 env in
let _obs, _info = Env.reset wrapped () in
for _ = 1 to 3 do
ignore (Env.step wrapped action_right)
done;
let _obs, _info = Env.reset wrapped () in
let s1 = Env.step wrapped action_right in
is_false ~msg:"counter reset after new episode" s1.truncated
let test_time_limit_nonpositive () =
let env = make_test_env () in
raises_invalid_arg ~msg:"max_episode_steps=0"
"Env.time_limit: max_episode_steps must be positive" (fun () ->
Env.time_limit ~max_episode_steps:0 env);
raises_invalid_arg ~msg:"max_episode_steps=-1"
"Env.time_limit: max_episode_steps must be positive" (fun () ->
Env.time_limit ~max_episode_steps:(-1) env)
let test_time_limit_needs_reset () =
let env = make_test_env () in
let wrapped = Env.time_limit ~max_episode_steps:2 env in
let _obs, _info = Env.reset wrapped () in
let _s1 = Env.step wrapped action_right in
let s2 = Env.step wrapped action_right in
is_true ~msg:"truncated at limit" s2.truncated;
raises_invalid_arg ~msg:"step after time_limit truncation"
"Env: operation 'step' requires calling reset first" (fun () ->
Env.step wrapped action_right)
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
run "Fehu.Env (wrappers)"
[
group "state sharing"
[
test "close wrapper closes inner" test_close_wrapper_closes_inner;
test "close inner closes wrapper" test_close_inner_closes_wrapper;
test "reset wrapper clears inner" test_reset_wrapper_clears_inner;
];
group "map_observation"
[
test "doubles reset obs" test_map_observation_reset;
test "doubles step obs" test_map_observation_step;
test "id suffix" test_map_observation_id;
];
group "map_action"
[
test "flip reverses direction" test_map_action_flip;
test "id suffix" test_map_action_id;
];
group "map_reward"
[
test "doubles reward" test_map_reward;
test "id suffix" test_map_reward_id;
];
group "clip_action" [ test "clamps out-of-bounds" test_clip_action ];
group "clip_observation"
[
test "clamps to explicit bounds" test_clip_observation;
test "observation space reflects bounds" test_clip_observation_space;
];
group "time_limit"
[
test "truncation at limit" test_time_limit_truncation;
test "info keys present" test_time_limit_info;
test "info values correct" test_time_limit_info_values;
test "counter resets on new episode" test_time_limit_counter_resets;
test "nonpositive raises" test_time_limit_nonpositive;
test "needs reset after truncation" test_time_limit_needs_reset;
];
]
================================================
FILE: packages/fehu/test/test_envs.ml
================================================
open Fehu
open Fehu_envs
open Windtrap
let read_float obs =
let arr : float array = Nx.to_array (Nx.reshape [| 1 |] obs) in
arr.(0)
let read_int32_array obs n =
let arr : Int32.t array = Nx.to_array (Nx.reshape [| n |] obs) in
Array.map Int32.to_int arr
let discrete action = Nx.create Nx.int32 [| 1 |] [| Int32.of_int action |]
(* Random_walk *)
let test_rw_creation () =
let env = Random_walk.make () in
match Env.id env with
| Some id ->
is_true ~msg:"id starts with RandomWalk"
(String.length id >= 10 && String.sub id 0 10 = "RandomWalk")
| None -> fail "expected an id"
let test_rw_reset_obs () =
let env = Random_walk.make () in
let obs, _info = Env.reset env () in
equal ~msg:"reset obs is 0.0" (float 1e-6) 0.0 (read_float obs)
let test_rw_step_left () =
let env = Random_walk.make () in
let _obs, _info = Env.reset env () in
let s = Env.step env (discrete 0) in
equal ~msg:"step left to -1.0" (float 1e-6) (-1.0) (read_float s.observation)
let test_rw_step_right () =
let env = Random_walk.make () in
let _obs, _info = Env.reset env () in
let s = Env.step env (discrete 1) in
equal ~msg:"step right to 1.0" (float 1e-6) 1.0 (read_float s.observation)
let test_rw_termination () =
let env = Random_walk.make () in
let _obs, _info = Env.reset env () in
let terminated = ref false in
for _ = 1 to 20 do
if not !terminated then begin
let s = Env.step env (discrete 1) in
if s.terminated then terminated := true
else if s.truncated then begin
let _obs, _info = Env.reset env () in
()
end
end
done;
is_true ~msg:"terminated at boundary" !terminated
let test_rw_ansi_render () =
let env = Random_walk.make ~render_mode:`Ansi () in
let _obs, _info = Env.reset env () in
match Env.render env with
| Some s -> is_true ~msg:"non-empty render" (String.length s > 0)
| None -> fail "expected Some render"
(* Cartpole *)
let test_cp_creation () =
let env = Cartpole.make () in
match Env.id env with
| Some id ->
is_true ~msg:"id starts with CartPole"
(String.length id >= 8 && String.sub id 0 8 = "CartPole")
| None -> fail "expected an id"
let test_cp_reset_shape () =
let env = Cartpole.make () in
let obs, _info = Env.reset env () in
let shape = Nx.shape obs in
equal ~msg:"obs shape [4]" (array int) [| 4 |] shape
let test_cp_step_reward () =
let env = Cartpole.make () in
let _obs, _info = Env.reset env () in
let s = Env.step env (discrete 1) in
is_false ~msg:"not terminated on first step" s.terminated;
equal ~msg:"reward 1.0" (float 1e-6) 1.0 s.reward
let test_cp_termination () =
let env = Cartpole.make () in
let _obs, _info = Env.reset env () in
let done_flag = ref false in
for _ = 1 to 600 do
if not !done_flag then begin
let s = Env.step env (discrete 0) in
if s.terminated || s.truncated then done_flag := true
end
done;
is_true ~msg:"episode ends" !done_flag
(* Grid_world *)
let test_gw_creation () =
let env = Grid_world.make () in
match Env.id env with
| Some id ->
is_true ~msg:"id starts with GridWorld"
(String.length id >= 9 && String.sub id 0 9 = "GridWorld")
| None -> fail "expected an id"
let test_gw_reset_obs () =
let env = Grid_world.make () in
let obs, _info = Env.reset env () in
let pos = read_int32_array obs 2 in
equal ~msg:"row = 0" int 0 pos.(0);
equal ~msg:"col = 0" int 0 pos.(1)
let test_gw_move_down () =
let env = Grid_world.make () in
let _obs, _info = Env.reset env () in
let s = Env.step env (discrete 1) in
let pos = read_int32_array s.observation 2 in
equal ~msg:"row = 1 after down" int 1 pos.(0)
let test_gw_move_right () =
let env = Grid_world.make () in
let _obs, _info = Env.reset env () in
let s = Env.step env (discrete 3) in
let pos = read_int32_array s.observation 2 in
equal ~msg:"col = 1 after right" int 1 pos.(1)
let test_gw_obstacle () =
let env = Grid_world.make () in
let _obs, _info = Env.reset env () in
(* Navigate to (1, 2): down, right, right *)
let _s = Env.step env (discrete 1) in
let _s = Env.step env (discrete 3) in
let s = Env.step env (discrete 3) in
let pos = read_int32_array s.observation 2 in
equal ~msg:"at (1,2)" int 1 pos.(0);
equal ~msg:"at (1,2)" int 2 pos.(1);
(* Try to move down into obstacle at (2,2) *)
let s = Env.step env (discrete 1) in
let pos = read_int32_array s.observation 2 in
equal ~msg:"blocked row still 1" int 1 pos.(0);
equal ~msg:"blocked col still 2" int 2 pos.(1)
let test_gw_reach_goal () =
let env = Grid_world.make () in
let _obs, _info = Env.reset env () in
(* Path to (4,4) avoiding obstacle at (2,2): down 4 times to row 4, then right
4 times to col 4 *)
for _ = 1 to 4 do
ignore (Env.step env (discrete 1))
done;
let s_right1 = Env.step env (discrete 3) in
is_false ~msg:"not done yet" s_right1.terminated;
let _s = Env.step env (discrete 3) in
let _s = Env.step env (discrete 3) in
let s = Env.step env (discrete 3) in
is_true ~msg:"terminated at goal" s.terminated;
equal ~msg:"reward 10.0" (float 1e-6) 10.0 s.reward
let test_gw_ansi_render () =
let env = Grid_world.make ~render_mode:`Ansi () in
let _obs, _info = Env.reset env () in
match Env.render env with
| Some (Grid_world.Text s) ->
is_true ~msg:"non-empty render" (String.length s > 0)
| Some (Grid_world.Image _) -> fail "expected Text render"
| None -> fail "expected Some render"
(* Mountain_car *)
let test_mc_creation () =
let env = Mountain_car.make () in
match Env.id env with
| Some id ->
is_true ~msg:"id starts with MountainCar"
(String.length id >= 11 && String.sub id 0 11 = "MountainCar")
| None -> fail "expected an id"
let test_mc_reset_shape () =
let env = Mountain_car.make () in
let obs, _info = Env.reset env () in
let shape = Nx.shape obs in
equal ~msg:"obs shape [2]" (array int) [| 2 |] shape
let test_mc_step_coast () =
let env = Mountain_car.make () in
let _obs, _info = Env.reset env () in
let s = Env.step env (discrete 1) in
let shape = Nx.shape s.observation in
equal ~msg:"obs shape after step" (array int) [| 2 |] shape;
is_false ~msg:"not terminated" s.terminated
let test_mc_reward () =
let env = Mountain_car.make () in
let _obs, _info = Env.reset env () in
let s = Env.step env (discrete 1) in
equal ~msg:"reward -1.0" (float 1e-6) (-1.0) s.reward
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
run "Fehu_envs"
[
group "RandomWalk"
[
test "creation" test_rw_creation;
test "reset observation" test_rw_reset_obs;
test "step left" test_rw_step_left;
test "step right" test_rw_step_right;
test "termination at boundary" test_rw_termination;
test "ansi render" test_rw_ansi_render;
];
group "CartPole"
[
test "creation" test_cp_creation;
test "reset shape" test_cp_reset_shape;
test "step reward" test_cp_step_reward;
test "termination" test_cp_termination;
];
group "GridWorld"
[
test "creation" test_gw_creation;
test "reset observation" test_gw_reset_obs;
test "move down" test_gw_move_down;
test "move right" test_gw_move_right;
test "obstacle blocks movement" test_gw_obstacle;
test "reach goal" test_gw_reach_goal;
test "ansi render" test_gw_ansi_render;
];
group "MountainCar"
[
test "creation" test_mc_creation;
test "reset shape" test_mc_reset_shape;
test "step coast" test_mc_step_coast;
test "reward" test_mc_reward;
];
]
================================================
FILE: packages/fehu/test/test_eval.ml
================================================
open Fehu
open Windtrap
let make_test_env ?(max_steps = 100) () =
let obs_space = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let act_space = Space.Discrete.create 2 in
let state = ref 5.0 in
let steps = ref 0 in
let reset _env ?options:_ () =
state := 5.0;
steps := 0;
(Nx.create Nx.float32 [| 1 |] [| !state |], Info.empty)
in
let step _env action =
let a : Int32.t array = Nx.to_array (Nx.reshape [| 1 |] action) in
state := !state +. if Int32.to_int a.(0) = 0 then -1.0 else 1.0;
incr steps;
let terminated = !state <= 0.0 || !state >= 10.0 in
let truncated = (not terminated) && !steps >= max_steps in
Env.step_result
~observation:(Nx.create Nx.float32 [| 1 |] [| !state |])
~reward:1.0 ~terminated ~truncated ()
in
Env.create ~id:"Test-v0" ~observation_space:obs_space ~action_space:act_space
~reset ~step ()
(* Run *)
let test_constant_reward_stats () =
let env = make_test_env ~max_steps:5 () in
let policy _obs = Nx.create Nx.int32 [| 1 |] [| 1l |] in
let stats = Eval.run env ~policy ~n_episodes:3 ~max_steps:5 () in
equal ~msg:"mean_reward" (float 1e-6) 5.0 stats.mean_reward;
equal ~msg:"std_reward" (float 1e-6) 0.0 stats.std_reward;
equal ~msg:"mean_length" (float 1e-6) 5.0 stats.mean_length;
equal ~msg:"n_episodes" int 3 stats.n_episodes
let test_n_episodes_matches () =
let env = make_test_env ~max_steps:5 () in
let policy _obs = Nx.create Nx.int32 [| 1 |] [| 1l |] in
let stats = Eval.run env ~policy ~n_episodes:7 ~max_steps:5 () in
equal ~msg:"n_episodes matches" int 7 stats.n_episodes
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
run "Fehu.Eval"
[
group "run"
[
test "constant reward statistics" test_constant_reward_stats;
test "n_episodes matches stats" test_n_episodes_matches;
];
]
================================================
FILE: packages/fehu/test/test_gae.ml
================================================
open Fehu
open Windtrap
let f = float 1e-6
(* Compute *)
let test_compute_simple () =
let rewards = [| 1.0; 1.0; 1.0 |] in
let values = [| 0.5; 0.5; 0.5 |] in
let terminated = [| false; false; false |] in
let truncated = [| false; false; false |] in
let next_values = [| 0.5; 0.5; 0.5 |] in
let advantages, returns =
Gae.compute ~rewards ~values ~terminated ~truncated ~next_values ~gamma:0.99
~lambda:0.95
in
equal ~msg:"lengths match" int 3 (Array.length advantages);
for i = 0 to 2 do
equal ~msg:"returns = advantages + values" f returns.(i)
(advantages.(i) +. values.(i))
done
let test_compute_termination () =
let rewards = [| 1.0; 1.0; 1.0 |] in
let values = [| 0.5; 0.5; 0.5 |] in
let terminated = [| false; true; false |] in
let truncated = [| false; false; false |] in
let next_values = [| 0.5; 0.5; 0.5 |] in
let advantages, _returns =
Gae.compute ~rewards ~values ~terminated ~truncated ~next_values ~gamma:0.99
~lambda:0.95
in
(* At step 1 (terminated), bootstrap is 0: delta = 1.0 + 0.99*0 - 0.5 = 0.5 *)
equal ~msg:"terminal advantage" f 0.5 advantages.(1)
let test_compute_truncation () =
let rewards = [| 1.0; 1.0; 1.0 |] in
let values = [| 0.5; 0.5; 0.5 |] in
let terminated = [| false; false; false |] in
let truncated = [| false; true; false |] in
let next_values = [| 0.5; 2.0; 0.5 |] in
let advantages_trunc, _returns =
Gae.compute ~rewards ~values ~terminated ~truncated ~next_values ~gamma:0.99
~lambda:0.95
in
(* At step 1 (truncated), bootstrap uses next_values.(1) = 2.0 delta = 1.0 +
0.99*2.0 - 0.5 = 2.48 *)
let terminated_fake = [| false; true; false |] in
let advantages_term, _returns_term =
Gae.compute ~rewards ~values ~terminated:terminated_fake
~truncated:[| false; false; false |] ~next_values ~gamma:0.99 ~lambda:0.95
in
(* With termination instead, bootstrap would be 0: delta = 1.0 + 0.99*0 - 0.5
= 0.5 These must differ because truncation uses next_values. *)
not_equal ~msg:"truncation differs from termination" f advantages_trunc.(1)
advantages_term.(1)
let test_compute_length_mismatch () =
raises_invalid_arg "Gae: all arrays must have the same length" (fun () ->
Gae.compute ~rewards:[| 1.0; 1.0 |] ~values:[| 0.5 |]
~terminated:[| false; false |] ~truncated:[| false; false |]
~next_values:[| 0.5; 0.5 |] ~gamma:0.99 ~lambda:0.95)
let test_compute_empty () =
let advantages, returns =
Gae.compute ~rewards:[||] ~values:[||] ~terminated:[||] ~truncated:[||]
~next_values:[||] ~gamma:0.99 ~lambda:0.95
in
equal ~msg:"empty advantages" int 0 (Array.length advantages);
equal ~msg:"empty returns" int 0 (Array.length returns)
(* Returns *)
let test_returns_simple () =
let ret =
Gae.returns ~rewards:[| 1.0; 1.0; 1.0 |]
~terminated:[| false; false; false |] ~truncated:[| false; false; false |]
~gamma:1.0
in
equal ~msg:"ret[0]" f 3.0 ret.(0);
equal ~msg:"ret[1]" f 2.0 ret.(1);
equal ~msg:"ret[2]" f 1.0 ret.(2)
let test_returns_gamma_zero () =
let ret =
Gae.returns ~rewards:[| 1.0; 2.0; 3.0 |]
~terminated:[| false; false; false |] ~truncated:[| false; false; false |]
~gamma:0.0
in
equal ~msg:"ret[0]" f 1.0 ret.(0);
equal ~msg:"ret[1]" f 2.0 ret.(1);
equal ~msg:"ret[2]" f 3.0 ret.(2)
let test_returns_terminated () =
let ret =
Gae.returns ~rewards:[| 1.0; 1.0; 1.0 |]
~terminated:[| false; true; false |] ~truncated:[| false; false; false |]
~gamma:1.0
in
(* Step 2: acc = 1.0 Step 1: terminated, so acc = 1.0 + 1.0*0.0*1.0 = 1.0 Step
0: acc = 1.0 + 1.0*1.0*1.0 = 2.0 *)
equal ~msg:"ret[0]" f 2.0 ret.(0);
equal ~msg:"ret[1]" f 1.0 ret.(1);
equal ~msg:"ret[2]" f 1.0 ret.(2)
let test_returns_truncated () =
let ret =
Gae.returns ~rewards:[| 1.0; 1.0; 1.0 |]
~terminated:[| false; false; false |] ~truncated:[| false; true; false |]
~gamma:1.0
in
(* Truncation at step 1 resets accumulation, same as terminated *)
equal ~msg:"ret[0]" f 2.0 ret.(0);
equal ~msg:"ret[1]" f 1.0 ret.(1);
equal ~msg:"ret[2]" f 1.0 ret.(2)
let test_returns_length_mismatch () =
raises_invalid_arg
"Gae.returns: rewards, terminated, and truncated must have the same length"
(fun () ->
Gae.returns ~rewards:[| 1.0; 1.0 |] ~terminated:[| false |]
~truncated:[| false; false |] ~gamma:0.99)
(* Compute from values *)
let test_compute_from_values_simple () =
let rewards = [| 1.0; 1.0; 1.0 |] in
let values = [| 0.5; 0.5; 0.5 |] in
let terminated = [| false; false; false |] in
let truncated = [| false; false; false |] in
let last_value = 0.5 in
let advantages, returns =
Gae.compute_from_values ~rewards ~values ~terminated ~truncated ~last_value
~gamma:0.99 ~lambda:0.95
in
(* next_values should be [| 0.5; 0.5; 0.5 |] (values shifted + last_value) *)
let advantages2, returns2 =
Gae.compute ~rewards ~values ~terminated ~truncated
~next_values:[| 0.5; 0.5; 0.5 |] ~gamma:0.99 ~lambda:0.95
in
for i = 0 to 2 do
equal ~msg:"advantages match" f advantages2.(i) advantages.(i);
equal ~msg:"returns match" f returns2.(i) returns.(i)
done
let test_compute_from_values_shifted () =
let rewards = [| 1.0; 1.0; 1.0 |] in
let values = [| 1.0; 2.0; 3.0 |] in
let terminated = [| false; false; false |] in
let truncated = [| false; false; false |] in
let last_value = 4.0 in
let advantages, _returns =
Gae.compute_from_values ~rewards ~values ~terminated ~truncated ~last_value
~gamma:0.99 ~lambda:0.95
in
(* next_values = [| 2.0; 3.0; 4.0 |] *)
let advantages2, _returns2 =
Gae.compute ~rewards ~values ~terminated ~truncated
~next_values:[| 2.0; 3.0; 4.0 |] ~gamma:0.99 ~lambda:0.95
in
for i = 0 to 2 do
equal ~msg:"advantages match" f advantages2.(i) advantages.(i)
done
(* Normalize *)
let test_normalize_mean_zero () =
let arr = [| 1.0; 2.0; 3.0; 4.0; 5.0 |] in
let normed = Gae.normalize arr in
let mean = ref 0.0 in
Array.iter (fun x -> mean := !mean +. x) normed;
mean := !mean /. Float.of_int (Array.length normed);
equal ~msg:"mean near 0" f 0.0 !mean
let test_normalize_std_one () =
let arr = [| 1.0; 2.0; 3.0; 4.0; 5.0 |] in
let normed = Gae.normalize arr in
let n = Array.length normed in
let mean = ref 0.0 in
Array.iter (fun x -> mean := !mean +. x) normed;
mean := !mean /. Float.of_int n;
let var = ref 0.0 in
Array.iter
(fun x ->
let d = x -. !mean in
var := !var +. (d *. d))
normed;
var := !var /. Float.of_int n;
let std = sqrt !var in
is_true ~msg:"std near 1" (Float.abs (std -. 1.0) < 0.01)
let test_normalize_empty () =
let normed = Gae.normalize [||] in
equal ~msg:"empty" int 0 (Array.length normed)
let test_normalize_single () =
let normed = Gae.normalize [| 42.0 |] in
equal ~msg:"single normalizes to 0" f 0.0 normed.(0)
let () =
run "Fehu.Gae"
[
group "compute"
[
test "simple" test_compute_simple;
test "termination" test_compute_termination;
test "truncation" test_compute_truncation;
test "length mismatch" test_compute_length_mismatch;
test "empty" test_compute_empty;
];
group "returns"
[
test "simple" test_returns_simple;
test "gamma zero" test_returns_gamma_zero;
test "terminated resets" test_returns_terminated;
test "truncated resets" test_returns_truncated;
test "length mismatch" test_returns_length_mismatch;
];
group "compute_from_values"
[
test "matches compute" test_compute_from_values_simple;
test "shifted values" test_compute_from_values_shifted;
];
group "normalize"
[
test "mean near zero" test_normalize_mean_zero;
test "std near one" test_normalize_std_one;
test "empty" test_normalize_empty;
test "single element" test_normalize_single;
];
]
================================================
FILE: packages/fehu/test/test_info.ml
================================================
open Fehu
open Windtrap
let value = testable ~pp:Value.pp ~equal:Value.equal ()
(* Operations *)
let test_empty_is_empty () =
equal ~msg:"empty is_empty" bool true (Info.is_empty Info.empty)
let test_set_then_find () =
let info = Info.set "k" (Value.Int 1) Info.empty in
equal ~msg:"find after set" (option value) (Some (Value.Int 1))
(Info.find "k" info)
let test_find_missing () =
let info = Info.set "k" (Value.Int 1) Info.empty in
equal ~msg:"find missing" (option value) None (Info.find "other" info)
let test_find_exn_existing () =
let info = Info.set "k" (Value.Int 42) Info.empty in
equal ~msg:"find_exn existing" value (Value.Int 42) (Info.find_exn "k" info)
let test_remove () =
let info = Info.set "k" (Value.Int 1) Info.empty in
let info = Info.remove "k" info in
equal ~msg:"find after remove" (option value) None (Info.find "k" info)
let test_merge_right_biased () =
let a = Info.set "k" (Value.Int 1) Info.empty in
let b = Info.set "k" (Value.Int 2) Info.empty in
let merged = Info.merge a b in
equal ~msg:"merge right wins" value (Value.Int 2) (Info.find_exn "k" merged)
let test_of_list_to_list_round_trip () =
let kvs = [ ("a", Value.Int 1); ("c", Value.Int 3); ("b", Value.Int 2) ] in
let info = Info.of_list kvs in
let result = Info.to_list info in
equal ~msg:"round-trip keys sorted" (list string) [ "a"; "b"; "c" ]
(List.map fst result);
equal ~msg:"round-trip values" (list value)
[ Value.Int 1; Value.Int 2; Value.Int 3 ]
(List.map snd result)
(* Errors *)
let test_find_exn_missing () =
raises_invalid_arg "Info.find_exn: key \"missing\" not present" (fun () ->
ignore (Info.find_exn "missing" Info.empty))
(* Convenience *)
let test_int_convenience () =
equal ~msg:"Info.int" value (Value.Int 42) (Info.int 42)
let test_float_convenience () =
equal ~msg:"Info.float" value (Value.Float 1.0) (Info.float 1.0)
let test_bool_convenience () =
equal ~msg:"Info.bool" value (Value.Bool true) (Info.bool true)
let test_string_convenience () =
equal ~msg:"Info.string" value (Value.String "hi") (Info.string "hi")
let test_null_convenience () = equal ~msg:"Info.null" value Value.Null Info.null
let () =
run "Fehu.Info"
[
group "operations"
[
test "empty is_empty" test_empty_is_empty;
test "set then find" test_set_then_find;
test "find missing key" test_find_missing;
test "find_exn existing" test_find_exn_existing;
test "remove" test_remove;
test "merge right-biased" test_merge_right_biased;
test "of_list/to_list round-trip" test_of_list_to_list_round_trip;
];
group "errors" [ test "find_exn missing raises" test_find_exn_missing ];
group "convenience"
[
test "int" test_int_convenience;
test "float" test_float_convenience;
test "bool" test_bool_convenience;
test "string" test_string_convenience;
test "null" test_null_convenience;
];
]
================================================
FILE: packages/fehu/test/test_render.ml
================================================
open Fehu
open Windtrap
let make_data n =
Bigarray.Array1.create Bigarray.int8_unsigned Bigarray.c_layout n
(* Image *)
let test_valid_rgb_image () =
let data = make_data 12 in
let img = Render.image ~width:2 ~height:2 data in
equal ~msg:"width" int 2 img.width;
equal ~msg:"height" int 2 img.height
let test_wrong_data_length_raises () =
let data = make_data 10 in
raises_invalid_arg
"Render.image: data length 10 does not match width * height * channels = 12"
(fun () -> ignore (Render.image ~width:2 ~height:2 data))
let test_rgba_channels () =
let data = make_data 16 in
let img =
Render.image ~width:2 ~height:2 ~pixel_format:Render.Pixel.Rgba data
in
equal ~msg:"width" int 2 img.width;
equal ~msg:"height" int 2 img.height
let test_gray_channels () =
let data = make_data 4 in
let img =
Render.image ~width:2 ~height:2 ~pixel_format:Render.Pixel.Gray data
in
equal ~msg:"width" int 2 img.width;
equal ~msg:"height" int 2 img.height
let test_pixel_format_default_rgb () =
let data = make_data 3 in
let img = Render.image ~width:1 ~height:1 data in
equal ~msg:"default is Rgb" int 3 (Render.Pixel.channels img.pixel_format)
(* Rollout *)
let make_renderable_env () =
let obs_space = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let act_space = Space.Discrete.create 2 in
let state = ref 5.0 in
let reset _env ?options:_ () =
state := 5.0;
(Nx.create Nx.float32 [| 1 |] [| !state |], Info.empty)
in
let step _env action =
let a : Int32.t array = Nx.to_array (Nx.reshape [| 1 |] action) in
state := !state +. if Int32.to_int a.(0) = 0 then -1.0 else 1.0;
let terminated = !state <= 0.0 || !state >= 10.0 in
Env.step_result
~observation:(Nx.create Nx.float32 [| 1 |] [| !state |])
~reward:1.0 ~terminated ()
in
let render () =
let data = make_data 3 in
Some (Render.image ~width:1 ~height:1 data)
in
Env.create ~id:"Renderable-v0" ~observation_space:obs_space
~action_space:act_space ~reset ~step ~render ()
let test_rollout_sink_called () =
let env = make_renderable_env () in
let count = ref 0 in
let policy _obs = Nx.create Nx.int32 [| 1 |] [| 1l |] in
let sink _frame = incr count in
Render.rollout env ~policy ~steps:3 ~sink ();
equal ~msg:"sink called 3 times" int 3 !count
(* on_render *)
let action_right = Nx.create Nx.int32 [| 1 |] [| 1l |]
let test_on_render_frame_count () =
let env = make_renderable_env () in
let count = ref 0 in
let wrapped = Render.on_render ~sink:(fun _ -> incr count) env in
let _obs, _info = Env.reset wrapped () in
let _s1 = Env.step wrapped action_right in
let _s2 = Env.step wrapped action_right in
let _s3 = Env.step wrapped action_right in
(* 1 frame from reset + 3 frames from steps = 4 *)
equal ~msg:"frame count" int 4 !count
let test_on_render_passthrough () =
let env = make_renderable_env () in
let wrapped = Render.on_render ~sink:(fun _ -> ()) env in
let _obs, _info = Env.reset wrapped () in
let step = Env.step wrapped action_right in
equal ~msg:"reward unchanged" (float 0.0) 1.0 step.reward;
is_false ~msg:"not terminated" step.terminated;
is_false ~msg:"not truncated" step.truncated
let test_on_render_id () =
let env = make_renderable_env () in
let wrapped = Render.on_render ~sink:(fun _ -> ()) env in
equal ~msg:"id suffix" (option string) (Some "Renderable-v0/OnRender")
(Env.id wrapped)
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
run "Fehu.Render"
[
group "image"
[
test "valid RGB 2x2" test_valid_rgb_image;
test "wrong data length raises" test_wrong_data_length_raises;
test "RGBA 4 channels" test_rgba_channels;
test "Gray 1 channel" test_gray_channels;
test "default pixel_format is Rgb" test_pixel_format_default_rgb;
];
group "rollout"
[ test "sink called for each step" test_rollout_sink_called ];
group "on_render"
[
test "frame count" test_on_render_frame_count;
test "passthrough" test_on_render_passthrough;
test "id suffix" test_on_render_id;
];
]
================================================
FILE: packages/fehu/test/test_space.ml
================================================
open Fehu
open Windtrap
let value = testable ~pp:Value.pp ~equal:Value.equal ()
(* Helpers *)
let int32_scalar v = Nx.scalar Nx.int32 (Int32.of_int v)
let int32_vec arr =
Nx.create Nx.int32 [| Array.length arr |] (Array.map Int32.of_int arr)
let float32_vec arr = Nx.create Nx.float32 [| Array.length arr |] arr
let read_float32_vec t =
let n = (Nx.shape t).(0) in
let arr : float array = Nx.to_array (Nx.reshape [| n |] t) in
arr
(* Discrete *)
let test_discrete_default () =
let s = Space.Discrete.create 3 in
equal ~msg:"n is 3" int 3 (Space.Discrete.n s);
equal ~msg:"start is 0" int 0 (Space.Discrete.start s)
let test_discrete_custom_start () =
let s = Space.Discrete.create ~start:5 3 in
equal ~msg:"start is 5" int 5 (Space.Discrete.start s);
equal ~msg:"n is 3" int 3 (Space.Discrete.n s)
let test_discrete_contains () =
let s = Space.Discrete.create 3 in
is_true ~msg:"contains 0" (Space.contains s (int32_scalar 0));
is_true ~msg:"contains 1" (Space.contains s (int32_scalar 1));
is_true ~msg:"contains 2" (Space.contains s (int32_scalar 2));
is_false ~msg:"not contains 3" (Space.contains s (int32_scalar 3));
is_false ~msg:"not contains -1" (Space.contains s (int32_scalar (-1)))
let test_discrete_contains_with_start () =
let s = Space.Discrete.create ~start:5 3 in
is_true ~msg:"contains 5" (Space.contains s (int32_scalar 5));
is_true ~msg:"contains 7" (Space.contains s (int32_scalar 7));
is_false ~msg:"not contains 4" (Space.contains s (int32_scalar 4));
is_false ~msg:"not contains 8" (Space.contains s (int32_scalar 8))
let test_discrete_sample () =
let s = Space.Discrete.create 3 in
let v = Space.sample s in
is_true ~msg:"sample is valid" (Space.contains s v)
let test_discrete_pack_unpack () =
let s = Space.Discrete.create 3 in
let v = int32_scalar 2 in
let packed = Space.pack s v in
equal ~msg:"pack produces Int 2" value (Value.Int 2) packed;
let unpacked = Space.unpack s packed in
is_ok ~msg:"unpack succeeds" unpacked
let test_discrete_unpack_invalid () =
let s = Space.Discrete.create 3 in
is_error ~msg:"unpack out of range" (Space.unpack s (Value.Int 5));
is_error ~msg:"unpack wrong type" (Space.unpack s (Value.String "x"))
let test_discrete_boundary_values () =
let s = Space.Discrete.create 3 in
let bvs = Space.boundary_values s in
equal ~msg:"2 boundary values" int 2 (List.length bvs);
equal ~msg:"first boundary" value (Value.Int 0) (List.hd bvs);
equal ~msg:"last boundary" value (Value.Int 2) (List.nth bvs 1)
let test_discrete_boundary_single () =
let s = Space.Discrete.create 1 in
let bvs = Space.boundary_values s in
equal ~msg:"1 boundary for n=1" int 1 (List.length bvs)
let test_discrete_shape () =
let s = Space.Discrete.create 3 in
is_none ~msg:"discrete shape is None" (Space.shape s)
let test_discrete_error_zero () =
raises_invalid_arg "Space.Discrete.create: n must be strictly positive"
(fun () -> Space.Discrete.create 0)
let test_discrete_error_negative () =
raises_invalid_arg "Space.Discrete.create: n must be strictly positive"
(fun () -> Space.Discrete.create (-1))
(* Box *)
let test_box_1d () =
let s = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let low, high = Space.Box.bounds s in
equal ~msg:"low" (array (float 0.)) [| 0.0 |] low;
equal ~msg:"high" (array (float 0.)) [| 10.0 |] high
let test_box_contains () =
let s = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
is_true ~msg:"mid value" (Space.contains s (float32_vec [| 5.0 |]));
is_true ~msg:"low bound" (Space.contains s (float32_vec [| 0.0 |]));
is_true ~msg:"high bound" (Space.contains s (float32_vec [| 10.0 |]));
is_false ~msg:"below low" (Space.contains s (float32_vec [| -0.1 |]));
is_false ~msg:"above high" (Space.contains s (float32_vec [| 10.1 |]))
let test_box_sample () =
let s = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let v = Space.sample s in
is_true ~msg:"sample in bounds" (Space.contains s v)
let test_box_pack_unpack () =
let s = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let v = float32_vec [| 5.0 |] in
let packed = Space.pack s v in
let unpacked = Space.unpack s packed in
is_ok ~msg:"round-trip succeeds" unpacked;
match unpacked with
| Ok t ->
let arr = read_float32_vec t in
equal ~msg:"value preserved" (float 0.01) 5.0 arr.(0)
| Error _ -> fail "unreachable"
let test_box_boundary_values () =
let s = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let bvs = Space.boundary_values s in
equal ~msg:"2 boundaries" int 2 (List.length bvs)
let test_box_boundary_identical () =
let s = Space.Box.create ~low:[| 5.0 |] ~high:[| 5.0 |] in
let bvs = Space.boundary_values s in
equal ~msg:"1 boundary when identical" int 1 (List.length bvs)
let test_box_shape_1d () =
let s = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
is_some ~msg:"shape is Some" (Space.shape s);
equal ~msg:"shape [|1|]" (array int) [| 1 |] (Option.get (Space.shape s))
let test_box_2d () =
let s = Space.Box.create ~low:[| 0.0; -1.0 |] ~high:[| 1.0; 1.0 |] in
equal ~msg:"shape [|2|]" (array int) [| 2 |] (Option.get (Space.shape s));
is_true ~msg:"2d in bounds" (Space.contains s (float32_vec [| 0.5; 0.0 |]));
is_false ~msg:"2d out of bounds"
(Space.contains s (float32_vec [| 0.5; 2.0 |]))
let test_box_error_empty () =
raises_invalid_arg "Space.Box.create: low cannot be empty" (fun () ->
Space.Box.create ~low:[||] ~high:[||])
let test_box_error_mismatch () =
raises_invalid_arg
"Space.Box.create: low and high must have identical lengths" (fun () ->
Space.Box.create ~low:[| 0.0 |] ~high:[| 1.0; 2.0 |])
let test_box_error_low_gt_high () =
raises_match ~msg:"low > high raises"
(fun exn -> match exn with Invalid_argument _ -> true | _ -> false)
(fun () -> Space.Box.create ~low:[| 5.0 |] ~high:[| 1.0 |])
(* Multi_binary *)
let test_mb_contains () =
let s = Space.Multi_binary.create 3 in
is_true ~msg:"all zeros" (Space.contains s (int32_vec [| 0; 0; 0 |]));
is_true ~msg:"all ones" (Space.contains s (int32_vec [| 1; 1; 1 |]));
is_true ~msg:"mixed" (Space.contains s (int32_vec [| 0; 1; 0 |]));
is_false ~msg:"value 2 invalid" (Space.contains s (int32_vec [| 0; 2; 0 |]));
is_false ~msg:"wrong length" (Space.contains s (int32_vec [| 0; 1 |]))
let test_mb_sample () =
let s = Space.Multi_binary.create 3 in
let v = Space.sample s in
is_true ~msg:"sample valid" (Space.contains s v)
let test_mb_boundary_values () =
let s = Space.Multi_binary.create 3 in
let bvs = Space.boundary_values s in
equal ~msg:"2 boundaries" int 2 (List.length bvs)
let test_mb_shape () =
let s = Space.Multi_binary.create 3 in
equal ~msg:"shape [|3|]" (option (array int)) (Some [| 3 |]) (Space.shape s)
let test_mb_error () =
raises_invalid_arg "Space.Multi_binary.create: n must be strictly positive"
(fun () -> Space.Multi_binary.create 0)
(* Multi_discrete *)
let test_md_contains () =
let s = Space.Multi_discrete.create [| 3; 4 |] in
is_true ~msg:"valid" (Space.contains s (int32_vec [| 0; 0 |]));
is_true ~msg:"upper valid" (Space.contains s (int32_vec [| 2; 3 |]));
is_false ~msg:"first oob" (Space.contains s (int32_vec [| 3; 0 |]));
is_false ~msg:"second oob" (Space.contains s (int32_vec [| 0; 4 |]));
is_false ~msg:"negative" (Space.contains s (int32_vec [| -1; 0 |]))
let test_md_sample () =
let s = Space.Multi_discrete.create [| 3; 4 |] in
let v = Space.sample s in
is_true ~msg:"sample valid" (Space.contains s v)
let test_md_shape () =
let s = Space.Multi_discrete.create [| 3; 4 |] in
equal ~msg:"shape [|2|]" (option (array int)) (Some [| 2 |]) (Space.shape s)
let test_md_error_empty () =
raises_invalid_arg "Space.Multi_discrete.create: nvec must not be empty"
(fun () -> Space.Multi_discrete.create [||])
let test_md_error_zero_element () =
raises_match ~msg:"nvec element <= 0 raises"
(fun exn -> match exn with Invalid_argument _ -> true | _ -> false)
(fun () -> Space.Multi_discrete.create [| 3; 0 |])
(* Tuple *)
let test_tuple_contains () =
let ds = Space.Discrete.create 3 in
let bs = Space.Box.create ~low:[| 0.0 |] ~high:[| 1.0 |] in
let s = Space.Tuple.create [ Pack ds; Pack bs ] in
let valid = [ Value.Int 1; Value.Float_array [| 0.5 |] ] in
is_true ~msg:"valid tuple" (Space.contains s valid);
let bad_length = [ Value.Int 1 ] in
is_false ~msg:"wrong length" (Space.contains s bad_length);
let bad_value = [ Value.Int 5; Value.Float_array [| 0.5 |] ] in
is_false ~msg:"invalid element" (Space.contains s bad_value)
let test_tuple_sample () =
let ds = Space.Discrete.create 3 in
let bs = Space.Box.create ~low:[| 0.0 |] ~high:[| 1.0 |] in
let s = Space.Tuple.create [ Pack ds; Pack bs ] in
let v = Space.sample s in
is_true ~msg:"sample valid" (Space.contains s v)
let test_tuple_pack_unpack () =
let ds = Space.Discrete.create 3 in
let bs = Space.Box.create ~low:[| 0.0 |] ~high:[| 1.0 |] in
let s = Space.Tuple.create [ Pack ds; Pack bs ] in
let v = [ Value.Int 1; Value.Float_array [| 0.5 |] ] in
let packed = Space.pack s v in
let unpacked = Space.unpack s packed in
is_ok ~msg:"round-trip succeeds" unpacked
let test_tuple_empty () =
let s = Space.Tuple.create [] in
is_true ~msg:"empty tuple valid" (Space.contains s []);
is_false ~msg:"non-empty invalid" (Space.contains s [ Value.Int 0 ])
(* Dict *)
let test_dict_contains () =
let ds = Space.Discrete.create 3 in
let bs = Space.Box.create ~low:[| 0.0 |] ~high:[| 1.0 |] in
let s = Space.Dict.create [ ("action", Pack ds); ("obs", Pack bs) ] in
let valid =
[ ("action", Value.Int 1); ("obs", Value.Float_array [| 0.5 |]) ]
in
is_true ~msg:"valid dict" (Space.contains s valid);
let missing_key = [ ("action", Value.Int 1) ] in
is_false ~msg:"missing key" (Space.contains s missing_key);
let extra_key =
[
("action", Value.Int 1);
("obs", Value.Float_array [| 0.5 |]);
("extra", Value.Int 0);
]
in
is_false ~msg:"extra key" (Space.contains s extra_key)
let test_dict_sample () =
let ds = Space.Discrete.create 3 in
let s = Space.Dict.create [ ("a", Pack ds) ] in
let v = Space.sample s in
is_true ~msg:"sample valid" (Space.contains s v)
let test_dict_error_duplicate () =
let ds = Space.Discrete.create 3 in
raises_match ~msg:"duplicate key raises"
(fun exn -> match exn with Invalid_argument _ -> true | _ -> false)
(fun () -> Space.Dict.create [ ("a", Pack ds); ("a", Pack ds) ])
(* Text *)
let test_text_contains () =
let s = Space.Text.create () in
is_true ~msg:"alpha string" (Space.contains s "hello");
is_true ~msg:"empty string" (Space.contains s "");
is_true ~msg:"with digits" (Space.contains s "abc123");
is_true ~msg:"with space" (Space.contains s "hello world")
let test_text_contains_invalid () =
let s = Space.Text.create ~charset:"abc" () in
is_false ~msg:"char outside charset" (Space.contains s "abcd")
let test_text_contains_too_long () =
let s = Space.Text.create ~max_length:3 () in
is_false ~msg:"exceeds max_length" (Space.contains s "abcd");
is_true ~msg:"at max_length" (Space.contains s "abc")
let test_text_sample () =
let s = Space.Text.create () in
let v = Space.sample s in
is_true ~msg:"sample valid" (Space.contains s v);
is_true ~msg:"sample non-empty" (String.length v > 0)
let test_text_boundary_values () =
let s = Space.Text.create () in
let bvs = Space.boundary_values s in
equal ~msg:"2 boundaries" int 2 (List.length bvs)
let test_text_error_max_length () =
raises_invalid_arg "Space.Text.create: max_length must be positive" (fun () ->
Space.Text.create ~max_length:0 ())
let test_text_error_charset () =
raises_invalid_arg "Space.Text.create: charset must not be empty" (fun () ->
Space.Text.create ~charset:"" ())
(* Sequence *)
let test_seq_contains () =
let ds = Space.Discrete.create 3 in
let s = Space.Sequence.create ~min_length:1 ~max_length:3 ds in
let v1 = int32_scalar 0 in
let v2 = int32_scalar 2 in
is_true ~msg:"length 1" (Space.contains s [ v1 ]);
is_true ~msg:"length 3" (Space.contains s [ v1; v2; v1 ]);
is_false ~msg:"empty" (Space.contains s []);
is_false ~msg:"too long" (Space.contains s [ v1; v2; v1; v2 ])
let test_seq_contains_unbounded () =
let ds = Space.Discrete.create 3 in
let s = Space.Sequence.create ~min_length:0 ds in
is_true ~msg:"empty is valid" (Space.contains s []);
is_true ~msg:"long is valid"
(Space.contains s (List.init 100 (fun _ -> int32_scalar 0)))
let test_seq_sample () =
let ds = Space.Discrete.create 3 in
let s = Space.Sequence.create ~min_length:1 ~max_length:5 ds in
let v = Space.sample s in
is_true ~msg:"sample valid" (Space.contains s v)
let test_seq_sample_fixed () =
let ds = Space.Discrete.create 3 in
let s = Space.Sequence.create ~min_length:2 ds in
let v = Space.sample s in
equal ~msg:"fixed length 2" int 2 (List.length v)
let test_seq_pack_unpack () =
let ds = Space.Discrete.create 3 in
let s = Space.Sequence.create ~min_length:1 ~max_length:3 ds in
let v = [ int32_scalar 0; int32_scalar 1 ] in
let packed = Space.pack s v in
let unpacked = Space.unpack s packed in
is_ok ~msg:"round-trip succeeds" unpacked
let test_seq_error_min_negative () =
let ds = Space.Discrete.create 3 in
raises_invalid_arg "Space.Sequence.create: min_length must be non-negative"
(fun () -> Space.Sequence.create ~min_length:(-1) ds)
let test_seq_error_max_lt_min () =
let ds = Space.Discrete.create 3 in
raises_invalid_arg "Space.Sequence.create: max_length must be >= min_length"
(fun () -> Space.Sequence.create ~min_length:5 ~max_length:2 ds)
(* Discrete helpers *)
let test_discrete_to_int () =
let v = Space.Discrete.of_int 5 in
equal ~msg:"to_int round-trip" int 5 (Space.Discrete.to_int v)
let test_discrete_of_int () =
let v = Space.Discrete.of_int 3 in
let s = Space.Discrete.create 5 in
is_true ~msg:"of_int creates valid element" (Space.contains s v)
(* Spec *)
let test_spec_discrete () =
let s = Space.Discrete.create ~start:2 4 in
let sp = Space.spec s in
equal ~msg:"discrete spec" bool true
(Space.equal_spec sp (Space.Discrete { start = 2; n = 4 }))
let test_spec_box () =
let s = Space.Box.create ~low:[| 0.0 |] ~high:[| 1.0 |] in
let sp = Space.spec s in
equal ~msg:"box spec" bool true
(Space.equal_spec sp (Space.Box { low = [| 0.0 |]; high = [| 1.0 |] }))
let test_spec_equal_same () =
let s1 = Space.Discrete.create 3 in
let s2 = Space.Discrete.create 3 in
is_true ~msg:"same spaces equal spec"
(Space.equal_spec (Space.spec s1) (Space.spec s2))
let test_spec_not_equal_different () =
let s1 = Space.Discrete.create 3 in
let s2 = Space.Discrete.create 4 in
is_false ~msg:"different spaces not equal spec"
(Space.equal_spec (Space.spec s1) (Space.spec s2))
let test_spec_not_equal_kinds () =
let s1 = Space.Discrete.create 3 in
let s2 = Space.Box.create ~low:[| 0.0 |] ~high:[| 1.0 |] in
is_false ~msg:"different kinds not equal spec"
(Space.equal_spec (Space.spec s1) (Space.spec s2))
let test_spec_tuple () =
let ds = Space.Discrete.create 3 in
let bs = Space.Box.create ~low:[| 0.0 |] ~high:[| 1.0 |] in
let s = Space.Tuple.create [ Pack ds; Pack bs ] in
let sp = Space.spec s in
let expected =
Space.Tuple
[
Space.Discrete { start = 0; n = 3 };
Space.Box { low = [| 0.0 |]; high = [| 1.0 |] };
]
in
is_true ~msg:"tuple spec" (Space.equal_spec sp expected)
let test_spec_dict () =
let ds = Space.Discrete.create 3 in
let s = Space.Dict.create [ ("a", Pack ds) ] in
let sp = Space.spec s in
let expected = Space.Dict [ ("a", Space.Discrete { start = 0; n = 3 }) ] in
is_true ~msg:"dict spec" (Space.equal_spec sp expected)
(* Tuple.unpack validation *)
let test_tuple_unpack_validates_elements () =
let ds = Space.Discrete.create 3 in
let s = Space.Tuple.create [ Pack ds ] in
(* Value.Int 5 is out of range for Discrete(n=3, start=0) *)
let bad = Value.List [ Value.Int 5 ] in
is_error ~msg:"unpack rejects invalid element" (Space.unpack s bad)
let test_tuple_unpack_valid () =
let ds = Space.Discrete.create 3 in
let s = Space.Tuple.create [ Pack ds ] in
let good = Value.List [ Value.Int 1 ] in
is_ok ~msg:"unpack accepts valid element" (Space.unpack s good)
(* Entry point *)
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
run "Fehu.Space"
[
group "Discrete"
[
test "default start" test_discrete_default;
test "custom start" test_discrete_custom_start;
test "contains valid/invalid" test_discrete_contains;
test "contains with start" test_discrete_contains_with_start;
test "sample" test_discrete_sample;
test "pack/unpack" test_discrete_pack_unpack;
test "unpack invalid" test_discrete_unpack_invalid;
test "boundary values" test_discrete_boundary_values;
test "boundary single" test_discrete_boundary_single;
test "shape" test_discrete_shape;
test "error n=0" test_discrete_error_zero;
test "error n<0" test_discrete_error_negative;
test "to_int round-trip" test_discrete_to_int;
test "of_int valid" test_discrete_of_int;
];
group "Box"
[
test "1d create and bounds" test_box_1d;
test "contains" test_box_contains;
test "sample" test_box_sample;
test "pack/unpack" test_box_pack_unpack;
test "boundary values" test_box_boundary_values;
test "boundary identical" test_box_boundary_identical;
test "shape 1d" test_box_shape_1d;
test "2d" test_box_2d;
test "error empty" test_box_error_empty;
test "error mismatched lengths" test_box_error_mismatch;
test "error low > high" test_box_error_low_gt_high;
];
group "Multi_binary"
[
test "contains" test_mb_contains;
test "sample" test_mb_sample;
test "boundary values" test_mb_boundary_values;
test "shape" test_mb_shape;
test "error n=0" test_mb_error;
];
group "Multi_discrete"
[
test "contains" test_md_contains;
test "sample" test_md_sample;
test "shape" test_md_shape;
test "error empty" test_md_error_empty;
test "error element <= 0" test_md_error_zero_element;
];
group "Tuple"
[
test "contains" test_tuple_contains;
test "sample" test_tuple_sample;
test "pack/unpack" test_tuple_pack_unpack;
test "empty tuple" test_tuple_empty;
test "unpack validates elements" test_tuple_unpack_validates_elements;
test "unpack valid" test_tuple_unpack_valid;
];
group "Dict"
[
test "contains" test_dict_contains;
test "sample" test_dict_sample;
test "error duplicate keys" test_dict_error_duplicate;
];
group "Text"
[
test "contains" test_text_contains;
test "contains invalid charset" test_text_contains_invalid;
test "contains too long" test_text_contains_too_long;
test "sample" test_text_sample;
test "boundary values" test_text_boundary_values;
test "error max_length=0" test_text_error_max_length;
test "error empty charset" test_text_error_charset;
];
group "Sequence"
[
test "contains bounded" test_seq_contains;
test "contains unbounded" test_seq_contains_unbounded;
test "sample" test_seq_sample;
test "sample fixed length" test_seq_sample_fixed;
test "pack/unpack" test_seq_pack_unpack;
test "error min < 0" test_seq_error_min_negative;
test "error max < min" test_seq_error_max_lt_min;
];
group "spec"
[
test "discrete" test_spec_discrete;
test "box" test_spec_box;
test "equal same" test_spec_equal_same;
test "not equal different" test_spec_not_equal_different;
test "not equal kinds" test_spec_not_equal_kinds;
test "tuple" test_spec_tuple;
test "dict" test_spec_dict;
];
]
================================================
FILE: packages/fehu/test/test_value.ml
================================================
open Fehu
open Windtrap
let value = testable ~pp:Value.pp ~equal:Value.equal ()
(* Equality *)
let test_null_equal () = equal ~msg:"null = null" value Value.Null Value.Null
let test_bool_equal () =
equal ~msg:"true = true" value (Bool true) (Bool true);
not_equal ~msg:"true <> false" value (Bool true) (Bool false)
let test_int_equal () =
equal ~msg:"1 = 1" value (Int 1) (Int 1);
not_equal ~msg:"1 <> 2" value (Int 1) (Int 2)
let test_float_equal () = equal ~msg:"1.0 = 1.0" value (Float 1.0) (Float 1.0)
let test_string_equal () =
equal ~msg:"a = a" value (String "a") (String "a");
not_equal ~msg:"a <> b" value (String "a") (String "b")
let test_int_array_equal () =
equal ~msg:"[|1;2|] = [|1;2|]" value
(Int_array [| 1; 2 |])
(Int_array [| 1; 2 |])
let test_float_array_equal () =
equal ~msg:"[|1.0|] = [|1.0|]" value (Float_array [| 1.0 |])
(Float_array [| 1.0 |])
let test_bool_array_equal () =
equal ~msg:"[|true|] = [|true|]" value (Bool_array [| true |])
(Bool_array [| true |])
let test_list_equal () =
equal ~msg:"[Int 1] = [Int 1]" value (List [ Int 1 ]) (List [ Int 1 ])
let test_dict_equal () =
equal ~msg:"dict equal" value (Dict [ ("k", Int 1) ]) (Dict [ ("k", Int 1) ])
let test_cross_type_inequality () =
not_equal ~msg:"Int 1 <> Float 1.0" value (Int 1) (Float 1.0);
not_equal ~msg:"Null <> Int 0" value Null (Int 0)
(* Formatting *)
let test_to_string_null () =
equal ~msg:"null" string "null" (Value.to_string Null)
let test_to_string_bool () =
equal ~msg:"bool true" string "true" (Value.to_string (Bool true))
let test_to_string_int () =
equal ~msg:"int 42" string "42" (Value.to_string (Int 42))
let test_to_string_float () =
let s = Value.to_string (Float 3.14) in
is_true ~msg:"float non-empty" (String.length s > 0)
let test_to_string_string () =
let s = Value.to_string (String "hello") in
is_true ~msg:"string non-empty" (String.length s > 0)
let test_to_string_arrays () =
is_true ~msg:"int_array non-empty"
(String.length (Value.to_string (Int_array [| 1 |])) > 0);
is_true ~msg:"float_array non-empty"
(String.length (Value.to_string (Float_array [| 1.0 |])) > 0);
is_true ~msg:"bool_array non-empty"
(String.length (Value.to_string (Bool_array [| true |])) > 0)
let test_to_string_list () =
let s = Value.to_string (List [ Int 1; Int 2 ]) in
is_true ~msg:"list non-empty" (String.length s > 0)
let test_to_string_dict () =
let s = Value.to_string (Dict [ ("k", Int 1) ]) in
is_true ~msg:"dict non-empty" (String.length s > 0)
let () =
run "Fehu.Value"
[
group "equality"
[
test "null" test_null_equal;
test "bool" test_bool_equal;
test "int" test_int_equal;
test "float" test_float_equal;
test "string" test_string_equal;
test "int_array" test_int_array_equal;
test "float_array" test_float_array_equal;
test "bool_array" test_bool_array_equal;
test "list" test_list_equal;
test "dict" test_dict_equal;
test "cross-type inequality" test_cross_type_inequality;
];
group "formatting"
[
test "to_string null" test_to_string_null;
test "to_string bool" test_to_string_bool;
test "to_string int" test_to_string_int;
test "to_string float" test_to_string_float;
test "to_string string" test_to_string_string;
test "to_string arrays" test_to_string_arrays;
test "to_string list" test_to_string_list;
test "to_string dict" test_to_string_dict;
];
]
================================================
FILE: packages/fehu/test/test_vec_env.ml
================================================
open Fehu
open Windtrap
let make_test_env ?(max_steps = 100) () =
let obs_space = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let act_space = Space.Discrete.create 2 in
let state = ref 5.0 in
let steps = ref 0 in
let reset _env ?options:_ () =
state := 5.0;
steps := 0;
(Nx.create Nx.float32 [| 1 |] [| !state |], Info.empty)
in
let step _env action =
let a : Int32.t array = Nx.to_array (Nx.reshape [| 1 |] action) in
state := !state +. if Int32.to_int a.(0) = 0 then -1.0 else 1.0;
incr steps;
let terminated = !state <= 0.0 || !state >= 10.0 in
let truncated = (not terminated) && !steps >= max_steps in
Env.step_result
~observation:(Nx.create Nx.float32 [| 1 |] [| !state |])
~reward:1.0 ~terminated ~truncated ()
in
Env.create ~id:"Test-v0" ~observation_space:obs_space ~action_space:act_space
~reset ~step ()
let make_envs n = List.init n (fun _ -> make_test_env ())
(* Creation *)
let test_create_num_envs () =
let venv = Vec_env.create (make_envs 3) in
equal ~msg:"num_envs" int 3 (Vec_env.num_envs venv)
let test_spaces_match_first_env () =
let envs = make_envs 3 in
let venv = Vec_env.create envs in
let obs_shape = Space.shape (Vec_env.observation_space venv) in
let act_shape = Space.shape (Vec_env.action_space venv) in
let first_obs = Space.shape (Env.observation_space (List.hd envs)) in
let first_act = Space.shape (Env.action_space (List.hd envs)) in
equal ~msg:"obs space shape" (option (array int)) first_obs obs_shape;
equal ~msg:"act space shape" (option (array int)) first_act act_shape
let test_empty_list_raises () =
raises_invalid_arg "Vec_env.create: env list must not be empty" (fun () ->
ignore (Vec_env.create []))
let test_incompatible_spaces_raises () =
let obs1 = Space.Box.create ~low:[| 0.0 |] ~high:[| 10.0 |] in
let act = Space.Discrete.create 2 in
let obs2 = Space.Box.create ~low:[| 0.0 |] ~high:[| 5.0 |] in
let make_env obs =
let reset _env ?options:_ () =
(Nx.create Nx.float32 [| 1 |] [| 0.0 |], Info.empty)
in
let step _env _action =
Env.step_result ~observation:(Nx.create Nx.float32 [| 1 |] [| 0.0 |]) ()
in
Env.create ~observation_space:obs ~action_space:act ~reset ~step ()
in
let e1 = make_env obs1 in
let e2 = make_env obs2 in
raises_match ~msg:"incompatible spaces raises"
(fun exn -> match exn with Invalid_argument _ -> true | _ -> false)
(fun () -> ignore (Vec_env.create [ e1; e2 ]))
(* Reset *)
let test_reset_obs_length () =
let venv = Vec_env.create (make_envs 3) in
let obs, _infos = Vec_env.reset venv () in
equal ~msg:"obs array length" int 3 (Array.length obs)
let test_reset_infos_length () =
let venv = Vec_env.create (make_envs 3) in
let _obs, infos = Vec_env.reset venv () in
equal ~msg:"infos array length" int 3 (Array.length infos)
(* Step *)
let test_step_result_lengths () =
let venv = Vec_env.create (make_envs 3) in
let _obs, _infos = Vec_env.reset venv () in
let action = Nx.create Nx.int32 [| 1 |] [| 1l |] in
let actions = Array.make 3 action in
let s = Vec_env.step venv actions in
equal ~msg:"observations length" int 3 (Array.length s.observations);
equal ~msg:"rewards length" int 3 (Array.length s.rewards);
equal ~msg:"terminated length" int 3 (Array.length s.terminated);
equal ~msg:"truncated length" int 3 (Array.length s.truncated);
equal ~msg:"infos length" int 3 (Array.length s.infos)
let test_wrong_action_count_raises () =
let venv = Vec_env.create (make_envs 3) in
let _obs, _infos = Vec_env.reset venv () in
let action = Nx.create Nx.int32 [| 1 |] [| 1l |] in
let actions = Array.make 2 action in
raises_invalid_arg "Vec_env.step: expected 3 actions, got 2" (fun () ->
ignore (Vec_env.step venv actions))
let test_autoreset_final_observation () =
let env = make_test_env ~max_steps:3 () in
let venv = Vec_env.create [ env ] in
let _obs, _infos = Vec_env.reset venv () in
let right = Nx.create Nx.int32 [| 1 |] [| 1l |] in
let actions = [| right |] in
(* Step until truncated at max_steps=3 *)
let s1 = Vec_env.step venv actions in
is_false ~msg:"not done after step 1" s1.truncated.(0);
let s2 = Vec_env.step venv actions in
is_false ~msg:"not done after step 2" s2.truncated.(0);
let s3 = Vec_env.step venv actions in
is_true ~msg:"truncated after step 3" s3.truncated.(0);
(* After autoreset, info should have final_observation *)
is_some ~msg:"final_observation key present"
(Info.find "final_observation" s3.infos.(0));
(* Observation should be from reset (5.0), not terminal *)
let arr : float array =
Nx.to_array (Nx.reshape [| 1 |] s3.observations.(0))
in
equal ~msg:"obs is from reset" (float 1e-6) 5.0 arr.(0)
(* Close *)
let test_close_all_envs () =
let envs = make_envs 3 in
let venv = Vec_env.create envs in
Vec_env.close venv;
List.iter (fun env -> is_true ~msg:"env is closed" (Env.closed env)) envs
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
run "Fehu.Vec_env"
[
group "creation"
[
test "num_envs" test_create_num_envs;
test "spaces match first env" test_spaces_match_first_env;
test "empty list raises" test_empty_list_raises;
test "incompatible spaces raises" test_incompatible_spaces_raises;
];
group "reset"
[
test "observations length" test_reset_obs_length;
test "infos length" test_reset_infos_length;
];
group "step"
[
test "result array lengths" test_step_result_lengths;
test "wrong action count raises" test_wrong_action_count_raises;
test "autoreset with final_observation"
test_autoreset_final_observation;
];
group "close" [ test "closes all inner envs" test_close_all_envs ];
]
================================================
FILE: packages/hugin/README.md
================================================
# Hugin
Declarative plotting and visualization library for OCaml.
Hugin is part of the Raven ecosystem, providing a functional API to create
publication-quality charts and figures from Nx arrays. You build immutable
plot specifications with mark constructors, compose them with `|>` pipelines,
and render to PNG, SVG, PDF, or an interactive SDL window.
## Features
- Line, scatter, bar, histogram, error bar, fill-between, hline/vline, hspan/vspan
- Heatmap, colormapped image display (`imshow`), contour plots
- Multi-panel layouts with `Layout.grid`, `Layout.hstack`, `Layout.vstack`
- Perceptually uniform OKLCH colors with colorblind-friendly Okabe-Ito palette
- Predefined colormaps: viridis, plasma, inferno, magma, cividis, coolwarm
- Themes with context scaling (paper, notebook, talk, poster)
- Axis scales: linear, log, sqrt, asinh, symlog
- Cairo rendering (PNG, PDF), pure-OCaml SVG backend, interactive SDL display
- Format printer for Quill notebooks (`#install_printer`)
## Quick Start
```ocaml
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. 6.28 100 in
let y = Nx.sin x in
line ~x ~y () |> title "Sine wave" |> render_png "sine.png"
```
## Contributing
See the [Raven monorepo README](../../README.md) for guidelines.
## License
ISC License. See [LICENSE](../../LICENSE) for details.
================================================
FILE: packages/hugin/doc/01-getting-started.md
================================================
# Getting Started
This guide covers installation, your first plot, and the key concepts behind Hugin.
## Installation
Install system dependencies:
```bash
# macOS
brew install cairo sdl2
# Ubuntu/Debian
apt install libcairo2-dev libsdl2-dev
```
Then install hugin:
```bash
opam install hugin
```
Or build from source:
```bash
git clone https://github.com/raven-ml/raven
cd raven && dune build dev/hugin
```
Add to your `dune` file:
```dune
(executable
(name main)
(libraries hugin))
```
## Your First Plot
```ocaml
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. (2. *. Float.pi) 100 in
let y = Nx.sin x in
line ~x ~y () |> title "Sine wave" |> render_png "sine.png"
```
This creates a 1-D array of 100 points, computes the sine, builds a line specification, adds a title, and writes a PNG file.
## Key Concepts
### Marks
A mark constructor (`line`, `point`, `bar`, `hist`, `heatmap`, etc.) takes data arrays and optional visual properties and returns an immutable plot specification of type `t`. A mark is already a complete spec — you can render it directly:
```ocaml
line ~x ~y () |> render_png "plot.png"
```
### Decorations
Decoration functions add metadata to a spec. They are designed for the `|>` pipeline:
```ocaml
line ~x ~y ()
|> title "My Plot"
|> xlabel "Time (s)"
|> ylabel "Amplitude"
|> xlim 0. 10.
|> grid_lines true
```
Decorations include `title`, `xlabel`, `ylabel`, `xlim`, `ylim`, `xscale`, `yscale`, `grid_lines`, `legend`, `xticks`, `yticks`, `xinvert`, `yinvert`, `with_theme`, and tick formatting.
### Composition
`layers` overlays multiple marks on shared axes:
```ocaml
layers [
line ~x ~y:(Nx.sin x) ~label:"sin" ();
line ~x ~y:(Nx.cos x) ~label:"cos" ~line_style:`Dashed ();
]
|> legend |> render_png "overlay.png"
```
You can mix mark types freely. A `line` with `point` markers, a `bar` chart with `hline` reference lines — anything goes.
### Layout
`Layout.grid` arranges specs in rows and columns:
```ocaml
let p1 = line ~x ~y:(Nx.sin x) () |> title "sin" in
let p2 = line ~x ~y:(Nx.cos x) () |> title "cos" in
Layout.grid [ [ p1; p2 ] ] |> render_png "grid.png"
```
`Layout.hstack` and `Layout.vstack` are shorthands for single-row and single-column grids.
### Rendering
Four output modes:
| Function | Output |
|----------|--------|
| `render_png "file.png" t` | PNG image file |
| `render_svg "file.svg" t` | SVG document file |
| `render_pdf "file.pdf" t` | PDF document file |
| `show t` | Interactive SDL window (resize, Esc to close) |
All renderers accept optional `~width` and `~height` (default 1600×1200) and `~theme`.
`render_svg_to_string` and `render_to_buffer` return the output as a string instead of writing a file.
## Common Marks
### Line
```ocaml
line ~x ~y ()
line ~x ~y ~color:Color.blue ~line_style:`Dashed ~line_width:2.0 ()
line ~x ~y ~step:`Post () (* staircase plot *)
```
### Scatter
```ocaml
point ~x ~y ()
point ~x ~y ~color_by:values ~size:8. ~marker:Star ()
point ~x ~y ~size_by:weights () (* variable marker size *)
```
### Bar Chart
```ocaml
bar ~x:categories ~height:values ()
bar ~x:categories ~height:values ~width:0.5 ~color:Color.orange ()
```
### Histogram
```ocaml
hist ~x:data ()
hist ~x:data ~bins:(`Num 30) ~density:true ~color:Color.green ()
```
### Heatmap
```ocaml
(* data has shape [|rows; cols|] *)
heatmap ~data ()
heatmap ~data ~annotate:true ~cmap:Cmap.viridis ()
```
### Fill Between
```ocaml
fill_between ~x ~y1:(Nx.sub y err) ~y2:(Nx.add y err) ~alpha:0.3 ()
```
### Error Bars
```ocaml
errorbar ~x ~y ~yerr:(`Symmetric err) ()
errorbar ~x ~y ~yerr:(`Asymmetric (lo, hi)) ~xerr:(`Symmetric xerr) ()
```
## Next Steps
- [Marks and Styling](/docs/hugin/marks-and-styling/) — full mark catalog and visual properties
- [Layout and Decorations](/docs/hugin/layout-and-decorations/) — axes, scales, themes, multi-panel
- [Colors and Colormaps](/docs/hugin/colors-and-colormaps/) — OKLCH colors and colormap reference
================================================
FILE: packages/hugin/doc/02-marks-and-styling.md
================================================
# Marks and Styling
Every visualization in Hugin starts with one or more marks. A mark constructor takes data arrays and optional visual properties and returns an immutable plot specification.
## Mark Catalog
### Line Plots
`line ~x ~y ()` connects points `(x.(i), y.(i))` with straight segments.
| Argument | Type | Default | Description |
|----------|------|---------|-------------|
| `~x` | `Nx.float32_t` | required | X coordinates |
| `~y` | `Nx.float32_t` | required | Y coordinates |
| `~color` | `Color.t` | theme palette | Line color |
| `~line_width` | `float` | theme line width | Stroke width |
| `~line_style` | `` `Solid \| `Dashed \| `Dotted \| `Dash_dot `` | `` `Solid `` | Dash pattern |
| `~step` | `` `Pre \| `Post \| `Mid `` | none | Staircase interpolation |
| `~marker` | `marker` | none | Marker at each point |
| `~label` | `string` | none | Legend entry |
| `~alpha` | `float` | 1.0 | Opacity |
Step modes: `Post` holds each value until the next x-point, `Pre` steps at the current x-point, `Mid` steps at the midpoint.
### Scatter Plots
`point ~x ~y ()` places individual markers at data coordinates.
| Argument | Type | Default | Description |
|----------|------|---------|-------------|
| `~x` | `Nx.float32_t` | required | X coordinates |
| `~y` | `Nx.float32_t` | required | Y coordinates |
| `~color` | `Color.t` | theme palette | Uniform color |
| `~color_by` | `Nx.float32_t` | none | Per-point values mapped through sequential colormap |
| `~size` | `float` | theme marker size | Uniform marker size |
| `~size_by` | `Nx.float32_t` | none | Per-point values for variable marker area |
| `~marker` | `marker` | `Circle` | Marker shape |
| `~label` | `string` | none | Legend entry |
| `~alpha` | `float` | 1.0 | Opacity |
When `~color_by` is set, a colorbar is displayed showing the value-to-color mapping.
### Bar Charts
`bar ~x ~height ()` draws vertical bars centered on `x` values.
| Argument | Type | Default | Description |
|----------|------|---------|-------------|
| `~x` | `Nx.float32_t` | required | Bar center positions |
| `~height` | `Nx.float32_t` | required | Bar heights |
| `~width` | `float` | 0.8 | Bar width |
| `~bottom` | `float` | 0.0 | Baseline y-value |
| `~color` | `Color.t` | theme palette | Fill color |
| `~label` | `string` | none | Legend entry |
| `~alpha` | `float` | 1.0 | Opacity |
### Histograms
`hist ~x ()` bins the values in `x` and draws a bar chart.
| Argument | Type | Default | Description |
|----------|------|---------|-------------|
| `~x` | `Nx.float32_t` | required | Data values |
| `~bins` | `` `Num of int \| `Edges of float array `` | `` `Num 10 `` | Number of bins or explicit edges |
| `~density` | `bool` | false | Normalize so total area equals 1.0 |
| `~color` | `Color.t` | theme palette | Fill color |
| `~label` | `string` | none | Legend entry |
### Reference Lines and Spans
`hline ~y ()` draws a horizontal line across the full plot width. `vline ~x ()` draws a vertical line across the full height. Both accept `~color`, `~line_width`, `~line_style`, `~label`, and `~alpha`.
`hspan ~y0 ~y1 ()` shades a horizontal band. `vspan ~x0 ~x1 ()` shades a vertical band. Both accept `~color`, `~alpha` (default 0.2), and `~label`.
### Fill Between
`fill_between ~x ~y1 ~y2 ()` fills the area between two curves. `~alpha` defaults to 0.3.
### Error Bars
`errorbar ~x ~y ~yerr ()` draws error bars at each point.
- `~yerr`: `` `Symmetric e `` draws y ± e, `` `Asymmetric (lo, hi) `` draws [y - lo, y + hi]
- `~xerr`: optional horizontal error bars with the same format
- `~cap_size`: cap width (defaults to half the theme marker size)
### Text
`text ~x ~y "label" ()` places a string at data coordinates `(x, y)`. Accepts `~color` and `~font_size`.
### Image
`image data` displays an Nx uint8 array as an image. `data` must have shape `[|h; w; 3|]` (RGB) or `[|h; w; 4|]` (RGBA).
### Colormapped Image
`imshow ~data ()` displays a 2-D float array through a colormap.
| Argument | Type | Default | Description |
|----------|------|---------|-------------|
| `~data` | `Nx.float32_t` | required | 2-D array of shape `[|rows; cols|]` |
| `~stretch` | `` `Linear \| `Log \| `Sqrt \| `Asinh \| `Power of float `` | `` `Linear `` | Transfer function before colormap lookup |
| `~cmap` | `Cmap.t` | theme sequential | Colormap |
| `~vmin` | `float` | data min | Lower bound of color range |
| `~vmax` | `float` | data max | Upper bound of color range |
### Heatmap
`heatmap ~data ()` displays a 2-D array as a grid of colored cells. Row 0 appears at the top.
| Argument | Type | Default | Description |
|----------|------|---------|-------------|
| `~data` | `Nx.float32_t` | required | 2-D array of shape `[|rows; cols|]` |
| `~annotate` | `bool` | false | Show cell values |
| `~cmap` | `Cmap.t` | theme sequential | Colormap |
| `~vmin` | `float` | data min | Lower bound |
| `~vmax` | `float` | data max | Upper bound |
| `~fmt` | `float -> string` | `Printf.sprintf "%.2g"` | Cell value formatter (when annotate is true) |
### Contour
`contour ~data ~x0 ~x1 ~y0 ~y1 ()` draws iso-level contour lines through a 2-D grid.
| Argument | Type | Default | Description |
|----------|------|---------|-------------|
| `~data` | `Nx.float32_t` | required | 2-D array of shape `[|rows; cols|]` |
| `~x0`, `~x1`, `~y0`, `~y1` | `float` | required | Data-space rectangle |
| `~levels` | `` `Num of int \| `Values of float array `` | `` `Num 8 `` | Number of levels or explicit values |
| `~filled` | `bool` | false | Fill regions between levels |
| `~cmap` | `Cmap.t` | theme sequential | Per-level colormap |
| `~color` | `Color.t` | none | Single stroke color (unfilled contours) |
| `~line_width` | `float` | theme line width | Stroke width |
| `~label` | `string` | none | Legend entry |
| `~alpha` | `float` | 1.0 | Opacity |
## Marker Shapes
Five built-in shapes:
| Marker | Description |
|--------|-------------|
| `Circle` | Filled circle |
| `Square` | Filled square |
| `Triangle` | Filled triangle |
| `Plus` | Plus sign (+) |
| `Star` | Five-pointed star |
Use with `line ~marker:Triangle` or `point ~marker:Star`.
## Auto-Coloring
When you omit `~color`, marks are colored automatically from the theme's categorical palette. The first mark in a spec gets `palette.(0)`, the second gets `palette.(1)`, and so on. Explicitly setting `~color` takes precedence.
## Next Steps
- [Layout and Decorations](/docs/hugin/layout-and-decorations/) — axes, scales, themes, multi-panel layouts
- [Colors and Colormaps](/docs/hugin/colors-and-colormaps/) — OKLCH color space, palettes, and colormap reference
================================================
FILE: packages/hugin/doc/03-layout-and-decorations.md
================================================
# Layout and Decorations
Decorations add metadata and styling to a plot specification. Layout functions arrange multiple specs into multi-panel figures.
## Decorations
All decoration functions take a `t` and return a new `t`, designed for the `|>` pipeline:
```ocaml
line ~x ~y ()
|> title "Frequency Response"
|> xlabel "Frequency (Hz)"
|> ylabel "Magnitude (dB)"
|> xscale `Log
|> ylim (-60.) 0.
|> grid_lines true
```
### Titles and Labels
| Function | Description |
|----------|-------------|
| `title s t` | Plot title |
| `xlabel s t` | X-axis label |
| `ylabel s t` | Y-axis label |
### Axis Limits
| Function | Description |
|----------|-------------|
| `xlim lo hi t` | Fix x-axis range to [lo, hi] |
| `ylim lo hi t` | Fix y-axis range to [lo, hi] |
When omitted, axis ranges are computed automatically from the data with 5% padding.
### Axis Scales
| Function | Description |
|----------|-------------|
| `xscale s t` | Set x-axis scale |
| `yscale s t` | Set y-axis scale |
Available scales:
| Scale | When to use |
|-------|-------------|
| `` `Linear `` | Default. Uniform spacing. |
| `` `Log `` | Data spanning multiple orders of magnitude. All values must be positive. |
| `` `Sqrt `` | Moderate compression of large values. Handles zero. |
| `` `Asinh `` | Like log but handles zero and negative values. Transitions smoothly from linear near zero to logarithmic at large magnitudes. |
| `` `Symlog linthresh `` | Linear within [-linthresh, linthresh], logarithmic outside. Good for data with both small and large values centered around zero. |
### Axis Direction
| Function | Description |
|----------|-------------|
| `xinvert t` | X-axis values increase right-to-left |
| `yinvert t` | Y-axis values increase top-to-bottom |
Useful for conventions like right ascension in sky charts (xinvert) or magnitude axes in HR diagrams (yinvert).
### Ticks
| Function | Description |
|----------|-------------|
| `xticks ticks t` | Explicit tick positions and labels as `(float * string) list` |
| `yticks ticks t` | Same for y-axis |
| `xtick_format fmt t` | Custom tick label formatter (preserves auto-generated positions) |
| `ytick_format fmt t` | Same for y-axis |
Example with explicit ticks:
```ocaml
line ~x ~y ()
|> xticks [ (0., "Jan"); (1., "Feb"); (2., "Mar"); (3., "Apr") ]
```
Example with custom formatting:
```ocaml
line ~x ~y ()
|> xtick_format (Printf.sprintf "%.1f%%")
```
### Grid and Legend
| Function | Description |
|----------|-------------|
| `grid_lines visible t` | Show or hide grid lines |
| `legend ?loc t` | Show legend at `loc` (default `Upper_right`) |
Legend locations: `Upper_right`, `Upper_left`, `Lower_right`, `Lower_left`, `Center`.
The legend is populated from marks that have a `~label`. Marks without labels are excluded.
### Theme Override
`with_theme theme t` renders with `theme` instead of the default.
## Layout
### Grid
`Layout.grid rows` arranges specs in a grid where each inner list is a row:
```ocaml
let p1 = line ~x ~y:(Nx.sin x) () |> title "sin" in
let p2 = line ~x ~y:(Nx.cos x) () |> title "cos" in
let p3 = line ~x ~y:(Nx.tan x) () |> title "tan" |> ylim (-5.) 5. in
let p4 = hist ~x:(Nx.rand Nx.float32 [|500|]) () |> title "random" in
Layout.grid [ [ p1; p2 ]; [ p3; p4 ] ] |> render_png "grid.png"
```
`~gap` controls spacing between panels as a fraction of total size (default 0.05).
### Stack
| Function | Description |
|----------|-------------|
| `Layout.hstack specs` | Single row of panels |
| `Layout.vstack specs` | Single column of panels |
Both accept `~gap`.
## Themes
A theme controls every non-data visual element: background, typography, axes, grid, spacing, and data palettes.
### Predefined Themes
| Theme | Description |
|-------|-------------|
| `Theme.default` | Light background, subtle grid, Okabe-Ito palette |
| `Theme.dark` | Dark background |
| `Theme.minimal` | No grid, thin axes |
```ocaml
line ~x ~y () |> with_theme Theme.dark |> render_png "dark.png"
```
### Context Scaling
Context functions scale all visual elements (fonts, line widths, spacing) for different output media:
| Function | Scale factor | Use case |
|----------|-------------|----------|
| `Theme.paper` | 1.0 | Journal figures |
| `Theme.notebook` | 1.3 | Quill notebooks |
| `Theme.talk` | 1.6 | Slides and presentations |
| `Theme.poster` | 2.0 | Conference posters |
```ocaml
let theme = Theme.dark |> Theme.talk in
line ~x ~y () |> with_theme theme |> render_png "slide.png"
```
### Theme Fields
The `Theme.t` record is fully public. You can create custom themes by modifying fields:
| Field | Type | Description |
|-------|------|-------------|
| `background` | `Color.t` | Background color |
| `palette` | `Color.t array` | Categorical color palette |
| `sequential` | `Cmap.t` | Default sequential colormap |
| `diverging` | `Cmap.t` | Default diverging colormap |
| `font_title` | `Theme.font` | Title font |
| `font_label` | `Theme.font` | Axis label font |
| `font_tick` | `Theme.font` | Tick label font |
| `axis` | `Theme.line` | Axis line style |
| `grid` | `Theme.line option` | Grid line style (None to hide) |
| `tick_length` | `float` | Tick mark length |
| `padding` | `float` | Plot area padding |
| `title_gap` | `float` | Gap below title |
| `label_gap` | `float` | Gap between label and axis |
| `scale_factor` | `float` | Global size multiplier |
| `line_width` | `float` | Default line width |
| `marker_size` | `float` | Default marker size |
## Next Steps
- [Colors and Colormaps](/docs/hugin/colors-and-colormaps/) — OKLCH color space, operations, and colormap reference
- [Matplotlib Comparison](/docs/hugin/matplotlib-comparison/) — side-by-side with Python
================================================
FILE: packages/hugin/doc/04-colors-and-colormaps.md
================================================
# Colors and Colormaps
Hugin uses the OKLCH color space for perceptually uniform color operations and ships with colorblind-friendly palettes and scientific colormaps.
## Colors
### OKLCH Color Space
Colors are represented internally in [OKLCH](https://bottosson.github.io/posts/oklab/), a perceptually uniform color space. Operations like `lighten`, `darken`, and `mix` produce visually consistent results: equal numerical steps yield equal perceived differences.
OKLCH components:
| Component | Range | Description |
|-----------|-------|-------------|
| Lightness (L) | [0, 1] | Black to white |
| Chroma (C) | [0, ~0.4] | Gray to saturated |
| Hue (H) | [0, 360) | Color wheel angle |
| Alpha (A) | [0, 1] | Transparency |
### Constructors
```ocaml
(* From OKLCH components *)
Color.oklch ~l:0.7 ~c:0.15 ~h:145. ()
Color.oklcha ~l:0.7 ~c:0.15 ~h:145. ~a:0.5 ()
(* From sRGB [0, 1] *)
Color.rgb ~r:0.2 ~g:0.6 ~b:0.8 ()
Color.rgba ~r:0.2 ~g:0.6 ~b:0.8 ~a:0.5 ()
(* From hex string *)
Color.hex "#3399CC"
Color.hex "#3399CCAA" (* with alpha *)
```
All constructors convert to OKLCH on creation. The reverse conversion (`to_rgba`) is called at render time.
### Accessors
```ocaml
Color.lightness c (* OKLCH lightness *)
Color.chroma c (* OKLCH chroma *)
Color.hue c (* OKLCH hue in degrees *)
Color.alpha c (* alpha channel *)
Color.to_rgba c (* sRGB (r, g, b, a) tuple, clamped to gamut *)
```
### Operations
```ocaml
Color.lighten 0.1 c (* increase lightness by 0.1, clamped to [0, 1] *)
Color.darken 0.1 c (* decrease lightness by 0.1, clamped to [0, 1] *)
Color.with_alpha 0.5 c (* set alpha *)
Color.mix 0.5 a b (* blend a and b: 0.0 = a, 1.0 = b *)
```
`mix` interpolates all OKLCH components. Hue follows the shortest arc on the color wheel.
### Named Colors
The default named colors follow the [Okabe-Ito palette](https://jfly.uni-koeln.de/color/), designed to be distinguishable under all forms of color-vision deficiency:
| Color | Value |
|-------|-------|
| `Color.orange` | Okabe-Ito orange |
| `Color.sky_blue` | Okabe-Ito sky blue |
| `Color.green` | Okabe-Ito bluish green |
| `Color.yellow` | Okabe-Ito yellow |
| `Color.blue` | Okabe-Ito blue |
| `Color.vermillion` | Okabe-Ito vermillion |
| `Color.purple` | Okabe-Ito reddish purple |
| `Color.black` | Black |
| `Color.white` | White |
| `Color.gray` | Neutral gray |
### Formatting
`Color.pp` formats as `oklch(L C H / A)` for debugging.
## Colormaps
A colormap is a continuous mapping from [0, 1] to `Color.t`. Internally stored as a 256-entry lookup table with OKLCH interpolation.
### Evaluation
```ocaml
let c = Cmap.eval Cmap.viridis 0.5 (* color at midpoint *)
```
Values are clamped to [0, 1].
### Predefined Colormaps
Perceptually uniform sequential colormaps from the [viridis family](https://bids.github.io/colormap/):
| Colormap | Description |
|----------|-------------|
| `Cmap.viridis` | Purple-teal-yellow (default) |
| `Cmap.plasma` | Purple-orange-yellow |
| `Cmap.inferno` | Black-purple-orange-yellow |
| `Cmap.magma` | Black-purple-pink-yellow |
| `Cmap.cividis` | Optimized for color-vision deficiency |
Other colormaps:
| Colormap | Description |
|----------|-------------|
| `Cmap.coolwarm` | Blue-white-red diverging |
| `Cmap.gray` | Black to white |
| `Cmap.gray_r` | White to black (standard for astronomy) |
| `Cmap.hot` | Black-red-yellow-white |
### Custom Colormaps
`Cmap.of_colors` creates a colormap by interpolating linearly through an array of color stops in OKLCH space:
```ocaml
let my_cmap = Cmap.of_colors [|
Color.hex "#000080";
Color.hex "#FFFFFF";
Color.hex "#800000";
|]
```
Stops are evenly spaced from 0 to 1. Requires at least 2 colors.
## Using Colors with Marks
### Uniform Color
Set `~color` on any mark:
```ocaml
line ~x ~y ~color:Color.vermillion ()
bar ~x ~height ~color:(Color.hex "#336699") ()
```
### Data-Driven Color
`point` supports `~color_by` to map per-point values through the theme's sequential colormap:
```ocaml
point ~x ~y ~color_by:temperature ~marker:Circle ()
```
A colorbar is displayed automatically.
### Colormaps on 2-D Data
`heatmap`, `imshow`, and `contour` accept `~cmap` to override the default:
```ocaml
heatmap ~data ~cmap:Cmap.coolwarm ()
imshow ~data ~cmap:Cmap.inferno ~stretch:`Log ()
contour ~data ~x0 ~x1 ~y0 ~y1 ~filled:true ~cmap:Cmap.plasma ()
```
## Next Steps
- [Matplotlib Comparison](/docs/hugin/matplotlib-comparison/) — side-by-side with Python
- [Marks and Styling](/docs/hugin/marks-and-styling/) — full mark catalog
================================================
FILE: packages/hugin/doc/05-matplotlib-comparison.md
================================================
# Hugin vs Matplotlib
Side-by-side examples comparing Hugin (OCaml) with Matplotlib (Python). Hugin uses a declarative, pipeline-oriented API while Matplotlib uses an imperative, object-oriented approach.
## Key Differences
| | Hugin | Matplotlib |
|---|---|---|
| Style | Declarative, immutable specs | Imperative, mutable state |
| Composition | `\|>` pipeline | Method calls on axes |
| State | No global state | `plt` global state |
| Colors | OKLCH color space | sRGB strings |
| Output | `render_png`, `render_svg`, `show` | `plt.savefig`, `plt.show` |
## Line Plot
**Hugin:**
```ocaml
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. (2. *. Float.pi) 100 in
layers [
line ~x ~y:(Nx.sin x) ~label:"sin(x)" ~color:Color.blue ();
line ~x ~y:(Nx.cos x) ~label:"cos(x)" ~color:Color.vermillion
~line_style:`Dashed ();
]
|> title "Trigonometric Functions"
|> xlabel "Angle (radians)"
|> ylabel "Value"
|> ylim (-1.2) 1.2
|> grid_lines true
|> legend
|> render_png "trig.png"
```
**Matplotlib:**
```python
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 2 * np.pi, 100)
plt.figure()
plt.plot(x, np.sin(x), label="sin(x)", color="blue")
plt.plot(x, np.cos(x), label="cos(x)", color="red", linestyle="--")
plt.title("Trigonometric Functions")
plt.xlabel("Angle (radians)")
plt.ylabel("Value")
plt.ylim(-1.2, 1.2)
plt.grid(True)
plt.legend()
plt.savefig("trig.png")
```
## Scatter Plot
**Hugin:**
```ocaml
open Hugin
let () =
let x = Nx.rand Nx.float32 [| 200 |] in
let y = Nx.rand Nx.float32 [| 200 |] in
let c = Nx.add x y in
point ~x ~y ~color_by:c ~size:8. ~marker:Circle ()
|> title "Random Scatter"
|> render_png "scatter.png"
```
**Matplotlib:**
```python
import numpy as np
import matplotlib.pyplot as plt
x = np.random.rand(200)
y = np.random.rand(200)
c = x + y
plt.figure()
plt.scatter(x, y, c=c, s=64, marker="o")
plt.title("Random Scatter")
plt.colorbar()
plt.savefig("scatter.png")
```
## Bar Chart
**Hugin:**
```ocaml
open Hugin
let () =
let x = Nx.create Nx.float32 [| 4 |] [| 1.; 2.; 3.; 4. |] in
let h = Nx.create Nx.float32 [| 4 |] [| 3.; 7.; 2.; 5. |] in
bar ~x ~height:h ~color:Color.orange ()
|> title "Quarterly Revenue"
|> xticks [ (1., "Q1"); (2., "Q2"); (3., "Q3"); (4., "Q4") ]
|> ylabel "Revenue ($M)"
|> render_png "bar.png"
```
**Matplotlib:**
```python
import matplotlib.pyplot as plt
x = [1, 2, 3, 4]
h = [3, 7, 2, 5]
plt.figure()
plt.bar(x, h, color="orange")
plt.title("Quarterly Revenue")
plt.xticks(x, ["Q1", "Q2", "Q3", "Q4"])
plt.ylabel("Revenue ($M)")
plt.savefig("bar.png")
```
## Histogram
**Hugin:**
```ocaml
open Hugin
let () =
let data = Nx.randn Nx.float32 [| 1000 |] in
hist ~x:data ~bins:(`Num 30) ~density:true ~color:Color.sky_blue ()
|> title "Normal Distribution"
|> xlabel "Value"
|> ylabel "Density"
|> render_png "hist.png"
```
**Matplotlib:**
```python
import numpy as np
import matplotlib.pyplot as plt
data = np.random.randn(1000)
plt.figure()
plt.hist(data, bins=30, density=True, color="skyblue")
plt.title("Normal Distribution")
plt.xlabel("Value")
plt.ylabel("Density")
plt.savefig("hist.png")
```
## Multi-Panel Layout
**Hugin:**
```ocaml
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. (2. *. Float.pi) 100 in
let p1 = line ~x ~y:(Nx.sin x) () |> title "sin" in
let p2 = line ~x ~y:(Nx.cos x) () |> title "cos" in
let p3 = line ~x ~y:(Nx.tan x) () |> title "tan" |> ylim (-5.) 5. in
let p4 = hist ~x:(Nx.rand Nx.float32 [| 500 |]) () |> title "random" in
Layout.grid [ [ p1; p2 ]; [ p3; p4 ] ]
|> render_png "grid.png"
```
**Matplotlib:**
```python
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 2 * np.pi, 100)
fig, axes = plt.subplots(2, 2)
axes[0, 0].plot(x, np.sin(x)); axes[0, 0].set_title("sin")
axes[0, 1].plot(x, np.cos(x)); axes[0, 1].set_title("cos")
axes[1, 0].plot(x, np.tan(x)); axes[1, 0].set_title("tan")
axes[1, 0].set_ylim(-5, 5)
axes[1, 1].hist(np.random.rand(500)); axes[1, 1].set_title("random")
plt.tight_layout()
plt.savefig("grid.png")
```
## Heatmap
**Hugin:**
```ocaml
open Hugin
let () =
let data = Nx.init Nx.float32 [| 8; 10 |] (fun idx ->
let i = Float.of_int idx.(0) and j = Float.of_int idx.(1) in
Float.sin (i *. 0.5) *. Float.cos (j *. 0.4))
in
heatmap ~data ~annotate:true ~cmap:Cmap.viridis ()
|> title "Heatmap"
|> render_png "heatmap.png"
```
**Matplotlib:**
```python
import numpy as np
import matplotlib.pyplot as plt
data = np.fromfunction(
lambda i, j: np.sin(i * 0.5) * np.cos(j * 0.4), (8, 10)
)
fig, ax = plt.subplots()
im = ax.imshow(data, cmap="viridis")
for i in range(8):
for j in range(10):
ax.text(j, i, f"{data[i, j]:.2g}", ha="center", va="center")
ax.set_title("Heatmap")
plt.colorbar(im)
plt.savefig("heatmap.png")
```
## Styling
**Hugin:**
```ocaml
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. (2. *. Float.pi) 50 in
line ~x ~y:(Nx.sin x)
~color:Color.vermillion
~line_style:`Dashed
~line_width:2.5
~marker:Triangle
~alpha:0.7
()
|> render_png "styled.png"
```
**Matplotlib:**
```python
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 2 * np.pi, 50)
plt.figure()
plt.plot(x, np.sin(x), color="red", linestyle="--",
linewidth=2.5, marker="^", alpha=0.7)
plt.savefig("styled.png")
```
## Themes
Hugin provides built-in themes with context scaling. Matplotlib uses style sheets.
**Hugin:**
```ocaml
(* Dark theme scaled for a presentation *)
let theme = Theme.dark |> Theme.talk in
line ~x ~y () |> with_theme theme |> render_png "slide.png"
```
**Matplotlib:**
```python
plt.style.use("dark_background")
plt.rcParams.update({"font.size": 14})
plt.plot(x, y)
plt.savefig("slide.png")
```
## Save and Export
**Hugin:**
```ocaml
let spec = line ~x ~y () |> title "My Plot" in
spec |> render_png "plot.png";
spec |> render_svg "plot.svg";
spec |> render_pdf "plot.pdf";
spec |> show (* interactive SDL window *)
```
**Matplotlib:**
```python
plt.plot(x, y)
plt.title("My Plot")
plt.savefig("plot.png")
plt.savefig("plot.svg")
plt.savefig("plot.pdf")
plt.show()
```
In Hugin, the spec is an immutable value. You can render the same spec to multiple formats without rebuilding it. In Matplotlib, the figure is mutable state that `savefig` and `show` consume.
## Interactive Display
**Hugin:**
```ocaml
show ~width:1600. ~height:1200. spec
```
The SDL window is resizable. The plot re-renders at the new dimensions. Press Escape or Q to close.
**Matplotlib:**
```python
plt.show()
```
================================================
FILE: packages/hugin/doc/index.md
================================================
# Hugin
Hugin creates publication-quality plots from Nx arrays using a declarative, pipeline-oriented API.
## What Hugin Does
Hugin turns immutable plot specifications into rendered output. You build a specification from mark constructors (`line`, `point`, `bar`, `hist`), decorate it with `title`, `xlabel`, and axis controls via the `|>` pipeline, and render with `render_png`, `render_svg`, or `show`.
Internally, rendering proceeds in three stages: the user-facing spec is compiled to a prepared tree (histograms binned, data bounds computed, marks auto-colored), then resolved to device-pixel coordinates, then drawn by a backend. Data compilation happens once; layout resolution is cheap and repeatable at different sizes.
## System Requirements
Hugin needs Cairo and SDL2 for rendering:
```bash
# macOS
brew install cairo sdl2
# Ubuntu/Debian
apt install libcairo2-dev libsdl2-dev
```
## Quick Start
```ocaml
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. (2. *. Float.pi) 100 in
let y = Nx.sin x in
line ~x ~y () |> title "Sine wave" |> render_png "sine.png"
```
Two marks on shared axes:
```ocaml
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. (2. *. Float.pi) 100 in
layers [
line ~x ~y:(Nx.sin x) ~label:"sin" ();
line ~x ~y:(Nx.cos x) ~label:"cos" ~line_style:`Dashed ();
]
|> legend |> render_png "trig.png"
```
## Next Steps
- [Getting Started](/docs/hugin/getting-started/) — installation, first plot, key concepts
- [Marks and Styling](/docs/hugin/marks-and-styling/) — mark catalog, visual properties
- [Layout and Decorations](/docs/hugin/layout-and-decorations/) — axes, scales, themes, multi-panel
- [Colors and Colormaps](/docs/hugin/colors-and-colormaps/) — OKLCH colors, palettes, colormaps
- [Matplotlib Comparison](/docs/hugin/matplotlib-comparison/) — side-by-side with Python
================================================
FILE: packages/hugin/examples/01-line-plot/README.md
================================================
# Line Plot
Create data with Nx, build a line plot, and render to PNG in three lines.

================================================
FILE: packages/hugin/examples/01-line-plot/dune
================================================
(executable
(name main)
(libraries hugin nx))
(rule
(targets line_plot.png)
(deps main.exe)
(action
(run ./main.exe))
(mode
(promote (until-clean))))
================================================
FILE: packages/hugin/examples/01-line-plot/main.ml
================================================
(* Your first plot.
The simplest possible visualization: create data with Nx, build a line plot,
and render to PNG. *)
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. (2. *. Float.pi) 100 in
let y = Nx.sin x in
line ~x ~y () |> render_png "line_plot.png"
================================================
FILE: packages/hugin/examples/02-styling/README.md
================================================
# Styling
Mark constructors accept optional visual properties: `~color`, `~line_style`, `~line_width`, `~marker`, and `~alpha`.

================================================
FILE: packages/hugin/examples/02-styling/dune
================================================
(executable
(name main)
(libraries hugin nx))
(rule
(targets styling.png)
(deps main.exe)
(action
(run ./main.exe))
(mode
(promote (until-clean))))
================================================
FILE: packages/hugin/examples/02-styling/main.ml
================================================
(* Styling.
Every mark constructor accepts optional visual properties as labeled
arguments. This example shows how to set color, line style, width, and marker
shape. *)
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. (2. *. Float.pi) 50 in
let y = Nx.sin x in
line ~x ~y ~color:Color.vermillion ~line_style:`Dashed ~line_width:2.5
~marker:Triangle ~alpha:0.7 ()
|> render_png "styling.png"
================================================
FILE: packages/hugin/examples/03-scatter/README.md
================================================
# Scatter Plot
The `point` mark places markers at data coordinates. Pass `~color_by` to map a third variable through the theme's sequential colormap.

================================================
FILE: packages/hugin/examples/03-scatter/dune
================================================
(executable
(name main)
(libraries hugin nx))
(rule
(targets scatter.png)
(deps main.exe)
(action
(run ./main.exe))
(mode
(promote (until-clean))))
================================================
FILE: packages/hugin/examples/03-scatter/main.ml
================================================
(* Scatter plots.
Point marks place individual markers at data coordinates. Use color_by to map
a third variable through the theme's sequential colormap. *)
open Hugin
let () =
let x = Nx.rand Nx.float32 [| 200 |] in
let y = Nx.rand Nx.float32 [| 200 |] in
let c = Nx.add x y in
point ~x ~y ~color_by:c ~size:8. ~marker:Circle ()
|> title "Random Scatter" |> render_png "scatter.png"
================================================
FILE: packages/hugin/examples/04-bar-chart/README.md
================================================
# Bar Chart
Bar marks draw vertical bars centered at x positions. Use `~xticks` to label the x-axis with category names.

================================================
FILE: packages/hugin/examples/04-bar-chart/dune
================================================
(executable
(name main)
(libraries hugin nx))
(rule
(targets bar_chart.png)
(deps main.exe)
(action
(run ./main.exe))
(mode
(promote (until-clean))))
================================================
FILE: packages/hugin/examples/04-bar-chart/main.ml
================================================
(* Bar charts.
Bar marks draw vertical bars centered at x positions. Height is measured from
bottom (default 0). *)
open Hugin
let () =
let x = Nx.create Nx.float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
let h = Nx.create Nx.float32 [| 5 |] [| 4.2; 7.1; 3.8; 9.0; 5.5 |] in
bar ~x ~height:h ~color:Color.sky_blue ()
|> title "Quarterly Revenue" |> xlabel "Quarter" |> ylabel "Revenue ($M)"
|> xticks [ (1., "Q1"); (2., "Q2"); (3., "Q3"); (4., "Q4"); (5., "Q5") ]
|> render_png "bar_chart.png"
================================================
FILE: packages/hugin/examples/05-histogram/README.md
================================================
# Histogram
Histogram marks bin continuous data into evenly-spaced intervals. Use `~density:true` to normalize the total area to 1.

================================================
FILE: packages/hugin/examples/05-histogram/dune
================================================
(executable
(name main)
(libraries hugin nx))
(rule
(targets histogram.png)
(deps main.exe)
(action
(run ./main.exe))
(mode
(promote (until-clean))))
================================================
FILE: packages/hugin/examples/05-histogram/main.ml
================================================
(* Histograms.
Histogram marks bin continuous data. Use ~density:true to normalize so the
total area equals 1. *)
open Hugin
let () =
let samples = Nx.rand Nx.float32 [| 1000 |] in
hist ~x:samples ~bins:(`Num 25) ~density:true ~color:Color.green ()
|> title "Distribution" |> xlabel "Value" |> render_png "histogram.png"
================================================
FILE: packages/hugin/examples/06-layers/README.md
================================================
# Layers
`layers` overlays multiple marks on shared axes. Any mark with a `~label` automatically appears in the legend.

================================================
FILE: packages/hugin/examples/06-layers/dune
================================================
(executable
(name main)
(libraries hugin nx))
(rule
(targets layers.png)
(deps main.exe)
(action
(run ./main.exe))
(mode
(promote (until-clean))))
================================================
FILE: packages/hugin/examples/06-layers/main.ml
================================================
(* Layers and legends.
Use layers to overlay different mark types on shared axes. Any mark with a
~label automatically appears in the legend. *)
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. 10. 100 in
let y = Nx.sin x in
layers
[
fill_between ~x ~y1:(Nx.sub_s y 0.3) ~y2:(Nx.add_s y 0.3) ~label:"± 0.3"
();
line ~x ~y ~label:"sin(x)" ();
hline ~y:0. ~line_style:`Dashed ~color:Color.gray ~label:"baseline" ();
]
|> title "Sine with Confidence Band"
|> xlabel "x" |> ylabel "y" |> legend |> render_png "layers.png"
================================================
FILE: packages/hugin/examples/07-decorations/README.md
================================================
# Decorations
Decoration functions (`xscale`, `xlim`, `ylim`, `grid_lines`, `xtick_format`) control axis behavior and compose with `|>`.

================================================
FILE: packages/hugin/examples/07-decorations/dune
================================================
(executable
(name main)
(libraries hugin nx))
(rule
(targets decorations.png)
(deps main.exe)
(action
(run ./main.exe))
(mode
(promote (until-clean))))
================================================
FILE: packages/hugin/examples/07-decorations/main.ml
================================================
(* Axis decorations.
Decorations control axes limits, scales, and grid visibility. They compose
naturally with the |> pipeline. *)
open Hugin
let () =
let x = Nx.linspace Nx.float32 1. 1000. 100 in
let y = Nx.log x in
line ~x ~y () |> title "Logarithmic Scale" |> xlabel "x" |> ylabel "ln(x)"
|> xscale `Log |> xlim 1. 1000. |> ylim 0. 8.
|> xtick_format (Printf.sprintf "%.0f")
|> grid_lines true
|> render_png "decorations.png"
================================================
FILE: packages/hugin/examples/08-grid-layout/README.md
================================================
# Grid Layout
`grid` arranges independent plots in rows and columns. Each cell has its own axes and decorations.

================================================
FILE: packages/hugin/examples/08-grid-layout/dune
================================================
(executable
(name main)
(libraries hugin nx))
(rule
(targets grid_layout.png)
(deps main.exe)
(action
(run ./main.exe))
(mode
(promote (until-clean))))
================================================
FILE: packages/hugin/examples/08-grid-layout/main.ml
================================================
(* Grid layout.
Arrange independent plots in a grid. Each cell is a standalone specification
with its own axes and decorations. *)
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. (2. *. Float.pi) 100 in
let p1 = line ~x ~y:(Nx.sin x) () |> title "sin" in
let p2 = line ~x ~y:(Nx.cos x) () |> title "cos" in
let p3 = line ~x ~y:(Nx.tan (Nx.mul_s x 0.3)) () |> title "tan(0.3x)" in
let p4 =
point ~x ~y:(Nx.sin x) ~color:Color.vermillion ~marker:Plus ()
|> title "sin (scatter)"
in
grid [ [ p1; p2 ]; [ p3; p4 ] ] |> render_png "grid_layout.png"
================================================
FILE: packages/hugin/examples/09-themes/README.md
================================================
# Themes
Themes control visual appearance: background, fonts, axes, grid, and data colors. Context functions like `Theme.talk` scale everything up for presentations.

================================================
FILE: packages/hugin/examples/09-themes/dune
================================================
(executable
(name main)
(libraries hugin nx))
(rule
(targets themes.png)
(deps main.exe)
(action
(run ./main.exe))
(mode
(promote (until-clean))))
================================================
FILE: packages/hugin/examples/09-themes/main.ml
================================================
(* Themes and context scaling.
Themes control the entire visual appearance: colors, fonts, line widths.
Context functions like Theme.talk scale everything up for presentations. *)
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. 10. 80 in
let base =
layers
[
line ~x ~y:(Nx.sin x) ~label:"sin" ();
line ~x ~y:(Nx.cos x) ~label:"cos" ();
]
|> legend
in
grid
[
[
base |> with_theme Theme.default |> title "Default";
base |> with_theme Theme.dark |> title "Dark";
];
[
base |> with_theme Theme.minimal |> title "Minimal";
base |> with_theme (Theme.talk Theme.default) |> title "Talk";
];
]
|> render_png "themes.png"
================================================
FILE: packages/hugin/examples/10-showcase/README.md
================================================
# Showcase
Combines multiple mark types, layouts, and output formats in a single visualization. Renders to both PNG and SVG.

================================================
FILE: packages/hugin/examples/10-showcase/dune
================================================
(executable
(name main)
(libraries hugin nx))
(rule
(targets showcase.png showcase.svg)
(deps main.exe)
(action
(run ./main.exe))
(mode
(promote (until-clean))))
================================================
FILE: packages/hugin/examples/10-showcase/main.ml
================================================
(* Full showcase.
Demonstrates multiple mark types, layouts, themes, and output formats in a
single example. *)
open Hugin
let () =
let x = Nx.linspace Nx.float32 0. 10. 100 in
let p1 =
layers
[
line ~x ~y:(Nx.sin x) ~label:"sin" ~color:Color.blue ();
point
~x:(Nx.mul_s (Nx.rand Nx.float32 [| 30 |]) 10.)
~y:(Nx.sub_s (Nx.mul_s (Nx.rand Nx.float32 [| 30 |]) 2.) 1.)
~color:Color.vermillion ~marker:Star ~label:"noise" ();
]
|> title "Lines & Scatter" |> legend
in
let p2 =
let xb = Nx.create Nx.float32 [| 4 |] [| 1.; 2.; 3.; 4. |] in
let h = Nx.create Nx.float32 [| 4 |] [| 3.; 7.; 2.; 5. |] in
bar ~x:xb ~height:h ~color:Color.orange () |> title "Bar Chart"
in
let p3 =
hist ~x:(Nx.rand Nx.float32 [| 500 |]) ~bins:(`Num 20) ~color:Color.green ()
|> title "Histogram"
in
let p4 =
let xs = Nx.rand Nx.float32 [| 50 |] in
let ys = Nx.rand Nx.float32 [| 50 |] in
let cb = Nx.mul_s xs 100. in
let sb = Nx.mul_s ys 40. in
point ~x:xs ~y:ys ~color_by:cb ~size_by:sb ~marker:Circle ()
|> title "color_by + size_by" |> xlabel "x" |> ylabel "y"
in
let p5 =
let xl = Nx.linspace Nx.float32 1. 100. 50 in
line ~x:xl ~y:(Nx.mul xl xl) ~color:Color.purple ()
|> title "Quadratic (log y)" |> yscale `Log
in
let p6 =
let data =
Nx.init Nx.float32 [| 8; 10 |] (fun idx ->
let i = Float.of_int idx.(0) and j = Float.of_int idx.(1) in
Float.sin (i *. 0.5) *. Float.cos (j *. 0.4))
in
heatmap ~data ~cmap:Cmap.viridis () |> title "Heatmap"
in
let spec = grid [ [ p1; p2 ]; [ p3; p4 ]; [ p5; p6 ] ] in
spec |> render_png "showcase.png";
spec |> render_svg "showcase.svg"
================================================
FILE: packages/hugin/examples/11-errorbar/README.md
================================================
# Error Bars
`errorbar` shows measurement uncertainty. Use `` `Symmetric `` for equal +/- errors or `` `Asymmetric `` for independent lower and upper bounds.

================================================
FILE: packages/hugin/examples/11-errorbar/dune
================================================
(executable
(name main)
(libraries hugin nx))
(rule
(targets errorbar.png)
(deps main.exe)
(action
(run ./main.exe))
(mode
(promote (until-clean))))
================================================
FILE: packages/hugin/examples/11-errorbar/main.ml
================================================
(* Error bars.
Errorbar marks show measurement uncertainty. Use `Symmetric for equal +/-
errors or `Asymmetric for independent lower and upper bounds. *)
open Hugin
let () =
let x = Nx.create Nx.float32 [| 6 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
let y = Nx.create Nx.float32 [| 6 |] [| 2.1; 3.8; 3.2; 5.1; 4.5; 6.3 |] in
let err = Nx.create Nx.float32 [| 6 |] [| 0.3; 0.5; 0.2; 0.6; 0.4; 0.3 |] in
errorbar ~x ~y ~yerr:(`Symmetric err) ~cap_size:6. ~color:Color.blue ()
|> title "Measurements" |> xlabel "Trial" |> ylabel "Value"
|> render_png "errorbar.png"
================================================
FILE: packages/hugin/examples/README.md
================================================
# Hugin Examples
Learn Hugin through progressively complex examples. Start with `01-line-plot`
and work through the numbered examples in order.
## Examples
| Example | Concept | Key Functions |
|---------|---------|---------------|
| [`01-line-plot`](./01-line-plot/) | Your first plot | `line`, `render_png` |
| [`02-styling`](./02-styling/) | Colors, line styles, markers | `~color`, `~line_style`, `~marker`, `~alpha` |
| [`03-scatter`](./03-scatter/) | Scatter plots and color mapping | `point`, `~color_by` |
| [`04-bar-chart`](./04-bar-chart/) | Bar charts with categorical axes | `bar`, `xlabel`, `ylabel`, `xticks` |
| [`05-histogram`](./05-histogram/) | Histograms and density | `hist`, `~bins`, `~density` |
| [`06-layers`](./06-layers/) | Overlaying marks and legends | `layers`, `fill_between`, `hline`, `legend` |
| [`07-decorations`](./07-decorations/) | Axis control and grid lines | `xscale`, `xlim`, `ylim`, `xtick_format`, `grid_lines` |
| [`08-grid-layout`](./08-grid-layout/) | Multi-panel layouts | `Layout.grid` |
| [`09-themes`](./09-themes/) | Themes and context scaling | `Theme.default`, `Theme.dark`, `Theme.talk` |
| [`10-showcase`](./10-showcase/) | Full showcase with multiple outputs | All mark types, `heatmap`, `render_svg` |
| [`11-errorbar`](./11-errorbar/) | Measurement uncertainty | `errorbar`, `~yerr`, `~cap_size` |
## Running Examples
All examples can be run with:
```bash
dune exec dev/hugin/examples//main.exe
```
For example:
```bash
dune exec dev/hugin/examples/01-line-plot/main.exe
```
## Quick Reference
### Single Plot
```ocaml
open Hugin
let x = Nx.linspace Nx.float32 0. 6.28 100 in
let y = Nx.sin x in
line ~x ~y () |> title "Sine" |> render_png "plot.png"
```
### Multiple Marks on Shared Axes
```ocaml
layers
[
line ~x ~y:(Nx.sin x) ~label:"sin" ();
line ~x ~y:(Nx.cos x) ~label:"cos" ~line_style:`Dashed ();
]
|> legend |> render_png "plot.png"
```
### Grid Layout
```ocaml
let p1 = line ~x ~y:(Nx.sin x) () |> title "sin" in
let p2 = line ~x ~y:(Nx.cos x) () |> title "cos" in
Layout.grid [ [ p1; p2 ] ] |> render_png "grid.png"
```
================================================
FILE: packages/hugin/lib/axis.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Per-axis configuration and resolution.
config holds the user-set options (all optional, from decorations). t holds
the resolved axis with defaults applied and data bounds merged. *)
type config = {
label : string option;
lim : (float * float) option;
scale : Spec.scale option;
invert : bool;
ticks : (float * string) list option;
tick_format : (float -> string) option;
}
let empty_config =
{
label = None;
lim = None;
scale = None;
invert = false;
ticks = None;
tick_format = None;
}
type t = {
scale : Spec.scale;
invert : bool;
lo : float;
hi : float;
label : string option;
ticks : (float * string) list option;
tick_format : (float -> string) option;
}
let resolve ~data_lo ~data_hi (c : config) =
let scale = Option.value ~default:`Linear c.scale in
let lo, hi = Option.value ~default:(data_lo, data_hi) c.lim in
{
scale;
invert = c.invert;
lo;
hi;
label = c.label;
ticks = c.ticks;
tick_format = c.tick_format;
}
let make_scale_and_ticks (a : t) =
let s = Scale.make ~invert:a.invert a.scale ~lo:a.lo ~hi:a.hi () in
let ticks =
match a.ticks with
| Some t -> t
| None -> Ticks.generate a.scale ~lo:a.lo ~hi:a.hi ()
in
let ticks =
match a.tick_format with
| None -> ticks
| Some f -> List.map (fun (v, _) -> (v, f v)) ticks
in
(s, ticks)
================================================
FILE: packages/hugin/lib/axis.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Per-axis configuration and resolution.
{b Internal module.} Consolidates per-axis state into two stages: {!config}
holds user-set options from decorations (all optional), and {!t} holds the
resolved axis with defaults applied and data bounds merged. Used by
{!Prepared} and {!Resolve}. *)
(** {1:config Configuration} *)
type config = {
label : string option;
lim : (float * float) option;
scale : Spec.scale option;
invert : bool;
ticks : (float * string) list option;
tick_format : (float -> string) option;
}
(** The type for per-axis user options collected from decorations. *)
val empty_config : config
(** [empty_config] is the default configuration: no label, no limits, [invert]
is [false], scale/ticks/format are [None]. *)
(** {1:resolved Resolved axis} *)
type t = {
scale : Spec.scale;
invert : bool;
lo : float;
hi : float;
label : string option;
ticks : (float * string) list option;
tick_format : (float -> string) option;
}
(** The type for resolved axes. [scale] defaults to [`Linear], [lo] and [hi]
come from data bounds unless overridden by {!config.lim}. *)
val resolve : data_lo:float -> data_hi:float -> config -> t
(** [resolve ~data_lo ~data_hi c] is a resolved axis from [c]. Uses [data_lo]
and [data_hi] when [c.lim] is [None], and [`Linear] when [c.scale] is
[None]. *)
val make_scale_and_ticks : t -> Scale.t * (float * string) list
(** [make_scale_and_ticks a] is [(scale, ticks)] for [a]. Generates ticks via
{!Ticks.generate} when [a.ticks] is [None], then applies [a.tick_format] if
set. *)
================================================
FILE: packages/hugin/lib/cairo_backend.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Color helpers *)
let set_color cr c =
let r, g, b, a = Color.to_rgba c in
Ucairo.set_source_rgba cr r g b a
let set_font cr (font : Theme.font) =
let weight =
match font.weight with `Normal -> Ucairo.Normal | `Bold -> Ucairo.Bold
in
Ucairo.select_font_face cr font.family weight;
Ucairo.set_font_size cr font.size
(* Text measurer *)
let text_measurer cr ~font s =
set_font cr font;
let ext = Ucairo.text_extents cr s in
(ext.width, ext.height)
(* Marker rendering *)
let draw_marker cr shape size (px, py) =
let hs = size /. 2. in
match shape with
| Spec.Circle ->
Ucairo.arc cr px py ~r:hs ~a1:0. ~a2:(2. *. Float.pi);
Ucairo.Path.close cr
| Spec.Square -> Ucairo.rectangle cr (px -. hs) (py -. hs) ~w:size ~h:size
| Spec.Triangle ->
Ucairo.move_to cr px (py -. hs);
Ucairo.line_to cr (px +. hs) (py +. hs);
Ucairo.line_to cr (px -. hs) (py +. hs);
Ucairo.Path.close cr
| Spec.Plus ->
Ucairo.move_to cr (px -. hs) py;
Ucairo.line_to cr (px +. hs) py;
Ucairo.move_to cr px (py -. hs);
Ucairo.line_to cr px (py +. hs)
| Spec.Star ->
Ucairo.move_to cr (px -. hs) py;
Ucairo.line_to cr (px +. hs) py;
Ucairo.move_to cr px (py -. hs);
Ucairo.line_to cr px (py +. hs);
let d = hs *. 0.707 in
Ucairo.move_to cr (px -. d) (py -. d);
Ucairo.line_to cr (px +. d) (py +. d);
Ucairo.move_to cr (px +. d) (py -. d);
Ucairo.line_to cr (px -. d) (py +. d)
(* Primitive rendering *)
let rec render_primitive cr = function
| Scene.Path { points; close; fill; stroke; line_width; dash } ->
if Array.length points < 2 then ()
else begin
let x0, y0 = points.(0) in
Ucairo.move_to cr x0 y0;
for i = 1 to Array.length points - 1 do
let x, y = points.(i) in
Ucairo.line_to cr x y
done;
if close then Ucairo.Path.close cr;
begin match fill with
| Some c ->
set_color cr c;
if stroke <> None then Ucairo.fill_preserve cr else Ucairo.fill cr
| None -> ()
end;
begin match stroke with
| Some c ->
set_color cr c;
Ucairo.set_line_width cr line_width;
(match dash with
| [] -> Ucairo.set_dash cr [||]
| ds -> Ucairo.set_dash cr (Array.of_list ds));
Ucairo.stroke cr
| None -> ()
end
end
| Scene.Markers { points; shape; size; sizes; fill; fills; stroke } ->
let stroke_only =
match shape with Spec.Plus | Spec.Star -> true | _ -> false
in
Array.iteri
(fun i pt ->
let s = match sizes with Some ss -> ss.(i) | None -> size in
let f = match fills with Some fs -> Some fs.(i) | None -> fill in
Ucairo.Path.clear cr;
draw_marker cr shape s pt;
if stroke_only then begin
let c =
match f with
| Some c -> c
| None -> ( match stroke with Some c -> c | None -> Color.black)
in
set_color cr c;
Ucairo.set_line_width cr (Float.max 1. (s *. 0.15));
Ucairo.stroke cr
end
else begin
begin match f with
| Some c ->
set_color cr c;
if stroke <> None then Ucairo.fill_preserve cr
else Ucairo.fill cr
| None -> ()
end;
begin match stroke with
| Some c ->
set_color cr c;
Ucairo.set_line_width cr (Float.max 1. (s *. 0.15));
Ucairo.stroke cr
| None -> ()
end
end)
points
| Scene.Text { x; y; content; font; color; anchor; baseline; angle } ->
set_font cr font;
set_color cr color;
let ext = Ucairo.text_extents cr content in
let dx =
match anchor with
| `Start -> -.ext.x_bearing
| `Middle -> -.(ext.x_bearing +. (ext.width /. 2.))
| `End -> -.(ext.x_bearing +. ext.width)
in
let dy =
match baseline with
| `Top -> -.ext.y_bearing
| `Middle -> -.(ext.y_bearing +. (ext.height /. 2.))
| `Bottom -> -.(ext.y_bearing +. ext.height)
in
Ucairo.save cr;
Ucairo.translate cr x y;
if angle <> 0. then Ucairo.rotate cr angle;
Ucairo.move_to cr dx dy;
Ucairo.show_text cr content;
Ucairo.restore cr
| Scene.Image { x; y; w; h; data } ->
let img_surface = Image_util.nx_to_cairo_surface data in
let img_w = (Nx.shape data).(1) and img_h = (Nx.shape data).(0) in
Ucairo.save cr;
Ucairo.translate cr x y;
Ucairo.scale cr (w /. float img_w) (h /. float img_h);
Ucairo.set_source_surface cr img_surface ~x:0. ~y:0.;
Ucairo.paint cr;
Ucairo.restore cr;
Ucairo.Surface.finish img_surface
| Scene.Clip { x; y; w; h; children } ->
Ucairo.save cr;
Ucairo.rectangle cr x y ~w ~h;
Ucairo.clip cr;
List.iter (render_primitive cr) children;
Ucairo.restore cr
| Scene.Group children -> List.iter (render_primitive cr) children
(* Scene rendering *)
let render_scene cr (scene : Scene.t) =
Ucairo.set_antialias cr Ucairo.Antialias_default;
Ucairo.set_line_cap cr Ucairo.Round;
Ucairo.set_line_join cr Ucairo.Join_round;
List.iter (render_primitive cr) scene.primitives
(* Entry points *)
let render_to_png filename ~width ~height (scene : Scene.t) =
let w = int_of_float width and h = int_of_float height in
let surface = Ucairo.Image.create ~w ~h in
let cr = Ucairo.create surface in
render_scene cr scene;
Ucairo.Png.write surface filename;
Ucairo.Surface.finish surface
let render_to_pdf filename ~width ~height (scene : Scene.t) =
let surface = Ucairo.Pdf.create filename ~w:width ~h:height in
let cr = Ucairo.create surface in
render_scene cr scene;
Ucairo.Surface.finish surface
let render_to_buffer ~width ~height (scene : Scene.t) =
let w = int_of_float width and h = int_of_float height in
let surface = Ucairo.Image.create ~w ~h in
let cr = Ucairo.create surface in
render_scene cr scene;
let buf = Buffer.create 4096 in
Ucairo.Png.write_to_stream surface (Buffer.add_string buf);
Ucairo.Surface.finish surface;
Buffer.contents buf
let show_interactive ~theme ~width ~height prepared =
let w = int_of_float width and h = int_of_float height in
let csdl = Cairo_sdl.create ~width:w ~height:h ~title:"Hugin" in
let render_current () =
let cr = Cairo_sdl.context csdl in
let cw = float (Cairo_sdl.width csdl) in
let ch = float (Cairo_sdl.height csdl) in
let tm = text_measurer cr in
let scene =
Resolve.resolve_prepared ~text_measurer:tm ~theme ~width:cw ~height:ch
prepared
in
render_scene cr scene;
Cairo_sdl.present csdl
in
render_current ();
let ev = Usdl.Event.create () in
let quit = ref false in
while not !quit do
if not (Usdl.Event.wait ev) then quit := true
else
begin match Usdl.Event.typ ev with
| `Quit -> quit := true
| `Window_event ->
begin match Usdl.Event.window_event_id ev with
| `Resized | `Size_changed ->
Cairo_sdl.resize csdl;
render_current ()
| `Exposed -> render_current ()
| `Close -> quit := true
| _ -> ()
end
| `Key_down ->
let keycode = Usdl.Event.keycode ev in
if keycode = Usdl.Keycode.escape || keycode = Usdl.Keycode.q then
quit := true
| _ -> ()
end
done;
Cairo_sdl.destroy csdl
================================================
FILE: packages/hugin/lib/cairo_backend.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Cairo rendering backend.
{b Internal module.} Renders {!Scene.t} to PNG, PDF, or an interactive SDL
window via Cairo. *)
(** {1:measurer Text measurement} *)
val text_measurer : Ucairo.t -> Resolve.text_measurer
(** [text_measurer cr] is a text measurer backed by {!Ucairo.text_extents}. *)
(** {1:rendering Rendering} *)
val render_scene : Ucairo.t -> Scene.t -> unit
(** [render_scene cr scene] draws [scene] onto [cr]. *)
val render_to_png : string -> width:float -> height:float -> Scene.t -> unit
(** [render_to_png filename ~width ~height scene] writes [scene] as a PNG image.
*)
val render_to_pdf : string -> width:float -> height:float -> Scene.t -> unit
(** [render_to_pdf filename ~width ~height scene] writes [scene] as a
single-page PDF. *)
val render_to_buffer : width:float -> height:float -> Scene.t -> string
(** [render_to_buffer ~width ~height scene] is the PNG-encoded contents of
[scene] as a string. *)
(** {1:interactive Interactive display} *)
val show_interactive :
theme:Theme.t -> width:float -> height:float -> Prepared.t -> unit
(** [show_interactive ~theme ~width ~height prepared] opens an SDL window and
renders [prepared]. Compiles data once; only re-resolves layout on resize.
Exits on Escape, Q, or window close. *)
================================================
FILE: packages/hugin/lib/cairo_sdl.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Cairo-SDL integration: shared ARGB8888 surface *)
type t = {
window : Usdl.Window.t;
renderer : Usdl.Renderer.t;
mutable surface : Usdl.Surface.t;
mutable cairo_surface : Ucairo.surface;
mutable context : Ucairo.t;
mutable width : int;
mutable height : int;
}
let make_cairo_context surface =
let pixels = Usdl.Surface.pixels surface in
let stride = Usdl.Surface.pitch surface in
let total = Bigarray.Array1.dim pixels in
let h = total / stride in
let w = stride / 4 in
let cs = Ucairo.Image.create_for_data8 pixels ~w ~h ~stride in
let cr = Ucairo.create cs in
(cs, cr, w, h)
let create ~width ~height ~title =
Usdl.init ();
let window = Usdl.Window.create ~title ~w:width ~h:height in
let renderer = Usdl.Renderer.create window in
let ow, oh = Usdl.Renderer.output_size renderer in
let surface = Usdl.Surface.create_argb8888 ~w:ow ~h:oh in
let cairo_surface, context, w, h = make_cairo_context surface in
{ window; renderer; surface; cairo_surface; context; width = w; height = h }
let context t = t.context
let width t = t.width
let height t = t.height
let present t =
Ucairo.Surface.flush t.cairo_surface;
let tex = Usdl.Texture.of_surface t.renderer t.surface in
Usdl.Renderer.clear t.renderer;
Usdl.Renderer.copy t.renderer tex;
Usdl.Renderer.present t.renderer;
Usdl.Texture.destroy tex
let resize t =
let nw, nh = Usdl.Renderer.output_size t.renderer in
if nw <> t.width || nh <> t.height then
begin if nw > 0 && nh > 0 then begin
Ucairo.Surface.finish t.cairo_surface;
Usdl.Surface.destroy t.surface;
let surface = Usdl.Surface.create_argb8888 ~w:nw ~h:nh in
let cairo_surface, context, w, h = make_cairo_context surface in
t.surface <- surface;
t.cairo_surface <- cairo_surface;
t.context <- context;
t.width <- w;
t.height <- h
end
end
let destroy t =
Ucairo.Surface.finish t.cairo_surface;
Usdl.Surface.destroy t.surface;
Usdl.Renderer.destroy t.renderer;
Usdl.Window.destroy t.window;
Usdl.quit ()
================================================
FILE: packages/hugin/lib/cairo_sdl.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Cairo-SDL integration.
{b Internal module.} Manages a shared ARGB8888 surface between Cairo and SDL
for interactive rendering. *)
type t
(** The type for Cairo-SDL contexts. *)
val create : width:int -> height:int -> title:string -> t
(** [create ~width ~height ~title] initializes SDL, creates a resizable window,
and sets up a shared Cairo surface.
Raises [Failure] if SDL initialization fails. *)
val context : t -> Ucairo.t
(** [context t] is the current Cairo drawing context. Valid until the next
{!present} or {!resize}. *)
val width : t -> int
(** [width t] is the current surface width in pixels. *)
val height : t -> int
(** [height t] is the current surface height in pixels. *)
val present : t -> unit
(** [present t] flushes the Cairo surface to the SDL window and prepares a fresh
Cairo context for the next frame. *)
val resize : t -> unit
(** [resize t] updates the surface dimensions to match the renderer output size.
No-op if the size has not changed. *)
val destroy : t -> unit
(** [destroy t] frees all SDL and Cairo resources and calls {!Usdl.quit}. *)
================================================
FILE: packages/hugin/lib/cmap.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
type t = Color.t array
let eval t v =
let v = Float.max 0. (Float.min 1. v) in
let i = int_of_float (v *. 255.) in
t.(min i 255)
let of_colors stops =
let n = Array.length stops in
if n < 2 then invalid_arg "Cmap.of_colors: need at least 2 stops";
Array.init 256 (fun i ->
let v = float i /. 255. in
let scaled = v *. float (n - 1) in
let idx = int_of_float scaled in
let idx = min idx (n - 2) in
let frac = scaled -. float idx in
Color.mix frac stops.(idx) stops.(idx + 1))
(* Decode a canonical 256-entry hex-encoded colormap *)
let hex_digit c =
match c with
| '0' .. '9' -> Char.code c - Char.code '0'
| 'a' .. 'f' -> 10 + Char.code c - Char.code 'a'
| 'A' .. 'F' -> 10 + Char.code c - Char.code 'A'
| _ -> invalid_arg (Printf.sprintf "Cmap.hex_digit: invalid hex digit %C" c)
let decode_hex_cmap hex =
Array.init 256 (fun i ->
let off = i * 6 in
let byte j =
let h = hex_digit (String.unsafe_get hex (off + (j * 2))) in
let l = hex_digit (String.unsafe_get hex (off + (j * 2) + 1)) in
float ((h lsl 4) lor l) /. 255.
in
Color.rgb ~r:(byte 0) ~g:(byte 1) ~b:(byte 2) ())
let viridis = decode_hex_cmap Cmap_data.viridis_hex
let plasma = decode_hex_cmap Cmap_data.plasma_hex
let inferno = decode_hex_cmap Cmap_data.inferno_hex
let magma = decode_hex_cmap Cmap_data.magma_hex
let cividis = decode_hex_cmap Cmap_data.cividis_hex
let coolwarm = decode_hex_cmap Cmap_data.coolwarm_hex
let gray =
of_colors [| Color.rgb ~r:0. ~g:0. ~b:0. (); Color.rgb ~r:1. ~g:1. ~b:1. () |]
let gray_r =
of_colors [| Color.rgb ~r:1. ~g:1. ~b:1. (); Color.rgb ~r:0. ~g:0. ~b:0. () |]
let hot =
of_colors
[|
Color.rgb ~r:0. ~g:0. ~b:0. ();
Color.rgb ~r:0.7 ~g:0. ~b:0. ();
Color.rgb ~r:1. ~g:0.6 ~b:0. ();
Color.rgb ~r:1. ~g:1. ~b:1. ();
|]
================================================
FILE: packages/hugin/lib/cmap.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Colormaps.
A colormap is a continuous mapping from \[[0];[1]\] to {!Color.t}.
Internally stored as a 256-entry lookup table with OKLCH interpolation, so
{!eval} is a single array access. *)
(** {1:types Types} *)
type t
(** The type for colormaps. *)
(** {1:eval Evaluation} *)
val eval : t -> float -> Color.t
(** [eval cmap v] is the color at position [v], clamped to \[[0];[1]\]. *)
(** {1:constructors Constructors} *)
val of_colors : Color.t array -> t
(** [of_colors stops] is a colormap interpolating linearly through [stops] in
OKLCH space. The stops are evenly spaced from [0] to [1].
Raises [Invalid_argument] if [stops] has fewer than 2 elements. *)
(** {1:predefined Predefined colormaps}
Perceptually uniform sequential colormaps from the
{{:https://bids.github.io/colormap/}viridis family}, plus a diverging
colormap. *)
val viridis : t
val plasma : t
val inferno : t
val magma : t
val cividis : t
val coolwarm : t
val gray : t
(** Linear grayscale (black to white). *)
val gray_r : t
(** Reversed grayscale (white to black). The standard default for astronomical
image display. *)
val hot : t
(** Black-red-yellow-white. Common in X-ray astronomy. *)
================================================
FILE: packages/hugin/lib/cmap_data.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Canonical 256-entry colormap data encoded as hex strings. Each string is 1536
characters: 256 entries of 6 hex chars (RRGGBB). *)
let viridis_hex =
"44015444025645045745055946075a46085c460a5d460b5e470d60470e6147106347116447136548146748166848176948186a481a6c481b6d481c6e481d6f481f70482071482173482374482475482576482677482878482979472a7a472c7a472d7b472e7c472f7d46307e46327e46337f463480453581453781453882443983443a83443b84433d84433e85423f854240864241864142874144874045884046883f47883f48893e49893e4a893e4c8a3d4d8a3d4e8a3c4f8a3c508b3b518b3b528b3a538b3a548c39558c39568c38588c38598c375a8c375b8d365c8d365d8d355e8d355f8d34608d34618d33628d33638d32648e32658e31668e31678e31688e30698e306a8e2f6b8e2f6c8e2e6d8e2e6e8e2e6f8e2d708e2d718e2c718e2c728e2c738e2b748e2b758e2a768e2a778e2a788e29798e297a8e297b8e287c8e287d8e277e8e277f8e27808e26818e26828e26828e25838e25848e25858e24868e24878e23888e23898e238a8d228b8d228c8d228d8d218e8d218f8d21908d21918c20928c20928c20938c1f948c1f958b1f968b1f978b1f988b1f998a1f9a8a1e9b8a1e9c891e9d891f9e891f9f881fa0881fa1881fa1871fa28720a38620a48621a58521a68522a78522a88423a98324aa8325ab8225ac8226ad8127ad8128ae8029af7f2ab07f2cb17e2db27d2eb37c2fb47c31b57b32b67a34b67935b77937b87838b9773aba763bbb753dbc743fbc7340bd7242be7144bf7046c06f48c16e4ac16d4cc26c4ec36b50c46a52c56954c56856c66758c7655ac8645cc8635ec96260ca6063cb5f65cb5e67cc5c69cd5b6ccd5a6ece5870cf5773d05675d05477d1537ad1517cd2507fd34e81d34d84d44b86d54989d5488bd6468ed64590d74393d74195d84098d83e9bd93c9dd93ba0da39a2da37a5db36a8db34aadc32addc30b0dd2fb2dd2db5de2bb8de29bade28bddf26c0df25c2df23c5e021c8e020cae11fcde11dd0e11cd2e21bd5e21ad8e219dae319dde318dfe318e2e418e5e419e7e419eae51aece51befe51cf1e51df4e61ef6e620f8e621fbe723fde725"
let plasma_hex =
"0d088710078813078916078a19068c1b068d1d068e20068f2206902406912605912805922a05932c05942e05952f059631059733059735049837049938049a3a049a3c049b3e049c3f049c41049d43039e44039e46039f48039f4903a04b03a14c02a14e02a25002a25102a35302a35502a45601a45801a45901a55b01a55c01a65e01a66001a66100a76300a76400a76600a76700a86900a86a00a86c00a86e00a86f00a87100a87201a87401a87501a87701a87801a87a02a87b02a87d03a87e03a88004a88104a78305a78405a78606a68707a68808a68a09a58b0aa58d0ba58e0ca48f0da4910ea3920fa39410a29511a19613a19814a099159f9a169f9c179e9d189d9e199da01a9ca11b9ba21d9aa31e9aa51f99a62098a72197a82296aa2395ab2494ac2694ad2793ae2892b02991b12a90b22b8fb32c8eb42e8db52f8cb6308bb7318ab83289ba3388bb3488bc3587bd3786be3885bf3984c03a83c13b82c23c81c33d80c43e7fc5407ec6417dc7427cc8437bc9447aca457acb4679cc4778cc4977cd4a76ce4b75cf4c74d04d73d14e72d24f71d35171d45270d5536fd5546ed6556dd7566cd8576bd9586ada5a6ada5b69db5c68dc5d67dd5e66de5f65de6164df6263e06363e16462e26561e26660e3685fe4695ee56a5de56b5de66c5ce76e5be76f5ae87059e97158e97257ea7457eb7556eb7655ec7754ed7953ed7a52ee7b51ef7c51ef7e50f07f4ff0804ef1814df1834cf2844bf3854bf3874af48849f48948f58b47f58c46f68d45f68f44f79044f79143f79342f89441f89540f9973ff9983ef99a3efa9b3dfa9c3cfa9e3bfb9f3afba139fba238fca338fca537fca636fca835fca934fdab33fdac33fdae32fdaf31fdb130fdb22ffdb42ffdb52efeb72dfeb82cfeba2cfebb2bfebd2afebe2afec029fdc229fdc328fdc527fdc627fdc827fdca26fdcb26fccd25fcce25fcd025fcd225fbd324fbd524fbd724fad824fada24f9dc24f9dd25f8df25f8e125f7e225f7e425f6e626f6e826f5e926f5eb27f4ed27f3ee27f3f027f2f227f1f426f1f525f0f724f0f921"
let inferno_hex =
"00000401000501010601010802010a02020c02020e03021004031204031405041706041907051b08051d09061f0a07220b07240c08260d08290e092b10092d110a30120a32140b34150b37160b39180c3c190c3e1b0c411c0c431e0c451f0c48210c4a230c4c240c4f260c51280b53290b552b0b572d0b592f0a5b310a5c320a5e340a5f3609613809623909633b09643d09653e0966400a67420a68440a68450a69470b6a490b6a4a0c6b4c0c6b4d0d6c4f0d6c510e6c520e6d540f6d550f6d57106e59106e5a116e5c126e5d126e5f136e61136e62146e64156e65156e67166e69166e6a176e6c186e6d186e6f196e71196e721a6e741a6e751b6e771c6d781c6d7a1d6d7c1d6d7d1e6d7f1e6c801f6c82206c84206b85216b87216b88226a8a226a8c23698d23698f24699025689225689326679526679727669827669a28659b29649d29649f2a63a02a63a22b62a32c61a52c60a62d60a82e5fa92e5eab2f5ead305dae305cb0315bb1325ab3325ab43359b63458b73557b93556ba3655bc3754bd3853bf3952c03a51c13a50c33b4fc43c4ec63d4dc73e4cc83f4bca404acb4149cc4248ce4347cf4446d04545d24644d34743d44842d54a41d74b3fd84c3ed94d3dda4e3cdb503bdd513ade5238df5337e05536e15635e25734e35933e45a31e55c30e65d2fe75e2ee8602de9612bea632aeb6429eb6628ec6726ed6925ee6a24ef6c23ef6e21f06f20f1711ff1731df2741cf3761bf37819f47918f57b17f57d15f67e14f68013f78212f78410f8850ff8870ef8890cf98b0bf98c0af98e09fa9008fa9207fa9407fb9606fb9706fb9906fb9b06fb9d07fc9f07fca108fca309fca50afca60cfca80dfcaa0ffcac11fcae12fcb014fcb216fcb418fbb61afbb81dfbba1ffbbc21fbbe23fac026fac228fac42afac62df9c72ff9c932f9cb35f8cd37f8cf3af7d13df7d340f6d543f6d746f5d949f5db4cf4dd4ff4df53f4e156f3e35af3e55df2e661f2e865f2ea69f1ec6df1ed71f1ef75f1f179f2f27df2f482f3f586f3f68af4f88ef5f992f6fa96f8fb9af9fc9dfafda1fcffa4"
let magma_hex =
"00000401000501010601010802010902020b02020d03030f03031204041405041606051806051a07061c08071e0907200a08220b09240c09260d0a290e0b2b100b2d110c2f120d31130d34140e36150e38160f3b180f3d19103f1a10421c10441d11471e114920114b21114e22115024125325125527125829115a2a115c2c115f2d11612f116331116533106734106936106b38106c390f6e3b0f703d0f713f0f72400f74420f75440f764510774710784910784a10794c117a4e117b4f127b51127c52137c54137d56147d57157e59157e5a167e5c167f5d177f5f187f601880621980641a80651a80671b80681c816a1c816b1d816d1d816e1e81701f81721f817320817521817621817822817922827b23827c23827e24828025828125818326818426818627818827818928818b29818c29818e2a81902a81912b81932b80942c80962c80982d80992d809b2e7f9c2e7f9e2f7fa02f7fa1307ea3307ea5317ea6317da8327daa337dab337cad347cae347bb0357bb2357bb3367ab5367ab73779b83779ba3878bc3978bd3977bf3a77c03a76c23b75c43c75c53c74c73d73c83e73ca3e72cc3f71cd4071cf4070d0416fd2426fd3436ed5446dd6456cd8456cd9466bdb476adc4869de4968df4a68e04c67e24d66e34e65e44f64e55064e75263e85362e95462ea5661eb5760ec5860ed5a5fee5b5eef5d5ef05f5ef1605df2625df2645cf3655cf4675cf4695cf56b5cf66c5cf66e5cf7705cf7725cf8745cf8765cf9785df9795df97b5dfa7d5efa7f5efa815ffb835ffb8560fb8761fc8961fc8a62fc8c63fc8e64fc9065fd9266fd9467fd9668fd9869fd9a6afd9b6bfe9d6cfe9f6dfea16efea36ffea571fea772fea973feaa74feac76feae77feb078feb27afeb47bfeb67cfeb77efeb97ffebb81febd82febf84fec185fec287fec488fec68afec88cfeca8dfecc8ffecd90fecf92fed194fed395fed597fed799fed89afdda9cfddc9efddea0fde0a1fde2a3fde3a5fde5a7fde7a9fde9aafdebacfcecaefceeb0fcf0b2fcf2b4fcf4b6fcf6b8fcf7b9fcf9bbfcfbbdfcfdbf"
let cividis_hex =
"00224e00234f00245100255300255400265600275800285900285b00295d002a5f002a61002b62002c64002c66002d68002e6a002e6c002f6d00306f0030700031700031710132710533710833700c34700f357012357014367016377018376f1a386f1c396f1e3a6f203a6f213b6e233c6e243c6e263d6e273e6e293f6e2a3f6d2b406d2d416d2e416d2f426d31436d32436d33446d34456c35456c36466c38476c39486c3a486c3b496c3c4a6c3d4a6c3e4b6c3f4c6c404c6c414d6c424e6c434e6c444f6c45506c46516c47516c48526c49536c4a536c4b546c4c556c4d556c4e566c4f576c50576c51586d52596d535a6d545a6d555b6d555c6d565c6d575d6d585e6d595e6e5a5f6e5b606e5c616e5d616e5e626e5e636f5f636f60646f61656f62656f636670646770656870656870666970676a71686a71696b716a6c716b6d726c6d726c6e726d6f726e6f736f70737071737172747272747273747374757474757575757676767777767777777878777979777a7a787b7a787c7b787d7c787e7c787e7d787f7e78807f78817f788280798381798482798582798683798784788885788985788a86788b87788c88788d88788e89788f8a78908b78918b78928c78928d78938e78948e77958f779690779791779892779992779a93769b94769c95769d95769e96769f9775a09875a19975a29975a39a74a49b74a59c74a69c74a79d73a89e73a99f73aaa073aba072aca172ada272aea371afa471b0a571b1a570b3a670b4a76fb5a86fb6a96fb7a96eb8aa6eb9ab6dbaac6dbbad6dbcae6cbdae6cbeaf6bbfb06bc0b16ac1b26ac2b369c3b369c4b468c5b568c6b667c7b767c8b866c9b965cbb965ccba64cdbb63cebc63cfbd62d0be62d1bf61d2c060d3c05fd4c15fd5c25ed6c35dd7c45cd9c55cdac65bdbc75adcc859ddc858dec958dfca57e0cb56e1cc55e2cd54e4ce53e5cf52e6d051e7d150e8d24fe9d34eead34cebd44bedd54aeed649efd748f0d846f1d945f2da44f3db42f5dc41f6dd3ff7de3ef8df3cf9e03afbe138fce236fde334fee434fee535fee636fee838"
let coolwarm_hex =
"3b4cc03c4ec23d50c33e51c53f53c64055c84257c94358cb445acc455cce465ecf485fd14961d24a63d34b64d54c66d64e68d84f69d9506bda516ddb536edd5470de5572df5673e05875e15977e35a78e45b7ae55d7ce65e7de75f7fe86180e96282ea6384eb6485ec6687ed6788ee688aef6a8bef6b8df06c8ff16e90f26f92f37093f37295f47396f57597f67699f6779af7799cf87a9df87b9ff97da0f97ea1fa80a3fa81a4fb82a6fb84a7fc85a8fc86a9fc88abfd89acfd8badfd8caffe8db0fe8fb1fe90b2fe92b4fe93b5fe94b6ff96b7ff97b8ff98b9ff9abbff9bbcff9dbdff9ebeff9fbfffa1c0ffa2c1ffa3c2fea5c3fea6c4fea7c5fea9c6fdaac7fdabc8fdadc9fdaec9fcafcafcb1cbfcb2ccfbb3cdfbb5cdfab6cefab7cff9b9d0f9bad0f8bbd1f8bcd2f7bed2f6bfd3f6c0d4f5c1d4f4c3d5f4c4d5f3c5d6f2c6d6f1c7d7f0c9d7f0cad8efcbd8eeccd9edcdd9eccedaebcfdaead1dae9d2dbe8d3dbe7d4dbe6d5dbe5d6dce4d7dce3d8dce2d9dce1dadce0dbdcdedcdddddddcdcdedcdbdfdbd9e0dbd8e1dad6e2dad5e3d9d3e4d9d2e5d8d1e6d7cfe7d7cee8d6cce9d5cbead5c9ead4c8ebd3c6ecd3c5edd2c3edd1c2eed0c0efcfbfefcebdf0cdbbf1cdbaf1ccb8f2cbb7f2cab5f2c9b4f3c8b2f3c7b1f4c6aff4c5adf5c4acf5c2aaf5c1a9f5c0a7f6bfa6f6bea4f6bda2f7bca1f7ba9ff7b99ef7b89cf7b79bf7b599f7b497f7b396f7b194f7b093f7af91f7ad90f7ac8ef7aa8cf7a98bf7a889f7a688f6a586f6a385f6a283f5a081f59f80f59d7ef59c7df49a7bf4987af39778f39577f39475f29274f29072f18f71f18d6ff08b6ef08a6cef886bee8669ee8468ed8366ec8165ec7f63eb7d62ea7b60e97a5fe9785de8765ce7745be67259e57058e46e56e36c55e36b54e26952e16751e0654fdf634ede614ddd5f4bdc5d4ada5a49d95847d85646d75445d65244d55042d44e41d24b40d1493fd0473dcf453ccd423bcc403acb3e38ca3b37c83836c73635c53334c43032c32e31c12b30c0282fbe242ebd1f2dbb1b2cba162bb8122ab70d28b50927b40426"
================================================
FILE: packages/hugin/lib/color.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
type t = { l : float; c : float; h : float; a : float }
(* Constructors *)
let oklch ~l ~c ~h () = { l; c; h; a = 1. }
let oklcha ~l ~c ~h ~a () = { l; c; h; a }
(* sRGB <-> linear RGB *)
let srgb_to_linear c =
if c <= 0.04045 then c /. 12.92 else Float.pow ((c +. 0.055) /. 1.055) 2.4
let linear_to_srgb c =
if c <= 0.0031308 then 12.92 *. c
else (1.055 *. Float.pow c (1. /. 2.4)) -. 0.055
(* Linear RGB -> OKLab *)
let linear_rgb_to_oklab r g b =
let l = (0.4122214708 *. r) +. (0.5363325363 *. g) +. (0.0514459929 *. b) in
let m = (0.2119034982 *. r) +. (0.6806995451 *. g) +. (0.1073969566 *. b) in
let s = (0.0883024619 *. r) +. (0.2164557896 *. g) +. (0.6898418685 *. b) in
let l = Float.cbrt l and m = Float.cbrt m and s = Float.cbrt s in
let lab_l =
(0.2104542553 *. l) +. (0.7936177850 *. m) -. (0.0040720468 *. s)
in
let lab_a =
(1.9779984951 *. l) -. (2.4285922050 *. m) +. (0.4505937099 *. s)
in
let lab_b =
(0.0259040371 *. l) +. (0.7827717662 *. m) -. (0.8086757660 *. s)
in
(lab_l, lab_a, lab_b)
(* OKLab -> linear RGB *)
let oklab_to_linear_rgb lab_l lab_a lab_b =
let l = lab_l +. (0.3963377774 *. lab_a) +. (0.2158037573 *. lab_b) in
let m = lab_l -. (0.1055613458 *. lab_a) -. (0.0638541728 *. lab_b) in
let s = lab_l -. (0.0894841775 *. lab_a) -. (1.2914855480 *. lab_b) in
let l = l *. l *. l and m = m *. m *. m and s = s *. s *. s in
let r = (4.0767416621 *. l) -. (3.3077115913 *. m) +. (0.2309699292 *. s) in
let g = (-1.2684380046 *. l) +. (2.6097574011 *. m) -. (0.3413193965 *. s) in
let b = (-0.0041960863 *. l) -. (0.7034186147 *. m) +. (1.7076147010 *. s) in
(r, g, b)
(* OKLab <-> OKLCH *)
let oklab_to_oklch lab_l lab_a lab_b =
let c = Float.sqrt ((lab_a *. lab_a) +. (lab_b *. lab_b)) in
let h = Float.atan2 lab_b lab_a *. 180. /. Float.pi in
let h = if h < 0. then h +. 360. else h in
(lab_l, c, h)
let oklch_to_oklab l c h =
let h_rad = h *. Float.pi /. 180. in
(l, c *. Float.cos h_rad, c *. Float.sin h_rad)
(* sRGB -> OKLCH *)
let of_srgb r g b =
let lr = srgb_to_linear r
and lg = srgb_to_linear g
and lb = srgb_to_linear b in
let lab_l, lab_a, lab_b = linear_rgb_to_oklab lr lg lb in
oklab_to_oklch lab_l lab_a lab_b
(* OKLCH -> sRGB *)
let to_srgb l c h =
let lab_l, lab_a, lab_b = oklch_to_oklab l c h in
let lr, lg, lb = oklab_to_linear_rgb lab_l lab_a lab_b in
let clamp v = Float.max 0. (Float.min 1. v) in
( linear_to_srgb (clamp lr),
linear_to_srgb (clamp lg),
linear_to_srgb (clamp lb) )
let rgb ~r ~g ~b () =
let l, c, h = of_srgb r g b in
{ l; c; h; a = 1. }
let rgba ~r ~g ~b ~a () =
let l, c, h = of_srgb r g b in
{ l; c; h; a }
let hex_digit c =
match c with
| '0' .. '9' -> Char.code c - Char.code '0'
| 'a' .. 'f' -> Char.code c - Char.code 'a' + 10
| 'A' .. 'F' -> Char.code c - Char.code 'A' + 10
| _ -> invalid_arg (Printf.sprintf "Color.hex: invalid hex digit %C" c)
let hex_byte s i =
let hi = hex_digit (String.get s i) in
let lo = hex_digit (String.get s (i + 1)) in
float ((hi * 16) + lo) /. 255.
let hex s =
let n = String.length s in
let off = if n > 0 && String.get s 0 = '#' then 1 else 0 in
let len = n - off in
match len with
| 6 ->
let r = hex_byte s off
and g = hex_byte s (off + 2)
and b = hex_byte s (off + 4) in
rgb ~r ~g ~b ()
| 8 ->
let r = hex_byte s off and g = hex_byte s (off + 2) in
let b = hex_byte s (off + 4) and a = hex_byte s (off + 6) in
rgba ~r ~g ~b ~a ()
| _ ->
invalid_arg
(Printf.sprintf "Color.hex: expected 6 or 8 hex digits, got %d" len)
(* Accessors *)
let lightness t = t.l
let chroma t = t.c
let hue t = t.h
let alpha t = t.a
(* Converting *)
let to_rgba t =
let r, g, b = to_srgb t.l t.c t.h in
(r, g, b, t.a)
(* Operations *)
let with_alpha a t = { t with a }
let lighten amount t = { t with l = Float.min 1. (t.l +. amount) }
let darken amount t = { t with l = Float.max 0. (t.l -. amount) }
let interpolate_hue ratio h1 h2 =
let diff = h2 -. h1 in
let diff =
if diff > 180. then diff -. 360.
else if diff < -180. then diff +. 360.
else diff
in
let h = h1 +. (ratio *. diff) in
if h < 0. then h +. 360. else if h >= 360. then h -. 360. else h
let mix ratio a b =
{
l = a.l +. (ratio *. (b.l -. a.l));
c = a.c +. (ratio *. (b.c -. a.c));
h = interpolate_hue ratio a.h b.h;
a = a.a +. (ratio *. (b.a -. a.a));
}
(* Named colors — Okabe-Ito *)
let orange = hex "#E69F00"
let sky_blue = hex "#56B4E9"
let green = hex "#009E73"
let yellow = hex "#F0E442"
let blue = hex "#0072B2"
let vermillion = hex "#D55E00"
let purple = hex "#CC79A7"
let black = { l = 0.; c = 0.; h = 0.; a = 1. }
let white = { l = 1.; c = 0.; h = 0.; a = 1. }
let gray = oklch ~l:0.5 ~c:0. ~h:0. ()
(* Formatting *)
let pp fmt t = Format.fprintf fmt "oklch(%g %g %g / %g)" t.l t.c t.h t.a
================================================
FILE: packages/hugin/lib/color.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Perceptually uniform colors.
Colors are represented internally in the
{{:https://bottosson.github.io/posts/oklab/}OKLCH} color space. All
operations ({!lighten}, {!darken}, {!mix}) produce perceptually uniform
results: equal numerical steps yield equal perceived differences.
Constructors accept common input formats (sRGB, hex) and convert to OKLCH on
creation. The reverse conversion {!to_rgba} is called only at render time.
*)
(** {1:types Types} *)
type t
(** The type for colors in OKLCH space. Components are lightness \[0, 1\],
chroma \[0, ~0.4\], hue \[0, 360), and alpha \[0, 1\]. *)
(** {1:constructors Constructors} *)
val oklch : l:float -> c:float -> h:float -> unit -> t
(** [oklch ~l ~c ~h ()] is the fully opaque OKLCH color with lightness [l],
chroma [c], and hue [h] (in degrees). *)
val oklcha : l:float -> c:float -> h:float -> a:float -> unit -> t
(** [oklcha ~l ~c ~h ~a ()] is like {!oklch} with alpha [a]. *)
val rgb : r:float -> g:float -> b:float -> unit -> t
(** [rgb ~r ~g ~b ()] is the fully opaque color with sRGB components [r], [g],
[b] in \[0, 1\], converted to OKLCH. *)
val rgba : r:float -> g:float -> b:float -> a:float -> unit -> t
(** [rgba ~r ~g ~b ~a ()] is like {!rgb} with alpha [a]. *)
val hex : string -> t
(** [hex s] is the color parsed from the hex string [s]. Accepts ["#RRGGBB"] and
["#RRGGBBAA"] formats.
Raises [Invalid_argument] if [s] is not a valid hex color. *)
(** {1:accessors Accessors} *)
val lightness : t -> float
(** [lightness c] is the OKLCH lightness of [c] in \[0, 1\]. *)
val chroma : t -> float
(** [chroma c] is the OKLCH chroma of [c] in \[0, ~0.4\]. *)
val hue : t -> float
(** [hue c] is the OKLCH hue of [c] in degrees \[0, 360). *)
val alpha : t -> float
(** [alpha c] is the alpha of [c] in \[0, 1\]. *)
(** {1:converting Converting} *)
val to_rgba : t -> float * float * float * float
(** [to_rgba c] is [(r, g, b, a)] with sRGB components in \[0, 1\]. Values are
clamped to the sRGB gamut. *)
(** {1:operations Operations} *)
val with_alpha : float -> t -> t
(** [with_alpha a c] is [c] with alpha set to [a]. *)
val lighten : float -> t -> t
(** [lighten amount c] is [c] with lightness increased by [amount], clamped to
\[0, 1\]. *)
val darken : float -> t -> t
(** [darken amount c] is [c] with lightness decreased by [amount], clamped to
\[0, 1\]. *)
val mix : float -> t -> t -> t
(** [mix ratio a b] is the perceptual blend of [a] and [b]. [ratio] is the
interpolation factor: [0.0] gives [a], [1.0] gives [b]. Hue is interpolated
along the shortest arc. *)
(** {1:named Named colors}
The default named colors follow the
{{:https://jfly.uni-koeln.de/color/}Okabe-Ito} palette, designed to be
distinguishable under all forms of color-vision deficiency. *)
val orange : t
val sky_blue : t
val green : t
val yellow : t
val blue : t
val vermillion : t
val purple : t
val black : t
val white : t
val gray : t
(** {1:fmt Formatting} *)
val pp : Format.formatter -> t -> unit
(** [pp] formats the color as [oklch(L C H / A)]. *)
================================================
FILE: packages/hugin/lib/dune
================================================
(library
(name hugin)
(public_name hugin)
(private_modules
axis
spec
scale
ticks
scene
prepared
resolve
image_util
cairo_sdl
cairo_backend
svg_backend
cmap_data)
(libraries nx nx.buffer ucairo usdl))
================================================
FILE: packages/hugin/lib/hugin.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
module Color = Color
module Cmap = Cmap
module Theme = Theme
type t = Spec.t
type marker = Spec.marker = Circle | Square | Triangle | Plus | Star
type legend_loc = Spec.legend_loc =
| Upper_right
| Upper_left
| Lower_right
| Lower_left
| Center
| Right
| Upper_center
| Lower_center
type line_style = Spec.line_style
type scale = Spec.scale
type stretch = Spec.stretch
(* Mark constructors *)
let line = Spec.line
let point = Spec.point
let bar = Spec.bar
let hist = Spec.hist
let image = Spec.image
let text = Spec.text
let hline = Spec.hline
let vline = Spec.vline
let abline = Spec.abline
let fill_between = Spec.fill_between
let hspan = Spec.hspan
let vspan = Spec.vspan
let errorbar = Spec.errorbar
let heatmap = Spec.heatmap
let imshow = Spec.imshow
let contour = Spec.contour
(* Composition *)
let layers = Spec.layers
(* Decorations *)
let title = Spec.title
let xlabel = Spec.xlabel
let ylabel = Spec.ylabel
let xlim = Spec.xlim
let ylim = Spec.ylim
let xscale = Spec.xscale
let yscale = Spec.yscale
let xinvert = Spec.xinvert
let yinvert = Spec.yinvert
let grid_lines = Spec.grid_lines
let legend = Spec.legend
let xticks = Spec.xticks
let yticks = Spec.yticks
let with_theme = Spec.with_theme
let xtick_format = Spec.xtick_format
let ytick_format = Spec.ytick_format
let frame = Spec.frame
let no_axes = Spec.no_axes
(* Layout *)
let grid = Spec.grid_layout
let hstack ?gap specs = Spec.grid_layout ?gap [ specs ]
let vstack ?gap specs = Spec.grid_layout ?gap (List.map (fun s -> [ s ]) specs)
(* Rendering *)
let default_width = 1600.
let default_height = 1200.
(* Use Cairo text measurement for all backends for consistent layout *)
let resolve_with_cairo ~theme ~width ~height spec =
let surface = Ucairo.Image.create ~w:1 ~h:1 in
let cr = Ucairo.create surface in
let tm = Cairo_backend.text_measurer cr in
let scene = Resolve.resolve ~text_measurer:tm ~theme ~width ~height spec in
Ucairo.Surface.finish surface;
scene
let show ?(theme = Theme.default) ?(width = default_width)
?(height = default_height) spec =
let prepared = Prepared.compile ~theme spec in
Cairo_backend.show_interactive ~theme ~width ~height prepared
let render_png ?(theme = Theme.default) ?(width = default_width)
?(height = default_height) filename spec =
let scene = resolve_with_cairo ~theme ~width ~height spec in
Cairo_backend.render_to_png filename ~width ~height scene
let render_pdf ?(theme = Theme.default) ?(width = default_width)
?(height = default_height) filename spec =
let scene = resolve_with_cairo ~theme ~width ~height spec in
Cairo_backend.render_to_pdf filename ~width ~height scene
let render_svg ?(theme = Theme.default) ?(width = default_width)
?(height = default_height) filename spec =
let scene = resolve_with_cairo ~theme ~width ~height spec in
Svg_backend.render_to_file filename scene
let render_svg_to_string ?(theme = Theme.default) ?(width = default_width)
?(height = default_height) spec =
let scene = resolve_with_cairo ~theme ~width ~height spec in
Svg_backend.render scene
let render_to_buffer ?(theme = Theme.default) ?(width = default_width)
?(height = default_height) spec =
let scene = resolve_with_cairo ~theme ~width ~height spec in
Cairo_backend.render_to_buffer ~width ~height scene
let infer_dimensions spec =
let rec grid_shape = function
| Spec.Grid { rows; _ } ->
let nrows = List.length rows in
let ncols =
List.fold_left (fun acc row -> max acc (List.length row)) 0 rows
in
Some (nrows, ncols)
| Spec.Decorated { inner; _ } -> grid_shape inner
| _ -> None
in
match grid_shape spec with
| Some (nrows, ncols) when ncols > 0 ->
let cell_w = default_width /. float ncols in
(default_width, cell_w *. float nrows)
| _ -> (default_width, default_height)
let pp fmt spec =
let width, height = infer_dimensions spec in
let buf = render_to_buffer ~width ~height spec in
let b64 = Image_util.base64_encode buf in
Format.fprintf fmt "" b64
================================================
FILE: packages/hugin/lib/hugin.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Declarative plotting and visualization.
Hugin turns immutable plot specifications into rendered output. A plot is a
value of type {!t} built from mark constructors ({!line}, {!point}, {!bar},
{!hist}), composed with {!layers}, decorated with {!title}, {!xlabel}, etc.
via the [|>] pipeline, and rendered with {!show}, {!render_png}, or
{!render_svg}.
{[
let x = Nx.linspace Float32 0. 6.28 100 in
let y = Nx.map (fun v -> Float.sin v) x in
Hugin.line ~x ~y () |> Hugin.title "Sine wave"
|> Hugin.render_png "sine.png"
]} *)
(** {1:sub Sub-modules} *)
module Color = Color
(** Perceptually uniform OKLCH colors. *)
module Cmap = Cmap
(** Colormaps. *)
module Theme = Theme
(** Visual themes. *)
(** {1:types Types} *)
type t
(** The type for plot specifications. Immutable and composable. *)
type marker =
| Circle
| Square
| Triangle
| Plus
| Star (** The type for point marker shapes. *)
type legend_loc =
| Upper_right
| Upper_left
| Lower_right
| Lower_left
| Center
| Right
| Upper_center
| Lower_center (** The type for legend placement. *)
type line_style = [ `Solid | `Dashed | `Dotted | `Dash_dot ]
(** The type for line dash patterns. *)
type scale = [ `Linear | `Log | `Sqrt | `Asinh | `Symlog of float ]
(** The type for axis scales. [`Sqrt] and [`Asinh] handle zero gracefully.
[`Symlog linthresh] is linear within \[[-linthresh];[linthresh]\] and
logarithmic outside. *)
type stretch = [ `Linear | `Log | `Sqrt | `Asinh | `Power of float ]
(** The type for image stretch functions. [`Power a] raises normalized values to
the power [a]. *)
(** {1:marks Mark constructors}
Each constructor builds a single-layer specification from data arrays and
optional visual properties. A mark is already a valid {!t} that can be
rendered directly. *)
val line :
x:Nx.float32_t ->
y:Nx.float32_t ->
?color:Color.t ->
?line_width:float ->
?line_style:line_style ->
?step:[ `Pre | `Post | `Mid ] ->
?marker:marker ->
?label:string ->
?alpha:float ->
unit ->
t
(** [line ~x ~y ()] is a line plot connecting the points [(x.(i), y.(i))].
[color] defaults to the next color in the theme palette. [line_width]
defaults to the theme line width. [line_style] defaults to [`Solid]. [step]
draws a staircase line: [`Post] holds each value until the next x-point,
[`Pre] steps to the new value at the current x-point, [`Mid] steps at the
midpoint between consecutive x-points. *)
val point :
x:Nx.float32_t ->
y:Nx.float32_t ->
?color:Color.t ->
?color_by:Nx.float32_t ->
?size:float ->
?size_by:Nx.float32_t ->
?marker:marker ->
?label:string ->
?alpha:float ->
unit ->
t
(** [point ~x ~y ()] is a scatter plot of discrete markers at [(x.(i), y.(i))].
[color_by] maps per-point values through the theme's sequential colormap.
[size_by] scales marker area per point. [marker] defaults to {!Circle}. *)
val bar :
x:Nx.float32_t ->
height:Nx.float32_t ->
?width:float ->
?bottom:float ->
?color:Color.t ->
?label:string ->
?alpha:float ->
unit ->
t
(** [bar ~x ~height ()] is a bar chart with bars centered on [x] values,
extending from [bottom] (default [0.0]) to [bottom + height]. [width]
defaults to [0.8]. *)
val hist :
x:Nx.float32_t ->
?bins:[ `Num of int | `Edges of float array ] ->
?density:bool ->
?color:Color.t ->
?label:string ->
unit ->
t
(** [hist ~x ()] is a histogram of the values in [x].
[bins] defaults to [`Num 10]. When [density] is [true], the histogram is
normalized so the total area equals [1.0]. *)
val image : ?extent:float * float * float * float -> Nx.uint8_t -> t
(** [image ?extent data] displays [data] as an image. [data] has shape
[[|h; w; 3|]] (RGB) or [[|h; w; 4|]] (RGBA).
When [extent] is [(xmin, xmax, ymin, ymax)], the image is placed in data
coordinates. Without [extent], the image is centered in the plot area
preserving aspect ratio. *)
val text :
x:float ->
y:float ->
string ->
?color:Color.t ->
?font_size:float ->
unit ->
t
(** [text ~x ~y s ()] places the string [s] at data coordinates [(x, y)]. *)
val hline :
y:float ->
?color:Color.t ->
?line_width:float ->
?line_style:line_style ->
?label:string ->
?alpha:float ->
unit ->
t
(** [hline ~y ()] draws a horizontal reference line at [y] spanning the full
plot width. *)
val vline :
x:float ->
?color:Color.t ->
?line_width:float ->
?line_style:line_style ->
?label:string ->
?alpha:float ->
unit ->
t
(** [vline ~x ()] draws a vertical reference line at [x] spanning the full plot
height. *)
val abline :
slope:float ->
intercept:float ->
?color:Color.t ->
?line_width:float ->
?line_style:line_style ->
?label:string ->
?alpha:float ->
unit ->
t
(** [abline ~slope ~intercept ()] draws a diagonal line
[y = slope * x + intercept] spanning the full plot area. Useful for
regression lines and [y = x] references. *)
val fill_between :
x:Nx.float32_t ->
y1:Nx.float32_t ->
y2:Nx.float32_t ->
?where:Nx.float32_t ->
?color:Color.t ->
?alpha:float ->
?label:string ->
unit ->
t
(** [fill_between ~x ~y1 ~y2 ()] fills the area between curves [y1] and [y2]
over the shared [x] axis. [alpha] defaults to [0.3].
[where] is an optional mask array of the same length as [x]: the fill is
only drawn where [where.(i) > 0.], producing separate filled regions. *)
val hspan :
y0:float ->
y1:float ->
?color:Color.t ->
?alpha:float ->
?label:string ->
unit ->
t
(** [hspan ~y0 ~y1 ()] is a horizontal shaded band between [y0] and [y1],
spanning the full plot width. [alpha] defaults to [0.2]. *)
val vspan :
x0:float ->
x1:float ->
?color:Color.t ->
?alpha:float ->
?label:string ->
unit ->
t
(** [vspan ~x0 ~x1 ()] is a vertical shaded band between [x0] and [x1], spanning
the full plot height. [alpha] defaults to [0.2]. *)
val errorbar :
x:Nx.float32_t ->
y:Nx.float32_t ->
yerr:
[ `Symmetric of Nx.float32_t | `Asymmetric of Nx.float32_t * Nx.float32_t ] ->
?xerr:
[ `Symmetric of Nx.float32_t | `Asymmetric of Nx.float32_t * Nx.float32_t ] ->
?color:Color.t ->
?line_width:float ->
?cap_size:float ->
?label:string ->
?alpha:float ->
unit ->
t
(** [errorbar ~x ~y ~yerr ()] draws error bars at [(x.(i), y.(i))].
[yerr] specifies vertical error: [`Symmetric e] draws [y +/- e],
[`Asymmetric (lo, hi)] draws [[y - lo, y + hi]]. [xerr] adds horizontal
error bars. [cap_size] defaults to half the theme marker size. *)
val heatmap :
data:Nx.float32_t ->
?annotate:bool ->
?cmap:Cmap.t ->
?vmin:float ->
?vmax:float ->
?fmt:(float -> string) ->
unit ->
t
(** [heatmap ~data ()] displays a 2D array as a grid of colored cells. [data]
has shape [[|rows; cols|]]. Row 0 appears at the top.
[cmap] defaults to the theme's sequential colormap. [vmin] and [vmax]
override the automatic value range. When [annotate] is [true], each cell
shows its value formatted by [fmt] (default [Printf.sprintf "%.2g"]). *)
val imshow :
data:Nx.float32_t ->
?stretch:stretch ->
?cmap:Cmap.t ->
?vmin:float ->
?vmax:float ->
unit ->
t
(** [imshow ~data ()] displays a 2D float array as a colormapped image. [data]
has shape [[|rows; cols|]].
[stretch] controls the transfer function applied before colormap lookup:
[`Linear] (default), [`Log], [`Sqrt], [`Asinh], or [`Power a]. [cmap]
defaults to the theme's sequential colormap. [vmin] and [vmax] override the
automatic value range. *)
val contour :
data:Nx.float32_t ->
x0:float ->
x1:float ->
y0:float ->
y1:float ->
?levels:[ `Num of int | `Values of float array ] ->
?filled:bool ->
?cmap:Cmap.t ->
?color:Color.t ->
?line_width:float ->
?label:string ->
?alpha:float ->
unit ->
t
(** [contour ~data ~x0 ~x1 ~y0 ~y1 ()] draws iso-level contour lines through the
2D grid [data] of shape [[|rows; cols|]], mapped to the data-space rectangle
\[[x0];[x1]\] x \[[y0];[y1]\].
[levels] defaults to [`Num 8]. When [filled] is [true], regions between
adjacent levels are filled. [color] sets a single stroke color for unfilled
contours; [cmap] assigns per-level colors from the theme's sequential
colormap. *)
(** {1:composition Composition} *)
val layers : t list -> t
(** [layers marks] overlays [marks] on shared axes. A single mark is already a
valid {!t}; [layers] is only needed to combine multiple marks into one plot.
*)
(** {1:decorations Decorations}
Decoration functions add metadata to a specification. They are designed for
the [|>] pipeline:
{[
line ~x ~y () |> title "My Plot" |> xlabel "Time"
]} *)
val title : string -> t -> t
(** [title s t] is [t] with plot title [s]. *)
val xlabel : string -> t -> t
(** [xlabel s t] is [t] with x-axis label [s]. *)
val ylabel : string -> t -> t
(** [ylabel s t] is [t] with y-axis label [s]. *)
val xlim : float -> float -> t -> t
(** [xlim lo hi t] is [t] with x-axis range fixed to \[[lo];[hi]\]. *)
val ylim : float -> float -> t -> t
(** [ylim lo hi t] is [t] with y-axis range fixed to \[[lo];[hi]\]. *)
val xscale : scale -> t -> t
(** [xscale s t] is [t] with x-axis scale [s]. Defaults to [`Linear].
[`Sqrt] and [`Asinh] handle zero gracefully. [`Symlog linthresh] is linear
within \[[-linthresh];[linthresh]\] and logarithmic outside. *)
val yscale : scale -> t -> t
(** [yscale s t] is [t] with y-axis scale [s]. Defaults to [`Linear]. *)
val xinvert : t -> t
(** [xinvert t] is [t] with the x-axis inverted (values increase right-to-left).
Useful for right ascension in sky charts. *)
val yinvert : t -> t
(** [yinvert t] is [t] with the y-axis inverted (values increase top-to-bottom).
Useful for magnitude axes in HR diagrams. *)
val grid_lines : bool -> t -> t
(** [grid_lines visible t] is [t] with grid lines shown or hidden. *)
val legend : ?loc:legend_loc -> ?ncol:int -> t -> t
(** [legend ?loc ?ncol t] is [t] with the legend shown at [loc]. [loc] defaults
to {!Upper_right}. [ncol] defaults to [1]; set higher for multi-column
layouts with many series. The legend is automatically visible when any mark
has a [~label]. *)
val xticks : (float * string) list -> t -> t
(** [xticks ticks t] is [t] with explicit x-axis tick positions and labels.
Overrides auto-generated ticks. *)
val yticks : (float * string) list -> t -> t
(** [yticks ticks t] is [t] with explicit y-axis tick positions and labels.
Overrides auto-generated ticks. *)
val with_theme : Theme.t -> t -> t
(** [with_theme th t] is [t] rendered with theme [th] instead of the default. *)
val xtick_format : (float -> string) -> t -> t
(** [xtick_format fmt t] is [t] with x-axis tick labels formatted by [fmt].
Overrides auto-generated labels while preserving tick positions. *)
val ytick_format : (float -> string) -> t -> t
(** [ytick_format fmt t] is [t] with y-axis tick labels formatted by [fmt].
Overrides auto-generated labels while preserving tick positions. *)
val frame : bool -> t -> t
(** [frame visible t] is [t] with the axis border rectangle shown or hidden.
[visible] defaults to [true]. *)
val no_axes : t -> t
(** [no_axes t] hides the axis frame, ticks, and tick labels. Title is
preserved. The full panel area is used for marks. Useful for image grids:
{[
List.init 10 (fun i ->
Hugin.imshow ~data:digits.(i) ~cmap:Cmap.gray ()
|> Hugin.title (string_of_int labels.(i))
|> Hugin.no_axes)
|> Hugin.hstack
]} *)
(** {1:layout Layout} *)
val grid : ?gap:float -> t list list -> t
(** [grid rows] arranges specifications in a grid. Each inner list is a row of
panels. [gap] defaults to [0.05] (fraction of total size). *)
val hstack : ?gap:float -> t list -> t
(** [hstack specs] arranges [specs] in a single row. *)
val vstack : ?gap:float -> t list -> t
(** [vstack specs] arranges [specs] in a single column. *)
(** {1:rendering Rendering} *)
val show : ?theme:Theme.t -> ?width:float -> ?height:float -> t -> unit
(** [show t] displays [t] in an interactive SDL window.
[width] defaults to [1600.0]. [height] defaults to [1200.0]. The window
supports resize (re-resolves at new dimensions) and closes on Escape or Q.
*)
val render_png :
?theme:Theme.t -> ?width:float -> ?height:float -> string -> t -> unit
(** [render_png filename t] writes [t] as a PNG image to [filename].
[width] defaults to [1600.0]. [height] defaults to [1200.0]. *)
val render_pdf :
?theme:Theme.t -> ?width:float -> ?height:float -> string -> t -> unit
(** [render_pdf filename t] writes [t] as a PDF document to [filename].
[width] defaults to [1600.0]. [height] defaults to [1200.0]. *)
val render_svg :
?theme:Theme.t -> ?width:float -> ?height:float -> string -> t -> unit
(** [render_svg filename t] writes [t] as an SVG document to [filename].
[width] defaults to [1600.0]. [height] defaults to [1200.0]. *)
val render_svg_to_string :
?theme:Theme.t -> ?width:float -> ?height:float -> t -> string
(** [render_svg_to_string t] is [t] rendered as an SVG document string.
[width] defaults to [1600.0]. [height] defaults to [1200.0]. *)
val render_to_buffer :
?theme:Theme.t -> ?width:float -> ?height:float -> t -> string
(** [render_to_buffer t] is [t] rendered as a PNG image, returned as a string of
bytes. *)
(** {1:fmt Formatting} *)
val pp : Format.formatter -> t -> unit
(** [pp] renders the specification as a PNG data URI. Intended for use with
[#install_printer] in the toplevel and Quill.
Output format: [] *)
================================================
FILE: packages/hugin/lib/image_util.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Shared image encoding utilities *)
(* Base64 *)
let base64_alphabet =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
let base64_encode input =
let len = String.length input in
let out_len = (len + 2) / 3 * 4 in
let out = Bytes.create out_len in
let rec loop i j =
if i < len then begin
let b0 = Char.code (String.unsafe_get input i) in
let b1 =
if i + 1 < len then Char.code (String.unsafe_get input (i + 1)) else 0
in
let b2 =
if i + 2 < len then Char.code (String.unsafe_get input (i + 2)) else 0
in
Bytes.unsafe_set out j (String.unsafe_get base64_alphabet (b0 lsr 2));
Bytes.unsafe_set out (j + 1)
(String.unsafe_get base64_alphabet (((b0 land 3) lsl 4) lor (b1 lsr 4)));
Bytes.unsafe_set out (j + 2)
(if i + 1 < len then
String.unsafe_get base64_alphabet
(((b1 land 0xf) lsl 2) lor (b2 lsr 6))
else '=');
Bytes.unsafe_set out (j + 3)
(if i + 2 < len then String.unsafe_get base64_alphabet (b2 land 0x3f)
else '=');
loop (i + 3) (j + 4)
end
in
loop 0 0;
Bytes.unsafe_to_string out
(* Nx uint8 image -> Cairo ARGB32 surface *)
let nx_to_cairo_surface (data : Nx.uint8_t) =
let shape = Nx.shape data in
let img_h = shape.(0) and img_w = shape.(1) in
let channels = if Array.length shape > 2 then shape.(2) else 1 in
let stride = Ucairo.Image.stride_for_width img_w in
let data_arr =
Bigarray.Array1.create Bigarray.int8_unsigned Bigarray.c_layout
(stride * img_h)
in
let buf = Nx.data data in
let base = Nx.offset data in
let strides = Nx.strides data in
(* uint8: byte strides = element strides *)
let s0 = strides.(0) and s1 = strides.(1) in
let s2 = if Array.length strides > 2 then strides.(2) else 0 in
for row = 0 to img_h - 1 do
let row_base = base + (row * s0) in
for col = 0 to img_w - 1 do
let off = (row * stride) + (col * 4) in
let idx = row_base + (col * s1) in
let r = Nx_buffer.unsafe_get buf idx in
let g = Nx_buffer.unsafe_get buf (idx + s2) in
let b = Nx_buffer.unsafe_get buf (idx + (2 * s2)) in
let a =
if channels >= 4 then Nx_buffer.unsafe_get buf (idx + (3 * s2)) else 255
in
(* Cairo ARGB32: premultiplied BGRA in memory on little-endian *)
let premul c a = c * a / 255 in
Bigarray.Array1.unsafe_set data_arr off (premul b a);
Bigarray.Array1.unsafe_set data_arr (off + 1) (premul g a);
Bigarray.Array1.unsafe_set data_arr (off + 2) (premul r a);
Bigarray.Array1.unsafe_set data_arr (off + 3) a
done
done;
Ucairo.Image.create_for_data8 data_arr ~w:img_w ~h:img_h ~stride
let nx_to_png_base64 data =
let surface = nx_to_cairo_surface data in
let png_buf = Buffer.create 4096 in
Ucairo.Png.write_to_stream surface (Buffer.add_string png_buf);
Ucairo.Surface.finish surface;
base64_encode (Buffer.contents png_buf)
================================================
FILE: packages/hugin/lib/image_util.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Image encoding utilities.
{b Internal module.} Shared base64 encoding and Nx-to-Cairo-surface
conversion used by both rendering backends. *)
(** {1:base64 Base64} *)
val base64_encode : string -> string
(** [base64_encode s] is the base64 encoding of [s]. *)
(** {1:surface Cairo surface conversion} *)
val nx_to_cairo_surface : Nx.uint8_t -> Ucairo.surface
(** [nx_to_cairo_surface data] is a Cairo ARGB32 image surface from [data].
[data] has shape [[|h; w; 3|]] (RGB) or [[|h; w; 4|]] (RGBA). *)
val nx_to_png_base64 : Nx.uint8_t -> string
(** [nx_to_png_base64 data] is the base64-encoded PNG of [data]. *)
================================================
FILE: packages/hugin/lib/prepared.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Data-only compilation: Spec.t -> Prepared.t
Compiles once per dataset. Separates data-dependent work (collecting
decorations, histogram binning, auto-coloring, data bounds) from
layout-dependent work (pixel coordinates, text measurement) which lives in
Resolve. *)
(* Data bounds *)
let nx_finite_range (arr : Nx.float32_t) =
let n = (Nx.shape arr).(0) in
let lo = ref Float.infinity and hi = ref Float.neg_infinity in
for i = 0 to n - 1 do
let v = Nx.item [ i ] arr in
if Float.is_finite v then begin
if v < !lo then lo := v;
if v > !hi then hi := v
end
done;
(!lo, !hi)
let expand_range scale lo hi =
if lo = hi then (lo -. 1., hi +. 1.)
else
match scale with
| `Log ->
let lo_log = Float.log10 (Float.max 1e-10 lo) in
let hi_log = Float.log10 (Float.max 1e-10 hi) in
let pad = (hi_log -. lo_log) *. 0.05 in
(Float.pow 10. (lo_log -. pad), Float.pow 10. (hi_log +. pad))
| `Sqrt ->
let lo = Float.max 0. lo in
let pad = (hi -. lo) *. 0.05 in
(Float.max 0. (lo -. pad), hi +. pad)
| `Asinh | `Symlog _ | `Linear ->
let pad = (hi -. lo) *. 0.05 in
(lo -. pad, hi +. pad)
let mark_x_range = function
| Spec.Line { x; _ } | Spec.Point { x; _ } -> Some (nx_finite_range x)
| Spec.Bar { x; width; _ } ->
let lo, hi = nx_finite_range x in
let w = (match width with Some w -> w | None -> 0.8) /. 2. in
Some (lo -. w, hi +. w)
| Spec.Hist { x; _ } -> Some (nx_finite_range x)
| Spec.Image { extent = Some (xmin, xmax, _, _); _ } ->
Some (Float.min xmin xmax, Float.max xmin xmax)
| Spec.Image _ -> None
| Spec.Text_mark { x; _ } -> Some (x, x)
| Spec.Hline _ -> None
| Spec.Vline { x; _ } -> Some (x, x)
| Spec.Abline _ -> None
| Spec.Fill_between { x; _ } -> Some (nx_finite_range x)
| Spec.Errorbar { x; xerr; _ } ->
let lo, hi = nx_finite_range x in
let lo, hi =
match xerr with
| Some (`Symmetric e) ->
let _, emax = nx_finite_range e in
(lo -. emax, hi +. emax)
| Some (`Asymmetric (elo, ehi)) ->
let _, emlo = nx_finite_range elo in
let _, emhi = nx_finite_range ehi in
(lo -. emlo, hi +. emhi)
| None -> (lo, hi)
in
Some (lo, hi)
| Spec.Hspan _ -> None
| Spec.Vspan { x0; x1; _ } -> Some (Float.min x0 x1, Float.max x0 x1)
| Spec.Heatmap { data; _ } ->
let shape = Nx.shape data in
let cols = float shape.(1) in
Some (0., cols)
| Spec.Imshow _ -> None
| Spec.Contour { x0; x1; _ } -> Some (Float.min x0 x1, Float.max x0 x1)
let mark_y_range = function
| Spec.Line { y; _ } | Spec.Point { y; _ } -> Some (nx_finite_range y)
| Spec.Bar { height; bottom; _ } ->
let lo, hi = nx_finite_range height in
Some (Float.min bottom (bottom +. lo), Float.max bottom (bottom +. hi))
| Spec.Hist _ -> None
| Spec.Image { extent = Some (_, _, ymin, ymax); _ } ->
Some (Float.min ymin ymax, Float.max ymin ymax)
| Spec.Image _ -> None
| Spec.Text_mark { y; _ } -> Some (y, y)
| Spec.Hline { y; _ } -> Some (y, y)
| Spec.Vline _ -> None
| Spec.Abline _ -> None
| Spec.Fill_between { y1; y2; _ } ->
let lo1, hi1 = nx_finite_range y1 in
let lo2, hi2 = nx_finite_range y2 in
Some (Float.min lo1 lo2, Float.max hi1 hi2)
| Spec.Errorbar { y; yerr; _ } ->
let lo, hi = nx_finite_range y in
let lo, hi =
match yerr with
| `Symmetric e ->
let _, emax = nx_finite_range e in
(lo -. emax, hi +. emax)
| `Asymmetric (elo, ehi) ->
let _, emlo = nx_finite_range elo in
let _, emhi = nx_finite_range ehi in
(lo -. emlo, hi +. emhi)
in
Some (lo, hi)
| Spec.Hspan { y0; y1; _ } -> Some (Float.min y0 y1, Float.max y0 y1)
| Spec.Vspan _ -> None
| Spec.Heatmap { data; _ } ->
let shape = Nx.shape data in
let rows = float shape.(0) in
Some (0., rows)
| Spec.Imshow _ -> None
| Spec.Contour { y0; y1; _ } -> Some (Float.min y0 y1, Float.max y0 y1)
let union_range a b =
match (a, b) with
| None, x | x, None -> x
| Some (a0, a1), Some (b0, b1) -> Some (Float.min a0 b0, Float.max a1 b1)
let compute_data_bounds ~xscale ~yscale marks =
let xr =
List.fold_left (fun acc m -> union_range acc (mark_x_range m)) None marks
in
let yr =
List.fold_left (fun acc m -> union_range acc (mark_y_range m)) None marks
in
let xlo, xhi =
match xr with Some (a, b) -> expand_range xscale a b | None -> (0., 1.)
in
let ylo, yhi =
match yr with Some (a, b) -> expand_range yscale a b | None -> (0., 1.)
in
(xlo, xhi, ylo, yhi)
(* Collect decorations from spec tree *)
type collected = {
marks : Spec.mark list;
x : Axis.config;
y : Axis.config;
title : string option;
grid_visible : bool option;
frame_visible : bool option;
legend_loc : Spec.legend_loc option;
legend_ncol : int;
theme_override : Theme.t option;
}
let empty_collected =
{
marks = [];
x = Axis.empty_config;
y = Axis.empty_config;
title = None;
grid_visible = None;
frame_visible = None;
legend_loc = None;
legend_ncol = 1;
theme_override = None;
}
let rec collect c = function
| Spec.Mark m -> { c with marks = m :: c.marks }
| Spec.Layers ts -> List.fold_left collect c ts
| Spec.Decorated { inner; decorations } ->
let c = collect c inner in
List.fold_left apply_decoration c decorations
| Spec.Grid _ -> c
and apply_decoration c = function
| Spec.Title s when c.title = None -> { c with title = Some s }
| Spec.Xlabel s when c.x.label = None ->
{ c with x = { c.x with label = Some s } }
| Spec.Ylabel s when c.y.label = None ->
{ c with y = { c.y with label = Some s } }
| Spec.Xlim (lo, hi) when c.x.lim = None ->
{ c with x = { c.x with lim = Some (lo, hi) } }
| Spec.Ylim (lo, hi) when c.y.lim = None ->
{ c with y = { c.y with lim = Some (lo, hi) } }
| Spec.Xscale s when c.x.scale = None ->
{ c with x = { c.x with scale = Some s } }
| Spec.Yscale s when c.y.scale = None ->
{ c with y = { c.y with scale = Some s } }
| Spec.Xinvert -> { c with x = { c.x with invert = true } }
| Spec.Yinvert -> { c with y = { c.y with invert = true } }
| Spec.Grid_visible v when c.grid_visible = None ->
{ c with grid_visible = Some v }
| Spec.Legend (loc, ncol) when c.legend_loc = None ->
{ c with legend_loc = Some loc; legend_ncol = ncol }
| Spec.Xticks t when c.x.ticks = None ->
{ c with x = { c.x with ticks = Some t } }
| Spec.Yticks t when c.y.ticks = None ->
{ c with y = { c.y with ticks = Some t } }
| Spec.With_theme t when c.theme_override = None ->
{ c with theme_override = Some t }
| Spec.Xtick_format f when c.x.tick_format = None ->
{ c with x = { c.x with tick_format = Some f } }
| Spec.Ytick_format f when c.y.tick_format = None ->
{ c with y = { c.y with tick_format = Some f } }
| Spec.Frame v when c.frame_visible = None ->
{ c with frame_visible = Some v }
| _ -> c
(* Auto-coloring *)
let mark_color = function
| Spec.Line { color; _ }
| Spec.Point { color; _ }
| Spec.Bar { color; _ }
| Spec.Hist { color; _ }
| Spec.Text_mark { color; _ }
| Spec.Hline { color; _ }
| Spec.Vline { color; _ }
| Spec.Abline { color; _ }
| Spec.Fill_between { color; _ }
| Spec.Hspan { color; _ }
| Spec.Vspan { color; _ }
| Spec.Errorbar { color; _ }
| Spec.Contour { color; _ } ->
color
| Spec.Image _ | Spec.Heatmap _ | Spec.Imshow _ -> None
let auto_color (theme : Theme.t) marks =
let n_palette = Array.length theme.palette in
List.mapi
(fun i m ->
match mark_color m with
| Some _ -> m
| None -> (
let c = theme.palette.(i mod n_palette) in
match m with
| Spec.Line r -> Spec.Line { r with color = Some c }
| Spec.Point r -> Spec.Point { r with color = Some c }
| Spec.Bar r -> Spec.Bar { r with color = Some c }
| Spec.Hist r -> Spec.Hist { r with color = Some c }
| Spec.Hline r -> Spec.Hline { r with color = Some c }
| Spec.Vline r -> Spec.Vline { r with color = Some c }
| Spec.Abline r -> Spec.Abline { r with color = Some c }
| Spec.Fill_between r -> Spec.Fill_between { r with color = Some c }
| Spec.Hspan r -> Spec.Hspan { r with color = Some c }
| Spec.Vspan r -> Spec.Vspan { r with color = Some c }
| Spec.Errorbar r -> Spec.Errorbar { r with color = Some c }
| Spec.Contour r -> Spec.Contour { r with color = Some c }
| m -> m))
marks
(* Histogram normalization — convert Hist to Bar *)
let normalize_hist marks =
List.map
(fun m ->
match m with
| Spec.Hist { x; bins; density; color; label } ->
let xmin, xmax = nx_finite_range x in
let edges =
match bins with
| `Num num_bins ->
Array.init (num_bins + 1) (fun i ->
xmin +. ((xmax -. xmin) *. float i /. float num_bins))
| `Edges e -> e
in
let num_bins = Array.length edges - 1 in
let n = (Nx.shape x).(0) in
let counts = Array.make num_bins 0. in
let binned = ref 0 in
for i = 0 to n - 1 do
let v = Nx.item [ i ] x in
if Float.is_finite v && v >= edges.(0) && v <= edges.(num_bins) then begin
incr binned;
let bin = ref 0 in
while !bin < num_bins - 1 && v >= edges.(!bin + 1) do
incr bin
done;
counts.(!bin) <- counts.(!bin) +. 1.
end
done;
if density then begin
let total =
let b = float !binned in
if b = 0. then 1. else b
in
for i = 0 to num_bins - 1 do
let w = edges.(i + 1) -. edges.(i) in
counts.(i) <- counts.(i) /. (total *. w)
done
end;
let bar_x =
Nx.init Float32 [| num_bins |] (fun idx ->
let i = idx.(0) in
(edges.(i) +. edges.(i + 1)) /. 2.)
in
let bar_h =
Nx.init Float32 [| num_bins |] (fun idx -> counts.(idx.(0)))
in
let w = if num_bins > 0 then edges.(1) -. edges.(0) else 1. in
Spec.Bar
{
x = bar_x;
height = bar_h;
width = Some w;
bottom = 0.;
color;
label;
alpha = None;
}
| m -> m)
marks
(* Guide ranges *)
let color_by_range marks =
List.fold_left
(fun acc m ->
match m with
| Spec.Point { color_by = Some cb; _ } ->
let lo, hi = nx_finite_range cb in
union_range acc (Some (lo, hi))
| _ -> acc)
None marks
let size_by_range marks =
List.fold_left
(fun acc m ->
match m with
| Spec.Point { size_by = Some sb; _ } ->
let lo, hi = nx_finite_range sb in
union_range acc (Some (lo, hi))
| _ -> acc)
None marks
(* Collect marks from all panels in a spec tree *)
let rec collect_all_marks = function
| Spec.Mark m -> [ m ]
| Spec.Layers ts -> List.concat_map collect_all_marks ts
| Spec.Decorated { inner; _ } -> collect_all_marks inner
| Spec.Grid { rows; _ } ->
List.concat_map (List.concat_map collect_all_marks) rows
(* Grid-level decorations *)
type grid_decorations = {
gd_title : string option;
gd_xlabel : string option;
gd_ylabel : string option;
gd_legend_loc : Spec.legend_loc option;
gd_legend_ncol : int;
gd_theme_override : Theme.t option;
}
let extract_grid_decorations decorations =
let d =
{
gd_title = None;
gd_xlabel = None;
gd_ylabel = None;
gd_legend_loc = None;
gd_legend_ncol = 1;
gd_theme_override = None;
}
in
List.fold_left
(fun d dec ->
match dec with
| Spec.Title s when d.gd_title = None -> { d with gd_title = Some s }
| Spec.Xlabel s when d.gd_xlabel = None -> { d with gd_xlabel = Some s }
| Spec.Ylabel s when d.gd_ylabel = None -> { d with gd_ylabel = Some s }
| Spec.Legend (loc, ncol) when d.gd_legend_loc = None ->
{ d with gd_legend_loc = Some loc; gd_legend_ncol = ncol }
| Spec.With_theme t when d.gd_theme_override = None ->
{ d with gd_theme_override = Some t }
| _ -> d)
d decorations
(* Prepared panel — all data-only work done *)
type panel = {
marks : Spec.mark list;
x : Axis.t;
y : Axis.t;
title : string option;
legend_loc : Spec.legend_loc option;
legend_ncol : int;
grid_visible : bool option;
frame_visible : bool option;
theme_override : Theme.t option;
colorbar_range : (float * float) option;
size_by_range : (float * float) option;
}
type t =
| Panel of panel
| Grid of { rows : t list list; gap : float }
| Decorated_grid of {
decorations : grid_decorations;
inner : t;
all_marks : Spec.mark list;
}
(* Imshow: rasterize float32 data to uint8 RGB via stretch + colormap *)
let apply_stretch stretch v =
match stretch with
| `Linear -> v
| `Log -> Float.log10 (1. +. (9. *. v)) /. Float.log10 10.
| `Sqrt -> Float.sqrt (Float.max 0. v)
| `Asinh ->
let a = 10. in
Float.asinh (a *. v) /. Float.asinh a
| `Power a -> Float.pow (Float.max 0. v) a
let rasterize_imshow ~stretch ~cmap ~vmin ~vmax (data : Nx.float32_t) =
let shape = Nx.shape data in
let rows = shape.(0) and cols = shape.(1) in
let lo = ref Float.infinity and hi = ref Float.neg_infinity in
for r = 0 to rows - 1 do
for c = 0 to cols - 1 do
let v = Nx.item [ r; c ] data in
if Float.is_finite v then begin
if v < !lo then lo := v;
if v > !hi then hi := v
end
done
done;
let vlo = match vmin with Some v -> v | None -> !lo in
let vhi = match vmax with Some v -> v | None -> !hi in
let vrange = if vhi = vlo then 1. else vhi -. vlo in
let rgb = Nx.zeros Nx.uint8 [| rows; cols; 3 |] in
for r = 0 to rows - 1 do
for c = 0 to cols - 1 do
let v = Nx.item [ r; c ] data in
let t = Float.max 0. (Float.min 1. ((v -. vlo) /. vrange)) in
let t = apply_stretch stretch t in
let t = Float.max 0. (Float.min 1. t) in
let color = Cmap.eval cmap t in
let cr, cg, cb, _ = Color.to_rgba color in
Nx.set_item [ r; c; 0 ] (int_of_float (cr *. 255.)) rgb;
Nx.set_item [ r; c; 1 ] (int_of_float (cg *. 255.)) rgb;
Nx.set_item [ r; c; 2 ] (int_of_float (cb *. 255.)) rgb
done
done;
rgb
let normalize_imshow (theme : Theme.t) marks =
List.map
(fun m ->
match m with
| Spec.Imshow { data; stretch; cmap; vmin; vmax } ->
let cmap = match cmap with Some c -> c | None -> theme.sequential in
let rgb = rasterize_imshow ~stretch ~cmap ~vmin ~vmax data in
Spec.Image { data = rgb; extent = None }
| m -> m)
marks
(* Contour tracing via marching squares *)
type contour_paths = { level : float; paths : (float * float) array list }
(* Join 2-point segments that share endpoints into connected polylines. Marching
squares produces one segment per cell edge crossing. Segments from adjacent
cells share exact floating-point endpoints (deterministic lerp), so we chain
them with exact equality via a hashtable. *)
let join_segments segments =
let n = List.length segments in
if n = 0 then []
else
let segs = Array.of_list segments in
let visited = Array.make n false in
let adj = Hashtbl.create (2 * n) in
Array.iteri
(fun i (a, b) ->
let add pt =
let cur = try Hashtbl.find adj pt with Not_found -> [] in
Hashtbl.replace adj pt (i :: cur)
in
add a;
add b)
segs;
let find_unvisited_neighbor pt =
match Hashtbl.find adj pt with
| exception Not_found -> None
| neighbors ->
let rec scan = function
| [] -> None
| j :: rest -> if visited.(j) then scan rest else Some j
in
scan neighbors
in
let chains = ref [] in
for start = 0 to n - 1 do
if not visited.(start) then begin
visited.(start) <- true;
let a0, b0 = segs.(start) in
(* front: backward extensions (cons'd, so in chain order). back: forward
extensions (cons'd, so reversed). *)
let front = ref [ a0 ] in
let back = ref [ b0 ] in
(* Extend forward from b0 *)
let cur = ref b0 in
let go = ref true in
while !go do
match find_unvisited_neighbor !cur with
| None -> go := false
| Some j ->
visited.(j) <- true;
let a, b = segs.(j) in
let next = if a = !cur then b else a in
back := next :: !back;
cur := next
done;
(* Extend backward from a0 *)
cur := a0;
go := true;
while !go do
match find_unvisited_neighbor !cur with
| None -> go := false
| Some j ->
visited.(j) <- true;
let a, b = segs.(j) in
let next = if a = !cur then b else a in
front := next :: !front;
cur := next
done;
(* front is in chain order; back is reversed *)
chains := Array.of_list (!front @ List.rev !back) :: !chains
end
done;
List.rev !chains
let trace_contours ~rows ~cols (data : Nx.float32_t) levels =
let get r c =
if r >= 0 && r < rows && c >= 0 && c < cols then Nx.item [ r; c ] data
else 0.
in
List.map
(fun level ->
let segments = ref [] in
for r = 0 to rows - 2 do
for c = 0 to cols - 2 do
let v00 = get r c in
let v10 = get r (c + 1) in
let v11 = get (r + 1) (c + 1) in
let v01 = get (r + 1) c in
let b0 = if v00 >= level then 1 else 0 in
let b1 = if v10 >= level then 1 else 0 in
let b2 = if v11 >= level then 1 else 0 in
let b3 = if v01 >= level then 1 else 0 in
let case = b0 lor (b1 lsl 1) lor (b2 lsl 2) lor (b3 lsl 3) in
let lerp va vb =
let d = vb -. va in
if Float.abs d < 1e-30 then 0.5 else (level -. va) /. d
in
let fc = float c and fr = float r in
let top = (fc +. lerp v00 v10, fr) in
let right = (fc +. 1., fr +. lerp v10 v11) in
let bottom = (fc +. lerp v01 v11, fr +. 1.) in
let left = (fc, fr +. lerp v00 v01) in
let add a b = segments := (a, b) :: !segments in
begin match case with
| 0 | 15 -> ()
| 1 | 14 -> add top left
| 2 | 13 -> add top right
| 3 | 12 -> add left right
| 4 | 11 -> add right bottom
| 5 ->
let center = (v00 +. v10 +. v11 +. v01) /. 4. in
if center >= level then begin
add top right;
add bottom left
end
else begin
add top left;
add bottom right
end
| 6 | 9 -> add top bottom
| 7 | 8 -> add bottom left
| 10 ->
let center = (v00 +. v10 +. v11 +. v01) /. 4. in
if center >= level then begin
add top left;
add bottom right
end
else begin
add top right;
add bottom left
end
| _ -> ()
end
done
done;
let paths = join_segments !segments in
{ level; paths })
levels
let prepare_contour ~x0 ~x1 ~y0 ~y1 ~data ~levels =
let shape = Nx.shape data in
let rows = shape.(0) and cols = shape.(1) in
let lo = ref Float.infinity and hi = ref Float.neg_infinity in
for r = 0 to rows - 1 do
for c = 0 to cols - 1 do
let v = Nx.item [ r; c ] data in
if Float.is_finite v then begin
if v < !lo then lo := v;
if v > !hi then hi := v
end
done
done;
let vlo = !lo and vhi = !hi in
let level_values =
match levels with
| `Values a -> Array.to_list a
| `Num n ->
let range = vhi -. vlo in
if range = 0. then [ vlo ]
else
List.init n (fun i ->
vlo +. (range *. (float (i + 1) /. float (n + 1))))
in
let contours = trace_contours ~rows ~cols data level_values in
(* Map grid coords to data coords *)
let xscale = (x1 -. x0) /. float (cols - 1) in
let yscale = (y1 -. y0) /. float (rows - 1) in
List.map
(fun cp ->
let paths =
List.map
(fun seg ->
Array.map
(fun (gc, gr) -> (x0 +. (gc *. xscale), y0 +. (gr *. yscale)))
seg)
cp.paths
in
{ cp with paths })
contours
(* Compile a spec tree into a prepared tree *)
let compile_panel theme spec =
let c = collect empty_collected spec in
let c = { c with marks = List.rev c.marks } in
let theme = Option.value ~default:theme c.theme_override in
let marks =
normalize_hist (normalize_imshow theme (auto_color theme c.marks))
in
let xscale = Option.value ~default:`Linear c.x.scale in
let yscale = Option.value ~default:`Linear c.y.scale in
let xlo, xhi, ylo, yhi = compute_data_bounds ~xscale ~yscale marks in
let x = Axis.resolve ~data_lo:xlo ~data_hi:xhi c.x in
let y = Axis.resolve ~data_lo:ylo ~data_hi:yhi c.y in
Panel
{
marks;
x;
y;
title = c.title;
legend_loc = c.legend_loc;
legend_ncol = c.legend_ncol;
grid_visible = c.grid_visible;
frame_visible = c.frame_visible;
theme_override = c.theme_override;
colorbar_range = color_by_range marks;
size_by_range = size_by_range marks;
}
let rec compile ~theme spec =
match spec with
| Spec.Grid { rows; gap } ->
let rows = List.map (List.map (compile ~theme)) rows in
Grid { rows; gap }
| Spec.Decorated { inner = Spec.Grid _ as g; decorations } ->
let gd = extract_grid_decorations decorations in
let theme = Option.value ~default:theme gd.gd_theme_override in
let all_marks = auto_color theme (collect_all_marks g) in
let inner = compile ~theme g in
Decorated_grid { decorations = gd; inner; all_marks }
| spec -> compile_panel theme spec
================================================
FILE: packages/hugin/lib/prepared.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Data-only compilation stage.
{b Internal module.} Compiles a {!Spec.t} tree into a {!t} tree with all
data-dependent work done: decoration collection, histogram binning,
auto-coloring, data-bound computation, imshow rasterization, contour
tracing, and guide-range detection.
The result is independent of output dimensions and can be resolved
repeatedly at different sizes by {!Resolve.resolve_prepared}. *)
(** {1:bounds Data bounds} *)
val nx_finite_range : Nx.float32_t -> float * float
(** [nx_finite_range arr] is [(lo, hi)] of the finite values in [arr]. *)
(** {1:marks Mark introspection} *)
val mark_color : Spec.mark -> Color.t option
(** [mark_color m] is the color of [m], if set. *)
(** {1:contour Contour tracing} *)
type contour_paths = { level : float; paths : (float * float) array list }
(** The type for traced contour paths at a single iso-level. Coordinates are in
data space. *)
val prepare_contour :
x0:float ->
x1:float ->
y0:float ->
y1:float ->
data:Nx.float32_t ->
levels:[ `Num of int | `Values of float array ] ->
contour_paths list
(** [prepare_contour ~x0 ~x1 ~y0 ~y1 ~data ~levels] traces contour paths through
[data] and maps grid coordinates to the data-space rectangle \[[x0];[x1]\] x
\[[y0];[y1]\]. *)
(** {1:panel Prepared panel} *)
type panel = {
marks : Spec.mark list;
x : Axis.t;
y : Axis.t;
title : string option;
legend_loc : Spec.legend_loc option;
legend_ncol : int;
grid_visible : bool option;
frame_visible : bool option;
theme_override : Theme.t option;
colorbar_range : (float * float) option;
size_by_range : (float * float) option;
}
(** The type for prepared panels. All data-only work is done: marks are
auto-colored and histograms normalized to bars, data bounds are computed,
and guide ranges are detected. *)
(** {1:grid Grid decorations} *)
type grid_decorations = {
gd_title : string option;
gd_xlabel : string option;
gd_ylabel : string option;
gd_legend_loc : Spec.legend_loc option;
gd_legend_ncol : int;
gd_theme_override : Theme.t option;
}
(** The type for grid-level decorations extracted from a decorated grid spec. *)
(** {1:tree Prepared tree} *)
type t =
| Panel of panel
| Grid of { rows : t list list; gap : float }
| Decorated_grid of {
decorations : grid_decorations;
inner : t;
all_marks : Spec.mark list;
}
(** The type for prepared spec trees. Mirrors {!Spec.t} structure with all
data-only work pre-computed. *)
(** {1:compile Compilation} *)
val compile : theme:Theme.t -> Spec.t -> t
(** [compile ~theme spec] is the prepared tree for [spec]. Collects decorations,
normalizes histograms, auto-colors marks, computes data bounds, and detects
colorbar/size-guide ranges. *)
================================================
FILE: packages/hugin/lib/resolve.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Prepared.t + Theme.t -> Scene.t resolution
Layout and pixel-coordinate work. All data-only processing (histogram
binning, auto-coloring, bounds) is done in Prepared. *)
type text_measurer = font:Theme.font -> string -> float * float
(* Region in device pixels *)
type region = { rx : float; ry : float; rw : float; rh : float }
(* Scale-aware coord transform *)
let data_to_pixel_x sx region v =
let u = sx.Scale.to_unit v in
region.rx +. (u *. region.rw)
let data_to_pixel_y sy region v =
let u = sy.Scale.to_unit v in
region.ry +. region.rh -. (u *. region.rh)
(* Dash pattern *)
let dash_of_style = function
| `Solid -> []
| `Dashed -> [ 6.; 4. ]
| `Dotted -> [ 2.; 3. ]
| `Dash_dot -> [ 6.; 3.; 2.; 3. ]
(* Resolution helpers *)
let resolve_color ?default_alpha color alpha =
let c = Option.value ~default:Color.black color in
match (alpha, default_alpha) with
| Some a, _ | None, Some a -> Color.with_alpha a c
| None, None -> c
let resolve_line_width sf (theme : Theme.t) line_width =
Option.value ~default:theme.line_width line_width *. sf
let resolve_dash sf line_style =
let dash = match line_style with Some s -> dash_of_style s | None -> [] in
List.map (fun d -> d *. sf) dash
(* Emit mark primitives *)
let step_transform step points n =
match step with
| None -> points
| Some `Post ->
if n < 2 then points
else begin
let out = Array.make ((2 * n) - 1) (0., 0.) in
let k = ref 0 in
for i = 0 to n - 2 do
let px, py = points.(i) in
let px_next, _ = points.(i + 1) in
out.(!k) <- (px, py);
incr k;
out.(!k) <- (px_next, py);
incr k
done;
out.(!k) <- points.(n - 1);
Array.sub out 0 (!k + 1)
end
| Some `Pre ->
if n < 2 then points
else begin
let out = Array.make ((2 * n) - 1) (0., 0.) in
let k = ref 0 in
out.(!k) <- points.(0);
incr k;
for i = 1 to n - 1 do
let _, py_prev = points.(i - 1) in
let px, py = points.(i) in
out.(!k) <- (px, py_prev);
incr k;
out.(!k) <- (px, py);
incr k
done;
Array.sub out 0 !k
end
| Some `Mid ->
if n < 2 then points
else begin
let out = Array.make ((3 * n) - 2) (0., 0.) in
let k = ref 0 in
for i = 0 to n - 2 do
let px, py = points.(i) in
let px_next, py_next = points.(i + 1) in
let mx = (px +. px_next) /. 2. in
out.(!k) <- (px, py);
incr k;
out.(!k) <- (mx, py);
incr k;
out.(!k) <- (mx, py_next);
incr k
done;
out.(!k) <- points.(n - 1);
Array.sub out 0 (!k + 1)
end
let emit_line_mark sx sy plot_area theme ~x ~y ~color ~line_width ~line_style
~step ~marker ~alpha =
let n = (Nx.shape x).(0) in
let color = resolve_color color alpha in
let sf = theme.Theme.scale_factor in
let lw = resolve_line_width sf theme line_width in
let scaled_dash = resolve_dash sf line_style in
(* Split line into finite-value segments *)
let segments = ref [] in
let current = ref [] in
let all_finite_points = ref [] in
for i = 0 to n - 1 do
let xv = Nx.item [ i ] x in
let yv = Nx.item [ i ] y in
if Float.is_finite xv && Float.is_finite yv then begin
let px = data_to_pixel_x sx plot_area xv in
let py = data_to_pixel_y sy plot_area yv in
let pt = (px, py) in
current := pt :: !current;
all_finite_points := pt :: !all_finite_points
end
else
match !current with
| [] -> ()
| _ ->
segments := Array.of_list (List.rev !current) :: !segments;
current := []
done;
(match !current with
| [] -> ()
| _ -> segments := Array.of_list (List.rev !current) :: !segments);
let segments = List.rev !segments in
let paths =
List.map
(fun points ->
let n_pts = Array.length points in
let points = step_transform step points n_pts in
Scene.Path
{
points;
close = false;
fill = None;
stroke = Some color;
line_width = lw;
dash = scaled_dash;
})
segments
in
match marker with
| Some shape ->
let finite_points = Array.of_list (List.rev !all_finite_points) in
let ms = theme.marker_size *. sf in
let markers =
Scene.Markers
{
points = finite_points;
shape;
size = ms;
sizes = None;
fill = Some color;
fills = None;
stroke = None;
}
in
paths @ [ markers ]
| None -> paths
let emit_point_mark sx sy plot_area theme ~x ~y ~color ~color_by ~size ~size_by
~marker ~alpha =
let n = (Nx.shape x).(0) in
let color = resolve_color color alpha in
let shape = Option.value ~default:Spec.Circle marker in
let ms =
(match size with Some s -> s | None -> theme.Theme.marker_size)
*. theme.scale_factor
in
(* Collect only finite points *)
let valid = Array.make n false in
let num_valid = ref 0 in
for i = 0 to n - 1 do
let xv = Nx.item [ i ] x in
let yv = Nx.item [ i ] y in
if Float.is_finite xv && Float.is_finite yv then begin
valid.(i) <- true;
incr num_valid
end
done;
let nv = !num_valid in
let points = Array.make nv (0., 0.) in
let vi = ref 0 in
for i = 0 to n - 1 do
if valid.(i) then begin
let px = data_to_pixel_x sx plot_area (Nx.item [ i ] x) in
let py = data_to_pixel_y sy plot_area (Nx.item [ i ] y) in
points.(!vi) <- (px, py);
incr vi
end
done;
let sizes =
match size_by with
| Some sb ->
let sb_lo, sb_hi = Prepared.nx_finite_range sb in
let sb_range = if sb_hi = sb_lo then 1. else sb_hi -. sb_lo in
let arr = Array.make nv ms in
let vi = ref 0 in
for i = 0 to n - 1 do
if valid.(i) then begin
let sv = (Nx.item [ i ] sb -. sb_lo) /. sb_range in
arr.(!vi) <- (ms *. 0.5) +. (ms *. Float.sqrt sv);
incr vi
end
done;
Some arr
| None -> None
in
let fills =
match color_by with
| Some cb ->
let cb_lo, cb_hi = Prepared.nx_finite_range cb in
let cb_range = if cb_hi = cb_lo then 1. else cb_hi -. cb_lo in
let arr = Array.make nv Color.black in
let vi = ref 0 in
for i = 0 to n - 1 do
if valid.(i) then begin
let cv = (Nx.item [ i ] cb -. cb_lo) /. cb_range in
let c = Cmap.eval theme.sequential cv in
arr.(!vi) <-
(match alpha with Some a -> Color.with_alpha a c | None -> c);
incr vi
end
done;
Some arr
| None -> None
in
let fill = if fills <> None then None else Some color in
let stroke = Some color in
[ Scene.Markers { points; shape; size = ms; sizes; fill; fills; stroke } ]
let emit_bar_mark sx sy plot_area theme ~x ~height ~width ~bottom ~color ~alpha
=
let n = (Nx.shape x).(0) in
let color = resolve_color color alpha in
let w = Option.value ~default:0.8 width in
let prims = ref [] in
for i = 0 to n - 1 do
let xi = Nx.item [ i ] x in
let hi = Nx.item [ i ] height in
if Float.is_finite xi && Float.is_finite hi then
let x0 = data_to_pixel_x sx plot_area (xi -. (w /. 2.)) in
let x1 = data_to_pixel_x sx plot_area (xi +. (w /. 2.)) in
let y0 = data_to_pixel_y sy plot_area bottom in
let y1 = data_to_pixel_y sy plot_area (bottom +. hi) in
let lx = Float.min x0 x1 and rx = Float.max x0 x1 in
let ty = Float.min y0 y1 and by = Float.max y0 y1 in
prims :=
Scene.Path
{
points = [| (lx, ty); (rx, ty); (rx, by); (lx, by) |];
close = true;
fill = Some color;
stroke = None;
line_width = 0.;
dash = [];
}
:: !prims
done;
List.rev !prims
let emit_image_mark sx sy plot_area ~data ~extent =
match extent with
| Some (xmin, xmax, ymin, ymax) ->
let px0 = data_to_pixel_x sx plot_area xmin in
let px1 = data_to_pixel_x sx plot_area xmax in
let py0 = data_to_pixel_y sy plot_area ymax in
let py1 = data_to_pixel_y sy plot_area ymin in
let x = Float.min px0 px1 in
let y = Float.min py0 py1 in
let w = Float.abs (px1 -. px0) in
let h = Float.abs (py1 -. py0) in
[ Scene.Image { x; y; w; h; data } ]
| None ->
[
Scene.Image
{
x = plot_area.rx;
y = plot_area.ry;
w = plot_area.rw;
h = plot_area.rh;
data;
};
]
let emit_text_mark sx sy plot_area theme ~x ~y ~content ~color ~font_size =
let color = Option.value ~default:Color.black color in
let size = Option.value ~default:theme.Theme.font_label.size font_size in
let px = data_to_pixel_x sx plot_area x in
let py = data_to_pixel_y sy plot_area y in
[
Scene.Text
{
x = px;
y = py;
content;
font =
{
family = theme.font_label.family;
size = size *. theme.scale_factor;
weight = `Normal;
};
color;
anchor = `Start;
baseline = `Bottom;
angle = 0.;
};
]
let emit_hline_mark sy plot_area theme ~y:yv ~color ~line_width ~line_style
~alpha =
let color = resolve_color color alpha in
let sf = theme.Theme.scale_factor in
let lw = resolve_line_width sf theme line_width in
let dash = resolve_dash sf line_style in
let py = data_to_pixel_y sy plot_area yv in
[
Scene.Path
{
points = [| (plot_area.rx, py); (plot_area.rx +. plot_area.rw, py) |];
close = false;
fill = None;
stroke = Some color;
line_width = lw;
dash;
};
]
let emit_vline_mark sx plot_area theme ~x:xv ~color ~line_width ~line_style
~alpha =
let color = resolve_color color alpha in
let sf = theme.Theme.scale_factor in
let lw = resolve_line_width sf theme line_width in
let dash = resolve_dash sf line_style in
let px = data_to_pixel_x sx plot_area xv in
[
Scene.Path
{
points = [| (px, plot_area.ry); (px, plot_area.ry +. plot_area.rh) |];
close = false;
fill = None;
stroke = Some color;
line_width = lw;
dash;
};
]
let emit_abline_mark sx sy plot_area theme ~slope ~intercept ~color ~line_width
~line_style ~alpha =
let color = resolve_color color alpha in
let sf = theme.Theme.scale_factor in
let lw = resolve_line_width sf theme line_width in
let dash = resolve_dash sf line_style in
let x0 = sx.Scale.lo and x1 = sx.Scale.hi in
let y0v = (slope *. x0) +. intercept in
let y1v = (slope *. x1) +. intercept in
let px0 = data_to_pixel_x sx plot_area x0 in
let py0 = data_to_pixel_y sy plot_area y0v in
let px1 = data_to_pixel_x sx plot_area x1 in
let py1 = data_to_pixel_y sy plot_area y1v in
[
Scene.Path
{
points = [| (px0, py0); (px1, py1) |];
close = false;
fill = None;
stroke = Some color;
line_width = lw;
dash;
};
]
let emit_fill_between_segment sx sy plot_area color indices x y1 y2 =
let n_seg = List.length indices in
if n_seg = 0 then []
else
let points = Array.make (2 * n_seg) (0., 0.) in
let k = ref 0 in
List.iter
(fun i ->
let xv = Nx.item [ i ] x in
let yv = Nx.item [ i ] y1 in
if Float.is_finite xv && Float.is_finite yv then begin
points.(!k) <-
(data_to_pixel_x sx plot_area xv, data_to_pixel_y sy plot_area yv);
incr k
end)
indices;
let forward_count = !k in
List.iter
(fun i ->
let xv = Nx.item [ i ] x in
let yv = Nx.item [ i ] y2 in
if Float.is_finite xv && Float.is_finite yv then begin
points.(!k) <-
(data_to_pixel_x sx plot_area xv, data_to_pixel_y sy plot_area yv);
incr k
end)
(List.rev indices);
let total = !k in
if total < 3 || forward_count = 0 then []
else
[
Scene.Path
{
points = Array.sub points 0 total;
close = true;
fill = Some color;
stroke = None;
line_width = 0.;
dash = [];
};
]
let emit_fill_between_mark sx sy plot_area ~x ~y1 ~y2 ~where ~color ~alpha =
let n = (Nx.shape x).(0) in
let color = resolve_color ~default_alpha:0.3 color alpha in
match where with
| None ->
let indices = List.init n Fun.id in
emit_fill_between_segment sx sy plot_area color indices x y1 y2
| Some mask ->
(* Split into contiguous runs where mask > 0 *)
let segments = ref [] in
let current = ref [] in
for i = 0 to n - 1 do
if Nx.item [ i ] mask > 0. then current := i :: !current
else
match !current with
| [] -> ()
| seg ->
segments := List.rev seg :: !segments;
current := []
done;
(match !current with
| [] -> ()
| seg -> segments := List.rev seg :: !segments);
List.concat_map
(fun seg -> emit_fill_between_segment sx sy plot_area color seg x y1 y2)
(List.rev !segments)
let emit_hspan_mark sy plot_area ~y0 ~y1 ~color ~alpha =
let color = resolve_color ~default_alpha:0.2 color alpha in
let py0 = data_to_pixel_y sy plot_area y0 in
let py1 = data_to_pixel_y sy plot_area y1 in
let top = Float.min py0 py1 and bot = Float.max py0 py1 in
[
Scene.Path
{
points =
[|
(plot_area.rx, top);
(plot_area.rx +. plot_area.rw, top);
(plot_area.rx +. plot_area.rw, bot);
(plot_area.rx, bot);
|];
close = true;
fill = Some color;
stroke = None;
line_width = 0.;
dash = [];
};
]
let emit_vspan_mark sx plot_area ~x0 ~x1 ~color ~alpha =
let color = resolve_color ~default_alpha:0.2 color alpha in
let px0 = data_to_pixel_x sx plot_area x0 in
let px1 = data_to_pixel_x sx plot_area x1 in
let left = Float.min px0 px1 and right = Float.max px0 px1 in
[
Scene.Path
{
points =
[|
(left, plot_area.ry);
(right, plot_area.ry);
(right, plot_area.ry +. plot_area.rh);
(left, plot_area.ry +. plot_area.rh);
|];
close = true;
fill = Some color;
stroke = None;
line_width = 0.;
dash = [];
};
]
let emit_errorbar_mark sx sy plot_area theme ~x ~y ~yerr ~xerr ~color
~line_width ~cap_size ~alpha =
let n = (Nx.shape x).(0) in
let color = resolve_color color alpha in
let sf = theme.Theme.scale_factor in
let lw = resolve_line_width sf theme line_width in
let cap =
(match cap_size with Some s -> s | None -> theme.marker_size *. 0.5) *. sf
in
let prims = ref [] in
let make_path pts =
Scene.Path
{
points = pts;
close = false;
fill = None;
stroke = Some color;
line_width = lw;
dash = [];
}
in
for i = 0 to n - 1 do
let xv = Nx.item [ i ] x in
let yv = Nx.item [ i ] y in
if Float.is_finite xv && Float.is_finite yv then begin
let px = data_to_pixel_x sx plot_area xv in
let py = data_to_pixel_y sy plot_area yv in
let y_lo, y_hi =
match yerr with
| `Symmetric e ->
let ev = Nx.item [ i ] e in
(yv -. ev, yv +. ev)
| `Asymmetric (elo, ehi) ->
(yv -. Nx.item [ i ] elo, yv +. Nx.item [ i ] ehi)
in
let py_lo = data_to_pixel_y sy plot_area y_lo in
let py_hi = data_to_pixel_y sy plot_area y_hi in
prims := make_path [| (px, py_lo); (px, py_hi) |] :: !prims;
prims := make_path [| (px -. cap, py_hi); (px +. cap, py_hi) |] :: !prims;
prims := make_path [| (px -. cap, py_lo); (px +. cap, py_lo) |] :: !prims;
begin match xerr with
| Some xerr_val ->
let x_lo, x_hi =
match xerr_val with
| `Symmetric e ->
let ev = Nx.item [ i ] e in
(xv -. ev, xv +. ev)
| `Asymmetric (elo, ehi) ->
(xv -. Nx.item [ i ] elo, xv +. Nx.item [ i ] ehi)
in
let px_lo = data_to_pixel_x sx plot_area x_lo in
let px_hi = data_to_pixel_x sx plot_area x_hi in
prims := make_path [| (px_lo, py); (px_hi, py) |] :: !prims;
prims :=
make_path [| (px_lo, py -. cap); (px_lo, py +. cap) |] :: !prims;
prims :=
make_path [| (px_hi, py -. cap); (px_hi, py +. cap) |] :: !prims
| None -> ()
end
end
done;
List.rev !prims
let emit_heatmap_mark sx sy plot_area theme ~data ~cmap ~annotate ~vmin ~vmax
~fmt =
let shape = Nx.shape data in
let rows = shape.(0) and cols = shape.(1) in
let frows = float rows in
let lo = ref Float.infinity and hi = ref Float.neg_infinity in
for r = 0 to rows - 1 do
for c = 0 to cols - 1 do
let v = Nx.item [ r; c ] data in
if Float.is_finite v then begin
if v < !lo then lo := v;
if v > !hi then hi := v
end
done
done;
let vlo = Option.value ~default:!lo vmin in
let vhi = Option.value ~default:!hi vmax in
let vrange = if vhi = vlo then 1. else vhi -. vlo in
let cmap = Option.value ~default:theme.Theme.sequential cmap in
let sf = theme.Theme.scale_factor in
let prims = ref [] in
for r = 0 to rows - 1 do
for c = 0 to cols - 1 do
let v = Nx.item [ r; c ] data in
let t = Float.max 0. (Float.min 1. ((v -. vlo) /. vrange)) in
let cell_color = Cmap.eval cmap t in
let x0 = data_to_pixel_x sx plot_area (float c) in
let x1 = data_to_pixel_x sx plot_area (float (c + 1)) in
let y0 = data_to_pixel_y sy plot_area (frows -. float r) in
let y1 = data_to_pixel_y sy plot_area (frows -. float (r + 1)) in
let lx = Float.min x0 x1 and rx = Float.max x0 x1 in
let ty = Float.min y0 y1 and by = Float.max y0 y1 in
prims :=
Scene.Path
{
points = [| (lx, ty); (rx, ty); (rx, by); (lx, by) |];
close = true;
fill = Some cell_color;
stroke = None;
line_width = 0.;
dash = [];
}
:: !prims;
if annotate then begin
let text =
match fmt with Some f -> f v | None -> Printf.sprintf "%.2g" v
in
let text_color =
if Color.lightness cell_color > 0.65 then Color.black else Color.white
in
let cx = (lx +. rx) /. 2. in
let cy = (ty +. by) /. 2. in
let font_size =
Float.max (8. *. sf)
(Float.min
(Float.abs (rx -. lx) *. 0.4)
(Float.abs (by -. ty) *. 0.4))
in
prims :=
Scene.Text
{
x = cx;
y = cy;
content = text;
font =
{
family = theme.font_tick.family;
size = font_size;
weight = `Normal;
};
color = text_color;
anchor = `Middle;
baseline = `Middle;
angle = 0.;
}
:: !prims
end
done
done;
List.rev !prims
let emit_contour_mark sx sy plot_area theme ~data ~x0 ~x1 ~y0 ~y1 ~levels
~filled ~cmap ~color ~line_width ~alpha =
let sf = theme.Theme.scale_factor in
let contours = Prepared.prepare_contour ~x0 ~x1 ~y0 ~y1 ~data ~levels in
let n_levels = List.length contours in
let prims = ref [] in
List.iteri
(fun i cp ->
let t = if n_levels <= 1 then 0.5 else float i /. float (n_levels - 1) in
let c =
match color with
| Some c -> c
| None ->
let cmap = Option.value ~default:theme.Theme.sequential cmap in
Cmap.eval cmap t
in
let c = match alpha with Some a -> Color.with_alpha a c | None -> c in
let lw = resolve_line_width sf theme line_width in
List.iter
(fun seg ->
let points =
Array.map
(fun (dx, dy) ->
( data_to_pixel_x sx plot_area dx,
data_to_pixel_y sy plot_area dy ))
seg
in
if filled then
prims :=
Scene.Path
{
points;
close = true;
fill = Some c;
stroke = None;
line_width = 0.;
dash = [];
}
:: !prims
else
prims :=
Scene.Path
{
points;
close = false;
fill = None;
stroke = Some c;
line_width = lw;
dash = [];
}
:: !prims)
cp.Prepared.paths)
contours;
List.rev !prims
let emit_mark sx sy plot_area theme = function
| Spec.Line
{ x; y; color; line_width; line_style; step; marker; label = _; alpha } ->
emit_line_mark sx sy plot_area theme ~x ~y ~color ~line_width ~line_style
~step ~marker ~alpha
| Spec.Point
{ x; y; color; color_by; size; size_by; marker; label = _; alpha } ->
emit_point_mark sx sy plot_area theme ~x ~y ~color ~color_by ~size
~size_by ~marker ~alpha
| Spec.Bar { x; height; width; bottom; color; label = _; alpha } ->
emit_bar_mark sx sy plot_area theme ~x ~height ~width ~bottom ~color
~alpha
| Spec.Hist _ ->
failwith
"resolve: Spec.Hist reached emit_mark; should have been normalized to \
Bar by Prepared.compile"
| Spec.Image { data; extent } -> emit_image_mark sx sy plot_area ~data ~extent
| Spec.Text_mark { x; y; content; color; font_size } ->
emit_text_mark sx sy plot_area theme ~x ~y ~content ~color ~font_size
| Spec.Hline { y; color; line_width; line_style; label = _; alpha } ->
emit_hline_mark sy plot_area theme ~y ~color ~line_width ~line_style
~alpha
| Spec.Vline { x; color; line_width; line_style; label = _; alpha } ->
emit_vline_mark sx plot_area theme ~x ~color ~line_width ~line_style
~alpha
| Spec.Abline
{ slope; intercept; color; line_width; line_style; label = _; alpha } ->
emit_abline_mark sx sy plot_area theme ~slope ~intercept ~color
~line_width ~line_style ~alpha
| Spec.Fill_between { x; y1; y2; where; color; alpha; label = _ } ->
emit_fill_between_mark sx sy plot_area ~x ~y1 ~y2 ~where ~color ~alpha
| Spec.Hspan { y0; y1; color; alpha; label = _ } ->
emit_hspan_mark sy plot_area ~y0 ~y1 ~color ~alpha
| Spec.Vspan { x0; x1; color; alpha; label = _ } ->
emit_vspan_mark sx plot_area ~x0 ~x1 ~color ~alpha
| Spec.Errorbar
{ x; y; yerr; xerr; color; line_width; cap_size; label = _; alpha } ->
emit_errorbar_mark sx sy plot_area theme ~x ~y ~yerr ~xerr ~color
~line_width ~cap_size ~alpha
| Spec.Heatmap { data; cmap; annotate; vmin; vmax; fmt } ->
emit_heatmap_mark sx sy plot_area theme ~data ~cmap ~annotate ~vmin ~vmax
~fmt
| Spec.Imshow _ ->
failwith
"resolve: Spec.Imshow reached emit_mark; should have been normalized \
to Image by Prepared.compile"
| Spec.Contour
{
data;
x0;
x1;
y0;
y1;
levels;
filled;
cmap;
color;
line_width;
label = _;
alpha;
} ->
emit_contour_mark sx sy plot_area theme ~data ~x0 ~x1 ~y0 ~y1 ~levels
~filled ~cmap ~color ~line_width ~alpha
(* Axis primitives *)
let scaled_font (theme : Theme.t) (f : Theme.font) =
{ f with size = f.size *. theme.scale_factor }
let emit_axes ~text_measurer sx sy plot_area (theme : Theme.t) ~xticks ~yticks
(pp : Prepared.panel) =
let sf = theme.scale_factor in
let prims = ref [] in
let axis_color = theme.axis.color in
let lw = theme.axis.width *. sf in
let tl = theme.tick_length *. sf in
List.iter
(fun (v, label) ->
let px = data_to_pixel_x sx plot_area v in
let by = plot_area.ry +. plot_area.rh in
prims :=
Scene.Path
{
points = [| (px, by); (px, by +. tl) |];
close = false;
fill = None;
stroke = Some axis_color;
line_width = lw;
dash = [];
}
:: !prims;
let font = scaled_font theme theme.font_tick in
prims :=
Scene.Text
{
x = px;
y = by +. tl +. (8. *. sf);
content = label;
font;
color = axis_color;
anchor = `Middle;
baseline = `Top;
angle = 0.;
}
:: !prims)
xticks;
(* Y ticks *)
List.iter
(fun (v, label) ->
let py = data_to_pixel_y sy plot_area v in
let lx = plot_area.rx in
prims :=
Scene.Path
{
points = [| (lx -. tl, py); (lx, py) |];
close = false;
fill = None;
stroke = Some axis_color;
line_width = lw;
dash = [];
}
:: !prims;
let font = scaled_font theme theme.font_tick in
prims :=
Scene.Text
{
x = lx -. tl -. (8. *. sf);
y = py;
content = label;
font;
color = axis_color;
anchor = `End;
baseline = `Middle;
angle = 0.;
}
:: !prims)
yticks;
(* Grid *)
let show_grid = Option.value ~default:(theme.grid <> None) pp.grid_visible in
begin match theme.grid with
| Some grid_line when show_grid ->
List.iter
(fun (v, _) ->
let px = data_to_pixel_x sx plot_area v in
prims :=
Scene.Path
{
points =
[| (px, plot_area.ry); (px, plot_area.ry +. plot_area.rh) |];
close = false;
fill = None;
stroke = Some grid_line.color;
line_width = grid_line.width *. sf;
dash = grid_line.dash;
}
:: !prims)
xticks;
List.iter
(fun (v, _) ->
let py = data_to_pixel_y sy plot_area v in
prims :=
Scene.Path
{
points =
[| (plot_area.rx, py); (plot_area.rx +. plot_area.rw, py) |];
close = false;
fill = None;
stroke = Some grid_line.color;
line_width = grid_line.width *. sf;
dash = grid_line.dash;
}
:: !prims)
yticks
| _ -> ()
end;
(* Axis border *)
let show_frame = Option.value ~default:true pp.frame_visible in
if show_frame then begin
let lx = plot_area.rx and ty = plot_area.ry in
let rx = lx +. plot_area.rw and by = ty +. plot_area.rh in
prims :=
Scene.Path
{
points = [| (lx, ty); (rx, ty); (rx, by); (lx, by) |];
close = true;
fill = None;
stroke = Some axis_color;
line_width = lw;
dash = [];
}
:: !prims
end;
(* Title *)
begin match pp.title with
| Some s ->
let font = scaled_font theme theme.font_title in
let cx = plot_area.rx +. (plot_area.rw /. 2.) in
prims :=
Scene.Text
{
x = cx;
y = plot_area.ry -. (theme.title_gap *. sf);
content = s;
font;
color = axis_color;
anchor = `Middle;
baseline = `Bottom;
angle = 0.;
}
:: !prims
| None -> ()
end;
(* X label *)
begin match pp.x.label with
| Some s ->
let font = scaled_font theme theme.font_label in
let cx = plot_area.rx +. (plot_area.rw /. 2.) in
let _, tick_h =
text_measurer ~font:(scaled_font theme theme.font_tick) "0"
in
let y =
plot_area.ry +. plot_area.rh +. tl +. tick_h +. (theme.label_gap *. sf)
in
prims :=
Scene.Text
{
x = cx;
y;
content = s;
font;
color = axis_color;
anchor = `Middle;
baseline = `Top;
angle = 0.;
}
:: !prims
| None -> ()
end;
(* Y label *)
begin match pp.y.label with
| Some s ->
let font = scaled_font theme theme.font_label in
let tick_font = scaled_font theme theme.font_tick in
let max_ytick_w =
List.fold_left
(fun acc (_, label) ->
let w, _ = text_measurer ~font:tick_font label in
Float.max acc w)
0. yticks
in
let _, label_h = text_measurer ~font s in
let x =
plot_area.rx -. tl -. max_ytick_w -. (8. *. sf)
-. (theme.label_gap *. sf) -. (label_h /. 2.)
in
let y = plot_area.ry +. (plot_area.rh /. 2.) in
prims :=
Scene.Text
{
x;
y;
content = s;
font;
color = axis_color;
anchor = `Middle;
baseline = `Middle;
angle = Float.pi /. 2.;
}
:: !prims
| None -> ()
end;
List.rev !prims
(* Legend *)
type legend_kind =
| Legend_line of Spec.line_style option * Spec.marker option
| Legend_point of Spec.marker
| Legend_bar
| Legend_ref_line of Spec.line_style option
let mark_label = function
| Spec.Line { label; _ }
| Spec.Point { label; _ }
| Spec.Bar { label; _ }
| Spec.Hist { label; _ }
| Spec.Hline { label; _ }
| Spec.Vline { label; _ }
| Spec.Abline { label; _ }
| Spec.Fill_between { label; _ }
| Spec.Hspan { label; _ }
| Spec.Vspan { label; _ }
| Spec.Errorbar { label; _ }
| Spec.Contour { label; _ } ->
label
| Spec.Image _ | Spec.Text_mark _ | Spec.Heatmap _ | Spec.Imshow _ -> None
let mark_legend_kind = function
| Spec.Line { line_style; marker; _ } -> Legend_line (line_style, marker)
| Spec.Point { marker; _ } ->
Legend_point (Option.value ~default:Spec.Circle marker)
| Spec.Bar _ | Spec.Hist _ | Spec.Fill_between _ | Spec.Hspan _ | Spec.Vspan _
->
Legend_bar
| Spec.Hline { line_style; _ }
| Spec.Vline { line_style; _ }
| Spec.Abline { line_style; _ } ->
Legend_ref_line line_style
| Spec.Errorbar _ -> Legend_ref_line None
| Spec.Contour { filled = true; _ } -> Legend_bar
| Spec.Contour _ -> Legend_ref_line None
| _ -> Legend_bar
let emit_legend ~text_measurer ~loc ~ncol plot_area theme marks =
let sf = theme.Theme.scale_factor in
let entries =
List.filter_map
(fun m ->
match mark_label m with
| Some label ->
let color =
match Prepared.mark_color m with
| Some c -> c
| None -> Color.black
in
Some (label, color, mark_legend_kind m)
| None -> None)
marks
in
if entries = [] then []
else begin
let font = scaled_font theme theme.font_tick in
let swatch_size = 10. *. sf in
let gap = 4. *. sf in
let line_h = Float.max swatch_size (font.size *. 1.2) in
let margin = 8. *. sf in
let ncol = max 1 ncol in
let n_entries = List.length entries in
let nrows = (n_entries + ncol - 1) / ncol in
(* Compute per-column max label width *)
let col_widths = Array.make ncol 0. in
List.iteri
(fun i (label, _, _) ->
let col = i mod ncol in
let w, _ = text_measurer ~font label in
col_widths.(col) <- Float.max col_widths.(col) w)
entries;
let col_gap = 12. *. sf in
let col_w i = swatch_size +. gap +. col_widths.(i) in
let legend_w =
let total = ref 0. in
for i = 0 to ncol - 1 do
total := !total +. col_w i
done;
!total +. (col_gap *. float (ncol - 1))
in
let legend_h = (float nrows *. (line_h +. gap)) -. gap in
let loc = Option.value ~default:Spec.Upper_right loc in
(* x0 is the right edge of the legend area *)
let x0, y0 =
match loc with
| Spec.Upper_right ->
(plot_area.rx +. plot_area.rw -. margin, plot_area.ry +. margin)
| Spec.Upper_left ->
(plot_area.rx +. margin +. legend_w, plot_area.ry +. margin)
| Spec.Lower_right ->
( plot_area.rx +. plot_area.rw -. margin,
plot_area.ry +. plot_area.rh -. margin -. legend_h )
| Spec.Lower_left ->
( plot_area.rx +. margin +. legend_w,
plot_area.ry +. plot_area.rh -. margin -. legend_h )
| Spec.Center ->
( plot_area.rx +. ((plot_area.rw +. legend_w) /. 2.),
plot_area.ry +. ((plot_area.rh -. legend_h) /. 2.) )
| Spec.Right ->
( plot_area.rx +. plot_area.rw -. margin,
plot_area.ry +. ((plot_area.rh -. legend_h) /. 2.) )
| Spec.Upper_center ->
( plot_area.rx +. ((plot_area.rw +. legend_w) /. 2.),
plot_area.ry +. margin )
| Spec.Lower_center ->
( plot_area.rx +. ((plot_area.rw +. legend_w) /. 2.),
plot_area.ry +. plot_area.rh -. margin -. legend_h )
in
(* Background box *)
let inner_pad = 6. *. sf in
let bg_x = x0 -. legend_w -. inner_pad in
let bg_y = y0 -. inner_pad in
let bg_w = legend_w +. (2. *. inner_pad) in
let bg_h = legend_h +. (2. *. inner_pad) in
let bg =
Scene.Path
{
points =
[|
(bg_x, bg_y);
(bg_x +. bg_w, bg_y);
(bg_x +. bg_w, bg_y +. bg_h);
(bg_x, bg_y +. bg_h);
|];
close = true;
fill = Some (Color.with_alpha 0.85 theme.background);
stroke = Some (Color.with_alpha 0.3 theme.axis.color);
line_width = 1. *. sf;
dash = [];
}
in
(* Compute column x-offsets (from right edge of legend) *)
let col_offsets = Array.make ncol 0. in
let acc = ref 0. in
for c = ncol - 1 downto 0 do
col_offsets.(c) <- !acc;
acc := !acc +. col_w c +. col_gap
done;
let prims = ref [ bg ] in
List.iteri
(fun i (label, color, kind) ->
let row = i / ncol in
let col = i mod ncol in
let y = y0 +. (float row *. (line_h +. gap)) in
let y_mid = y +. (swatch_size /. 2.) in
let cx0 = x0 -. col_offsets.(col) in
begin match kind with
| Legend_line (line_style, marker) ->
prims :=
Scene.Path
{
points = [| (cx0 -. swatch_size, y_mid); (cx0, y_mid) |];
close = false;
fill = None;
stroke = Some color;
line_width = theme.line_width *. sf;
dash = resolve_dash sf line_style;
}
:: !prims;
begin match marker with
| Some shape ->
let ms = 6. *. sf in
prims :=
Scene.Markers
{
points = [| (cx0 -. (swatch_size /. 2.), y_mid) |];
shape;
size = ms;
sizes = None;
fill = Some color;
fills = None;
stroke = None;
}
:: !prims
| None -> ()
end
| Legend_point marker ->
let ms = 8. *. sf in
prims :=
Scene.Markers
{
points = [| (cx0 -. (swatch_size /. 2.), y_mid) |];
shape = marker;
size = ms;
sizes = None;
fill = Some color;
fills = None;
stroke = None;
}
:: !prims
| Legend_bar ->
prims :=
Scene.Path
{
points =
[|
(cx0 -. swatch_size, y);
(cx0, y);
(cx0, y +. swatch_size);
(cx0 -. swatch_size, y +. swatch_size);
|];
close = true;
fill = Some color;
stroke = None;
line_width = 0.;
dash = [];
}
:: !prims
| Legend_ref_line line_style ->
prims :=
Scene.Path
{
points = [| (cx0 -. swatch_size, y_mid); (cx0, y_mid) |];
close = false;
fill = None;
stroke = Some color;
line_width = theme.line_width *. sf;
dash = resolve_dash sf line_style;
}
:: !prims
end;
prims :=
Scene.Text
{
x = cx0 -. swatch_size -. gap;
y = y_mid;
content = label;
font;
color = theme.axis.color;
anchor = `End;
baseline = `Middle;
angle = 0.;
}
:: !prims)
entries;
List.rev !prims
end
(* Colorbar for color_by *)
let emit_colorbar plot_area (theme : Theme.t) ~height_frac (lo, hi) =
let sf = theme.scale_factor in
let font = scaled_font theme theme.font_tick in
let bar_w = 16. *. sf in
let bar_gap = 12. *. sf in
let bar_x = plot_area.rx +. plot_area.rw +. bar_gap in
let bar_y = plot_area.ry in
let bar_h = plot_area.rh *. height_frac in
(* Vertical gradient: series of thin horizontal strips *)
let n_strips = 64 in
let strip_h = bar_h /. float n_strips in
let strips =
List.init n_strips (fun i ->
let t = 1. -. (float i /. float (n_strips - 1)) in
let c = Cmap.eval theme.sequential t in
let sy = bar_y +. (float i *. strip_h) in
Scene.Path
{
points =
[|
(bar_x, sy);
(bar_x +. bar_w, sy);
(bar_x +. bar_w, sy +. strip_h +. 1.);
(bar_x, sy +. strip_h +. 1.);
|];
close = true;
fill = Some c;
stroke = None;
line_width = 0.;
dash = [];
})
in
(* Border around colorbar *)
let border =
Scene.Path
{
points =
[|
(bar_x, bar_y);
(bar_x +. bar_w, bar_y);
(bar_x +. bar_w, bar_y +. bar_h);
(bar_x, bar_y +. bar_h);
|];
close = true;
fill = None;
stroke = Some theme.axis.color;
line_width = theme.axis.width *. sf;
dash = [];
}
in
(* Tick labels along the right edge *)
let ticks = Ticks.generate `Linear ~lo ~hi () in
let range = hi -. lo in
let range = if range = 0. then 1. else range in
let label_x = bar_x +. bar_w +. (6. *. sf) in
let tick_prims =
List.filter_map
(fun (v, label) ->
let t = (v -. lo) /. range in
if t < -0.01 || t > 1.01 then None
else
let py = bar_y +. bar_h -. (t *. bar_h) in
Some
(Scene.Text
{
x = label_x;
y = py;
content = label;
font;
color = theme.axis.color;
anchor = `Start;
baseline = `Middle;
angle = 0.;
}))
ticks
in
strips @ [ border ] @ tick_prims
(* Size guide for size_by *)
let emit_size_guide plot_area (theme : Theme.t) ~y_offset (lo, hi) =
let sf = theme.scale_factor in
let font = scaled_font theme theme.font_tick in
let guide_gap = 12. *. sf in
let guide_x = plot_area.rx +. plot_area.rw +. guide_gap in
let max_r = 12. *. sf in
(* Three representative sizes: max, mid, min *)
let values = [| hi; (lo +. hi) /. 2.; lo |] in
let range = hi -. lo in
let range = if range = 0. then 1. else range in
let prims = ref [] in
let cy = ref (plot_area.ry +. y_offset +. max_r +. (4. *. sf)) in
Array.iter
(fun v ->
let t = (v -. lo) /. range in
let size = ((max_r *. 0.3) +. (max_r *. 0.7 *. Float.sqrt t)) *. 2. in
let cx = guide_x +. max_r in
prims :=
Scene.Markers
{
points = [| (cx, !cy) |];
shape = Spec.Circle;
size;
sizes = None;
fill = Some (Color.with_alpha 0.2 theme.axis.color);
fills = None;
stroke = Some theme.axis.color;
}
:: !prims;
let label = Printf.sprintf "%.4g" v in
let label_x = cx +. max_r +. (6. *. sf) in
prims :=
Scene.Text
{
x = label_x;
y = !cy;
content = label;
font;
color = theme.axis.color;
anchor = `Start;
baseline = `Middle;
angle = 0.;
}
:: !prims;
cy := !cy +. (max_r *. 2.) +. (8. *. sf))
values;
List.rev !prims
(* Compute layout padding *)
let compute_layout ~text_measurer (theme : Theme.t) (pp : Prepared.panel) xticks
yticks =
let sf = theme.scale_factor in
let tick_font = scaled_font theme theme.font_tick in
let label_font = scaled_font theme theme.font_label in
let title_font = scaled_font theme theme.font_title in
let tl = theme.tick_length *. sf in
(* Left padding: y-tick labels + gap + optional ylabel *)
let left =
let base = theme.padding *. sf in
match yticks with
| [] -> base
| _ ->
let max_ytick_w =
List.fold_left
(fun acc (_, label) ->
let w, _ = text_measurer ~font:tick_font label in
Float.max acc w)
0. yticks
in
base +. max_ytick_w +. tl +. (8. *. sf)
in
let left =
match pp.y.label with
| Some s ->
let _, h = text_measurer ~font:label_font s in
left +. h +. (theme.label_gap *. sf)
| None -> left
in
(* Bottom padding: x-tick labels + gap + optional xlabel *)
let bottom =
let base = theme.padding *. sf in
match xticks with
| [] -> base
| _ ->
let _, tick_h = text_measurer ~font:tick_font "0" in
base +. tick_h +. tl +. (8. *. sf)
in
let bottom =
match pp.x.label with
| Some s ->
let _, h = text_measurer ~font:label_font s in
bottom +. h +. (theme.label_gap *. sf)
| None -> bottom
in
(* Top padding: title *)
let top = theme.padding *. sf in
let top =
match pp.title with
| Some s ->
let _, h = text_measurer ~font:title_font s in
top +. h +. (theme.title_gap *. sf)
| None -> top
in
(* Right padding — extra space for colorbar / size guide *)
let right =
let base = theme.padding *. sf in
let colorbar_w =
match pp.colorbar_range with
| Some (lo, hi) ->
let bar_w = 16. *. sf in
let bar_gap = 12. *. sf in
let ticks = Ticks.generate `Linear ~lo ~hi () in
let max_label_w =
List.fold_left
(fun acc (_, label) ->
let w, _ = text_measurer ~font:tick_font label in
Float.max acc w)
0. ticks
in
bar_gap +. bar_w +. (6. *. sf) +. max_label_w +. (4. *. sf)
| None -> 0.
in
let size_guide_w =
match pp.size_by_range with
| Some (lo, hi) ->
let guide_gap = 12. *. sf in
let max_r = 12. *. sf in
let mid = (lo +. hi) /. 2. in
let max_label_w =
List.fold_left
(fun acc v ->
let w, _ =
text_measurer ~font:tick_font (Printf.sprintf "%.4g" v)
in
Float.max acc w)
0. [ lo; mid; hi ]
in
guide_gap +. (max_r *. 2.) +. (6. *. sf) +. max_label_w +. (4. *. sf)
| None -> 0.
in
base +. Float.max colorbar_w size_guide_w
in
(left, top, right, bottom)
(* Resolve a single prepared panel *)
let resolve_panel ~text_measurer theme region (pp : Prepared.panel) =
let theme = Option.value ~default:theme pp.theme_override in
let sx, xticks = Axis.make_scale_and_ticks pp.x in
let sy, yticks = Axis.make_scale_and_ticks pp.y in
let left, top, right, bottom =
compute_layout ~text_measurer theme pp xticks yticks
in
let plot_area =
{
rx = region.rx +. left;
ry = region.ry +. top;
rw = Float.max 1. (region.rw -. left -. right);
rh = Float.max 1. (region.rh -. top -. bottom);
}
in
(* Background *)
let bg =
Scene.Path
{
points =
[|
(region.rx, region.ry);
(region.rx +. region.rw, region.ry);
(region.rx +. region.rw, region.ry +. region.rh);
(region.rx, region.ry +. region.rh);
|];
close = true;
fill = Some theme.background;
stroke = None;
line_width = 0.;
dash = [];
}
in
(* Axes decorations *)
let axes_prims =
emit_axes ~text_measurer sx sy plot_area theme ~xticks ~yticks pp
in
(* Data marks inside clip region *)
let data_prims = List.concat_map (emit_mark sx sy plot_area theme) pp.marks in
let clipped_data =
Scene.Clip
{
x = plot_area.rx;
y = plot_area.ry;
w = plot_area.rw;
h = plot_area.rh;
children = data_prims;
}
in
(* Legend *)
let legend_prims =
emit_legend ~text_measurer ~loc:pp.legend_loc ~ncol:pp.legend_ncol plot_area
theme pp.marks
in
(* Colorbar for color_by *)
let has_both = pp.colorbar_range <> None && pp.size_by_range <> None in
let colorbar_prims =
match pp.colorbar_range with
| Some range ->
let height_frac = if has_both then 0.55 else 1. in
emit_colorbar plot_area theme ~height_frac range
| None -> []
in
(* Size guide for size_by *)
let size_guide_prims =
match pp.size_by_range with
| Some range ->
let y_offset = if has_both then plot_area.rh *. 0.6 else 0. in
emit_size_guide plot_area theme ~y_offset range
| None -> []
in
[ bg; clipped_data ] @ axes_prims @ legend_prims @ colorbar_prims
@ size_guide_prims
(* Resolve a prepared grid layout *)
let resolve_grid ~resolve_prepared ~text_measurer theme region rows gap =
let nrows = List.length rows in
let ncols =
List.fold_left (fun acc row -> max acc (List.length row)) 0 rows
in
if nrows = 0 || ncols = 0 then []
else begin
let cell_w = (region.rw -. (gap *. float (ncols - 1))) /. float ncols in
let cell_h = (region.rh -. (gap *. float (nrows - 1))) /. float nrows in
let prims = ref [] in
List.iteri
(fun ri row ->
List.iteri
(fun ci prepared ->
let cell_region =
{
rx = region.rx +. (float ci *. (cell_w +. gap));
ry = region.ry +. (float ri *. (cell_h +. gap));
rw = cell_w;
rh = cell_h;
}
in
let p =
resolve_prepared ~text_measurer theme cell_region prepared
in
prims := List.rev_append p !prims)
row)
rows;
List.rev !prims
end
(* Grid-level decorations *)
let emit_grid_decorations ~text_measurer theme region
(gd : Prepared.grid_decorations) all_marks =
let sf = theme.Theme.scale_factor in
let color = theme.axis.color in
let prims = ref [] in
let r = ref region in
(* Title: above grid *)
begin match gd.gd_title with
| Some s ->
let font = scaled_font theme theme.font_title in
let _, title_h = text_measurer ~font s in
let title_gap = theme.title_gap *. sf in
prims :=
Scene.Text
{
x = !r.rx +. (!r.rw /. 2.);
y = !r.ry +. title_h;
content = s;
font;
color;
anchor = `Middle;
baseline = `Bottom;
angle = 0.;
}
:: !prims;
let used = title_h +. title_gap in
r := { !r with ry = !r.ry +. used; rh = !r.rh -. used }
| None -> ()
end;
(* Xlabel: below grid *)
begin match gd.gd_xlabel with
| Some s ->
let font = scaled_font theme theme.font_label in
let _, label_h = text_measurer ~font s in
let label_gap = theme.label_gap *. sf in
let used = label_h +. label_gap in
prims :=
Scene.Text
{
x = !r.rx +. (!r.rw /. 2.);
y = !r.ry +. !r.rh -. label_gap;
content = s;
font;
color;
anchor = `Middle;
baseline = `Bottom;
angle = 0.;
}
:: !prims;
r := { !r with rh = !r.rh -. used }
| None -> ()
end;
(* Ylabel: left of grid, rotated *)
begin match gd.gd_ylabel with
| Some s ->
let font = scaled_font theme theme.font_label in
let _, label_h = text_measurer ~font s in
let label_gap = theme.label_gap *. sf in
let used = label_h +. label_gap in
prims :=
Scene.Text
{
x = !r.rx +. (label_h /. 2.);
y = !r.ry +. (!r.rh /. 2.);
content = s;
font;
color;
anchor = `Middle;
baseline = `Middle;
angle = Float.pi /. 2.;
}
:: !prims;
r := { !r with rx = !r.rx +. used; rw = !r.rw -. used }
| None -> ()
end;
(* Shared legend *)
let legend_prims =
match gd.gd_legend_loc with
| Some loc ->
emit_legend ~text_measurer ~loc:(Some loc)
~ncol:gd.Prepared.gd_legend_ncol !r theme all_marks
| None -> []
in
(List.rev !prims, legend_prims, !r)
(* Top-level resolve from Prepared.t *)
let rec resolve_tree ~text_measurer theme region = function
| Prepared.Panel pp -> resolve_panel ~text_measurer theme region pp
| Prepared.Grid { rows; gap } ->
let gap_px = gap *. Float.min region.rw region.rh in
resolve_grid ~resolve_prepared:resolve_tree ~text_measurer theme region
rows gap_px
| Prepared.Decorated_grid { decorations; inner; all_marks } ->
let theme = Option.value ~default:theme decorations.gd_theme_override in
let dec_prims, legend_prims, grid_region =
emit_grid_decorations ~text_measurer theme region decorations all_marks
in
dec_prims
@ resolve_tree ~text_measurer theme grid_region inner
@ legend_prims
let resolve_prepared ~text_measurer ~theme ~width ~height prepared =
let region = { rx = 0.; ry = 0.; rw = width; rh = height } in
let primitives = resolve_tree ~text_measurer theme region prepared in
{ Scene.width; height; primitives }
(* Convenience: compile + resolve in one step *)
let resolve ~text_measurer ~theme ~width ~height spec =
let prepared = Prepared.compile ~theme spec in
resolve_prepared ~text_measurer ~theme ~width ~height prepared
================================================
FILE: packages/hugin/lib/resolve.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Specification to scene resolution.
{b Internal module.} Walks a {!Spec.t} tree, computes data bounds and
layout, and emits a {!Scene.t} with all coordinates in device pixels. *)
type text_measurer = font:Theme.font -> string -> float * float
(** The type for text measurers. Returns [(width, height)] for a string rendered
in the given font. *)
val resolve_prepared :
text_measurer:text_measurer ->
theme:Theme.t ->
width:float ->
height:float ->
Prepared.t ->
Scene.t
(** [resolve_prepared ~text_measurer ~theme ~width ~height prepared] is the
resolved scene for [prepared] at the given dimensions. Layout-only work
(pixel coordinates, text measurement) is done here; data work is already
done in {!Prepared.compile}. *)
val resolve :
text_measurer:text_measurer ->
theme:Theme.t ->
width:float ->
height:float ->
Spec.t ->
Scene.t
(** [resolve ~text_measurer ~theme ~width ~height spec] is the resolved scene
for [spec] at the given dimensions. Convenience wrapper that calls
{!Prepared.compile} then {!resolve_prepared}. *)
================================================
FILE: packages/hugin/lib/scale.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Data-to-unit mapping functions *)
type t = {
to_unit : float -> float;
from_unit : float -> float;
lo : float;
hi : float;
}
let maybe_invert invert to_unit from_unit =
if invert then ((fun v -> 1. -. to_unit v), fun u -> from_unit (1. -. u))
else (to_unit, from_unit)
let linear ?(invert = false) ~lo ~hi () =
let range = hi -. lo in
let range = if range = 0. then 1. else range in
let to_unit, from_unit =
maybe_invert invert
(fun v -> (v -. lo) /. range)
(fun u -> lo +. (u *. range))
in
{ to_unit; from_unit; lo; hi }
let log ?(invert = false) ~lo ~hi () =
let lo_log = Float.log10 (Float.max 1e-300 lo) in
let hi_log = Float.log10 (Float.max 1e-300 hi) in
let range = hi_log -. lo_log in
let range = if range = 0. then 1. else range in
let to_unit, from_unit =
maybe_invert invert
(fun v ->
if v <= 0. then Float.nan else (Float.log10 v -. lo_log) /. range)
(fun u -> Float.pow 10. (lo_log +. (u *. range)))
in
{ to_unit; from_unit; lo; hi }
let sqrt ?(invert = false) ~lo ~hi () =
let lo_s = Float.sqrt (Float.max 0. lo) in
let hi_s = Float.sqrt (Float.max 0. hi) in
let range = hi_s -. lo_s in
let range = if range = 0. then 1. else range in
let to_unit, from_unit =
maybe_invert invert
(fun v -> (Float.sqrt (Float.max 0. v) -. lo_s) /. range)
(fun u ->
let s = lo_s +. (u *. range) in
s *. s)
in
{ to_unit; from_unit; lo; hi }
let asinh ?(invert = false) ~lo ~hi () =
let lo_a = Float.asinh lo in
let hi_a = Float.asinh hi in
let range = hi_a -. lo_a in
let range = if range = 0. then 1. else range in
let to_unit, from_unit =
maybe_invert invert
(fun v -> (Float.asinh v -. lo_a) /. range)
(fun u ->
let a = lo_a +. (u *. range) in
Float.sinh a)
in
{ to_unit; from_unit; lo; hi }
let symlog ?(invert = false) ~linthresh ~lo ~hi () =
let transform v =
if Float.abs v <= linthresh then v /. linthresh
else Float.copy_sign (1. +. Float.log10 (Float.abs v /. linthresh)) v
in
let inv_transform v =
if Float.abs v <= 1. then v *. linthresh
else Float.copy_sign (linthresh *. Float.pow 10. (Float.abs v -. 1.)) v
in
let lo_t = transform lo in
let hi_t = transform hi in
let range = hi_t -. lo_t in
let range = if range = 0. then 1. else range in
let to_unit, from_unit =
maybe_invert invert
(fun v -> (transform v -. lo_t) /. range)
(fun u -> inv_transform (lo_t +. (u *. range)))
in
{ to_unit; from_unit; lo; hi }
let make ?(invert = false) kind ~lo ~hi () =
match kind with
| `Linear -> linear ~invert ~lo ~hi ()
| `Log -> log ~invert ~lo ~hi ()
| `Sqrt -> sqrt ~invert ~lo ~hi ()
| `Asinh -> asinh ~invert ~lo ~hi ()
| `Symlog linthresh -> symlog ~invert ~linthresh ~lo ~hi ()
================================================
FILE: packages/hugin/lib/scale.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Data-to-unit mapping functions.
{b Internal module.} Maps data-space values to the unit interval [[0, 1]]
for linear, logarithmic, square-root, inverse-hyperbolic-sine, and
symmetric-log scales. When [~invert] is [true], the mapping is reversed
([lo] maps to [1] and [hi] to [0]). *)
type t = {
to_unit : float -> float; (** [to_unit v] maps data value [v] to [[0, 1]]. *)
from_unit : float -> float;
(** [from_unit u] maps unit value [u] back to data space. *)
lo : float; (** Lower bound in data space. *)
hi : float; (** Upper bound in data space. *)
}
(** The type for scales. *)
val linear : ?invert:bool -> lo:float -> hi:float -> unit -> t
(** [linear ~lo ~hi ()] is a linear scale over [[lo, hi]]. *)
val log : ?invert:bool -> lo:float -> hi:float -> unit -> t
(** [log ~lo ~hi ()] is a base-10 logarithmic scale over [[lo, hi]]. *)
val sqrt : ?invert:bool -> lo:float -> hi:float -> unit -> t
(** [sqrt ~lo ~hi ()] is a square-root scale over [[lo, hi]]. Values below zero
are clamped. *)
val asinh : ?invert:bool -> lo:float -> hi:float -> unit -> t
(** [asinh ~lo ~hi ()] is an inverse-hyperbolic-sine scale over [[lo, hi]].
Transitions smoothly from linear near zero to logarithmic at large absolute
values. Handles negative values. *)
val symlog :
?invert:bool -> linthresh:float -> lo:float -> hi:float -> unit -> t
(** [symlog ~linthresh ~lo ~hi ()] is a symmetric logarithmic scale. Linear
within \[[-linthresh];[linthresh]\], logarithmic outside. *)
val make :
?invert:bool ->
[ `Linear | `Log | `Sqrt | `Asinh | `Symlog of float ] ->
lo:float ->
hi:float ->
unit ->
t
(** [make kind ~lo ~hi ()] is a scale of the given [kind]. *)
================================================
FILE: packages/hugin/lib/scene.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* Scene IR — resolved primitives in device pixels *)
type primitive =
| Path of {
points : (float * float) array;
close : bool;
fill : Color.t option;
stroke : Color.t option;
line_width : float;
dash : float list;
}
| Markers of {
points : (float * float) array;
shape : Spec.marker;
size : float;
sizes : float array option;
fill : Color.t option;
fills : Color.t array option;
stroke : Color.t option;
}
| Text of {
x : float;
y : float;
content : string;
font : Theme.font;
color : Color.t;
anchor : [ `Start | `Middle | `End ];
baseline : [ `Top | `Middle | `Bottom ];
angle : float;
}
| Image of { x : float; y : float; w : float; h : float; data : Nx.uint8_t }
| Clip of {
x : float;
y : float;
w : float;
h : float;
children : primitive list;
}
| Group of primitive list
type t = { width : float; height : float; primitives : primitive list }
let rec fold_primitive f acc = function
| Group children | Clip { children; _ } ->
List.fold_left (fold_primitive f) acc children
| p -> f acc p
let fold f scene acc = List.fold_left (fold_primitive f) acc scene.primitives
================================================
FILE: packages/hugin/lib/scene.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Scene intermediate representation.
{b Internal module.} Resolved drawing primitives in device-pixel
coordinates. All data-space concepts are gone; backends fold over these
primitives to produce output. *)
(** {1:types Types} *)
type primitive =
| Path of {
points : (float * float) array;
close : bool;
fill : Color.t option;
stroke : Color.t option;
line_width : float;
dash : float list;
}
| Markers of {
points : (float * float) array;
shape : Spec.marker;
size : float;
sizes : float array option;
fill : Color.t option;
fills : Color.t array option;
stroke : Color.t option;
}
| Text of {
x : float;
y : float;
content : string;
font : Theme.font;
color : Color.t;
anchor : [ `Start | `Middle | `End ];
baseline : [ `Top | `Middle | `Bottom ];
angle : float;
}
| Image of { x : float; y : float; w : float; h : float; data : Nx.uint8_t }
| Clip of {
x : float;
y : float;
w : float;
h : float;
children : primitive list;
}
| Group of primitive list (** The type for drawing primitives. *)
type t = { width : float; height : float; primitives : primitive list }
(** The type for resolved scenes. *)
(** {1:traversal Traversal} *)
val fold : ('a -> primitive -> 'a) -> t -> 'a -> 'a
(** [fold f scene acc] folds [f] over every leaf primitive in [scene],
descending into {!Clip} and {!Group} nodes. *)
================================================
FILE: packages/hugin/lib/spec.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
type line_style = [ `Solid | `Dashed | `Dotted | `Dash_dot ]
type marker = Circle | Square | Triangle | Plus | Star
type legend_loc =
| Upper_right
| Upper_left
| Lower_right
| Lower_left
| Center
| Right
| Upper_center
| Lower_center
type scale = [ `Linear | `Log | `Sqrt | `Asinh | `Symlog of float ]
type stretch = [ `Linear | `Log | `Sqrt | `Asinh | `Power of float ]
type mark =
| Line of {
x : Nx.float32_t;
y : Nx.float32_t;
color : Color.t option;
line_width : float option;
line_style : line_style option;
step : [ `Pre | `Post | `Mid ] option;
marker : marker option;
label : string option;
alpha : float option;
}
| Point of {
x : Nx.float32_t;
y : Nx.float32_t;
color : Color.t option;
color_by : Nx.float32_t option;
size : float option;
size_by : Nx.float32_t option;
marker : marker option;
label : string option;
alpha : float option;
}
| Bar of {
x : Nx.float32_t;
height : Nx.float32_t;
width : float option;
bottom : float;
color : Color.t option;
label : string option;
alpha : float option;
}
| Hist of {
x : Nx.float32_t;
bins : [ `Num of int | `Edges of float array ];
density : bool;
color : Color.t option;
label : string option;
}
| Image of {
data : Nx.uint8_t;
extent : (float * float * float * float) option;
}
| Text_mark of {
x : float;
y : float;
content : string;
color : Color.t option;
font_size : float option;
}
| Hline of {
y : float;
color : Color.t option;
line_width : float option;
line_style : line_style option;
label : string option;
alpha : float option;
}
| Vline of {
x : float;
color : Color.t option;
line_width : float option;
line_style : line_style option;
label : string option;
alpha : float option;
}
| Abline of {
slope : float;
intercept : float;
color : Color.t option;
line_width : float option;
line_style : line_style option;
label : string option;
alpha : float option;
}
| Fill_between of {
x : Nx.float32_t;
y1 : Nx.float32_t;
y2 : Nx.float32_t;
where : Nx.float32_t option;
color : Color.t option;
alpha : float option;
label : string option;
}
| Hspan of {
y0 : float;
y1 : float;
color : Color.t option;
alpha : float option;
label : string option;
}
| Vspan of {
x0 : float;
x1 : float;
color : Color.t option;
alpha : float option;
label : string option;
}
| Errorbar of {
x : Nx.float32_t;
y : Nx.float32_t;
yerr :
[ `Symmetric of Nx.float32_t
| `Asymmetric of Nx.float32_t * Nx.float32_t ];
xerr :
[ `Symmetric of Nx.float32_t
| `Asymmetric of Nx.float32_t * Nx.float32_t ]
option;
color : Color.t option;
line_width : float option;
cap_size : float option;
label : string option;
alpha : float option;
}
| Heatmap of {
data : Nx.float32_t;
cmap : Cmap.t option;
annotate : bool;
vmin : float option;
vmax : float option;
fmt : (float -> string) option;
}
| Imshow of {
data : Nx.float32_t;
stretch : stretch;
cmap : Cmap.t option;
vmin : float option;
vmax : float option;
}
| Contour of {
data : Nx.float32_t;
x0 : float;
x1 : float;
y0 : float;
y1 : float;
levels : [ `Num of int | `Values of float array ];
filled : bool;
cmap : Cmap.t option;
color : Color.t option;
line_width : float option;
label : string option;
alpha : float option;
}
type decoration =
| Title of string
| Xlabel of string
| Ylabel of string
| Xlim of float * float
| Ylim of float * float
| Xscale of scale
| Yscale of scale
| Xinvert
| Yinvert
| Grid_visible of bool
| Legend of legend_loc * int
| Xticks of (float * string) list
| Yticks of (float * string) list
| With_theme of Theme.t
| Xtick_format of (float -> string)
| Ytick_format of (float -> string)
| Frame of bool
type t =
| Mark of mark
| Layers of t list
| Decorated of { inner : t; decorations : decoration list }
| Grid of { rows : t list list; gap : float }
(* Mark constructors *)
let line ~x ~y ?color ?line_width ?line_style ?step ?marker ?label ?alpha () =
Mark
(Line { x; y; color; line_width; line_style; step; marker; label; alpha })
let point ~x ~y ?color ?color_by ?size ?size_by ?marker ?label ?alpha () =
Mark (Point { x; y; color; color_by; size; size_by; marker; label; alpha })
let bar ~x ~height ?width ?(bottom = 0.) ?color ?label ?alpha () =
Mark (Bar { x; height; width; bottom; color; label; alpha })
let hist ~x ?(bins = `Num 10) ?(density = false) ?color ?label () =
Mark (Hist { x; bins; density; color; label })
let image ?extent data = Mark (Image { data; extent })
let text ~x ~y s ?color ?font_size () =
Mark (Text_mark { x; y; content = s; color; font_size })
let hline ~y ?color ?line_width ?line_style ?label ?alpha () =
Mark (Hline { y; color; line_width; line_style; label; alpha })
let vline ~x ?color ?line_width ?line_style ?label ?alpha () =
Mark (Vline { x; color; line_width; line_style; label; alpha })
let abline ~slope ~intercept ?color ?line_width ?line_style ?label ?alpha () =
Mark
(Abline { slope; intercept; color; line_width; line_style; label; alpha })
let fill_between ~x ~y1 ~y2 ?where ?color ?alpha ?label () =
Mark (Fill_between { x; y1; y2; where; color; alpha; label })
let hspan ~y0 ~y1 ?color ?alpha ?label () =
Mark (Hspan { y0; y1; color; alpha; label })
let vspan ~x0 ~x1 ?color ?alpha ?label () =
Mark (Vspan { x0; x1; color; alpha; label })
let errorbar ~x ~y ~yerr ?xerr ?color ?line_width ?cap_size ?label ?alpha () =
Mark
(Errorbar { x; y; yerr; xerr; color; line_width; cap_size; label; alpha })
let heatmap ~data ?(annotate = false) ?cmap ?vmin ?vmax ?fmt () =
Mark (Heatmap { data; cmap; annotate; vmin; vmax; fmt })
let imshow ~data ?(stretch = `Linear) ?cmap ?vmin ?vmax () =
Mark (Imshow { data; stretch; cmap; vmin; vmax })
let contour ~data ~x0 ~x1 ~y0 ~y1 ?(levels = `Num 8) ?(filled = false) ?cmap
?color ?line_width ?label ?alpha () =
Mark
(Contour
{
data;
x0;
x1;
y0;
y1;
levels;
filled;
cmap;
color;
line_width;
label;
alpha;
})
(* Composition *)
let layers ts = Layers ts
(* Decorations *)
let decorate d = function
| Decorated r -> Decorated { r with decorations = d :: r.decorations }
| t -> Decorated { inner = t; decorations = [ d ] }
let title s t = decorate (Title s) t
let xlabel s t = decorate (Xlabel s) t
let ylabel s t = decorate (Ylabel s) t
let xlim lo hi t = decorate (Xlim (lo, hi)) t
let ylim lo hi t = decorate (Ylim (lo, hi)) t
let xscale s t = decorate (Xscale s) t
let yscale s t = decorate (Yscale s) t
let xinvert t = decorate Xinvert t
let yinvert t = decorate Yinvert t
let grid_lines visible t = decorate (Grid_visible visible) t
let legend ?(loc = Upper_right) ?(ncol = 1) t = decorate (Legend (loc, ncol)) t
let xticks ticks t = decorate (Xticks ticks) t
let yticks ticks t = decorate (Yticks ticks) t
let with_theme th t = decorate (With_theme th) t
let xtick_format fmt t = decorate (Xtick_format fmt) t
let ytick_format fmt t = decorate (Ytick_format fmt) t
let frame v t = decorate (Frame v) t
let no_axes t =
t |> decorate (Frame false) |> decorate (Xticks []) |> decorate (Yticks [])
|> decorate (Grid_visible false)
(* Layout *)
let grid_layout ?(gap = 0.05) rows = Grid { rows; gap }
================================================
FILE: packages/hugin/lib/spec.mli
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(** Immutable plot specifications.
{b Internal module.} The specification tree is the user-facing
representation of a plot. {!Prepared.compile} resolves data-dependent work;
{!Resolve} turns the result into a {!Scene.t}. *)
(** {1:types Types} *)
type line_style = [ `Solid | `Dashed | `Dotted | `Dash_dot ]
(** The type for line dash patterns. *)
type marker =
| Circle
| Square
| Triangle
| Plus
| Star (** The type for point marker shapes. *)
type legend_loc =
| Upper_right
| Upper_left
| Lower_right
| Lower_left
| Center
| Right
| Upper_center
| Lower_center (** The type for legend placement. *)
type scale = [ `Linear | `Log | `Sqrt | `Asinh | `Symlog of float ]
(** The type for axis scales. [`Sqrt] and [`Asinh] handle zero gracefully.
[`Symlog linthresh] is linear within \[[-linthresh];[linthresh]\] and
logarithmic outside. *)
type stretch = [ `Linear | `Log | `Sqrt | `Asinh | `Power of float ]
(** The type for image stretch functions. [`Power a] raises normalized values to
the power [a]. *)
type mark =
| Line of {
x : Nx.float32_t;
y : Nx.float32_t;
color : Color.t option;
line_width : float option;
line_style : line_style option;
step : [ `Pre | `Post | `Mid ] option;
marker : marker option;
label : string option;
alpha : float option;
}
| Point of {
x : Nx.float32_t;
y : Nx.float32_t;
color : Color.t option;
color_by : Nx.float32_t option;
size : float option;
size_by : Nx.float32_t option;
marker : marker option;
label : string option;
alpha : float option;
}
| Bar of {
x : Nx.float32_t;
height : Nx.float32_t;
width : float option;
bottom : float;
color : Color.t option;
label : string option;
alpha : float option;
}
| Hist of {
x : Nx.float32_t;
bins : [ `Num of int | `Edges of float array ];
density : bool;
color : Color.t option;
label : string option;
}
| Image of {
data : Nx.uint8_t;
extent : (float * float * float * float) option;
}
| Text_mark of {
x : float;
y : float;
content : string;
color : Color.t option;
font_size : float option;
}
| Hline of {
y : float;
color : Color.t option;
line_width : float option;
line_style : line_style option;
label : string option;
alpha : float option;
}
| Vline of {
x : float;
color : Color.t option;
line_width : float option;
line_style : line_style option;
label : string option;
alpha : float option;
}
| Abline of {
slope : float;
intercept : float;
color : Color.t option;
line_width : float option;
line_style : line_style option;
label : string option;
alpha : float option;
}
| Fill_between of {
x : Nx.float32_t;
y1 : Nx.float32_t;
y2 : Nx.float32_t;
where : Nx.float32_t option;
color : Color.t option;
alpha : float option;
label : string option;
}
| Hspan of {
y0 : float;
y1 : float;
color : Color.t option;
alpha : float option;
label : string option;
}
| Vspan of {
x0 : float;
x1 : float;
color : Color.t option;
alpha : float option;
label : string option;
}
| Errorbar of {
x : Nx.float32_t;
y : Nx.float32_t;
yerr :
[ `Symmetric of Nx.float32_t
| `Asymmetric of Nx.float32_t * Nx.float32_t ];
xerr :
[ `Symmetric of Nx.float32_t
| `Asymmetric of Nx.float32_t * Nx.float32_t ]
option;
color : Color.t option;
line_width : float option;
cap_size : float option;
label : string option;
alpha : float option;
}
| Heatmap of {
data : Nx.float32_t;
cmap : Cmap.t option;
annotate : bool;
vmin : float option;
vmax : float option;
fmt : (float -> string) option;
}
| Imshow of {
data : Nx.float32_t;
stretch : stretch;
cmap : Cmap.t option;
vmin : float option;
vmax : float option;
}
| Contour of {
data : Nx.float32_t;
x0 : float;
x1 : float;
y0 : float;
y1 : float;
levels : [ `Num of int | `Values of float array ];
filled : bool;
cmap : Cmap.t option;
color : Color.t option;
line_width : float option;
label : string option;
alpha : float option;
}
(** The type for visual marks. Each constructor carries the data arrays
and visual properties for one layer. *)
type decoration =
| Title of string
| Xlabel of string
| Ylabel of string
| Xlim of float * float
| Ylim of float * float
| Xscale of scale
| Yscale of scale
| Xinvert
| Yinvert
| Grid_visible of bool
| Legend of legend_loc * int
| Xticks of (float * string) list
| Yticks of (float * string) list
| With_theme of Theme.t
| Xtick_format of (float -> string)
| Ytick_format of (float -> string)
| Frame of bool
(** The type for plot decorations. Applied via {!Decorated} nodes. *)
type t =
| Mark of mark
| Layers of t list
| Decorated of { inner : t; decorations : decoration list }
| Grid of { rows : t list list; gap : float }
(** The type for plot specifications. An immutable tree composed via mark
constructors, {!layers}, decoration functions, and {!grid_layout}. *)
(** {1:marks Mark constructors} *)
val line :
x:Nx.float32_t ->
y:Nx.float32_t ->
?color:Color.t ->
?line_width:float ->
?line_style:line_style ->
?step:[ `Pre | `Post | `Mid ] ->
?marker:marker ->
?label:string ->
?alpha:float ->
unit ->
t
(** [line ~x ~y ()] is a line mark. *)
val point :
x:Nx.float32_t ->
y:Nx.float32_t ->
?color:Color.t ->
?color_by:Nx.float32_t ->
?size:float ->
?size_by:Nx.float32_t ->
?marker:marker ->
?label:string ->
?alpha:float ->
unit ->
t
(** [point ~x ~y ()] is a scatter mark. *)
val bar :
x:Nx.float32_t ->
height:Nx.float32_t ->
?width:float ->
?bottom:float ->
?color:Color.t ->
?label:string ->
?alpha:float ->
unit ->
t
(** [bar ~x ~height ()] is a bar mark. [bottom] defaults to [0.]. *)
val hist :
x:Nx.float32_t ->
?bins:[ `Num of int | `Edges of float array ] ->
?density:bool ->
?color:Color.t ->
?label:string ->
unit ->
t
(** [hist ~x ()] is a histogram mark. [bins] defaults to [`Num 10]. *)
val image : ?extent:float * float * float * float -> Nx.uint8_t -> t
(** [image ?extent data] is an image mark. When [extent] is
[(xmin, xmax, ymin, ymax)], the image is placed in data coordinates. *)
val text :
x:float ->
y:float ->
string ->
?color:Color.t ->
?font_size:float ->
unit ->
t
(** [text ~x ~y s ()] is a text mark at [(x, y)]. *)
val hline :
y:float ->
?color:Color.t ->
?line_width:float ->
?line_style:line_style ->
?label:string ->
?alpha:float ->
unit ->
t
(** [hline ~y ()] is a horizontal reference line. *)
val vline :
x:float ->
?color:Color.t ->
?line_width:float ->
?line_style:line_style ->
?label:string ->
?alpha:float ->
unit ->
t
(** [vline ~x ()] is a vertical reference line. *)
val abline :
slope:float ->
intercept:float ->
?color:Color.t ->
?line_width:float ->
?line_style:line_style ->
?label:string ->
?alpha:float ->
unit ->
t
(** [abline ~slope ~intercept ()] is a diagonal line [y = slope * x + intercept]
spanning the full plot area. *)
val fill_between :
x:Nx.float32_t ->
y1:Nx.float32_t ->
y2:Nx.float32_t ->
?where:Nx.float32_t ->
?color:Color.t ->
?alpha:float ->
?label:string ->
unit ->
t
(** [fill_between ~x ~y1 ~y2 ()] is a filled area between two curves. [where] is
a mask array: only fill where [where.(i) > 0.]. *)
val hspan :
y0:float ->
y1:float ->
?color:Color.t ->
?alpha:float ->
?label:string ->
unit ->
t
(** [hspan ~y0 ~y1 ()] is a horizontal shaded band. *)
val vspan :
x0:float ->
x1:float ->
?color:Color.t ->
?alpha:float ->
?label:string ->
unit ->
t
(** [vspan ~x0 ~x1 ()] is a vertical shaded band. *)
val errorbar :
x:Nx.float32_t ->
y:Nx.float32_t ->
yerr:
[ `Symmetric of Nx.float32_t | `Asymmetric of Nx.float32_t * Nx.float32_t ] ->
?xerr:
[ `Symmetric of Nx.float32_t | `Asymmetric of Nx.float32_t * Nx.float32_t ] ->
?color:Color.t ->
?line_width:float ->
?cap_size:float ->
?label:string ->
?alpha:float ->
unit ->
t
(** [errorbar ~x ~y ~yerr ()] is an error bar mark. *)
val heatmap :
data:Nx.float32_t ->
?annotate:bool ->
?cmap:Cmap.t ->
?vmin:float ->
?vmax:float ->
?fmt:(float -> string) ->
unit ->
t
(** [heatmap ~data ()] is a heatmap mark. [data] has shape [[|rows; cols|]]. *)
val imshow :
data:Nx.float32_t ->
?stretch:stretch ->
?cmap:Cmap.t ->
?vmin:float ->
?vmax:float ->
unit ->
t
(** [imshow ~data ()] is a colormapped image mark. [stretch] defaults to
[`Linear]. *)
val contour :
data:Nx.float32_t ->
x0:float ->
x1:float ->
y0:float ->
y1:float ->
?levels:[ `Num of int | `Values of float array ] ->
?filled:bool ->
?cmap:Cmap.t ->
?color:Color.t ->
?line_width:float ->
?label:string ->
?alpha:float ->
unit ->
t
(** [contour ~data ~x0 ~x1 ~y0 ~y1 ()] is a contour mark. [levels] defaults to
[`Num 8]. [filled] defaults to [false]. *)
(** {1:composition Composition} *)
val layers : t list -> t
(** [layers marks] overlays [marks] on shared axes. *)
(** {1:decorations Decorations} *)
val title : string -> t -> t
(** [title s t] adds plot title [s]. *)
val xlabel : string -> t -> t
(** [xlabel s t] adds x-axis label [s]. *)
val ylabel : string -> t -> t
(** [ylabel s t] adds y-axis label [s]. *)
val xlim : float -> float -> t -> t
(** [xlim lo hi t] fixes the x-axis range. *)
val ylim : float -> float -> t -> t
(** [ylim lo hi t] fixes the y-axis range. *)
val xscale : scale -> t -> t
(** [xscale s t] sets the x-axis scale. *)
val yscale : scale -> t -> t
(** [yscale s t] sets the y-axis scale. *)
val xinvert : t -> t
(** [xinvert t] inverts the x-axis direction (values increase right-to-left). *)
val yinvert : t -> t
(** [yinvert t] inverts the y-axis direction (values increase top-to-bottom). *)
val grid_lines : bool -> t -> t
(** [grid_lines visible t] shows or hides grid lines. *)
val legend : ?loc:legend_loc -> ?ncol:int -> t -> t
(** [legend t] shows the legend. [loc] defaults to {!Upper_right}. [ncol]
defaults to [1]; set higher for multi-column layouts. *)
val xticks : (float * string) list -> t -> t
(** [xticks ticks t] sets explicit x-axis tick positions and labels. *)
val yticks : (float * string) list -> t -> t
(** [yticks ticks t] sets explicit y-axis tick positions and labels. *)
val with_theme : Theme.t -> t -> t
(** [with_theme th t] overrides the rendering theme. *)
val xtick_format : (float -> string) -> t -> t
(** [xtick_format fmt t] formats x-axis tick labels with [fmt]. *)
val ytick_format : (float -> string) -> t -> t
(** [ytick_format fmt t] formats y-axis tick labels with [fmt]. *)
val frame : bool -> t -> t
(** [frame visible t] shows or hides the axis border rectangle. *)
val no_axes : t -> t
(** [no_axes t] hides the axis frame, ticks, and tick labels. The full panel
area is used for marks. Title is preserved. Useful for image grids. *)
(** {1:layout Layout} *)
val grid_layout : ?gap:float -> t list list -> t
(** [grid_layout rows] arranges specs in a grid. [gap] defaults to [0.05]. *)
================================================
FILE: packages/hugin/lib/svg_backend.ml
================================================
(*---------------------------------------------------------------------------
Copyright (c) 2026 The Raven authors. All rights reserved.
SPDX-License-Identifier: ISC
---------------------------------------------------------------------------*)
(* SVG backend *)
(* Text measurer *)
let text_measurer ~(font : Theme.font) s =
let w = float (String.length s) *. font.size *. 0.6 in
let h = font.size in
(w, h)
(* Helpers *)
let color_to_rgb_string c =
let r, g, b, _ = Color.to_rgba c in
Printf.sprintf "rgb(%d,%d,%d)"
(Float.to_int (r *. 255.))
(Float.to_int (g *. 255.))
(Float.to_int (b *. 255.))
let add_fill buf = function
| None -> Buffer.add_string buf " fill=\"none\""
| Some c ->
Printf.bprintf buf " fill=\"%s\"" (color_to_rgb_string c);
let a = Color.alpha c in
if a < 1. then Printf.bprintf buf " fill-opacity=\"%.3g\"" a
let add_stroke buf = function
| None -> Buffer.add_string buf " stroke=\"none\""
| Some c ->
Printf.bprintf buf " stroke=\"%s\"" (color_to_rgb_string c);
let a = Color.alpha c in
if a < 1. then Printf.bprintf buf " stroke-opacity=\"%.3g\"" a
let text_anchor_string = function
| `Start -> "start"
| `Middle -> "middle"
| `End -> "end"
let dominant_baseline_string = function
| `Top -> "text-before-edge"
| `Middle -> "central"
| `Bottom -> "text-after-edge"
let escape_xml s =
let buf = Buffer.create (String.length s) in
String.iter
(function
| '<' -> Buffer.add_string buf "<"
| '>' -> Buffer.add_string buf ">"
| '&' -> Buffer.add_string buf "&"
| '"' -> Buffer.add_string buf """
| c -> Buffer.add_char buf c)
s;
Buffer.contents buf
(* Marker shapes *)
let marker_path shape size =
let hs = size /. 2. in
match shape with
| Spec.Circle ->
Printf.sprintf "M %g 0 A %g %g 0 1 1 %g 0 A %g %g 0 1 1 %g 0 Z" (-.hs) hs
hs hs hs hs (-.hs)
| Spec.Square ->
Printf.sprintf "M %g %g L %g %g L %g %g L %g %g Z" (-.hs) (-.hs) hs (-.hs)
hs hs (-.hs) hs
| Spec.Triangle ->
Printf.sprintf "M 0 %g L %g %g L %g %g Z" (-.hs) hs hs (-.hs) hs
| Spec.Plus ->
Printf.sprintf "M %g 0 L %g 0 M 0 %g L 0 %g" (-.hs) hs (-.hs) hs
| Spec.Star ->
let d = hs *. 0.707 in
Printf.sprintf
"M %g 0 L %g 0 M 0 %g L 0 %g M %g %g L %g %g M %g %g L %g %g" (-.hs) hs
(-.hs) hs (-.d) (-.d) d d d (-.d) (-.d) d
(* Primitive rendering — ids threaded through to avoid global state *)
type ids = { mutable clip_id : int; mutable marker_id : int }
let fresh_clip ids =
ids.clip_id <- ids.clip_id + 1;
Printf.sprintf "clip-%d" ids.clip_id
let fresh_marker ids =
ids.marker_id <- ids.marker_id + 1;
Printf.sprintf "marker-%d" ids.marker_id
let rec render_primitive ids buf = function
| Scene.Path { points; close; fill; stroke; line_width; dash } ->
if Array.length points < 2 then ()
else begin
Buffer.add_string buf "
if i = 0 then Printf.bprintf buf "M %g %g" x y
else Printf.bprintf buf " L %g %g" x y)
points;
if close then Buffer.add_string buf " Z";
Buffer.add_char buf '"';
add_fill buf fill;
add_stroke buf stroke;
if line_width > 0. then
Printf.bprintf buf " stroke-width=\"%g\"" line_width;
begin match dash with
| [] -> ()
| ds ->
Buffer.add_string buf " stroke-dasharray=\"";
List.iteri
(fun i d ->
if i > 0 then Buffer.add_char buf ',';
Printf.bprintf buf "%g" d)
ds;
Buffer.add_char buf '"'
end;
Buffer.add_string buf "/>\n"
end
| Scene.Markers { points; shape; size; sizes; fill; fills; stroke } ->
let stroke_only =
match shape with Spec.Plus | Spec.Star -> true | _ -> false
in
begin match (fills, sizes) with
| None, None ->
let id = fresh_marker ids in
let d = marker_path shape size in
Printf.bprintf buf " fill | None -> stroke in
add_stroke buf stroke_c;
Printf.bprintf buf " stroke-width=\"%g\""
(Float.max 1. (size *. 0.15))
end
else begin
add_fill buf fill;
add_stroke buf stroke;
if stroke <> None then Buffer.add_string buf " stroke-width=\"1\""
end;
Buffer.add_string buf "/>\n";
Array.iter
(fun (x, y) ->
Printf.bprintf buf "