From ecc18dc989aaa42f465f45594df6000f5af0a92d Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Fri, 10 May 2024 14:52:16 +0530 Subject: [PATCH] Added support & tests for subs attribute --- integration_tests/symbolics_05.py | 8 +++ src/libasr/pass/intrinsic_function_registry.h | 6 ++ src/libasr/pass/intrinsic_functions.h | 59 +++++++++++++++++++ src/libasr/pass/replace_symbolic.cpp | 17 ++++++ src/lpython/semantics/python_ast_to_asr.cpp | 8 +-- src/lpython/semantics/python_attribute_eval.h | 16 ++++- 6 files changed, 109 insertions(+), 5 deletions(-) diff --git a/integration_tests/symbolics_05.py b/integration_tests/symbolics_05.py index 46a6d39860..b503bbcdda 100644 --- a/integration_tests/symbolics_05.py +++ b/integration_tests/symbolics_05.py @@ -40,4 +40,12 @@ def test_operations(): assert(c.args[0] == x) assert(d.args[0] == x) + # test subs + b1: S = b.subs(x, y) + b1 = b1.subs(z, y) + assert(a.subs(x, y) == S(4)*y**S(2)) + assert(b1 == S(27)*y**S(3)) + assert(c.subs(x, y) == sin(y)) + assert(d.subs(x, z) == cos(z)) + test_operations() \ No newline at end of file diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index e81feeeabd..9d26a37954 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -148,6 +148,7 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(SymbolicInteger) INTRINSIC_NAME_CASE(SymbolicDiff) INTRINSIC_NAME_CASE(SymbolicExpand) + INTRINSIC_NAME_CASE(SymbolicSubs) INTRINSIC_NAME_CASE(SymbolicSin) INTRINSIC_NAME_CASE(SymbolicCos) INTRINSIC_NAME_CASE(SymbolicLog) @@ -435,6 +436,8 @@ namespace IntrinsicElementalFunctionRegistry { {nullptr, &SymbolicDiff::verify_args}}, {static_cast(IntrinsicElementalFunctions::SymbolicExpand), {nullptr, &SymbolicExpand::verify_args}}, + {static_cast(IntrinsicElementalFunctions::SymbolicSubs), + {nullptr, &SymbolicSubs::verify_args}}, {static_cast(IntrinsicElementalFunctions::SymbolicSin), {nullptr, &SymbolicSin::verify_args}}, {static_cast(IntrinsicElementalFunctions::SymbolicCos), @@ -724,6 +727,8 @@ namespace IntrinsicElementalFunctionRegistry { "SymbolicDiff"}, {static_cast(IntrinsicElementalFunctions::SymbolicExpand), "SymbolicExpand"}, + {static_cast(IntrinsicElementalFunctions::SymbolicSubs), + "SymbolicSubs"}, {static_cast(IntrinsicElementalFunctions::SymbolicSin), "SymbolicSin"}, {static_cast(IntrinsicElementalFunctions::SymbolicCos), @@ -889,6 +894,7 @@ namespace IntrinsicElementalFunctionRegistry { {"SymbolicInteger", {&SymbolicInteger::create_SymbolicInteger, &SymbolicInteger::eval_SymbolicInteger}}, {"diff", {&SymbolicDiff::create_SymbolicDiff, &SymbolicDiff::eval_SymbolicDiff}}, {"expand", {&SymbolicExpand::create_SymbolicExpand, &SymbolicExpand::eval_SymbolicExpand}}, + {"subs", {&SymbolicSubs::create_SymbolicSubs, &SymbolicSubs::eval_SymbolicSubs}}, {"SymbolicSin", {&SymbolicSin::create_SymbolicSin, &SymbolicSin::eval_SymbolicSin}}, {"SymbolicCos", {&SymbolicCos::create_SymbolicCos, &SymbolicCos::eval_SymbolicCos}}, {"SymbolicLog", {&SymbolicLog::create_SymbolicLog, &SymbolicLog::eval_SymbolicLog}}, diff --git a/src/libasr/pass/intrinsic_functions.h b/src/libasr/pass/intrinsic_functions.h index 23b5be711d..a644854e12 100644 --- a/src/libasr/pass/intrinsic_functions.h +++ b/src/libasr/pass/intrinsic_functions.h @@ -149,6 +149,7 @@ enum class IntrinsicElementalFunctions : int64_t { SymbolicInteger, SymbolicDiff, SymbolicExpand, + SymbolicSubs, SymbolicSin, SymbolicCos, SymbolicLog, @@ -5651,6 +5652,64 @@ create_symbolic_binary_macro(SymbolicDiv) create_symbolic_binary_macro(SymbolicPow) create_symbolic_binary_macro(SymbolicDiff) +#define create_symbolic_ternary_macro(X) \ +namespace X{ \ + static inline void verify_args(const ASR::IntrinsicElementalFunction_t& x, \ + diag::Diagnostics& diagnostics) { \ + ASRUtils::require_impl(x.n_args == 3, "Intrinsic function `"#X"` accepts" \ + "exactly 3 arguments", x.base.base.loc, diagnostics); \ + \ + ASR::ttype_t* arg1_type = ASRUtils::expr_type(x.m_args[0]); \ + ASR::ttype_t* arg2_type = ASRUtils::expr_type(x.m_args[1]); \ + ASR::ttype_t* arg3_type = ASRUtils::expr_type(x.m_args[2]); \ + \ + ASRUtils::require_impl(ASR::is_a(*arg1_type) && \ + ASR::is_a(*arg2_type) && \ + ASR::is_a(*arg3_type), \ + "All arguments of `"#X"` must be of type SymbolicExpression", \ + x.base.base.loc, diagnostics); \ + } \ + \ + static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \ + ASR::ttype_t *, Vec &/*args*/, diag::Diagnostics& /*diag*/) { \ + /*TODO*/ \ + return nullptr; \ + } \ + \ + static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \ + Vec& args, \ + diag::Diagnostics& diag) { \ + if (args.size() != 3) { \ + append_error(diag, "Intrinsic function `"#X"` accepts exactly 3 arguments", \ + loc); \ + return nullptr; \ + } \ + \ + for (size_t i = 0; i < args.size(); i++) { \ + ASR::ttype_t* argtype = ASRUtils::expr_type(args[i]); \ + if(!ASR::is_a(*argtype)) { \ + append_error(diag, \ + "Arguments of `"#X"` function must be of type SymbolicExpression", \ + args[i]->base.loc); \ + return nullptr; \ + } \ + } \ + \ + Vec arg_values; \ + arg_values.reserve(al, args.size()); \ + for( size_t i = 0; i < args.size(); i++ ) { \ + arg_values.push_back(al, ASRUtils::expr_value(args[i])); \ + } \ + ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc)); \ + ASR::expr_t* compile_time_value = eval_##X(al, loc, to_type, arg_values, diag); \ + return ASR::make_IntrinsicElementalFunction_t(al, loc, \ + static_cast(IntrinsicElementalFunctions::X), \ + args.p, args.size(), 0, to_type, compile_time_value); \ + } \ +} // namespace X + +create_symbolic_ternary_macro(SymbolicSubs) + #define create_symbolic_constants_macro(X) \ namespace X { \ static inline void verify_args(const ASR::IntrinsicElementalFunction_t& x, \ diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index f71bf6b2c0..58daede218 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -57,6 +57,12 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0], x->m_args[1], x->m_args[2])); \ + break; } + #define BASIC_BINOP(SYM, name) \ case LCompilers::ASRUtils::IntrinsicElementalFunctions::Symbolic##SYM: { \ pass_result.push_back(al, basic_binop(loc, "basic_"#name, target, \ @@ -241,6 +247,16 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 9a06a8453b..169c6c957f 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -7565,7 +7565,7 @@ we will have to use something else. } else { st = current_scope->resolve_symbol(mod_name); std::set symbolic_attributes = { - "diff", "expand", "has" + "diff", "expand", "has", "subs" }; std::set symbolic_constants = { "pi", "E", "oo" @@ -7640,7 +7640,7 @@ we will have to use something else. } else if (AST::is_a(*at->m_value)) { AST::BinOp_t* bop = AST::down_cast(at->m_value); std::set symbolic_attributes = { - "diff", "expand", "has" + "diff", "expand", "has", "subs" }; if (symbolic_attributes.find(at->m_attr) != symbolic_attributes.end()){ switch (bop->m_op) { @@ -7687,7 +7687,7 @@ we will have to use something else. } else if (AST::is_a(*at->m_value)) { AST::Call_t* call = AST::down_cast(at->m_value); std::set symbolic_attributes = { - "diff", "expand", "has" + "diff", "expand", "has", "subs" }; if (symbolic_attributes.find(at->m_attr) != symbolic_attributes.end()){ std::set symbolic_functions = { @@ -7819,7 +7819,7 @@ we will have to use something else. if (!s) { std::string intrinsic_name = call_name; std::set not_cpython_builtin = { - "sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand", "trunc", "fix", + "sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand", "trunc", "fix", "subs", "sum" // For sum called over lists }; std::set symbolic_functions = { diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 369b8486a5..f8926a3eb8 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -49,7 +49,8 @@ struct AttributeHandler { {"expand", &eval_symbolic_expand}, {"has", &eval_symbolic_has_symbol}, {"is_integer", &eval_symbolic_is_integer}, - {"is_positive", &eval_symbolic_is_positive} + {"is_positive", &eval_symbolic_is_positive}, + {"subs", &eval_symbolic_subs} }; } @@ -590,6 +591,19 @@ struct AttributeHandler { return create_function(al, loc, args_with_list, diag); } + static ASR::asr_t* eval_symbolic_subs(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &diag) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicElementalFunctionRegistry::get_create_function("subs"); + return create_function(al, loc, args_with_list, diag); + } + }; // AttributeHandler } // namespace LCompilers::LPython