diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index 6d3a1eb6d6..b75e977b5d 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -18,7 +18,10 @@ use num_complex::Complex64; use num_traits::ToPrimitive; use rustpython_ast::located::{self as located_ast, Located}; use rustpython_compiler_core::{ - bytecode::{self, Arg as OpArgMarker, CodeObject, ConstantData, Instruction, OpArg, OpArgType}, + bytecode::{ + self, Arg as OpArgMarker, CodeObject, ComparisonOperator, ConstantData, Instruction, OpArg, + OpArgType, + }, Mode, }; use rustpython_parser_core::source_code::{LineNumber, SourceLocation}; @@ -211,6 +214,12 @@ macro_rules! emit { }; } +struct PatternContext { + current_block: usize, + blocks: Vec, + allow_irrefutable: bool, +} + impl Compiler { fn new(opts: CompileOpts, source_path: String, code_name: String) -> Self { let module_code = ir::CodeInfo { @@ -1755,14 +1764,152 @@ impl Compiler { Ok(()) } + fn compile_pattern_value( + &mut self, + value: &located_ast::PatternMatchValue, + _pattern_context: &mut PatternContext, + ) -> CompileResult<()> { + self.compile_expression(&value.value)?; + emit!( + self, + Instruction::CompareOperation { + op: ComparisonOperator::Equal + } + ); + Ok(()) + } + + fn compile_pattern_as( + &mut self, + as_pattern: &located_ast::PatternMatchAs, + pattern_context: &mut PatternContext, + ) -> CompileResult<()> { + if as_pattern.pattern.is_none() && !pattern_context.allow_irrefutable { + // TODO: better error message + if let Some(_name) = as_pattern.name.as_ref() { + return Err( + self.error_loc(CodegenErrorType::InvalidMatchCase, as_pattern.location()) + ); + } + return Err(self.error_loc(CodegenErrorType::InvalidMatchCase, as_pattern.location())); + } + // Need to make a copy for (possibly) storing later: + emit!(self, Instruction::Duplicate); + if let Some(pattern) = &as_pattern.pattern { + self.compile_pattern_inner(pattern, pattern_context)?; + } + if let Some(name) = as_pattern.name.as_ref() { + self.store_name(name.as_str())?; + } else { + emit!(self, Instruction::Pop); + } + Ok(()) + } + + fn compile_pattern_inner( + &mut self, + pattern_type: &located_ast::Pattern, + pattern_context: &mut PatternContext, + ) -> CompileResult<()> { + match &pattern_type { + located_ast::Pattern::MatchValue(value) => { + self.compile_pattern_value(value, pattern_context) + } + located_ast::Pattern::MatchAs(as_pattern) => { + self.compile_pattern_as(as_pattern, pattern_context) + } + _ => { + eprintln!("not implemented pattern type: {pattern_type:?}"); + Err(self.error(CodegenErrorType::NotImplementedYet)) + } + } + } + + fn compile_pattern( + &mut self, + pattern_type: &located_ast::Pattern, + pattern_context: &mut PatternContext, + ) -> CompileResult<()> { + self.compile_pattern_inner(pattern_type, pattern_context)?; + emit!( + self, + Instruction::JumpIfFalse { + target: pattern_context.blocks[pattern_context.current_block + 1] + } + ); + Ok(()) + } + + fn compile_match_inner( + &mut self, + subject: &located_ast::Expr, + cases: &[located_ast::MatchCase], + pattern_context: &mut PatternContext, + ) -> CompileResult<()> { + self.compile_expression(subject)?; + pattern_context.blocks = std::iter::repeat_with(|| self.new_block()) + .take(cases.len() + 1) + .collect::>(); + let end_block = *pattern_context.blocks.last().unwrap(); + + let _match_case_type = cases.last().expect("cases is not empty"); + // TODO: get proper check for default case + // let has_default = match_case_type.pattern.is_match_as() && 1 < cases.len(); + let has_default = false; + for i in 0..cases.len() - (has_default as usize) { + self.switch_to_block(pattern_context.blocks[i]); + pattern_context.current_block = i; + pattern_context.allow_irrefutable = cases[i].guard.is_some() || i == cases.len() - 1; + let m = &cases[i]; + // Only copy the subject if we're *not* on the last case: + if i != cases.len() - has_default as usize - 1 { + emit!(self, Instruction::Duplicate); + } + self.compile_pattern(&m.pattern, pattern_context)?; + self.compile_statements(&m.body)?; + emit!(self, Instruction::Jump { target: end_block }); + } + // TODO: below code is not called and does not work + if has_default { + // A trailing "case _" is common, and lets us save a bit of redundant + // pushing and popping in the loop above: + let m = &cases.last().unwrap(); + self.switch_to_block(*pattern_context.blocks.last().unwrap()); + if cases.len() == 1 { + // No matches. Done with the subject: + emit!(self, Instruction::Pop); + } else { + // Show line coverage for default case (it doesn't create bytecode) + // emit!(self, Instruction::Nop); + } + self.compile_statements(&m.body)?; + } + + self.switch_to_block(end_block); + + let code = self.current_code_info(); + pattern_context + .blocks + .iter() + .zip(pattern_context.blocks.iter().skip(1)) + .for_each(|(a, b)| { + code.blocks[a.0 as usize].next = *b; + }); + Ok(()) + } + fn compile_match( &mut self, subject: &located_ast::Expr, cases: &[located_ast::MatchCase], ) -> CompileResult<()> { - eprintln!("match subject: {subject:?}"); - eprintln!("match cases: {cases:?}"); - Err(self.error(CodegenErrorType::NotImplementedYet)) + let mut pattern_context = PatternContext { + current_block: usize::MAX, + blocks: Vec::new(), + allow_irrefutable: false, + }; + self.compile_match_inner(subject, cases, &mut pattern_context)?; + Ok(()) } fn compile_chained_comparison( diff --git a/compiler/codegen/src/error.rs b/compiler/codegen/src/error.rs index 017f735105..27333992df 100644 --- a/compiler/codegen/src/error.rs +++ b/compiler/codegen/src/error.rs @@ -30,6 +30,8 @@ pub enum CodegenErrorType { TooManyStarUnpack, EmptyWithItems, EmptyWithBody, + DuplicateStore(String), + InvalidMatchCase, NotImplementedYet, // RustPython marker for unimplemented features } @@ -75,6 +77,12 @@ impl fmt::Display for CodegenErrorType { EmptyWithBody => { write!(f, "empty body on With") } + DuplicateStore(s) => { + write!(f, "duplicate store {s}") + } + InvalidMatchCase => { + write!(f, "invalid match case") + } NotImplementedYet => { write!(f, "RustPython does not implement this feature yet") } diff --git a/compiler/codegen/src/symboltable.rs b/compiler/codegen/src/symboltable.rs index 8522c82037..bbb134facf 100644 --- a/compiler/codegen/src/symboltable.rs +++ b/compiler/codegen/src/symboltable.rs @@ -886,11 +886,13 @@ impl SymbolTableBuilder { self.scan_statements(orelse)?; self.scan_statements(finalbody)?; } - Stmt::Match(StmtMatch { subject, .. }) => { - return Err(SymbolTableError { - error: "match expression is not implemented yet".to_owned(), - location: Some(subject.location()), - }); + Stmt::Match(StmtMatch { subject, cases, .. }) => { + self.scan_expression(subject, ExpressionContext::Load)?; + for case in cases { + // TODO: below + // self.scan_pattern(&case.pattern, ExpressionContext::Load)?; + self.scan_statements(&case.body)?; + } } Stmt::Raise(StmtRaise { exc, cause, .. }) => { if let Some(expression) = exc {