Repository: lazear/types-and-programming-languages Branch: master Commit: 0787493713b4 Files: 97 Total size: 427.6 KB Directory structure: gitextract_94czkswh/ ├── .gitattributes ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── feature_request.md │ └── workflows/ │ └── rust.yml ├── .gitignore ├── .rustfmt.toml ├── .travis.yml ├── 01_arith/ │ ├── Cargo.toml │ └── src/ │ ├── lexer.rs │ ├── main.rs │ └── parser.rs ├── 02_lambda/ │ ├── Cargo.toml │ └── src/ │ ├── context.rs │ ├── lexer.rs │ ├── main.rs │ └── parser.rs ├── 03_typedarith/ │ ├── Cargo.toml │ └── src/ │ ├── ast.rs │ ├── lexer.rs │ ├── main.rs │ └── parser.rs ├── 04_stlc/ │ ├── .gitignore │ ├── Cargo.toml │ └── src/ │ ├── eval.rs │ ├── lexer.rs │ ├── main.rs │ ├── parser.rs │ ├── term.rs │ ├── typing.rs │ └── visitor.rs ├── 05_recon/ │ ├── Cargo.toml │ └── src/ │ ├── disjoint.rs │ ├── main.rs │ ├── mutation/ │ │ ├── mod.rs │ │ └── write_once.rs │ ├── naive.rs │ ├── parser.rs │ └── types.rs ├── 06_system_f/ │ ├── Cargo.toml │ ├── README.md │ ├── src/ │ │ ├── diagnostics.rs │ │ ├── eval.rs │ │ ├── macros.rs │ │ ├── main.rs │ │ ├── patterns/ │ │ │ └── mod.rs │ │ ├── syntax/ │ │ │ ├── lexer.rs │ │ │ ├── mod.rs │ │ │ └── parser.rs │ │ ├── terms/ │ │ │ ├── mod.rs │ │ │ └── visit.rs │ │ ├── types/ │ │ │ ├── mod.rs │ │ │ ├── patterns.rs │ │ │ └── visit.rs │ │ └── visit.rs │ └── test.sf ├── 07_system_fw/ │ ├── Cargo.toml │ ├── README.md │ ├── src/ │ │ ├── diagnostics.rs │ │ ├── elaborate.rs │ │ ├── functor.rs │ │ ├── hir/ │ │ │ ├── bidir.rs │ │ │ └── mod.rs │ │ ├── macros.rs │ │ ├── main.rs │ │ ├── stack.rs │ │ ├── syntax/ │ │ │ ├── ast.rs │ │ │ ├── lexer.rs │ │ │ ├── mod.rs │ │ │ ├── parser/ │ │ │ │ ├── README.md │ │ │ │ ├── decls.rs │ │ │ │ ├── exprs.rs │ │ │ │ ├── infix.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── patterns.rs │ │ │ │ └── types.rs │ │ │ ├── tokens.rs │ │ │ └── visit/ │ │ │ ├── mod.rs │ │ │ └── types.rs │ │ ├── terms.rs │ │ ├── typecheck.rs │ │ └── types.rs │ └── test.fw ├── Cargo.toml ├── LICENSE ├── README.md ├── util/ │ ├── .gitignore │ ├── Cargo.toml │ └── src/ │ ├── arena.rs │ ├── diagnostic.rs │ ├── lib.rs │ ├── span.rs │ └── unsafe_arena.rs ├── x1_bidir/ │ ├── Cargo.toml │ └── src/ │ ├── helpers.rs │ └── main.rs └── x2_dependent/ ├── Cargo.toml └── src/ └── main.rs ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitattributes ================================================ * text=auto ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: bug assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: 1. Go to '...' 2. Click on '....' 3. Scroll down to '....' 4. See error **Expected behavior** A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your problem. **Additional context** Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: enhancement assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. ================================================ FILE: .github/workflows/rust.yml ================================================ name: Rust on: [push, pull_request] jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - name: Build run: cargo build --verbose - name: Run tests run: | cargo test --verbose cargo run --bin system_f ./06_system_f/test.sf ================================================ FILE: .gitignore ================================================ /target **/*.rs.bk .vscode/ ================================================ FILE: .rustfmt.toml ================================================ wrap_comments = true max_width = 120 ================================================ FILE: .travis.yml ================================================ language: rust rust: - stable - nightly matrix: allow_failures: - rust: nightly script: - cargo build --verbose --all - cargo test --lib notifications: email: false ================================================ FILE: 01_arith/Cargo.toml ================================================ [package] name = "arith" version = "0.1.0" authors = ["Michael Lazear "] edition = "2018" [dependencies] util = { path = "../util" } ================================================ FILE: 01_arith/src/lexer.rs ================================================ use util::span::{Location, Span, Spanned}; use std::char; use std::iter::Peekable; use std::str::Chars; #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)] pub enum Token { Int(u32), Succ, Pred, If, Then, Else, True, False, IsZero, Semicolon, LParen, RParen, Invalid, } #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)] pub struct TokenSpan { pub kind: Token, pub span: Span, } impl std::ops::Deref for TokenSpan { type Target = Token; fn deref(&self) -> &Self::Target { &self.kind } } #[derive(Clone)] pub struct Lexer<'s> { input: Peekable>, current: Location, } impl<'s> Lexer<'s> { pub fn new(input: Chars<'s>) -> Lexer<'s> { Lexer { input: input.peekable(), current: Location { line: 0, col: 0, abs: 0, }, } } fn peek(&mut self) -> Option { self.input.peek().cloned() } /// Consume the next [`char`] and advance internal source position fn consume(&mut self) -> Option { match self.input.next() { Some('\n') => { self.current.line += 1; self.current.col = 0; self.current.abs += 1; Some('\n') } Some(ch) => { self.current.col += 1; self.current.abs += 1; Some(ch) } None => None, } } fn consume_while bool>(&mut self, pred: F) -> Spanned { let mut s = String::new(); let start = self.current; while let Some(n) = self.peek() { if pred(n) { match self.consume() { Some(ch) => s.push(ch), None => break, } } else { break; } } Spanned::new(Span::new(start, self.current), s) } /// Eat whitespace fn consume_delimiter(&mut self) { let _ = self.consume_while(char::is_whitespace); } fn number(&mut self) -> Option { let Spanned { data, span } = self.consume_while(char::is_numeric); let kind = Token::Int(data.parse::().expect("only numeric chars")); Some(TokenSpan { kind, span }) } fn keyword(&mut self) -> Option { let Spanned { data, span } = self.consume_while(|ch| ch.is_ascii_alphanumeric()); let kind = match data.as_ref() { "if" => Token::If, "then" => Token::Then, "else" => Token::Else, "true" => Token::True, "false" => Token::False, "succ" => Token::Succ, "pred" => Token::Pred, "iszero" => Token::IsZero, "zero" => Token::Int(0), _ => Token::Invalid, }; Some(TokenSpan { kind, span }) } fn eat(&mut self, ch: char, token: Token) -> Option { let loc = self.current; let n = self.consume()?; let kind = if n == ch { token } else { Token::Invalid }; Some(TokenSpan { span: Span::new(loc, self.current), kind, }) } fn lex(&mut self) -> Option { self.consume_delimiter(); match self.peek()? { x if x.is_ascii_alphabetic() => self.keyword(), x if x.is_numeric() => self.number(), '(' => self.eat('(', Token::LParen), ')' => self.eat(')', Token::RParen), ';' => self.eat(';', Token::Semicolon), _ => self.eat(' ', Token::Invalid), } } } impl<'s> Iterator for Lexer<'s> { type Item = TokenSpan; fn next(&mut self) -> Option { self.lex() } } #[cfg(test)] mod test { use super::*; use Token::*; #[test] fn valid() { let input = "succ(succ(succ(0)))"; let expected = vec![Succ, LParen, Succ, LParen, Succ, LParen, Int(0), RParen, RParen, RParen]; let output = Lexer::new(input.chars()) .into_iter() .map(|t| t.kind) .collect::>(); assert_eq!(expected, output); } #[test] fn invalid() { let input = "succ(succ(succ(xyz)))"; let expected = vec![ Succ, LParen, Succ, LParen, Succ, LParen, Invalid, RParen, RParen, RParen, ]; let output = Lexer::new(input.chars()) .into_iter() .map(|t| t.kind) .collect::>(); assert_eq!(expected, output); } } ================================================ FILE: 01_arith/src/main.rs ================================================ mod lexer; mod parser; use parser::{Parser, Term}; #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)] pub enum RuntimeError { NoRuleApplies, } impl Term { pub fn is_numeric(&self) -> bool { match self { Term::TmZero => true, Term::TmSucc(t) => t.is_numeric(), _ => false, } } pub fn is_normal(&self) -> bool { match self { Term::TmZero | Term::TmTrue | Term::TmFalse => true, _ => false, } } } pub fn eval1(t: Term) -> Result { use Term::*; let res = match t { TmIf(cond, csq, alt) => match *cond { TmFalse => *alt, TmTrue => *csq, _ => TmIf(Box::new(eval1(*cond)?), csq, alt), }, TmSucc(term) => TmSucc(Box::new(eval1(*term)?)), TmPred(term) => match *term { TmZero => TmZero, TmSucc(nv) => { if nv.is_numeric() { *nv } else { return Err(RuntimeError::NoRuleApplies); } } _ => TmPred(Box::new(eval1(*term)?)), }, TmIsZero(term) => match *term { TmZero => TmTrue, TmSucc(nv) => { if nv.is_numeric() { TmFalse } else { return Err(RuntimeError::NoRuleApplies); } } _ => TmIsZero(Box::new(eval1(*term)?)), }, _ => return Err(RuntimeError::NoRuleApplies), }; Ok(res) } pub fn eval(t: Term) -> Term { let mut r = t; while let Ok(tprime) = eval1(r.clone()) { r = tprime; if r.is_normal() { break; } } r } fn main() { println!("λ"); let input = "if iszero(succ(zero)) then false else succ(4)"; let mut p = Parser::new(input); while let Some(tm) = p.parse_term() { print!("{:?} ==> ", tm); println!("{:?}", eval(tm)); } let diag = p.diagnostic(); if diag.error_count() > 0 { println!("\n{} error(s) detected while parsing!", diag.error_count()); println!("{}", diag.emit()); } } ================================================ FILE: 01_arith/src/parser.rs ================================================ use crate::lexer::{Lexer, Token}; use std::iter::Peekable; use util::diagnostic::Diagnostic; use util::span::Span; #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum Term { TmTrue, TmFalse, TmIf(Box, Box, Box), TmZero, TmSucc(Box), TmPred(Box), TmIsZero(Box), } pub struct Parser<'s> { diagnostic: Diagnostic<'s>, /// [`Lexer`] impls [`Iterator`] over [`TokenSpan`], /// so we can just directly wrap it in a [`Peekable`] lexer: Peekable>, span: Span, } impl<'s> Parser<'s> { /// Create a new [`Parser`] for the input `&str` pub fn new(input: &'s str) -> Parser<'s> { Parser { diagnostic: Diagnostic::new(input), lexer: Lexer::new(input.chars()).peekable(), span: Span::default(), } } fn consume(&mut self) -> Option { let ts = self.lexer.next()?; self.span = ts.span; Some(ts.kind) } fn expect(&mut self, token: Token) -> Option { match self.consume()? { t if t == token => Some(t), _ => None, } } fn parse_paren(&mut self) -> Option { let e = self.parse_term(); self.expect(Token::RParen); e } fn parse_if(&mut self) -> Option { let cond = self.parse_term()?; let _ = self.expect(Token::Then)?; let csq = self.parse_term()?; let _ = self.expect(Token::Else)?; let alt = self.parse_term()?; Some(Term::TmIf(Box::new(cond), Box::new(csq), Box::new(alt))) } pub fn parse_term(&mut self) -> Option { let kind = match self.consume()? { Token::False => Term::TmFalse, Token::True => Term::TmTrue, Token::Succ => Term::TmSucc(Box::new(self.parse_term()?)), Token::Pred => Term::TmPred(Box::new(self.parse_term()?)), Token::IsZero => Term::TmIsZero(Box::new(self.parse_term()?)), Token::If => return self.parse_if(), Token::LParen => return self.parse_paren(), Token::Semicolon => return self.parse_term(), Token::Int(x) => baptize(x), Token::Then | Token::Else | Token::RParen => { self.diagnostic.push("Out of place token", self.span); return self.parse_term(); } Token::Invalid => { self.diagnostic.push("Invalid token", self.span); return self.parse_term(); } }; Some(kind) } pub fn diagnostic(self) -> Diagnostic<'s> { self.diagnostic } } /// Convert from natural number to church encoding fn baptize(int: u32) -> Term { let mut num = Term::TmZero; for _ in 0..int { num = Term::TmSucc(Box::new(num)); } num } ================================================ FILE: 02_lambda/Cargo.toml ================================================ [package] name = "lambda" version = "0.1.0" authors = ["Michael Lazear "] edition = "2018" [dependencies] util = { path = "../util" } ================================================ FILE: 02_lambda/src/context.rs ================================================ use std::collections::VecDeque; #[derive(Clone, Debug, Default)] pub struct Context { inner: VecDeque, } impl Context { pub fn bind(&mut self, hint: String) -> (Context, usize) { if self.inner.contains(&hint) { self.bind(format!("{}'", hint)) } else { let mut ctx = self.clone(); let idx = ctx.size(); ctx.inner.push_front(hint); (ctx, idx) } } pub fn lookup(&self, key: String) -> Option { for (idx, s) in self.inner.iter().enumerate() { if key == *s { return Some(idx); } } None } pub fn size(&self) -> usize { self.inner.len() } } ================================================ FILE: 02_lambda/src/lexer.rs ================================================ use util::span::{Location, Span, Spanned}; use std::char; use std::iter::Peekable; use std::str::Chars; #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)] pub enum Token { Var(char), LParen, RParen, Lambda, Dot, Invalid, } #[derive(Clone)] pub struct Lexer<'s> { input: Peekable>, current: Location, } impl<'s> Lexer<'s> { pub fn new(input: Chars<'s>) -> Lexer<'s> { Lexer { input: input.peekable(), current: Location { line: 0, col: 0, abs: 0, }, } } fn peek(&mut self) -> Option { self.input.peek().copied() } /// Consume the next [`char`] and advance internal source position fn consume(&mut self) -> Option { match self.input.next() { Some('\n') => { self.current.line += 1; self.current.col = 0; self.current.abs += 1; Some('\n') } Some(ch) => { self.current.col += 1; self.current.abs += 1; Some(ch) } None => None, } } fn consume_while bool>(&mut self, pred: F) -> Spanned { let mut s = String::new(); let start = self.current; while let Some(n) = self.peek() { if pred(n) { match self.consume() { Some(ch) => s.push(ch), None => break, } } else { break; } } Spanned::new(Span::new(start, self.current), s) } /// Eat whitespace fn consume_delimiter(&mut self) { let _ = self.consume_while(char::is_whitespace); } fn eat(&mut self, ch: char, token: Token) -> Option> { let loc = self.current; let n = self.consume()?; let kind = if n == ch { token } else { Token::Invalid }; Some(Spanned::new(Span::new(loc, self.current), kind)) } fn lex(&mut self) -> Option> { self.consume_delimiter(); match self.peek()? { '(' => self.eat('(', Token::LParen), ')' => self.eat(')', Token::RParen), 'λ' => self.eat('λ', Token::Lambda), '.' => self.eat('.', Token::Dot), ch => self.eat(ch, Token::Var(ch)), } } } impl<'s> Iterator for Lexer<'s> { type Item = Spanned; fn next(&mut self) -> Option { self.lex() } } ================================================ FILE: 02_lambda/src/main.rs ================================================ mod context; mod lexer; mod parser; use parser::Parser; use context::Context; use parser::{RcTerm, Term}; fn shift1(d: isize, c: isize, tm: RcTerm) -> RcTerm { match &tm as &Term { Term::TmVar(sp, x) => { if *x as isize >= c { Term::TmVar(*sp, *x + d as usize).into() } else { Term::TmVar(*sp, *x).into() } } Term::TmAbs(sp, x) => Term::TmAbs(*sp, shift1(d, c + 1, x.clone())).into(), Term::TmApp(sp, a, b) => Term::TmApp(*sp, shift1(d, c, a.clone()), shift1(d, c, b.clone())).into(), } } fn shift(d: isize, tm: RcTerm) -> RcTerm { shift1(d, 0, tm) } fn subst_walk(j: isize, s: RcTerm, c: isize, t: RcTerm) -> RcTerm { match &t as &Term { Term::TmVar(_, x) => { if *x as isize == j + c { shift(c, s) } else { t } } Term::TmAbs(sp, tm) => Term::TmAbs(*sp, subst_walk(j, s, c + 1, tm.clone())).into(), Term::TmApp(sp, lhs, rhs) => Term::TmApp( *sp, subst_walk(j, s.clone(), c, lhs.clone()), subst_walk(j, s, c, rhs.clone()), ) .into(), } } fn subst(j: isize, s: RcTerm, tm: RcTerm) -> RcTerm { subst_walk(j, s, 0, tm) } fn term_subst_top(s: RcTerm, tm: RcTerm) -> RcTerm { shift(-1, subst(0, shift(1, s), tm)) } fn isval(_ctx: &Context, tm: RcTerm) -> bool { match &tm as &Term { Term::TmAbs(_, _) => true, _ => false, } } fn eval1(ctx: &Context, tm: RcTerm) -> RcTerm { match &tm as &Term { Term::TmApp(_, t, v) if isval(ctx, v.clone()) => { if let Term::TmAbs(_, t2) = &t as &Term { term_subst_top(v.clone(), t2.clone()) } else { panic!("No rule applies!") } } Term::TmApp(sp, v, t) if isval(ctx, v.clone()) => { let t_prime = eval1(ctx, t.clone()); Term::TmApp(*sp, v.clone(), t_prime).into() } Term::TmApp(sp, t1, t2) => { let t_prime = eval1(ctx, t1.clone()); Term::TmApp(*sp, t_prime, t2.clone()).into() } _ => panic!("No rule applies!"), } } fn main() { // let input = "(λ x. x x) (λ x. x x) λ x. λ y. y λ x. λ x. x"; // let input = "(λ x. (λ y. y) x) (λ x. x)"; let mut p = Parser::new(input); while let Some(tm) = p.parse_term() { println!("{:?}", tm); dbg!(eval1(p.ctx(), tm)); // dbg!(term_subst_top(Term::TmVar(Span::default(), 0).into(), tm)); } dbg!(p.ctx()); let diag = p.diagnostic(); if diag.error_count() > 0 { println!("\n{} error(s) detected while parsing!", diag.error_count()); println!("{}", diag.emit()); } } ================================================ FILE: 02_lambda/src/parser.rs ================================================ use crate::context::Context; use crate::lexer::{Lexer, Token}; use std::iter::Peekable; use std::ops::Deref; use std::rc::Rc; use util::diagnostic::Diagnostic; use util::span::*; #[derive(Clone, PartialEq, PartialOrd)] pub struct RcTerm(pub Rc); impl From for RcTerm { fn from(term: Term) -> RcTerm { RcTerm(Rc::new(term)) } } impl std::fmt::Debug for RcTerm { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{:?}", self.0) } } impl std::fmt::Debug for Term { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { Term::TmVar(_, v) => write!(f, "{}", v), Term::TmAbs(_, tm) => write!(f, "λ.{:?}", tm), Term::TmApp(_, t, b) => write!(f, "{:?} {:?}", t, b), } } } impl Deref for RcTerm { type Target = Term; fn deref(&self) -> &Self::Target { &self.0 } } #[derive(Clone, PartialEq, PartialOrd)] pub enum Term { TmVar(Span, usize), TmAbs(Span, RcTerm), TmApp(Span, RcTerm, RcTerm), } pub struct Parser<'s> { ctx: Context, diagnostic: Diagnostic<'s>, /// [`Lexer`] impls [`Iterator`] over [`TokenSpan`], /// so we can just directly wrap it in a [`Peekable`] lexer: Peekable>, span: Span, } impl<'s> Parser<'s> { /// Create a new [`Parser`] for the input `&str` pub fn new(input: &'s str) -> Parser<'s> { Parser { ctx: Context::default(), diagnostic: Diagnostic::new(input), lexer: Lexer::new(input.chars()).peekable(), span: Span::default(), } } fn consume(&mut self) -> Option> { let ts = self.lexer.next()?; self.span = ts.span; Some(ts) } fn expect(&mut self, token: Token) -> Option> { let spanned = self.consume()?; match spanned.data { t if t == token => Some(spanned), t => { self.diagnostic .push(format!("Expected token {:?}, found {:?}", token, t), spanned.span); None } } } fn peek(&mut self) -> Option { self.lexer.peek().map(|s| s.data) } fn lambda(&mut self) -> Option { let start = self.expect(Token::Lambda)?.span; let var = self.consume()?; // Bind variable into a new context before parsing the body // of the lambda abstraction let prev_ctx = self.ctx.clone(); let (ctx, _) = match var.data { Token::Var(ch) => { let (ctx, idx) = self.ctx.bind(format!("{}", ch)); (ctx, Term::TmVar(var.span, idx)) } x => { self.diagnostic .push(format!("Expected variable, found {:?}", x), var.span); return None; } }; self.ctx = ctx; let _ = self.expect(Token::Dot)?; let body = self.term()?; let end = self.span; // Return to previous context self.ctx = prev_ctx; Some(Term::TmAbs(start + end, body).into()) } fn term(&mut self) -> Option { match self.peek()? { Token::Lambda => self.lambda(), _ => self.application(), } } /// Parse an application of form: /// application = atom application' | atom /// application' = atom application' | empty fn application(&mut self) -> Option { let mut lhs = self.atom()?; let span = self.span; while let Some(rhs) = self.atom() { lhs = Term::TmApp(span + self.span, lhs, rhs).into(); } Some(lhs) } /// Parse an atomic term /// LPAREN term RPAREN | var fn atom(&mut self) -> Option { match self.peek()? { Token::LParen => { self.expect(Token::LParen)?; let term = self.term()?; self.expect(Token::RParen)?; Some(term) } Token::Var(ch) => { let sp = self.consume()?.span; match self.ctx.lookup(format!("{}", ch)) { Some(idx) => Some(Term::TmVar(sp, idx).into()), None => { self.diagnostic.push(format!("Unbound variable {}", ch), sp); None } } } _ => None, } } pub fn parse_term(&mut self) -> Option { self.term() } pub fn ctx(&self) -> &Context { &self.ctx } pub fn diagnostic(self) -> Diagnostic<'s> { self.diagnostic } } ================================================ FILE: 03_typedarith/Cargo.toml ================================================ [package] name = "typedarith" version = "0.1.0" authors = ["Michael Lazear "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] util = { path = "../util" } ================================================ FILE: 03_typedarith/src/ast.rs ================================================ use std::ops::Deref; use std::rc::Rc; #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)] pub enum Type { Nat, Bool, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum Term { TmTrue, TmFalse, TmIf(RcTerm, RcTerm, RcTerm), TmZero, TmSucc(RcTerm), TmPred(RcTerm), TmIsZero(RcTerm), } #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)] pub enum TyError { TypingError, } pub fn typing(tm: RcTerm) -> Result { match &tm as &Term { Term::TmTrue => Ok(Type::Bool), Term::TmFalse => Ok(Type::Bool), Term::TmZero => Ok(Type::Nat), Term::TmSucc(t) => match typing(t.clone()) { Ok(Type::Nat) => Ok(Type::Nat), _ => Err(TyError::TypingError), }, Term::TmPred(t) => match typing(t.clone()) { Ok(Type::Nat) => Ok(Type::Nat), _ => Err(TyError::TypingError), }, Term::TmIsZero(t) => match typing(t.clone()) { Ok(Type::Nat) => Ok(Type::Bool), _ => Err(TyError::TypingError), }, Term::TmIf(a, b, c) => match typing(a.clone()) { Ok(Type::Bool) => { let ty_b = typing(b.clone())?; let ty_c = typing(c.clone())?; if ty_b == ty_c { Ok(ty_b) } else { Err(TyError::TypingError) } } _ => Err(TyError::TypingError), }, } } #[derive(Clone, PartialEq, PartialOrd)] pub struct RcTerm(pub Rc); impl std::fmt::Debug for RcTerm { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{:?}", self.0) } } impl From for RcTerm { fn from(term: Term) -> RcTerm { RcTerm(Rc::new(term)) } } impl Deref for RcTerm { type Target = Term; fn deref(&self) -> &Self::Target { &self.0 } } ================================================ FILE: 03_typedarith/src/lexer.rs ================================================ use util::span::{Location, Span, Spanned}; use std::char; use std::iter::Peekable; use std::str::Chars; #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)] pub enum Token { Int(u32), Succ, Pred, If, Then, Else, True, False, IsZero, Semicolon, LParen, RParen, Invalid, } #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)] pub struct TokenSpan { pub kind: Token, pub span: Span, } impl std::ops::Deref for TokenSpan { type Target = Token; fn deref(&self) -> &Self::Target { &self.kind } } #[derive(Clone)] pub struct Lexer<'s> { input: Peekable>, current: Location, } impl<'s> Lexer<'s> { pub fn new(input: Chars<'s>) -> Lexer<'s> { Lexer { input: input.peekable(), current: Location { line: 0, col: 0, abs: 0, }, } } fn peek(&mut self) -> Option { self.input.peek().cloned() } /// Consume the next [`char`] and advance internal source position fn consume(&mut self) -> Option { match self.input.next() { Some('\n') => { self.current.line += 1; self.current.col = 0; self.current.abs += 1; Some('\n') } Some(ch) => { self.current.col += 1; self.current.abs += 1; Some(ch) } None => None, } } fn consume_while bool>(&mut self, pred: F) -> Spanned { let mut s = String::new(); let start = self.current; while let Some(n) = self.peek() { if pred(n) { match self.consume() { Some(ch) => s.push(ch), None => break, } } else { break; } } Spanned::new(Span::new(start, self.current), s) } /// Eat whitespace fn consume_delimiter(&mut self) { let _ = self.consume_while(char::is_whitespace); } fn number(&mut self) -> Option { let Spanned { data, span } = self.consume_while(char::is_numeric); let kind = Token::Int(data.parse::().expect("only numeric chars")); Some(TokenSpan { kind, span }) } fn keyword(&mut self) -> Option { let Spanned { data, span } = self.consume_while(|ch| ch.is_ascii_alphanumeric()); let kind = match data.as_ref() { "if" => Token::If, "then" => Token::Then, "else" => Token::Else, "true" => Token::True, "false" => Token::False, "succ" => Token::Succ, "pred" => Token::Pred, "iszero" => Token::IsZero, "zero" => Token::Int(0), _ => Token::Invalid, }; Some(TokenSpan { kind, span }) } fn eat(&mut self, ch: char, token: Token) -> Option { let loc = self.current; let n = self.consume()?; let kind = if n == ch { token } else { Token::Invalid }; Some(TokenSpan { span: Span::new(loc, self.current), kind, }) } fn lex(&mut self) -> Option { self.consume_delimiter(); match self.peek()? { x if x.is_ascii_alphabetic() => self.keyword(), x if x.is_numeric() => self.number(), '(' => self.eat('(', Token::LParen), ')' => self.eat(')', Token::RParen), ';' => self.eat(';', Token::Semicolon), _ => self.eat(' ', Token::Invalid), } } } impl<'s> Iterator for Lexer<'s> { type Item = TokenSpan; fn next(&mut self) -> Option { self.lex() } } #[cfg(test)] mod test { use super::*; use Token::*; #[test] fn valid() { let input = "succ(succ(succ(0)))"; let expected = vec![Succ, LParen, Succ, LParen, Succ, LParen, Int(0), RParen, RParen, RParen]; let output = Lexer::new(input.chars()) .into_iter() .map(|t| t.kind) .collect::>(); assert_eq!(expected, output); } #[test] fn invalid() { let input = "succ(succ(succ(xyz)))"; let expected = vec![ Succ, LParen, Succ, LParen, Succ, LParen, Invalid, RParen, RParen, RParen, ]; let output = Lexer::new(input.chars()) .into_iter() .map(|t| t.kind) .collect::>(); assert_eq!(expected, output); } } ================================================ FILE: 03_typedarith/src/main.rs ================================================ mod ast; mod lexer; mod parser; use ast::*; use parser::Parser; fn main() { let input = "if iszero(succ(zero)) then pred(0) else succ(4)"; let mut p = Parser::new(input); while let Some(tm) = p.parse_term() { print!("{:?} ==> ", tm); println!("{:?}", typing(tm)); } let diag = p.diagnostic(); if diag.error_count() > 0 { println!("\n{} error(s) detected while parsing!", diag.error_count()); println!("{}", diag.emit()); } } ================================================ FILE: 03_typedarith/src/parser.rs ================================================ use crate::ast::{RcTerm, Term}; use crate::lexer::{Lexer, Token}; use std::iter::Peekable; use util::diagnostic::Diagnostic; use util::span::Span; pub struct Parser<'s> { diagnostic: Diagnostic<'s>, /// [`Lexer`] impls [`Iterator`] over [`TokenSpan`], /// so we can just directly wrap it in a [`Peekable`] lexer: Peekable>, span: Span, } impl<'s> Parser<'s> { /// Create a new [`Parser`] for the input `&str` pub fn new(input: &'s str) -> Parser<'s> { Parser { diagnostic: Diagnostic::new(input), lexer: Lexer::new(input.chars()).peekable(), span: Span::default(), } } fn consume(&mut self) -> Option { let ts = self.lexer.next()?; self.span = ts.span; Some(ts.kind) } fn expect(&mut self, token: Token) -> Option { match self.consume()? { t if t == token => Some(t), _ => None, } } fn parse_paren(&mut self) -> Option { let e = self.parse_term(); self.expect(Token::RParen); e } fn parse_if(&mut self) -> Option { let cond = self.parse_term()?; let _ = self.expect(Token::Then)?; let csq = self.parse_term()?; let _ = self.expect(Token::Else)?; let alt = self.parse_term()?; Some(Term::TmIf(cond, csq, alt).into()) } pub fn parse_term(&mut self) -> Option { let kind = match self.consume()? { Token::False => Term::TmFalse, Token::True => Term::TmTrue, Token::Succ => Term::TmSucc(self.parse_term()?), Token::Pred => Term::TmPred(self.parse_term()?), Token::IsZero => Term::TmIsZero(self.parse_term()?), Token::If => return self.parse_if(), Token::LParen => return self.parse_paren(), Token::Semicolon => return self.parse_term(), Token::Int(x) => baptize(x), Token::Then | Token::Else | Token::RParen => { self.diagnostic.push("Out of place token", self.span); return self.parse_term(); } Token::Invalid => { self.diagnostic.push("Invalid token", self.span); return self.parse_term(); } }; Some(kind.into()) } pub fn diagnostic(self) -> Diagnostic<'s> { self.diagnostic } } /// Convert from natural number to church encoding fn baptize(int: u32) -> Term { let mut num = Term::TmZero; for _ in 0..int { num = Term::TmSucc(num.into()); } num } ================================================ FILE: 04_stlc/.gitignore ================================================ /target **/*.rs.bk .vscode/ ================================================ FILE: 04_stlc/Cargo.toml ================================================ [package] name = "stlc" version = "0.1.0" authors = ["Michael Lazear "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] util = { path = "../util" } ================================================ FILE: 04_stlc/src/eval.rs ================================================ use super::term::*; use super::typing::Context; use super::visitor::{Direction, MutVisitor, Shifting, Substitution}; #[derive(Debug)] pub enum Error { NoRuleApplies, } #[inline] fn subst(mut val: Term, body: &mut Term) { Shifting::new(Direction::Up).visit_term(&mut val); Substitution::new(val).visit_term(body); Shifting::new(Direction::Down).visit_term(body); } fn value(ctx: &Context, term: &Term) -> bool { match term { Term::Unit | Term::True | Term::False | Term::Abs(_, _) | Term::Zero => true, Term::Succ(t) | Term::Pred(t) | Term::IsZero(t) => value(ctx, t), Term::Record(fields) => { for field in fields { if !value(ctx, &field.term) { return false; } } true } _ => false, } } fn eval1(ctx: &Context, term: Term) -> Result, Error> { match term { Term::App(t1, t2) => { if value(ctx, &t2) { match *t1 { Term::Abs(_, mut abs) => { subst(*t2, abs.as_mut()); Ok(abs) } _ => { let t_prime = eval1(ctx, *t1)?; Ok(Term::App(t_prime, t2).into()) } } } else if value(ctx, &t1) { let t_prime = eval1(ctx, *t2)?; Ok(Term::App(t1.clone(), t_prime).into()) } else { let t_prime = eval1(ctx, *t1)?; Ok(Term::App(t_prime, t2.clone()).into()) } } Term::If(guard, csq, alt) => match &*guard { Term::True => Ok(csq), Term::False => Ok(alt), _ => { let t_prime = eval1(ctx, *guard)?; Ok(Term::If(t_prime, csq, alt).into()) } }, Term::Let(bind, mut body) => { if value(ctx, &bind) { subst(*bind, body.as_mut()); Ok(body) } else { let t = eval1(ctx, *bind)?; Ok(Term::Let(t, body).into()) } } Term::Succ(t) => { let t_prime = eval1(ctx, *t)?; Ok(Term::Succ(t_prime).into()) } Term::Pred(t) => match t.as_ref() { Term::Zero => Ok(t.clone()), Term::Succ(n) => Ok(n.clone()), _ => Ok(Term::Pred(eval1(ctx, *t)?).into()), }, Term::IsZero(t) => match t.as_ref() { Term::Zero => Ok(Term::True.into()), Term::Succ(_) => Ok(Term::False.into()), _ => Ok(Term::IsZero(eval1(ctx, *t)?).into()), }, Term::Projection(rec, proj) => { if value(ctx, &rec) { match rec.as_ref() { Term::Record(rec) => crate::term::record_access(rec, &proj).ok_or(Error::NoRuleApplies), _ => Ok(Term::Projection(eval1(ctx, *rec)?, proj).into()), } } else { Ok(Term::Projection(eval1(ctx, *rec)?, proj).into()) } } _ => Err(Error::NoRuleApplies), } } pub fn eval(ctx: &Context, term: Term) -> Result { let mut tp = term; loop { println!(" -> {}", &tp); match eval1(ctx, tp.clone()) { Ok(r) => tp = *r, Err(e) => { return Ok(tp); } } } } ================================================ FILE: 04_stlc/src/lexer.rs ================================================ use util::span::{Location, Span}; use std::char; use std::iter::Peekable; use std::str::Chars; #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum TokenKind { Ident(String), Nat(u32), TyNat, TyBool, TyArrow, TyUnit, TypeDecl, Unit, True, False, Lambda, Succ, Pred, If, Then, Else, Let, In, IsZero, Semicolon, Colon, Comma, Proj, LParen, RParen, LBrace, RBrace, Equals, Bar, Invalid(char), Eof, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Token { pub kind: TokenKind, pub span: Span, } impl Token { pub const fn new(kind: TokenKind, span: Span) -> Token { Token { kind, span } } } #[derive(Clone)] pub struct Lexer<'s> { input: Peekable>, current: Location, } impl<'s> Lexer<'s> { pub fn new(input: Chars<'s>) -> Lexer<'s> { Lexer { input: input.peekable(), current: Location { line: 0, col: 0, abs: 0, }, } } /// Peek at the next [`char`] in the input stream fn peek(&mut self) -> Option { self.input.peek().cloned() } /// Consume the next [`char`] and advance internal source position fn consume(&mut self) -> Option { match self.input.next() { Some('\n') => { self.current.line += 1; self.current.col = 0; self.current.abs += 1; Some('\n') } Some(ch) => { self.current.col += 1; self.current.abs += 1; Some(ch) } None => None, } } /// Consume characters from the input stream while pred(peek()) is true, /// collecting the characters into a string. fn consume_while bool>(&mut self, pred: F) -> (String, Span) { let mut s = String::new(); let start = self.current; while let Some(n) = self.peek() { if pred(n) { match self.consume() { Some(ch) => s.push(ch), None => break, } } else { break; } } (s, Span::new(start, self.current)) } /// Eat whitespace fn consume_delimiter(&mut self) { let _ = self.consume_while(char::is_whitespace); } /// Lex a natural number fn number(&mut self) -> Token { // Since we peeked at least one numeric char, we should always // have a string containing at least 1 single digit, as such // it is safe to call unwrap() on str::parse let (data, span) = self.consume_while(char::is_numeric); let n = data.parse::().unwrap(); Token::new(TokenKind::Nat(n), span) } /// Lex a reserved keyword or an identifier fn keyword(&mut self) -> Token { let (data, span) = self.consume_while(|ch| ch.is_ascii_alphanumeric()); let kind = match data.as_ref() { "if" => TokenKind::If, "then" => TokenKind::Then, "else" => TokenKind::Else, "true" => TokenKind::True, "false" => TokenKind::False, "succ" => TokenKind::Succ, "pred" => TokenKind::Pred, "iszero" => TokenKind::IsZero, "zero" => TokenKind::Nat(0), "Bool" => TokenKind::TyBool, "Nat" => TokenKind::TyNat, "Unit" => TokenKind::TyUnit, "unit" => TokenKind::Unit, "let" => TokenKind::Let, "in" => TokenKind::In, "type" => TokenKind::TypeDecl, _ => TokenKind::Ident(data), }; Token::new(kind, span) } /// Consume the next input character, expecting to match `ch`. /// Return a [`TokenKind::Invalid`] if the next character does not match, /// or the argument `kind` if it does fn eat(&mut self, ch: char, kind: TokenKind) -> Token { let loc = self.current; // Lexer::eat() should only be called internally after calling peek() // so we know that it's safe to unwrap the result of Lexer::consume() let n = self.consume().unwrap(); let kind = if n == ch { kind } else { TokenKind::Invalid(n) }; Token::new(kind, Span::new(loc, self.current)) } /// Return the next lexeme in the input as a [`Token`] pub fn lex(&mut self) -> Token { self.consume_delimiter(); let next = match self.peek() { Some(ch) => ch, None => return Token::new(TokenKind::Eof, Span::dummy()), }; match next { x if x.is_ascii_alphabetic() => self.keyword(), x if x.is_numeric() => self.number(), '(' => self.eat('(', TokenKind::LParen), ')' => self.eat(')', TokenKind::RParen), ';' => self.eat(';', TokenKind::Semicolon), ':' => self.eat(':', TokenKind::Colon), ',' => self.eat(',', TokenKind::Comma), '{' => self.eat('{', TokenKind::LBrace), '}' => self.eat('}', TokenKind::RBrace), '\\' => self.eat('\\', TokenKind::Lambda), 'λ' => self.eat('λ', TokenKind::Lambda), '.' => self.eat('.', TokenKind::Proj), '=' => self.eat('=', TokenKind::Equals), '|' => self.eat('|', TokenKind::Bar), '-' => { self.consume(); self.eat('>', TokenKind::TyArrow) } ch => self.eat(' ', TokenKind::Invalid(ch)), } } } impl<'s> Iterator for Lexer<'s> { type Item = Token; fn next(&mut self) -> Option { match self.lex() { Token { kind: TokenKind::Eof, .. } => None, tok => Some(tok), } } } #[cfg(test)] mod test { use super::*; use TokenKind::*; #[test] fn valid() { let input = "succ(succ(succ(0)))"; let expected = vec![Succ, LParen, Succ, LParen, Succ, LParen, Nat(0), RParen, RParen, RParen]; let output = Lexer::new(input.chars()) .into_iter() .map(|t| t.kind) .collect::>(); assert_eq!(expected, output); } #[test] fn invalid() { let input = "succ(succ(succ(xyz)))"; let expected = vec![ Succ, LParen, Succ, LParen, Succ, LParen, Ident("xyz".into()), RParen, RParen, RParen, ]; let output = Lexer::new(input.chars()) .into_iter() .map(|t| t.kind) .collect::>(); assert_eq!(expected, output); } } ================================================ FILE: 04_stlc/src/main.rs ================================================ #![allow(unused_variables)] mod eval; mod lexer; mod parser; mod term; mod typing; mod visitor; use term::Term; use typing::{Context, Type}; fn ev(ctx: &mut Context, term: Term) -> Result { let ty = match ctx.type_of(&term) { Ok(ty) => ty, Err(err) => { println!("Mistyped term {} => {:?}", term, err); return Err(eval::Error::NoRuleApplies); } }; let r = eval::eval(&ctx, term)?; // This is safe by our typing inference/induction rules // any well typed term t (checked previously) that evaluates to // a term t' [ t -> t' ] is also well typed // // Furthermore, Γ t:T, t ->* t' => t':T let ty_ = ctx.type_of(&r); // assert_eq!(ty_, ty); println!("===> {} -- {:?}\n", r, ty_); Ok(r) } fn parse(ctx: &mut Context, input: &str) { let mut p = parser::Parser::new(input); while let Some(tok) = p.parse_term() { let _ = ev(ctx, *tok); } let diag = p.diagnostic(); if diag.error_count() > 0 { println!("\n{} error(s) detected while parsing!", diag.error_count()); println!("{}", diag.emit()); } } fn main() { let mut root: Context = Context::default(); // parse( // &mut root, // "let not = (\\x: Bool. if x then false else true) in // let x = not false in // let y = not x in // if y then succ 0 else succ succ 0", // ); parse(&mut root, "let x = (\\y: Nat. y) in x"); parse(&mut root, "(\\x: Nat. (\\y: Nat. iszero x)) (succ 0) 0"); parse( &mut root, "(\\x: {a: Bool, b: Bool, c: Nat}. x.b) {a: true, b: false, c: 0}", ); // parse(&mut root, "let not = \\x: Bool. if x then false else true in {a: // 0, b: \\x: Bool. not x, c: unit}.b "); parse(&mut root, "type Struct // = {valid: Bool, number: Nat}"); parse(&mut root, "(\\x: Struct. // x.number) {valid: true, number: succ 0}"); parse( // &mut root, // "(\\x: Struct. x.number) {valid: false, number: succ 0}", // ) // dbg!(root); } ================================================ FILE: 04_stlc/src/parser.rs ================================================ use crate::lexer::{Lexer, Token, TokenKind}; use crate::term::{Field, Term}; use crate::typing::{Record, RecordField, Type}; use std::collections::VecDeque; use std::iter::Peekable; use util::diagnostic::Diagnostic; use util::span::*; #[derive(Clone, Debug, Default)] pub struct DeBruijnIndexer { inner: VecDeque, } impl DeBruijnIndexer { pub fn push(&mut self, hint: String) -> usize { if self.inner.contains(&hint) { self.push(format!("{}'", hint)) } else { let idx = self.inner.len(); self.inner.push_front(hint); idx } } pub fn pop(&mut self) { self.inner.pop_front(); } pub fn lookup(&self, key: &str) -> Option { for (idx, s) in self.inner.iter().enumerate() { if key == s { return Some(idx); } } None } } pub struct Parser<'s> { ctx: DeBruijnIndexer, diagnostic: Diagnostic<'s>, /// [`Lexer`] impls [`Iterator`] over [`TokenSpan`], /// so we can just directly wrap it in a [`Peekable`] lexer: Peekable>, span: Span, } impl<'s> Parser<'s> { /// Create a new [`Parser`] for the input `&str` pub fn new(input: &'s str) -> Parser<'s> { Parser { ctx: DeBruijnIndexer::default(), diagnostic: Diagnostic::new(input), lexer: Lexer::new(input.chars()).peekable(), span: Span::dummy(), } } fn consume(&mut self) -> Option { let ts = self.lexer.next()?; self.span = ts.span; Some(ts) } fn expect(&mut self, kind: TokenKind) -> Option { let tk = self.consume()?; match &tk.kind { t if t == &kind => Some(tk), _ => { self.diagnostic .push(format!("Expected token {:?}, found {:?}", kind, tk.kind), tk.span); None } } } fn expect_term(&mut self) -> Option> { match self.term() { Some(term) => Some(term), None => { let sp = self.peek_span(); self.diagnostic.push("Expected term".to_string(), sp); None } } } fn peek(&mut self) -> Option { self.lexer.peek().map(|tk| tk.kind.clone()) } fn peek_span(&mut self) -> Span { self.lexer.peek().map(|s| s.span).unwrap_or(self.span) } fn lambda(&mut self) -> Option> { let start = self.expect(TokenKind::Lambda)?; // Bind variable into a new context before parsing the body let var = self.ident()?; self.ctx.push(var); let _ = self.expect(TokenKind::Colon)?; let ty = self.ty()?; let _ = self.expect(TokenKind::Proj)?; let body = self.term()?; // Return to previous context self.ctx.pop(); Some(Term::Abs(ty, body).into()) } fn let_expr(&mut self) -> Option> { let start = self.expect(TokenKind::Let)?; let var = self.ident()?; self.ctx.push(var); let _ = self.expect(TokenKind::Equals)?; let bind = self.expect_term()?; let _ = self.expect(TokenKind::In)?; let body = self.expect_term()?; self.ctx.pop(); Some(Term::Let(bind, body).into()) } fn ty_record_field(&mut self) -> Option { let ident = self.ident()?; self.expect(TokenKind::Colon)?; let ty = self.ty()?; Some(RecordField { ident, ty: Box::new(ty), }) } fn ty_atom(&mut self) -> Option { match &self.peek()? { TokenKind::TyBool => { self.consume()?; Some(Type::Bool) } TokenKind::TyNat => { self.consume()?; Some(Type::Nat) } TokenKind::TyUnit => { self.consume()?; Some(Type::Unit) } TokenKind::LBrace => { self.consume()?; let mut fields = vec![self.ty_record_field()?]; while let Some(TokenKind::Comma) = self.peek() { self.expect(TokenKind::Comma)?; fields.push(self.ty_record_field()?); } self.expect(TokenKind::RBrace)?; Some(Type::Record(Record { // span, ident: String::new(), fields, })) } TokenKind::LParen => { self.consume()?; let r = self.ty()?; self.expect(TokenKind::RParen)?; Some(r) } _ => None, } } fn ty(&mut self) -> Option { let span = self.span; let mut lhs = match self.ty_atom() { Some(ty) => ty, None => { let sp = self.peek_span(); self.diagnostic.push("Expected type".to_string(), sp); return None; } }; if let Some(TokenKind::TyArrow) = self.peek() { self.consume()?; } while let Some(rhs) = self.ty_atom() { lhs = Type::Arrow(Box::new(lhs), Box::new(rhs)); if let Some(TokenKind::TyArrow) = self.peek() { self.consume()?; } else { break; } } Some(lhs) } /// Parse an application of form: /// application = atom application' | atom /// application' = atom application' | empty fn application(&mut self) -> Option> { let mut lhs = self.atom()?; let span = self.span; while let Some(rhs) = self.atom() { lhs = Term::App(lhs, rhs).into(); } if let Some(TokenKind::Proj) = self.peek() { self.expect(TokenKind::Proj)?; let accessor = self.ident()?; lhs = Term::Projection(lhs, accessor.into()).into(); } Some(lhs) } fn ident(&mut self) -> Option { let Token { kind, span } = self.consume()?; match kind { TokenKind::Ident(s) => Some(s), _ => { self.diagnostic .push(format!("Expected identifier, found {:?}", kind), span); None } } } fn record_field(&mut self) -> Option { let span = self.span; let ident = self.ident()?; self.expect(TokenKind::Colon)?; let term = self.expect_term()?; Some(Field { span: span + self.span, ident, term, }) } fn record(&mut self) -> Option> { let mut fields = vec![self.record_field()?]; let span = self.span; while let Some(TokenKind::Comma) = self.peek() { self.expect(TokenKind::Comma)?; fields.push(self.record_field()?); } Some(Term::Record(fields).into()) } fn if_expr(&mut self) -> Option> { let _ = self.expect(TokenKind::If)?; let guard = self.expect_term()?; let _ = self.expect(TokenKind::Then)?; let csq = self.expect_term()?; let _ = self.expect(TokenKind::Else)?; let alt = self.expect_term()?; Some(Term::If(guard, csq, alt).into()) } /// Parse an atomic term /// LPAREN term RPAREN | var fn atom(&mut self) -> Option> { match self.peek()? { TokenKind::True => { self.expect(TokenKind::True)?; Some(Term::True.into()) } TokenKind::False => { self.expect(TokenKind::False)?; Some(Term::False.into()) } TokenKind::If => self.if_expr(), TokenKind::Let => self.let_expr(), TokenKind::Nat(i) => { self.consume()?; Some(Term::Zero.into()) } TokenKind::Succ => { self.expect(TokenKind::Succ)?; Some(Term::Succ(self.term()?).into()) } TokenKind::Pred => { self.expect(TokenKind::Pred)?; Some(Term::Pred(self.term()?).into()) } TokenKind::IsZero => { self.expect(TokenKind::IsZero)?; Some(Term::IsZero(self.term()?).into()) } TokenKind::LParen => { self.expect(TokenKind::LParen)?; let term = self.term()?; self.expect(TokenKind::RParen)?; Some(term) } TokenKind::LBrace => { self.expect(TokenKind::LBrace)?; let term = self.record()?; self.expect(TokenKind::RBrace)?; Some(term) } TokenKind::Unit => { self.expect(TokenKind::Unit)?; Some(Term::Unit.into()) } TokenKind::Lambda => self.lambda(), TokenKind::Ident(s) => { let sp = self.consume()?.span; match self.ctx.lookup(&s) { Some(idx) => Some(Term::Var(idx).into()), None => { self.diagnostic.push(format!("Unbound variable {}", s), sp); None } } } _ => None, } } fn term(&mut self) -> Option> { match self.peek()? { // TokenKind::Lambda => self.lambda(), _ => self.application(), } } pub fn parse_term(&mut self) -> Option> { self.term() } pub fn diagnostic(self) -> Diagnostic<'s> { self.diagnostic } } ================================================ FILE: 04_stlc/src/term.rs ================================================ use crate::typing::Type; use std::fmt; use util::span::Span; #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Field { pub span: Span, pub ident: String, pub term: Box, } // pub enum Item { // Variant(VariantDecl), // Record(RecordDecl) // } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum Term { Unit, True, False, Zero, Succ(Box), Pred(Box), IsZero(Box), // DeBrujin index Var(usize), // Type of bound variable, and body of abstraction Abs(Type, Box), // Application (t1 t2) App(Box, Box), If(Box, Box, Box), Let(Box, Box), Record(Vec), Projection(Box, Box), } pub fn record_access(fields: &[Field], projection: &str) -> Option> { for f in fields { if f.ident == projection { return Some(f.term.clone()); } } None } impl fmt::Display for Term { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Term::Unit => write!(f, "unit"), Term::True => write!(f, "true"), Term::False => write!(f, "false"), Term::Zero => write!(f, "Z"), Term::Succ(t) => write!(f, "S({})", t), Term::Pred(t) => write!(f, "P({})", t), Term::IsZero(t) => write!(f, "IsZero({})", t), Term::Var(idx) => write!(f, "#{}", idx), Term::Abs(ty, body) => write!(f, "λ_:{:?}. {}", ty, body), Term::App(t1, t2) => write!(f, "({}) {}", t1, t2), Term::If(a, b, c) => write!(f, "if {} then {} else {}", a, b, c), Term::Let(bind, body) => write!(f, "let x={} in {}", bind, body), Term::Record(rec) => write!( f, "{{{}}}", rec.iter() .map(|x| format!("{}:{}", x.ident, x.term)) .collect::>() .join(",") ), Term::Projection(rec, idx) => write!(f, "{}.{}", rec, idx), } } } ================================================ FILE: 04_stlc/src/typing.rs ================================================ use crate::term::Term; use std::fmt; #[derive(Clone, PartialEq, PartialOrd)] pub enum Type { Unit, Bool, Nat, Arrow(Box, Box), Record(Record), } #[derive(Clone, PartialEq, PartialOrd)] pub struct Record { // pub span: Span, pub ident: String, pub fields: Vec, } #[derive(Clone, PartialEq, PartialOrd)] pub struct RecordField { // pub span: Span, pub ident: String, pub ty: Box, } impl fmt::Debug for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Type::Unit => write!(f, "Unit"), Type::Bool => write!(f, "Bool"), Type::Nat => write!(f, "Nat"), Type::Arrow(a, b) => write!(f, "({:?}->{:?})", a, b), Type::Record(r) => write!( f, "{} {{{}}}", r.ident, r.fields .iter() .map(|x| format!("{}:{:?}", x.ident, x.ty)) .collect::>() .join(",") ), } } } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum TypeError { Guard, ArmMismatch, ParameterMismatch, UnknownVariable(usize), ExpectedArrow, InvalidProjection, NotRecordType, } #[derive(Clone, Debug, Default)] /// A typing context, Γ /// /// Much simpler than the binding list suggested in the book, and used /// in the other directories, but this should be more efficient, and /// a vec is really overkill here pub struct Context<'a> { parent: Option<&'a Context<'a>>, ty: Option, } impl<'a> Context<'a> { pub fn add(&self, ty: Type) -> Context { if self.ty.is_none() { Context { parent: None, ty: Some(ty), } } else { Context { parent: Some(self), ty: Some(ty), } } } pub fn get(&self, idx: usize) -> Option<&Type> { if idx == 0 { self.ty.as_ref() } else if let Some(ctx) = self.parent { ctx.get(idx - 1) } else { None } } pub fn type_of(&self, term: &Term) -> Result { use Term::*; match term { Unit => Ok(Type::Unit), True => Ok(Type::Bool), False => Ok(Type::Bool), Zero => Ok(Type::Nat), Record(fields) => { let fields: Vec = fields .iter() .map(|f| { self.type_of(&f.term).map(|ty| { RecordField { // span: f.span, ident: f.ident.clone(), ty: Box::new(ty), } }) }) .collect::, TypeError>>()?; Ok(Type::Record(crate::typing::Record { // span: Span::dummy(), ident: String::new(), fields, })) } Projection(r, proj) => match self.type_of(r)? { Type::Record(self::Record { fields, .. }) => { for f in &fields { if &f.ident == proj.as_ref() { return Ok(*f.ty.clone()); } } Err(TypeError::InvalidProjection) } _ => Err(TypeError::NotRecordType), }, IsZero(t) => { if let Ok(Type::Nat) = self.type_of(t) { Ok(Type::Bool) } else { Err(TypeError::ParameterMismatch) } } Succ(t) | Pred(t) => { if let Ok(Type::Nat) = self.type_of(t) { Ok(Type::Nat) } else { Err(TypeError::ParameterMismatch) } } If(guard, csq, alt) => { if let Ok(Type::Bool) = self.type_of(guard) { let ty1 = self.type_of(csq)?; let ty2 = self.type_of(alt)?; if ty1 == ty2 { Ok(ty2) } else { Err(TypeError::ArmMismatch) } } else { Err(TypeError::Guard) } } Let(bind, body) => { let ty = self.type_of(bind)?; let ctx = self.add(ty); ctx.type_of(body) } Var(s) => match self.get(*s) { Some(ty) => Ok(ty.clone()), _ => Err(TypeError::UnknownVariable(*s)), }, Abs(ty, body) => { let ctx = self.add(ty.clone()); let ty_body = ctx.type_of(body)?; Ok(Type::Arrow(Box::new(ty.clone()), Box::new(ty_body))) } App(t1, t2) => { let ty1 = self.type_of(t1)?; let ty2 = self.type_of(t2)?; match ty1 { Type::Arrow(ty11, ty12) => { if *ty11 == ty2 { Ok(*ty12) } else { Err(TypeError::ParameterMismatch) } } _ => Err(TypeError::ExpectedArrow), } } } } } // impl<'a> Visitor for Context<'a> { // fn visit_var(&mut self, var: usize) { // self.get(var) // .cloned() // .ok_or(TypeError::UnknownVariable(var)) // } // fn visit_abs(&mut self, ty: Type, body: &Term) { // let ty = match ty { // Type::Var(name) => self // .types // .borrow() // .get(&name) // .cloned() // .ok_or(TypeError::Undefined(name))?, // x => x, // }; // let mut ctx = self.add(ty.clone()); // let ty_body: Result = body.accept(&mut ctx); // Ok(Type::Arrow(Box::new(ty), Box::new(ty_body?))) // } // fn visit_app(&mut self, t1: &Term, t2: &Term) { // let ty1 = t1.accept(self)?; // let ty2 = t2.accept(self)?; // match ty1 { // Type::Arrow(ty11, ty12) => { // if *ty11 == ty2 { // Ok(*ty12) // } else { // Err(TypeError::ParameterMismatch) // } // } // _ => Err(TypeError::ExpectedArrow), // } // } // fn visit_if( // &mut self, // guard: &Term, // csq: &Term, // alt: &Term, // ) { // if let Ok(Type::Bool) = guard.accept(self) { // let ty1 = csq.accept(self)?; // let ty2 = alt.accept(self)?; // if ty1 == ty2 { // Ok(ty2) // } else { // Err(TypeError::ArmMismatch) // } // } else { // Err(TypeError::Guard) // } // } // fn visit_let(&mut self, bind: &Term, body: &Term) { // // Dirty hack or correct behavior? // // // // We definitely need to correct var indices or how the context is // // working so that let binders can access names defined in an // // enclosing let-bound scope // let ty = bind // // .accept(&mut Shifting::new(Direction::Down)) // .accept(self)?; // let mut ctx = self.add(ty); // body.accept(&mut ctx) // } // fn visit_succ(&mut self, t: &Term) { // Ok(Type::Nat) // } // fn visit_pred(&mut self, t: &Term) { // Ok(Type::Nat) // } // fn visit_iszero(&mut self, t: &Term) { // Ok(Type::Bool) // } // fn visit_const(&mut self, c: &Term) { // match c.as_ref() { // Term::Unit => Ok(Type::Unit), // Term::Zero => Ok(Type::Nat), // Term::True | Term::False => Ok(Type::Bool), // _ => unreachable!(), // } // } // fn visit_record(&mut self, rec: &[RecordField]) { // let tys = rec // .iter() // .map(|f| f.data.accept(self).map(|ty| (f.label.clone(), ty))) // .collect::, Type)>, TypeError>>()?; // Ok(Type::Record(tys)) // } // fn visit_proj(&mut self, c: &Term, proj: Rc) { // match c.accept(self)? { // Type::Record(fields) => { // for f in &fields { // if f.0 == proj { // return Ok(f.1.clone()); // } // } // Err(TypeError::InvalidProjection) // } // _ => Err(TypeError::NotRecordType), // } // } // fn visit_typedecl(&mut self, name: Rc, ty: &Type) { // self.bind(name.to_string(), ty.clone()); // Ok(Type::Unit) // } // } ================================================ FILE: 04_stlc/src/visitor.rs ================================================ use super::*; use crate::term::{Field, Term}; use std::default::Default; pub trait Visitor: Sized { fn visit_var(&mut self, var: usize); fn visit_abs(&mut self, ty: Type, body: &Term); fn visit_app(&mut self, t1: &Term, t2: &Term); fn visit_if(&mut self, guard: &Term, csq: &Term, alt: &Term); fn visit_let(&mut self, bind: &Term, body: &Term); fn visit_succ(&mut self, t: &Term); fn visit_pred(&mut self, t: &Term); fn visit_iszero(&mut self, t: &Term); fn visit_const(&mut self, c: &Term); fn visit_record(&mut self, c: &[Field]); fn visit_proj(&mut self, c: &Term, proj: &str); fn visit_typedecl(&mut self, name: &str, ty: &Type); } pub trait MutVisitor: Sized { fn visit_var(&mut self, var: &mut Term) {} fn visit_abs(&mut self, ty: &mut Type, body: &mut Term) { self.visit_term(body); } fn visit_app(&mut self, t1: &mut Term, t2: &mut Term) { self.visit_term(t1); self.visit_term(t2); } fn visit_if(&mut self, guard: &mut Term, csq: &mut Term, alt: &mut Term) { self.visit_term(guard); self.visit_term(csq); self.visit_term(alt); } fn visit_let(&mut self, bind: &mut Term, body: &mut Term) { self.visit_term(bind); self.visit_term(body); } fn visit_succ(&mut self, t: &mut Term) { self.visit_term(t); } fn visit_pred(&mut self, t: &mut Term) { self.visit_term(t); } fn visit_iszero(&mut self, t: &mut Term) { self.visit_term(t); } fn visit_const(&mut self, t: &mut Term) {} fn visit_record(&mut self, c: &mut [Field]) { for t in c { self.visit_term(t.term.as_mut()); } } fn visit_proj(&mut self, t: &mut Term, proj: &mut String) { self.visit_term(t); } fn visit_typedecl(&mut self, name: &mut String, ty: &mut Type) {} fn visit_term(&mut self, term: &mut Term) { walk_mut_term(self, term); } } fn walk_mut_term(visitor: &mut V, var: &mut Term) { match var { Term::Unit | Term::True | Term::False | Term::Zero => visitor.visit_const(var), Term::Succ(t) => visitor.visit_succ(t), Term::Pred(t) => visitor.visit_pred(t), Term::IsZero(t) => visitor.visit_iszero(t), Term::Var(_) => visitor.visit_var(var), Term::Abs(ty, body) => visitor.visit_abs(ty, body), Term::App(t1, t2) => visitor.visit_app(t1, t2), Term::If(a, b, c) => visitor.visit_if(a, b, c), Term::Let(bind, body) => visitor.visit_let(bind, body), Term::Record(rec) => visitor.visit_record(rec), Term::Projection(rec, idx) => visitor.visit_proj(rec, idx), } } #[derive(Copy, Clone, Debug)] pub enum Direction { Up, Down, } #[derive(Copy, Clone, Debug)] pub struct Shifting { pub cutoff: usize, pub direction: Direction, } impl Default for Shifting { fn default() -> Self { Shifting { cutoff: 0, direction: Direction::Up, } } } impl Shifting { pub fn new(direction: Direction) -> Self { Shifting { cutoff: 0, direction } } } impl MutVisitor for Shifting { fn visit_var(&mut self, var: &mut Term) { let n = match var { Term::Var(n) => n, _ => unreachable!(), }; if *n >= self.cutoff { // NB: Substracting 1 from the usize here is safe, as long as // a shift Down is only called *after* a shift/substitute cycle match self.direction { Direction::Up => *n += 1, Direction::Down => *n -= 1, } } } fn visit_abs(&mut self, ty_: &mut Type, body: &mut Term) { self.cutoff += 1; self.visit_term(body); self.cutoff -= 1; } fn visit_let(&mut self, bind: &mut Term, body: &mut Term) { self.cutoff += 1; self.visit_term(bind); self.visit_term(body); self.cutoff -= 1; } } #[derive(Debug)] pub struct Substitution { pub cutoff: usize, pub term: Term, } impl Substitution { pub fn new(term: Term) -> Substitution { Substitution { cutoff: 0, term } } } impl MutVisitor for Substitution { fn visit_var(&mut self, var: &mut Term) { match var { Term::Var(n) if *n >= self.cutoff => { *var = self.term.clone(); } _ => unreachable!(), } } fn visit_abs(&mut self, ty_: &mut Type, body: &mut Term) { self.cutoff += 1; walk_mut_term(self, body); self.cutoff -= 1; } fn visit_let(&mut self, bind: &mut Term, body: &mut Term) { self.cutoff += 1; walk_mut_term(self, bind); walk_mut_term(self, body); self.cutoff -= 1; } } ================================================ FILE: 05_recon/Cargo.toml ================================================ [package] name = "recon" version = "0.1.0" authors = ["Michael Lazear "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] util = { path = "../util" } ================================================ FILE: 05_recon/src/disjoint.rs ================================================ //! A disjoint set using the union-find algorithm with path-compression use std::cell::Cell; use std::cmp::Ordering; use std::collections::HashMap; struct SetElement { data: Option, rank: Cell, parent: Cell, } pub struct DisjointSet { elements: Vec>, components: Cell, } impl Default for DisjointSet { fn default() -> Self { DisjointSet { elements: Vec::new(), components: Cell::new(0), } } } #[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Hash)] pub struct Element(usize); pub enum Choice { Left, Right, } impl DisjointSet { pub fn new() -> DisjointSet { DisjointSet { elements: Vec::new(), components: Cell::new(0), } } pub fn singleton(&mut self, data: T) -> Element { let n = self.elements.len(); let elem = SetElement { data: Some(data), rank: Cell::new(0), parent: Cell::new(n), }; self.elements.push(elem); self.components.replace(self.components.get() + 1); Element(n) } fn find_set(&self, id: usize) -> usize { // locate parent set let mut ptr = id; while ptr != self.elements[ptr].parent.get() { ptr = self.elements[ptr].parent.get(); } // id is the representative element, return if ptr == id { return id; } // perform path compression let parent = ptr; ptr = id; while ptr != self.elements[ptr].parent.get() { ptr = self.elements[ptr].parent.replace(parent); } parent } pub fn find_repr(&self, element: Element) -> Element { Element(self.find_set(element.0)) } pub fn data(&self, element: Element) -> Option<&T> { self.elements[element.0].data.as_ref() } pub fn find(&self, element: Element) -> &T { // Invariant that the representative element is always "Some" self.elements[self.find_set(element.0)] .data .as_ref() .expect("Invariant violated") } pub fn union T>(&mut self, f: F, a: Element, b: Element) { let pa = self.find_set(a.0); let pb = self.find_set(b.0); if pa == pb { return; } // Move data out first to appease borrowck let a_data = self.elements[pa].data.take().expect("Invariant violated"); let b_data = self.elements[pb].data.take().expect("Invariant violated"); self.components.replace(self.components.get() - 1); match self.elements[pa].rank.cmp(&self.elements[pb].rank) { Ordering::Equal => { self.elements[pa].data = Some(f(a_data, b_data)); self.elements[pb].parent.replace(pa); self.elements[pa].rank.replace(self.elements[pa].rank.get() + 1); } Ordering::Less => { self.elements[pb].data = Some(f(a_data, b_data)); self.elements[pa].parent.replace(pb); self.elements[pb].rank.replace(self.elements[pb].rank.get() + 1); } Ordering::Greater => { self.elements[pa].data = Some(f(a_data, b_data)); self.elements[pb].parent.replace(pa); self.elements[pa].rank.replace(self.elements[pa].rank.get() + 1); } } } pub fn partition(&self) -> Vec<&T> { let mut v = HashSet::new(); for idx in 0..self.elements.len() { v.insert(self.find_set(idx)); } v.into_iter() .map(|idx| self.elements[idx].data.as_ref().unwrap()) .collect() } } use super::*; type Variable = Element; #[derive(Debug, Clone)] pub enum Unification { Unknown(TypeVar), Constr(Tycon, Vec), } impl Unification { fn is_var(&self) -> bool { match self { Self::Unknown(_) => true, _ => false, } } } #[derive(Debug, Default)] pub struct Unifier { set: disjoint::DisjointSet, map: HashMap, } impl Unifier { pub fn new() -> Unifier { Unifier { set: DisjointSet::new(), map: HashMap::default(), } } pub fn occurs_check(&self, v: TypeVar, u: &Unification) -> bool { match u { Unification::Unknown(x) => *x == v, Unification::Constr(_, vars) => vars.iter().any(|x| self.occurs_check(v, self.set.find(*x))), } } pub fn decode(&self, uni: &Unification) -> Type { match uni { Unification::Unknown(x) => Type::Var(*x), Unification::Constr(tc, vars) => { Type::Con(*tc, vars.into_iter().map(|v| self.decode(self.set.find(*v))).collect()) } } } pub fn intern(&mut self, ty: Type) -> Variable { if let Some(v) = self.map.get(&ty) { return *v; } let v = match &ty { Type::Var(x) => self.set.singleton(Unification::Unknown(*x)), Type::Con(tc, vars) => { let vars = vars.into_iter().cloned().map(|v| self.intern(v)).collect(); self.set.singleton(Unification::Constr(*tc, vars)) } }; self.map.insert(ty, v); v } fn var_bind(&mut self, v: TypeVar, v_: Variable, u: &Unification, u_: Variable) -> Result<(), String> { if self.occurs_check(v, u) { return Err(format!("Failed occurs check {:?} {:?}", v, u)); } self.set.union( |a, b| match (a, b) { (a @ Unification::Constr(_, _), _) => a, (_, b) => b, }, u_, v_, ); Ok(()) } pub fn subst(&self) -> HashMap { let mut map = HashMap::new(); for (ty, var) in &self.map { match ty { Type::Var(x) => { map.insert(*x, self.decode(self.set.find(*var))); } _ => {} } } map } pub fn unify(&mut self, a_: Variable, b_: Variable) -> Result<(), String> { if a_ == b_ { return Ok(()); } if a_ == self.set.find_repr(b_) || b_ == self.set.find_repr(a_) { return Ok(()); } let a = self.set.find(a_).clone(); let b = self.set.find(b_).clone(); use Unification::*; match (a, b) { (Unknown(a), b) => self.var_bind(a, a_, &b, b_), (a, Unknown(b)) => self.var_bind(b, b_, &a, a_), (Constr(a, a_vars), Constr(b, b_vars)) => { if a != b { return Err(format!("Can't unify constructors {:?} and {:?}", a, b)); } if a_vars.len() != b_vars.len() { return Err(format!("Can't unify argument lists {:?} and {:?}", a_vars, b_vars)); } for (c, d) in a_vars.into_iter().zip(b_vars) { self.set.union( |a, b| match (a, b) { (a @ Unification::Constr(_, _), _) => a, (_, b) => b, }, c, d, ); } Ok(()) } } } } pub fn solve>(iter: I) -> Result, String> { let mut un = Unifier::new(); for (a, b) in iter { let a = un.intern(a); let b = un.intern(b); un.unify(a, b)?; } let mut map = HashMap::new(); for (ty, var) in &un.map { match ty { Type::Var(x) => { map.insert(*x, un.decode(un.set.find(*var))); } _ => {} } } Ok(map) } impl std::fmt::Debug for DisjointSet { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let part = self.partition(); writeln!(f, "{{")?; for values in part { write!(f, "\t{:?}\n", values)?; } writeln!(f, "}}") } } ================================================ FILE: 05_recon/src/main.rs ================================================ use std::collections::{HashMap, HashSet}; pub mod disjoint; pub mod mutation; pub mod naive; pub mod parser; pub mod types; use types::*; #[derive(Debug)] pub enum Term { Unit, Bool(bool), Int(usize), Var(usize, String), Abs(Box), App(Box, Box), Let(Box, Box), If(Box, Box, Box), } #[derive(Debug)] pub enum TypedTerm { Unit, Bool(bool), Int(usize), Var(usize, String), Abs(Box), App(Box, Box), Let(Box, Box), If(Box, Box, Box), } #[derive(Debug)] pub struct SystemF { expr: TypedTerm, ty: T, } pub enum Constraint { Eq(Type, Type), Inst(Type, Scheme), Gen(Type, Vec, Type), } #[derive(Default, Debug)] struct Elaborator { exist: TypeVar, context: Vec, constraints: Vec<(Type, Type)>, uni: disjoint::Unifier, } impl SystemF { fn new(expr: TypedTerm, ty: Type) -> SystemF { SystemF { expr, ty } } fn de(self) -> (TypedTerm, Type) { (self.expr, self.ty) } } impl Elaborator { fn fresh(&mut self) -> TypeVar { let ex = self.exist; self.exist.0 += 1; ex } fn ftv(&self) -> HashSet { let mut set = HashSet::new(); for s in &self.context { set.extend(s.ftv()); } set } fn get_scheme(&self, index: usize) -> Option<&Scheme> { for (idx, scheme) in self.context.iter().rev().enumerate() { if idx == index { return Some(scheme); } } None } fn generalize(&mut self, ty: Type) -> Scheme { let set: HashSet = ty.ftv().difference(&self.ftv()).copied().collect(); if set.is_empty() { Scheme::Mono(ty) } else { Scheme::Poly(set.into_iter().collect(), ty) } } fn instantiate(&mut self, scheme: Scheme) -> Type { match scheme { Scheme::Mono(ty) => ty, Scheme::Poly(vars, ty) => { let freshv: Vec = (0..vars.len()).map(|_| self.fresh()).collect(); let map = vars .into_iter() .zip(freshv.iter()) .map(|(v, f)| (v, Type::Var(*f))) .collect::>(); ty.apply(&map) } } } fn push(&mut self, ty: (Type, Type)) { let a = self.uni.intern(ty.0); let b = self.uni.intern(ty.1); self.uni.unify(a, b).unwrap(); } fn elaborate(&mut self, term: &Term) -> SystemF { // dbg!(term); match term { Term::Unit => SystemF::new(TypedTerm::Unit, Type::Con(T_UNIT, vec![])), Term::Bool(b) => SystemF::new(TypedTerm::Bool(*b), Type::Con(T_BOOL, vec![])), Term::Int(i) => SystemF::new(TypedTerm::Int(*i), Type::Con(T_INT, vec![])), // x has type T iff T is an instance of the type scheme associated with x Term::Var(x, s) => { let scheme = self.get_scheme(*x).cloned().expect("Unbound variable!"); let ty = self.instantiate(scheme.clone()); SystemF::new(TypedTerm::Var(*x, s.clone()), ty) } Term::Abs(body) => { let arg = self.fresh(); self.context.push(Scheme::Mono(Type::Var(arg))); let (body, ty) = self.elaborate(body).de(); self.context.pop(); let arrow = Type::arrow(Type::Var(arg), ty.clone()); SystemF::new(TypedTerm::Abs(Box::new(SystemF::new(body, ty))), arrow) } // t1 t2 has type T iff for some X2, t1 has type X2 -> T and t2 has type X2 Term::App(t1, t2) => { let (t1, ty1) = self.elaborate(t1).de(); let (t2, ty2) = self.elaborate(t2).de(); let v = self.fresh(); self.push((ty1.clone(), Type::arrow(ty2.clone(), Type::Var(v)))); SystemF::new( TypedTerm::App(Box::new(SystemF::new(t1, ty1)), Box::new(SystemF::new(t2, ty2))), Type::Var(v), ) } Term::Let(t1, t2) => { let (t1, ty1) = self.elaborate(t1).de(); // let sub = disjoint::solve(self.constraints.drain(..)).unwrap(); // for (a, b) in self.constraints.drain(..) { // let a = self.uni.intern(a); // let b = self.uni.intern(b); // self.uni.unify(a, b).unwrap(); // } let sub = self.uni.subst(); self.context = self.context.drain(..).map(|sch| sch.apply(&sub)).collect(); let scheme = self.generalize(ty1.clone().apply(&sub)); self.context.push(scheme); let (t2, ty2) = self.elaborate(t2).de(); self.context.pop(); SystemF::new( TypedTerm::Let(Box::new(SystemF::new(t1, ty1)), Box::new(SystemF::new(t2, ty2.clone()))), ty2, ) } Term::If(t1, t2, t3) => { let (t1, ty1) = self.elaborate(t1).de(); let (t2, ty2) = self.elaborate(t2).de(); let (t3, ty3) = self.elaborate(t3).de(); let fresh = self.fresh(); self.push((ty1.clone(), Type::bool())); self.push((ty2.clone(), Type::Var(fresh))); self.push((ty3.clone(), Type::Var(fresh))); SystemF::new( TypedTerm::If( Box::new(SystemF::new(t1, ty1)), Box::new(SystemF::new(t2, ty2)), Box::new(SystemF::new(t3, ty3)), ), Type::Var(fresh), ) } } } } impl TypedTerm { fn subst(self, s: &HashMap) -> TypedTerm { use TypedTerm::*; match self { Abs(a) => Abs(Box::new(a.subst(s))), App(a, b) => App(Box::new(a.subst(s)), Box::new(b.subst(s))), Let(a, b) => Let(Box::new(a.subst(s)), Box::new(b.subst(s))), If(a, b, c) => If(Box::new(a.subst(s)), Box::new(b.subst(s)), Box::new(c.subst(s))), x => x, } } } impl SystemF { fn subst(self, s: &HashMap) -> SystemF { SystemF { expr: self.expr.subst(s), ty: self.ty.apply(s), } } } fn main() { use std::io::prelude::*; use std::time::{Duration, Instant}; let input = "fn m. let y = m in let x = y true in x"; let input = " let id = fn x. x in let g = id id in let f = id true in let h = (id id) 1 in let j = id 10 in g f"; let tm = parser::Parser::new(input).parse_term().unwrap(); let start = Instant::now(); let mut gen = mutation::Elaborator::default(); let tm = gen.elaborate(&tm); // let sub = gen.uni.subst(); // let sub = disjoint.solve(gen.constraints); // let sub = disjoint::solve(gen.constraints.into_iter()); let end1 = start.elapsed().as_micros(); println!("{:?} {:?}", end1, tm); loop { let mut buffer = String::new(); print!("repl: "); std::io::stdout().flush().unwrap(); std::io::stdin().read_to_string(&mut buffer).unwrap(); // let mut gen = Elaborator::default(); match parser::Parser::new(&buffer).parse_term() { Some(tm) => { // let (tm, ty) = gen.elaborate(&tm).de(); let mut e = mutation::Elaborator::default(); dbg!(e.elaborate(&tm)); // let mut sub = HashMap::new(); // println!("{:?}", gen.constraints); // for (a, b) in &gen.constraints { // let tmp = unify(a.clone().apply(&sub), b.clone().apply(&sub)).unwrap(); // sub = compose(tmp, sub); // } // let sub = disjoint::solve(gen.constraints.clone()); // println!("{:?}", sub); // println!("tm {:#?} :{:?}", tm, ty); // println!("tm {:#?} :{:?}", tm.subst(&sub), ty.apply(&sub)); // dbg!(sub); } None => println!("parse error!"), } } } ================================================ FILE: 05_recon/src/mutation/mod.rs ================================================ use super::{Term, T_ARROW, T_BOOL, T_INT, T_UNIT}; use std::collections::{HashMap, HashSet, VecDeque}; use std::rc::Rc; mod write_once; use write_once::WriteOnce; #[derive(Debug, Clone, PartialEq)] pub struct TypeVar { exist: usize, data: Rc>, } #[derive(Debug, Clone, PartialEq)] pub enum Type { Var(TypeVar), Con(super::Tycon, Vec), } #[derive(Debug, Clone)] pub enum Scheme { Mono(Type), Poly(Vec, Type), } #[derive(Debug)] pub enum TypedTerm { Unit, Bool(bool), Int(usize), Var(usize, String), Abs(Box), App(Box, Box), Let(Box, Box), If(Box, Box, Box), } #[derive(Debug)] pub struct SystemF { expr: TypedTerm, ty: Type, } impl SystemF { fn new(expr: TypedTerm, ty: Type) -> SystemF { SystemF { expr, ty } } } impl Type { fn ftv(&self, rank: usize) -> HashSet { let mut set = HashSet::new(); let mut queue = VecDeque::new(); queue.push_back(self); while let Some(ty) = queue.pop_front() { match ty { Type::Var(x) => match x.data.get() { None => { if x.data.get_rank() > rank { set.insert(x.exist); } } Some(link) => { queue.push_back(link); } }, Type::Con(_, tys) => { for ty in tys { queue.push_back(ty); } } } } set } fn apply(self, map: &HashMap) -> Type { match self { Type::Var(x) => match x.data.get() { Some(ty) => ty.clone().apply(map), None => map.get(&x.exist).cloned().unwrap_or(Type::Var(x)), }, Type::Con(tc, vars) => Type::Con(tc, vars.into_iter().map(|ty| ty.apply(map)).collect()), } } } impl Type { pub fn arrow(a: Type, b: Type) -> Type { Type::Con(T_ARROW, vec![a, b]) } pub fn bool() -> Type { Type::Con(T_BOOL, vec![]) } pub fn de_arrow(&self) -> (&Type, &Type) { match self { Type::Con(T_ARROW, v) => (&v[0], &v[1]), _ => panic!("Not arrow type! {:?}", self), } } } pub fn occurs_check(v: &TypeVar, ty: &Type) -> bool { match ty { Type::Var(x) => { if let Some(info) = x.data.get() { occurs_check(v, &info) } else { let min_rank = x.data.get_rank().min(v.data.get_rank()); if min_rank != x.data.get_rank() { println!("promoting type var {:?} {}->{}", x, x.data.get_rank(), min_rank); x.data.set_rank(min_rank); } x.exist == v.exist } } Type::Con(_, vars) => vars.iter().any(|x| occurs_check(v, x)), } } fn var_bind(v: &TypeVar, ty: &Type) -> Result<(), String> { if occurs_check(&v, ty) { return Err(format!("Failed occurs check {:?} {:?}", v, ty)); } v.data.set(ty.clone()).unwrap(); Ok(()) } fn unify_type(a: &Type, b: &Type) -> Result<(), String> { match (a, b) { (Type::Var(a), b) => match a.data.get() { Some(ty) => unify_type(ty, b), None => var_bind(a, b), }, (a, Type::Var(b)) => match b.data.get() { Some(ty) => unify_type(a, ty), None => var_bind(b, a), }, (Type::Con(a, a_args), Type::Con(b, b_args)) => { if a != b { return Err(format!("Can't unify constructors {:?} and {:?}", a, b)); } if a_args.len() != b_args.len() { return Err(format!("Can't unify argument lists {:?} and {:?}", a_args, b_args)); } for (c, d) in a_args.into_iter().zip(b_args) { unify_type(c, d)?; } Ok(()) } } } #[derive(Default, Debug)] pub struct Elaborator { exist: usize, rank: usize, context: Vec, } impl Elaborator { fn fresh(&mut self) -> TypeVar { let ex = self.exist; self.exist += 1; TypeVar { exist: ex, data: Rc::new(WriteOnce::with_rank(self.rank)), } } fn get_scheme(&self, index: usize) -> Option<&Scheme> { for (idx, scheme) in self.context.iter().rev().enumerate() { if idx == index { return Some(scheme); } } None } fn generalize(&mut self, ty: Type) -> Scheme { let set: HashSet = ty.ftv(self.rank); if set.is_empty() { Scheme::Mono(ty) } else { Scheme::Poly(set.into_iter().collect(), ty) } } fn instantiate(&mut self, scheme: Scheme) -> Type { match scheme { Scheme::Mono(ty) => ty, Scheme::Poly(vars, ty) => { let map = vars .into_iter() .map(|v| (v, Type::Var(self.fresh()))) .collect::>(); ty.apply(&map) } } } pub fn elaborate(&mut self, term: &Term) -> SystemF { match term { Term::Unit => SystemF::new(TypedTerm::Unit, Type::Con(T_UNIT, vec![])), Term::Bool(b) => SystemF::new(TypedTerm::Bool(*b), Type::Con(T_BOOL, vec![])), Term::Int(i) => SystemF::new(TypedTerm::Int(*i), Type::Con(T_INT, vec![])), Term::Var(x, s) => { let scheme = self.get_scheme(*x).cloned().expect("Unbound variable!"); let ty = self.instantiate(scheme.clone()); SystemF::new(TypedTerm::Var(*x, s.clone()), ty) } Term::Abs(body) => { let arg = self.fresh(); self.context.push(Scheme::Mono(Type::Var(arg.clone()))); let body = self.elaborate(body); self.context.pop(); let arrow = Type::arrow(Type::Var(arg), body.ty.clone()); SystemF::new(TypedTerm::Abs(Box::new(body)), arrow) } Term::App(t1, t2) => { let t1 = self.elaborate(t1); let t2 = self.elaborate(t2); let v = self.fresh(); unify_type(&t1.ty, &Type::arrow(t2.ty.clone(), Type::Var(v.clone()))).unwrap(); SystemF::new(TypedTerm::App(Box::new(t1), Box::new(t2)), Type::Var(v)) } Term::Let(t1, t2) => { self.rank += 1; let t1 = self.elaborate(t1); self.rank -= 1; let scheme = self.generalize(t1.ty.clone()); self.context.push(scheme); let t2 = self.elaborate(t2); self.context.pop(); let ty = t2.ty.clone(); SystemF::new(TypedTerm::Let(Box::new(t1), Box::new(t2)), ty) } Term::If(t1, t2, t3) => { let t1 = self.elaborate(t1); let t2 = self.elaborate(t2); let t3 = self.elaborate(t3); unify_type(&t1.ty, &Type::bool()).unwrap(); unify_type(&t2.ty, &t3.ty).unwrap(); let ty = t2.ty.clone(); SystemF::new(TypedTerm::If(Box::new(t1), Box::new(t2), Box::new(t3)), ty) } } } } ================================================ FILE: 05_recon/src/mutation/write_once.rs ================================================ use std::cell::{Cell, UnsafeCell}; use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; pub struct WriteOnce { inner: UnsafeCell>, rank: Cell, init: AtomicBool, } pub type WriteOnceCell = Rc>; impl Default for WriteOnce { fn default() -> Self { WriteOnce { inner: UnsafeCell::new(None), rank: Cell::new(0), init: false.into(), } } } impl PartialEq for WriteOnce { fn eq(&self, other: &Self) -> bool { self.get() == other.get() } } impl std::fmt::Debug for WriteOnce { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}#{}", self.get(), self.get_rank()) } } impl WriteOnce { pub fn with_rank(rank: usize) -> Self { WriteOnce { inner: UnsafeCell::new(None), rank: Cell::new(rank), init: false.into(), } } pub fn set(&self, data: T) -> Result<(), T> { if !self.init.compare_and_swap(false, true, Ordering::Acquire) { unsafe { let ptr = &mut *self.inner.get(); *ptr = Some(data); } Ok(()) } else { Err(data) } } pub fn get(&self) -> Option<&T> { if !self.init.compare_and_swap(false, false, Ordering::Release) { None } else { unsafe { &*self.inner.get() }.as_ref() } } pub fn set_rank(&self, rank: usize) { self.rank.set(rank) } pub fn get_rank(&self) -> usize { self.rank.get() } } #[cfg(test)] mod tests { use super::*; #[test] fn smoke() { let cell = WriteOnce::default(); assert_eq!(cell.get(), None); assert_eq!(cell.set(10), Ok(())); assert_eq!(cell.set(12), Err(12)); assert_eq!(cell.get(), Some(&10)); } #[test] fn smoke_shared() { let cell = Rc::new(WriteOnce::default()); let rc1 = cell.clone(); let rc2 = cell.clone(); assert_eq!(rc2.get(), None); rc1.set(12).unwrap(); assert_eq!(rc2.get(), Some(&12)); assert_eq!(rc2.set(10), Err(10)); } } ================================================ FILE: 05_recon/src/naive.rs ================================================ use super::*; fn var_bind(var: TypeVar, ty: Type) -> Result, String> { if ty.occurs(var) { return Err(format!("Fails occurs check! {:?} {:?}", var, ty)); } let mut sub = HashMap::new(); match ty { Type::Var(x) if x == var => {} _ => { sub.insert(var, ty); } } Ok(sub) } pub fn unify(a: Type, b: Type) -> Result, String> { // println!("{:?} {:?}", a, b); match (a, b) { (Type::Con(a, a_args), Type::Con(b, b_args)) => { if a_args.len() == b_args.len() && a == b { solve(a_args.into_iter().zip(b_args.into_iter())) } else { Err(format!( "Can't unify types: {:?} {:?}", Type::Con(a, a_args), Type::Con(b, b_args) )) } } (Type::Var(tv), b) => var_bind(tv, b), (a, Type::Var(tv)) => var_bind(tv, a), } } pub fn solve>(iter: I) -> Result, String> { let mut sub = HashMap::new(); for (a, b) in iter { let tmp = unify(a.clone().apply(&sub), b.clone().apply(&sub))?; sub = compose(tmp, sub); } Ok(sub) } ================================================ FILE: 05_recon/src/parser.rs ================================================ use super::Term; use std::char; use std::collections::VecDeque; use std::iter::Peekable; use std::str::Chars; use util::span::{Location, Span}; #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum TokenKind { Ident(String), Int(u32), Unit, Lambda, Let, Equals, In, Dot, If, Then, Else, True, False, LParen, RParen, Invalid(char), Eof, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Token { pub kind: TokenKind, pub span: Span, } impl Token { pub const fn new(kind: TokenKind, span: Span) -> Token { Token { kind, span } } } #[derive(Clone)] pub struct Lexer<'s> { input: Peekable>, current: Location, } impl<'s> Lexer<'s> { pub fn new(input: Chars<'s>) -> Lexer<'s> { Lexer { input: input.peekable(), current: Location { line: 0, col: 0, abs: 0, }, } } /// Peek at the next [`char`] in the input stream fn peek(&mut self) -> Option { self.input.peek().cloned() } /// Consume the next [`char`] and advance internal source position fn consume(&mut self) -> Option { match self.input.next() { Some('\n') => { self.current.line += 1; self.current.col = 0; self.current.abs += 1; Some('\n') } Some(ch) => { self.current.col += 1; self.current.abs += 1; Some(ch) } None => None, } } /// Consume characters from the input stream while pred(peek()) is true, /// collecting the characters into a string. fn consume_while bool>(&mut self, pred: F) -> (String, Span) { let mut s = String::new(); let start = self.current; while let Some(n) = self.peek() { if pred(n) { match self.consume() { Some(ch) => s.push(ch), None => break, } } else { break; } } (s, Span::new(start, self.current)) } /// Eat whitespace fn consume_delimiter(&mut self) { let _ = self.consume_while(char::is_whitespace); } /// Lex a natural number fn number(&mut self) -> Token { // Since we peeked at least one numeric char, we should always // have a string containing at least 1 single digit, as such // it is safe to call unwrap() on str::parse let (data, span) = self.consume_while(char::is_numeric); let n = data.parse::().unwrap(); Token::new(TokenKind::Int(n), span) } /// Lex a reserved keyword or an identifier fn keyword(&mut self) -> Token { let (data, span) = self.consume_while(|ch: char| ch.is_ascii_alphanumeric()); let kind = match data.as_ref() { "unit" => TokenKind::Unit, "let" => TokenKind::Let, "in" => TokenKind::In, "fn" => TokenKind::Lambda, "if" => TokenKind::If, "then" => TokenKind::Then, "else" => TokenKind::Else, "true" => TokenKind::True, "false" => TokenKind::False, _ => TokenKind::Ident(data), }; Token::new(kind, span) } /// Consume the next input character, expecting to match `ch`. /// Return a [`TokenKind::Invalid`] if the next character does not match, /// or the argument `kind` if it does fn eat(&mut self, ch: char, kind: TokenKind) -> Token { let loc = self.current; // Lexer::eat() should only be called internally after calling peek() // so we know that it's safe to unwrap the result of Lexer::consume() let n = self.consume().unwrap(); let kind = if n == ch { kind } else { TokenKind::Invalid(n) }; Token::new(kind, Span::new(loc, self.current)) } /// Return the next lexeme in the input as a [`Token`] pub fn lex(&mut self) -> Token { self.consume_delimiter(); let next = match self.peek() { Some(ch) => ch, None => return Token::new(TokenKind::Eof, Span::dummy()), }; match next { x if x.is_ascii_alphabetic() => self.keyword(), x if x.is_numeric() => self.number(), '(' => self.eat('(', TokenKind::LParen), ')' => self.eat(')', TokenKind::RParen), '\\' => self.eat('\\', TokenKind::Lambda), 'λ' => self.eat('λ', TokenKind::Lambda), '.' => self.eat('.', TokenKind::Dot), '=' => self.eat('=', TokenKind::Equals), ch => self.eat(' ', TokenKind::Invalid(ch)), } } } impl<'s> Iterator for Lexer<'s> { type Item = Token; fn next(&mut self) -> Option { match self.lex() { Token { kind: TokenKind::Eof, .. } => None, tok => Some(tok), } } } #[derive(Clone, Debug, Default)] pub struct DeBruijnIndexer { inner: VecDeque, } impl DeBruijnIndexer { pub fn push(&mut self, hint: String) -> usize { if self.inner.contains(&hint) { self.push(format!("{}'", hint)) } else { let idx = self.inner.len(); self.inner.push_front(hint); idx } } pub fn pop(&mut self) { self.inner.pop_front(); } pub fn lookup(&self, key: &str) -> Option { for (idx, s) in self.inner.iter().enumerate() { if key == s { return Some(idx); } } None } } pub struct Parser<'s> { ctx: DeBruijnIndexer, /// [`Lexer`] impls [`Iterator`] over [`TokenSpan`], /// so we can just directly wrap it in a [`Peekable`] lexer: Peekable>, span: Span, } impl<'s> Parser<'s> { /// Create a new [`Parser`] for the input `&str` pub fn new(input: &'s str) -> Parser<'s> { Parser { ctx: DeBruijnIndexer::default(), lexer: Lexer::new(input.chars()).peekable(), span: Span::dummy(), } } fn consume(&mut self) -> Option { let ts = self.lexer.next()?; self.span = ts.span; Some(ts) } fn expect(&mut self, kind: TokenKind) -> Option { let tk = self.consume()?; match &tk.kind { t if t == &kind => Some(tk), _ => { eprintln!("Expected token {:?}, found {:?}", kind, tk.kind); None } } } fn expect_term(&mut self) -> Option> { match self.term() { Some(term) => Some(term), None => { let sp = self.peek_span(); eprintln!("Expected term at {:?}", sp); None } } } fn peek(&mut self) -> Option { self.lexer.peek().map(|tk| tk.kind.clone()) } fn peek_span(&mut self) -> Span { self.lexer.peek().map(|s| s.span).unwrap_or(self.span) } fn lambda(&mut self) -> Option> { let start = self.expect(TokenKind::Lambda)?; // Bind variable into a new context before parsing the body let var = self.ident()?; self.ctx.push(var); let _ = self.expect(TokenKind::Dot)?; let body = self.term()?; // Return to previous context self.ctx.pop(); Some(Term::Abs(body).into()) } fn let_expr(&mut self) -> Option> { let start = self.expect(TokenKind::Let)?; let var = self.ident()?; let _ = self.expect(TokenKind::Equals)?; let bind = self.expect_term()?; self.ctx.push(var); let _ = self.expect(TokenKind::In)?; let body = self.expect_term()?; self.ctx.pop(); Some(Term::Let(bind, body).into()) } /// Parse an application of form: /// application = atom application' | atom /// application' = atom application' | empty fn application(&mut self) -> Option> { let mut lhs = self.atom()?; let span = self.span; while let Some(rhs) = self.atom() { lhs = Term::App(lhs, rhs).into(); } Some(lhs) } fn ident(&mut self) -> Option { let Token { kind, span } = self.consume()?; match kind { TokenKind::Ident(s) => Some(s), _ => { eprintln!("Expected identifier, found {:?}", kind); None } } } fn if_expr(&mut self) -> Option> { let _ = self.expect(TokenKind::If)?; let guard = self.expect_term()?; let _ = self.expect(TokenKind::Then)?; let csq = self.expect_term()?; let _ = self.expect(TokenKind::Else)?; let alt = self.expect_term()?; Some(Term::If(guard, csq, alt).into()) } /// Parse an atomic term /// LPAREN term RPAREN | var fn atom(&mut self) -> Option> { match self.peek()? { TokenKind::Let => self.let_expr(), TokenKind::Int(i) => { self.consume()?; Some(Term::Int(i as usize).into()) } TokenKind::True => { self.consume(); Some(Term::Bool(true).into()) } TokenKind::False => { self.consume(); Some(Term::Bool(false).into()) } TokenKind::LParen => { self.expect(TokenKind::LParen)?; let term = self.term()?; self.expect(TokenKind::RParen)?; Some(term) } TokenKind::Unit => { self.expect(TokenKind::Unit)?; Some(Term::Unit.into()) } TokenKind::If => self.if_expr(), TokenKind::Lambda => self.lambda(), TokenKind::Ident(s) => { let sp = self.consume()?.span; match self.ctx.lookup(&s) { Some(idx) => Some(Term::Var(idx, s).into()), None => { eprintln!("Unbound variable {}", s); None } } } _ => None, } } fn term(&mut self) -> Option> { match self.peek()? { // TokenKind::Lambda => self.lambda(), _ => self.application(), } } pub fn parse_term(&mut self) -> Option> { self.term() } } ================================================ FILE: 05_recon/src/types.rs ================================================ use std::collections::{HashMap, HashSet, VecDeque}; #[derive(Copy, Clone, Default, PartialEq, PartialOrd, Eq, Hash)] pub struct TypeVar(pub u32, pub u32); #[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Hash)] pub struct Tycon { id: usize, arity: usize, } #[derive(Clone, PartialEq, PartialOrd, Eq, Hash)] pub enum Type { Var(TypeVar), Con(Tycon, Vec), } #[derive(Debug, Clone)] pub enum Scheme { Mono(Type), Poly(Vec, Type), } pub trait Substitution { fn ftv(&self) -> HashSet; fn apply(self, s: &HashMap) -> Self; } impl Substitution for Type { fn ftv(&self) -> HashSet { let mut set = HashSet::new(); let mut queue = VecDeque::new(); queue.push_back(self); while let Some(ty) = queue.pop_front() { match ty { Type::Var(x) => { set.insert(*x); } Type::Con(_, tys) => { for ty in tys { queue.push_back(ty); } } } } set } fn apply(self, map: &HashMap) -> Type { match self { Type::Var(x) => map.get(&x).cloned().unwrap_or(Type::Var(x)), Type::Con(tc, vars) => Type::Con(tc, vars.into_iter().map(|ty| ty.apply(map)).collect()), } } } impl Type { pub fn arrow(a: Type, b: Type) -> Type { Type::Con(T_ARROW, vec![a, b]) } pub fn bool() -> Type { Type::Con(T_BOOL, vec![]) } pub fn occurs(&self, exist: TypeVar) -> bool { match self { Type::Var(x) => *x == exist, Type::Con(_, tys) => tys.iter().any(|ty| ty.occurs(exist)), } } pub fn de_arrow(&self) -> (&Type, &Type) { match self { Type::Con(T_ARROW, v) => (&v[0], &v[1]), _ => panic!("Not arrow type! {:?}", self), } } } pub fn compose(s1: HashMap, s2: HashMap) -> HashMap { let mut s2 = s2 .into_iter() .map(|(k, v)| (k, v.apply(&s1))) .collect::>(); for (k, v) in s1 { if !s2.contains_key(&k) { s2.insert(k, v); } } s2 } impl Substitution for Scheme { fn ftv(&self) -> HashSet { match self { Scheme::Mono(ty) => ty.ftv(), Scheme::Poly(vars, ty) => ty.ftv(), } } fn apply(self, map: &HashMap) -> Scheme { match self { Scheme::Mono(ty) => Scheme::Mono(ty.apply(map)), Scheme::Poly(vars, ty) => { let mut map: HashMap = map.clone(); for v in &vars { map.remove(v); } Scheme::Poly(vars, ty.apply(&map)) } } } } pub const T_ARROW: Tycon = Tycon { id: 0, arity: 2 }; pub const T_INT: Tycon = Tycon { id: 1, arity: 0 }; pub const T_UNIT: Tycon = Tycon { id: 2, arity: 0 }; pub const T_BOOL: Tycon = Tycon { id: 3, arity: 0 }; impl std::fmt::Debug for Tycon { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self.id { 0 => write!(f, "->"), 1 => write!(f, "int"), 2 => write!(f, "unit"), 3 => write!(f, "bool"), _ => write!(f, "??"), } } } fn fresh_name(x: u32) -> String { let last = ((x % 26) as u8 + 'a' as u8) as char; (0..x / 26) .map(|_| 'z') .chain(std::iter::once(last)) .collect::() } impl std::fmt::Debug for TypeVar { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.write_str(&fresh_name(self.0)) } } impl std::fmt::Debug for Type { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { Type::Var(x) => write!(f, "{:?}", x), Type::Con(T_ARROW, tys) => write!(f, "({:?} -> {:?})", tys[0], tys[1]), Type::Con(tc, _) => write!(f, "{:?}", tc,), } } } ================================================ FILE: 06_system_f/Cargo.toml ================================================ [package] name = "system_f" version = "0.1.0" authors = ["Michael Lazear "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] util = { path = "../util" } ================================================ FILE: 06_system_f/README.md ================================================ # System F An extension of the simply typed lambda calculus with parametric polymorphism ================================================ FILE: 06_system_f/src/diagnostics.rs ================================================ use util::span::Span; #[derive(Debug, Copy, Clone)] pub enum Level { Warn, Error, } #[derive(Debug, Clone)] pub struct Annotation { pub span: Span, pub info: String, } #[derive(Debug, Clone)] pub struct Diagnostic { pub level: Level, pub primary: Annotation, pub info: Vec, pub other: Vec, } impl Annotation { pub fn new>(span: Span, message: S) -> Annotation { Annotation { span, info: message.into(), } } } impl Diagnostic { pub fn error>(span: Span, message: S) -> Diagnostic { Diagnostic { level: Level::Error, primary: Annotation::new(span, message), other: Vec::new(), info: Vec::new(), } } pub fn warn>(span: Span, message: S) -> Diagnostic { Diagnostic { level: Level::Warn, primary: Annotation::new(span, message), other: Vec::new(), info: Vec::new(), } } pub fn message>(mut self, span: Span, message: S) -> Diagnostic { self.other.push(Annotation::new(span, message)); self } pub fn info>(mut self, info: S) -> Diagnostic { self.info.push(info.into()); self } pub fn lines(&self) -> std::ops::Range { let mut range = std::ops::Range { start: self.primary.span.start.line, end: self.primary.span.end.line + 1, }; for addl in &self.other { if addl.span.start.line < range.start { range.start = addl.span.start.line; } if addl.span.end.line + 1 > range.end { range.end = addl.span.end.line + 1; } } range } } ================================================ FILE: 06_system_f/src/eval.rs ================================================ use crate::patterns::Pattern; use crate::terms::visit::{Shift, Subst, TyTermSubst}; use crate::terms::{Kind, Literal, Primitive, Term}; use crate::types::{Context, Type}; use crate::visit::MutTermVisitor; pub struct Eval<'ctx> { _context: &'ctx Context, } impl<'ctx> Eval<'ctx> { pub fn with_context(_context: &Context) -> Eval<'_> { Eval { _context } } fn normal_form(&self, term: &Term) -> bool { match &term.kind { Kind::Lit(_) => true, Kind::Abs(_, _) => true, Kind::TyAbs(_) => true, Kind::Primitive(_) => true, Kind::Injection(_, tm, _) => self.normal_form(tm), Kind::Product(fields) => fields.iter().all(|f| self.normal_form(f)), Kind::Fold(_, tm) => self.normal_form(tm), Kind::Pack(_, tm, _) => self.normal_form(tm), // Kind::Unpack(pack, tm) => self.normal_form(tm), _ => false, } } fn eval_primitive(&self, p: Primitive, term: Term) -> Option { fn map u32>(f: F, mut term: Term) -> Option { match &term.kind { Kind::Lit(Literal::Nat(n)) => { term.kind = Kind::Lit(Literal::Nat(f(*n))); Some(term) } _ => None, } } match p { Primitive::Succ => map(|l| l + 1, term), Primitive::Pred => map(|l| l.saturating_sub(1), term), Primitive::IsZero => match &term.kind { Kind::Lit(Literal::Nat(0)) => Some(Term::new(Kind::Lit(Literal::Bool(true)), term.span)), _ => Some(Term::new(Kind::Lit(Literal::Bool(false)), term.span)), }, } } pub fn small_step(&self, term: Term) -> Option { if self.normal_form(&term) { return None; } match term.kind { Kind::App(t1, t2) => { if self.normal_form(&t2) { match t1.kind { Kind::Abs(_, mut abs) => { term_subst(*t2, abs.as_mut()); Some(*abs) } Kind::Primitive(p) => self.eval_primitive(p, *t2), _ => { let t = self.small_step(*t1)?; Some(Term::new(Kind::App(Box::new(t), t2), term.span)) } } } else if self.normal_form(&t1) { // t1 is in normal form, but t2 is not, so we will // carry out the reducton t2 -> t2', and return // App(t1, t2') let t = self.small_step(*t2)?; Some(Term::new(Kind::App(t1, Box::new(t)), term.span)) } else { // Neither t1 nor t2 are in normal form, we reduce t1 first let t = self.small_step(*t1)?; Some(Term::new(Kind::App(Box::new(t), t2), term.span)) } } Kind::Let(pat, bind, mut body) => { if self.normal_form(&bind) { // term_subst(*bind, &mut body); self.case_subst(&pat, &bind, body.as_mut()); Some(*body) } else { let t = self.small_step(*bind)?; Some(Term::new(Kind::Let(pat, Box::new(t), body), term.span)) } } Kind::TyApp(tm, ty) => match tm.kind { Kind::TyAbs(mut tm2) => { type_subst(*ty, &mut tm2); Some(*tm2) } _ => { let t_prime = self.small_step(*tm)?; Some(Term::new(Kind::TyApp(Box::new(t_prime), ty), term.span)) } }, Kind::Injection(label, tm, ty) => { let t_prime = self.small_step(*tm)?; Some(Term::new(Kind::Injection(label, Box::new(t_prime), ty), term.span)) } Kind::Projection(tm, idx) => { if self.normal_form(&tm) { match tm.kind { // Typechecker ensures that idx is in bounds Kind::Product(terms) => terms.get(idx).cloned(), _ => None, } } else { let t_prime = self.small_step(*tm)?; Some(Term::new(Kind::Projection(Box::new(t_prime), idx), term.span)) } } Kind::Product(terms) => { let mut v = Vec::with_capacity(terms.len()); for term in terms { if self.normal_form(&term) { v.push(term); } else { v.push(self.small_step(term)?); } } Some(Term::new(Kind::Product(v), term.span)) } Kind::Fix(tm) => { if !self.normal_form(&tm) { let t_prime = self.small_step(*tm)?; return Some(Term::new(Kind::Fix(Box::new(t_prime)), term.span)); } let x = Term::new(Kind::Fix(tm.clone()), term.span); match tm.kind { Kind::Abs(_, mut body) => { term_subst(x, &mut body); Some(*body) } _ => None, } } Kind::Case(expr, arms) => { if !self.normal_form(&expr) { let t_prime = self.small_step(*expr)?; return Some(Term::new(Kind::Case(Box::new(t_prime), arms), term.span)); } for mut arm in arms { if arm.pat.matches(&expr) { self.case_subst(&arm.pat, &expr, arm.term.as_mut()); return Some(*arm.term); } } None } Kind::Fold(ty, tm) => { if !self.normal_form(&tm) { let t_prime = self.small_step(*tm)?; Some(Term::new(Kind::Fold(ty, Box::new(t_prime)), term.span)) } else { None } } Kind::Unfold(ty, tm) => { if !self.normal_form(&tm) { let t_prime = self.small_step(*tm)?; return Some(Term::new(Kind::Unfold(ty, Box::new(t_prime)), term.span)); } match tm.kind { Kind::Fold(ty2, inner) => Some(*inner), _ => None, } } Kind::Pack(wit, evidence, sig) => { if !self.normal_form(&evidence) { let t_prime = self.small_step(*evidence)?; return Some(Term::new(Kind::Pack(wit, Box::new(t_prime), sig), term.span)); } None } Kind::Unpack(package, mut body) => match package.kind { Kind::Pack(wit, evidence, sig) => { term_subst(*evidence, &mut body); type_subst(*wit, &mut body); Some(*body) } _ => { if !self.normal_form(&package) { let t_prime = self.small_step(*package)?; return Some(Term::new(Kind::Unpack(Box::new(t_prime), body), term.span)); } None } }, _ => None, } } fn case_subst(&self, pat: &Pattern, expr: &Term, term: &mut Term) { use Pattern::*; match pat { Any => {} Literal(_) => {} Variable(_) => { term_subst(expr.clone(), term); } Product(v) => { if let Kind::Product(terms) = &expr.kind { let mut idx = 0; for tm in terms.iter() { self.case_subst(&v[idx], tm, term); idx += 1; } } else { panic!("wrong type!") } } Constructor(label, v) => { if let Kind::Injection(label_, tm, _) = &expr.kind { if label == label_ { self.case_subst(&v, &tm, term); } } else { panic!("wrong type!") } } } } } fn term_subst(mut s: Term, t: &mut Term) { Shift::new(1).visit(&mut s); Subst::new(s).visit(t); Shift::new(-1).visit(t); } fn type_subst(s: Type, t: &mut Term) { TyTermSubst::new(s).visit(t); Shift::new(-1).visit(t); } #[cfg(test)] mod test { use super::*; use util::span::Span; #[test] fn literal() { let ctx = crate::types::Context::default(); let eval = Eval::with_context(&ctx); assert_eq!(eval.small_step(lit!(false)), None); } #[test] fn application() { let ctx = crate::types::Context::default(); let eval = Eval::with_context(&ctx); let tm = app!(abs!(Type::Nat, app!(prim!(Primitive::Succ), var!(0))), nat!(1)); let t1 = eval.small_step(tm); assert_eq!(t1, Some(app!(prim!(Primitive::Succ), nat!(1)))); let t2 = eval.small_step(t1.unwrap()); assert_eq!(t2, Some(nat!(2))); let t3 = eval.small_step(t2.unwrap()); assert_eq!(t3, None); } #[test] fn type_application() { let ctx = crate::types::Context::default(); let eval = Eval::with_context(&ctx); let tm = tyapp!( tyabs!(abs!(Type::Var(0), app!(prim!(Primitive::Succ), var!(0)))), Type::Nat ); let t1 = eval.small_step(tm); assert_eq!(t1, Some(abs!(Type::Nat, app!(prim!(Primitive::Succ), var!(0))))); let t2 = eval.small_step(t1.unwrap()); assert_eq!(t2, None); } #[test] fn projection() { let ctx = crate::types::Context::default(); let eval = Eval::with_context(&ctx); let product = Term::new(Kind::Product(vec![nat!(5), nat!(6), nat!(29)]), Span::zero()); let projection = Term::new(Kind::Projection(Box::new(product), 2), Span::zero()); let term = app!(prim!(Primitive::Succ), projection); let t1 = eval.small_step(term); assert_eq!(t1, Some(app!(prim!(Primitive::Succ), nat!(29)))); let t2 = eval.small_step(t1.unwrap()); assert_eq!(t2, Some(nat!(30))); let t3 = eval.small_step(t2.unwrap()); assert_eq!(t3, None); } } ================================================ FILE: 06_system_f/src/macros.rs ================================================ //! Macros to make writing tests easier /// Boolean term macro_rules! lit { ($x:expr) => { crate::terms::Term::new( crate::terms::Kind::Lit(crate::terms::Literal::Bool($x)), util::span::Span::dummy(), ) }; } /// Integer term macro_rules! nat { ($x:expr) => { crate::terms::Term::new( crate::terms::Kind::Lit(crate::terms::Literal::Nat($x)), util::span::Span::dummy(), ) }; } /// TmVar term macro_rules! var { ($x:expr) => { crate::terms::Term::new(crate::terms::Kind::Var($x), util::span::Span::dummy()) }; } /// Application term macro_rules! app { ($t1:expr, $t2:expr) => { crate::terms::Term::new( crate::terms::Kind::App(Box::new($t1), Box::new($t2)), util::span::Span::dummy(), ) }; } /// Lambda abstraction term macro_rules! abs { ($ty:expr, $t:expr) => { crate::terms::Term::new( crate::terms::Kind::Abs(Box::new($ty), Box::new($t)), util::span::Span::dummy(), ) }; } /// Type application term macro_rules! tyapp { ($t1:expr, $t2:expr) => { crate::terms::Term::new( crate::terms::Kind::TyApp(Box::new($t1), Box::new($t2)), util::span::Span::dummy(), ) }; } /// Type abstraction term macro_rules! tyabs { ( $t:expr) => { crate::terms::Term::new(crate::terms::Kind::TyAbs(Box::new($t)), util::span::Span::dummy()) }; } /// Primitive term macro_rules! prim { ($t:expr) => { crate::terms::Term::new(crate::terms::Kind::Primitive($t), util::span::Span::dummy()) }; } macro_rules! inj { ($label:expr, $t:expr, $ty:expr) => { crate::terms::Term::new( crate::terms::Kind::Injection($label.to_string(), Box::new($t), Box::new($ty)), util::span::Span::dummy(), ) }; } /// Product term macro_rules! tuple { ($($ex:expr),+) => { crate::terms::Term::new(crate::terms::Kind::Product(vec![$($ex),+]), util::span::Span::dummy()) } } /// Type arrow macro_rules! arrow { ($ty1:expr, $ty2:expr) => { crate::types::Type::Arrow(Box::new($ty1), Box::new($ty2)) }; } /// Boolean pattern macro_rules! boolean { ($ex:expr) => { crate::patterns::Pattern::Literal(crate::terms::Literal::Bool($ex)) }; } /// Numeric pattern macro_rules! num { ($ex:expr) => { crate::patterns::Pattern::Literal(crate::terms::Literal::Nat($ex)) }; } /// Product pattern macro_rules! prod { ($($ex:expr),+) => { crate::patterns::Pattern::Product(vec![$($ex),+]) } } /// Constructor pattern macro_rules! con { ($label:expr, $ex:expr) => { crate::patterns::Pattern::Constructor($label.to_string(), Box::new($ex)) }; } /// Variant type macro_rules! variant { ($label:expr, $ty:expr) => { crate::types::Variant { label: $label.to_string(), ty: $ty, } }; } ================================================ FILE: 06_system_f/src/main.rs ================================================ #![allow(unused_variables, unused_macros)] #[macro_use] pub mod macros; pub mod diagnostics; pub mod eval; pub mod patterns; pub mod syntax; pub mod terms; pub mod types; pub mod visit; use diagnostics::*; use std::env; use std::io::{Read, Write}; use syntax::parser::{self, Parser}; use terms::{visit::InjRewriter, Term}; use types::{Type, Variant}; use visit::MutTermVisitor; fn test_variant() -> Type { Type::Variant(vec![ Variant { label: "A".into(), ty: Type::Unit, }, Variant { label: "B".into(), ty: Type::Nat, }, Variant { label: "C".into(), ty: Type::Nat, }, ]) } pub fn code_format(src: &str, diag: Diagnostic) { // let lines = diag.ot // .iter() // .map(|(_, sp)| sp.start.line) // .collect::>(); let srcl = src.lines().collect::>(); let mut msgs = diag.other.clone(); msgs.insert(0, diag.primary.clone()); for line in diag.lines() { println!("| {} {}", line + 1, &srcl[line as usize]); for anno in &msgs { if anno.span.start.line != line { continue; } let empty = (0..anno.span.start.col + 3).map(|_| ' ').collect::(); let tilde = (1..anno.span.end.col.saturating_sub(anno.span.start.col)) .map(|_| '~') .collect::(); println!("{}^{}^ --- {}", empty, tilde, anno.info); } } } fn eval(ctx: &mut types::Context, mut term: Term, verbose: bool) -> Result { ctx.de_alias(&mut term); InjRewriter.visit(&mut term); let ty = ctx.type_check(&term)?; println!(" -: {:?}", ty); let ev = eval::Eval::with_context(ctx); let mut t = term; let fin = loop { if let Some(res) = ev.small_step(t.clone()) { t = res; } else { break t; } if verbose { println!("---> {}", t); } }; println!("===> {}", fin); let fty = ctx.type_check(&fin)?; if fty != ty { panic!( "Type of term after evaluation is different than before!\n1 {:?}\n2 {:?}", ty, fty ); } Ok(fin) } fn parse_and_eval(ctx: &mut types::Context, input: &str, verbose: bool) -> bool { let mut p = Parser::new(input); loop { let term = match p.parse() { Ok(term) => term, Err(parser::Error { kind: parser::ErrorKind::Eof, .. }) => break, Err(e) => { dbg!(e); break; } }; if let Err(diag) = eval(ctx, term, verbose) { code_format(input, diag); return false; } } let diag = p.diagnostic(); if diag.error_count() > 0 { println!("Parsing {}", diag.emit()); false } else { true } } fn nat_list() -> Type { Type::Rec(Box::new(Type::Variant(vec![ variant!("Nil", Type::Unit), variant!("Cons", Type::Product(vec![Type::Nat, Type::Var(0)])), ]))) } fn nat_list2() -> Type { Type::Variant(vec![ variant!("Nil", Type::Unit), variant!("Cons", Type::Product(vec![Type::Nat, Type::Var(0)])), ]) } fn main() { let mut ctx = types::Context::default(); ctx.alias("Var".into(), test_variant()); ctx.alias("NatList".into(), nat_list()); ctx.alias("NB".into(), nat_list2()); let args = env::args(); if args.len() > 1 { for f in args.skip(1) { println!("reading {}", f); let file = std::fs::read_to_string(&f).unwrap(); if !parse_and_eval(&mut ctx, &file, false) { panic!("test failed! {}", f); } } return; } loop { let mut buffer = String::new(); print!("repl: "); std::io::stdout().flush().unwrap(); std::io::stdin().read_to_string(&mut buffer).unwrap(); parse_and_eval(&mut ctx, &buffer, true); } } ================================================ FILE: 06_system_f/src/patterns/mod.rs ================================================ use crate::terms::{Kind, Literal, Term}; use crate::types::{variant_field, Type}; use crate::visit::PatternVisitor; use util::span::Span; /// Patterns for case and let expressions #[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Hash)] pub enum Pattern { /// Wildcard pattern, this always matches Any, /// Constant pattern Literal(Literal), /// Variable binding pattern, this always matches Variable(String), /// Tuple of pattern bindings Product(Vec), /// Algebraic datatype constructor, along with binding pattern Constructor(String, Box), } #[derive(Clone, Debug, Default)] pub struct PatVarStack { pub inner: Vec, } impl PatVarStack { pub fn collect(pat: &mut Pattern) -> Vec { let mut p = Self::default(); p.visit_pattern(pat); p.inner } } impl PatternVisitor for PatVarStack { fn visit_variable(&mut self, var: &String) { self.inner.push(var.clone()); } } /// Visitor that simply counts the number of binders (variables) within a /// pattern pub struct PatternCount(usize); impl PatternCount { pub fn collect(pat: &mut Pattern) -> usize { let mut p = PatternCount(0); p.visit_pattern(pat); p.0 } } impl PatternVisitor for PatternCount { fn visit_variable(&mut self, var: &String) { self.0 += 1; } } impl Pattern { /// Does this pattern match the given [`Term`]? pub fn matches(&self, term: &Term) -> bool { match self { Pattern::Any => return true, Pattern::Variable(_) => return true, Pattern::Literal(l) => { if let Kind::Lit(l2) = &term.kind { return l == l2; } } Pattern::Product(vec) => { if let Kind::Product(terms) = &term.kind { return vec.iter().zip(terms).all(|(p, t)| p.matches(t)); } } Pattern::Constructor(label, inner) => { if let Kind::Injection(label_, tm, _) = &term.kind { if label == label_ { return inner.matches(&tm); } } } } false } } /// Helper struct to traverse a [`Pattern`] and bind variables /// to the typing context as needed. /// /// It is the caller's responsibiliy to track stack growth and pop off /// types after calling this function pub struct PatTyStack<'ty> { pub ty: &'ty Type, pub inner: Vec<&'ty Type>, } impl<'ty> PatTyStack<'ty> { pub fn collect(ty: &'ty Type, pat: &Pattern) -> Vec<&'ty Type> { let mut p = PatTyStack { ty, inner: Vec::with_capacity(16), }; p.visit_pattern(pat); p.inner } } impl<'ty> PatternVisitor for PatTyStack<'_> { fn visit_product(&mut self, pats: &Vec) { if let Type::Product(tys) = self.ty { let ty = self.ty; for (ty, pat) in tys.iter().zip(pats.iter()) { self.ty = ty; self.visit_pattern(pat); } self.ty = ty; } } fn visit_constructor(&mut self, label: &String, pat: &Pattern) { if let Type::Variant(vs) = self.ty { let ty = self.ty; self.ty = variant_field(&vs, label, Span::zero()).unwrap(); self.visit_pattern(pat); self.ty = ty; } } fn visit_pattern(&mut self, pattern: &Pattern) { match pattern { Pattern::Any | Pattern::Literal(_) => {} Pattern::Variable(_) => self.inner.push(self.ty), Pattern::Constructor(label, pat) => self.visit_constructor(label, pat), Pattern::Product(pats) => self.visit_product(pats), } } } #[cfg(test)] mod test { use super::*; #[test] fn pattern_count() { let mut pat = Pattern::Variable(String::new()); assert_eq!(PatternCount::collect(&mut pat), 1); } #[test] fn pattern_ty_stack() { let mut pat = Pattern::Variable(String::new()); let ty = Type::Nat; assert_eq!(PatTyStack::collect(&ty, &mut pat), vec![&ty]); } #[test] fn pattern_var_stack() { let mut pat = Pattern::Variable("x".into()); assert_eq!(PatVarStack::collect(&mut pat), vec![String::from("x")]); } } ================================================ FILE: 06_system_f/src/syntax/lexer.rs ================================================ use super::{Token, TokenKind}; use std::char; use std::iter::Peekable; use std::str::Chars; use util::span::{Location, Span}; #[derive(Clone)] pub struct Lexer<'s> { input: Peekable>, current: Location, } impl<'s> Lexer<'s> { pub fn new(input: Chars<'s>) -> Lexer<'s> { Lexer { input: input.peekable(), current: Location { line: 0, col: 0, abs: 0, }, } } /// Peek at the next [`char`] in the input stream fn peek(&mut self) -> Option { self.input.peek().cloned() } /// Consume the next [`char`] and advance internal source position fn consume(&mut self) -> Option { match self.input.next() { Some('\n') => { self.current.line += 1; self.current.col = 0; self.current.abs += 1; Some('\n') } Some(ch) => { self.current.col += 1; self.current.abs += 1; Some(ch) } None => None, } } /// Consume characters from the input stream while pred(peek()) is true, /// collecting the characters into a string. fn consume_while bool>(&mut self, pred: F) -> (String, Span) { let mut s = String::new(); let start = self.current; while let Some(n) = self.peek() { if pred(n) { match self.consume() { Some(ch) => s.push(ch), None => break, } } else { break; } } (s, Span::new(start, self.current)) } /// Eat whitespace fn consume_delimiter(&mut self) { let _ = self.consume_while(char::is_whitespace); } /// Lex a natural number fn number(&mut self) -> Token { // Since we peeked at least one numeric char, we should always // have a string containing at least 1 single digit, as such // it is safe to call unwrap() on str::parse let (data, span) = self.consume_while(char::is_numeric); let n = data.parse::().unwrap(); Token::new(TokenKind::Nat(n), span) } /// Lex a reserved keyword or an identifier fn keyword(&mut self) -> Token { let (data, span) = self.consume_while(|ch| ch.is_ascii_alphanumeric()); let kind = match data.as_ref() { "if" => TokenKind::If, "then" => TokenKind::Then, "else" => TokenKind::Else, "true" => TokenKind::True, "false" => TokenKind::False, "succ" => TokenKind::Succ, "pred" => TokenKind::Pred, "iszero" => TokenKind::IsZero, "zero" => TokenKind::Nat(0), "Bool" => TokenKind::TyBool, "Nat" => TokenKind::TyNat, "Unit" => TokenKind::TyUnit, "unit" => TokenKind::Unit, "let" => TokenKind::Let, "in" => TokenKind::In, "fix" => TokenKind::Fix, "case" => TokenKind::Case, "of" => TokenKind::Of, "fold" => TokenKind::Fold, "unfold" => TokenKind::Unfold, "rec" => TokenKind::Rec, "lambda" => TokenKind::Lambda, "forall" => TokenKind::Forall, "exists" => TokenKind::Exists, "pack" => TokenKind::Pack, "unpack" => TokenKind::Unpack, "as" => TokenKind::As, _ => { if data.starts_with(|ch: char| ch.is_ascii_uppercase()) { TokenKind::Uppercase(data) } else { TokenKind::Lowercase(data) } } }; Token::new(kind, span) } /// Consume the next input character, expecting to match `ch`. /// Return a [`TokenKind::Invalid`] if the next character does not match, /// or the argument `kind` if it does fn eat(&mut self, ch: char, kind: TokenKind) -> Token { let loc = self.current; // Lexer::eat() should only be called internally after calling peek() // so we know that it's safe to unwrap the result of Lexer::consume() let n = self.consume().unwrap(); let kind = if n == ch { kind } else { TokenKind::Invalid(n) }; Token::new(kind, Span::new(loc, self.current)) } /// Return the next lexeme in the input as a [`Token`] pub fn lex(&mut self) -> Token { self.consume_delimiter(); let next = match self.peek() { Some(ch) => ch, None => return Token::new(TokenKind::Eof, Span::new(self.current, self.current)), }; match next { x if x.is_ascii_alphabetic() => self.keyword(), x if x.is_numeric() => self.number(), '(' => self.eat('(', TokenKind::LParen), ')' => self.eat(')', TokenKind::RParen), ';' => self.eat(';', TokenKind::Semicolon), ':' => self.eat(':', TokenKind::Colon), ',' => self.eat(',', TokenKind::Comma), '{' => self.eat('{', TokenKind::LBrace), '}' => self.eat('}', TokenKind::RBrace), '[' => self.eat('[', TokenKind::LSquare), ']' => self.eat(']', TokenKind::RSquare), '\\' => self.eat('\\', TokenKind::Lambda), 'λ' => self.eat('λ', TokenKind::Lambda), '∀' => self.eat('∀', TokenKind::Forall), '∃' => self.eat('∃', TokenKind::Exists), '.' => self.eat('.', TokenKind::Proj), '=' => self.eat('=', TokenKind::Equals), '|' => self.eat('|', TokenKind::Bar), '_' => self.eat('_', TokenKind::Wildcard), '>' => self.eat('>', TokenKind::Gt), '-' => { self.consume(); self.eat('>', TokenKind::TyArrow) } ch => self.eat(' ', TokenKind::Invalid(ch)), } } } impl<'s> Iterator for Lexer<'s> { type Item = Token; fn next(&mut self) -> Option { match self.lex() { Token { kind: TokenKind::Eof, .. } => None, tok => Some(tok), } } } #[cfg(test)] mod test { use super::*; use TokenKind::*; #[test] fn nested() { let input = "succ(succ(succ(0)))"; let expected = vec![Succ, LParen, Succ, LParen, Succ, LParen, Nat(0), RParen, RParen, RParen]; let output = Lexer::new(input.chars()) .into_iter() .map(|t| t.kind) .collect::>(); assert_eq!(expected, output); } #[test] fn case() { let input = "case x of | A _ => true | B x => (\\y: Nat. x)"; let expected = vec![ Case, Lowercase("x".into()), Of, Bar, Uppercase("A".into()), Wildcard, Equals, Gt, True, Bar, Uppercase("B".into()), Lowercase("x".into()), Equals, Gt, LParen, Lambda, Lowercase("y".into()), Colon, TyNat, Proj, Lowercase("x".into()), RParen, ]; let output = Lexer::new(input.chars()) .into_iter() .map(|t| t.kind) .collect::>(); assert_eq!(expected, output); } } ================================================ FILE: 06_system_f/src/syntax/mod.rs ================================================ //! Lexical analysis and recursive descent parser for System F pub mod lexer; pub mod parser; use util::span::Span; #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum TokenKind { Uppercase(String), Lowercase(String), Nat(u32), TyNat, TyBool, TyArrow, TyUnit, Unit, True, False, Lambda, Forall, Exists, As, Pack, Unpack, Succ, Pred, If, Then, Else, Let, In, IsZero, Semicolon, Colon, Comma, Proj, LParen, RParen, LBrace, RBrace, LSquare, RSquare, Equals, Bar, Wildcard, Gt, Case, Of, Fix, Fold, Unfold, Rec, Invalid(char), Dummy, Eof, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Token { pub kind: TokenKind, pub span: Span, } impl Token { pub const fn dummy() -> Token { Token { kind: TokenKind::Dummy, span: Span::zero(), } } pub const fn new(kind: TokenKind, span: Span) -> Token { Token { kind, span } } } ================================================ FILE: 06_system_f/src/syntax/parser.rs ================================================ use super::lexer::Lexer; use super::{Token, TokenKind}; use std::collections::VecDeque; use util::diagnostic::Diagnostic; use util::span::*; use crate::patterns::{PatVarStack, Pattern}; use crate::terms::*; use crate::types::*; #[derive(Clone, Debug, Default)] pub struct DeBruijnIndexer { inner: VecDeque, } impl DeBruijnIndexer { pub fn push(&mut self, hint: String) -> usize { let idx = self.inner.len(); self.inner.push_front(hint); idx } pub fn pop(&mut self) { self.inner.pop_front(); } pub fn lookup(&self, key: &str) -> Option { for (idx, s) in self.inner.iter().enumerate() { if key == s { return Some(idx); } } None } pub fn len(&self) -> usize { self.inner.len() } } pub struct Parser<'s> { tmvar: DeBruijnIndexer, tyvar: DeBruijnIndexer, diagnostic: Diagnostic<'s>, lexer: Lexer<'s>, span: Span, token: Token, } #[derive(Clone, Debug)] pub struct Error { pub span: Span, pub tok: Token, pub kind: ErrorKind, } #[derive(Clone, Debug)] pub enum ErrorKind { ExpectedAtom, ExpectedIdent, ExpectedType, ExpectedPattern, ExpectedToken(TokenKind), UnboundTypeVar, Unknown, Eof, } impl<'s> Parser<'s> { /// Create a new [`Parser`] for the input `&str` pub fn new(input: &'s str) -> Parser<'s> { let mut p = Parser { tmvar: DeBruijnIndexer::default(), tyvar: DeBruijnIndexer::default(), diagnostic: Diagnostic::new(input), lexer: Lexer::new(input.chars()), span: Span::default(), token: Token::dummy(), }; p.bump(); p } pub fn diagnostic(self) -> Diagnostic<'s> { self.diagnostic } } impl<'s> Parser<'s> { /// Kleene Plus combinator fn once_or_more(&mut self, func: F, delimiter: TokenKind) -> Result, Error> where F: Fn(&mut Parser) -> Result, { let mut v = vec![func(self)?]; while self.bump_if(&delimiter) { v.push(func(self)?); } Ok(v) } /// Expect combinator /// Combinator that must return Ok or a message will be pushed to /// diagnostic. This method should only be called after a token has /// already been bumped. fn once(&mut self, func: F, message: &str) -> Result where F: Fn(&mut Parser) -> Result, { match func(self) { Ok(t) => Ok(t), Err(e) => { self.diagnostic.push(message, self.span); Err(e) } } } } impl<'s> Parser<'s> { fn error(&self, kind: ErrorKind) -> Result { Err(Error { span: self.token.span, tok: self.token.clone(), kind, }) } fn bump(&mut self) -> TokenKind { let prev = std::mem::replace(&mut self.token, self.lexer.lex()); self.span = prev.span; prev.kind } fn bump_if(&mut self, kind: &TokenKind) -> bool { if &self.token.kind == kind { self.bump(); true } else { false } } fn expect(&mut self, kind: TokenKind) -> Result<(), Error> { if self.token.kind == kind { self.bump(); Ok(()) } else { self.diagnostic.push( format!("expected token {:?}, found {:?}", kind, self.token.kind), self.span, ); self.error(ErrorKind::ExpectedToken(kind)) } } fn kind(&self) -> &TokenKind { &self.token.kind } fn ty_variant(&mut self) -> Result { let label = self.uppercase_id()?; let ty = match self.ty() { Ok(ty) => ty, _ => Type::Unit, }; Ok(Variant { label, ty }) } fn ty_app(&mut self) -> Result { if !self.bump_if(&TokenKind::LSquare) { return self.error(ErrorKind::ExpectedToken(TokenKind::LSquare)); } let ty = self.ty()?; self.expect(TokenKind::RSquare)?; Ok(ty) } fn ty_atom(&mut self) -> Result { match self.kind() { TokenKind::TyBool => { self.bump(); Ok(Type::Bool) } TokenKind::TyNat => { self.bump(); Ok(Type::Nat) } TokenKind::TyUnit => { self.bump(); Ok(Type::Unit) } TokenKind::LParen => { self.bump(); let r = self.ty()?; self.expect(TokenKind::RParen)?; Ok(r) } TokenKind::Forall => { self.bump(); Ok(Type::Universal(Box::new(self.ty()?))) } TokenKind::Exists => { self.bump(); let tvar = self.uppercase_id()?; self.expect(TokenKind::Proj)?; self.tyvar.push(tvar); let xs = Type::Existential(Box::new(self.ty()?)); self.tyvar.pop(); Ok(xs) } TokenKind::Uppercase(_) => { let ty = self.uppercase_id()?; match self.tyvar.lookup(&ty) { Some(idx) => Ok(Type::Var(idx)), None => Ok(Type::Alias(ty)), } } TokenKind::LBrace => { self.bump(); let fields = self.once_or_more(|p| p.ty_variant(), TokenKind::Bar)?; self.expect(TokenKind::RBrace)?; Ok(Type::Variant(fields)) } _ => self.error(ErrorKind::ExpectedType), } } fn ty_tuple(&mut self) -> Result { if self.bump_if(&TokenKind::LParen) { let mut v = self.once_or_more(|p| p.ty(), TokenKind::Comma)?; self.expect(TokenKind::RParen)?; if v.len() > 1 { Ok(Type::Product(v)) } else { Ok(v.remove(0)) } } else { self.ty_atom() } } pub fn ty(&mut self) -> Result { if self.bump_if(&TokenKind::Rec) { let name = self.uppercase_id()?; self.expect(TokenKind::Equals)?; self.tyvar.push(name); let ty = self.ty()?; self.tyvar.pop(); return Ok(Type::Rec(Box::new(ty))); } let mut lhs = self.ty_tuple()?; if let TokenKind::TyArrow = self.kind() { self.bump(); while let Ok(rhs) = self.ty() { lhs = Type::Arrow(Box::new(lhs), Box::new(rhs)); if let TokenKind::TyArrow = self.kind() { self.bump(); } else { break; } } } Ok(lhs) } fn tyabs(&mut self) -> Result { let tyvar = self.uppercase_id()?; let sp = self.span; let ty = Box::new(Type::Var(self.tyvar.push(tyvar))); let body = self.once(|p| p.parse(), "abstraction body required")?; Ok(Term::new(Kind::TyAbs(Box::new(body)), sp + self.span)) } fn tmabs(&mut self) -> Result { let tmvar = self.lowercase_id()?; let sp = self.span; self.tmvar.push(tmvar); self.expect(TokenKind::Colon)?; let ty = self.once(|p| p.ty(), "type annotation required in abstraction")?; self.expect(TokenKind::Proj)?; let body = self.once(|p| p.parse(), "abstraction body required")?; self.tmvar.pop(); Ok(Term::new(Kind::Abs(Box::new(ty), Box::new(body)), sp + self.span)) } fn fold(&mut self) -> Result { self.expect(TokenKind::Fold)?; let sp = self.span; let ty = self.once(|p| p.ty(), "type annotation required after `fold`")?; let tm = self.once(|p| p.parse(), "term required after `fold`")?; Ok(Term::new(Kind::Fold(Box::new(ty), Box::new(tm)), sp + self.span)) } fn unfold(&mut self) -> Result { self.expect(TokenKind::Unfold)?; let sp = self.span; let ty = self.once(|p| p.ty(), "type annotation required after `unfold`")?; let tm = self.once(|p| p.parse(), "term required after `unfold`")?; Ok(Term::new(Kind::Unfold(Box::new(ty), Box::new(tm)), sp + self.span)) } fn fix(&mut self) -> Result { let sp = self.span; self.expect(TokenKind::Fix)?; let t = self.parse()?; Ok(Term::new(Kind::Fix(Box::new(t)), sp + self.span)) } fn letexpr(&mut self) -> Result { let sp = self.span; self.expect(TokenKind::Let)?; let mut pat = self.once(|p| p.pattern(), "missing pattern")?; self.expect(TokenKind::Equals)?; let t1 = self.once(|p| p.parse(), "let binder required")?; let len = self.tmvar.len(); for var in PatVarStack::collect(&mut pat).into_iter().rev() { self.tmvar.push(var); } self.expect(TokenKind::In)?; let t2 = self.once(|p| p.parse(), "let body required")?; while self.tmvar.len() > len { self.tmvar.pop(); } Ok(Term::new( Kind::Let(Box::new(pat), Box::new(t1), Box::new(t2)), sp + self.span, )) } fn lambda(&mut self) -> Result { self.expect(TokenKind::Lambda)?; match self.kind() { TokenKind::Uppercase(_) => self.tyabs(), TokenKind::Lowercase(_) => self.tmabs(), _ => { self.diagnostic .push("expected identifier after lambda, found".to_string(), self.span); self.error(ErrorKind::ExpectedIdent) } } } fn paren(&mut self) -> Result { self.expect(TokenKind::LParen)?; let span = self.span; let mut n = self.once_or_more(|p| p.parse(), TokenKind::Comma)?; self.expect(TokenKind::RParen)?; if n.len() > 1 { Ok(Term::new(Kind::Product(n), span + self.span)) } else { // invariant, n.len() >= 1 Ok(n.remove(0)) } } fn uppercase_id(&mut self) -> Result { match self.bump() { TokenKind::Uppercase(s) => Ok(s), tk => { self.diagnostic .push(format!("expected uppercase identifier, found {:?}", tk), self.span); self.error(ErrorKind::ExpectedIdent) } } } fn lowercase_id(&mut self) -> Result { match self.bump() { TokenKind::Lowercase(s) => Ok(s), tk => { self.diagnostic .push(format!("expected lowercase identifier, found {:?}", tk), self.span); self.error(ErrorKind::ExpectedIdent) } } } fn literal(&mut self) -> Result { let lit = match self.bump() { TokenKind::Nat(x) => Literal::Nat(x), TokenKind::True => Literal::Bool(true), TokenKind::False => Literal::Bool(false), TokenKind::Unit => Literal::Unit, _ => return self.error(ErrorKind::Unknown), }; Ok(Term::new(Kind::Lit(lit), self.span)) } fn primitive(&mut self) -> Result { let p = match self.bump() { TokenKind::IsZero => Primitive::IsZero, TokenKind::Succ => Primitive::Succ, TokenKind::Pred => Primitive::Pred, _ => return self.error(ErrorKind::Unknown), }; Ok(Term::new(Kind::Primitive(p), self.span)) } /// Important to note that this function can push variable names to the /// de Bruijn naming context. Callers of this function are responsible for /// making sure that the stack is balanced afterwards fn pat_atom(&mut self) -> Result { match self.kind() { TokenKind::LParen => self.pattern(), TokenKind::Wildcard => { self.bump(); Ok(Pattern::Any) } TokenKind::Uppercase(_) => { let tycon = self.uppercase_id()?; let inner = match self.pattern() { Ok(pat) => pat, _ => Pattern::Any, }; Ok(Pattern::Constructor(tycon, Box::new(inner))) } TokenKind::Lowercase(_) => { let var = self.lowercase_id()?; // self.tmvar.push(var.clone()); Ok(Pattern::Variable(var)) } TokenKind::True => { self.bump(); Ok(Pattern::Literal(Literal::Bool(true))) } TokenKind::False => { self.bump(); Ok(Pattern::Literal(Literal::Bool(false))) } TokenKind::Unit => { self.bump(); Ok(Pattern::Literal(Literal::Unit)) } TokenKind::Nat(n) => { // O great borrowck, may this humble offering appease thee let n = *n; self.bump(); Ok(Pattern::Literal(Literal::Nat(n))) } _ => self.error(ErrorKind::ExpectedPattern), } } fn pattern(&mut self) -> Result { match self.kind() { TokenKind::LParen => { self.bump(); let mut v = self.once_or_more(|p| p.pat_atom(), TokenKind::Comma)?; self.expect(TokenKind::RParen)?; if v.len() > 1 { Ok(Pattern::Product(v)) } else { // v must have length == 1, else we would have early returned assert_eq!(v.len(), 1); Ok(v.remove(0)) } } _ => self.pat_atom(), } } fn case_arm(&mut self) -> Result { // match self.kind() { // TokenKind::Bar => self.bump(), // _ => return self.error(ErrorKind::ExpectedToken(TokenKind::Bar)), // }; // We don't track the length of the debruijn index in other methods, // but we have a couple branches where variables might be bound, // and this is pretty much the easiest way of doing it let len = self.tmvar.len(); let mut span = self.span; let mut pat = self.once(|p| p.pattern(), "missing pattern")?; for var in PatVarStack::collect(&mut pat).into_iter().rev() { self.tmvar.push(var); } self.expect(TokenKind::Equals)?; self.expect(TokenKind::Gt)?; let term = Box::new(self.once(|p| p.application(), "missing case term")?); self.bump_if(&TokenKind::Comma); // Unbind any variables from the parsing context while self.tmvar.len() > len { self.tmvar.pop(); } span = span + self.span; Ok(Arm { span, pat, term }) } fn case(&mut self) -> Result { self.expect(TokenKind::Case)?; let span = self.span; let expr = self.once(|p| p.parse(), "missing case expression")?; self.expect(TokenKind::Of)?; self.bump_if(&TokenKind::Bar); let arms = self.once_or_more(|p| p.case_arm(), TokenKind::Bar)?; Ok(Term::new(Kind::Case(Box::new(expr), arms), span + self.span)) } fn injection(&mut self) -> Result { let label = self.uppercase_id()?; let sp = self.span; let term = match self.parse() { Ok(t) => t, _ => Term::new(Kind::Lit(Literal::Unit), self.span), }; self.expect(TokenKind::Of)?; let ty = self.ty()?; Ok(Term::new( Kind::Injection(label, Box::new(term), Box::new(ty)), sp + self.span, )) } fn pack(&mut self) -> Result { self.expect(TokenKind::Pack)?; let sp = self.span; let witness = self.ty()?; self.expect(TokenKind::Comma)?; let evidence = self.parse()?; self.expect(TokenKind::As)?; let signature = self.ty()?; Ok(Term::new( Kind::Pack(Box::new(witness), Box::new(evidence), Box::new(signature)), sp + self.span, )) } fn unpack(&mut self) -> Result { self.expect(TokenKind::Unpack)?; let sp = self.span; let package = self.parse()?; self.expect(TokenKind::As)?; let tyvar = self.uppercase_id()?; self.expect(TokenKind::Comma)?; let name = self.lowercase_id()?; self.tyvar.push(tyvar); self.tmvar.push(name); self.expect(TokenKind::In)?; let expr = self.parse()?; self.tmvar.pop(); self.tyvar.pop(); Ok(Term::new( Kind::Unpack(Box::new(package), Box::new(expr)), sp + self.span, )) } fn atom(&mut self) -> Result { match self.kind() { TokenKind::LParen => self.paren(), TokenKind::Fix => self.fix(), TokenKind::Fold => self.fold(), TokenKind::Unfold => self.unfold(), TokenKind::Pack => self.pack(), TokenKind::Unpack => self.unpack(), TokenKind::IsZero | TokenKind::Succ | TokenKind::Pred => self.primitive(), TokenKind::Uppercase(_) => self.injection(), TokenKind::Lowercase(s) => { let var = self.lowercase_id()?; match self.tmvar.lookup(&var) { Some(idx) => Ok(Term::new(Kind::Var(idx), self.span)), None => { self.diagnostic.push(format!("unbound variable {}", var), self.span); self.error(ErrorKind::UnboundTypeVar) } } } TokenKind::Nat(_) | TokenKind::True | TokenKind::False | TokenKind::Unit => self.literal(), TokenKind::Eof => self.error(ErrorKind::Eof), TokenKind::Semicolon => { self.bump(); self.error(ErrorKind::ExpectedAtom) } _ => self.error(ErrorKind::ExpectedAtom), } } /// Parse a term of form: /// projection = atom `.` projection /// projection = atom fn projection(&mut self) -> Result { let atom = self.atom()?; if self.bump_if(&TokenKind::Proj) { let idx = match self.bump() { TokenKind::Nat(idx) => idx, _ => { self.diagnostic .push(format!("expected integer index after {}", atom), self.span); return self.error(ErrorKind::ExpectedToken(TokenKind::Proj)); } }; let sp = atom.span + self.span; Ok(Term::new(Kind::Projection(Box::new(atom), idx as usize), sp)) } else { Ok(atom) } } /// Parse an application of form: /// application = atom application' | atom /// application' = atom application' | empty fn application(&mut self) -> Result { let mut app = self.projection()?; loop { let sp = app.span; if let Ok(ty) = self.ty_app() { // Full type inference for System F is undecidable // Additionally, even partial type reconstruction, // where only type application types are erased is also // undecidable, see TaPL 23.6.2, Boehm 1985, 1989 // // Partial erasure rules: // erasep(x) = x // erasep(λx:T. t) = λx:T. erasep(t) // erasep(t1 t2) = erasep(t1) erasep(t2) // erasep(λX. t) = λX. erasep(t) // erasep(t T) = erasep(t) [] <--- erasure of TyApp app = Term::new(Kind::TyApp(Box::new(app), Box::new(ty)), sp + self.span); } else if let Ok(term) = self.projection() { app = Term::new(Kind::App(Box::new(app), Box::new(term)), sp + self.span); } else { break; } } Ok(app) } pub fn parse(&mut self) -> Result { match self.kind() { TokenKind::Case => self.case(), TokenKind::Lambda => self.lambda(), TokenKind::Let => self.letexpr(), _ => self.application(), } } } ================================================ FILE: 06_system_f/src/terms/mod.rs ================================================ //! Representation lambda calculus terms use crate::patterns::Pattern; use crate::types::Type; use std::fmt; use util::span::Span; pub mod visit; #[derive(Clone, PartialEq, PartialOrd)] pub struct Term { pub span: Span, pub kind: Kind, } /// Primitive functions supported by this implementation #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)] pub enum Primitive { Succ, Pred, IsZero, } /// Abstract syntax of the parametric polymorphic lambda calculus #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum Kind { /// A literal value Lit(Literal), /// A bound variable, represented by it's de Bruijn index Var(usize), /// Fixpoint operator/Y combinator Fix(Box), Primitive(Primitive), /// Injection into a sum type /// fields: type constructor tag, term, and sum type Injection(String, Box, Box), /// Product type (tuple) Product(Vec), /// Projection into a term Projection(Box, usize), /// A case expr, with case arms Case(Box, Vec), Let(Box, Box, Box), /// A lambda abstraction Abs(Box, Box), /// Application of a term to another term App(Box, Box), /// Type abstraction TyAbs(Box), /// Type application TyApp(Box, Box), Fold(Box, Box), Unfold(Box, Box), /// Introduce an existential type /// { *Ty1, Term } as {∃X.Ty} /// essentially, concrete representation as interface Pack(Box, Box, Box), /// Unpack an existential type /// open {∃X, bind} in body -- X is bound as a TyVar, and bind as Var(0) /// Eliminate an existential type Unpack(Box, Box), } /// Arm of a case expression #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Arm { pub span: Span, pub pat: Pattern, pub term: Box, } /// Constant literal expression or pattern #[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Hash)] pub enum Literal { Unit, Bool(bool), Nat(u32), } impl Term { pub fn new(kind: Kind, span: Span) -> Term { Term { span, kind } } #[allow(dead_code)] pub const fn unit() -> Term { Term { span: Span::dummy(), kind: Kind::Lit(Literal::Unit), } } #[allow(dead_code)] #[inline] pub fn span(&self) -> Span { self.span } #[inline] pub fn kind(&self) -> &Kind { &self.kind } } impl fmt::Display for Literal { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Literal::Nat(n) => write!(f, "{}", n), Literal::Bool(b) => write!(f, "{}", b), Literal::Unit => write!(f, "unit"), } } } impl fmt::Display for Term { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match &self.kind { Kind::Lit(lit) => write!(f, "{}", lit), Kind::Var(v) => write!(f, "#{}", v), Kind::Abs(ty, term) => write!(f, "(λ_:{:?}. {})", ty, term), Kind::Fix(term) => write!(f, "Fix {:?}", term), Kind::Primitive(p) => write!(f, "{:?}", p), Kind::Injection(label, tm, ty) => write!(f, "{}({})", label, tm), Kind::Projection(term, idx) => write!(f, "{}.{}", term, idx), Kind::Product(terms) => write!( f, "({})", terms .iter() .map(|t| format!("{}", t)) .collect::>() .join(",") ), Kind::Case(term, arms) => { writeln!(f, "case {} of", term)?; for arm in arms { writeln!(f, "\t| {:?} => {},", arm.pat, arm.term)?; } write!(f, "") } Kind::Let(pat, t1, t2) => write!(f, "let {:?} = {} in {}", pat, t1, t2), Kind::App(t1, t2) => write!(f, "({} {})", t1, t2), Kind::TyAbs(term) => write!(f, "(λTy {})", term), Kind::TyApp(term, ty) => write!(f, "({} [{:?}])", term, ty), Kind::Fold(ty, term) => write!(f, "fold [{:?}] {}", ty, term), Kind::Unfold(ty, term) => write!(f, "unfold [{:?}] {}", ty, term), Kind::Pack(witness, body, sig) => write!(f, "[|pack {{*{:?}, {}}} as {:?} |]", witness, body, sig), Kind::Unpack(m, n) => write!(f, "unpack {} as {}", m, n), } } } impl fmt::Debug for Term { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{:?}", self.kind) } } #[cfg(test)] mod test { use super::*; #[test] fn pattern_matches() { let ty = Type::Variant(vec![ variant!("A", Type::Nat), variant!("B", Type::Product(vec![Type::Nat, Type::Bool])), ]); let a_pats = vec![con!("A", Pattern::Any), con!("A", num!(9)), con!("A", num!(10))]; let b_pats = vec![ con!("B", Pattern::Any), con!("B", prod!(num!(1), boolean!(true))), con!("B", prod!(Pattern::Any, boolean!(false))), ]; let res = [true, false, true]; let a = inj!("A", nat!(10), ty.clone()); let b = inj!("B", tuple!(nat!(1), lit!(false)), ty.clone()); for (pat, result) in a_pats.iter().zip(&res) { assert_eq!(pat.matches(&a), *result); } for (pat, result) in b_pats.iter().zip(&res) { assert_eq!(pat.matches(&b), *result, "{:?}", pat); } } } ================================================ FILE: 06_system_f/src/terms/visit.rs ================================================ use crate::patterns::{Pattern, PatternCount}; use crate::terms::{Arm, Kind, Primitive, Term}; use crate::types::Type; use crate::visit::{MutTermVisitor, MutTypeVisitor}; use util::span::Span; pub struct Shift { cutoff: usize, shift: isize, } impl Shift { pub const fn new(shift: isize) -> Shift { Shift { cutoff: 0, shift } } } impl MutTermVisitor for Shift { fn visit_var(&mut self, sp: &mut Span, var: &mut usize) { if *var >= self.cutoff { *var = (*var as isize + self.shift) as usize; } } fn visit_abs(&mut self, sp: &mut Span, ty: &mut Type, term: &mut Term) { self.cutoff += 1; self.visit(term); self.cutoff -= 1; } fn visit_let(&mut self, sp: &mut Span, pat: &mut Pattern, t1: &mut Term, t2: &mut Term) { self.visit(t1); let c = PatternCount::collect(pat); self.cutoff += c; self.visit(t2); self.cutoff -= c; } fn visit_case(&mut self, sp: &mut Span, term: &mut Term, arms: &mut Vec) { self.visit(term); for arm in arms { let c = PatternCount::collect(&mut arm.pat); self.cutoff += c; self.visit(&mut arm.term); self.cutoff -= c; } } fn visit_unpack(&mut self, _: &mut Span, package: &mut Term, term: &mut Term) { self.visit(package); self.cutoff += 1; self.visit(term); self.cutoff -= 1; } } pub struct Subst { cutoff: usize, term: Term, } impl Subst { pub fn new(term: Term) -> Subst { Subst { cutoff: 0, term } } } impl MutTermVisitor for Subst { fn visit_abs(&mut self, sp: &mut Span, ty: &mut Type, term: &mut Term) { self.cutoff += 1; self.visit(term); self.cutoff -= 1; } fn visit_let(&mut self, sp: &mut Span, pat: &mut Pattern, t1: &mut Term, t2: &mut Term) { self.visit(t1); let c = PatternCount::collect(pat); self.cutoff += c; self.visit(t2); self.cutoff -= c; } fn visit_case(&mut self, sp: &mut Span, term: &mut Term, arms: &mut Vec) { self.visit(term); for arm in arms { let c = PatternCount::collect(&mut arm.pat); self.cutoff += c; self.visit(&mut arm.term); self.cutoff -= c; } } fn visit_unpack(&mut self, _: &mut Span, package: &mut Term, term: &mut Term) { self.visit(package); self.cutoff += 1; self.visit(term); self.cutoff -= 1; } fn visit(&mut self, term: &mut Term) { let sp = &mut term.span; match &mut term.kind { Kind::Var(v) if *v == self.cutoff => { Shift::new(self.cutoff as isize).visit(&mut self.term); *term = self.term.clone(); } _ => self.walk(term), } } } pub struct TyTermSubst { cutoff: usize, ty: Type, } impl TyTermSubst { pub fn new(ty: Type) -> TyTermSubst { use crate::types::visit::*; let mut ty = ty; Shift::new(1).visit(&mut ty); TyTermSubst { cutoff: 0, ty } } fn visit_ty(&mut self, ty: &mut Type) { let mut s = crate::types::visit::Subst { cutoff: self.cutoff, ty: self.ty.clone(), }; s.visit(ty); } } impl MutTermVisitor for TyTermSubst { fn visit_abs(&mut self, sp: &mut Span, ty: &mut Type, term: &mut Term) { // self.cutoff += 1; self.visit_ty(ty); self.visit(term); // self.cutoff -= 1; } fn visit_tyapp(&mut self, sp: &mut Span, term: &mut Term, ty: &mut Type) { self.visit_ty(ty); self.visit(term); } fn visit_tyabs(&mut self, sp: &mut Span, term: &mut Term) { self.cutoff += 1; self.visit(term); self.cutoff -= 1; } fn visit_fold(&mut self, sp: &mut Span, ty: &mut Type, term: &mut Term) { self.visit_ty(ty); self.visit(term); } fn visit_unfold(&mut self, sp: &mut Span, ty: &mut Type, term: &mut Term) { self.visit_ty(ty); self.visit(term); } fn visit_unpack(&mut self, _: &mut Span, package: &mut Term, term: &mut Term) { self.visit(package); self.cutoff += 1; self.visit(term); self.cutoff -= 1; } fn visit_pack(&mut self, _: &mut Span, wit: &mut Type, body: &mut Term, sig: &mut Type) { self.visit_ty(wit); self.visit(body); self.visit_ty(sig); } fn visit_injection(&mut self, sp: &mut Span, label: &mut String, term: &mut Term, ty: &mut Type) { self.visit_ty(ty); self.visit(term); } } /// Visitor for handling recursive variants automatically, by inserting a /// fold term /// /// Transform an [`Injection`] term of form: `Label tm of Rec(u.T)` into /// `fold [u.T] Label tm of [X->u.T] T` pub struct InjRewriter; impl MutTermVisitor for InjRewriter { fn visit(&mut self, term: &mut Term) { match &mut term.kind { Kind::Injection(label, val, ty) => { match *ty.clone() { Type::Rec(inner) => { let ty_prime = crate::types::subst(*ty.clone(), *inner.clone()); let rewrite_ty = Term::new( Kind::Injection(label.clone(), val.clone(), Box::new(ty_prime)), term.span, ); *term = Term::new(Kind::Fold(ty.clone(), Box::new(rewrite_ty)), term.span); } _ => {} } self.walk(term); } _ => self.walk(term), } } } ================================================ FILE: 06_system_f/src/types/mod.rs ================================================ //! Typechecking of the simply typed lambda calculus with parametric //! polymorphism pub mod patterns; pub mod visit; use crate::diagnostics::*; use crate::terms::{Kind, Literal, Primitive, Term}; use crate::visit::{MutTermVisitor, MutTypeVisitor}; use std::collections::{HashMap, VecDeque}; use std::fmt; use util::span::Span; use visit::{Shift, Subst}; #[derive(Clone, PartialEq, PartialOrd, Eq, Hash)] pub enum Type { Unit, Nat, Bool, Alias(String), Var(usize), Variant(Vec), Product(Vec), Arrow(Box, Box), Universal(Box), Existential(Box), Rec(Box), } #[derive(Clone, PartialEq, PartialOrd, Eq, Hash)] pub struct Variant { pub label: String, pub ty: Type, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct TypeError { pub span: Span, pub kind: TypeErrorKind, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum TypeErrorKind { ParameterMismatch(Box, Box, Span), InvalidProjection, NotArrow, NotUniversal, NotVariant, NotProduct, NotRec, IncompatibleArms, InvalidPattern, NotExhaustive, UnreachablePattern, UnboundVariable(usize), } #[derive(Clone, Debug, Default, PartialEq)] pub struct Context { stack: VecDeque, map: HashMap, } impl Context { fn push(&mut self, ty: Type) { self.stack.push_front(ty); } fn pop(&mut self) { self.stack.pop_front().expect("Context::pop() with empty type stack"); } fn find(&self, idx: usize) -> Option<&Type> { self.stack.get(idx) } pub fn alias(&mut self, alias: String, ty: Type) { self.map.insert(alias, ty); } fn aliaser(&self) -> Aliaser<'_> { Aliaser { map: &self.map } } pub fn de_alias(&mut self, term: &mut Term) { crate::visit::MutTermVisitor::visit(self, term) } } /// Helper function for extracting type from a variant pub fn variant_field<'vs>(var: &'vs [Variant], label: &str, span: Span) -> Result<&'vs Type, Diagnostic> { for f in var { if label == f.label { return Ok(&f.ty); } } Err(Diagnostic::error( span, format!("constructor {} doesn't appear in variant fields", label), )) // Err(TypeError { // span, // kind: TypeErrorKind::NotVariant, // }) } impl Context { pub fn type_check(&mut self, term: &Term) -> Result { // dbg!(&self.stack); // println!("{}", term); match term.kind() { Kind::Lit(Literal::Unit) => Ok(Type::Unit), Kind::Lit(Literal::Bool(_)) => Ok(Type::Bool), Kind::Lit(Literal::Nat(_)) => Ok(Type::Nat), Kind::Var(idx) => self .find(*idx) .cloned() .ok_or_else(|| Diagnostic::error(term.span, format!("unbound variable {}", idx))), Kind::Abs(ty, t2) => { self.push(*ty.clone()); let ty2 = self.type_check(t2)?; // Shift::new(-1).visit(&mut ty2); self.pop(); Ok(Type::Arrow(ty.clone(), Box::new(ty2))) } Kind::App(t1, t2) => { let ty1 = self.type_check(t1)?; let ty2 = self.type_check(t2)?; match ty1 { Type::Arrow(ty11, ty12) => { if *ty11 == ty2 { Ok(*ty12) } else { let d = Diagnostic::error(term.span, "Type mismatch in application") .message(t1.span, format!("Abstraction requires type {:?}", ty11)) .message(t2.span, format!("Value has a type of {:?}", ty2)); Err(d) } } _ => Err(Diagnostic::error(term.span, "Expected arrow type!") .message(t1.span, format!("operator has type {:?}", ty1))), } } Kind::Fix(inner) => { let ty = self.type_check(inner)?; match ty { Type::Arrow(ty1, ty2) => { if ty1 == ty2 { Ok(*ty1) } else { let d = Diagnostic::error(term.span, "Type mismatch in fix term") .message(inner.span, format!("Abstraction requires type {:?}->{:?}", ty1, ty1)); Err(d) } } _ => Err(Diagnostic::error(term.span, "Expected arrow type!") .message(inner.span, format!("operator has type {:?}", ty))), } } Kind::Primitive(prim) => match prim { Primitive::IsZero => Ok(Type::Arrow(Box::new(Type::Nat), Box::new(Type::Bool))), _ => Ok(Type::Arrow(Box::new(Type::Nat), Box::new(Type::Nat))), }, Kind::Injection(label, tm, ty) => match ty.as_ref() { Type::Variant(fields) => { for f in fields { if label == &f.label { let ty_ = self.type_check(tm)?; if ty_ == f.ty { return Ok(*ty.clone()); } else { let d = Diagnostic::error(term.span, "Invalid associated type in variant").message( tm.span, format!("variant {} requires type {:?}, but this is {:?}", label, f.ty, ty_), ); return Err(d); } } } Err(Diagnostic::error( term.span, format!( "constructor {} does not belong to the variant {:?}", label, fields .iter() .map(|f| f.label.clone()) .collect::>() .join(" | ") ), )) } _ => Err(Diagnostic::error( term.span, format!("Cannot injection {} into non-variant type {:?}", label, ty), )), }, Kind::Projection(term, idx) => match self.type_check(term)? { Type::Product(types) => match types.get(*idx) { Some(ty) => Ok(ty.clone()), None => Err(Diagnostic::error( term.span, format!("{} is out of range for product of length {}", idx, types.len()), )), }, ty => Err(Diagnostic::error( term.span, format!("Cannot project on non-product type {:?}", ty), )), }, Kind::Product(terms) => Ok(Type::Product( terms.iter().map(|t| self.type_check(t)).collect::>()?, )), Kind::Let(pat, t1, t2) => { let ty = self.type_check(t1)?; if !self.pattern_type_eq(&pat, &ty) { return Err(Diagnostic::error( t1.span, format!("pattern does not match type of binder"), )); } let height = self.stack.len(); let binds = crate::patterns::PatTyStack::collect(&ty, &pat); for b in binds.into_iter().rev() { self.push(b.clone()); } let y = self.type_check(t2); while self.stack.len() > height { self.pop(); } y } Kind::TyAbs(term) => { self.stack.iter_mut().for_each(|ty| match ty { Type::Var(v) => *v += 1, _ => {} }); let ty2 = self.type_check(term)?; self.stack.iter_mut().for_each(|ty| match ty { Type::Var(v) => *v -= 1, _ => {} }); Ok(Type::Universal(Box::new(ty2))) } Kind::TyApp(term, ty) => { let mut ty = ty.clone(); let ty1 = self.type_check(term)?; match ty1 { Type::Universal(mut ty12) => { Shift::new(1).visit(&mut ty); Subst::new(*ty).visit(&mut ty12); Shift::new(-1).visit(&mut ty12); Ok(*ty12) } _ => Err(Diagnostic::error( term.span, format!("Expected a universal type, not {:?}", ty1), )), } } // See src/types/patterns.rs for exhaustiveness and typechecking // of case expressions Kind::Case(expr, arms) => self.type_check_case(expr, arms), Kind::Unfold(rec, tm) => match rec.as_ref() { Type::Rec(inner) => { let ty_ = self.type_check(&tm)?; if ty_ == *rec.clone() { let s = subst(*rec.clone(), *inner.clone()); Ok(s) } else { let d = Diagnostic::error(term.span, "Type mismatch in unfold") .message(term.span, format!("unfold requires type {:?}", rec)) .message(tm.span, format!("term has a type of {:?}", ty_)); Err(d) } } _ => Err(Diagnostic::error( term.span, format!("Expected a recursive type, not {:?}", rec), )), }, Kind::Fold(rec, tm) => match rec.as_ref() { Type::Rec(inner) => { let ty_ = self.type_check(&tm)?; let s = subst(*rec.clone(), *inner.clone()); if ty_ == s { Ok(*rec.clone()) } else { let d = Diagnostic::error(term.span, "Type mismatch in fold") .message(term.span, format!("unfold requires type {:?}", s)) .message(tm.span, format!("term has a type of {:?}", ty_)); Err(d) } } _ => Err(Diagnostic::error( term.span, format!("Expected a recursive type, not {:?}", rec), )), }, Kind::Pack(witness, evidence, signature) => { if let Type::Existential(exists) = signature.as_ref() { let sig_prime = subst(*witness.clone(), *exists.clone()); let evidence_ty = self.type_check(evidence)?; if evidence_ty == sig_prime { Ok(*signature.clone()) } else { let d = Diagnostic::error(term.span, "Type mismatch in pack") .message(term.span, format!("signature has type {:?}", sig_prime)) .message(evidence.span, format!("but term has a type {:?}", evidence_ty)); Err(d) } } else { Err(Diagnostic::error( term.span, format!("Expected an existential type signature, not {:?}", signature), )) } } Kind::Unpack(package, body) => { let p_ty = self.type_check(package)?; if let Type::Existential(xst) = p_ty { self.push(*xst); let body_ty = self.type_check(body)?; self.pop(); Ok(body_ty) } else { Err(Diagnostic::error( package.span, format!("Expected an existential type signature, not {:?}", p_ty), )) } } } } } pub fn subst(mut s: Type, mut t: Type) -> Type { Shift::new(1).visit(&mut s); Subst::new(s).visit(&mut t); Shift::new(-1).visit(&mut t); t } struct Aliaser<'ctx> { map: &'ctx HashMap, } impl<'ctx> MutTypeVisitor for Aliaser<'ctx> { fn visit(&mut self, ty: &mut Type) { match ty { Type::Unit | Type::Bool | Type::Nat => {} Type::Var(v) => {} Type::Alias(v) => { if let Some(aliased) = self.map.get(v) { *ty = aliased.clone(); } } Type::Variant(v) => self.visit_variant(v), Type::Product(v) => self.visit_product(v), Type::Arrow(ty1, ty2) => self.visit_arrow(ty1, ty2), Type::Universal(ty) => self.visit_universal(ty), Type::Existential(ty) => self.visit_existential(ty), Type::Rec(ty) => self.visit_rec(ty), } } } impl MutTermVisitor for Context { fn visit_abs(&mut self, sp: &mut Span, ty: &mut Type, term: &mut Term) { self.aliaser().visit(ty); self.visit(term); } fn visit_tyapp(&mut self, sp: &mut Span, term: &mut Term, ty: &mut Type) { self.aliaser().visit(ty); self.visit(term); } fn visit_injection(&mut self, sp: &mut Span, label: &mut String, term: &mut Term, ty: &mut Type) { self.aliaser().visit(ty); self.visit(term); } fn visit_fold(&mut self, sp: &mut Span, ty: &mut Type, tm: &mut Term) { self.aliaser().visit(ty); self.visit(tm); } fn visit_unfold(&mut self, sp: &mut Span, ty: &mut Type, tm: &mut Term) { self.aliaser().visit(ty); self.visit(tm); } } impl fmt::Debug for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Type::Unit => write!(f, "Unit"), Type::Bool => write!(f, "Bool"), Type::Nat => write!(f, "Nat"), Type::Var(v) => write!(f, "TyVar({})", v), Type::Variant(v) => write!( f, "{:?}", v.iter() .map(|x| format!("{}: {:?}", x.label, x.ty)) .collect::>() .join(" | ") ), Type::Product(v) => write!( f, "({})", v.iter().map(|x| format!("{:?}", x)).collect::>().join(",") ), Type::Alias(s) => write!(f, "{}", s), Type::Arrow(t1, t2) => write!(f, "({:?}->{:?})", t1, t2), Type::Universal(ty) => write!(f, "forall X.{:?}", ty), Type::Existential(ty) => write!(f, "exists X.{:?}", ty), Type::Rec(ty) => write!(f, "rec {:?}", ty), } } } ================================================ FILE: 06_system_f/src/types/patterns.rs ================================================ //! Naive, inefficient exhaustiveness checking for pattern matching //! //! Inspired somewhat by the docs for the Rust compiler (and linked paper), we //! create a "usefulness" predicate. We store current patterns in a row-wise //! [`Matrix`], and iterate through each row in the matrix every time we want //! to add a new pattern. If no existing rows completely overlap the new row, //! then we can determine that the new row is "useful", and add it. //! //! To check for exhaustiveness, we simply create a row of Wildcard matches, //! and see if it would be useful to add //! //! https://doc.rust-lang.org/nightly/nightly-rustc/src/rustc_mir/hair/pattern/_match.rs.html //! http://moscova.inria.fr/~maranget/papers/warn/index.html //! use super::*; use crate::diagnostics::*; use crate::patterns::{PatTyStack, Pattern}; use crate::terms::*; use std::collections::HashSet; /// Return true if `existing` covers `new`, i.e. if new is a useful pattern /// then `overlap` will return `false` fn overlap(existing: &Pattern, new: &Pattern) -> bool { use Pattern::*; match (existing, new) { (Any, _) => true, (Variable(_), _) => true, (Constructor(l, a), Constructor(l2, b)) => { if l == l2 { overlap(a, b) } else { false } } (Product(a), Product(b)) => a.iter().zip(b.iter()).all(|(a, b)| overlap(a, b)), (Product(a), b) => a.iter().all(|a| overlap(a, b)), (x, y) => x == y, } } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Matrix<'pat> { pub expr_ty: Type, len: usize, matrix: Vec>, } impl<'pat> Matrix<'pat> { /// Create a new [`Matrix`] for a given type pub fn new(expr_ty: Type) -> Matrix<'pat> { let len = match &expr_ty { Type::Product(p) => p.len(), _ => 1, }; Matrix { expr_ty, len, matrix: Vec::new(), } } /// Is the pattern [`Matrix`] exhaustive for this type? /// /// For a boolean type, True, False, or a wildcard/variable match are /// required For a product type (tuple), the algorithm is slightly more /// complicated: We generate a tuple of length N (equal to the length of /// the case expressions' tuple) filled with Wildcard patterns, and see /// if addition of tuple is a useful pattern. If the pattern is not /// useful (i.e. it totally `overlap's` with an existing row), then the /// matrix is exhaustive /// /// For a sum type, a dummy constructor of pattern `Const_i _` is generated /// for all `i` of the possible constructors of the type. If none of the /// dummy constructors are useful, then the current patterns are exhaustive pub fn exhaustive(&self) -> bool { match &self.expr_ty { Type::Variant(v) => v.iter().all(|variant| { // For all constructors in the sum type, generate a constructor // pattern that will match all possible inhabitants of that // constructor let con = Pattern::Constructor(variant.label.clone(), Box::new(Pattern::Any)); let temp = [&con]; let mut ret = false; for row in &self.matrix { if row.iter().zip(&temp).all(|(a, b)| overlap(a, b)) { ret = true; break; } } ret }), Type::Product(_) | Type::Nat => { // Generate a tuple of wildcard patterns. If the pattern is // useful, then we do not have an exhaustive matrix let filler = (0..self.len).map(|_| Pattern::Any).collect::>(); for row in &self.matrix { if row.iter().zip(filler.iter()).all(|(a, b)| overlap(a, b)) { return true; } } false } Type::Bool => { // Boolean type is one of the simplest cases: we only need // to match `true` and `false`, or one of those + a wildcard, // or just a wildcard let tru = Pattern::Literal(Literal::Bool(true)); let fal = Pattern::Literal(Literal::Bool(false)); !(self.can_add_row(vec![&tru]) && self.can_add_row(vec![&fal])) } Type::Unit => { // Unit is a degenerate case let unit = Pattern::Literal(Literal::Unit); !self.can_add_row(vec![&unit]) } _ => false, } } /// Return true if a new pattern row is reachable fn can_add_row(&self, new_row: Vec<&'pat Pattern>) -> bool { assert_eq!(self.len, new_row.len()); for row in &self.matrix { if row.iter().zip(new_row.iter()).all(|(a, b)| overlap(a, b)) { return false; } } true } fn try_add_row(&mut self, new_row: Vec<&'pat Pattern>) -> bool { assert_eq!(self.len, new_row.len()); for row in &self.matrix { if row.iter().zip(new_row.iter()).all(|(a, b)| overlap(a, b)) { return false; } } self.matrix.push(new_row); true } /// Attempt to add a new [`Pattern`] to the [`Matrix`] /// /// Returns true on success, and false if the new pattern is /// unreachable pub fn add_pattern(&mut self, pat: &'pat Pattern) -> bool { match pat { Pattern::Any | Pattern::Variable(_) => { let filler = (0..self.len).map(|_| &Pattern::Any).collect::>(); self.try_add_row(filler) } Pattern::Product(tuple) => self.try_add_row(tuple.iter().collect()), Pattern::Literal(lit) => { if self.len == 1 { self.try_add_row(vec![pat]) } else { false } } Pattern::Constructor(label, inner) => self.try_add_row(vec![pat]), } } } impl Context { /// Type check a case expression, returning the Type of the arms, assuming /// that the case expression is exhaustive and well-typed /// /// This is one of the more complicated functions in the typechecker. /// 1) We first have to check that each arm in the case expression has /// a pattern that is proper for type of the case expression - we shouldn't /// have case arms with patterns that can be never be matched! /// /// 2) We then need to bind any variables referenced in the pattern into /// the typing context - `Cons (x, xs)` needs to bind both x and xs, as does /// just `Cons x`. /// /// 3) After variable binding, we need to typecheck the actual case arm's /// term, and store the result so that we can compare it to the types of /// the other arms in the case expression /// /// 4) If the arm is properly typed, then we need to add it to a matrix so /// that we can determine if the pattern is reachable, and if the case arms /// are exhaustive - one, and only one, pattern should be matchable /// /// 5) Finally, assuming all of the previous checks have passed, we return /// the shared type of all of the case arms - the term associated with each /// arm should have one type, and that type should be the same for all of /// the arms. pub(crate) fn type_check_case(&mut self, expr: &Term, arms: &[Arm]) -> Result { let ty = self.type_check(expr)?; let mut matrix = patterns::Matrix::new(ty); let mut set = HashSet::new(); for arm in arms { if self.pattern_type_eq(&arm.pat, &matrix.expr_ty) { let height = self.stack.len(); let binds = PatTyStack::collect(&matrix.expr_ty, &arm.pat); for b in binds.into_iter().rev() { self.push(b.clone()); } let arm_ty = self.type_check(&arm.term)?; while self.stack.len() > height { self.pop(); } set.insert(arm_ty); if !matrix.add_pattern(&arm.pat) { return Err(Diagnostic::error(arm.span, "unreachable pattern!")); } } else { return Err( Diagnostic::error(expr.span, format!("case binding has a type {:?}", &matrix.expr_ty)).message( arm.span, format!("but this pattern cannot bind a value of type {:?}", &matrix.expr_ty), ), ); } } if set.len() != 1 { return Err(Diagnostic::error(expr.span, format!("incompatible arms! {:?}", set))); } if matrix.exhaustive() { match set.into_iter().next() { Some(s) => Ok(s), None => Err(Diagnostic::error( expr.span, "probably unreachable - expected variant type!", )), } } else { Err(Diagnostic::error(expr.span, "patterns are not exhaustive!")) } } /// Helper function for pattern to type equivalence /// /// A `_` wildcard pattern is obviously valid for every type, as is a /// variable binding: /// case Some(10) of /// | None => None /// | x => x -- x will always match to Some(10) here /// /// A literal pattern should only be equal to the equivalent type, etc /// /// This function is primarily used as a first pass to ensure that a pattern /// is valid for a given case expression pub(crate) fn pattern_type_eq(&self, pat: &Pattern, ty: &Type) -> bool { match pat { Pattern::Any => true, Pattern::Variable(_) => true, Pattern::Literal(lit) => match (lit, ty) { (Literal::Bool(_), Type::Bool) => true, (Literal::Nat(_), Type::Nat) => true, (Literal::Unit, Type::Unit) => true, _ => false, }, Pattern::Product(patterns) => match ty { Type::Product(types) => { patterns.len() == types.len() && patterns .iter() .zip(types.iter()) .all(|(pt, tt)| self.pattern_type_eq(pt, tt)) } _ => false, }, Pattern::Constructor(label, inner) => match ty { Type::Variant(v) => { for discriminant in v { if label == &discriminant.label && self.pattern_type_eq(&inner, &discriminant.ty) { return true; } } false } _ => false, }, } } } #[cfg(test)] mod test { use super::*; use Pattern::*; #[test] fn product() { let ty = Type::Product(vec![Type::Bool, Type::Bool, Type::Nat]); let pat = prod!(boolean!(true), boolean!(true), num!(10)); let ctx = Context::default(); assert!(ctx.pattern_type_eq(&pat, &ty)); } #[test] #[should_panic] fn product_mistyped() { let ty = Type::Product(vec![Type::Bool, Type::Bool, Type::Bool]); let pat = prod!(boolean!(true), boolean!(true), num!(10)); let ctx = Context::default(); assert!(ctx.pattern_type_eq(&pat, &ty)); } #[test] fn constructor() { let ty = Type::Variant(vec![ Variant { label: "A".into(), ty: Type::Unit, }, Variant { label: "B".into(), ty: Type::Nat, }, ]); let pat1 = con!("A", Pattern::Any); let pat2 = con!("A", boolean!(true)); let pat3 = con!("B", num!(1)); let ctx = Context::default(); assert!(ctx.pattern_type_eq(&pat1, &ty)); assert!(!ctx.pattern_type_eq(&pat2, &ty)); assert!(ctx.pattern_type_eq(&pat3, &ty)); } #[test] fn constructor_product() { let ty = Type::Variant(vec![ Variant { label: "A".into(), ty: Type::Unit, }, Variant { label: "B".into(), ty: Type::Product(vec![Type::Nat, Type::Nat]), }, ]); let pat1 = con!("A", Any); let pat2 = con!("B", Any); let pat3 = con!("B", prod!(Any, Variable("x".into()))); let pat4 = con!("B", prod!(num!(1), Variable("x".into()))); let pat5 = con!("A", num!(1)); let ctx = Context::default(); assert!(ctx.pattern_type_eq(&pat1, &ty)); assert!(ctx.pattern_type_eq(&pat2, &ty)); assert!(ctx.pattern_type_eq(&pat3, &ty)); assert!(ctx.pattern_type_eq(&pat4, &ty)); assert!(!ctx.pattern_type_eq(&pat5, &ty)); } #[test] fn matrix_tuple() { let pats = vec![ prod!(num!(0), num!(1)), prod!(num!(1), num!(1)), prod!(Any, num!(2)), prod!(num!(2), Any), prod!(num!(1), num!(4)), prod!(Any, Variable(String::default())), ]; let ty = Type::Product(vec![Type::Nat, Type::Nat]); let mut matrix = Matrix::new(ty); for pat in &pats { assert!(matrix.add_pattern(pat)); } assert!(!matrix.add_pattern(&Any)); assert!(matrix.exhaustive()); } #[test] fn matrix_constructor() { let ty = Type::Variant(vec![ variant!("A", Type::Nat), variant!("B", Type::Nat), variant!("C", Type::Product(vec![Type::Nat, Type::Nat])), ]); let pats = vec![ con!("A", num!(20)), con!("A", Any), con!("B", Any), con!("C", prod!(num!(1), num!(1))), con!("C", prod!(Any, num!(1))), con!("C", prod!(num!(1), Any)), ]; let ctx = Context::default(); assert!(pats.iter().all(|p| ctx.pattern_type_eq(p, &ty))); let mut matrix = Matrix::new(ty); for p in &pats { assert!(matrix.add_pattern(p)); } let last = con!("C", Any); assert!(!matrix.exhaustive()); assert!(matrix.add_pattern(&last)); assert!(matrix.exhaustive()); } #[test] fn matrix_bool() { let pats = vec![boolean!(true), boolean!(false)]; let ty = Type::Bool; let ctx = Context::default(); assert!(pats.iter().all(|p| ctx.pattern_type_eq(p, &ty))); let mut matrix = Matrix::new(ty); for p in &pats { assert!(matrix.add_pattern(p)); } assert!(!matrix.add_pattern(&pats[1])); assert!(matrix.exhaustive()); } } ================================================ FILE: 06_system_f/src/types/visit.rs ================================================ use super::Type; use crate::visit::MutTypeVisitor; use std::convert::TryFrom; pub struct Shift { pub cutoff: usize, pub shift: isize, } impl Shift { pub const fn new(shift: isize) -> Shift { Shift { cutoff: 0, shift } } } impl MutTypeVisitor for Shift { fn visit_var(&mut self, var: &mut usize) { if *var >= self.cutoff { *var = usize::try_from(*var as isize + self.shift).expect("Variable has been shifted below 0! Fatal bug"); } } fn visit_universal(&mut self, inner: &mut Type) { self.cutoff += 1; self.visit(inner); self.cutoff -= 1; } fn visit_existential(&mut self, inner: &mut Type) { self.cutoff += 1; self.visit(inner); self.cutoff -= 1; } fn visit_rec(&mut self, ty: &mut Type) { self.cutoff += 1; self.visit(ty); self.cutoff -= 1; } } pub struct Subst { pub cutoff: usize, pub ty: Type, } impl Subst { pub fn new(ty: Type) -> Subst { Subst { cutoff: 0, ty } } } impl MutTypeVisitor for Subst { fn visit_universal(&mut self, inner: &mut Type) { self.cutoff += 1; self.visit(inner); self.cutoff -= 1; } fn visit_existential(&mut self, inner: &mut Type) { self.cutoff += 1; self.visit(inner); self.cutoff -= 1; } fn visit_rec(&mut self, ty: &mut Type) { self.cutoff += 1; self.visit(ty); self.cutoff -= 1; } fn visit(&mut self, ty: &mut Type) { match ty { Type::Unit | Type::Bool | Type::Nat => {} Type::Var(v) if *v >= self.cutoff => { Shift::new(self.cutoff as isize).visit(&mut self.ty); *ty = self.ty.clone(); } Type::Var(v) => self.visit_var(v), Type::Variant(v) => self.visit_variant(v), Type::Product(v) => self.visit_product(v), Type::Alias(v) => self.visit_alias(v), Type::Arrow(ty1, ty2) => self.visit_arrow(ty1, ty2), Type::Universal(ty) => self.visit_universal(ty), Type::Existential(ty) => self.visit_existential(ty), Type::Rec(ty) => self.visit_rec(ty), } } } ================================================ FILE: 06_system_f/src/visit.rs ================================================ //! Visitor traits for [`Pattern`], [`Term`], and [`Type`] objects use crate::patterns::Pattern; use crate::terms::{Arm, Kind, Literal, Primitive, Term}; use crate::types::{Type, Variant}; use util::span::Span; pub trait MutTypeVisitor: Sized { fn visit_var(&mut self, var: &mut usize) {} fn visit_alias(&mut self, alias: &mut String) {} fn visit_arrow(&mut self, ty1: &mut Type, ty2: &mut Type) { self.visit(ty1); self.visit(ty2); } fn visit_universal(&mut self, inner: &mut Type) { self.visit(inner); } fn visit_existential(&mut self, inner: &mut Type) { self.visit(inner); } fn visit_variant(&mut self, variant: &mut Vec) { for v in variant { self.visit(&mut v.ty); } } fn visit_product(&mut self, product: &mut Vec) { for v in product { self.visit(v); } } fn visit_rec(&mut self, ty: &mut Type) { self.visit(ty); } fn visit(&mut self, ty: &mut Type) { match ty { Type::Unit | Type::Bool | Type::Nat => {} Type::Var(v) => self.visit_var(v), Type::Variant(v) => self.visit_variant(v), Type::Product(v) => self.visit_product(v), Type::Alias(s) => self.visit_alias(s), Type::Arrow(ty1, ty2) => self.visit_arrow(ty1, ty2), Type::Universal(ty) => self.visit_universal(ty), Type::Existential(ty) => self.visit_existential(ty), Type::Rec(ty) => self.visit_rec(ty), } } } pub trait MutTermVisitor: Sized { fn visit_lit(&mut self, sp: &mut Span, lit: &mut Literal) {} fn visit_var(&mut self, sp: &mut Span, var: &mut usize) {} fn visit_abs(&mut self, sp: &mut Span, ty: &mut Type, term: &mut Term) { self.visit(term); } fn visit_app(&mut self, sp: &mut Span, t1: &mut Term, t2: &mut Term) { self.visit(t1); self.visit(t2); } fn visit_let(&mut self, sp: &mut Span, pat: &mut Pattern, t1: &mut Term, t2: &mut Term) { self.visit(t1); self.visit(t2); } fn visit_tyabs(&mut self, sp: &mut Span, term: &mut Term) { self.visit(term); } fn visit_tyapp(&mut self, sp: &mut Span, term: &mut Term, ty: &mut Type) { self.visit(term); } fn visit_primitive(&mut self, sp: &mut Span, prim: &mut Primitive) {} fn visit_injection(&mut self, sp: &mut Span, label: &mut String, term: &mut Term, ty: &mut Type) { self.visit(term); } fn visit_case(&mut self, sp: &mut Span, term: &mut Term, arms: &mut Vec) { self.visit(term); for arm in arms { self.visit(&mut arm.term); } } fn visit_product(&mut self, sp: &mut Span, product: &mut Vec) { for t in product { self.visit(t); } } fn visit_projection(&mut self, sp: &mut Span, term: &mut Term, index: &mut usize) { self.visit(term); } fn visit_fold(&mut self, sp: &mut Span, ty: &mut Type, term: &mut Term) { self.visit(term); } fn visit_unfold(&mut self, sp: &mut Span, ty: &mut Type, term: &mut Term) { self.visit(term); } fn visit_pack(&mut self, sp: &mut Span, witness: &mut Type, evidence: &mut Term, signature: &mut Type) { self.visit(evidence); } fn visit_unpack(&mut self, sp: &mut Span, package: &mut Term, term: &mut Term) { self.visit(package); self.visit(term); } fn visit(&mut self, term: &mut Term) { self.walk(term); } fn walk(&mut self, term: &mut Term) { let sp = &mut term.span; match &mut term.kind { Kind::Lit(l) => self.visit_lit(sp, l), Kind::Var(v) => self.visit_var(sp, v), Kind::Abs(ty, term) => self.visit_abs(sp, ty, term), Kind::App(t1, t2) => self.visit_app(sp, t1, t2), // Do we need a separate branch? Kind::Fix(term) => self.visit(term), Kind::Primitive(p) => self.visit_primitive(sp, p), Kind::Injection(label, tm, ty) => self.visit_injection(sp, label, tm, ty), Kind::Projection(term, idx) => self.visit_projection(sp, term, idx), Kind::Product(terms) => self.visit_product(sp, terms), Kind::Case(term, arms) => self.visit_case(sp, term, arms), Kind::Let(pat, t1, t2) => self.visit_let(sp, pat, t1, t2), Kind::TyAbs(term) => self.visit_tyabs(sp, term), Kind::TyApp(term, ty) => self.visit_tyapp(sp, term, ty), Kind::Fold(ty, term) => self.visit_fold(sp, ty, term), Kind::Unfold(ty, term) => self.visit_unfold(sp, ty, term), Kind::Pack(wit, term, sig) => self.visit_pack(sp, wit, term, sig), Kind::Unpack(package, term) => self.visit_unpack(sp, package, term), } } } pub trait PatternVisitor: Sized { fn visit_literal(&mut self, lit: &Literal) {} fn visit_variable(&mut self, var: &String) {} fn visit_product(&mut self, pats: &Vec) { for p in pats { self.visit_pattern(p); } } fn visit_constructor(&mut self, label: &String, pat: &Pattern) { self.visit_pattern(pat); } fn visit_pattern(&mut self, pattern: &Pattern) { match pattern { Pattern::Any => {} Pattern::Constructor(label, pat) => self.visit_constructor(label, pat), Pattern::Product(pat) => self.visit_product(pat), Pattern::Literal(lit) => self.visit_literal(lit), Pattern::Variable(var) => self.visit_variable(var), } } } ================================================ FILE: 06_system_f/test.sf ================================================ let func = \X (\c: {None | Some X}. \x: X->(X, X). case c of | None => None of {None | Some (X, X)} | Some val => Some (val, val) of {None | Some (X, X)} ) in func [Nat] (Some 10 of {None|Some Nat}) (\x: Nat. (x, x)) ; let poly = \X \x: X. x in let x = poly [Nat] 0 in let y = poly [Bool] false in let z = poly [(Nat, Bool)] in z (x, y) ; let poly = \X \Y (\func: X->Y. \val: X. func val) in poly [Nat][Bool] ; case Some (5, 2) of {None | Some (Nat, Nat)} of | None => (0, 0) | Some (1, _) => (1, 1) | Some(x, y) => (y, x) ; case (1, (2, 3)) of | (x, (y, z)) => ((z, y), x) ; let x = \z: (Nat, Nat)->Nat. \y: (Nat, Nat). case y of | (0, x) => x, | x => z (pred y.0, succ (succ x.1)) in (fix x) (10, 0) ; let cdr = \list: NatList. case unfold NatList list of | Nil => Nil of NatList | Cons (x, xs) => xs in cdr Cons (10, Cons (20, Nil of NatList) of NatList) of NatList ; case unfold NatList Cons (10, Cons (20, Nil of NatList) of NatList) of NatList of | Nil => Nil of NatList | Cons (10, xs) => Cons (11, xs) of NatList | Cons (x, xs) => xs ; let nil = Nil of NatList in let cons = (\val: Nat. \list: NatList. Cons (val, list) of NatList) in case unfold NatList (cons 1 nil) of | Nil => nil | Cons (x, y) => y ; let x = 10 in let (y, _) = (x, 1) in y; let (x, y) = (0, 10) in let z = x in z ; (\x: Nat. \Y \y: Nat->Y. y x) 10 [Nat] succ ; let x = \struct: (Nat, Nat, Nat). let (_, q, _) = struct in q in x (10, 12, 13) let x = \A \B \C \tuple: (A, B, (C, C)). let (_, mid, (n, s)) = tuple in (n, mid, s, mid) in x [Nat] [Bool] [Nat] (10, true, (1, 11)) ; let package = (pack Nat, ((\x: Nat. succ (succ x)), 0) as exists X. (X->Nat, X)) in unpack package as T, mod in mod.0 ((\x: T. x) mod.1) ;; let package = (pack Bool, ((\x: Bool. case x of | true => 10 | false => 0), true) as exists REPR. (REPR->Nat, REPR)) in let x = (\x: exists T. (T->Nat, T). unpack x as T, mod in succ (mod.0 mod.1)) in x package ; ================================================ FILE: 07_system_fw/Cargo.toml ================================================ [package] name = "system_fw" version = "0.1.0" authors = ["Michael Lazear "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] util = { path = "../util" } ================================================ FILE: 07_system_fw/README.md ================================================ # System Fω This is an implementation (mostly of just the type system) of the higher-order polymorphic lambda calculus with explicit typing. This allows us to express functions with impredicative arguments, and we can emulate Haskell style typeclasses using functors over existential types (akin to 1ML) - something not expressible in Standard ML. #### References I've included some selected references on implementing System Fw, particularly with respect to the addition of recursive types (System F omega-mu): - Mendler, N.P.: Recursive types and type constraints in second-order lambda calculus. In: Proceedings of the Second Annual IEEE Symposium on Logic in Computer Science, Ithaca, N.Y., IEEE Computer Society Press (1987) 30–36 - One of the classic, highly reference papers - Andreas Abel, Ralph Matthes, Tarmo Uustalu, "Iteration and coiteration schemes for higher-order and nested datatypes", Theoretical Computer Science, Volume 333, Issues 1–2, 2005, Pages 3-66, - Interesting implementations, a paper that allowed me to better understand "Mendler iterations" - Andreas Abel, Ralph Matthes, "Fixed Points of Type Constructors and Primitive Recursion", International Workshop on Computer Science Logic (2004) - Ahn, Ki Yung, "The Nax Language: Unifying Functional Programming and Logical Reasoning in a Language based on Mendler-style Recursion Schemes and Term-indexed Types" (2014). Dissertations and Theses. Paper 2088. - This reference is quite useful as it restates much of the literature in easily understandable terms - Yufei Cai, Paolo G. Giarrusso, and Klaus Ostermann. 2016. System f-omega with equirecursive types for datatype-generic programming. SIGPLAN Not. 51, 1 (January 2016), 30-43. DOI: https://doi.org/10.1145/2914770.2837660 - Interesting, but I found it to be of not too much practical use - A Polymorphic Lambda-Calculus with Sized Higher-Order Types, Andreas Abel, PhD thesis - Page 76 seems to give some good hints on how to type Fold/Unfold operators (but for equirecursive, isorecursive on pp 85). It also seems that the kinding rules/type-equivalence can treat a type constructor abstraction (* => *) and a recursive type of kind (* => *) as the same. Iso-coinductive construtors are also discussed, pp 157. ================================================ FILE: 07_system_fw/src/diagnostics.rs ================================================ use std::fmt; use util::span::Span; #[derive(Debug, Clone, PartialEq)] pub enum Level { Warn, Error, } #[derive(Debug, Clone, PartialEq)] pub struct Annotation { pub span: Span, pub info: String, } #[derive(Clone, PartialEq)] pub struct Diagnostic { pub level: Level, pub primary: Annotation, pub info: Vec, pub other: Vec, } impl Annotation { pub fn new>(span: Span, message: S) -> Annotation { Annotation { span, info: message.into(), } } } impl Diagnostic { pub fn error>(span: Span, message: S) -> Diagnostic { Diagnostic { level: Level::Error, primary: Annotation::new(span, message), other: Vec::new(), info: Vec::new(), } } pub fn warn>(span: Span, message: S) -> Diagnostic { Diagnostic { level: Level::Warn, primary: Annotation::new(span, message), other: Vec::new(), info: Vec::new(), } } pub fn message>(mut self, span: Span, message: S) -> Diagnostic { self.other.push(Annotation::new(span, message)); self } pub fn info>(mut self, info: S) -> Diagnostic { self.info.push(info.into()); self } pub fn lines(&self) -> std::ops::Range { let mut range = std::ops::Range { start: self.primary.span.start.line, end: self.primary.span.end.line + 1, }; for addl in &self.other { if addl.span.start.line < range.start { range.start = addl.span.start.line; } if addl.span.end.line + 1 > range.end { range.end = addl.span.end.line + 1; } } range } } impl fmt::Debug for Diagnostic { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "\n{:?}: {} starting at line {}, col {}\n{}", self.level, self.primary.info, self.primary.span.start.line, self.primary.span.start.col, self.other.iter().map(|a| a.info.clone()).collect::>().join("\n") ) } } ================================================ FILE: 07_system_fw/src/elaborate.rs ================================================ use super::ast::*; use super::hir::{self, Constructor, DeBruijn, HirId}; use super::stack::Stack; use super::syntax::visit::*; use std::collections::{HashMap, HashSet}; use std::iter::IntoIterator; /// Validate that a [`Program`] is closed, e.g. it has no free /// term or type variables. We traverse the program in execution order, /// adding bindings for top-level declarations, and also keeping track of /// bindings that occur in local scopes for de Bruijn index tracking #[derive(Default)] pub struct ElaborationContext<'s> { tyvars: Stack<&'s str>, tmvars: Stack<&'s str>, namespaces: Vec, current: usize, constructors: HashMap, elaborated: HashMap, next_hir_id: HirId, } pub struct Elaborated { pub constructors: HashMap, pub elaborated: HashMap, pub decls: Vec, } #[derive(Default)] pub struct Namespace { id: usize, parent: Option, values: HashMap, types: HashMap, } #[derive(Clone, Debug, PartialEq)] pub enum ElabError { UndefinedType(String, util::span::Span), UndefinedValue(String, util::span::Span), UndefinedConstr(String, util::span::Span), InvalidBinding(String, util::span::Span), } /// Housekeeping, namespace methods impl<'s> ElaborationContext<'s> { pub fn new() -> Self { let mut ec = Self::default(); let global_ns = Namespace::default(); ec.namespaces.push(global_ns); ec } pub fn elaborate(program: &'s Program) -> Result { let mut ec = Self::new(); let decls = ec.elab_program(program)?; Ok(Elaborated { constructors: ec.constructors, elaborated: ec.elaborated, decls, }) } /// Keep track of the type variable stack, while executing the combinator /// function `f` on `self`. Any stack growth is popped off after `f` /// returns. fn with_tyvars) -> T>(&mut self, f: F) -> T { let n = self.tyvars.len(); let r = f(self); let to_pop = self.tyvars.len() - n; self.tyvars.popn(to_pop); r } /// Keep track of the term variable stack, while executing the combinator /// function `f` on `self`. Any stack growth is popped off after `f` /// returns. fn with_tmvars) -> T>(&mut self, f: F) -> T { let n = self.tmvars.len(); let r = f(self); let to_pop = self.tmvars.len() - n; self.tmvars.popn(to_pop); r } fn elab_kind(&self, k: &Kind) -> hir::Kind { match k { Kind::Star => hir::Kind::Star, Kind::Arrow(k1, k2) => hir::Kind::Arrow(Box::new(self.elab_kind(k1)), Box::new(self.elab_kind(k2))), } } fn allocate_hir_id(&mut self) -> HirId { let id = self.next_hir_id; self.next_hir_id = HirId(self.next_hir_id.0 + 1); id } fn define_value(&mut self, name: String, expr: hir::Expr) -> HirId { let id = self.allocate_hir_id(); self.elaborated.insert(id, hir::Decl::Value(expr)); self.namespaces[self.current].values.insert(name, id); id } fn define_type(&mut self, name: String, ty: hir::Type) -> HirId { let id = self.allocate_hir_id(); self.elaborated.insert(id, hir::Decl::Type(ty)); self.namespaces[self.current].types.insert(name, id); id } pub fn dump(&self) { for n in &self.namespaces { println!("Current value bindings:"); for (name, key) in &n.values { println!("\t[{:?}] {}: {:?}", key, name, self.elaborated.get(key).unwrap()); } println!("Current type bindings:"); for (name, key) in &n.types { println!("\t[{:?}] {}: {:?}", key, name, self.elaborated.get(key).unwrap()); } } println!("Current constr bindings:"); for (name, key) in &self.constructors { println!("{:?}", key,); } } /// Starting from the current [`Namespace`], search for a bound name. /// If it's not found, then recursively search parent namespaces fn lexical_value(&self, s: &str) -> Option { let mut ptr = &self.namespaces[self.current]; loop { match ptr.values.get(s) { Some(idx) => return Some(*idx), None => ptr = &self.namespaces[ptr.parent?], } } } /// Search for a variable bound in a temporary lexical scope (i.e. a function) fn debruijn_value(&self, s: &str) -> Option { self.tmvars .lookup(&s) .map(|idx| hir::Expr::LocalVar(DeBruijn { idx, name: s.into() })) } /// Search for a value binding, starting with any temporary lambda captures, /// and then working upwards through top level definitions fn lookup_value(&self, s: &str) -> Option { if let Some(db) = self.debruijn_value(s) { return Some(db); } self.lexical_value(s).map(hir::Expr::ProgramVar) } fn lexical_type(&self, s: &str) -> Option { let mut ptr = &self.namespaces[self.current]; loop { match ptr.types.get(s) { Some(idx) => return Some(*idx), None => ptr = &self.namespaces[ptr.parent?], } } } fn debruijn_type(&self, s: &str) -> Option { self.tyvars .lookup(&s) .map(|idx| hir::Type::Var(DeBruijn { idx, name: s.into() })) } fn enter_namespace(&mut self) -> usize { let id = self.namespaces.len(); self.namespaces.push(Namespace { id, parent: Some(self.current), types: HashMap::new(), values: HashMap::new(), }); self.current = id; id } fn leave_namespace(&mut self) { match self.namespaces[self.current].parent { Some(id) => self.current = id, None => panic!("trying to leave global namespace!"), } } /// Perform all bindings within `f` in a fresh [`Namespace`], /// and then return to the current one fn with_new_namespace) -> T>(&mut self, f: F) -> T { self.enter_namespace(); let t = f(self); self.leave_namespace(); t } } /// Type elaboration impl<'s> ElaborationContext<'s> { fn elab_ty_row(&mut self, row: &'s Row) -> Result { Ok(hir::Row { label: row.label.clone(), ty: self.elab_type(&row.ty)?, }) } fn elab_ty_inner(&mut self, tv: &'s str, ty: &'s Type) -> Result { self.with_tyvars(|f| { f.tyvars.push(tv); f.elab_type(ty) }) } fn elab_type(&mut self, ty: &'s Type) -> Result { use TypeKind::*; match &ty.kind { Int => Ok(hir::Type::Int), Bool => Ok(hir::Type::Bool), Unit => Ok(hir::Type::Unit), Infer => Ok(hir::Type::Infer), Defined(s) => self .lexical_type(s) .map(hir::Type::Defined) .ok_or_else(|| ElabError::UndefinedType(s.into(), ty.span)), Variable(s) => self .debruijn_type(s) .ok_or_else(|| ElabError::UndefinedType(s.into(), ty.span)), Function(ty1, ty2) => Ok(hir::Type::Arrow( Box::new(self.elab_type(ty1)?), Box::new(self.elab_type(ty2)?), )), // Sum types can only be constructed through a DeclKind::Datatype // so it's okay to be unreachable!() here, as it would be a fatal // bug if we reached this path somehow. Sum(_) => unreachable!(), Product(tys) => Ok(hir::Type::Product( tys.iter().map(|t| self.elab_type(t)).collect::>()?, )), Record(rows) => Ok(hir::Type::Record( rows.iter().map(|t| self.elab_ty_row(t)).collect::>()?, )), Existential(s, k, ty) => Ok(hir::Type::Existential( Box::new(self.elab_kind(k)), Box::new(self.elab_ty_inner(s, ty)?), )), Universal(s, k, ty) => Ok(hir::Type::Universal( Box::new(self.elab_kind(k)), Box::new(self.elab_ty_inner(s, ty)?), )), Abstraction(s, k, ty) => Ok(hir::Type::Abstraction( Box::new(self.elab_kind(k)), Box::new(self.elab_ty_inner(s, ty)?), )), Application(ty1, ty2) => Ok(hir::Type::Application( Box::new(self.elab_type(ty1)?), Box::new(self.elab_type(ty2)?), )), Recursive(ty) => self.elab_type(ty).map(|ty| hir::Type::Recursive(Box::new(ty))), } } } /// Expr elaboration impl<'s> ElaborationContext<'s> { fn elab_let(&mut self, decls: &'s [Decl], expr: &'s Expr) -> Result { self.with_new_namespace(|f| { for d in decls { f.elab_decl(d)?; } f.elab_expr(expr) }) } fn elab_arm(&mut self, arm: &'s Arm) -> Result { Ok(hir::Arm { pat: self.elab_pattern(&arm.pat, true)?, expr: self.elab_expr(&arm.expr)?, }) } fn elab_case(&mut self, expr: &'s Expr, arms: &'s [Arm]) -> Result { let ex = self.elab_expr(expr)?; let arms = arms.iter().map(|a| self.elab_arm(a)).collect::>()?; Ok(hir::Expr::Case(Box::new(ex), arms)) } fn elab_field(&mut self, field: &'s Field) -> Result { Ok(hir::Field { label: field.label.clone(), expr: self.elab_expr(&field.expr)?, }) } /// We desugar to a case expression /// fn (Some x) => x + 1 /// fn $x : Infer option => case $x of (Some x) => x + 1 fn elab_abs(&mut self, pat: &'s Pattern, body: &'s Expr) -> Result { // Wow we have a lot of bindings self.with_tmvars(|f| { let pat = f.elab_pattern(pat, true)?; let expr = f.elab_expr(body)?; let ty = f.naive_type_infer(&pat)?; let arm = hir::Arm { pat, expr }; let dummy = hir::Expr::LocalVar(DeBruijn { name: "$anon".into(), idx: 0, }); let case = hir::Expr::Case(Box::new(dummy), vec![arm]); Ok(hir::Expr::Abs(Box::new(ty), Box::new(case))) }) } fn elab_expr(&mut self, expr: &'s Expr) -> Result { use ExprKind::*; match &expr.kind { Unit => Ok(hir::Expr::Unit), Int(i) => Ok(hir::Expr::Int(*i)), Var(s) => self .lookup_value(s) .ok_or_else(|| ElabError::UndefinedValue(s.into(), expr.span)), Constr(s) => self .lexical_value(s) .map(|id| self.constructors.get(&id)) .flatten() .map(|c| hir::Expr::Constr(c.type_id, c.tag)) .ok_or_else(|| ElabError::UndefinedConstr(s.into(), expr.span)), If(e1, e2, e3) => Ok(hir::Expr::If( Box::new(self.elab_expr(e1)?), Box::new(self.elab_expr(e2)?), Box::new(self.elab_expr(e3)?), )), Abs(pat, expr) => self.elab_abs(pat, expr), App(e1, e2) => Ok(hir::Expr::App( Box::new(self.elab_expr(e1)?), Box::new(self.elab_expr(e2)?), )), TyAbs(s, k, e) => self.with_tyvars(|f| { f.tyvars.push(s); let e = f.elab_expr(e)?; Ok(hir::Expr::TyAbs(Box::new(f.elab_kind(k)), Box::new(e))) }), TyApp(e, t) => Ok(hir::Expr::TyApp( Box::new(self.elab_expr(e)?), Box::new(self.elab_type(t)?), )), Record(fields) => fields .iter() .map(|e| self.elab_field(e)) .collect::>() .map(hir::Expr::Record), Tuple(exprs) => exprs .iter() .map(|e| self.elab_expr(e)) .collect::>() .map(hir::Expr::Tuple), Projection(e1, e2) => match &e2.kind { ExprKind::Var(label) => Ok(hir::Expr::RecordProj(Box::new(self.elab_expr(e1)?), label.clone())), ExprKind::Int(idx) => Ok(hir::Expr::TupleProj(Box::new(self.elab_expr(e1)?), *idx)), _ => Err(ElabError::InvalidBinding( format!("attempt to project using {:?}", e2), expr.span, )), }, Case(e, arms) => self.elab_case(e, arms), Let(decls, e) => self.elab_let(decls, e), } } } /// Pattern elaboration impl<'s> ElaborationContext<'s> { fn naive_type_infer(&self, pat: &hir::Pattern) -> Result { use hir::Pattern::*; match pat { Any => Ok(hir::Type::Infer), Unit => Ok(hir::Type::Unit), Literal(_) => Ok(hir::Type::Int), Ascribe(_, ty) => Ok(*ty.clone()), Constructor(id) => { let con = self.constructors.get(&id).expect("internal error"); let cty = hir::Type::Defined(con.type_id); if con.type_arity != 0 { Ok(hir::Type::Application(Box::new(cty), Box::new(hir::Type::Infer))) } else { Ok(cty) } } Product(pats) => pats .into_iter() .map(|p| self.naive_type_infer(p)) .collect::>() .map(hir::Type::Product), // Maybe we should go back to sub pats... Record(s) => Ok(hir::Type::Record( s.into_iter() .map(|s| hir::Row { label: s.clone(), ty: hir::Type::Infer, }) .collect(), )), Application(id, arg) => { let con = self.constructors.get(&id).expect("internal error"); let cty = hir::Type::Defined(con.type_id); // Ok(cty) self.naive_type_infer(arg) .map(|ty| hir::Type::Application(Box::new(cty), Box::new(ty))) } Variable(_) => Ok(hir::Type::Infer), } } fn elab_pattern(&mut self, pat: &'s Pattern, bind: bool) -> Result { match &pat.kind { PatKind::Any => Ok(hir::Pattern::Any), PatKind::Unit => Ok(hir::Pattern::Unit), PatKind::Literal(i) => Ok(hir::Pattern::Literal(*i)), PatKind::Variable(s) => { if bind { self.tmvars.push(s); } Ok(hir::Pattern::Variable(s.clone())) } PatKind::Product(sub) => sub .iter() .map(|p| self.elab_pattern(p, bind)) .collect::>() .map(hir::Pattern::Product), PatKind::Record(sub) => Ok(hir::Pattern::Record(sub.clone())), PatKind::Ascribe(pat, ty) => Ok(hir::Pattern::Ascribe( Box::new(self.elab_pattern(pat, bind)?), Box::new(self.elab_type(ty)?), )), PatKind::Constructor(s) => self .lexical_value(s) .ok_or_else(|| ElabError::UndefinedConstr(s.clone(), pat.span)) .map(hir::Pattern::Constructor), PatKind::Application(con, arg) => { let econ = self.elab_pattern(con, bind)?; let earg = self.elab_pattern(arg, bind)?; let id = match econ { hir::Pattern::Constructor(id) => { let con_info = self.constructors.get(&id).unwrap(); if !con_info.arity { let name = match &con.as_ref().kind { PatKind::Constructor(s) => s.clone(), _ => panic!("interal error!"), }; return Err(ElabError::InvalidBinding( format!("constructor {} doesn't accept arguments!", name), pat.span, )); } id } _ => { return Err(ElabError::InvalidBinding( format!("cannot apply {:?} to non-constructor {:?}", arg, con), pat.span, )) } }; Ok(hir::Pattern::Application(id, Box::new(earg))) } } } } /// Decl elaboration impl<'s> ElaborationContext<'s> { fn elab_decl_type(&mut self, tyvars: &'s [Type], name: &'s str, ty: &'s Type) -> Result { let ty = self.elab_type(ty)?; let ty = tyvars.iter().fold(ty, |ty, var| { hir::Type::Abstraction(Box::new(hir::Kind::Star), Box::new(ty)) }); Ok(self.define_type(name.into(), ty)) } fn elab_constructor( &mut self, name: &'s str, tag: usize, tyvar_arity: usize, type_signature: Option<&hir::Type>, type_id: HirId, ) -> HirId { let expr = match type_signature { Some(ty) => hir::Expr::Abs( Box::new(ty.clone()), Box::new(hir::Expr::App( Box::new(hir::Expr::Constr(type_id, tag)), Box::new(hir::Expr::LocalVar(DeBruijn { name: String::from("x"), idx: 0, })), )), ), None => hir::Expr::Constr(type_id, tag), }; let expr = (0..tyvar_arity).fold(expr, |e, _| hir::Expr::TyAbs(Box::new(hir::Kind::Star), Box::new(e))); let arity = type_signature.is_some(); let con_id = self.define_value(name.into(), expr); self.constructors.insert( con_id, Constructor { type_id, con_id, tag, arity, type_arity: tyvar_arity as u8, }, ); con_id } fn elab_decl_datatype(&mut self, tyvars: &'s [Type], name: &'s str, ty: &'s Type) -> Result { // Quickly collection all names that this type points to let mut coll = TyNameCollector::default(); coll.visit_ty(&ty); let is_recur = coll.definitions.contains(name); // Insert first, so we can be recursive if we need to let id = self.allocate_hir_id(); self.namespaces[self.current].types.insert(name.into(), id); // We just do all of this inside of the closure, rather than delegrating // to visit_sum, because we need access to both `tyvars` for generating // value-bindings for the constructors let ty = self.with_tyvars(|f| { f.tyvars.extend(tyvars.iter().map(|t| t.kind.as_tyvar())); let mut elab = Vec::new(); for (idx, v) in ty.kind.variants().iter().enumerate() { let ty = v.ty.as_ref().map(|ty| f.elab_type(ty)).transpose()?; // Generate a function or constant value for the constructor f.elab_constructor(&v.label, idx, tyvars.len(), ty.as_ref(), id); elab.push(hir::Variant { label: v.label.clone(), ty, }); } Ok(hir::Type::Sum(elab)) })?; // We have the raw sum type, so now wrap it in type abstractions let ty = tyvars.iter().fold(ty, |ty, var| { hir::Type::Abstraction(Box::new(hir::Kind::Star), Box::new(ty)) }); let ty = if is_recur { hir::Type::Recursive(Box::new(ty)) } else { ty }; self.elaborated.insert(id, hir::Decl::Type(ty)); Ok(id) } /// Caller is responsible for checking tmvar and tyvar stack growth /// Note: this function directly adds bindings to the global definition states fn deconstruct_pat_binding( &mut self, pat: hir::Pattern, expr: hir::Expr, span: util::span::Span, ) -> Result { use hir::Pattern::*; match pat { Any | Unit => Ok(self.define_value(String::default(), expr)), Variable(s) => Ok(self.define_value(s, expr)), Product(sub) => { // No need for extra redirection let id = match expr { hir::Expr::ProgramVar(id) => id, _ => self.define_value(String::default(), expr), }; let base = Box::new(hir::Expr::ProgramVar(id)); for (idx, pat) in sub.into_iter().enumerate() { self.deconstruct_pat_binding(pat, hir::Expr::TupleProj(base.clone(), idx), span)?; } Ok(id) } Record(sub) => { let id = match expr { hir::Expr::ProgramVar(id) => id, _ => self.define_value(String::default(), expr), }; let base = Box::new(hir::Expr::ProgramVar(id)); for (idx, pat) in sub.into_iter().enumerate() { self.define_value(pat, hir::Expr::TupleProj(base.clone(), idx)); } Ok(id) } Ascribe(pat, _) => self.deconstruct_pat_binding(*pat, expr, span), Constructor(_) => Err(ElabError::InvalidBinding( format!("cannot bind constructor to a value!"), span, )), Application(con, arg) => { // con arg expr // val Some (x, y) = Some (10, 9) // val Some (x, y) = funct 10 // $anon = func 10 // case $anon of let con_info = self.constructors.get(&con).unwrap(); let e = hir::Expr::App( Box::new(hir::Expr::Deconstr(con_info.type_id, con_info.tag)), Box::new(expr), ); let id = self.with_new_namespace(|f| f.define_value("$anon_bind_decon".into(), e.clone())); let e = hir::Expr::ProgramVar(id); self.deconstruct_pat_binding(*arg, e, span) } Literal(_) => Err(ElabError::InvalidBinding( format!("cannot bind a literal pattern to a value!"), span, )), } } fn elab_decl_value(&mut self, tyvars: &'s [Type], pat: &'s Pattern, expr: &'s Expr) -> Result { self.with_tyvars(|f| { f.tyvars.extend(tyvars.iter().map(|t| t.kind.as_tyvar())); f.with_tmvars(|f| { let sp = pat.span; let pat = f.elab_pattern(pat, false)?; let ex = f.elab_expr(expr)?; f.deconstruct_pat_binding(pat, ex, sp) }) }) } fn build_pat_matrix(&mut self, arms: &'s [FnArm]) -> Result { let rows = arms.len(); let mut pats: Vec> = Vec::with_capacity(rows); let mut exprs = Vec::with_capacity(rows); let mut cols = 0; for arm in arms { cols = cols.max(arm.pats.len()); pats.push( arm.pats .iter() .map(|p| self.elab_pattern(p, true)) .collect::>()?, ); dbg!(&self.tmvars); exprs.push(self.elab_expr(&arm.expr)?); } for r in pats.iter_mut() { if r.len() < cols { r.extend(std::iter::repeat(hir::Pattern::Any).take(cols - r.len())); } } Ok(PatternMatrix { pats, exprs, rows, cols, }) } fn infer_type_matrix(&self, mat: &PatternMatrix) -> Result>, ElabError> { let mut cols: Vec> = (0..mat.cols).map(|_| HashSet::default()).collect(); for i in 0..mat.cols { for j in 0..mat.rows { let ty = self.naive_type_infer(&mat.pats[j][i])?; cols[i].insert(ty); } } Ok(cols) } fn try_unify_type_matrix(mat: Vec>) -> Vec { fn unify(a: hir::Type, b: hir::Type) -> hir::Type { use hir::Type::*; if a == b { return a; } match (a, b) { (Infer, x) => x, (x, Infer) => x, (Application(r1, r2), Application(r3, r4)) if r1 == r3 => Application(r1, Box::new(unify(*r2, *r4))), (Product(xs), Product(ys)) if xs.len() == ys.len() => { Product(xs.into_iter().zip(ys.into_iter()).map(|(x, y)| unify(x, y)).collect()) } _ => hir::Type::Unclear, } } mat.into_iter() .map(|col| col.into_iter().fold(hir::Type::Infer, |ty, x| unify(ty, x))) .collect() } fn elab_decl_fun(&mut self, tyvars: &'s [Type], name: &'s str, arms: &'s [FnArm]) -> Result { self.with_tyvars(|f| { f.tyvars.extend(tyvars.iter().map(|t| t.kind.as_tyvar())); f.with_tmvars(|f| { f.tmvars.push(name); let matrix = f.build_pat_matrix(arms)?; let tys = f.infer_type_matrix(&matrix)?; let tys = Self::try_unify_type_matrix(tys); let arms = matrix.collapse(); let expr = hir::Expr::Tuple( (0..tys.len()) .rev() .map(|idx| { hir::Expr::LocalVar(DeBruijn { name: String::new(), idx, }) }) .collect(), ); let case = hir::Expr::Case(Box::new(expr), arms); let fun = tys .into_iter() .rev() .fold(case, |acc, ty| hir::Expr::Abs(Box::new(ty), Box::new(acc))); // let fun = hir::Expr::Abs( // Box::new(hir::Type::Arrow( // Box::new(hir::Type::Infer), // Box::new(hir::Type::Infer), // )), // Box::new(fun), // ); // let fun = hir::Expr::Fix(Box::new(fun)); Ok(f.define_value(name.into(), fun)) }) }) } fn elab_decl_expr(&mut self, expr: &'s Expr) -> Result { let e = self.elab_expr(expr)?; Ok(self.define_value(String::default(), e)) } fn elab_decl_and(&mut self, a: &'s Decl, b: &'s Decl) -> Result { let mut names = DeclNames::default(); names.visit_decl(a); names.visit_decl(b); for name in names.values { // Insert first, so we can be recursive if we need to let id = self.allocate_hir_id(); self.namespaces[self.current].values.insert(name.into(), id); } unimplemented!() } fn elab_decl(&mut self, decl: &'s Decl) -> Result { match &decl.kind { DeclKind::Datatype(tyvars, name, ty) => self.elab_decl_datatype(tyvars, name, ty), DeclKind::Type(tyvars, name, ty) => self.elab_decl_type(tyvars, name, ty), DeclKind::Value(tyvars, pat, expr) => self.elab_decl_value(tyvars, pat, expr), DeclKind::And(d1, d2) => unimplemented!(), DeclKind::Function(tyvars, name, arms) => self.elab_decl_fun(tyvars, name, arms), DeclKind::Expr(e) => self.elab_decl_expr(e), } } pub fn elab_program(&mut self, prog: &'s Program) -> Result, ElabError> { let mut v = Vec::with_capacity(prog.decls.len()); for d in &prog.decls { v.push(self.elab_decl(d)?); self.dump(); } Ok(v) } } #[derive(Debug, Clone)] pub struct PatternMatrix { pats: Vec>, exprs: Vec, rows: usize, cols: usize, } impl PatternMatrix { fn collapse(self) -> Vec { let mut arms = Vec::new(); for (pat, expr) in self.pats.into_iter().zip(self.exprs.into_iter()) { arms.push(hir::Arm { pat: hir::Pattern::Product(pat), expr, }); } arms } } /// Helper struct for walking top-level declarations and extracting /// bound type and value names. This does no validation or checking #[derive(Default)] struct DeclNames<'s> { values: Vec<&'s str>, types: Vec<&'s str>, } impl<'s> DeclNames<'s> { fn visit_pat(&mut self, pat: &'s Pattern) { match &pat.kind { PatKind::Variable(s) => self.values.push(&s), PatKind::Product(sub) => { for p in sub { self.visit_pat(p); } } PatKind::Record(sub) => { for p in sub { self.values.push(p); } } PatKind::Ascribe(pat, ty) => self.visit_pat(&pat), PatKind::Application(con, arg) => self.visit_pat(&arg), _ => {} } } fn visit_decl(&mut self, d: &'s Decl) { match &d.kind { DeclKind::Datatype(_, name, ty) => self.types.push(&name), DeclKind::Type(_, name, ty) => self.types.push(&name), DeclKind::Value(_, pat, expr) => self.visit_pat(pat), DeclKind::And(d1, d2) => { self.visit_decl(d1); self.visit_decl(d2); } DeclKind::Function(_, name, arms) => self.values.push(&name), _ => {} } } } /// Collect type variables and references to defined names #[derive(Default, Debug, Clone)] pub struct TyNameCollector<'s> { pub tyvars: HashSet<&'s str>, pub definitions: HashSet<&'s str>, } impl<'s> TypeVisitor<'s> for TyNameCollector<'s> { fn visit_variable(&mut self, s: &'s str) { self.tyvars.insert(s); } fn visit_defined(&mut self, s: &'s str) { self.definitions.insert(s); } } ================================================ FILE: 07_system_fw/src/functor.rs ================================================ use super::*; pub fn parameterized_set() -> Type { tyop!(kind!(*), exist!(kind!(* => *), op_app!(Type::Var(0), Type::Var(1)))) } fn list_type() -> Type { let inner = tyop!( kind!(* => *), tyop!( kind!(*), sum!( ("Nil", Type::Unit), ( "Cons", record!(("head", Type::Var(0)), ("tail", op_app!(Type::Var(1), Type::Var(0)))) ) ) ) ); Type::Recursive(Box::new(inner)) } pub fn parameterized_set_term() -> Term { let body = Term::new( terms::Kind::Fold( Box::new(op_app!(list_type(), Type::Var(0))), Box::new(Term::new( terms::Kind::Injection( "Nil".into(), Box::new(unit!()), // Manually perform an unfold on list_type() // - In the System F language, we had an // InjRewriter macro that takes care of this, // and we could probably tack it directly // into the type-checker since we can do simplification // now Box::new(op_app!(unfold(list_type()), Type::Var(0))), ), Span::default(), )), ), Span::default(), ); // \X :: * => pack type 'a list = Nil | Cons 'a * 'a list with Nil as /\T::* // {*X::*=>*, X T} tyabs!( kind!(*), pack!(list_type(), body, op_app!(parameterized_set(), Type::Var(0))) ) } #[cfg(test)] mod tests { use super::*; #[test] fn parameterized_functor() { let mut ctx = typecheck::Context::default(); let func = tyapp!(functor::parameterized_set_term(), Type::Nat); let func_actual_ty = ctx.typecheck(&func).unwrap(); let func_described_ty = op_app!(functor::parameterized_set(), Type::Nat); assert_eq!(ctx.equiv(&func_actual_ty, &func_described_ty), Ok(true)); } } ================================================ FILE: 07_system_fw/src/hir/bidir.rs ================================================ use super::*; use crate::elaborate::Elaborated; use std::collections::HashMap; use super::Type::*; #[derive(Debug)] pub struct Context<'hir> { hir_map: &'hir HashMap, defs: HashMap, ctx: Vec, gen: usize, } /// An element in the typing context #[derive(Clone, Debug, PartialEq)] pub enum Element { /// Universal type variable Var, /// Term variable typing x : A. We differ from the paper in that we use /// de Bruijn indices for variables, so we don't need to mark which var /// this annotation belongs to - it always belongs to the innermost binding (idx 0) /// and we will find this by traversing the stack Ann(Type), /// Unsolved existential type variable Exist(usize), /// Existential type variable that has been solved /// to some monotype Solved(usize, Type), /// I am actually unsure if we really need a marker, due to how we structure /// scoping, see `with_scope` method. Marker(usize), } pub enum Error { UnboundVariable, } impl<'hir> Context<'hir> { /// Find the term annotation corresponding to de Bruijn index `idx`. /// We traverse the stack in a reversed order, counting each annotation /// we come across fn find_annotation(&self, idx: usize) -> Option<&Type> { let mut ix = 0; for elem in self.ctx.iter().rev() { match &elem { Element::Ann(ty) => { if ix == idx { return Some(&ty); } ix += 1 } _ => {} } } None } pub fn infer(&mut self, e: &'hir Expr) -> Result { use Expr::*; match e { Unit => Ok(Type::Unit), Int(usize) => Ok(Type::Int), LocalVar(db) => self.find_annotation(db.idx).cloned().ok_or(Error::UnboundVariable), ProgramVar(id) => unimplemented!(), // Datatype constructor, pointing to type def and tag of the constr Constr(id, tag) => unimplemented!(), Deconstr(id, tag) => unimplemented!(), If(e1, e2, e3) => unimplemented!(), Abs(ty, ex) => { dbg!(ty); unimplemented!() } App(e1, e2) => unimplemented!(), TyAbs(k, ex) => unimplemented!(), TyApp(ex, ty) => unimplemented!(), Record(fields) => fields .iter() .map(|f| { Ok(Row { label: f.label.clone(), ty: self.infer(&f.expr)?, }) }) .collect::, _>>() .map(Type::Record), Tuple(exprs) => exprs .iter() .map(|e| self.infer(e)) .collect::, _>>() .map(Type::Product), RecordProj(ex, idx) => unimplemented!(), TupleProj(ex, idx) => unimplemented!(), Case(ex, arms) => unimplemented!(), Let(decls, ex) => unimplemented!(), Fix(ex) => unimplemented!(), } } } pub fn test(prog: Elaborated) { let mut ctx = Context { hir_map: &prog.elaborated, defs: HashMap::new(), ctx: Vec::new(), gen: 0, }; for id in prog.decls { match prog.elaborated.get(&id) { Some(Decl::Value(e)) => { ctx.infer(e); } _ => {} } dbg!(&ctx); } } ================================================ FILE: 07_system_fw/src/hir/mod.rs ================================================ pub mod bidir; use std::fmt; #[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Hash)] pub struct DeBruijn { pub idx: usize, pub name: String, } #[derive(Copy, Clone, Debug, Default, PartialEq, PartialOrd, Eq, Hash)] pub struct HirId(pub(crate) u32); /// Arm of a case expression #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Arm { pub pat: Pattern, pub expr: Expr, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Field { pub label: String, pub expr: Expr, } pub struct Program { pub decls: Vec, } // A lot of desugaring goes on here #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum Decl { Type(Type), Value(Expr), } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Constructor { // Points to the constructor function in defined_values pub con_id: HirId, // Points to a Type::Sum in the defined_types map pub type_id: HirId, // Index of this constructor into the sum variants pub tag: usize, // Whether this constr takes an argument or not pub arity: bool, pub type_arity: u8, } /// Patterns for case and let expressions #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum Pattern { /// Wildcard pattern, this always matches Any, Unit, Ascribe(Box, Box), /// Constant pattern Literal(usize), /// Datatype constructor, HirId points to the constructor value binding Constructor(HirId), /// Variable binding Variable(String), /// Tuple of pattern bindings (_, x) Product(Vec), /// Record pattern { label1, label2 } Record(Vec), /// Algebraic datatype constructor, along with binding pattern Application(HirId, Box), } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum Expr { Unit, Int(usize), LocalVar(DeBruijn), ProgramVar(HirId), // Datatype constructor, pointing to type def and tag of the constr Constr(HirId, usize), Deconstr(HirId, usize), If(Box, Box, Box), // Desugar into explicit type bindings Abs(Box, Box), App(Box, Box), TyAbs(Box, Box), TyApp(Box, Box), Record(Vec), Tuple(Vec), RecordProj(Box, String), TupleProj(Box, usize), Case(Box, Vec), Let(Vec, Box), Fix(Box), } #[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Hash)] pub enum Kind { Star, Arrow(Box, Box), } #[derive(Clone, PartialEq, PartialOrd, Eq, Hash)] pub enum Type { Int, Bool, Unit, Infer, Error, Unclear, /// Defined name Defined(HirId), /// Type variable 'a Var(DeBruijn), /// Type of functions from terms to terms Arrow(Box, Box), /// Sum type; None | Some of 'a Sum(Vec), /// Tuple type (ty * ty * ... tyN), invariant that N >= 1 Product(Vec), /// Record type { [label: ty],+ }, invariant that N >=1 Record(Vec), /// Existential type: exists (a :: K) of ty Existential(Box, Box), /// Universal type: forall (a :: K) of ty Universal(Box, Box), /// Type level function abstraction Abstraction(Box, Box), /// Type level function application Application(Box, Box), /// Recursive type Recursive(Box), } #[derive(Clone, PartialEq, PartialOrd, Eq, Hash)] pub struct Variant { pub label: String, pub ty: Option, } #[derive(Clone, PartialEq, PartialOrd, Eq, Hash)] pub struct Row { pub label: String, pub ty: Type, } impl fmt::Debug for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Type::Unit => write!(f, "unit"), Type::Int => write!(f, "int"), Type::Bool => write!(f, "bool"), Type::Infer => write!(f, "_"), Type::Error => write!(f, "!"), Type::Unclear => write!(f, "?"), Type::Var(v) => write!(f, "{}", &v.idx), Type::Sum(v) => write!( f, "{}", v.iter() .map(|x| format!( "{}{}", x.label, x.ty.as_ref().map(|i| format!(" of {:?}", i)).unwrap_or(String::new()) )) .collect::>() .join(" | ") ), Type::Product(v) => write!( f, "({})", v.iter().map(|x| format!("{:?}", x)).collect::>().join(",") ), Type::Record(v) => write!( f, "{{{}}}", v.iter() .map(|x| format!("{}: {:?}", x.label, x.ty)) .collect::>() .join(", ") ), Type::Defined(s) => write!(f, "tctx#{:?}", s), Type::Arrow(t1, t2) => write!(f, "({:?}->{:?})", t1, t2), Type::Universal(k, ty) => write!(f, "forall X :: {:?}.{:?}", k, ty), Type::Existential(k, ty) => write!(f, "exists X. :: {:?}. {:?}", k, ty), Type::Abstraction(k, ty) => write!(f, "fn (X. :: {:?}) => {:?}", k, ty), Type::Application(a, b) => write!(f, "{:?} {:?}", b, a), Type::Recursive(ty) => write!(f, "rec {:?}", ty), } } } ================================================ FILE: 07_system_fw/src/macros.rs ================================================ #![allow(unused_macros)] /// Boolean term macro_rules! bool { ($x:expr) => { crate::terms::Term::new( crate::terms::Kind::Const(crate::terms::Constant::Bool($x)), util::span::Span::dummy(), ) }; } /// Integer term macro_rules! nat { ($x:expr) => { crate::terms::Term::new( crate::terms::Kind::Const(crate::terms::Constant::Nat($x)), util::span::Span::dummy(), ) }; } macro_rules! unit { () => { crate::terms::Term::new( crate::terms::Kind::Const(crate::terms::Constant::Unit), util::span::Span::dummy(), ) }; } /// TmVar term macro_rules! var { ($x:expr) => { crate::terms::Term::new(crate::terms::Kind::Var($x), util::span::Span::dummy()) }; } /// Application term macro_rules! app { ($t1:expr, $t2:expr) => { crate::terms::Term::new( crate::terms::Kind::App(Box::new($t1), Box::new($t2)), util::span::Span::dummy(), ) }; } /// Lambda abstraction term macro_rules! abs { ($ty:expr, $t:expr) => { crate::terms::Term::new( crate::terms::Kind::Abs(Box::new($ty), Box::new($t)), util::span::Span::dummy(), ) }; } /// Type application term macro_rules! tyapp { ($t1:expr, $t2:expr) => { crate::terms::Term::new( crate::terms::Kind::TyApp(Box::new($t1), Box::new($t2)), util::span::Span::dummy(), ) }; } /// Type abstraction term macro_rules! tyabs { ($k:expr, $t:expr) => { crate::terms::Term::new( crate::terms::Kind::TyAbs(Box::new($k), Box::new($t)), util::span::Span::dummy(), ) }; } macro_rules! pack { ($ty1:expr, $t:expr, $ty2:expr) => { crate::terms::Term::new( crate::terms::Kind::Pack(Box::new($ty1), Box::new($t), Box::new($ty2)), util::span::Span::dummy(), ) }; } macro_rules! unpack { ($t1:expr, $t2:expr) => { crate::terms::Term::new( crate::terms::Kind::Unpack(Box::new($t1), Box::new($t2)), util::span::Span::dummy(), ) }; } macro_rules! access { ($t1:expr, $t2:expr) => { crate::terms::Term::new( crate::terms::Kind::Index(Box::new($t1), $t2.into()), util::span::Span::dummy(), ) }; } macro_rules! exist { ($k:expr, $ty:expr) => { crate::types::Type::Existential(Box::new($k), Box::new($ty)) }; } macro_rules! univ { ($ty:expr) => { crate::types::Type::Universal(Box::new(kind!(*)), Box::new($ty)) }; ($k:expr, $ty:expr) => { crate::types::Type::Universal(Box::new($k), Box::new($ty)) }; } macro_rules! arrow { ($ty1:expr, $ty2:expr) => { crate::types::Type::Arrow(Box::new($ty1), Box::new($ty2)) }; } macro_rules! field { ($name:expr, $ty:expr) => { crate::types::TyField { label: $name.to_string(), ty: Box::new($ty), } }; } macro_rules! record { ($($name:expr),+) => { crate::types::Type::Record(vec![$(field!($name.0, $name.1)),+]) } } macro_rules! sum { ($($name:expr),+) => { crate::types::Type::Sum(vec![$(field!($name.0, $name.1)),+]) } } macro_rules! product { ($($name:expr),+) => { crate::types::Type::Product(vec![$($name),+]) } } macro_rules! tyop { ($k:expr, $ty:expr) => { crate::types::Type::Abs(Box::new($k), Box::new($ty)) }; } macro_rules! op_app { ($ty1:expr, $ty2:expr) => { crate::types::Type::App(Box::new($ty1), Box::new($ty2)) }; } macro_rules! kind { (*) => { crate::types::TyKind::Star }; (* => *) => { crate::types::TyKind::Arrow(Box::new(kind!(*)), Box::new(kind!(*))) }; ($ex:expr => $ex2:expr) => { crate::types::TyKind::Arrow(Box::new($ex), Box::new($ex2)) }; } macro_rules! diag { ($sp:expr, $str:expr) => { Err(crate::diagnostics::Diagnostic::error($sp, $str)) }; ($sp:expr, $fmt:expr, $($args:expr),+) => { Err(crate::diagnostics::Diagnostic::error($sp, format!($fmt, $($args),+))) }; } ================================================ FILE: 07_system_fw/src/main.rs ================================================ #![allow(dead_code)] #[macro_use] pub mod macros; pub mod diagnostics; pub mod elaborate; pub mod functor; pub mod hir; pub mod stack; pub mod syntax; pub mod terms; pub mod typecheck; pub mod types; use std::io::prelude::*; use syntax::ast; use syntax::parser::{Error, ErrorKind, Parser}; use terms::Term; use types::Type; use util::span::Span; fn main() { loop { let mut buffer = String::new(); print!("repl: "); std::io::stdout().flush().unwrap(); std::io::stdin().read_to_string(&mut buffer).unwrap(); let mut p = Parser::new(&buffer); // let mut ctx = elaborate::ElaborationContext::new(); // loop { match p.parse_program() { Ok(d) => { println!("====> {:?}", &d.decls); // println!("Validate: {:?}", validate::ProgramValidation::validate(&d)); let elab = elaborate::ElaborationContext::elaborate(&d).unwrap(); println!("-----"); hir::bidir::test(elab); } Err(Error { kind: ErrorKind::EOF, .. }) => {} Err(e) => { println!("[err] {:?}", e); } } } } fn unfold(ty: Type) -> Type { match &ty { Type::Recursive(inner) => op_app!(*inner.clone(), ty), Type::App(a, b) => match a.as_ref() { Type::Recursive(_) => op_app!(unfold(*a.clone()), *b.clone()), _ => ty, }, _ => ty, } } ================================================ FILE: 07_system_fw/src/stack.rs ================================================ //! Wrapper around a Vec for use as a de Bruijn indexed stack, e.g. index 0 //! returns the last item pushed onto the stack use std::fmt; pub struct Stack { inner: Vec, } impl Stack { #[inline] pub fn push(&mut self, val: T) { self.inner.push(val); } #[inline] pub fn pop(&mut self) -> Option { self.inner.pop() } #[inline] pub fn popn(&mut self, n: usize) { for _ in 0..n { let _ = self.pop(); } } #[inline] pub fn get(&self, index: usize) -> Option<&T> { self.inner.get(self.inner.len().checked_sub(1 + index)?) } #[inline] pub fn with_capacity(size: usize) -> Self { Stack { inner: Vec::with_capacity(size), } } #[inline] pub fn new() -> Self { Stack { inner: Vec::new() } } #[inline] pub fn len(&self) -> usize { self.inner.len() } pub fn iter(&self) -> std::slice::Iter { self.inner.iter() } pub fn iter_mut(&mut self) -> std::slice::IterMut { self.inner.iter_mut() } } impl Stack { pub fn lookup(&self, key: &T) -> Option { for (idx, s) in self.inner.iter().rev().enumerate() { if key == s { return Some(idx); } } None } } impl Extend for Stack { fn extend>(&mut self, iter: I) { for elem in iter { self.push(elem); } } } impl Clone for Stack { fn clone(&self) -> Self { Stack { inner: self.inner.clone(), } } } impl Default for Stack { fn default() -> Self { Stack { inner: Vec::default() } } } impl fmt::Debug for Stack { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{:?}", self.inner) } } #[cfg(test)] mod test { use super::*; #[test] fn order() { let mut stack = Stack::default(); for i in 0..16 { stack.push(i); } assert_eq!(stack.get(0), Some(&15)); assert_eq!(stack.get(1), Some(&14)); assert_eq!(stack.get(2), Some(&13)); stack.pop(); assert_eq!(stack.get(0), Some(&14)); } } ================================================ FILE: 07_system_fw/src/syntax/ast.rs ================================================ use util::span::Span; #[derive(Copy, Clone, Debug, Default, PartialEq, PartialOrd, Eq, Hash)] pub struct AstId(pub(crate) u32); pub const AST_DUMMY: AstId = AstId(std::u32::MAX); macro_rules! container { ($id:ident, $id2:ident) => { #[derive(Clone, PartialEq, PartialOrd)] pub struct $id { pub kind: $id2, pub span: Span, pub id: AstId, } impl $id { pub fn new(kind: $id2, span: Span) -> $id { $id { kind, span, id: AST_DUMMY, } } pub fn with_id(kind: $id2, span: Span, id: AstId) -> $id { $id { kind, span, id } } } impl std::fmt::Debug for $id { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{:?}", self.kind) } } }; } container!(Expr, ExprKind); container!(Type, TypeKind); container!(Decl, DeclKind); container!(Pattern, PatKind); pub struct Program { pub decls: Vec, } /// Arm of a case expression #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Arm { pub pat: Pattern, pub expr: Expr, pub span: Span, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Field { pub label: String, pub expr: Expr, pub span: Span, } /// Patterns for case and let expressions #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum PatKind { /// Wildcard pattern, this always matches Any, Unit, Ascribe(Box, Box), /// Constant pattern Literal(usize), /// Datatype constructor Constructor(String), /// Variable binding Variable(String), /// Tuple of pattern bindings (_, x) Product(Vec), /// Record pattern { label1, label2 } Record(Vec), /// Algebraic datatype constructor, along with binding pattern Application(Box, Box), } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum ExprKind { Unit, Int(usize), Var(String), Constr(String), If(Box, Box, Box), Abs(Box, Box), App(Box, Box), /// Explicit type abstraction `fn 'x value (arg: 'x) = arg` TyAbs(String, Box, Box), /// Explicit type application `e @ty` TyApp(Box, Box), Record(Vec), Tuple(Vec), Projection(Box, Box), Case(Box, Vec), Let(Vec, Box), } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct FnArm { pub span: Span, pub pats: Vec, pub expr: Expr, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum DeclKind { Type(Vec, String, Type), Datatype(Vec, String, Type), Value(Vec, Pattern, Expr), Function(Vec, String, Vec), And(Box, Box), Expr(Expr), } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum Kind { Star, Arrow(Box, Box), } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum TypeKind { Int, Bool, Unit, Infer, /// Defined name Defined(String), /// Type variable 'a Variable(String), /// Type of functions from terms to terms Function(Box, Box), /// Sum type; None | Some of 'a Sum(Vec), /// Tuple type (ty * ty * ... tyN), invariant that N >= 1 Product(Vec), /// Record type { [label: ty],+ }, invariant that N >=1 Record(Vec), /// Existential type: exists (a :: K) of ty Existential(String, Box, Box), /// Universal type: forall (a :: K) of ty Universal(String, Box, Box), /// Type level function abstraction Abstraction(String, Box, Box), /// Type level function application Application(Box, Box), /// Recursive type Recursive(Box), } #[derive(Clone, PartialEq, PartialOrd)] pub struct Variant { pub label: String, pub ty: Option, pub span: Span, } #[derive(Clone, PartialEq, PartialOrd)] pub struct Row { pub label: String, pub ty: Type, pub span: Span, } impl TypeKind { pub fn variants(&self) -> &[Variant] { match self { TypeKind::Sum(v) => &v, _ => panic!("Not a sum type!"), } } pub fn as_tyvar(&self) -> &str { match self { TypeKind::Variable(s) => s, _ => panic!("Not a type var!"), } } pub fn as_tyvar_d(self) -> String { match self { TypeKind::Variable(s) => s, _ => panic!("Not a type var!"), } } } impl std::fmt::Debug for Variant { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{} of {:?}", self.label, self.ty) } } impl std::fmt::Debug for Row { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}: {:?}", self.label, self.ty) } } ================================================ FILE: 07_system_fw/src/syntax/lexer.rs ================================================ use super::tokens::*; use std::char; use std::iter::Peekable; use std::str::Chars; use util::span::{Location, Span, Spanned}; #[derive(Clone, Debug)] pub struct Lexer<'s> { input: Peekable>, current: Location, } impl<'s> Lexer<'s> { pub fn new(input: Chars<'s>) -> Lexer<'s> { Lexer { input: input.peekable(), current: Location { line: 0, col: 0, abs: 0, }, } } /// Peek at the next [`char`] in the input stream fn peek(&mut self) -> Option { self.input.peek().copied() } /// Consume the next [`char`] and advance internal source position fn consume(&mut self) -> Option { match self.input.next() { Some('\n') => { self.current.line += 1; self.current.col = 0; self.current.abs += 1; Some('\n') } Some(ch) => { self.current.col += 1; self.current.abs += 1; Some(ch) } None => None, } } /// Consume characters from the input stream while pred(peek()) is true, /// collecting the characters into a string. fn consume_while bool>(&mut self, pred: F) -> (String, Span) { let mut s = String::new(); let start = self.current; while let Some(n) = self.peek() { if pred(n) { match self.consume() { Some(ch) => s.push(ch), None => break, } } else { break; } } (s, Span::new(start, self.current)) } /// Eat whitespace fn consume_delimiter(&mut self) { let _ = self.consume_while(char::is_whitespace); } fn id(&mut self) -> (String, Span) { let (mut s, mut span) = self.consume_while(|ch| ch.is_ascii_alphanumeric()); while let Some(ch) = self.peek() { if ch != '.' { break; } self.consume(); s.push('.'); let (s2, sp2) = self.consume_while(|ch| ch.is_ascii_alphanumeric()); s += &s2; span += sp2; } (s, span) } fn valid_id_char(c: char) -> bool { match c { '.' => false, x if x.is_alphanumeric() => true, '`' | '~' | '!' | '$' | '%' | '^' | '&' | '-' | '_' | '+' | '?' | '<' | '>' => true, _ => false, } } /// Lex a reserved keyword or identifier fn keyword(&mut self) -> Spanned { let (word, sp) = self.consume_while(Self::valid_id_char); let kind = match word.as_ref() { "fun" => Token::Function, "fn" => Token::Lambda, "val" => Token::Val, "let" => Token::Let, "in" => Token::In, "case" => Token::Case, "of" => Token::Of, "end" => Token::End, "as" => Token::As, "if" => Token::If, "then" => Token::Then, "else" => Token::Else, "fix" => Token::Fix, "rec" => Token::Rec, "exists" => Token::Exists, "forall" => Token::Forall, "type" => Token::Type, "datatype" => Token::Datatype, "and" => Token::And, "int" => Token::TyInt, "unit" => Token::TyUnit, "bool" => Token::TyBool, s if s.starts_with(char::is_uppercase) => Token::UpperId(word), _ => Token::LowerId(word), }; Spanned::new(sp, kind) } /// Consume the next input character, expecting to match `ch`. /// Return a [`TokenKind::Invalid`] if the next character does not match, /// or the argument `kind` if it does fn eat(&mut self, ch: char, kind: Token) -> Spanned { let loc = self.current; // Lexer::eat() should only be called internally after calling peek() // so we know that it's safe to unwrap the result of Lexer::consume() let n = self.consume().unwrap(); let kind = if n == ch { kind } else { Token::Invalid(n) }; Spanned::new(Span::new(loc, self.current), kind) } /// Lex a natural number fn number(&mut self) -> Spanned { // Since we peeked at least one numeric char, we should always // have a string containing at least 1 single digit, as such // it is safe to call unwrap() on str::parse let (data, span) = self.consume_while(char::is_numeric); let n = data.parse::().unwrap(); Spanned::new(span, Token::Int(n)) } pub fn lex(&mut self) -> Spanned { self.consume_delimiter(); let next = match self.peek() { Some(ch) => ch, None => return Spanned::new(Span::new(self.current, self.current), Token::EOF), }; macro_rules! disamb { ($ch:expr, $($ch2:expr, $p:expr),+) => {{ self.consume(); match self.peek() { Some(ch) => match ch { $($ch2 => self.eat($ch2, $p)),+, _ => Spanned::new(Span::new(self.current, self.current), Token::Invalid($ch)) }, None => Spanned::new(Span::new(self.current, self.current), Token::Invalid($ch)), } }}; ($ch:expr, $p1:expr, $($ch2:expr, $p:expr),+) => {{ let fail = self.eat($ch, $p1); match self.peek() { Some(ch) => match ch { $($ch2 => self.eat($ch2, $p)),+, _ => fail, }, None => fail, } }}; } match next { '.' => self.eat('.', Token::Dot), ':' => disamb!(':', Token::Colon, '>', Token::Opaque), ';' => self.eat(';', Token::Semicolon), ',' => self.eat(',', Token::Comma), '\'' => self.eat('\'', Token::Apostrophe), '|' => self.eat('|', Token::Bar), '-' => disamb!('-', '>', Token::SingleArrow), '=' => disamb!('=', Token::Equals, '>', Token::DoubleArrow), '_' => self.eat('_', Token::Wildcard), '*' => self.eat('*', Token::Asterisk), '(' => disamb!('(', Token::LParen, ')', Token::Unit), ')' => self.eat(')', Token::RParen), '{' => self.eat('{', Token::LBrace), '}' => self.eat('}', Token::RBrace), '@' => self.eat('@', Token::TypeAppSigil), '\\' => self.eat('\\', Token::Lambda), 'λ' => self.eat('λ', Token::Lambda), '∀' => self.eat('∀', Token::Forall), '∃' => self.eat('∃', Token::Exists), x if x.is_ascii_alphabetic() => self.keyword(), x if x.is_numeric() => self.number(), _ => self.eat(' ', Token::EOF), // _ => Spanned::new(Token::Invalid(next), Span::new(self.current, self.current)), } } } impl<'s> Iterator for Lexer<'s> { type Item = Spanned; fn next(&mut self) -> Option { match self.lex() { Spanned { data: Token::EOF, .. } => None, tok => Some(tok), } } } ================================================ FILE: 07_system_fw/src/syntax/mod.rs ================================================ pub mod ast; pub mod lexer; pub mod parser; pub mod tokens; pub mod visit; ================================================ FILE: 07_system_fw/src/syntax/parser/README.md ================================================ # Parser We use a handwritten recursive descent parser. In general, there is a top-level entry function for parsing of `types`, `expressions`, `declarations`, and `patterns`. Each of these functions attempts to match against the current token, without popping it. If a suitable match is found, then we dispatch to a function specifically for parsing that token, which may then pop off the current token. Once the current token has been popped, we actually begin to return real errors. Errors that occur before the current token has been popped may be generally ignored. ================================================ FILE: 07_system_fw/src/syntax/parser/decls.rs ================================================ use super::*; impl<'s> Parser<'s> { fn decl_datatype(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::Datatype)?; let tyvars = self.parse_tyvar_sequence()?; let tyname = self.expect_lower_id()?; self.expect(Token::Equals)?; self.bump_if(&Token::Bar); let ty = self.type_sum()?; span += self.prev; Ok(Decl::with_id( DeclKind::Datatype(tyvars, tyname, ty), span, self.allocate_ast_id(), )) } fn decl_type(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::Type)?; let tyvars = self.parse_tyvar_sequence()?; let tyname = self.expect_lower_id()?; self.expect(Token::Equals)?; let ty = self.parse_type()?; span += self.prev; Ok(Decl::with_id( DeclKind::Type(tyvars, tyname, ty), span, self.allocate_ast_id(), )) } fn decl_value(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::Val)?; let tyvars = self.parse_tyvar_sequence()?; let pat = self.parse_pattern()?; self.expect(Token::Equals)?; let expr = self.parse_expr()?; span += self.prev; Ok(Decl::with_id( DeclKind::Value(tyvars, pat, expr), span, self.allocate_ast_id(), )) } fn decl_fun_arm(&mut self, ident: &str) -> Result { let mut span = self.current.span; let id = self.expect_lower_id()?; if id != ident { return self.error(ErrorKind::FunctionIdMismatch); } let pats = self.plus(|p| p.atomic_pattern(), None)?; self.expect(Token::Equals)?; let expr = self.parse_expr()?; span += self.prev; Ok(FnArm { pats, expr, span }) } fn decl_fun(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::Function)?; let tyvars = self.parse_tyvar_sequence()?; // Peek the id and clone it, since decl_fun_arm will expect it to be there let ident = match self.current() { Token::LowerId(id) => id.clone(), _ => return self.error(ErrorKind::ExpectedIdentifier), }; let arms = self.delimited(|p| p.decl_fun_arm(&ident), Token::Bar)?; span += self.prev; Ok(Decl::with_id( DeclKind::Function(tyvars, ident, arms), span, self.allocate_ast_id(), )) } fn decl_expr(&mut self) -> Result { let expr = self.parse_expr()?; let sp = expr.span; Ok(Decl::with_id(DeclKind::Expr(expr), sp, self.allocate_ast_id())) } /// Parse a simple declaration /// decl ::= type /// datatype /// val /// fun /// exp pub fn parse_decl_atom(&mut self) -> Result { match self.current() { Token::Type => self.decl_type(), Token::Datatype => self.decl_datatype(), Token::Val => self.decl_value(), Token::Function => self.decl_fun(), _ => self.decl_expr(), } } pub(crate) fn parse_decl(&mut self) -> Result { let mut span = self.current.span; let mut d = self.parse_decl_atom()?; while let Token::And = self.current.data { self.bump(); let d2 = self.once(|p| p.parse_decl_atom(), "expected declaration after `and`")?; span += self.prev; d = Decl::with_id(DeclKind::And(Box::new(d), Box::new(d2)), span, self.allocate_ast_id()); } span += self.prev; Ok(d) } pub fn parse_program(&mut self) -> Result { let mut decls = vec![self.parse_decl()?]; self.bump_if(&Token::Semicolon); while let Ok(d) = self.parse_decl() { decls.push(d); self.bump_if(&Token::Semicolon); } Ok(Program { decls }) } } ================================================ FILE: 07_system_fw/src/syntax/parser/exprs.rs ================================================ use super::*; impl<'s> Parser<'s> { fn record_row(&mut self) -> Result { let mut span = self.current.span; let label = self.expect_lower_id()?; self.expect(Token::Equals)?; let expr = self.once(|p| p.parse_expr(), "missing expr in record row")?; span += self.prev; Ok(Field { label, expr, span }) } fn record_expr(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::LBrace)?; let fields = self.delimited(|p| p.record_row(), Token::Comma)?; self.expect(Token::RBrace)?; span += self.prev; Ok(Expr::new(ExprKind::Record(fields), span)) } fn let_binding(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::Let)?; // let pat = self.once(|p| p.parse_pattern(), "missing pattern in let // binding")?; self.expect(Token::Equals)?; // let t1 = self.once(|p| p.parse_expr(), "let binder required")?; let decls = self.parse_program()?.decls; self.expect(Token::In)?; let t2 = self.once(|p| p.parse_expr(), "let body required")?; self.expect(Token::End)?; span += self.prev; Ok(Expr::new(ExprKind::Let(decls, Box::new(t2)), span)) } fn case_arm(&mut self) -> Result { let mut span = self.current.span; let pat = self.once(|p| p.parse_pattern(), "missing pattern in case arm")?; self.expect(Token::DoubleArrow)?; let expr = self.once(|p| p.parse_expr(), "missing expression in case arm")?; self.bump_if(&Token::Comma); span += self.prev; Ok(Arm { pat, expr, span }) } fn case_expr(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::Case)?; let expr = self.once(|p| p.parse_expr(), "missing case expression")?; self.expect(Token::Of)?; self.bump_if(&Token::Bar); let arms = self.delimited(|p| p.case_arm(), Token::Bar)?; self.expect(Token::End)?; span += self.prev; Ok(Expr::new(ExprKind::Case(Box::new(expr), arms), span)) } fn lambda_expr(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::Lambda)?; let arg = self.once(|p| p.parse_pattern(), "expected pattern binding in lambda expression!")?; self.expect(Token::DoubleArrow)?; let body = self.parse_expr()?; span += self.prev; Ok(Expr::new(ExprKind::Abs(Box::new(arg), Box::new(body)), span)) } fn if_expr(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::If)?; let guard = self.parse_expr()?; self.expect(Token::Then)?; let cond = self.parse_expr()?; self.expect(Token::Else)?; let alt = self.parse_expr()?; span += self.prev; Ok(Expr::new( ExprKind::If(Box::new(guard), Box::new(cond), Box::new(alt)), span, )) } /// atexp ::= constant /// id /// { [label = exp] } /// () /// ( exp, ... expN ) /// ( exp ) /// let decl in exp, ... expN end fn atomic_expr(&mut self) -> Result { let mut span = self.current.span; match self.current.data { Token::LowerId(_) => self.expect_lower_id().map(|e| Expr::new(ExprKind::Var(e), span)), Token::UpperId(_) => self.expect_upper_id().map(|e| Expr::new(ExprKind::Constr(e), span)), Token::LBrace => self.record_expr(), Token::Let => self.let_binding(), Token::Int(n) => { self.bump(); Ok(Expr::new(ExprKind::Int(n), span)) } Token::Unit => { self.bump(); Ok(Expr::new(ExprKind::Unit, span)) } Token::LParen => { self.expect(Token::LParen)?; let mut exprs = self.delimited(|p| p.parse_expr(), Token::Comma)?; let e = match exprs.len() { 1 => exprs.pop().unwrap(), _ => Expr::new(ExprKind::Tuple(exprs), span), }; self.expect(Token::RParen)?; span += self.prev; Ok(e) } _ => self.error(ErrorKind::ExpectedExpr), } } fn projection_expr(&mut self) -> Result { let mut span = self.current.span; let mut expr = self.atomic_expr()?; while self.bump_if(&Token::Dot) { span += self.prev; let p = self.once(|p| p.atomic_expr(), "expected expr after Dot")?; expr = Expr::new(ExprKind::Projection(Box::new(expr), Box::new(p)), span); } Ok(expr) } /// appexp ::= atexp /// appexp atexp fn application_expr(&mut self) -> Result { let mut span = self.current.span; let mut expr = self.projection_expr()?; loop { if let Token::LowerId(s) = &self.current() { if self.infix.get(&s).is_some() { break; } } if let Ok(e) = self.projection_expr() { span += self.prev; expr = Expr::new(ExprKind::App(Box::new(expr), Box::new(e)), span); } else if let Token::TypeAppSigil = self.current() { self.bump(); let ty = self.type_atom()?; span += self.prev; expr = Expr::new(ExprKind::TyApp(Box::new(expr), Box::new(ty)), span); } else { break; } } Ok(expr) } /// exp ::= appexp /// exp path exp // fn infix_expr(&mut self) -> Result { // let mut span = self.current.span; // let mut expr = self.application_expr()?; // while let Token::LowerId(s) = &self.current() { // if self.infix.get(s).is_some() { // let p = self.expect_lower_id()?; // let e = self.application_expr()?; // span += self.prev; // expr = Expr::new(ExprKind::Infix(p, Box::new(expr), Box::new(e)), span) // } else { // break; // } // } // Ok(expr) // } /// exp ::= if exp then exp2 else exp3 /// case exp of casearm end /// fn x /// infix pub fn parse_expr(&mut self) -> Result { match self.current() { Token::Case => self.case_expr(), Token::If => self.if_expr(), Token::Lambda => self.lambda_expr(), _ => self.application_expr(), } } } ================================================ FILE: 07_system_fw/src/syntax/parser/infix.rs ================================================ use std::collections::HashMap; #[derive(Clone, Default, Debug)] pub struct Infix { precedence: HashMap, } impl Infix { pub fn insert(&mut self, s: String, prec: usize) { self.precedence.insert(s, prec); } pub fn get(&self, s: &str) -> Option { self.precedence.get(s).copied() } } ================================================ FILE: 07_system_fw/src/syntax/parser/mod.rs ================================================ pub mod decls; pub mod exprs; pub mod infix; pub mod patterns; pub mod types; use super::ast::*; use super::lexer::Lexer; use super::tokens::*; use infix::Infix; use util::span::{Span, Spanned}; pub struct Parser<'s> { tokens: Lexer<'s>, current: Spanned, prev: Span, infix: Infix, next_ast_id: AstId, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum ErrorKind { ExpectedToken(Token), ExpectedIdentifier, ExpectedType, ExpectedExpr, ExpectedPattern, ExpectedDeclaration, ExpectedSpecification, ExpectedSignature, ExpectedStructure, UnboundTypeVar, UnboundExprVar, FunctionIdMismatch, EOF, } #[derive(Default)] pub struct InfixState(Infix); #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Error { pub span: Span, pub token: Token, pub kind: ErrorKind, } impl<'s> Parser<'s> { pub fn new(input: &'s str) -> Parser<'s> { Parser::with_infix_state(input, InfixState::default()) } pub fn with_infix_state(input: &'s str, state: InfixState) -> Parser<'s> { let mut p = Parser { tokens: Lexer::new(input.chars()), current: Spanned::new(Span::zero(), Token::Placeholder), infix: state.0, prev: Span::zero(), next_ast_id: AstId(0), }; p.bump(); p } pub fn top_level(&mut self) -> Result, Error> { let mut v = Vec::new(); while self.current() != &Token::EOF { v.push(self.parse_decl()?); self.bump_if(&Token::Semicolon); } Ok(v) } pub fn state(&self) -> InfixState { InfixState(self.infix.clone()) } fn allocate_ast_id(&mut self) -> AstId { let id = self.next_ast_id; self.next_ast_id = AstId(self.next_ast_id.0 + 1); id } /// Generate a parsing error. These are not necessarily fatal fn error(&self, k: ErrorKind) -> Result { Err(Error { span: self.current.span, token: self.current().clone(), kind: k, }) } fn current(&self) -> &Token { &self.current.data } /// Bump the current token, returning it, and pull a new token /// from the lexer fn bump(&mut self) -> Token { match self.tokens.next() { Some(t) => { #[cfg(test)] { let t = std::mem::replace(&mut self.current, t).data(); self.current.span = Span::default(); self.prev = Span::default(); t } #[cfg(not(test))] { self.prev = self.current.span; std::mem::replace(&mut self.current, t).data() } } None => std::mem::replace(&mut self.current.data, Token::EOF), } } /// Ignore a token matching `kind` fn bump_if(&mut self, kind: &Token) -> bool { if &self.current.data == kind { self.bump(); true } else { false } } fn expect(&mut self, kind: Token) -> Result<(), Error> { if self.current() == &kind { self.bump(); Ok(()) } else { self.error(ErrorKind::ExpectedToken(kind)) } } fn expect_lower_id(&mut self) -> Result { match self.current() { Token::LowerId(_) => Ok(self.bump().extract_string()), _ => self.error(ErrorKind::ExpectedIdentifier), } } fn expect_upper_id(&mut self) -> Result { match self.current() { Token::UpperId(_) => Ok(self.bump().extract_string()), _ => self.error(ErrorKind::ExpectedIdentifier), } } /// Call `func` once, returning the `Result` of the function. /// A failure of `func` may have side effects, including emitting /// diagnostics containing `message` /// /// Generally, this is just used to give better error messages fn once(&mut self, func: F, message: &str) -> Result where F: Fn(&mut Parser) -> Result, { match func(self) { Ok(t) => Ok(t), Err(e) => { eprintln!("[Parser] {}", message); Err(e) } } } /// Collect the result of `func` into a `Vec` as long as `func` returns /// an `Ok(T)`. A call to `func` must succeed on the first try, or an error /// is immediately returned. Subsequent calls to `func` may fail, in which /// case the error is discarded, and the results are returned. If `delimit` /// is supplied, the parser will discard matching tokens between each call /// to `func` fn plus(&mut self, func: F, delimit: Option<&Token>) -> Result, E> where F: Fn(&mut Parser) -> Result, { let mut v = vec![func(self)?]; if let Some(t) = delimit { if !self.bump_if(t) { return Ok(v); } } while let Ok(x) = func(self) { v.push(x); if let Some(t) = delimit { if !self.bump_if(t) { break; } } } Ok(v) } /// Collect the result of `func` into a `Vec` as long as `func` returns /// an `Ok(T)`. If an error is encountered, it is discarded and the results /// are immediately returned. If `delimit` is supplied, the parser will /// discard matching tokens between each call to `func` fn star(&mut self, func: F, delimit: Option<&Token>) -> Vec where F: Fn(&mut Parser) -> Result, { let mut v = Vec::new(); while let Ok(x) = func(self) { v.push(x); if let Some(t) = delimit { if !self.bump_if(t) { break; } } } v } /// Identical semantics to `Parser::plus`, except `delimit` must be supplied fn delimited(&mut self, func: F, delimit: Token) -> Result, E> where F: Fn(&mut Parser) -> Result, { let mut v = vec![func(self)?]; while self.bump_if(&delimit) { v.push(func(self)?); } Ok(v) } } ================================================ FILE: 07_system_fw/src/syntax/parser/patterns.rs ================================================ use super::*; use PatKind::*; impl<'s> Parser<'s> { fn tuple_pattern(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::LParen)?; let mut v = self.star(|p| p.parse_pattern(), Some(&Token::Comma)); self.expect(Token::RParen)?; span += self.prev; match v.len() { 0 => Ok(Pattern::new(Unit, span)), 1 => Ok(v.pop().unwrap()), _ => Ok(Pattern::new(Product(v), span)), } } fn record_pattern(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::LBrace)?; let v = self.delimited(|p| p.expect_lower_id(), Token::Comma)?; self.expect(Token::RBrace)?; span += self.prev; Ok(Pattern::new(Record(v), span)) } /// atpat ::= constant /// id /// wildcard /// ( pat ) /// ( pat, ... patN ) /// { [patrow] } pub(crate) fn atomic_pattern(&mut self) -> Result { let span = self.current.span; match self.current.data { Token::Wildcard => { self.bump(); Ok(Pattern::new(Any, span)) } Token::LowerId(_) => self.expect_lower_id().map(|s| Pattern::new(Variable(s), span)), Token::UpperId(_) => self.expect_upper_id().map(|s| Pattern::new(Constructor(s), span)), Token::Int(n) => { self.bump(); Ok(Pattern::new(Literal(n), span)) } Token::Unit => { self.bump(); Ok(Pattern::new(PatKind::Unit, span)) } Token::LParen => self.tuple_pattern(), Token::LBrace => self.record_pattern(), _ => self.error(ErrorKind::ExpectedPattern), } } /// app_pat ::= atpat /// app_pat atpat fn application_pattern(&mut self) -> Result { let mut span = self.current.span; let pat = self.atomic_pattern()?; if let PatKind::Constructor(_) = pat.kind { match self.atomic_pattern() { Ok(arg) => { span += self.prev; return Ok(Pattern::new(Application(Box::new(pat), Box::new(arg)), span)); } _ => return Ok(pat), } } // while let Ok(e) = self.atomic_pattern() { // span += self.prev; // pat = Pattern::new(Application(Box::new(pat), Box::new(e)), span); // } Ok(pat) } pub fn parse_pattern(&mut self) -> Result { let mut span = self.current.span; let pat = self.application_pattern()?; if self.bump_if(&Token::Colon) { let ty = self.once(|p| p.parse_type(), "expected type annotation after `pat :`")?; span += self.prev; return Ok(Pattern::new(Ascribe(Box::new(pat), Box::new(ty)), span)); } Ok(pat) } } ================================================ FILE: 07_system_fw/src/syntax/parser/types.rs ================================================ use super::*; use TypeKind::*; impl<'s> Parser<'s> { /// Parse a datatype Constructor [A-Z]+ pub fn variant(&mut self) -> Result { let mut span = self.current.span; let label = self.expect_upper_id()?; let ty = if self.bump_if(&Token::Of) { Some(self.parse_type()?) } else { None }; span += self.prev; Ok(Variant { label, ty, span }) } pub fn type_sum(&mut self) -> Result { let mut span = self.current.span; let vars = self.delimited(|p| p.variant(), Token::Bar)?; span += self.prev; Ok(Type::new(Sum(vars), span)) } /// Parse a single type variable '[a-z]+ fn parse_tyvar(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::Apostrophe)?; let v = self.expect_lower_id().map(Variable); span += self.prev; v.map(|t| Type::new(t, span)) } /// Parse a sequence of type variables ('a, 'b), which are N-arity arguments /// to a type constructor. /// This may return an empty Vec<> pub(crate) fn parse_tyvar_sequence(&mut self) -> Result, Error> { if self.bump_if(&Token::LParen) { let ret = self.delimited(|p| p.parse_tyvar(), Token::Comma)?; self.expect(Token::RParen)?; return Ok(ret); } Ok(self.star(|p| p.parse_tyvar(), Some(&Token::Comma))) } /// Parse a type sequence (ty1, ty2) , which are N-arity arguments to a /// type constructor. /// If an `Ok(Vec)` is returned, then Vec<> will always have N>=1 items fn parse_type_sequence(&mut self) -> Result, Error> { self.expect(Token::LParen)?; let ret = self.delimited(|p| p.parse_type(), Token::Comma)?; self.expect(Token::RParen)?; Ok(ret) } /// Parse a existential type of form `forall ('tv :: K) of ty` fn existential(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::Exists)?; let (name, kind) = self.once( |p| p.abstraction_arg(), "existential type requires an arg of form ('t :: K)", )?; self.expect(Token::Of)?; let body = self.once(|p| p.parse_type(), "existential type requires a body")?; span += self.prev; // We should probably just parse tyvars as string directly... Ok(Type::new( Existential(name.kind.as_tyvar_d(), Box::new(kind), Box::new(body)), span, )) } /// Parse a universal type of form `forall ('tv :: K) of ty` fn universal(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::Forall)?; let (name, kind) = self.once( |p| p.abstraction_arg(), "universal type requires an arg of form ('t :: K)", )?; self.expect(Token::Of)?; let body = self.once(|p| p.parse_type(), "universal type requires a body")?; span += self.prev; Ok(Type::new( Universal(name.kind.as_tyvar_d(), Box::new(kind), Box::new(body)), span, )) } /// Parse a type row of form `label: ty` fn row(&mut self) -> Result { let mut span = self.current.span; let label = self.expect_lower_id()?; self.expect(Token::Colon)?; let ty = self.once(|p| p.parse_type(), "record type row requires a type {label: ty, ...}")?; span += self.prev; Ok(Row { label, ty, span }) } /// Parse a type of form `{ label: ty, label2: ty2, ...}` fn record(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::LBrace)?; let rows = self.delimited(|p| p.row(), Token::Comma)?; self.expect(Token::RBrace)?; span += self.prev; Ok(Type::new(Record(rows), span)) } /// Parse a type of form: /// ty ::= 'var /// id /// ( ty ) /// ( ty1, ... tyN) ty /// fn (var :: kind) => ty /// exists (var :: kind) of ty /// forall (var :: kind) of ty /// rec ty /// { label: ty, ...} pub(crate) fn type_atom(&mut self) -> Result { let mut span = self.current.span; match self.current.data { Token::TyInt => { self.bump(); Ok(Type::new(Int, span)) } Token::TyBool => { self.bump(); Ok(Type::new(Bool, span)) } Token::TyUnit => { self.bump(); Ok(Type::new(Unit, span)) } Token::Apostrophe => self.parse_tyvar(), Token::LowerId(_) => self.expect_lower_id().map(|p| Type::new(Defined(p), span)), Token::Lambda => self.abstraction(), Token::Exists => self.existential(), Token::Forall => self.universal(), Token::Rec => { self.expect(Token::Rec)?; let ty = self.parse_type()?; span += self.prev; Ok(Type::new(Recursive(Box::new(ty)), span)) } Token::LBrace => self.record(), Token::LParen => { // Handle a set of N-arity arguments to constructors let mut v = self.parse_type_sequence()?; if v.len() == 1 { Ok(v.pop().unwrap()) } else { let tycon = self.type_atom()?; Ok(v.into_iter().fold(tycon, |ty, v| { let sp = ty.span + v.span; Type::new(Application(Box::new(ty), Box::new(v)), sp) })) } } Token::Wildcard => { self.bump(); Ok(Type::new(Infer, span)) } _ => self.error(ErrorKind::ExpectedType), } } /// Parse an argument of form: `('t :: K)` fn abstraction_arg(&mut self) -> Result<(Type, Kind), Error> { self.expect(Token::LParen)?; let tyvar = self.parse_tyvar()?; self.expect(Token::Colon)?; self.expect(Token::Colon)?; let k = self.kind()?; self.expect(Token::RParen)?; Ok((tyvar, k)) } /// Parse a type of form: `lambda ('t :: K) => ty` fn abstraction(&mut self) -> Result { let mut span = self.current.span; self.expect(Token::Lambda)?; // let args = self.plus(|p| p.abstraction_arg())?; let (name, kind) = self.once( |p| p.abstraction_arg(), "type abstraction requires an arg of form ('t :: K)", )?; self.expect(Token::DoubleArrow)?; let body = self.parse_type()?; span += self.prev; Ok(Type::new( Abstraction(name.kind.as_tyvar_d(), Box::new(kind), Box::new(body)), span, )) } /// Parse an application of form: `('a, 'b, ...) ty1 ty2 ty3` fn application(&mut self) -> Result { // TODO: Confirm this is incorrect for all cases let mut tys = self.plus(|p| p.type_atom(), None)?; tys.reverse(); let ty = tys.pop().unwrap(); Ok(tys.into_iter().rev().fold(ty, |ty, v| { let sp = ty.span + v.span; Type::new(Application(Box::new(v), Box::new(ty)), sp) })) } /// Parse a type of form: `ty` | `ty * ty2 * ...` fn product(&mut self) -> Result { let mut span = self.current.span; let mut v = self.delimited(|p| p.application(), Token::Asterisk)?; span += self.prev; match v.len() { 1 => Ok(v.pop().unwrap()), _ => Ok(Type::new(Product(v), span)), } } /// Parse a type of form: `ty * ty` | `ty -> ty` pub fn parse_type(&mut self) -> Result { let mut span = self.current.span; let ty = self.product()?; if self.bump_if(&Token::SingleArrow) { let ty2 = self.parse_type()?; span += ty2.span; return Ok(Type::new(Function(Box::new(ty), Box::new(ty2)), span)); } Ok(ty) } /// Parse a kind of form: `* | ( K )` fn kind_single(&mut self) -> Result { if self.bump_if(&Token::LParen) { let k = self.kind()?; self.expect(Token::RParen)?; return Ok(k); } self.expect(Token::Asterisk)?; Ok(Kind::Star) } /// Parse a kind of form: `K | K -> K` pub fn kind(&mut self) -> Result { let k = self.kind_single()?; if self.bump_if(&Token::SingleArrow) { let k2 = self.kind()?; return Ok(Kind::Arrow(Box::new(k), Box::new(k2))); } Ok(k) } } ================================================ FILE: 07_system_fw/src/syntax/tokens.rs ================================================ #[allow(dead_code)] #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum Token { Dot, Colon, Opaque, Semicolon, Comma, Apostrophe, Bar, SingleArrow, DoubleArrow, Wildcard, Asterisk, Equals, LParen, RParen, LBrace, RBrace, And, Function, Lambda, Val, Let, In, Case, Of, End, As, If, Then, Else, Type, Datatype, Fix, Rec, Exists, Forall, TyInt, TyBool, TyUnit, TypeAppSigil, LowerId(String), UpperId(String), Comment(String), Int(usize), Unit, Placeholder, Invalid(char), EOF, } impl Token { pub fn extract_string(self) -> String { match self { Token::LowerId(s) | Token::UpperId(s) | Token::Comment(s) => s, _ => panic!("Invalid token {:?}", self), } } } ================================================ FILE: 07_system_fw/src/syntax/visit/mod.rs ================================================ use super::*; mod types; pub use types::TypeVisitor; ================================================ FILE: 07_system_fw/src/syntax/visit/types.rs ================================================ use super::*; use ast::{Kind, Row, Type, TypeKind, Variant}; pub trait TypeVisitor<'t>: Sized { fn visit_defined(&mut self, _: &'t str) {} fn visit_variable(&mut self, _: &'t str) {} fn visit_function(&mut self, ty1: &'t Type, ty2: &'t Type) { self.visit_ty(ty1); self.visit_ty(ty2); } fn visit_application(&mut self, ty1: &'t Type, ty2: &'t Type) { self.visit_ty(ty1); self.visit_ty(ty2); } fn visit_sum(&mut self, var: &'t [Variant]) { for v in var { if let Some(ty) = &v.ty { self.visit_ty(ty); } } } fn visit_product(&mut self, var: &'t [Type]) { for v in var { self.visit_ty(v); } } fn visit_record(&mut self, var: &'t [Row]) { for v in var { self.visit_ty(&v.ty); } } fn visit_existential(&mut self, _: &'t str, _: &'t Kind, ty: &'t Type) { self.visit_ty(ty); } fn visit_universal(&mut self, _: &'t str, _: &'t Kind, ty: &'t Type) { self.visit_ty(ty); } fn visit_abstraction(&mut self, _: &'t str, _: &'t Kind, ty: &'t Type) { self.visit_ty(ty); } fn visit_recursive(&mut self, ty: &'t Type) { self.visit_ty(ty); } fn visit_ty(&mut self, ty: &'t Type) { self.walk_ty(ty); } fn walk_ty(&mut self, ty: &'t Type) { use TypeKind::*; match &ty.kind { Int => {} Bool => {} Unit => {} Infer => {} Defined(s) => self.visit_defined(s), Variable(s) => self.visit_variable(s), Function(ty1, ty2) => self.visit_function(ty1, ty2), Sum(var) => self.visit_sum(var), Product(tys) => self.visit_product(tys), Record(rows) => self.visit_record(rows), Existential(s, k, ty) => self.visit_existential(s, k, ty), Universal(s, k, ty) => self.visit_universal(s, k, ty), Abstraction(s, k, ty) => self.visit_abstraction(s, k, ty), Application(ty1, ty2) => self.visit_application(ty1, ty2), Recursive(ty) => self.visit_recursive(ty), } } } ================================================ FILE: 07_system_fw/src/terms.rs ================================================ use crate::types::{TyKind, Type}; use util::span::Span; /// Constant expression or pattern #[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Hash)] pub enum Constant { Unit, Bool(bool), Nat(u32), } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum Kind { /// Constant Const(Constant), /// Variable Var(usize), /// Term abstraction /// \x: Ty. x /// Introduce a lambda term Abs(Box, Box), /// Term application /// m n /// Eliminate a lambda term App(Box, Box), /// Type abstraction /// \X. \x: X. x /// Introduce a universally quantified type TyAbs(Box, Box), /// Type application /// id [Nat] 1 /// Eliminate a universally quantified type TyApp(Box, Box), /// Record term /// {label1 = Tm1, label2 = Tm2, etc} /// Invariant that all fields have unique labels Record(Record), Index(Box, String), /// Injection into a sum type /// fields: type constructor tag, term, and sum type Injection(String, Box, Box), Fold(Box, Box), Unfold(Box, Box), /// Introduce an existential type /// { *Ty1, Term } as {∃X.Ty} /// essentially, concrete representation as interface Pack(Box, Box, Box), /// Unpack an existential type /// open {∃X, bind} in body -- X is bound as a TyVar, and bind as Var(0) /// Eliminate an existential type Unpack(Box, Box), } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Field { pub span: Span, pub label: String, pub expr: Box, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Record { pub fields: Vec, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Term { pub span: Span, pub kind: Kind, } impl Term { pub fn new(kind: Kind, span: Span) -> Term { Term { kind, span } } } impl Record { pub fn get(&self, label: &str) -> Option<&Field> { for field in &self.fields { if field.label == label { return Some(field); } } None } } use std::fmt; impl fmt::Display for Term { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match &self.kind { Kind::Var(idx) => write!(f, "#{}", idx), Kind::Const(Constant::Bool(b)) => write!(f, "{}", b), Kind::Const(Constant::Nat(b)) => write!(f, "{}", b), Kind::Const(Constant::Unit) => write!(f, "()"), Kind::Abs(ty, body) => write!(f, "(λx:{}. {})", ty, body), Kind::App(m, n) => write!(f, "{} {}", m, n), Kind::TyAbs(kind, body) => write!(f, "ΛX::{}. {}", kind, body), Kind::TyApp(body, ty) => write!(f, "{} [{}]", body, ty), Kind::Pack(witness, body, sig) => write!(f, "{{*{}, {}}} as {}", witness, body, sig), Kind::Unpack(m, n) => write!(f, "unpack {} as {}", m, n), Kind::Record(rec) => write!( f, "{{\n{}\n}}", rec.fields .iter() .map(|fi| format!("\t{}: {}", fi.label, fi.expr)) .collect::>() .join(",\n") ), Kind::Index(t1, t2) => write!(f, "{}.{}", t1, t2), Kind::Injection(label, tm, ty) => write!(f, "{} of {} as {}", label, tm, ty), Kind::Fold(ty, term) => write!(f, "fold [{}] {}", ty, term), Kind::Unfold(ty, term) => write!(f, "unfold [{}] {}", ty, term), } } } ================================================ FILE: 07_system_fw/src/typecheck.rs ================================================ use crate::diagnostics::Diagnostic; use crate::stack::Stack; use crate::terms::{Constant, Field, Kind, Record, Term}; use crate::types::{MutTypeVisitor, TyField, TyKind, Type}; use util::span::Span; /// A typing context, Γ #[derive(Debug)] pub struct Context { stack: Stack, kstack: Stack, } impl Default for Context { fn default() -> Context { Context { stack: Stack::with_capacity(16), kstack: Stack::with_capacity(16), } } } #[derive(Debug, PartialEq)] pub enum KindError { Mismatch(TyKind, TyKind), NotArrow(TyKind), NotProduct, Unbound(usize), } struct TypeSimplifier<'a> { ctx: &'a mut Context, res: Result, } impl<'a> TypeSimplifier<'a> {} impl<'a> MutTypeVisitor for TypeSimplifier<'a> { fn visit_universal(&mut self, kind: &mut TyKind, ty: &mut Type) { self.ctx.kstack.push(kind.clone()); self.visit(ty); self.ctx.kstack.pop(); } fn visit_existential(&mut self, kind: &mut TyKind, ty: &mut Type) { self.ctx.kstack.push(kind.clone()); self.visit(ty); self.ctx.kstack.pop(); } fn visit_abs(&mut self, kind: &mut TyKind, ty: &mut Type) { self.ctx.kstack.push(kind.clone()); self.visit(ty); self.ctx.kstack.pop(); } fn visit(&mut self, ty: &mut Type) { match ty { Type::App(m, n) => { self.visit(m); self.visit(n); if let Type::Abs(k, t) = m.as_mut() { match self.ctx.kinding(&n) { Ok(n_kind) => { if k.as_ref() == &n_kind { t.subst(*n.clone()); *ty = *t.clone(); self.res = Ok(true); } else { self.res = Err(KindError::Mismatch(*k.clone(), n_kind)) } } Err(e) => self.res = Err(e), } } } Type::Projection(inner, idx) => { self.visit(inner); match inner.as_ref() { Type::Product(v) => match v.get(*idx) { Some(t) => { *ty = t.clone(); self.res = Ok(true); } None => { self.res = Err(KindError::Unbound(*idx)); } }, // Type::Var(_) => {} // _ => self.res = Err(KindError::NotProduct), _ => {} } } _ => self.walk(ty), } } } impl KindError { fn to_diag(self, span: Span) -> Diagnostic { match self { KindError::Mismatch(k1, k2) => Diagnostic::error( span, format!( "a type of kind {:?} is required, but a type of kind {:?} was supplied", k1, k2 ), ), KindError::NotArrow(k) => Diagnostic::error( span, format!( "a type of kind *->* was required, but a type of kind {:?} was supplied", k ), ), KindError::Unbound(idx) => { Diagnostic::error(span, format!("unbound type variable with de Bruijn index {}", idx)) } KindError::NotProduct => Diagnostic::error(span, format!("a product kind is required")), } } } impl Context { pub fn kinding(&mut self, ty: &Type) -> Result { match ty { Type::Var(idx) => self.kstack.get(*idx).cloned().ok_or(KindError::Unbound(*idx)), Type::Abs(kind, t) => { self.kstack.push(*kind.clone()); let k_ = self.kinding(&t)?; self.kstack.pop(); Ok(TyKind::Arrow(kind.clone(), Box::new(k_))) } Type::App(s, t) => match self.kinding(&s)? { TyKind::Arrow(a, b) => { let k = self.kinding(&t)?; if k == *a { Ok(*b) } else { Err(KindError::Mismatch(*a, k)) } } k => Err(KindError::NotArrow(k)), }, Type::Arrow(s, t) => match self.kinding(&s)? { TyKind::Star => match self.kinding(&t)? { TyKind::Star => Ok(TyKind::Star), k => Err(KindError::Mismatch(TyKind::Star, k)), }, k => Err(KindError::Mismatch(TyKind::Star, k)), }, Type::Universal(kind, t) => { self.kstack.push(*kind.clone()); let k_ = self.kinding(&t)?; self.kstack.pop(); match k_ { TyKind::Star => Ok(TyKind::Star), k => Err(KindError::Mismatch(TyKind::Star, k)), } } Type::Existential(kind, t) => { self.kstack.push(*kind.clone()); let k_ = self.kinding(&t)?; self.kstack.pop(); match k_ { TyKind::Star => Ok(TyKind::Star), k => Err(KindError::Mismatch(TyKind::Star, k)), } } Type::Record(fields) | Type::Sum(fields) => { for f in fields { let k = self.kinding(&f.ty)?; if k != TyKind::Star { return Err(KindError::Mismatch(TyKind::Star, k)); } } Ok(TyKind::Star) } Type::Product(tys) => { // for ty in tys { // let k = self.kinding(&ty)?; // if k != TyKind::Star { // return Err(KindError::Mismatch(TyKind::Star, k)); // } // } // Ok(TyKind::Star) tys.iter() .map(|t| self.kinding(t)) .collect::, _>>() .map(TyKind::Product) } Type::Projection(ty, idx) => { match self.kinding(ty)? { TyKind::Product(v) => v.get(*idx).cloned().ok_or(KindError::Unbound(*idx)), k => Err(KindError::Mismatch(TyKind::Product(vec![]), k)), } // match ty.as_ref() { // Type::Product(v) => // v.get(idx).ok_or(KindError::Unbound(idx)), // _ => Err(KindError::Mismatch()) // } } Type::Recursive(inner) => { let k = self.kinding(inner)?; match k { TyKind::Arrow(k1, k2) => { if &k1 == &k2 { Ok(*k1) } else { Err(KindError::Mismatch(*k1, *k2)) } } _ => Err(KindError::NotArrow(k)), } } _ => Ok(TyKind::Star), } } pub fn simplify_ty(&mut self, ty: &mut Type) -> Result { let mut ts = TypeSimplifier { ctx: self, res: Ok(false), }; let mut work = false; loop { ts.res = Ok(false); ts.visit(ty); match ts.res { Ok(false) => return Ok(work), Ok(true) => work = true, Err(e) => return Err(e), } } } pub fn equiv(&mut self, lhs: &Type, rhs: &Type) -> Result { if lhs == rhs { Ok(true) } else { let mut lhs_ = lhs.clone(); let mut rhs_ = rhs.clone(); self.simplify_ty(&mut lhs_)?; self.simplify_ty(&mut rhs_)?; Ok(lhs_ == rhs_) } } fn is_star_kind(&mut self, ty: &Type, err_span: Span) -> Result<(), Diagnostic> { let kind = self.kinding(ty).map_err(|k| KindError::to_diag(k, err_span))?; if kind == TyKind::Star { Ok(()) } else { return Err(Diagnostic::error( err_span, format!( "type bound in type abstraction must have a kind *, {} has a kind of {}", ty, kind ), )); } } pub fn typecheck(&mut self, term: &Term) -> Result { match &term.kind { Kind::Const(c) => match c { Constant::Unit => Ok(Type::Unit), Constant::Nat(_) => Ok(Type::Nat), Constant::Bool(_) => Ok(Type::Bool), }, Kind::Var(idx) => self .stack .get(*idx) .cloned() .ok_or(Diagnostic::error(term.span, "unbound variable")), Kind::Abs(ty, tm) => { self.is_star_kind(ty, term.span)?; self.stack.push(*ty.clone()); let ty2 = self.typecheck(&tm)?; self.stack.pop(); Ok(Type::Arrow(ty.clone(), Box::new(ty2))) } Kind::App(m, n) => { let mut ty = self.typecheck(&m)?; self.simplify_ty(&mut ty).map_err(|ke| ke.to_diag(term.span))?; if let Type::Arrow(ty11, ty12) = ty { let ty2 = self.typecheck(&n)?; if self.equiv(&ty11, &ty2).map_err(|ke| ke.to_diag(term.span))? { Ok(*ty12) } else { dbg!(&self.stack); dbg!(&self.kstack); let d = Diagnostic::error(term.span, "type mismatch in application") .message( m.span, format!("abstraction {} requires type {} to return {}", m, ty11, ty12), ) .message(n.span, format!("term {} has a type of {}", n, ty2)); return Err(d); } } else { // dbg!(&self.stack); dbg!(&m); dbg!(&ty); let d = Diagnostic::error(term.span, "type mismatch in application") .message(m.span, format!("this term {} has a type {}, not T->U", m, ty)); return Err(d); } } Kind::TyAbs(tk, polymorphic) => { // Reference commit log eda3417 for explanation of below code // TODO: Do we need to the same thing for the kinding stack? self.kstack.push(*tk.clone()); self.stack.iter_mut().for_each(|ty| match ty { Type::Var(v) => *v += 1, _ => {} }); let ty = self.typecheck(&polymorphic)?; self.stack.iter_mut().for_each(|ty| match ty { Type::Var(v) => *v -= 1, _ => {} }); self.kstack.pop(); Ok(Type::Universal(tk.clone(), Box::new(ty))) } Kind::TyApp(tyabs, applied) => { let mut ty = self.typecheck(&tyabs)?; self.simplify_ty(&mut ty).map_err(|ke| ke.to_diag(term.span))?; match ty { Type::Universal(k1, u) => { let k2 = self.kinding(&applied).map_err(|k| KindError::to_diag(k, term.span))?; if k2 == *k1 { // actually do subst let mut u = *u; u.subst(*applied.clone()); Ok(u) } else { let d = Diagnostic::error(term.span, "type kind mismatch in term-level type application") .message( tyabs.span, format!( "universal type requires a type of kind {}, but a kind of {} is given", &k1, k2 ), ); Err(d) } } ty => { let d = Diagnostic::error(term.span, "type mismatch in term-level type application") .message(tyabs.span, format!("this term has a type {}, not forall. X::K", ty)); Err(d) } } } Kind::Record(rec) => { let mut tys = Vec::with_capacity(rec.fields.len()); for f in &rec.fields { tys.push(TyField { label: f.label.clone(), ty: Box::new(self.typecheck(&f.expr)?), }); } Ok(Type::Record(tys)) } Kind::Index(tm, field) => { let ty = self.typecheck(tm)?; match ty { Type::Record(fields) => { for f in fields { if &f.label == field { return Ok(*f.ty); } } Err(Diagnostic::error(term.span, "invalid field access") .message(tm.span, format!("term does not have a label named {}", field))) } _ => Err(Diagnostic::error(term.span, "invalid field access") .message(tm.span, format!("term has a type of {}, not Record", ty))), } } Kind::Injection(label, tm, ty) => { let mut ty = ty.clone(); self.simplify_ty(&mut ty).map_err(|ke| ke.to_diag(term.span))?; match ty.as_ref() { Type::Sum(fields) => { for f in fields { if label == &f.label { let mut ty_ = self.typecheck(tm)?; self.simplify_ty(&mut ty_).map_err(|ke| ke.to_diag(term.span))?; if &ty_ == f.ty.as_ref() { return Ok(*ty.clone()); } else { let d = Diagnostic::error(term.span, "Invalid associated type in variant").message( tm.span, format!("variant {} requires type {}, but this is {}", label, f.ty, ty_), ); return Err(d); } } } Err(Diagnostic::error( term.span, format!( "constructor {} does not belong to the variant {}", label, fields .iter() .map(|f| f.label.clone()) .collect::>() .join(" | ") ), )) } _ => Err(Diagnostic::error( term.span, format!("Cannot inject {} into non-variant type {}", label, ty), )), } } Kind::Unfold(rec, tm) => match rec.as_ref() { Type::Recursive(inner) => { let ty_ = self.typecheck(&tm)?; if self.equiv(&ty_, &rec).map_err(|ke| ke.to_diag(term.span))? { let mut s = inner.clone(); s.subst(*rec.clone()); Ok(*s) } else { let d = Diagnostic::error(term.span, "Type mismatch in unfold") .message(term.span, format!("unfold requires type {}", rec)) .message(tm.span, format!("term has a type of {}", ty_)); Err(d) } } _ => Err(Diagnostic::error( term.span, format!("Expected a recursive type, not {}", rec), )), }, Kind::Fold(rec, tm) => { // Fold takes an argument of type μF, term where term types to F(μF) // and wraps the type into μF // Alternatively, Fold can also take an argument of type μF A, // term where term types to F(μF) A let rec = rec.clone(); // self.simplify_ty(&mut rec) // .map_err(|ke| ke.to_diag(term.span))?; match rec.as_ref() { Type::Recursive(inner) => { let ty_ = self.typecheck(&tm)?; let mut s = Type::App(inner.clone(), rec.clone()); self.simplify_ty(&mut s).map_err(|ke| ke.to_diag(term.span))?; if self.equiv(&ty_, &s).map_err(|ke| ke.to_diag(term.span))? { Ok(*rec) } else { let d = Diagnostic::error(term.span, "Type mismatch in fold") .message(term.span, format!("fold requires type {}", s)) .message(tm.span, format!("term has a type of {}", ty_)); Err(d) } } Type::App(ty1, ty2) => match ty1.as_ref() { Type::Recursive(rec_inner) => { let mut apped = Type::App(Box::new(Type::App(rec_inner.clone(), ty1.clone())), ty2.clone()); self.simplify_ty(&mut apped).map_err(|ke| ke.to_diag(term.span))?; let ty_ = self.typecheck(&tm)?; if self.equiv(&ty_, &apped).map_err(|ke| ke.to_diag(term.span))? { Ok(*rec) } else { let d = Diagnostic::error(term.span, "Type mismatch in fold") .message(term.span, format!("fold requires type {}", apped)) .message(tm.span, format!("term has a type of {}", ty_)); Err(d) } } _ => { let d = Diagnostic::error( term.span, format!("Fold requires a recursive type in application of {} to {}", ty1, ty2), ); Err(d) } }, _ => Err(Diagnostic::error( term.span, format!("Expected a recursive type, not {}", rec), )), } } Kind::Pack(witness, packed, sig) => { // where sig = {∃X, T2} // Γ⊢ packed : [ X -> witness ] T2 and Γ⊢ {∃X, T2} :: * // then {*witness, packed} as {∃X, T2} let k = self.kinding(sig).map_err(|k| KindError::to_diag(k, term.span))?; if k != TyKind::Star { return diag!(term.span, "existential type definition does not have a kind of *"); } let mut sig = sig.clone(); self.simplify_ty(&mut sig).map_err(|ke| ke.to_diag(term.span))?; match sig.as_mut() { Type::Existential(kind, t2) => { let witness_kind = self.kinding(&witness).map_err(|k| KindError::to_diag(k, term.span))?; if &witness_kind != kind.as_ref() { return diag!( term.span, "existential type requires a type of kind {}, but implementation type has a kind of {}", kind, witness_kind ); } let ty_packed = self.typecheck(packed)?; let mut ty_packed_prime = *t2.clone(); ty_packed_prime.subst(*witness.clone()); // Pierce's code has kind checking before type-substitution, // does this matter? He also directly kind-checks the // witness type against kind if self .equiv(&ty_packed, &ty_packed_prime) .map_err(|ke| ke.to_diag(term.span))? { Ok(*sig.clone()) } else { Err(Diagnostic::error(term.span, "type mismatch in existential package") .message(packed.span, format!("term has a type of {}", ty_packed)) .message( term.span, format!("but the existential package type is defined as {}", ty_packed_prime,), )) } } ty => diag!(term.span, "cannot pack an existential type into {}", ty), } } Kind::Unpack(packed, body) => { let ty = self.typecheck(packed)?; match ty { Type::Existential(kind, sig) => { self.kstack.push(*kind); self.stack.push(*sig); let body_ty = self.typecheck(body)?; self.stack.pop(); self.kstack.pop(); Ok(body_ty) } _ => Err(Diagnostic::error(term.span, "type mismatch during unpack") .message(packed.span, format!("term has a type of {}, not {{∃X::K, T}}", ty))), } } } } } #[cfg(test)] mod test { use super::*; #[test] fn ty_app() { // ΛX. λx: X. x should typecheck to ∀X. X->X let id = tyabs!(kind!(*), abs!(Type::Var(0), var!(0))); // Instantiations of polymorphic identity function ∀X. X->X let inst1 = tyapp!(id.clone(), Type::Nat); let inst2 = tyapp!(id.clone(), arrow!(Type::Unit, Type::Nat)); let inst3 = tyapp!(id.clone(), univ!(kind!(*), arrow!(Type::Var(0), Type::Var(0)))); let mut ctx = Context::default(); // ΛX. λx: X. x should typecheck to ∀X. X->X assert_eq!( ctx.typecheck(&id), Ok(univ!(kind!(*), arrow!(Type::Var(0), Type::Var(0)))) ); assert_eq!(ctx.typecheck(&inst1), Ok(arrow!(Type::Nat, Type::Nat))); assert_eq!( ctx.typecheck(&inst2), Ok(arrow!(arrow!(Type::Unit, Type::Nat), arrow!(Type::Unit, Type::Nat))) ); assert_eq!( ctx.typecheck(&inst3), Ok(arrow!( univ!(kind!(*), arrow!(Type::Var(0), Type::Var(0))), univ!(kind!(*), arrow!(Type::Var(0), Type::Var(0))) )) ); } #[test] fn ty_exist() { let interface_ty = exist!( kind!(*), record!(("new", Type::Var(0)), ("get", arrow!(Type::Var(0), Type::Nat))) ); let adt = Term::new( Kind::Record(Record { fields: vec![ Field { span: Span::zero(), label: "new".to_string(), expr: Box::new(nat!(0)), }, Field { span: Span::zero(), label: "get".to_string(), expr: Box::new(abs!(Type::Nat, var!(0))), }, ], }), Span::zero(), ); let counter = pack!(Type::Nat, adt, interface_ty.clone()); let mut ctx = Context::default(); assert_eq!(ctx.typecheck(&counter), Ok(interface_ty)); let unpacked = unpack!(counter, app!(access!(var!(0), "get"), access!(var!(0), "new"))); assert_eq!(ctx.typecheck(&unpacked), Ok(Type::Nat)); } #[test] fn ty_abs() { let ty = tyop!(kind!(*), Type::Var(0)); let mut ctx = Context::default(); assert_eq!(ctx.kinding(&ty), Ok(kind!(* => *))); assert_eq!(ctx.kinding(&Type::App(Box::new(ty), Box::new(Type::Nat))), Ok(kind!(*))); let pair = tyop!(kind!(*), tyop!(kind!(*), univ!(kind!(*), Type::Var(0)))); assert_eq!(ctx.kinding(&pair), Ok(kind!(kind!(*) => kind!(* => *)))); } #[test] fn ty_equivalence() { let ty1 = op_app!(tyop!(kind!(*), Type::Var(0)), Type::Nat); let ty2 = Type::Nat; let mut ctx = Context::default(); assert_eq!(ctx.equiv(&ty1, &ty2), Ok(true), "{:?}", ty1); } #[test] fn ty_record() { let adt = Term::new( Kind::Record(Record { fields: vec![ Field { span: Span::zero(), label: "new".to_string(), expr: Box::new(nat!(0)), }, Field { span: Span::zero(), label: "get".to_string(), expr: Box::new(abs!(Type::Nat, var!(0))), }, ], }), Span::zero(), ); let mut ctx = Context::default(); let rty = record!(("new", Type::Nat), ("get", arrow!(Type::Nat, Type::Nat))); assert_eq!(ctx.typecheck(&adt), Ok(rty)); let t1 = access!(adt.clone(), "new"); let t2 = access!(adt.clone(), "get"); assert_eq!(ctx.typecheck(&t1), Ok(Type::Nat)); assert_eq!(ctx.typecheck(&t2), Ok(arrow!(Type::Nat, Type::Nat))); } } ================================================ FILE: 07_system_fw/src/types.rs ================================================ use std::convert::TryFrom; use std::fmt; #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum Type { Unit, Nat, Bool, Var(usize), Record(Vec), Product(Vec), Projection(Box, usize), Arrow(Box, Box), Universal(Box, Box), Existential(Box, Box), Abs(Box, Box), App(Box, Box), Recursive(Box), Sum(Vec), } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct TyField { pub label: String, pub ty: Box, } #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum TyKind { Star, Arrow(Box, Box), Product(Vec), } impl Type { pub fn subst(&mut self, mut s: Type) { Shift::new(1).visit(&mut s); Subst::new(s).visit(self); Shift::new(-1).visit(self); } /// Support function for quickly accessing labelled subtypes pub fn label(&self, label: &str) -> Option<&Type> { match self { Type::Sum(fields) | Type::Record(fields) => { for f in fields { if f.label == label { return Some(&f.ty); } } None } _ => None, } } } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match &self { Type::Var(idx) => write!(f, "TyVar({})", idx), Type::Unit => write!(f, "Unit"), Type::Nat => write!(f, "Nat"), Type::Bool => write!(f, "Bool"), Type::Abs(kind, ty) => write!(f, "(ΛX::{}. {})", kind, ty), Type::App(m, n) => write!(f, "{} {}", m, n), Type::Arrow(m, n) => write!(f, "{}->{}", m, n), Type::Universal(k, ty) => write!(f, "∀X::{}. {}", k, ty), Type::Existential(k, ty) => write!(f, "{{∃X::{}, {}}}", k, ty), Type::Record(fields) => write!( f, "{{\n{}\n}}", fields .iter() .map(|fi| format!("\t{}: {}", fi.label, fi.ty)) .collect::>() .join(",\n") ), Type::Product(tys) => write!( f, "({})", tys.iter().map(|ty| ty.to_string()).collect::>().join(",") ), Type::Projection(ty, idx) => write!(f, "{}.{}", ty, idx), Type::Sum(fields) => write!( f, "{}", fields .iter() .map(|fi| format!("{} {}", fi.label, fi.ty)) .collect::>() .join("|") ), Type::Recursive(inner) => write!(f, "rec {}", inner), } } } impl fmt::Display for TyKind { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match &self { TyKind::Star => write!(f, "*"), TyKind::Arrow(k1, k2) => match k2.as_ref() { TyKind::Star => match k1.as_ref() { TyKind::Star => write!(f, "{}->{}", k1, k2), TyKind::Arrow(k11, k12) => write!(f, "({}->{})->{}", k11, k12, k2), _ => write!(f, "{}->{}", k1, k2), }, TyKind::Arrow(_, _) => write!(f, "{}->({})", k1, k2), k => write!(f, "{}->{}", k1, k), }, TyKind::Product(v) => { let s = v.iter().map(|k| format!("{}", k)).collect::>().join(","); write!(f, "({})", s) } } } } pub trait MutTypeVisitor: Sized { fn visit_var(&mut self, _: &mut usize) {} fn visit_arrow(&mut self, ty1: &mut Type, ty2: &mut Type) { self.visit(ty1); self.visit(ty2); } fn visit_universal(&mut self, _: &mut TyKind, ty: &mut Type) { self.visit(ty); } fn visit_existential(&mut self, _: &mut TyKind, ty: &mut Type) { self.visit(ty); } fn visit_abs(&mut self, _: &mut TyKind, ty: &mut Type) { self.visit(ty); } fn visit_app(&mut self, s: &mut Type, t: &mut Type) { self.visit(s); self.visit(t); } fn visit_record(&mut self, fields: &mut [TyField]) { for f in fields { self.visit(f.ty.as_mut()); } } fn visit_product(&mut self, tys: &mut [Type]) { for ty in tys { self.visit(ty); } } fn visit_projection(&mut self, ty: &mut Type, _: usize) { self.visit(ty); } fn visit_recursive(&mut self, ty: &mut Type) { self.visit(ty); } fn visit(&mut self, ty: &mut Type) { self.walk(ty); } fn walk(&mut self, ty: &mut Type) { match ty { Type::Unit | Type::Bool | Type::Nat => {} Type::Var(v) => self.visit_var(v), Type::Record(fields) => self.visit_record(fields), Type::Product(tys) => self.visit_product(tys), Type::Projection(ty, _) => self.visit(ty), Type::Sum(variants) => self.visit_record(variants), Type::Recursive(ty1) => self.visit_recursive(ty1), Type::Arrow(ty1, ty2) => self.visit_arrow(ty1, ty2), Type::Universal(k, ty) => self.visit_universal(k, ty), Type::Existential(k, ty) => self.visit_existential(k, ty), Type::Abs(s, t) => self.visit_abs(s, t), Type::App(k, t) => self.visit_app(k, t), } } } pub struct Shift { pub cutoff: usize, pub shift: isize, } impl Shift { pub const fn new(shift: isize) -> Shift { Shift { cutoff: 0, shift } } } impl MutTypeVisitor for Shift { fn visit_var(&mut self, var: &mut usize) { if *var >= self.cutoff { *var = usize::try_from(*var as isize + self.shift).expect("Type variable has been shifted below 0! Fatal bug"); } } fn visit_universal(&mut self, _: &mut TyKind, ty: &mut Type) { self.cutoff += 1; self.visit(ty); self.cutoff -= 1; } fn visit_existential(&mut self, _: &mut TyKind, ty: &mut Type) { self.cutoff += 1; self.visit(ty); self.cutoff -= 1; } fn visit_abs(&mut self, _: &mut TyKind, ty: &mut Type) { self.cutoff += 1; self.visit(ty); self.cutoff -= 1; } // fn visit_recursive(&mut self, ty: &mut Type) { // self.cutoff += 1; // self.visit(ty); // self.cutoff -= 1; // } } pub struct Subst { pub cutoff: usize, pub ty: Type, } impl Subst { pub fn new(ty: Type) -> Subst { Subst { cutoff: 0, ty } } } impl MutTypeVisitor for Subst { fn visit_universal(&mut self, _: &mut TyKind, ty: &mut Type) { self.cutoff += 1; self.visit(ty); self.cutoff -= 1; } fn visit_existential(&mut self, _: &mut TyKind, ty: &mut Type) { self.cutoff += 1; self.visit(ty); self.cutoff -= 1; } fn visit_abs(&mut self, _: &mut TyKind, ty: &mut Type) { self.cutoff += 1; self.visit(ty); self.cutoff -= 1; } // fn visit_recursive(&mut self, ty: &mut Type) { // self.cutoff += 1; // self.visit(ty); // self.cutoff -= 1; // } fn visit(&mut self, ty: &mut Type) { match ty { Type::Var(v) if *v == self.cutoff => { Shift::new(self.cutoff as isize).visit(&mut self.ty); *ty = self.ty.clone(); } _ => self.walk(ty), } } } ================================================ FILE: 07_system_fw/test.fw ================================================ datatype 'a list = Nil | Cons of 'a * 'a list datatype 'a option = None | Some of 'a datatype ('a, 'b) either = Left of 'a | Right of 'b fun t (Left 10) = 10 | t (Right _) = 9 fun test (Some _) 9 = 9 | test (Some 8) 8 = 8 | test (None) 7 = 0 datatype expr = Var of int | Abs of expr | App of expr * expr; type env = expr * list datatype kont = Mt | Fn of expr * env * kont | Arg of expr * env * kont fun lookup env i = i fun extend env x = x fun step (c: expr) (e: env) (k: kont) : (expr, env, kont) = case (c, e, k) of | (Var i, e, k) => (lookup e i, e, k) | (App (e1, e2), e, k) => (e1, e, Arg (e2, e, k)) | (Abs x, env, Arg (a, e, k)) => (a, e, Fn (x, env, k)) | (Abs x, env, Fn (a, e, k)) => (a, extend e x, k), | (c, e, Mt) => (c, e, Mt) | (_, _, _) => (Var 0, Nil, Mt) end ================================================ FILE: Cargo.toml ================================================ [workspace] members = ["01_arith", "02_lambda", "03_typedarith", "04_stlc", "05_recon", "06_system_f", "07_system_fw", "x1_bidir", "x2_dependent", "util"] ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2019 Michael Lazear Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # types-and-programming-languages ![](https://github.com/lazear/types-and-programming-languages/workflows/Rust/badge.svg) Several Rust implementations of exercises from Benjamin Pierce's "Types and Programming Languages" are organized into different folders, as described below: - `arith` is an implementation of the untyped lambda calculus extended with simple numeric operations - `lambda` is an implementation of the untyped lambda calculus as presented in chapter 5, 6, and 7. - `typedarith` is the `arith` project extended with simple types: `Nat` and `Bool` - `stlc` is an implementation of the simply typed lambda calculus, as discussed in chapters 9 and 10 of TAPL. This simply typed calculus has the types, `Unit`, `Nat`, `Bool`, `T -> T` (arrow), and `Record` types. - `recon` contains several implementations of Hindley-Milner based type reconstruction from the untyped lambda calculus to System F, with let-polymorphism. Both Algorithm W (the more common) and Algorithm J (the more efficient) are presented. For Alg. W, both a naive equality constraint solver, and a faster union-find (with path compression) solver are provided. Algorithm J makes use shared mutable references to promote type sharing instead. - `system_f` contains a parser, typechecker, and evaluator for the simply typed lambda calculus with parametric polymorphism (System F). The implementation of System F is the most complete so far, and I've tried to write a parser, typechecker and diagnostic system that can given meaningful messages - `system_fw` contains a parser for a high-level, Standard ML like source language that is desugared into an HIR, and then System F-omega. This extends `system_f` with type operators and higher-kinded types. This is where most of the ongoing work is located, as I'd like to make this the basis of a toy (but powerful, and useable) programming language. Ideally we will have some form of bidirectional type inference. Work on this has accidentally turned into a full fledged [SML compiler](https://github.com/SomewhatML/sml-compiler), so it's likely that I will roll back the work on the system_fw project to just type checking - `bidir` is is an implementation of the bidirectional typechecker from 'Complete and Easy Bidirectional Typechecking', extended with booleans, product, and sum types. I make no claims on the correctness of the implementation for the extended features not present in the paper. - `dependent` is WIP, implementing a simple, dependently typed lambda calculus as discussed in ATAPL. ================================================ FILE: util/.gitignore ================================================ /target **/*.rs.bk .vscode/ ================================================ FILE: util/Cargo.toml ================================================ [package] name = "util" version = "0.1.0" authors = ["Michael Lazear "] edition = "2018" [dependencies] ================================================ FILE: util/src/arena.rs ================================================ //! A safe, fast, and space-efficient typed Arena allocator //! //! # Examples: //! //! ``` //! use util::arena::Arena; //! //! let mut arena: Arena = Arena::default(); //! let index = arena.insert(10); //! //! // Get a reference to the item stored in the arena //! let r = arena.get(index); //! assert_eq!(r, Some(&10)); //! let i = arena.remove(index); //! //! assert_eq!(i, Some(10)); //! //! // Attempting to access the Arena at this index should return None //! // since the item was removed //! assert_eq!(arena.get(index), None); //! ``` //! //! ## Invariants: //! //! - Arena must have a capacity >= `MIN_CAPACITY` (16). Calls to //! `Arena::with_capacity` that use a capacity less than this value will //! default to a capacity of `MIN_CAPACITY` //! //! - The first entry (index 0) in the Arena is used to store the head of the //! list of free/vacant entries in the Arena (free list). As such, all //! `Index`'s are wrappers around a `NonZeroU32`, since accessing the first //! entry in the Arena's internal data would likely cause data corruption #![forbid(unsafe_code)] #![allow(dead_code)] use std::num::NonZeroU32; /// Minimum capacity for an `Arena` pub const MIN_CAPACITY: u32 = 16; /// The `Arena`, an allocator pub struct Arena { data: Vec>, } /// An index into an `Arena` #[derive(PartialEq, PartialOrd, Debug, Copy, Clone)] pub struct Index(NonZeroU32); /// Internal entry data structure #[derive(PartialEq, PartialOrd)] enum Entry { Vacant(Option), Occupied(T), } use std::fmt; impl fmt::Debug for Entry { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Entry::Vacant(opt) => write!(f, "Vacant({:?})", opt), Entry::Occupied(item) => write!(f, "{:?}", item), } } } impl std::default::Default for Arena { fn default() -> Arena { Arena::with_capacity(MIN_CAPACITY) } } impl Arena { /// Allocate an `Arena` capable of storing `n` items before re-allocating /// The mininimum capacity for an `Arena` is specified in `MIN_CAPACITY`, /// which defaults to `16` /// /// # Examples /// ``` /// use util::arena::Arena; /// /// let arena: Arena = Arena::with_capacity(256); /// assert_eq!(arena.capacity(), 256); /// /// let arena2: Arena = Arena::with_capacity(8); /// assert_eq!(arena2.capacity(), 16); /// ``` pub fn with_capacity(n: u32) -> Arena { // Invariant that we have at least MIN_CAPACITY vacant entries when we // initialize the Arena, so setting the first element to point to // element 1 is safe. let mut arena = Arena { data: vec![Entry::Vacant(None)], }; arena.reserve(n.max(MIN_CAPACITY)); arena } /// Returns the number of items the `Arena` can store without allocating pub fn capacity(&self) -> u32 { self.data.capacity() as u32 - 1 } /// Return the index of the next free slot in the `Arena`, if one exists /// or return None. /// /// The free-list head resides in index 0 of the `Arena` #[inline] fn get_free(&self) -> Option { match self.data.get(0) { Some(Entry::Vacant(next)) => *next, _ => None, } } /// Convenience function to set the free-list pointer #[inline] fn set_free(&mut self, next: Option) { self.data[0] = Entry::Vacant(next); } /// Reserve capacity for `n` additional items in the `Arena`, updating /// the free-list head as well to point to the beginning of the new items. /// /// This may cause the Arena to become somewhat segmented, so it may be /// desirable to revisit this behavior fn reserve(&mut self, n: u32) { debug_assert!(n >= MIN_CAPACITY); let start = self.data.len() as u32; let end = start + n; let head = self.get_free(); self.data.reserve(n as usize); // Subtract one for the free-list pointer for idx in start..end - 1 { self.data.push(Entry::Vacant(Some(NonZeroU32::new(idx + 1).unwrap()))); } *self.data.last_mut().unwrap() = Entry::Vacant(head); self.set_free(Some(NonZeroU32::new(start).unwrap())); } /// Attempt to insert an item into the `Arena` without performing any /// additional allocations. /// /// Will return an `Err` if the `Arena` has no remaining capacity /// /// # May panic (should never happen) /// /// In the event that the `Arena`'s free list is corrupted, this /// function will panic pub fn try_insert(&mut self, item: T) -> Result { match self.get_free() { None => Err(item), Some(free) => { let index = free.get() as usize; let old = std::mem::replace(&mut self.data[index], Entry::Occupied(item)); match old { Entry::Occupied(_) => panic!("Corrupted arena!"), Entry::Vacant(next) => { self.set_free(next); Ok(Index(free)) } } } } } /// Insert an `item` into the `Arena`. Insertion will only allocate /// additional storage capacity in the event that there are no free /// slots in the currently allocated space #[inline] pub fn insert(&mut self, item: T) -> Index { match self.try_insert(item) { Ok(idx) => idx, Err(item) => self.reserve_insert(item), } } /// Reserve additional capacity, and then insert an item /// into the newly allocated space fn reserve_insert(&mut self, item: T) -> Index { self.reserve(self.capacity()); self.try_insert(item).map_err(|_| ()).expect("Out of memory") } /// Remove an `item` from the `Arena` at the specified index, /// returning `Some` if the index was occupied, or `None` if /// the index was vacant pub fn remove(&mut self, index: Index) -> Option { let i = index.0.get() as usize; let free = self.get_free(); let prev = std::mem::replace(&mut self.data[i], Entry::Vacant(free)); match prev { Entry::Occupied(item) => { self.set_free(Some(index.0)); Some(item) } Entry::Vacant(_) => { std::mem::replace(&mut self.data[i], prev); None } } } #[inline] /// Get a reference to the item stored at `index`, if it exists pub fn get(&self, index: Index) -> Option<&T> { match self.data.get(index.0.get() as usize) { Some(Entry::Occupied(ptr)) => Some(ptr), _ => None, } } #[inline] /// Return a mutable reference to the data stored at `index`, if it exists pub fn get_mut(&mut self, index: Index) -> Option<&mut T> { match self.data.get_mut(index.0.get() as usize) { Some(Entry::Occupied(ptr)) => Some(ptr), _ => None, } } /// Returns an iterator over the occupied data in the [`Arena`] pub fn iter(&self) -> Iter<'_, T> { Iter { data: &self.data, idx: 1, } } } /// Immutable iterator /// /// This struct is created by the [`iter`] method on [`Arena`] /// /// [`iter`]: Arena::iter() pub struct Iter<'a, T> { data: &'a [Entry], idx: usize, } impl<'a, T> Iterator for Iter<'a, T> { type Item = &'a T; fn next(&mut self) -> Option { while self.idx < self.data.len() { self.idx += 1; if let Entry::Occupied(t) = &self.data[self.idx - 1] { return Some(&t); } } None } } /// An owning iterator over the data in an [`Arena`] pub struct IntoIter { data: Vec>, idx: usize, } impl Iterator for IntoIter { type Item = T; fn next(&mut self) -> Option { while self.idx < self.data.len() { self.idx += 1; if let Entry::Occupied(t) = std::mem::replace(&mut self.data[self.idx - 1], Entry::Vacant(None)) { return Some(t); } } None } } impl IntoIterator for Arena { type Item = T; type IntoIter = IntoIter; /// Creates a consuming iterator fn into_iter(self) -> Self::IntoIter { IntoIter { data: self.data, idx: 1, } } } impl<'a, T> IntoIterator for &'a Arena { type Item = &'a T; type IntoIter = Iter<'a, T>; fn into_iter(self) -> Self::IntoIter { Iter { data: &self.data, idx: 1, } } } impl std::iter::FromIterator for Arena { /// Collect elements from an [`Iterator`] into an [`Arena`] /// /// # Note /// /// You won't be able to receive an [`Index`] for each item fn from_iter>(iter: I) -> Arena { let mut arena = Arena::default(); for i in iter { arena.insert(i); } arena } } #[cfg(test)] mod test { use super::*; #[test] fn index_size() { assert_eq!(std::mem::size_of::(), 4); } #[test] fn smoke_insert() { let mut a = Arena::default(); assert_eq!(a.insert(255u8).0.get(), 1); } #[test] fn smoke_remove() { let mut a = Arena::default(); let idx = a.insert(255u8); assert_eq!(a.remove(idx), Some(255u8)); assert_eq!(a.get(idx), None); } #[test] fn smoke_iter() { let mut a = Arena::with_capacity(32); for i in 0..32 { a.insert(i); } assert_eq!(a.iter().zip(0..32).fold(0, |acc, x| acc + (x.0 - x.1)), 0); } #[test] fn fill() { let mut arena = Arena::default(); assert_eq!(arena.capacity(), MIN_CAPACITY); dbg!(&arena.data); for i in 0..15 { arena.insert(i); } dbg!(&arena.data); assert_eq!(arena.data[0], Entry::Vacant(None)); assert_eq!(arena.capacity(), MIN_CAPACITY); } } ================================================ FILE: util/src/diagnostic.rs ================================================ //! Diagnostic handling for errors detected in source code. //! //! Dropping a [`Diagnostic`] without calling `emit` will cause a [`panic`]! use crate::span::*; /// Struct that handles collecting and reporting Parser errors and diagnostics pub struct Diagnostic<'s> { src: &'s str, messages: Vec>, } impl Diagnostic<'_> { pub fn new(src: &str) -> Diagnostic<'_> { Diagnostic { src, messages: Vec::new(), } } pub fn push>(&mut self, msg: S, span: Span) { self.messages.push(Spanned::new(span, msg.into())); } pub fn error_count(&self) -> usize { self.messages.len() } /// Remove the last error message pub fn pop(&mut self) -> Option { let msg = self.messages.pop()?; let line = self.src.lines().nth(msg.span.start.line as usize)?; Some(format!( "Error occuring at line {}, col: {}: {}\n{}\n{}^{}\n", msg.span.start.line, msg.span.start.col, msg.data, &line, (0..msg.span.start.col).map(|_| ' ').collect::(), (0..msg.span.end.col - msg.span.start.col) .map(|_| '~') .collect::(), )) } #[must_use] /// Emit all remaining error message, if there are any pub fn emit(mut self) -> String { let mut s = String::new(); let lines = self.src.lines().collect::>(); for i in 0..self.messages.len() { let msg: &Spanned = &self.messages[i]; let mut squiggly = (1..msg.span.end.col.saturating_sub(msg.span.start.col)) .map(|_| '~') .collect::(); squiggly.push('^'); s.push_str(&format!( "Error occuring at line {}, col: {}: {}\n{}\n{}^{}\n", msg.span.start.line, msg.span.start.col, msg.data, &lines[msg.span.start.line as usize], (0..msg.span.start.col).map(|_| ' ').collect::(), squiggly )); } self.messages.clear(); s } } impl Drop for Diagnostic<'_> { fn drop(&mut self) { if self.error_count() != 0 { panic!("Diagnostic dropped without handling {} errors!", self.error_count()); } } } ================================================ FILE: util/src/lib.rs ================================================ //! Source code locations and diagnostic reporting that can be shared //! across different projects pub mod arena; pub mod diagnostic; pub mod span; pub mod unsafe_arena; ================================================ FILE: util/src/span.rs ================================================ //! Source code locations and spans use std::fmt; #[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Default)] /// Struct representing a location in a source string pub struct Location { pub line: u32, pub col: u32, pub abs: u32, } impl Location { pub fn new(line: u32, col: u32, abs: u32) -> Location { Location { line, col, abs } } } impl fmt::Display for Location { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}:{}", self.line, self.col) } } #[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Default)] /// A span of code pub struct Span { pub start: Location, pub end: Location, } /// Data with associated code span pub struct Spanned { pub span: Span, pub data: T, } impl Span { pub fn new(start: Location, end: Location) -> Span { Span { start, end } } pub const fn dummy() -> Span { let max = Location { line: std::u32::MAX, col: std::u32::MAX, abs: std::u32::MAX, }; Span { start: max, end: max } } pub const fn zero() -> Span { let max = Location { line: 0, col: 0, abs: 0, }; Span { start: max, end: max } } } impl Spanned { /// Create a new [`Spanned`], representing a [`Span`] and data pair pub fn new(span: Span, data: T) -> Spanned { Spanned { span, data } } /// Map a [`Spanned`]'s wrapped data into a new [`Spanned`] pub fn map B>(self, f: F) -> Spanned { Spanned { span: self.span, data: f(self.data), } } /// Replace the data in self with `src`, dropping the previous data pub fn replace(self, src: V) -> Spanned { Spanned { span: self.span, data: src, } } /// Consume self, returning the wrapped `T` pub fn into_inner(self) -> T { self.data } #[inline] pub fn span(&self) -> Span { self.span } #[inline] pub fn data(self) -> T { self.data } } impl Spanned> { /// Transpose a [`Spanned>`] into a [`Result, /// Spanned>`] pub fn map_result(self) -> Result, Spanned> { let Spanned { span, data } = self; data.map(|t| Spanned::new(span, t)).map_err(|e| Spanned::new(span, e)) } } impl Spanned> { pub fn map_option(self) -> Option> { let Spanned { span, data } = self; data.map(|t| Spanned::new(span, t)) } } impl fmt::Display for Spanned { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}: {}", self.span.start, self.data) } } impl fmt::Debug for Spanned { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{:?}: {:?}", self.span, self.data) } } impl std::ops::Deref for Spanned { type Target = T; fn deref(&self) -> &Self::Target { &self.data } } impl Clone for Spanned { fn clone(&self) -> Self { Spanned { span: self.span, data: self.data.clone(), } } } impl Copy for Spanned {} impl std::ops::Add for Span { type Output = Self; fn add(self, rhs: Self) -> Self::Output { Span { start: self.start, end: rhs.end, } } } impl std::ops::AddAssign for Span { fn add_assign(&mut self, rhs: Self) { self.end = rhs.end; } } ================================================ FILE: util/src/unsafe_arena.rs ================================================ //! A fast and efficient typed arena //! //! Translated from rustc's TypedArena into stable rust //! //! https://doc.rust-lang.org/1.1.0/src/arena/lib.rs.html use std::alloc::{alloc, dealloc, Layout}; use std::cell::Cell; use std::cmp; use std::marker; use std::mem; use std::ptr; pub struct Arena { ptr: Cell<*mut T>, end: Cell<*mut T>, chunk: Cell<*mut Chunk>, marker: marker::PhantomData, } struct Chunk { capacity: usize, entries: usize, prev: *mut Chunk, marker: marker::PhantomData, // data stored here } #[derive(Debug, PartialEq, Copy, Clone)] struct Info { capacity: usize, used: usize, } impl Default for Arena { fn default() -> Arena { Arena { ptr: Cell::new(ptr::null_mut()), end: Cell::new(ptr::null_mut()), chunk: Cell::new(ptr::null_mut()), marker: marker::PhantomData, } } } impl Arena { pub fn with_capacity(capacity: usize) -> Arena { unsafe { let chunk: *mut Chunk = Chunk::new(ptr::null_mut(), capacity); Arena { ptr: Cell::new((*chunk).start()), end: Cell::new((*chunk).end()), chunk: Cell::new(chunk), marker: marker::PhantomData, } } } #[inline] fn can_alloc(&self, n: usize) -> bool { let remaining = self.end.get() as usize - self.ptr.get() as usize; let required = mem::size_of::().checked_mul(n).unwrap(); remaining >= required } #[inline] fn ensure_capacity(&self, n: usize) { if !self.can_alloc(n) { self.grow(n) } } #[inline] fn entries(&self) -> usize { unsafe { let bytes = self.ptr.get() as usize - (*self.chunk.get()).start() as usize; bytes / mem::size_of::() } } #[inline] fn chunks(&self) -> Vec { let mut count = Vec::new(); let mut ptr = self.chunk.get(); unsafe { if !ptr.is_null() { let cap = self.end.get() as usize - (*ptr).start() as usize; count.push(Info { capacity: cap, used: self.entries(), }); ptr = (*ptr).prev; } while !ptr.is_null() { count.push(Info { capacity: (*ptr).capacity, used: (*ptr).entries, }); ptr = (*ptr).prev; } } count } #[inline] fn grow(&self, n: usize) { unsafe { let mut chunk = self.chunk.get(); let mut new_cap; if !chunk.is_null() { (*chunk).entries = self.entries(); new_cap = (*chunk).capacity.checked_mul(2).unwrap(); while new_cap < (*chunk).capacity + n { new_cap = new_cap.checked_mul(2).unwrap(); } } else { new_cap = cmp::max(n, 0x1000 / mem::size_of::()); } // Allocate at least 1 page new_cap = cmp::max(new_cap, 0x1000 / mem::size_of::()); let chunk = Chunk::::new(chunk, new_cap); self.ptr.set((*chunk).start()); self.end.set((*chunk).end()); self.chunk.set(chunk); } } #[inline] pub fn alloc(&self, value: T) -> &mut T { if self.ptr == self.end { self.grow(1); } unsafe { let ptr: &mut T = mem::transmute(self.ptr.get()); ptr::write(ptr, value); self.ptr.set(self.ptr.get().offset(1 as isize)); ptr } } #[inline] unsafe fn alloc_raw_slice(&self, n: usize) -> *mut T { assert!(n != 0); self.ensure_capacity(n); let ptr = self.ptr.get(); self.ptr.set(ptr.offset(n as isize)); ptr } #[inline] pub fn alloc_slice(&self, slice: &[T]) -> &mut [T] where T: Copy, { unsafe { let len = slice.len(); let ptr = self.alloc_raw_slice(len); slice.as_ptr().copy_to_nonoverlapping(ptr, len); std::slice::from_raw_parts_mut(ptr, len) } } } impl std::ops::Drop for Arena { fn drop(&mut self) { unsafe { (*self.chunk.get()).destroy(self.entries()); } } } #[inline] fn padding_needed(l: &Layout, align: usize) -> usize { let len = l.size(); let rounded = len.wrapping_add(align).wrapping_sub(1) & !align.wrapping_sub(1); rounded.wrapping_sub(len) } #[inline] fn round(layout: &Layout, align: usize) -> usize { let pad = padding_needed(layout, align); let offset = layout.size().checked_add(pad).unwrap(); offset } #[inline] fn extend(a: Layout, b: Layout) -> Layout { let new_align = std::cmp::max(a.align(), b.align()); let pad = padding_needed(&a, b.align()); let off = a.size().checked_add(pad).unwrap(); let sz = off.checked_add(b.size()).unwrap(); Layout::from_size_align(sz, new_align).unwrap() } impl Chunk { #[inline] fn layout(capacity: usize) -> Layout { let chunk_layout = Layout::from_size_align(mem::size_of::>(), mem::align_of::>()).unwrap(); let size = mem::size_of::().checked_mul(capacity).unwrap(); let elem_layout = Layout::from_size_align(size, mem::align_of::()).unwrap(); extend(chunk_layout, elem_layout) } unsafe fn new(prev: *mut Chunk, capacity: usize) -> *mut Chunk { let layout = Self::layout(capacity); let chunk = alloc(layout) as *mut Chunk; if chunk.is_null() { panic!("out of memory!"); } (*chunk).prev = prev; (*chunk).capacity = capacity; chunk } #[inline] unsafe fn destroy(&mut self, len: usize) { for i in 0..len { // copy to stack, destructor will run ptr::read(self.start().offset(i as isize)); } let prev = self.prev; let layout = Self::layout(self.capacity); let ptr: *mut Chunk = self; dealloc(ptr as *mut u8, layout); if !prev.is_null() { let entries = (*prev).entries; (*prev).destroy(entries); } } #[inline] pub fn start(&self) -> *mut T { let ptr: *const Chunk = self; let layout = Layout::from_size_align(mem::size_of::>(), mem::align_of::>()).unwrap(); let r = round(&layout, mem::align_of::()); unsafe { let mut p = ptr as usize; p += r; mem::transmute(p) } } #[inline] pub fn end(&self) -> *mut T { unsafe { self.start().offset(self.capacity as isize) } } } #[cfg(test)] mod test { use super::*; struct DropGuard { ptr: *mut usize, } impl std::ops::Drop for DropGuard { fn drop(&mut self) { unsafe { *self.ptr += 1 } } } #[allow(dead_code)] struct Test { data_a: usize, data_b: usize, data_c: [usize; 16], data_d: Vec, data_e: Box, } #[test] fn new_chunk() { unsafe { let ptr: *mut Chunk = Chunk::new(ptr::null_mut(), 256); let mut start = ptr as usize; start += std::mem::size_of::>(); assert_eq!(start, (*ptr).start() as usize); assert_eq!((*ptr).start().offset(256 as isize), (*ptr).end()); } } #[test] fn drop_test() { let mut flag: usize = 0; let arena: Arena = Arena::default(); for _ in 0..32 { arena.alloc(DropGuard { ptr: &mut flag as *mut _, }); } assert_eq!(flag, 0); drop(arena); assert_eq!(flag, 32); } #[test] fn references() { #[derive(Debug, PartialEq)] struct Val(usize); struct Ref<'arena>(&'arena mut Val); let arena = Arena::::with_capacity(32); let r1: Ref = Ref(arena.alloc(Val(1))); let _r2: Ref = Ref(arena.alloc(Val(2))); let _r3: Ref = Ref(arena.alloc(Val(3))); let _r4: Ref = Ref(arena.alloc(Val(4))); (*r1.0) = Val(10); assert_eq!(*r1.0, Val(10)); } #[test] fn slice() { let c = 0x1000 / mem::size_of::(); let a = Arena::with_capacity(c); assert!(a.can_alloc(c)); let v = (0..c - 1).map(|i| i as usize).collect::>(); // chunk 1 will have 2 elements only a.alloc(1); a.alloc(2); assert!(a.can_alloc(c - 2)); // should be in chunk 2 a.alloc_slice(&v); a.alloc(3); let mut new_cap = c.checked_mul(2).unwrap(); while new_cap < c + v.len() { new_cap = new_cap.checked_mul(2).unwrap(); } new_cap = cmp::max(new_cap, 0x1000 / mem::size_of::()); // where is 0x2000 coming from?? assert_eq!( a.chunks(), vec![ Info { capacity: 0x2000, used: v.len() + 1 }, Info { capacity: c, used: 2 } ] ); } } ================================================ FILE: x1_bidir/Cargo.toml ================================================ [package] name = "bidir" version = "0.1.0" authors = ["Michael Lazear "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] ================================================ FILE: x1_bidir/src/helpers.rs ================================================ use super::*; macro_rules! var { ($x:expr) => { Expr::Var($x) }; } macro_rules! app { ($x:expr, $y:expr) => { Expr::App(Box::new($x), Box::new($y)) }; } macro_rules! abs { ($x:expr) => { Expr::Abs(Box::new($x)) }; } macro_rules! ann { ($x:expr, $t:expr) => { Expr::Ann(Box::new($x), Box::new($t)) }; } macro_rules! ife { ($a:expr, $b:expr, $c:expr) => { Expr::If(Box::new($a), Box::new($b), Box::new($c)) }; } macro_rules! arrow { ($t1:expr, $t2:expr) => { Type::Arrow(Box::new($t1), Box::new($t2)) }; } macro_rules! karrow { ($t1:expr, $t2:expr) => { Kind::Arrow(Box::new($t1), Box::new($t2)) }; } macro_rules! forall { ($t1:expr) => { Type::Univ(Box::new(Kind::Star), Box::new($t1)) }; ($k:expr, $t1:expr) => { Type::Univ(Box::new($k), Box::new($t1)) }; } macro_rules! case { ($ex:expr, $pat1:expr => $arm1:expr, $pat2:expr => $arm2:expr) => { Expr::Case( Box::new($ex), Arm { pat: Box::new($pat1), expr: Box::new($arm1), }, Arm { pat: Box::new($pat2), expr: Box::new($arm2), }, ) }; } macro_rules! inj { (l; $ex:expr, $ty:expr) => { Expr::Inj(LR::Left, Box::new($ex), Box::new($ty)) }; (r; $ex:expr, $ty:expr) => { Expr::Inj(LR::Right, Box::new($ex), Box::new($ty)) }; } macro_rules! proj { (l; $ex:expr) => { Expr::Proj(LR::Left, Box::new($ex)) }; (r; $ex:expr) => { Expr::Proj(LR::Right, Box::new($ex)) }; } macro_rules! pair { ($a:expr, $b:expr) => { Expr::Pair(Box::new($a), Box::new($b)) }; } macro_rules! sum { ($a:expr, $b:expr) => { Type::Sum(Box::new($a), Box::new($b)) }; } macro_rules! product { ($a:expr, $b:expr) => { Type::Product(Box::new($a), Box::new($b)) }; } fn ty_display(ty: &Type) -> String { use std::collections::HashMap; let mut map = HashMap::new(); fn walk(ty: &Type, map: &mut HashMap) -> String { let nc = ('a' as u8 + map.len() as u8) as char; let vc = ('A' as u8 + map.len() as u8) as char; match ty { Type::Unit => "()".into(), Type::Bool => "bool".into(), Type::Int => "int".into(), Type::Arrow(a, b) => format!("({}->{})", walk(a, map), walk(b, map)), Type::Univ(k, ty) => format!("forall {}. {}", vc, walk(ty, map)), Type::Exist(idx) => format!("{}", map.entry(*idx).or_insert(nc)), Type::Var(idx) => format!("{}", map.entry(0xdeadbeef + *idx).or_insert(vc)), Type::Sum(a, b) => format!("{} + {}", walk(a, map), walk(b, map)), Type::Product(a, b) => format!("({} x {})", walk(a, map), walk(b, map)), _ => "".into(), } } walk(ty, &mut map) } fn expr_display(ex: &Expr) -> String { use std::collections::HashMap; let mut map = HashMap::new(); fn walk(ex: &Expr, map: &mut HashMap) -> String { match ex { Expr::Unit => "()".into(), Expr::True => "true".into(), Expr::False => "false".into(), Expr::Int(i) => format!("{}", i), Expr::If(e1, e2, e3) => format!("if {} then {} else {}", walk(e1, map), walk(e2, map), walk(e3, map)), Expr::App(a, b) => format!("({} {})", walk(a, map), walk(b, map)), Expr::Abs(body) => { let vc = ('a' as u8 + map.len() as u8) as char; let vc = *map.entry(map.len()).or_insert(vc); format!("(\\{}. {})", vc, walk(body, map)) } Expr::Var(idx) => { let i = map.len() - (*idx + 1); let vc = ('a' as u8 + i as u8) as char; format!("{}", map.entry(i).or_insert(vc)) } Expr::Ann(e, ty) => format!("<{} : {}>", walk(e, map), ty_display(ty)), Expr::Inj(LR::Left, e, ty) => format!("inl {}", walk(e, map)), Expr::Inj(LR::Right, e, ty) => format!("inr {}", walk(e, map)), Expr::Case(e, la, ra) => format!( "case {} of {} => {}, {} => {}", walk(e, map), walk(&la.pat, map), walk(&la.expr, map), walk(&ra.pat, map), walk(&ra.expr, map) ), Expr::Proj(LR::Left, e) => format!("{}.0", walk(e, map)), Expr::Proj(LR::Right, e) => format!("{}.1", walk(e, map)), Expr::Pair(a, b) => format!("({},{})", walk(a, map), walk(b, map)), _ => "".into(), } } walk(ex, &mut map) } impl std::fmt::Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}", ty_display(self)) } } impl std::fmt::Display for Expr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}", expr_display(self)) } } ================================================ FILE: x1_bidir/src/main.rs ================================================ //! "Complete and Easy Bidirectional Typechecking for Higher-Rank Polymorphism" //! Paper by J. Dunfield and N. Krishnaswami //! //! Also see very useful Haskell implementation: //! https://github.com/lexi-lambda/higher-rank/ #![allow(non_snake_case)] #[macro_use] mod helpers; #[derive(Clone, Debug, PartialEq)] enum Kind { Star, Arrow(Box, Box), } /// A source-level type #[derive(Clone, Debug, PartialEq)] enum Type { Unit, Int, Bool, /// A type variable Var(usize), /// The type of functions Arrow(Box, Box), /// Existential type variable that can be instantiated to a monotype Exist(usize), /// Universally quantified type, forall. A Univ(Box, Box), /// Class left/right sum type Sum(Box, Box), /// Simple pair type Product(Box, Box), Abs(Box, Box), App(Box, Box), } impl Type { fn monotype(&self) -> bool { match &self { Type::Univ(_, _) => false, Type::Arrow(t1, t2) => t1.monotype() && t2.monotype(), _ => true, } } /// Collect the free existential variables of the type fn freevars(&self) -> Vec { fn walk(ty: &Type, vec: &mut Vec) { match ty { Type::Unit | Type::Int | Type::Bool | Type::Var(_) => {} Type::Exist(v) => vec.push(*v), Type::Arrow(a, b) => { walk(a, vec); walk(b, vec); } Type::Sum(a, b) => { walk(a, vec); walk(b, vec); } Type::Product(a, b) => { walk(a, vec); walk(b, vec); } Type::Univ(k, a) => walk(a, vec), Type::Abs(k, a) => walk(a, vec), Type::App(a, b) => { walk(a, vec); walk(b, vec); } } } let mut v = Vec::new(); walk(self, &mut v); v } /// Perform de Bruijn shifting, algorithm from TAPL fn shift(&mut self, s: isize) { fn walk(t: &mut Type, c: usize, s: isize) { match t { Type::Var(n) if *n >= c => *n = (*n as isize + s) as usize, Type::Arrow(a, b) => { walk(a, c, s); walk(b, c, s); } Type::Sum(a, b) => { walk(a, c, s); walk(b, c, s); } Type::Product(a, b) => { walk(a, c, s); walk(b, c, s); } Type::Univ(k, a) => walk(a, c + 1, s), Type::Abs(k, a) => walk(a, c + 1, s), Type::App(a, b) => { walk(a, c, s); walk(b, c, s); } Type::Unit | Type::Int | Type::Bool | Type::Var(_) => {} Type::Exist(_) => {} } } walk(self, 0, s); } /// Perform subsitution of type `s` into self, algorithm from TAPL fn subst(&mut self, s: &mut Type) { fn walk(t: &mut Type, c: usize, f: &F) { match t { Type::Var(n) if *n == c => f(t, c), Type::Arrow(a, b) => { walk(a, c, f); walk(b, c, f); } Type::Sum(a, b) => { walk(a, c, f); walk(b, c, f); } Type::Product(a, b) => { walk(a, c, f); walk(b, c, f); } Type::Univ(k, a) => walk(a, c + 1, f), Type::Abs(k, a) => walk(a, c + 1, f), Type::App(a, b) => { walk(a, c, f); walk(b, c, f); } Type::Unit | Type::Int | Type::Bool | Type::Var(_) => {} Type::Exist(_) => {} } } s.shift(1); walk(self, 0, &|f, c| { let mut s = s.clone(); s.shift(c as isize); *f = s }); self.shift(-1); } } #[derive(Copy, Clone, Debug, PartialEq)] enum LR { Left, Right, } /// An expression in our simply typed lambda calculus #[derive(Clone, Debug, PartialEq)] enum Expr { /// The unit expression, () Unit, True, False, If(Box, Box, Box), Int(usize), /// A term variable, given in de Bruijn notation Var(usize), /// A lambda abstraction, with it's body. (\x. body) Abs(Box), /// Application (e1 e2) App(Box, Box), /// Explicit type annotation of a term, (x : A) Ann(Box, Box), /// Injection left/right into a sum type x1 Inj(LR, Box, Box), /// Simplified case expr Case(Box, Arm, Arm), /// Introduction of a pair Pair(Box, Box), /// Projection left/right from a pair Proj(LR, Box), } #[derive(Clone, Debug, PartialEq)] struct Arm { pat: Box, expr: Box, } /// An element in the typing context #[derive(Clone, Debug, PartialEq)] enum Element { /// Universal type variable Var(Kind), /// Term variable typing x : A. We differ from the paper in that we use /// de Bruijn indices for variables, so we don't need to mark which var /// this annotation belongs to - it always belongs to the innermost binding (idx 0) /// and we will find this by traversing the stack Ann(Type), /// Unsolved existential type variable Exist(usize), /// Existential type variable that has been solved /// to some monotype Solved(usize, Type), /// I am actually unsure if we really need a marker, due to how we structure /// scoping, see `with_scope` method. Marker(usize), } #[derive(Clone, Debug, Default, PartialEq)] pub struct Context { /// We model the algorithmic context as a simple stack of elements ctx: Vec, /// We assign fresh exist. variables a unique, strictly increasing number ev: usize, } impl Context { /// Generate a fresh identifier fn fresh_ev(&mut self) -> usize { let e = self.ev; self.ev += 1; e } /// Requires a mutable reference to self because we need to push/pop onto the stack /// in the case of universally quantified variables. However, this can be considered /// mostly immutable, since self should be equal before and after the call fn well_formed(&mut self, ty: &Type) -> bool { match ty { Type::Exist(alpha) => self.ctx.contains(&Element::Exist(*alpha)) || self.find_solved(*alpha).is_some(), Type::Univ(k, alpha) => self.with_scope(Element::Var(*k.clone()), |f| f.well_formed(&alpha)), Type::Var(idx) => self.find_type_var(*idx).is_some(), Type::Arrow(a, b) => self.well_formed(&a) && self.well_formed(&b), Type::Sum(a, b) => self.well_formed(&a) && self.well_formed(&b), Type::Product(a, b) => self.well_formed(&a) && self.well_formed(&b), Type::Abs(k, a) => self.with_scope(Element::Var(*k.clone()), |f| f.well_formed(&a)), Type::App(a, b) => self.well_formed(&a) && self.well_formed(&b), Type::Unit | Type::Int | Type::Bool => true, } } fn check_wf(&mut self, ty: &Type) -> Result { if self.well_formed(ty) { Ok(true) } else { Err(format!("Type {:?} is not well formed!", ty)) } } // Pop off any stack growth incurred from calling `f` fn with_scope T>(&mut self, e: Element, mut f: F) -> T { self.ctx.push(e.clone()); let t = f(self); while let Some(elem) = self.ctx.pop() { if elem == e { break; } } t } /// Apply the context to a type, replacing any solved existential variables /// in the context onto the type, if it contains a matching existential fn apply(&self, ty: Type) -> Type { match ty { Type::Unit | Type::Int | Type::Bool | Type::Var(_) => ty, Type::Arrow(a, b) => Type::Arrow(Box::new(self.apply(*a)), Box::new(self.apply(*b))), Type::Sum(a, b) => Type::Sum(Box::new(self.apply(*a)), Box::new(self.apply(*b))), Type::Product(a, b) => Type::Product(Box::new(self.apply(*a)), Box::new(self.apply(*b))), Type::Abs(k, a) => Type::Abs(k, Box::new(self.apply(*a))), Type::App(a, b) => Type::App(Box::new(self.apply(*a)), Box::new(self.apply(*b))), Type::Univ(k, ty) => Type::Univ(k, Box::new(self.apply(*ty))), Type::Exist(n) => { match self.find_solved(n) { // Apply to the solved variable also - this is important // since we can have solved references deeper in the stack Some(solved) => self.apply(solved.clone()), None => ty, } } } } /// Find the term annotation corresponding to de Bruijn index `idx`. /// We traverse the stack in a reversed order, counting each annotation /// we come across fn find_annotation(&self, idx: usize) -> Option<&Type> { let mut ix = 0; for elem in self.ctx.iter().rev() { match &elem { Element::Ann(ty) => { if ix == idx { return Some(&ty); } ix += 1 } _ => {} } } None } /// Find the term annotation corresponding to de Bruijn index `idx`. /// We traverse the stack in a reversed order, counting each annotation /// we come across fn find_type_var(&self, idx: usize) -> Option<&Kind> { let mut ix = 0; for elem in self.ctx.iter().rev() { match &elem { Element::Var(k) => { if ix == idx { return Some(&k); } ix += 1 } _ => {} } } None } /// Find the monotype associated with a solved existential variable `alpha` /// in the context, if it exists. fn find_solved(&self, alpha: usize) -> Option<&Type> { for elem in &self.ctx { match &elem { Element::Solved(n, ty) if *n == alpha => return Some(ty), _ => {} } } None } /// This is one of the more confusing parts of the paper, IMO. We have to open /// a 'hole' in the context, where we can replace/insert some arbitrary amount /// of bindings where an unsolved existential (or marker, in the paper) was /// previously located fn splice_hole)>(&mut self, exist: usize, f: F) -> Result<(), String> { let (l, r) = self.split_context(exist)?; f(&mut l.ctx); l.ctx.extend(r); Ok(()) } fn split_context(&mut self, exist: usize) -> Result<(&mut Self, Vec), String> { let mut ret = None; for (idx, el) in self.ctx.iter().enumerate() { match el { Element::Exist(n) if *n == exist => ret = Some(idx), _ => {} } } let idx = ret.ok_or_else(|| format!("{} not bound in ctx", exist))?; let rest = self.ctx.split_off(idx + 1); self.ctx.pop(); Ok((self, rest)) } fn kinding(&mut self, ty: &Type) -> Option { Some(Kind::Star) } fn beta_reduce(&mut self, ty: &mut Type) -> Result<(), String> { match ty { Type::Unit | Type::Int | Type::Bool | Type::Var(_) => Ok(()), Type::Exist(v) => Ok(()), Type::Arrow(a, b) => { self.beta_reduce(a)?; self.beta_reduce(b) } Type::Sum(a, b) => { self.beta_reduce(a)?; self.beta_reduce(b) } Type::Product(a, b) => { self.beta_reduce(a)?; self.beta_reduce(b) } Type::Univ(k, a) => self.with_scope(Element::Var(*k.clone()), |f| f.beta_reduce(a)), Type::Abs(k, a) => self.with_scope(Element::Var(*k.clone()), |f| f.beta_reduce(a)), Type::App(a, b) => { self.beta_reduce(a)?; self.beta_reduce(b)?; if let Type::Abs(k, body) = a.as_mut() { match self.kinding(b) { Some(arg_kind) => { if &arg_kind == k.as_ref() { body.subst(b); *ty = *body.clone(); self.beta_reduce(ty) } else { Err(format!("kind mismatch")) } } None => Err(format!("No kind!?")), } } else { Ok(()) } } } } fn subtype(&mut self, mut a: Type, mut b: Type) -> Result<(), String> { println!("{:?} <: {:?}", a, b); self.beta_reduce(&mut a)?; self.beta_reduce(&mut b)?; println!("{:?} <: {:?}", a, b); use Type::*; match (a, b) { (Bool, Bool) => Ok(()), (Int, Int) => Ok(()), // Rule <: Unit (Unit, Unit) => Ok(()), // Rule <: Var (Var(a), Var(b)) if a == b => Ok(()), // Rule <: Exvar (Exist(a), Exist(b)) if a == b => Ok(()), // Rule <: -> (Arrow(a1, a2), Arrow(b1, b2)) => { self.subtype(*b1, *a1)?; self.subtype(self.apply(*a2), self.apply(*b2)) } (Sum(l1, r1), Sum(l2, r2)) => { self.subtype(*l1, *l2)?; self.subtype(self.apply(*r1), self.apply(*r2)) } (Product(l1, r1), Product(l2, r2)) => { self.subtype(*l1, *l2)?; self.subtype(self.apply(*r1), self.apply(*r2)) } // Rule <: forall. L (Univ(k, a), b) => { let alpha = self.fresh_ev(); let mut a_ = *a; a_.subst(&mut b.clone()); self.with_scope(Element::Marker(alpha), |f| { f.ctx.push(Element::Exist(alpha)); f.subtype(a_.clone(), b.clone()) }) } // Rule <: forall. R (a, Univ(k, b)) => { // let alpha = self.fresh_ev(); self.with_scope(Element::Var(*k.clone()), |f| f.subtype(a.clone(), *b.clone())) } // Rule <: InstantiateL (Exist(alpha), a) if !a.freevars().contains(&alpha) => self.instantiateL(alpha, &a), // Rule <: InstantiateR (a, Exist(alpha)) if !a.freevars().contains(&alpha) => self.instantiateR(&a, alpha), (a, b) => Err(format!("{:?} is not a subtype of {:?}", a, b)), } } fn instantiateL(&mut self, alpha: usize, a: &Type) -> Result<(), String> { // We need to split our context into Γ1, alpha, Γ2 so that we // can ensure that alpha is a well formed Existenial in Γ1, e.g. // that alpha appears in Γ1. This ensures that alpha is declared // "to the left" (outer scope) of the type `a` let (l, r) = self.split_context(alpha)?; if a.monotype() && l.well_formed(a) { l.ctx.push(Element::Solved(alpha, a.clone())); l.ctx.extend(r); return Ok(()); } // Okay, alpha is *not* well-formed, but that's okay. `split_context` // removed alpha from the context, so we add it back and then reform // Γ1, alpha, Γ2 into a full context again. When we add it back // and reform the context depends on how we dispatch below match a { // InstLArr Type::Arrow(A1, A2) => { let a1 = l.fresh_ev(); let a2 = l.fresh_ev(); // Rather than reforming, then calling splice, we can just // directly push to `l`, since it currently points at the // hole corresponding to [^ ] l.ctx.push(Element::Exist(a2)); l.ctx.push(Element::Exist(a1)); l.ctx.push(Element::Solved( alpha, Type::Arrow(Box::new(Type::Exist(a1)), Box::new(Type::Exist(a2))), )); l.ctx.extend(r); self.instantiateR(A1, a1)?; let A2_ = self.apply(*A2.clone()); self.instantiateL(a2, &A2_) } // InstLAllR Type::Univ(k, beta) => { l.ctx.push(Element::Exist(alpha)); l.ctx.extend(r); self.with_scope(Element::Var(*k.clone()), |f| { f.instantiateL(alpha, &Type::Univ(k.clone(), beta.clone())) }) } // InstLReach Type::Exist(beta) => { // We need to ensure that beta only appears to the right of alpha, // e.g. that beta is well-formed in Γ2, so we make a temporary // context so that we can call the splice_hole method let mut gamma = Context { ctx: r, ev: 0 }; gamma.splice_hole(*beta, |ctx| ctx.push(Element::Solved(*beta, Type::Exist(alpha))))?; // As explained above, Exist(alpha) was popped off of `l` in the // `split_context` method, so we need to add it back in. We // now have some context such that Γ[alpha][beta=alpha] l.ctx.push(Element::Exist(alpha)); l.ctx.extend(gamma.ctx); Ok(()) } _ => Err(format!("Could not instantiate Exist({}) to {:?}", alpha, a)), } } fn instantiateR(&mut self, a: &Type, alpha: usize) -> Result<(), String> { let (l, r) = self.split_context(alpha)?; if a.monotype() && l.well_formed(a) { l.ctx.push(Element::Solved(alpha, a.clone())); l.ctx.extend(r); return Ok(()); } match a { // InstRArr Type::Arrow(A1, A2) => { let a1 = l.fresh_ev(); let a2 = l.fresh_ev(); l.ctx.push(Element::Exist(a2)); l.ctx.push(Element::Exist(a1)); l.ctx.push(Element::Solved( alpha, Type::Arrow(Box::new(Type::Exist(a1)), Box::new(Type::Exist(a2))), )); l.ctx.extend(r); // Much the same as InstLArr, except the following lines are swapped self.instantiateL(a1, &A1)?; let A2_ = self.apply(*A2.clone()); self.instantiateR(&A2_, a2) } // InstRAllL Type::Univ(k, beta) => { l.ctx.push(Element::Exist(alpha)); l.ctx.extend(r); let b = self.fresh_ev(); let mut beta_prime = *beta.clone(); beta_prime.subst(&mut Type::Exist(b)); self.with_scope(Element::Exist(b), |f| { f.instantiateR(&Type::Univ(k.clone(), Box::new(beta_prime.clone())), alpha) }) } // InstRReach Type::Exist(beta) => { let mut gamma = Context { ctx: r, ev: 0 }; gamma.splice_hole(*beta, |ctx| ctx.push(Element::Solved(*beta, Type::Exist(alpha))))?; l.ctx.push(Element::Exist(alpha)); l.ctx.extend(gamma.ctx); Ok(()) } _ => Err(format!("Could not instantiate Exist({}) to {:?}", alpha, a)), } } fn infer(&mut self, e: &Expr) -> Result { match e { Expr::True | Expr::False => Ok(Type::Bool), Expr::Int(_) => Ok(Type::Int), // Rule 1l=> Expr::Unit => Ok(Type::Unit), // Rule Anno Expr::Ann(x, ty) => { self.check_wf(ty)?; self.check(x, ty)?; Ok(*ty.clone()) } // Rule Var Expr::Var(x) => self.find_annotation(*x).cloned().ok_or(format!("unbound db {:?}", x)), // Rule ->I => Expr::Abs(e) => { let alpha = self.fresh_ev(); let beta = self.fresh_ev(); // Fresh existential var for function domain self.ctx.push(Element::Exist(alpha)); // And for codomain self.ctx.push(Element::Exist(beta)); // Check the function body against Beta self.with_scope(Element::Ann(Type::Exist(alpha)), |f| f.check(e, &Type::Exist(beta)))?; // alpha and beta stay on the stack, since they appear in the output type Ok(Type::Arrow(Box::new(Type::Exist(alpha)), Box::new(Type::Exist(beta)))) } // Rule ->E Expr::App(e1, e2) => { let a = self.infer(&e1)?; let a = self.apply(a); println!("{:?} {:?} {:?} {:?}", a, &e1, &e2, &self.ctx); self.infer_app(&a, e2) } Expr::If(e1, e2, e3) => { self.check(e1, &Type::Bool)?; let alpha = self.fresh_ev(); self.ctx.push(Element::Exist(alpha)); let exist = Type::Exist(alpha); self.check(&e2, &exist)?; self.check(&e3, &exist)?; Ok(exist) } Expr::Inj(lr, e, ty) => { self.check_wf(ty)?; let mut ty = self.apply(*ty.clone()); self.beta_reduce(&mut ty)?; match &ty { Type::Sum(l, r) => { match lr { LR::Left => self.check(e, l)?, LR::Right => self.check(e, r)?, } Ok(ty.clone()) } Type::Abs(_, _) => { let arg_ty = self.infer(e)?; self.infer(&Expr::Inj( *lr, e.clone(), Box::new(Type::App(Box::new(ty.clone()), Box::new(arg_ty))), )) } _ => Err(format!("#Expr::Inj {:?} is not a sum type!", ty)), } } Expr::Case(scrutinee, la, ra) => { let ty = self.infer(scrutinee)?; let ty = self.apply(ty); if let Type::Sum(left, right) = &ty { self.check(&la.pat, &ty)?; self.check(&ra.pat, &ty)?; fn bind_infer(ctx: &mut Context, arm: &Arm, left: &Type, right: &Type) -> Result { match arm.pat.as_ref() { Expr::Inj(lr, ex, _) => match (ex.as_ref(), lr) { (Expr::Var(_), LR::Left) => { ctx.with_scope(Element::Ann(left.clone()), |f| f.infer(&arm.expr)) } (Expr::Var(_), LR::Right) => { ctx.with_scope(Element::Ann(right.clone()), |f| f.infer(&arm.expr)) } _ => ctx.infer(&arm.expr), }, _ => Err(format!("Not injection expressions!")), } } let l = bind_infer(self, la, &left, &right)?; let r = bind_infer(self, ra, &left, &right)?; if l == r { Ok(l) } else { Err(format!("Case arms have different return types!")) } } else { Err(format!("{:?} is not a sum type!", ty)) } } Expr::Pair(a, b) => { let ta = self.infer(a)?; let ta = self.apply(ta); let tb = self.infer(b)?; let tb = self.apply(tb); Ok(Type::Product(Box::new(ta), Box::new(tb))) } Expr::Proj(lr, ex) => { let ty = self.infer(ex)?; let ty = self.apply(ty); match ty { Type::Product(left, right) => match lr { LR::Left => Ok(*left), LR::Right => Ok(*right), }, _ => Err(format!("{:?} is not a pair!", ex)), } } // _ => panic!("cant infer {:?}", e) } } fn infer_app(&mut self, ty: &Type, e2: &Expr) -> Result { match ty { // Rule alpha_hat App Type::Exist(alpha) => { let a1 = self.fresh_ev(); let a2 = self.fresh_ev(); self.splice_hole(*alpha, |ctx| { ctx.push(Element::Exist(a2)); ctx.push(Element::Exist(a1)); ctx.push(Element::Solved( *alpha, Type::Arrow(Box::new(Type::Exist(a1)), Box::new(Type::Exist(a2))), )); })?; self.check(e2, &Type::Exist(a1))?; Ok(Type::Exist(a2)) } // Rule ->App Type::Arrow(a, b) => { self.check(e2, a)?; Ok(*b.clone()) } // Rule forall. App Type::Univ(k, a) => { let alpha = self.fresh_ev(); let mut a_prime = a.clone(); a_prime.subst(&mut Type::Exist(alpha)); self.ctx.push(Element::Exist(alpha)); self.infer_app(&a_prime, e2) } _ => Err(format!("Cannot appl ty {:?} to expr {:?}", e2, ty)), } } fn check(&mut self, e: &Expr, a: &Type) -> Result<(), String> { match (e, a) { (Expr::Int(_), Type::Int) => Ok(()), (Expr::False, Type::Bool) => Ok(()), (Expr::True, Type::Bool) => Ok(()), (Expr::If(e1, e2, e3), a) => { self.check(&e1, &Type::Bool)?; self.check(&e2, a)?; self.check(&e3, a) } (Expr::Inj(lr, ex, tagged), Type::Sum(left, right)) => { self.subtype(*tagged.clone(), a.clone())?; match lr { LR::Left => self.check(ex, left), LR::Right => self.check(ex, right), } } (Expr::Pair(a, b), Type::Product(t1, t2)) => { self.check(a, t1)?; self.check(b, t2) } // (Expr::Proj(lr, expr), t) => { // let ty = self.infer(expr)?; // let ty = self.apply(ty); // dbg!(&self.ctx); // match (ty, lr) { // (Type::Product(left, _), LR::Left) => self.subtype(&left, t), // (Type::Product(right, _), LR::Right) => self.subtype(&right, t), // _ => Err(format!("{:?} is not a product type", t)) // } // } // Rule 1l (Expr::Unit, Type::Unit) => Ok(()), // Rule ->I (Expr::Abs(body), Type::Arrow(a1, a2)) => self.with_scope(Element::Ann(*a1.clone()), |f| f.check(body, a2)), // Rule forall. I (e, Type::Univ(k, ty)) => self.with_scope(Element::Var(*k.clone()), |f| f.check(e, &ty)), // Rule Sub (e, b) => { // let mut a = a.clone(); // let mut b = b.clone(); // self.beta_reduce(&mut a)?; // self.beta_reduce(&mut b)?; let a = self.infer(e)?; let a = self.apply(a); let b = self.apply(b.clone()); dbg!(&e, &self.ctx); self.subtype(a, b)?; Ok(()) } } } } fn infer(ex: &Expr) -> Result { let mut ctx = Context::default(); let inf = ctx.infer(ex)?; let mut ty = ctx.apply(inf); ctx.beta_reduce(&mut ty)?; Ok(ty) } fn main() { // \x. (x 1, x True) : forall A. (A -> A) -> (Int, Bool) let h = abs!(pair!(app!(var!(0), Expr::Int(1)), app!(var!(0), Expr::True))); let h = ann!( h, arrow!( forall!(arrow!(Type::Var(0), Type::Var(0))), product!(Type::Int, Type::Bool) ) ); let g = app!(h, abs!(var!(0))); let ty = Type::Abs(Box::new(Kind::Star), Box::new(Type::Var(0))); let ty = Type::App(Box::new(ty), Box::new(Type::Int)); let f = ann!(Expr::Int(99), ty); let ty_opt = Type::Abs(Box::new(Kind::Star), Box::new(sum!(Type::Var(0), Type::Unit))); // \x. inl x of [\X::* => X + ()] @ 'A let some = abs!(inj!(l; var!(0), ty_opt.clone())); let some = ann!( some, forall!(arrow!( Type::Var(0), Type::App(Box::new(ty_opt.clone()), Box::new(Type::Var(0))) )) ); // let some = app!(some, Expr::Int(10)); println!("{} : {:?}", some, infer(&some)); } #[cfg(test)] mod test { use super::*; use helpers::*; #[test] fn identity() { let id = abs!(var!(0)); let id_ann = ann!(id.clone(), forall!(arrow!(Type::Var(0), Type::Var(0)))); let id_ex = arrow!(Type::Exist(0), Type::Exist(0)); let id_ann_ex = forall!(arrow!(Type::Var(0), Type::Var(0))); assert_eq!(infer(&id), Ok(id_ex)); assert_eq!(infer(&id_ann), Ok(id_ann_ex)); } #[test] fn application() { // \f. \g. \x. f (g x) // x: C // g: C -> A // f: A -> B // (A -> B) -> (C -> A) -> C -> B let t = abs!(abs!(abs!(app!(var!(2), app!(var!(1), var!(0)))))); let ty = infer(&t).unwrap(); assert_eq!(ty.to_string(), "((a->b)->((c->a)->(c->b)))"); } #[test] fn application2() { // \x. if x then 1 else 0 : Bool -> Int let f = abs!(ife!(var!(0), Expr::Int(1), Expr::Int(0))); // \f. \g. \x. f (g x) // : (e6 -> e7) -> ((e14 -> e6) -> (e14 -> e7)) // : (a -> b) -> ((c -> a) -> (c -> b)) let t1 = abs!(abs!(abs!(app!(var!(2), app!(var!(1), var!(0)))))); // : ((c -> bool) -> (c -> int)) let f2 = app!(t1.clone(), f); // : (c -> int) let f3 = app!(f2, abs!(Expr::True)); let f4 = app!(f3, Expr::Unit); assert_eq!(infer(&f4), Ok(Type::Int)) } #[test] fn sum_type() { let ty = sum!(Type::Unit, Type::Bool); let f_ty = arrow!(ty.clone(), Type::Int); // \x. case x of inl () => 0 | inr True => 1 let f = abs!( case!(var!(0), inj!(l; Expr::Unit, ty.clone()) => Expr::Int(0), inj!(r; Expr::True, ty.clone()) => Expr::Int(1)) ); let f = ann!(f, f_ty); infer(&f).unwrap(); let a = arrow!(forall!(arrow!(Type::Var(0), Type::Var(0))), ty.clone()); // \x. if True then inl (x ()) else inr (x False) : (forall A. (A -> A)) -> (() + Bool) let f = abs!(ife!( Expr::True, inj!(l; app!(var!(0), Expr::Unit), ty.clone()), inj!(r; app!(var!(0), Expr::False), ty.clone()) )); let f = ann!(f, a); infer(&f).unwrap(); } #[test] fn product_type() { // \x. r.0 : forall A. (A, A) -> A let f = ann!( abs!(proj!(r; var!(0))), forall!(arrow!(product!(Type::Var(0), Type::Var(0)), Type::Var(0))) ); // (\x. fst 0) (1, 1) let f = app!(f, pair!(Expr::Int(1), Expr::Int(1))); assert_eq!(infer(&f), Ok(Type::Int)); // \x. (x 1, x True) : (forall A. (A -> A)) -> (Int, Bool) let h = abs!(pair!(app!(var!(0), Expr::Int(1)), app!(var!(0), Expr::True))); let h = ann!( h, arrow!( forall!(arrow!(Type::Var(0), Type::Var(0))), product!(Type::Int, Type::Bool) ) ); let g = app!(h, abs!(var!(0))); assert_eq!(infer(&g), Ok(product!(Type::Int, Type::Bool))) } } ================================================ FILE: x2_dependent/Cargo.toml ================================================ [package] name = "dependent" version = "0.1.0" authors = ["Michael Lazear "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] ================================================ FILE: x2_dependent/src/main.rs ================================================ #[derive(Debug, Clone, PartialEq)] enum Term { Universe(usize), Nat, Var(usize), Int(usize), App(Box, Box), Abs(Box, Box), Pi(Box, Box), } impl Term { fn normal(&self) -> bool { match self { Term::App(_, _) => false, Term::Pi(a, b) | Term::Abs(a, b) => a.normal() && b.normal(), _ => true, } } fn whnf(&self) -> bool { match self { Term::App(_, _) => false, Term::Pi(a, _) | Term::Abs(a, _) => a.normal(), _ => true, } } fn subst(&mut self, mut t2: Term) { println!("subst {:?} {:?}", self, t2); let mut v = Visitor::new(); fn shift(v: &mut Visitor, t: &mut Term, s: isize) { v.visit(t, &|f, c| { if let Term::Var(n) = f { if *n >= c { *n = (*n as isize + s) as usize; } } }); } shift(&mut v, &mut t2, 1); v.visit(self, &|f, i| { let mut t = t2.clone(); let mut v = Visitor::new(); v.visit(&mut t, &|f, _| { if let Term::Var(n) = f { *n += i; } }); if let Term::Var(n) = f { if *n == i { *f = t; } } }); shift(&mut v, self, -1); } } #[derive(Default, Debug, Clone)] struct Context { binding: Vec, } #[derive(Debug, Clone)] enum Error { Unbound, NotPi(Term), Mismatch(Term, Term), } struct Visitor { cutoff: usize, } impl Visitor { fn new() -> Visitor { Visitor { cutoff: 0 } } fn visit(&mut self, term: &mut Term, f: &F) { match term { Term::Universe(_) => {} Term::Nat | Term::Int(_) => {} Term::Var(_) => { f(term, self.cutoff); } Term::Pi(t1, t2) | Term::Abs(t1, t2) => { self.visit(t1, f); self.cutoff += 1; self.visit(t2, f); self.cutoff -= 1; } Term::App(t1, t2) => { self.visit(t1, f); self.visit(t2, f); } } } } impl Context { fn get(&self, idx: usize) -> Option<&Term> { self.binding.get(self.binding.len().checked_sub(idx + 1)?) } fn with_bind T>(&mut self, bind: Term, f: F) -> T { self.binding.push(bind); let r = f(self); self.binding.pop(); r } fn equiv(&mut self, t1: &Term, t2: &Term) -> bool { let mut t1p = t1.clone(); let mut t2p = t2.clone(); while !t1p.normal() { t1p = beta_reduce(t1p); } while !t2p.normal() { t2p = beta_reduce(t2p); } t1p == t2p } fn type_of(&mut self, term: &Term) -> Result { println!("type of {:?}", term); match &term { Term::Universe(n) => Ok(Term::Universe(*n + 1)), Term::Var(i) => self.get(*i).cloned().ok_or(Error::Unbound), Term::Abs(S, t) => { let k = self.type_of(&S)?; let T = self.with_bind(k.clone(), |f| f.type_of(&t))?; Ok(Term::Pi(Box::new(k), Box::new(T))) } Term::App(t1, t2) => { let ty1 = self.type_of(&t1)?; let ty2 = self.type_of(&t2)?; match ty1 { Term::Pi(S, mut T) => { if self.equiv(&S, &ty2) { T.subst(*t2.clone()); Ok(*T) } else { Err(Error::Mismatch(*S, ty2)) } } _ => Err(Error::NotPi(ty1)), } } Term::Pi(t1, t2) => { let ty1 = self.type_of(&t1)?; let ty2 = self.with_bind(ty1.clone(), |f| f.type_of(t2))?; Ok(ty2) } Term::Nat => Ok(Term::Universe(0)), Term::Int(_) => Ok(Term::Nat), } } } /// Small step beta reduction fn beta_reduce(mut term: Term) -> Term { match term { Term::App(mut abs, arg) => match (abs.normal(), arg.normal()) { (false, _) => Term::App(Box::new(beta_reduce(*abs)), arg), (_, false) => Term::App(abs, Box::new(beta_reduce(*arg))), _ => match *abs { Term::Abs(_, mut body) => { body.subst(*arg.clone()); *body } x => Term::App(Box::new(x), arg), }, }, Term::Abs(ty, body) => { if body.normal() { Term::Abs(ty, body) } else { Term::Abs(ty, Box::new(beta_reduce(*body))) } } _ => term, } } fn main() { macro_rules! term { (Abs; $ex1:expr, $ex2:expr) => { Term::Abs(Box::new($ex1), Box::new($ex2)) }; (App; $ex1:expr, $ex2:expr) => { Term::App(Box::new($ex1), Box::new($ex2)) }; (Pi; $ex1:expr, $ex2:expr) => { Term::Pi(Box::new($ex1), Box::new($ex2)) }; (Var; $ex1:expr) => { Term::Var($ex1) }; (Star) => { Term::Universe(0) }; (Universe; $ex:expr) => { Term::Universe($ex) }; (Int; $ex1:expr) => { Term::Int($ex1) }; }; println!("Hello, world!"); let mut ctx = Context::default(); // let tm = term!(Abs; term!(Star), term!(App; Term::Nat, term!(Var; 0))); // let tm = term!(App; tm, term!(Int; 10)); // Πx: Nat -> x let mut tm = term!(Abs; Term::Nat, Term::Var(0)); tm = term!(App; tm, Term::Int(10)); // tm.subst(Term::Int(10)); dbg!(ctx.type_of(&tm)); // dbg!(&tm); tm = beta_reduce(tm); dbg!(&tm); }