diff --git a/src/main.rs b/src/main.rs index f0c3389..6ebe6c3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ use group3::parser::*; use group3::scanner::Scanner; use group3::token::{Token, Tokens}; +use nom::{Err, Finish}; use pretty_trait::to_string; use std::error::Error; @@ -21,7 +22,7 @@ pub fn main() -> Result<(), Box> { let tokens = Tokens::new(&tokens); - match program_parser(tokens) { + match program_parser(tokens).finish() { Ok((t, ast)) => { if t.is_empty() { let max_line = Some(40); @@ -30,12 +31,25 @@ pub fn main() -> Result<(), Box> { println!("// Successfully parsed input file."); println!("{}", to_string(&ast.to_pretty(), max_line, tab_size)); } else { - println!("Parser did not complete, remaining tokens: {:?}", t); + match program_parser(t).finish() { + Ok(_) => { + // There was unparsed input, so we know parsing went wrong somewhere + unreachable!(); + } + Err(err) => { + eprintln!("{:#?}", err); + } + } } } - Err(e) => { - println!("{:?}", e); + Err(err) => { + if err.input.is_empty() { + eprintln!("{:#?}", err); + } else { + let head = &err.input[0]; + eprintln!("Error at line {}, column {}: expected {}, found {}", head.line, head.column, err.context.unwrap(), head.kind); + } } } diff --git a/src/parser/ast.rs b/src/parser/ast.rs index a9ca80a..a21f736 100644 --- a/src/parser/ast.rs +++ b/src/parser/ast.rs @@ -1,15 +1,30 @@ use num_bigint::BigUint; +use std::ops::{Deref, DerefMut}; + #[derive(PartialEq, Debug)] pub struct Program(pub Vec); +impl Deref for Program { + type Target = Vec; + fn deref(&self) -> &Vec { + &self.0 + } +} + +impl DerefMut for Program { + fn deref_mut(&mut self) -> &mut Vec { + &mut self.0 + } +} + #[derive(PartialEq, Debug)] pub enum Decl { VarDecl(VarDecl), FunDecl(FunDecl), } -#[derive(PartialEq, Debug, Clone)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Id(pub String); #[derive(PartialEq, Debug)] @@ -23,30 +38,22 @@ pub struct VarDecl { pub struct FunDecl { pub name: Id, pub params: Vec, - pub fun_type: Option, + pub fun_type: Option, pub statements: Vec, } -#[derive(PartialEq, Debug)] -pub struct FunType { - pub param_types: Vec, - pub return_type: ReturnType, -} - -#[derive(PartialEq, Debug)] -pub enum ReturnType { - Type(Type), - Void, -} - -#[derive(PartialEq, Debug)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum Type { Int, Bool, Char, + String, + Void, + Function(Vec, Box), Tuple(Box, Box), - Array(Box), - Generic(Id), + List(Box), + + Var(Id), } #[derive(PartialEq, Debug)] diff --git a/src/parser/error.rs b/src/parser/error.rs new file mode 100644 index 0000000..61746ce --- /dev/null +++ b/src/parser/error.rs @@ -0,0 +1,29 @@ +use nom::error::{ErrorKind, ParseError, ContextError}; + +#[derive(Debug, PartialEq)] +pub struct Error { + pub input: I, + pub kind: ErrorKind, + pub context: Option<&'static str>, +} + +impl ParseError for Error { + fn from_error_kind(input: I, kind: ErrorKind) -> Self { + Self { + input, + kind, + context: None, + } + } + + fn append(_input: I, _kind: ErrorKind, other: Self) -> Self { + other + } +} + +impl ContextError for Error { + fn add_context(_input: I, ctx: &'static str, mut other: Self) -> Self { + other.context.get_or_insert(ctx); + other + } +} diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 9afba2e..d773813 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1,94 +1,114 @@ mod ast; +mod tc; +pub mod error; pub mod pp; #[cfg(test)] mod test; use crate::token::*; -use ast::*; +use self::ast::*; +use self::error::Error; use nom::branch::alt; use nom::bytes::complete::take; use nom::combinator::{map, opt, verify}; -use nom::error::{Error, ErrorKind}; +use nom::error::{context, ErrorKind, ParseError}; use nom::multi::{fold_many0, many0, many1, many_till, separated_list0}; use nom::sequence::{delimited, pair, preceded, separated_pair, terminated, tuple}; -use nom::{Err, IResult}; +use nom::{Err, IResult, Parser}; macro_rules! token_parser ( - ($name:ident, $kind:expr) => ( - fn $name(tokens: Tokens) -> IResult { - verify(take(1usize), |t: &Tokens| { - t[0].kind == $kind} + ($name:ident, $kind:expr, $text:expr) => ( + fn $name(tokens: Tokens) -> IResult> { + context( + $text, + verify(take(1usize), |t: &Tokens| { + t[0].kind == $kind + }) )(tokens) } ) ); -token_parser!(var_parser, TokenKind::Var); -token_parser!(assignment_parser, TokenKind::Assignment); -token_parser!(semicolon_parser, TokenKind::Semicolon); -token_parser!(opening_brace_parser, TokenKind::OpeningBrace); -token_parser!(closing_brace_parser, TokenKind::ClosingBrace); +token_parser!(var_parser, TokenKind::Var, "var"); +token_parser!(assignment_parser, TokenKind::Assignment, "="); +token_parser!(semicolon_parser, TokenKind::Semicolon, ";"); +token_parser!(opening_brace_parser, TokenKind::OpeningBrace, "{"); +token_parser!(closing_brace_parser, TokenKind::ClosingBrace, "}"); -token_parser!(double_colon_parser, TokenKind::DoubleColon); -token_parser!(right_arrow_parser, TokenKind::RightArrow); +token_parser!(double_colon_parser, TokenKind::DoubleColon, "::"); +token_parser!(right_arrow_parser, TokenKind::RightArrow, "->"); -token_parser!(int_type_parser, TokenKind::IntType); -token_parser!(bool_type_parser, TokenKind::BoolType); -token_parser!(char_type_parser, TokenKind::CharType); +token_parser!(int_type_parser, TokenKind::IntType, "Int"); +token_parser!(bool_type_parser, TokenKind::BoolType, "Bool"); +token_parser!(char_type_parser, TokenKind::CharType, "Char"); +token_parser!(void_type_parser, TokenKind::VoidType, "Void"); -token_parser!(void_type_parser, TokenKind::VoidType); +token_parser!(opening_square_parser, TokenKind::OpeningSquare, "["); +token_parser!(closing_square_parser, TokenKind::ClosingSquare, "]"); -token_parser!(opening_square_parser, TokenKind::OpeningSquare); -token_parser!(closing_square_parser, TokenKind::ClosingSquare); - -token_parser!(if_parser, TokenKind::If); -token_parser!(else_parser, TokenKind::Else); -token_parser!(while_parser, TokenKind::While); -token_parser!(return_parser, TokenKind::Return); +token_parser!(if_parser, TokenKind::If, "if"); +token_parser!(else_parser, TokenKind::Else, "else"); +token_parser!(while_parser, TokenKind::While, "while"); +token_parser!(return_parser, TokenKind::Return, "return"); // Non-terminal Disjun -token_parser!(or_parser, TokenKind::Or); +token_parser!(or_parser, TokenKind::Or, "||"); // Non-terminal Conjun -token_parser!(and_parser, TokenKind::And); +token_parser!(and_parser, TokenKind::And, "&&"); // Non-terminal Compare -token_parser!(equals_parser, TokenKind::Equals); -token_parser!(not_equals_parser, TokenKind::NotEquals); -token_parser!(lt_parser, TokenKind::Lt); -token_parser!(le_parser, TokenKind::Le); -token_parser!(gt_parser, TokenKind::Gt); -token_parser!(ge_parser, TokenKind::Ge); +token_parser!(equals_parser, TokenKind::Equals, "=="); +token_parser!(not_equals_parser, TokenKind::NotEquals, "!="); +token_parser!(lt_parser, TokenKind::Lt, "<"); +token_parser!(le_parser, TokenKind::Le, "<="); +token_parser!(gt_parser, TokenKind::Gt, ">"); +token_parser!(ge_parser, TokenKind::Ge, ">="); // Non-terminal Concat -token_parser!(cons_parser, TokenKind::Cons); +token_parser!(cons_parser, TokenKind::Cons, ":"); // Non-terminal Term -token_parser!(plus_parser, TokenKind::Plus); -token_parser!(minus_parser, TokenKind::Minus); +token_parser!(plus_parser, TokenKind::Plus, "+"); +token_parser!(minus_parser, TokenKind::Minus, "-"); // Non-terminal Factor -token_parser!(times_parser, TokenKind::Times); -token_parser!(divide_parser, TokenKind::Divide); -token_parser!(modulo_parser, TokenKind::Modulo); +token_parser!(times_parser, TokenKind::Times, "*"); +token_parser!(divide_parser, TokenKind::Divide, "/"); +token_parser!(modulo_parser, TokenKind::Modulo, "%"); // Non-terminal Unary -token_parser!(not_parser, TokenKind::Not); +token_parser!(not_parser, TokenKind::Not, "!"); // Non-terminal Atom -token_parser!(opening_paren_parser, TokenKind::OpeningParen); -token_parser!(closing_paren_parser, TokenKind::ClosingParen); -token_parser!(empty_list_parser, TokenKind::EmptyList); -token_parser!(comma_parser, TokenKind::Comma); +token_parser!(opening_paren_parser, TokenKind::OpeningParen, "("); +token_parser!(closing_paren_parser, TokenKind::ClosingParen, ")"); +token_parser!(empty_list_parser, TokenKind::EmptyList, "[]"); +token_parser!(comma_parser, TokenKind::Comma, ","); // Non-terminal Field -token_parser!(hd_parser, TokenKind::Hd); -token_parser!(tl_parser, TokenKind::Tl); -token_parser!(fst_parser, TokenKind::Fst); -token_parser!(snd_parser, TokenKind::Snd); +token_parser!(hd_parser, TokenKind::Hd, ".hd"); +token_parser!(tl_parser, TokenKind::Tl, ".tl"); +token_parser!(fst_parser, TokenKind::Fst, ".fst"); +token_parser!(snd_parser, TokenKind::Snd, ".snd"); + +/// Turn a [`nom::Err::Error`] into a [`nom::Err::Failure`] if the given parser fails. +fn require<'a, P, I, O, E>(mut p: P) -> impl FnMut(I) -> IResult +where + E: ParseError, + P: Parser +{ + move |tokens: I| { + match p.parse(tokens) { + Ok(res) => Ok(res), + Err(Err::Error(e)) => Err(Err::Failure(e)), + Err(err) => Err(err), + } + } +} -pub fn program_parser(tokens: Tokens) -> IResult { +pub fn program_parser(tokens: Tokens) -> IResult> { map( many1(alt(( map(var_decl_parser, Decl::VarDecl), @@ -99,7 +119,7 @@ pub fn program_parser(tokens: Tokens) -> IResult { } /// Parses a tuple type, ie. "(" type "," type ")". -fn tuple_type_parser(tokens: Tokens) -> IResult { +fn tuple_type_parser(tokens: Tokens) -> IResult> { map( delimited( opening_paren_parser, @@ -110,33 +130,47 @@ fn tuple_type_parser(tokens: Tokens) -> IResult { )(tokens) } +fn function_type_parser(tokens: Tokens) -> IResult> { + map( + tuple(( + double_colon_parser, + many0(type_parser), + right_arrow_parser, + type_parser, + )), + |(_, param_types, _, return_type)| Type::Function(param_types, Box::new(return_type)), + )(tokens) +} + /// Parses an array type, ie. "[" type "]". -fn array_type_parser(tokens: Tokens) -> IResult { +fn array_type_parser(tokens: Tokens) -> IResult> { map( delimited(opening_square_parser, type_parser, closing_square_parser), - |t| Type::Array(Box::new(t)), + |t| Type::List(Box::new(t)), )(tokens) } /// Parses a type. -fn type_parser(tokens: Tokens) -> IResult { +fn type_parser(tokens: Tokens) -> IResult> { alt(( map(int_type_parser, |_| Type::Int), map(bool_type_parser, |_| Type::Bool), map(char_type_parser, |_| Type::Char), + map(void_type_parser, |_| Type::Void), tuple_type_parser, + function_type_parser, array_type_parser, - map(identifier_parser, Type::Generic), + map(identifier_parser, Type::Var), ))(tokens) } /// Parses the type of a variable declaration, either "var" or a type. -fn var_decl_type_parser(tokens: Tokens) -> IResult> { +fn var_decl_type_parser(tokens: Tokens) -> IResult, Error> { alt((map(var_parser, |_| None), map(type_parser, Some)))(tokens) } /// Parses a variable declaration. -fn var_decl_parser(tokens: Tokens) -> IResult { +fn var_decl_parser(tokens: Tokens) -> IResult> { map( tuple(( var_decl_type_parser, @@ -153,30 +187,8 @@ fn var_decl_parser(tokens: Tokens) -> IResult { )(tokens) } -fn fun_ret_type_parser(tokens: Tokens) -> IResult { - alt(( - map(void_type_parser, |_| ReturnType::Void), - map(type_parser, ReturnType::Type), - ))(tokens) -} - -fn fun_decl_type_parser(tokens: Tokens) -> IResult> { - opt(map( - tuple(( - double_colon_parser, - many0(type_parser), - right_arrow_parser, - fun_ret_type_parser, - )), - |(_, param_types, _, return_type)| FunType { - param_types, - return_type, - }, - ))(tokens) -} - /// Parses a function declaration. -fn fun_decl_parser(tokens: Tokens) -> IResult { +fn fun_decl_parser(tokens: Tokens) -> IResult> { map( tuple(( identifier_parser, @@ -185,7 +197,7 @@ fn fun_decl_parser(tokens: Tokens) -> IResult { separated_list0(comma_parser, identifier_parser), closing_paren_parser, ), - fun_decl_type_parser, + opt(type_parser), delimited( opening_brace_parser, many1(statement_parser), @@ -201,22 +213,22 @@ fn fun_decl_parser(tokens: Tokens) -> IResult { )(tokens) } -fn if_statement_parser(tokens: Tokens) -> IResult { +fn if_statement_parser(tokens: Tokens) -> IResult> { map( tuple(( if_parser, - delimited(opening_paren_parser, expr_parser, closing_paren_parser), + delimited(require(opening_paren_parser), expr_parser, require(closing_paren_parser)), delimited( - opening_brace_parser, + require(opening_brace_parser), many0(statement_parser), - closing_brace_parser, + require(closing_brace_parser), ), opt(preceded( else_parser, delimited( - opening_brace_parser, + require(opening_brace_parser), many0(statement_parser), - closing_brace_parser, + require(closing_brace_parser), ), )), )), @@ -230,10 +242,10 @@ fn if_statement_parser(tokens: Tokens) -> IResult { )(tokens) } -fn while_statement_parser(tokens: Tokens) -> IResult { +fn while_statement_parser(tokens: Tokens) -> IResult> { map( preceded( - while_parser, + require(while_parser), pair( delimited(opening_paren_parser, expr_parser, closing_paren_parser), delimited( @@ -247,7 +259,7 @@ fn while_statement_parser(tokens: Tokens) -> IResult { )(tokens) } -fn assign_statement_parser(tokens: Tokens) -> IResult { +fn assign_statement_parser(tokens: Tokens) -> IResult> { map( tuple(( pair(identifier_parser, field_parser), @@ -264,21 +276,21 @@ fn assign_statement_parser(tokens: Tokens) -> IResult { )(tokens) } -fn fun_call_statement_parser(tokens: Tokens) -> IResult { +fn fun_call_statement_parser(tokens: Tokens) -> IResult> { map( terminated(fun_call_parser, semicolon_parser), Statement::FunCall, )(tokens) } -fn return_statement_parser(tokens: Tokens) -> IResult { +fn return_statement_parser(tokens: Tokens) -> IResult> { map( delimited(return_parser, opt(expr_parser), semicolon_parser), Statement::Return, )(tokens) } -fn statement_parser(tokens: Tokens) -> IResult { +fn statement_parser(tokens: Tokens) -> IResult> { alt(( if_statement_parser, while_statement_parser, @@ -289,11 +301,11 @@ fn statement_parser(tokens: Tokens) -> IResult { ))(tokens) } -fn expr_parser(tokens: Tokens) -> IResult { +fn expr_parser(tokens: Tokens) -> IResult> { disjun_expr_parser(tokens) } -fn disjun_expr_parser(tokens: Tokens) -> IResult { +fn disjun_expr_parser(tokens: Tokens) -> IResult> { let (rest, start) = conjun_expr_parser(tokens)?; fold_many0( @@ -303,7 +315,7 @@ fn disjun_expr_parser(tokens: Tokens) -> IResult { )(rest) } -fn conjun_expr_parser(tokens: Tokens) -> IResult { +fn conjun_expr_parser(tokens: Tokens) -> IResult> { let (rest, start) = compare_expr_parser(tokens)?; fold_many0( @@ -313,7 +325,7 @@ fn conjun_expr_parser(tokens: Tokens) -> IResult { )(rest) } -fn compare_expr_parser(tokens: Tokens) -> IResult { +fn compare_expr_parser(tokens: Tokens) -> IResult> { let (rest, start) = concat_expr_parser(tokens)?; fold_many0( @@ -345,7 +357,7 @@ fn compare_expr_parser(tokens: Tokens) -> IResult { )(rest) } -fn concat_expr_parser(tokens: Tokens) -> IResult { +fn concat_expr_parser(tokens: Tokens) -> IResult> { let (rest, last) = term_expr_parser(tokens)?; let (tail, pairs) = many0(tuple((cons_parser, term_expr_parser)))(rest)?; @@ -365,7 +377,7 @@ fn concat_expr_parser(tokens: Tokens) -> IResult { } } -fn term_expr_parser(tokens: Tokens) -> IResult { +fn term_expr_parser(tokens: Tokens) -> IResult> { let (rest, start) = factor_expr_parser(tokens)?; fold_many0( @@ -383,7 +395,7 @@ fn term_expr_parser(tokens: Tokens) -> IResult { )(rest) } -fn factor_expr_parser(tokens: Tokens) -> IResult { +fn factor_expr_parser(tokens: Tokens) -> IResult> { let (rest, start) = unary_expr_parser(tokens)?; fold_many0( @@ -405,7 +417,7 @@ fn factor_expr_parser(tokens: Tokens) -> IResult { )(rest) } -fn unary_expr_parser(tokens: Tokens) -> IResult { +fn unary_expr_parser(tokens: Tokens) -> IResult> { map( many_till(alt((minus_parser, not_parser)), atom_expr_parser), |(unary_symbols, atom)| { @@ -426,7 +438,7 @@ fn unary_expr_parser(tokens: Tokens) -> IResult { )(tokens) } -fn atom_expr_parser(tokens: Tokens) -> IResult { +fn atom_expr_parser(tokens: Tokens) -> IResult> { alt(( // '(' Expr ')' // '(' Expr ',' Expr ')' @@ -442,7 +454,7 @@ fn atom_expr_parser(tokens: Tokens) -> IResult { ))(tokens) } -fn tuple_parenthesized_expr_parser(tokens: Tokens) -> IResult { +fn tuple_parenthesized_expr_parser(tokens: Tokens) -> IResult> { let (rest, expr) = preceded(opening_paren_parser, expr_parser)(tokens)?; let res = alt(( @@ -456,22 +468,22 @@ fn tuple_parenthesized_expr_parser(tokens: Tokens) -> IResult { res } -fn identifier_parser(tokens: Tokens) -> IResult { +fn identifier_parser(tokens: Tokens) -> IResult> { let (tail, mat) = take(1usize)(tokens)?; match mat[0].kind { TokenKind::Identifier(i) => Ok((tail, Id(i.to_string()))), - _ => Err(Err::Error(Error::new(tokens, ErrorKind::Tag))), + _ => Err(Err::Error(Error::from_error_kind(tokens, ErrorKind::Tag))), } } -fn variable_atom_parser(tokens: Tokens) -> IResult { +fn variable_atom_parser(tokens: Tokens) -> IResult> { map(tuple((identifier_parser, field_parser)), |(id, fields)| { Atom::Variable(Variable::new(id, fields)) })(tokens) } -fn field_parser(tokens: Tokens) -> IResult> { +fn field_parser(tokens: Tokens) -> IResult, Error> { fold_many0( alt((hd_parser, tl_parser, fst_parser, snd_parser)), Vec::new, @@ -490,7 +502,7 @@ fn field_parser(tokens: Tokens) -> IResult> { )(tokens) } -fn literal_atom_parser(tokens: Tokens) -> IResult { +fn literal_atom_parser(tokens: Tokens) -> IResult> { let (tail, mat) = take(1usize)(tokens)?; let atom = match mat[0].kind { @@ -498,13 +510,13 @@ fn literal_atom_parser(tokens: Tokens) -> IResult { TokenKind::Char(c) => Atom::CharLiteral(c), TokenKind::Integer(ref i) => Atom::IntLiteral(i.clone()), TokenKind::String(ref string) => Atom::StringLiteral(string.clone()), - _ => return Err(Err::Error(Error::new(tokens, ErrorKind::Tag))), + _ => return Err(Err::Error(Error::from_error_kind(tokens, ErrorKind::Tag))), }; Ok((tail, atom)) } -fn fun_call_parser(tokens: Tokens) -> IResult { +fn fun_call_parser(tokens: Tokens) -> IResult> { map( tuple(( identifier_parser, diff --git a/src/parser/pp.rs b/src/parser/pp.rs index d5a4211..d973559 100644 --- a/src/parser/pp.rs +++ b/src/parser/pp.rs @@ -84,12 +84,10 @@ impl PrettyPrintable for VarDecl { } } -impl PrettyPrintable for Option { - fn to_pretty(&self) -> Box { - match self { - Some(t) => Box::new(" :: ".join(t.to_pretty())), - None => Box::new("".join("")), - } +fn pretty_fun_type(ty: &Option) -> Box { + match ty { + Some(t) => Box::new(" :: ".join(t.to_pretty())), + None => Box::new("".join("")), } } @@ -99,7 +97,7 @@ impl PrettyPrintable for FunDecl { self.name.to_pretty().join("(").join( delimited(&", ", self.params.iter().map(Id::to_pretty)) .join(")") - .join(self.fun_type.to_pretty()) + .join(pretty_fun_type(&self.fun_type)) .join(Newline) .join("{") .join(block(delimited( @@ -112,28 +110,6 @@ impl PrettyPrintable for FunDecl { } } -impl PrettyPrintable for FunType { - fn to_pretty(&self) -> Box { - let delim = if self.param_types.is_empty() { "" } else { " " }; - - Box::new( - delimited(&" ", self.param_types.iter().map(Type::to_pretty)) - .join(delim) - .join("-> ") - .join(self.return_type.to_pretty()), - ) - } -} - -impl PrettyPrintable for ReturnType { - fn to_pretty(&self) -> Box { - match self { - ReturnType::Type(t) => t.to_pretty(), - ReturnType::Void => Box::new("Void"), - } - } -} - impl PrettyPrintable for Id { fn to_pretty(&self) -> Box { Box::new(self.0.to_string()) @@ -146,14 +122,26 @@ impl PrettyPrintable for Type { Type::Int => Box::new("Int"), Type::Bool => Box::new("Bool"), Type::Char => Box::new("Char"), + Type::Void => Box::new("Void"), + Type::String => Box::new("String"), Type::Tuple(t1, t2) => Box::new( "(".join(t1.to_pretty()) .join(",") .join(t2.to_pretty()) .join(")"), ), - Type::Array(t) => Box::new("[".join(t.to_pretty()).join("]")), - Type::Generic(id) => Box::new(id.to_pretty()), + Type::Function(args, return_type) => { + let delim = if args.is_empty() { "" } else { " " }; + + Box::new( + delimited(&" ", args.iter().map(Type::to_pretty)) + .join(delim) + .join("-> ") + .join(return_type.to_pretty()), + ) + } + Type::List(t) => Box::new("[".join(t.to_pretty()).join("]")), + Type::Var(id) => Box::new(id.to_pretty()), } } } diff --git a/src/parser/tc.rs b/src/parser/tc.rs new file mode 100644 index 0000000..b3a473b --- /dev/null +++ b/src/parser/tc.rs @@ -0,0 +1,399 @@ +use super::ast::*; + +use std::ops::{Deref, DerefMut}; +use std::{ + collections::{HashMap, HashSet}, + hash::Hash, +}; + +impl Type { + pub fn mgu(&self, other: &Self) -> Result { + match (self, other) { + (Type::Int, Type::Int) + | (Type::Bool, Type::Bool) + | (Type::Char, Type::Char) + | (Type::String, Type::String) => Ok(Subst::default()), + + (Type::Var(t1), t2) => t1.bind(t2), + (t1, Type::Var(t2)) => t2.bind(t1), + + (Type::Function(at1, rt1), Type::Function(at2, rt2)) => { + let s1 = rt1.mgu(&rt2)?; + let mut composed_subst = s1.clone(); + + if at1.len() != at2.len() { + return Err(String::from( + "Functions with different argument length cannot be unified", + )); + } + + let args_to_unify = at1.iter().zip(at2.iter()); + + for (t1, t2) in args_to_unify { + let s = t1.apply(&composed_subst).mgu(&t2.apply(&composed_subst))?; + composed_subst = composed_subst.compose(s); + } + + Ok(composed_subst) + } + + (Type::Tuple(t1, t2), Type::Tuple(t3, t4)) => { + let s1 = t1.mgu(t3)?; + let s2 = t2.apply(&s1).mgu(&t4.apply(&s1))?; + + Ok(s1.compose(s2)) + } + + (Type::List(t1), Type::List(t2)) => t1.mgu(&t2), + + (t1, t2) => Err(String::from(format!( + "Unification error: {:?} and {:?}", + t1, t2 + ))), + } + } +} + +#[derive(Default)] +pub struct VarGenerator { + counter: usize, +} + +impl VarGenerator { + pub fn new_var(&mut self) -> TypeVar { + let var = String::from(format!("t{}", self.counter)); + self.counter += 1; + + Id(var) + } +} + +// #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +// pub struct TypeVar(usize); + +type TypeVar = Id; + +impl TypeVar { + fn bind(&self, ty: &Type) -> Result { + if let Type::Var(v) = ty { + if v == self { + println!("Ehh"); + return Ok(Subst::default()); + } + } + + if ty.ftv().contains(self) { + return Err(String::from(format!( + "Occur check cannot construct infinite type: {:?}", + self + ))); + } + + let mut s = Subst::default(); + s.insert(self.clone(), ty.clone()); + + Ok(s) + } +} + +// Finite mapping from type variables to types. + +#[derive(Clone, Debug, Default)] +pub struct Subst(HashMap); + +impl Deref for Subst { + type Target = HashMap; + fn deref(&self) -> &HashMap { + &self.0 + } +} +impl DerefMut for Subst { + fn deref_mut(&mut self) -> &mut HashMap { + &mut self.0 + } +} + +impl Subst { + pub fn compose(&self, mut other: Subst) -> Subst { + let applied = other + .iter_mut() + .map(|(k, v)| (k.clone(), v.apply(self))) + .collect(); + + self.union(&Subst(applied)) + } + + fn union(&self, other: &Subst) -> Subst { + let mut unified = Subst::default(); + + for (k, v) in self.iter() { + unified.insert(k.clone(), v.clone()); + } + for (k, v) in other.iter() { + unified.insert(k.clone(), v.clone()); + } + + unified + } +} + +trait TypeInstance { + // Determines the free type variables of a type. + fn ftv(&self) -> HashSet; + + // Apply a substitution. + fn apply(&self, subst: &Subst) -> Self; +} + +impl<'a, T> TypeInstance for Vec +where + T: TypeInstance, +{ + // The free type variables of a vector of types is the union of the free type variables of each + // of the types in the vector. + fn ftv(&self) -> HashSet { + self.iter() + .map(|x| x.ftv()) + .fold(HashSet::new(), |set, x| set.union(&x).cloned().collect()) + } + + // To apply a substitution to a vector of types, just apply to each type in the vector. + fn apply(&self, s: &Subst) -> Vec { + self.iter().map(|x| x.apply(s)).collect() + } +} + +impl TypeInstance for Type { + fn ftv(&self) -> HashSet { + match self { + // Primitive types have no ftv. + Type::Int | Type::Bool | Type::Char | Type::String | Type::Void => HashSet::new(), + + // A TypeVar has one ftv: itself + Type::Var(t) => HashSet::from([t.to_owned()]), + + // A tuple has the unification of ftv of the inner types. + Type::Tuple(a, b) => a.ftv().union(&b.ftv()).cloned().collect(), + + Type::Function(arg_types, ret_type) => { + let mut ftv = ret_type.ftv(); + for at in arg_types.clone() { + ftv.extend(at.ftv()); + } + + ftv + } + + // A list has the inner type as ftv. + Type::List(inner) => inner.ftv(), + } + } + + fn apply(&self, substitution: &Subst) -> Type { + match self { + Type::Int | Type::Bool | Type::Char | Type::String | Type::Void => self.clone(), + + Type::Var(t) => substitution.get(t).cloned().unwrap_or(self.clone()), + + Type::Function(a, b) => Type::Function( + a.iter().map(|t| t.apply(substitution)).collect(), + Box::new(b.apply(substitution)), + ), + + Type::Tuple(a, b) => Type::Tuple( + Box::new(a.apply(substitution)), + Box::new(b.apply(substitution)), + ), + + Type::List(inner) => Type::List(Box::new(inner.apply(substitution))), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct TypeScheme { + pub ty_vars: Vec, + pub ty: Type, +} + +impl TypeInstance for TypeScheme { + fn ftv(&self) -> HashSet { + self.ty + .ftv() + .difference(&self.ty_vars.iter().cloned().collect()) + .cloned() + .collect() + } + + fn apply(&self, s: &Subst) -> TypeScheme { + let mut filtered_substitution = s.clone(); + + filtered_substitution.retain(|k, _| !self.ty_vars.contains(k)); + + Self::new(self.ty_vars.clone(), self.ty.apply(&filtered_substitution)) + } +} + +impl TypeScheme { + pub fn new(vars: Vec, ty: Type) -> Self { + Self { + ty_vars: vars, + ty: ty, + } + } + + fn instantiate(&self, generator: &mut VarGenerator) -> Type { + let newvars = self.ty_vars.iter().map(|_| Type::Var(generator.new_var())); + self.ty + .apply(&Subst(self.ty_vars.iter().cloned().zip(newvars).collect())) + } +} + +#[derive(Default, Debug)] +pub struct TypeEnv(HashMap); + +impl TypeInstance for TypeEnv { + fn ftv(&self) -> HashSet { + self.0 + .values() + .map(|x| x.clone()) + .collect::>() + .ftv() + } + + fn apply(&self, s: &Subst) -> TypeEnv { + TypeEnv( + self.0 + .iter() + .map(|(k, v)| (k.clone(), v.apply(s))) + .collect(), + ) + } +} + +impl Deref for TypeEnv { + type Target = HashMap; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for TypeEnv { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl TypeEnv { + fn generalize(&self, ty: &Type) -> TypeScheme { + TypeScheme::new( + ty.ftv().difference(&self.ftv()).cloned().collect(), + ty.clone(), + ) + } + + // Function M from slides + fn ti( + &self, + expr: &Expr, + expected: &Type, + generator: &mut VarGenerator, + ) -> Result { + // match d { + // // VarDecl + // _ => Err(String::from("Unimplemented")), + // } + + match expr { + Expr::Add(e1, e2) + | Expr::Sub(e1, e2) + | Expr::Mul(e1, e2) + | Expr::Div(e1, e2) + | Expr::Mod(e1, e2) => { + let s1 = self.ti(e1, &Type::Int, generator)?; + self.apply(&s1); + + let s2 = self.ti(e2, &Type::Int, generator)?.compose(s1); + let expected_applied = expected.apply(&s2); + + Ok(expected_applied.mgu(&Type::Int)?.compose(s2)) + } + + Expr::Atom(ref a) => match a { + // Atom::Variable(v) => { + // // TODO: this doesnt work for field calls on a variable. needs to be extended. + // match self.get(&v.name) { + // Some(s) => Ok((Subst::default(), s.instantiate(generator))), + + // None => Err(String::from(format!("Unbounded variable: {:?}", v.name))), + // } + // }, + Atom::IntLiteral(_) => expected.mgu(&Type::Int), + + Atom::BoolLiteral(_) => expected.mgu(&Type::Bool), + + _ => Err(String::from("Unimplemented")), + }, + _ => Err(String::from("Unimplemented")), + } + } +} + +pub fn run(program: &mut Program) -> Result<(), String> { + let mut context = TypeEnv::default(); + let mut generator = VarGenerator::default(); + + for p in program.iter_mut() { + match p { + Decl::VarDecl(v) => { + let fresh = generator.new_var(); + + if let Ok(s) = context.ti(&v.value, &Type::Var(fresh.clone()), &mut generator) { + context.apply(&s); + context.insert(v.name.clone(), TypeScheme::new(vec![], s[&fresh].clone())); + } + + // match &v.var_type { + // // Type check required. + // Some(t) => { + + // } + // // Type inference required. + // None => { + + // } + // } + } + Decl::FunDecl(f) => return Err(String::from("Unimplemented")), + } + } + + println!("Context: {:?}", context); + + Ok(()) +} + +#[cfg(test)] + +mod test { + + use super::*; + + use crate::parser::*; + use crate::scanner::*; + use crate::token::*; + + #[test] + fn test_literal_inference() { + const PROGRAM: &str = r" + var myInt = 123 + 456; + "; + + let tokens: Vec = Scanner::new(PROGRAM).collect(); + let tokens = Tokens::new(&tokens); + + let (rest, mut program) = program_parser(tokens).unwrap(); + assert!(rest.is_empty()); + + assert_eq!(Ok(()), run(&mut program)); + } +} diff --git a/src/parser/test.rs b/src/parser/test.rs index 4678660..44174fd 100644 --- a/src/parser/test.rs +++ b/src/parser/test.rs @@ -491,99 +491,103 @@ fn test_var_decl_statement_parser() { let (rest, fun_decl) = fun_decl_parser(tokens).unwrap(); assert!(rest.is_empty()); - assert_eq!(fun_decl, FunDecl { - name: Id("f".to_owned()), - params: vec![Id("x".to_string())], - fun_type: None, - statements: vec![ - Statement::If(If { - cond: Expr::Atom(Atom::Variable(Variable::new(Id("x".to_string()), Vec::new()))), + assert_eq!( + fun_decl, + FunDecl { + name: Id("f".to_owned()), + params: vec![Id("x".to_string())], + fun_type: None, + statements: vec![Statement::If(If { + cond: Expr::Atom(Atom::Variable(Variable::new( + Id("x".to_string()), + Vec::new() + ))), if_true: vec![Statement::VarDecl(VarDecl { var_type: None, name: Id("y".to_string()), value: Expr::Atom(Atom::BoolLiteral(false)), })], if_false: Vec::new(), - }), - ] - }) + }),] + } + ) } -#[test] -fn test_fun_decl_type_parser() { - // No type - - let tokens: Vec = Scanner::new("").collect(); - let tokens = Tokens::new(&tokens); - - let (rest, fun_type) = fun_decl_type_parser(tokens).unwrap(); - - assert!(rest.is_empty()); - assert_eq!(fun_type, None); - - // No params, void return type - - let tokens: Vec = Scanner::new(":: -> Void").collect(); - let tokens = Tokens::new(&tokens); - - let (rest, fun_type) = fun_decl_type_parser(tokens).unwrap(); - - assert!(rest.is_empty()); - assert_eq!( - fun_type, - Some(FunType { - param_types: Vec::new(), - return_type: ReturnType::Void, - }) - ); - - // Params and return type - - let tokens: Vec = Scanner::new(":: Int Bool -> Char").collect(); - let tokens = Tokens::new(&tokens); - - let (rest, fun_type) = fun_decl_type_parser(tokens).unwrap(); - - assert!(rest.is_empty()); - assert_eq!( - fun_type, - Some(FunType { - param_types: vec![Type::Int, Type::Bool], - return_type: ReturnType::Type(Type::Char), - }) - ); - - // Complex types - - let tokens: Vec = Scanner::new(":: (a, [b]) [(Int,c)] -> ((a,b),c)").collect(); - let tokens = Tokens::new(&tokens); - - let (rest, fun_type) = fun_decl_type_parser(tokens).unwrap(); - - assert!(rest.is_empty()); - assert_eq!( - fun_type, - Some(FunType { - param_types: vec![ - Type::Tuple( - Box::new(Type::Generic(Id("a".to_string()))), - Box::new(Type::Array(Box::new(Type::Generic(Id("b".to_string()))))), - ), - Type::Array(Box::new(Type::Tuple( - Box::new(Type::Int), - Box::new(Type::Generic(Id("c".to_string()))), - )),), - ], - return_type: ReturnType::Type(Type::Tuple( - Box::new(Type::Tuple( - Box::new(Type::Generic(Id("a".to_string()))), - Box::new(Type::Generic(Id("b".to_string()))), - )), - Box::new(Type::Generic(Id("c".to_string()))), - )) - }) - ); -} +// #[test] +// fn test_fun_decl_type_parser() { +// // No type + +// let tokens: Vec = Scanner::new("").collect(); +// let tokens = Tokens::new(&tokens); + +// let (rest, fun_type) = fun_decl_type_parser(tokens).unwrap(); + +// assert!(rest.is_empty()); +// assert_eq!(fun_type, None); + +// // No params, void return type + +// let tokens: Vec = Scanner::new(":: -> Void").collect(); +// let tokens = Tokens::new(&tokens); + +// let (rest, fun_type) = fun_decl_type_parser(tokens).unwrap(); + +// assert!(rest.is_empty()); +// assert_eq!( +// fun_type, +// Some(FunType { +// param_types: Vec::new(), +// return_type: ReturnType::Void, +// }) +// ); + +// // Params and return type + +// let tokens: Vec = Scanner::new(":: Int Bool -> Char").collect(); +// let tokens = Tokens::new(&tokens); + +// let (rest, fun_type) = fun_decl_type_parser(tokens).unwrap(); + +// assert!(rest.is_empty()); +// assert_eq!( +// fun_type, +// Some(FunType { +// param_types: vec![Type::Int, Type::Bool], +// return_type: ReturnType::Type(Type::Char), +// }) +// ); + +// // Complex types + +// let tokens: Vec = Scanner::new(":: (a, [b]) [(Int,c)] -> ((a,b),c)").collect(); +// let tokens = Tokens::new(&tokens); + +// let (rest, fun_type) = fun_decl_type_parser(tokens).unwrap(); + +// assert!(rest.is_empty()); +// assert_eq!( +// fun_type, +// Some(FunType { +// param_types: vec![ +// Type::Tuple( +// Box::new(Type::Generic(Id("a".to_string()))), +// Box::new(Type::Array(Box::new(Type::Generic(Id("b".to_string()))))), +// ), +// Type::Array(Box::new(Type::Tuple( +// Box::new(Type::Int), +// Box::new(Type::Generic(Id("c".to_string()))), +// )),), +// ], +// return_type: ReturnType::Type(Type::Tuple( +// Box::new(Type::Tuple( +// Box::new(Type::Generic(Id("a".to_string()))), +// Box::new(Type::Generic(Id("b".to_string()))), +// )), +// Box::new(Type::Generic(Id("c".to_string()))), +// )) +// }) +// ); +// } #[test] fn test_fun_call_in_return() { @@ -673,10 +677,10 @@ fn test_fun_decl_parser() { FunDecl { name: Id("someFunction".to_string()), params: vec![Id("a".to_string()), Id("b".to_string())], - fun_type: Some(FunType { - param_types: vec![Type::Int, Type::Array(Box::new(Type::Int))], - return_type: ReturnType::Type(Type::Array(Box::new(Type::Int))), - }), + fun_type: Some(Type::Function( + vec![Type::Int, Type::List(Box::new(Type::Int))], + Box::new(Type::List(Box::new(Type::Int))) + )), statements: vec![ Statement::VarDecl(VarDecl { var_type: Some(Type::Int), @@ -700,8 +704,8 @@ fn test_type_parser() { "(Int,Bool)", Type::Tuple(Box::new(Type::Int), Box::new(Type::Bool)), ), - ("[Char]", Type::Array(Box::new(Type::Char))), - ("a", Type::Generic(Id("a".to_string()))), + ("[Char]", Type::List(Box::new(Type::Char))), + ("a", Type::Var(Id("a".to_string()))), // More complex ( "((Int,Bool),Char)", @@ -712,9 +716,9 @@ fn test_type_parser() { ), ( "[(a,b)]", - Type::Array(Box::new(Type::Tuple( - Box::new(Type::Generic(Id("a".to_string()))), - Box::new(Type::Generic(Id("b".to_string()))), + Type::List(Box::new(Type::Tuple( + Box::new(Type::Var(Id("a".to_string()))), + Box::new(Type::Var(Id("b".to_string()))), ))), ), ]; diff --git a/src/token.rs b/src/token.rs index 264673d..429c3fb 100644 --- a/src/token.rs +++ b/src/token.rs @@ -4,6 +4,7 @@ use nom::{InputIter, InputLength, InputTake, Needed, Slice}; use num_bigint::BigUint; +use std::fmt; use std::iter::Enumerate; use std::ops::{Range, RangeFrom, RangeFull, RangeTo}; @@ -407,3 +408,61 @@ impl<'a> TokenKind<'a> { } } } + +impl<'a> fmt::Display for TokenKind<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let s = match self { + Self::Plus => "+", + Self::Minus => "-", + Self::Divide => "/", + Self::Times => "*", + Self::Modulo => "%", + Self::Equals => "==", + Self::Lt => "<", + Self::Gt => ">", + Self::Le => "<=", + Self::Ge => ">=", + Self::NotEquals => "!=", + Self::And => "&&", + Self::Assignment => "=", + Self::Or => "||", + Self::Cons => ":", + Self::Not => "!", + Self::Integer(i) => return write!(f, "{i}"), + Self::Bool(b) => { + let s = if *b { "True" } else { "False" }; + return write!(f, "{s}"); + } + Self::Char(c) => return write!(f, "'{c}'"), + Self::String(s) => return write!(f, r#""{s}""#), + Self::Identifier(i) => i, + Self::Var => "var", + Self::If => "if", + Self::Else => "else", + Self::While => "while", + Self::Return => "return", + Self::IntType => "Int", + Self::BoolType => "Bool", + Self::CharType => "Char", + Self::VoidType => "Void", + Self::DoubleColon => "::", + Self::RightArrow => "->", + Self::Hd => ".hd", + Self::Tl => ".tl", + Self::Fst => ".fst", + Self::Snd => ".snd", + Self::EmptyList => "[]", + Self::Semicolon => ";", + Self::Comma => ",", + Self::OpeningParen => "(", + Self::ClosingParen => ")", + Self::OpeningBrace => "{", + Self::ClosingBrace => "}", + Self::OpeningSquare => "[", + Self::ClosingSquare => "]", + Self::Error(e) => return write!(f, "{e:?}"), + }; + + write!(f, "{s}") + } +}