From 5673514b6669d8d54a12f97eaa28ed516b7104d9 Mon Sep 17 00:00:00 2001 From: Shylock Hg Date: Sun, 26 Jan 2025 17:31:01 +0800 Subject: [PATCH] feat: support recusive type/function. --- Cargo.lock | 207 +++++++++++++++++ tigerc/Cargo.toml | 2 + tigerc/src/escape.rs | 8 +- tigerc/src/lib.rs | 10 + tigerc/src/main.rs | 12 +- tigerc/src/type_ast.rs | 19 +- tigerc/src/type_inference.rs | 435 ++++++++++++++++++++++++----------- tigerc/tests/test_compile.rs | 17 ++ 8 files changed, 573 insertions(+), 137 deletions(-) create mode 100644 tigerc/tests/test_compile.rs diff --git a/Cargo.lock b/Cargo.lock index cd89bca..7d8f6db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,102 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" +dependencies = [ + "anstyle", + "once_cell", + "windows-sys", +] + +[[package]] +name = "clap" +version = "4.5.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" + +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + [[package]] name = "equivalent" version = "1.0.1" @@ -14,6 +110,12 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "indexmap" version = "2.7.0" @@ -24,6 +126,24 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + +[[package]] +name = "log" +version = "0.4.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + [[package]] name = "proc-macro2" version = "1.0.92" @@ -42,6 +162,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "syn" version = "2.0.90" @@ -57,7 +183,9 @@ dependencies = [ name = "tiger_c" version = "0.1.0" dependencies = [ + "clap", "indexmap", + "log", "tigerc-macros", "unicode-xid", ] @@ -82,3 +210,82 @@ name = "unicode-xid" version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/tigerc/Cargo.toml b/tigerc/Cargo.toml index 11e90ca..bdd1e14 100644 --- a/tigerc/Cargo.toml +++ b/tigerc/Cargo.toml @@ -7,3 +7,5 @@ edition = "2021" unicode-xid = "0.2.4" indexmap = "2.2.6" tigerc-macros = { path = "../tigerc-macros" } +clap = { version = "4.5.27", features = ["derive"] } +log = "0.4.25" diff --git a/tigerc/src/escape.rs b/tigerc/src/escape.rs index 0ca0477..1feab1e 100644 --- a/tigerc/src/escape.rs +++ b/tigerc/src/escape.rs @@ -5,7 +5,7 @@ use std::vec; use crate::ast; use crate::ident_pool::Symbol; -struct Escape { +pub struct Escape { e: Option>>, } @@ -152,7 +152,7 @@ impl Escape { ast::LeftValue::Variable(s) => { // escape!(depth, var_depth, s, var, self.e); if let Some(v) = var { - if depth > var_depth && dbg!(v) == s { + if depth > var_depth && v == s { if let Some(e) = &self.e { e.replace(true); } @@ -238,7 +238,7 @@ mod tests { let mut parser = Parser::new(Box::new(it)); let mut ast = parser.parse(); Escape::new().escape(&mut ast); - let escape = match dbg!(ast) { + let escape = match ast { ast::Ast::Expr(ast::Expr::Let(ast::Let { decls, .. })) => { let decl = decls.get(1).unwrap(); match decl { @@ -267,7 +267,7 @@ mod tests { let mut parser = Parser::new(Box::new(it)); let mut ast = parser.parse(); Escape::new().escape(&mut ast); - let escape = match dbg!(ast) { + let escape = match ast { ast::Ast::Expr(ast::Expr::Let(ast::Let { decls, .. })) => { let decl = decls.get(2).unwrap(); match decl { diff --git a/tigerc/src/lib.rs b/tigerc/src/lib.rs index 4914663..13bb238 100644 --- a/tigerc/src/lib.rs +++ b/tigerc/src/lib.rs @@ -8,3 +8,13 @@ pub mod symbol_table; pub mod tokenizer; pub mod type_ast; pub mod type_inference; + +pub fn compile_file(f: &str) { + let content = std::fs::read_to_string(f).unwrap(); + let it = tokenizer::tokenize(&content); + let mut parser = parser::Parser::new(Box::new(it)); + let mut e = parser.parse(); + escape::Escape::new().escape(&mut e); + let mut ti = type_inference::TypeInference::new(); + let _te = ti.infer(&e).unwrap(); +} diff --git a/tigerc/src/main.rs b/tigerc/src/main.rs index e7a11a9..b369423 100644 --- a/tigerc/src/main.rs +++ b/tigerc/src/main.rs @@ -1,3 +1,13 @@ +use clap::Parser as cParser; + +use tiger_c; + +#[derive(cParser, Debug)] +struct Args { + input: String, +} + fn main() { - println!("Hello, world!"); + let args = Args::parse(); + tiger_c::compile_file(&args.input); } diff --git a/tigerc/src/type_ast.rs b/tigerc/src/type_ast.rs index d11b983..d4d1559 100644 --- a/tigerc/src/type_ast.rs +++ b/tigerc/src/type_ast.rs @@ -10,9 +10,26 @@ pub enum Type { Record(Record), Array(Box), Function(Function), + // just place holder for recursive type/function Name(Symbol), } +impl Type { + // Can assign to other? + pub fn can_assign(&self, other: &Type) -> bool { + match self { + Type::Nil => { + if let Type::Record(_) = other { + true + } else { + false + } + } + _ => self == other, + } + } +} + #[derive(Debug, PartialEq, Eq, Clone)] pub struct Record { pub fields: IndexMap, @@ -200,7 +217,7 @@ pub struct FuncDecl { #[derive(Debug, PartialEq, Eq)] pub struct TyDecl { pub type_name: Symbol, - pub ty: Ty, + pub ty: Type, } #[derive(Debug, PartialEq, Eq)] diff --git a/tigerc/src/type_inference.rs b/tigerc/src/type_inference.rs index 9ebc153..ac94465 100644 --- a/tigerc/src/type_inference.rs +++ b/tigerc/src/type_inference.rs @@ -1,11 +1,12 @@ use std::result; use indexmap::IndexMap; +use log; -use crate::ast; use crate::ident_pool::{kw, Symbol}; use crate::symbol_table::SymbolTable; use crate::type_ast; +use crate::{ast, ident_pool}; #[derive(Debug)] pub struct InferError(pub String); @@ -27,17 +28,14 @@ struct SymbolValue { } macro_rules! infer_logical_op { - ($self:ident, $l:ident, $r:ident) => {{ + ($self:ident, $l:ident, $r:ident, $op:path) => {{ let ty_left = $self.infer_expr(&$l)?; let ty_right = $self.infer_expr(&$r)?; match (&ty_left.ty, &ty_right.ty) { (type_ast::Type::Int, type_ast::Type::Int) => { let ty = ty_left.ty.clone(); Ok(type_ast::TypeExpr { - expr: type_ast::TypeExpr_::Binary(type_ast::Binary::And( - Box::new(ty_left), - Box::new(ty_right), - )), + expr: type_ast::TypeExpr_::Binary($op(Box::new(ty_left), Box::new(ty_right))), ty, }) } @@ -50,7 +48,7 @@ macro_rules! infer_logical_op { } macro_rules! infer_compare_gl { - ($self:ident, $l:ident, $r:ident) => {{ + ($self:ident, $l:ident, $r:ident, $op:path) => {{ let ty_left = $self.infer_expr(&$l)?; let ty_right = $self.infer_expr(&$r)?; if ty_left.ty != ty_right.ty { @@ -58,10 +56,7 @@ macro_rules! infer_compare_gl { } match ty_left.ty { type_ast::Type::Int => Ok(type_ast::TypeExpr { - expr: type_ast::TypeExpr_::Binary(type_ast::Binary::Eq( - Box::new(ty_left), - Box::new(ty_right), - )), + expr: type_ast::TypeExpr_::Binary($op(Box::new(ty_left), Box::new(ty_right))), ty: type_ast::Type::Int, }), _ => Err(InferError::new(format!( @@ -73,24 +68,26 @@ macro_rules! infer_compare_gl { } macro_rules! infer_compare_eq_ne { - ($self:ident, $l:ident, $r:ident) => {{ + ($self:ident, $l:ident, $r:ident, $op:path) => {{ let ty_left = $self.infer_expr(&$l)?; let ty_right = $self.infer_expr(&$r)?; - if ty_left.ty != ty_right.ty { - return Err(InferError::new(format!("Expect same type for equality."))); + if !(ty_left.ty.can_assign(&ty_right.ty) || ty_right.ty.can_assign(&ty_left.ty)) { + return Err(InferError::new(format!( + "Expect same type for equality, but got {:?} and {:?}.", + ty_left.ty, ty_right.ty + ))); } match ty_left.ty { - type_ast::Type::Int | type_ast::Type::Record(..) | type_ast::Type::Array(..) => { - Ok(type_ast::TypeExpr { - expr: type_ast::TypeExpr_::Binary(type_ast::Binary::Eq( - Box::new(ty_left), - Box::new(ty_right), - )), - ty: type_ast::Type::Int, - }) - } + type_ast::Type::Int + | type_ast::Type::Str + | type_ast::Type::Nil + | type_ast::Type::Record(..) + | type_ast::Type::Array(..) => Ok(type_ast::TypeExpr { + expr: type_ast::TypeExpr_::Binary($op(Box::new(ty_left), Box::new(ty_right))), + ty: type_ast::Type::Int, + }), _ => Err(InferError::new(format!( - "Expect int/record/array type for (not)equality, got {:?}.", + "Expect int/string/record/array type for (not)equality, got {:?}.", ty_left.ty ))), } @@ -98,17 +95,14 @@ macro_rules! infer_compare_eq_ne { } macro_rules! infer_arithmetic { - ($self:ident, $l:ident, $r:ident) => {{ + ($self:ident, $l:ident, $r:ident, $op:path) => {{ let ty_l = $self.infer_expr(&$l)?; let ty_r = $self.infer_expr(&$r)?; match (&ty_l.ty, &ty_r.ty) { (type_ast::Type::Int, type_ast::Type::Int) => { let ty = ty_l.ty.clone(); Ok(type_ast::TypeExpr { - expr: type_ast::TypeExpr_::Binary(type_ast::Binary::Add( - Box::new(ty_l), - Box::new(ty_r), - )), + expr: type_ast::TypeExpr_::Binary($op(Box::new(ty_l), Box::new(ty_r))), ty, }) } @@ -142,9 +136,60 @@ impl TypeInference { ty: type_ast::Type::Str, }, ); + let mut var_table = SymbolTable::new(); + var_table.insert_symbol( + ident_pool::create_symbol("print"), + SymbolValue { + ty: type_ast::Type::Function(type_ast::Function { + name: ident_pool::create_symbol("print"), + params: vec![type_ast::Type::Str], + return_ty: Box::new(type_ast::Type::Nothing), + }), + }, + ); + var_table.insert_symbol( + ident_pool::create_symbol("getchar"), + SymbolValue { + ty: type_ast::Type::Function(type_ast::Function { + name: ident_pool::create_symbol("getchar"), + params: vec![], + return_ty: Box::new(type_ast::Type::Str), + }), + }, + ); + var_table.insert_symbol( + ident_pool::create_symbol("flush"), + SymbolValue { + ty: type_ast::Type::Function(type_ast::Function { + name: ident_pool::create_symbol("flush"), + params: vec![], + return_ty: Box::new(type_ast::Type::Nothing), + }), + }, + ); + var_table.insert_symbol( + ident_pool::create_symbol("ord"), + SymbolValue { + ty: type_ast::Type::Function(type_ast::Function { + name: ident_pool::create_symbol("ord"), + params: vec![type_ast::Type::Str], + return_ty: Box::new(type_ast::Type::Int), + }), + }, + ); + var_table.insert_symbol( + ident_pool::create_symbol("chr"), + SymbolValue { + ty: type_ast::Type::Function(type_ast::Function { + name: ident_pool::create_symbol("chr"), + params: vec![type_ast::Type::Int], + return_ty: Box::new(type_ast::Type::Str), + }), + }, + ); Self { type_symbol_table: ty_table, - variable_symbol_table: SymbolTable::new(), + variable_symbol_table: var_table, } } @@ -155,6 +200,7 @@ impl TypeInference { } } + // TODO maybe return void fn infer_decl(&mut self, decl: &ast::Decl) -> Result { match decl { ast::Decl::Type(t) => { @@ -164,21 +210,24 @@ impl TypeInference { type_ast::Type::Array(Box::new(sub)) } ast::Ty::Name(n) => self.type_symbol_table.get_symbol(n).unwrap().ty.clone(), - ast::Ty::Struct(s) => type_ast::Type::Record(type_ast::Record { - fields: self.infer_field_list(&s.0), - }), + ast::Ty::Struct(s) => { + // just placeholder for recursive type checking + self.type_symbol_table.insert_symbol( + t.type_name, + SymbolValue { + ty: type_ast::Type::Name(t.type_name), + }, + ); + type_ast::Type::Record(type_ast::Record { + fields: self.infer_field_list(&s.0), + }) + } }; self.type_symbol_table - .insert_symbol(t.type_name, SymbolValue { ty }); + .insert_symbol(t.type_name, SymbolValue { ty: ty.clone() }); Ok(type_ast::TypeDecl::Type(type_ast::TyDecl { type_name: t.type_name, - ty: match &t.ty { - ast::Ty::Array(s) => type_ast::Ty::Array(*s), - ast::Ty::Name(s) => type_ast::Ty::Name(*s), - ast::Ty::Struct(s) => { - type_ast::Ty::Struct(type_ast::TyStruct(self.infer_field_list(&s.0))) - } - }, + ty, })) } ast::Decl::Var(v) => { @@ -199,22 +248,11 @@ impl TypeInference { })) } ast::Decl::Func(f) => { - // scope of function body - self.variable_symbol_table.begin_scope(); let typed_args = self.infer_parameter_list(&f.args); - for (name, ty) in typed_args.iter() { - self.variable_symbol_table - .insert_symbol(*name, SymbolValue { ty: ty.ty.clone() }); - } - let typed_body = self.infer_expr(&f.body)?; - self.variable_symbol_table.end_scope(); let ret_ty = f .ret_ty .map(|t| self.type_symbol_table.get_symbol(&t).unwrap().ty.clone()) - .unwrap_or(typed_body.ty.clone()); - if typed_body.ty != ret_ty { - return Err(InferError::conflict_type(&ret_ty, &typed_body.ty)); - } + .unwrap_or(type_ast::Type::Nothing); self.variable_symbol_table.insert_symbol( f.name, SymbolValue { @@ -228,6 +266,18 @@ impl TypeInference { }), }, ); + // scope of function body + self.variable_symbol_table.begin_scope(); + for (name, ty) in typed_args.iter() { + self.variable_symbol_table + .insert_symbol(*name, SymbolValue { ty: ty.ty.clone() }); + } + let typed_body = self.infer_expr(&f.body)?; + self.variable_symbol_table.end_scope(); + + if typed_body.ty != ret_ty { + return Err(InferError::conflict_type(&ret_ty, &typed_body.ty)); + } Ok(type_ast::TypeDecl::Func(type_ast::FuncDecl { name: f.name, args: typed_args, @@ -298,40 +348,40 @@ impl TypeInference { }, ast::Expr::Binary(binary) => match binary { ast::Binary::Add(l, r) => { - infer_arithmetic!(self, l, r) + infer_arithmetic!(self, l, r, type_ast::Binary::Add) } ast::Binary::Minus(l, r) => { - infer_arithmetic!(self, l, r) + infer_arithmetic!(self, l, r, type_ast::Binary::Minus) } ast::Binary::Multiply(l, r) => { - infer_arithmetic!(self, l, r) + infer_arithmetic!(self, l, r, type_ast::Binary::Multiply) } ast::Binary::Divide(l, r) => { - infer_arithmetic!(self, l, r) + infer_arithmetic!(self, l, r, type_ast::Binary::Divide) } ast::Binary::Eq(l, r) => { - infer_compare_eq_ne!(self, l, r) + infer_compare_eq_ne!(self, l, r, type_ast::Binary::Eq) } ast::Binary::Ne(l, r) => { - infer_compare_eq_ne!(self, l, r) + infer_compare_eq_ne!(self, l, r, type_ast::Binary::Ne) } ast::Binary::Gt(l, r) => { - infer_compare_gl!(self, l, r) + infer_compare_gl!(self, l, r, type_ast::Binary::Gt) } ast::Binary::Ge(l, r) => { - infer_compare_gl!(self, l, r) + infer_compare_gl!(self, l, r, type_ast::Binary::Ge) } ast::Binary::Lt(l, r) => { - infer_compare_gl!(self, l, r) + infer_compare_gl!(self, l, r, type_ast::Binary::Lt) } ast::Binary::Le(l, r) => { - infer_compare_gl!(self, l, r) + infer_compare_gl!(self, l, r, type_ast::Binary::Le) } ast::Binary::And(l, r) => { - infer_logical_op!(self, l, r) + infer_logical_op!(self, l, r, type_ast::Binary::And) } ast::Binary::Or(l, r) => { - infer_logical_op!(self, l, r) + infer_logical_op!(self, l, r, type_ast::Binary::Or) } }, ast::Expr::FuncCall(f, args) => { @@ -342,39 +392,40 @@ impl TypeInference { .ty .clone(); - if let type_ast::Type::Function(tf) = ty_func { - if tf.params.len() != args.len() { - Err(InferError::new(format!( - "Expect {} arguments for function `{}`, got {}.", - tf.params.len(), - tf.name, - args.len(), - ))) - } else { - let mut ty_args = vec![]; - for a in args.iter() { - let ty_arg = self.infer_expr(a)?; - ty_args.push(ty_arg); - } + match ty_func { + type_ast::Type::Function(tf) => { + if tf.params.len() != args.len() { + Err(InferError::new(format!( + "Expect {} arguments for function `{}`, got {}.", + tf.params.len(), + tf.name, + args.len(), + ))) + } else { + let mut ty_args = vec![]; + for a in args.iter() { + let ty_arg = self.infer_expr(a)?; + ty_args.push(ty_arg); + } - for (p, a) in tf.params.iter().zip(ty_args.iter()) { - if p != &a.ty { - return Err(InferError::new(format!( - "Expect {:?} type for argument {} but got {:?}.", - p, f, a.ty - ))); + for (p, a) in tf.params.iter().zip(ty_args.iter()) { + if p != &a.ty { + return Err(InferError::new(format!( + "Expect {:?} type for argument {} but got {:?}.", + p, f, a.ty + ))); + } } + Ok(type_ast::TypeExpr { + expr: type_ast::TypeExpr_::FuncCall(tf.name, ty_args), + ty: *tf.return_ty, + }) } - Ok(type_ast::TypeExpr { - expr: type_ast::TypeExpr_::FuncCall(tf.name, ty_args), - ty: *tf.return_ty, - }) } - } else { - Err(InferError::new(format!( + _ => Err(InferError::new(format!( "Expect function type for function call, got {:?}.", ty_func - ))) + ))), } } ast::Expr::RecordExpr(record) => { @@ -397,10 +448,20 @@ impl TypeInference { } let ty_arg = self.infer_expr(&ee)?; if t != &ty_arg.ty { - return Err(InferError::new(format!( - "Expect {:?} type for field {}, got {:?}.", - t, f, ty_arg.ty, - ))); + if let type_ast::Type::Name(n) = t { + let t = self.type_symbol_table.get_symbol(n).unwrap().ty.clone(); + if &t != &ty_arg.ty { + return Err(InferError::new(format!( + "Expect {:?} type for field {}, got {:?}.", + t, f, ty_arg.ty, + ))); + } + } else { + return Err(InferError::new(format!( + "Expect {:?} type for field {}, got {:?}.", + t, f, ty_arg.ty, + ))); + } } init.push((*f, ty_arg)); } @@ -480,7 +541,7 @@ impl TypeInference { } let ty_then = self.infer_expr(then)?; let ty_el = self.infer_expr(el)?; - if ty_then.ty != ty_el.ty { + if !(ty_then.ty.can_assign(&ty_el.ty) || ty_el.ty.can_assign(&ty_then.ty)) { Err(InferError::new(format!( "Expect same type for then and else, but got {:?} and {:?}", ty_then.ty, ty_el.ty @@ -555,6 +616,12 @@ impl TypeInference { ty_upper.ty, ))); } + self.variable_symbol_table.insert_symbol( + for_.local, + SymbolValue { + ty: type_ast::Type::Int, + }, + ); let ty_body = self.infer_expr(&for_.body)?; if !matches![ty_body.ty, type_ast::Type::Nothing] { return Err(InferError::new(format!( @@ -634,6 +701,12 @@ impl TypeInference { let ty_left = self.infer_left_value(l)?; if let type_ast::Type::Record(r) = &ty_left.ty { let ty = r.fields.get(f).unwrap().clone(); + let ty = if let type_ast::Type::Name(_) = ty { + // the name typed field is just record itself + ty_left.ty.clone() + } else { + ty + }; Ok(type_ast::LeftValue { left: type_ast::LeftValue_::Field(Box::new(ty_left), *f), ty, @@ -860,40 +933,6 @@ mod tests { } } - #[test] - fn test_let_higher_order_function() { - let doc = " - let - function f(x: int) = - let function g(y: int) = x+y - in g - end - in f - end - "; - let it = tokenize(doc); - let mut parser = Parser::new(Box::new(it)); - let e = parser.parse(); - let mut ti = TypeInference::new(); - let te = ti.infer(&e).unwrap(); - if let type_ast::TypeAst::TypeExpr(te) = te { - assert_eq!( - te.ty, - type_ast::Type::Function(type_ast::Function { - name: ident_pool::create_symbol("f"), - params: vec![type_ast::Type::Int], - return_ty: Box::new(type_ast::Type::Function(type_ast::Function { - name: ident_pool::create_symbol("g"), - params: vec![type_ast::Type::Int], - return_ty: Box::new(type_ast::Type::Int), - })), - }) - ); - } else { - panic!("Unexpected decl."); - } - } - #[test] #[should_panic] fn test_use_undefined_type() { @@ -921,4 +960,138 @@ mod tests { let mut ti = TypeInference::new(); let _te = ti.infer(&e).unwrap(); } + + #[test] + fn test_recursive_record() { + let doc = " + let type list = {num: int, rest: list} in + end + "; + let it = tokenize(doc); + let mut parser = Parser::new(Box::new(it)); + let e = parser.parse(); + let mut ti = TypeInference::new(); + let te = ti.infer(&e).unwrap(); + let expected = type_ast::TypeAst::TypeExpr(type_ast::TypeExpr { + ty: type_ast::Type::Nothing, + expr: type_ast::TypeExpr_::Let(type_ast::Let { + decls: vec![type_ast::TypeDecl::Type(type_ast::TyDecl { + type_name: ident_pool::create_symbol("list"), + ty: type_ast::Type::Record(type_ast::Record { + fields: vec![ + (ident_pool::create_symbol("num"), type_ast::Type::Int), + ( + ident_pool::create_symbol("rest"), + type_ast::Type::Name(ident_pool::create_symbol("list")), + ), + ] + .into_iter() + .collect::>(), + }), + })], + sequence: vec![], + }), + }); + assert_eq!(te, expected); + } + + #[test] + fn test_recursive_function() { + let doc = " + let function f(x: int): int = + if x < 0 then 0 else f(x-1) + in + f(10) + end + "; + let it = tokenize(doc); + let mut parser = Parser::new(Box::new(it)); + let e = dbg!(parser.parse()); + let mut ti = TypeInference::new(); + let te = ti.infer(&e).unwrap(); + let expected = type_ast::TypeAst::TypeExpr(type_ast::TypeExpr { + ty: type_ast::Type::Int, + expr: type_ast::TypeExpr_::Let(type_ast::Let { + decls: vec![type_ast::TypeDecl::Func(type_ast::FuncDecl { + name: ident_pool::create_symbol("f"), + args: [( + ident_pool::create_symbol("x"), + type_ast::Parameter { + ty: type_ast::Type::Int, + escape: false, + }, + )] + .into_iter() + .collect::>(), + ret_ty: type_ast::Type::Int, + body: type_ast::TypeExpr { + ty: type_ast::Type::Int, + expr: type_ast::TypeExpr_::IfThenElse(type_ast::IfThenElseExpr { + condition: Box::new(type_ast::TypeExpr { + ty: type_ast::Type::Int, + expr: type_ast::TypeExpr_::Binary(type_ast::Binary::Lt( + Box::new(type_ast::TypeExpr { + ty: type_ast::Type::Int, + expr: type_ast::TypeExpr_::LeftValue(type_ast::LeftValue { + ty: type_ast::Type::Int, + left: type_ast::LeftValue_::Variable( + ident_pool::create_symbol("x"), + ), + }), + }), + Box::new(type_ast::TypeExpr { + ty: type_ast::Type::Int, + expr: type_ast::TypeExpr_::Literal(ast::Value::Int(0)), + }), + )), + }), + then: Box::new(type_ast::TypeExpr { + ty: type_ast::Type::Int, + expr: type_ast::TypeExpr_::Literal(ast::Value::Int(0)), + }), + el: Box::new(type_ast::TypeExpr { + ty: type_ast::Type::Int, + expr: type_ast::TypeExpr_::FuncCall( + ident_pool::create_symbol("f"), + vec![type_ast::TypeExpr { + ty: type_ast::Type::Int, + expr: type_ast::TypeExpr_::Binary(type_ast::Binary::Minus( + Box::new(type_ast::TypeExpr { + ty: type_ast::Type::Int, + expr: type_ast::TypeExpr_::LeftValue( + type_ast::LeftValue { + ty: type_ast::Type::Int, + left: type_ast::LeftValue_::Variable( + ident_pool::create_symbol("x"), + ), + }, + ), + }), + Box::new(type_ast::TypeExpr { + ty: type_ast::Type::Int, + expr: type_ast::TypeExpr_::Literal( + ast::Value::Int(1), + ), + }), + )), + }], + ), + }), + }), + }, + })], + sequence: vec![type_ast::TypeExpr { + ty: type_ast::Type::Int, + expr: type_ast::TypeExpr_::FuncCall( + ident_pool::create_symbol("f"), + vec![type_ast::TypeExpr { + ty: type_ast::Type::Int, + expr: type_ast::TypeExpr_::Literal(ast::Value::Int(10)), + }], + ), + }], + }), + }); + assert_eq!(te, expected); + } } diff --git a/tigerc/tests/test_compile.rs b/tigerc/tests/test_compile.rs new file mode 100644 index 0000000..a44087f --- /dev/null +++ b/tigerc/tests/test_compile.rs @@ -0,0 +1,17 @@ +#[cfg(test)] +mod test { + + use tiger_c::compile_file; + + #[test] + fn test_compile() { + let path = "tests/testcases/merge.tig"; + compile_file(path); + } + + #[test] + fn test_compile2() { + let path = "tests/testcases/queens.tig"; + compile_file(path); + } +}