diff --git a/integration_tests/symbolics_02.py b/integration_tests/symbolics_02.py index 74f4a4af35..1bfda7e74d 100644 --- a/integration_tests/symbolics_02.py +++ b/integration_tests/symbolics_02.py @@ -1,35 +1,71 @@ -from sympy import Symbol, pi +from sympy import Symbol, pi, Add, Mul, Pow from lpython import S def test_symbolic_operations(): x: S = Symbol('x') y: S = Symbol('y') - p1: S = pi - p2: S = pi + pi1: S = pi + pi2: S = pi # Addition z: S = x + y + z1: bool = z.func == Add + z2: bool = z.func == Mul assert(z == x + y) + assert(z1 == True) + assert(z2 == False) + if z.func == Add: + assert True + else: + assert False print(z) # Subtraction w: S = x - y + w1: bool = w.func == Add assert(w == x - y) + assert(w1 == True) + if w.func == Add: + assert True + else: + assert False print(w) # Multiplication u: S = x * y + u1: bool = u.func == Mul assert(u == x * y) + assert(u1 == True) + if u.func == Mul: + assert True + else: + assert False print(u) # Division v: S = x / y + v1: bool = v.func == Mul assert(v == x / y) + assert(v1 == True) + if v.func == Mul: + assert True + else: + assert False print(v) # Power p: S = x ** y + p1: bool = p.func == Pow + p2: bool = p.func == Add + p3: bool = p.func == Mul assert(p == x ** y) + assert(p1 == True) + assert(p2 == False) + assert(p3 == False) + if p.func == Pow: + assert True + else: + assert False print(p) # Casting @@ -40,13 +76,13 @@ def test_symbolic_operations(): print(c) # Comparison - b1: bool = p1 == p2 + b1: bool = pi1 == pi2 print(b1) assert(b1 == True) - b2: bool = p1 != pi + b2: bool = pi1 != pi print(b2) assert(b2 == False) - b3: bool = p1 != x + b3: bool = pi1 != x print(b3) assert(b3 == True) b4: bool = pi == Symbol("x") diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index 4878d242f5..7cd59fe8b9 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -79,6 +79,9 @@ enum class IntrinsicScalarFunctions : int64_t { SymbolicExp, SymbolicAbs, SymbolicHasSymbolQ, + SymbolicAddQ, + SymbolicMulQ, + SymbolicPowQ, // ... }; @@ -140,6 +143,9 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(SymbolicExp) INTRINSIC_NAME_CASE(SymbolicAbs) INTRINSIC_NAME_CASE(SymbolicHasSymbolQ) + INTRINSIC_NAME_CASE(SymbolicAddQ) + INTRINSIC_NAME_CASE(SymbolicMulQ) + INTRINSIC_NAME_CASE(SymbolicPowQ) default : { throw LCompilersException("pickle: intrinsic_id not implemented"); } @@ -3100,6 +3106,48 @@ namespace SymbolicHasSymbolQ { } } // namespace SymbolicHasSymbolQ +#define create_symbolic_query_macro(X) \ +namespace X { \ + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \ + diag::Diagnostics& diagnostics) { \ + const Location& loc = x.base.base.loc; \ + ASRUtils::require_impl(x.n_args == 1, \ + #X " must have exactly 1 input argument", loc, diagnostics); \ + \ + ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); \ + ASRUtils::require_impl(ASR::is_a(*input_type), \ + #X " expects an argument of type SymbolicExpression", loc, diagnostics); \ + } \ + \ + static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \ + ASR::ttype_t *, Vec &/*args*/) { \ + /*TODO*/ \ + return nullptr; \ + } \ + \ + static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \ + Vec& args, \ + const std::function err) { \ + if (args.size() != 1) { \ + err("Intrinsic " #X " function accepts exactly 1 argument", loc); \ + } \ + \ + ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); \ + if (!ASR::is_a(*argtype)) { \ + err("Argument of " #X " function must be of type SymbolicExpression", \ + args[0]->base.loc); \ + } \ + \ + return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_##X, \ + static_cast(IntrinsicScalarFunctions::X), 0, logical); \ + } \ +} // namespace X + +create_symbolic_query_macro(SymbolicAddQ) +create_symbolic_query_macro(SymbolicMulQ) +create_symbolic_query_macro(SymbolicPowQ) + + #define create_symbolic_unary_macro(X) \ namespace X { \ static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \ @@ -3253,6 +3301,12 @@ namespace IntrinsicScalarFunctionRegistry { {nullptr, &SymbolicAbs::verify_args}}, {static_cast(IntrinsicScalarFunctions::SymbolicHasSymbolQ), {nullptr, &SymbolicHasSymbolQ::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicAddQ), + {nullptr, &SymbolicAddQ::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicMulQ), + {nullptr, &SymbolicMulQ::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicPowQ), + {nullptr, &SymbolicPowQ::verify_args}}, }; static const std::map& intrinsic_function_id_to_name = { @@ -3357,6 +3411,12 @@ namespace IntrinsicScalarFunctionRegistry { "SymbolicAbs"}, {static_cast(IntrinsicScalarFunctions::SymbolicHasSymbolQ), "SymbolicHasSymbolQ"}, + {static_cast(IntrinsicScalarFunctions::SymbolicAddQ), + "SymbolicAddQ"}, + {static_cast(IntrinsicScalarFunctions::SymbolicMulQ), + "SymbolicMulQ"}, + {static_cast(IntrinsicScalarFunctions::SymbolicPowQ), + "SymbolicPowQ"}, }; @@ -3412,6 +3472,9 @@ namespace IntrinsicScalarFunctionRegistry { {"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}}, {"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}}, {"has", {&SymbolicHasSymbolQ::create_SymbolicHasSymbolQ, &SymbolicHasSymbolQ::eval_SymbolicHasSymbolQ}}, + {"AddQ", {&SymbolicAddQ::create_SymbolicAddQ, &SymbolicAddQ::eval_SymbolicAddQ}}, + {"MulQ", {&SymbolicMulQ::create_SymbolicMulQ, &SymbolicMulQ::eval_SymbolicMulQ}}, + {"PowQ", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}}, }; static inline bool is_intrinsic_function(const std::string& name) { diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 1433c00410..f8059f50b1 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -672,6 +672,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } + ASR::symbol_t* declare_basic_get_type_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "basic_get_type"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { std::string name = "basic_eq"; symbolic_dependencies.push_back(name); @@ -828,6 +867,60 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value1; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + // Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM + return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 16, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + break; + } + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ: { + ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope); + ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value1; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + // Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM + return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 15, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + break; + } + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ: { + ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope); + ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value1; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + // Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM + return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + break; + } default: { throw LCompilersException("IntrinsicFunction: `" + ASRUtils::get_intrinsic_name(intrinsic_id) @@ -998,6 +1091,20 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x); + transform_stmts(xx.m_body, xx.n_body); + transform_stmts(xx.m_orelse, xx.n_orelse); + SymbolTable* module_scope = current_scope->parent; + if (ASR::is_a(*xx.m_test)) { + ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(xx.m_test); + if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { + ASR::expr_t* function_call = process_attributes(al, xx.base.base.loc, xx.m_test, module_scope); + xx.m_test = function_call; + } + } + } + void visit_SubroutineCall(const ASR::SubroutineCall_t &x) { SymbolTable* module_scope = current_scope->parent; Vec call_args; @@ -1298,7 +1405,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_test)) { + } else if (ASR::is_a(*x.m_test)) { ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_test); SymbolTable* module_scope = current_scope->parent; ASR::expr_t* left_tmp = nullptr; diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 66c7bbf978..a9d01fae4c 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -4605,6 +4605,7 @@ class BodyVisitor : public CommonVisitor { public: ASR::asr_t *asr; std::vector do_loop_variables; + bool using_func_attr = false; BodyVisitor(Allocator &al, LocationManager &lm, ASR::asr_t *unit, diag::Diagnostics &diagnostics, bool main_module, std::string module_name, std::map &ast_overload, @@ -5803,6 +5804,16 @@ class BodyVisitor : public CommonVisitor { } else if(ASR::is_a(*type)) { ASR::Pointer_t* p = ASR::down_cast(type); visit_AttributeUtil(p->m_type, attr_char, t, loc); + } else if(ASR::is_a(*type)) { + std::string attr = attr_char; + if (attr == "func") { + using_func_attr = true; + return; + } + ASR::expr_t *se = ASR::down_cast(ASR::make_Var_t(al, loc, t)); + Vec args; + args.reserve(al, 0); + handle_symbolic_attribute(se, attr, loc, args); } else { throw SemanticError(ASRUtils::type_to_str_python(type) + " not supported yet in Attribute.", loc); @@ -5996,8 +6007,6 @@ class BodyVisitor : public CommonVisitor { } void visit_Compare(const AST::Compare_t &x) { - this->visit_expr(*x.m_left); - ASR::expr_t *left = ASRUtils::EXPR(tmp); if (x.n_comparators > 1) { diag.add(diag::Diagnostic( "Only one comparison operator is supported for now", @@ -6008,6 +6017,36 @@ class BodyVisitor : public CommonVisitor { ); throw SemanticAbort(); } + this->visit_expr(*x.m_left); + if (using_func_attr) { + if (AST::is_a(*x.m_left) && AST::is_a(*x.m_comparators[0])) { + AST::Attribute_t *attr = AST::down_cast(x.m_left); + AST::Name_t *type_name = AST::down_cast(x.m_comparators[0]); + std::string symbolic_type = type_name->m_id; + if (AST::is_a(*attr->m_value)) { + AST::Name_t *var_name = AST::down_cast(attr->m_value); + std::string var = var_name->m_id; + ASR::symbol_t *st = current_scope->resolve_symbol(var); + ASR::expr_t *se = ASR::down_cast( + ASR::make_Var_t(al, x.base.base.loc, st)); + Vec args; + args.reserve(al, 0); + if (symbolic_type == "Add") { + tmp = attr_handler.eval_symbolic_is_Add(se, al, x.base.base.loc, args, diag); + return; + } else if (symbolic_type == "Mul") { + tmp = attr_handler.eval_symbolic_is_Mul(se, al, x.base.base.loc, args, diag); + return; + } else if (symbolic_type == "Pow") { + tmp = attr_handler.eval_symbolic_is_Pow(se, al, x.base.base.loc, args, diag); + return; + } else { + throw SemanticError(symbolic_type + " symbolic type not supported yet", x.base.base.loc); + } + } + } + } + ASR::expr_t *left = ASRUtils::EXPR(tmp); this->visit_expr(*x.m_comparators[0]); ASR::expr_t *right = ASRUtils::EXPR(tmp); diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 1c8256cfc2..6aef6f76cd 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -46,7 +46,7 @@ struct AttributeHandler { symbolic_attribute_map = { {"diff", &eval_symbolic_diff}, {"expand", &eval_symbolic_expand}, - {"has", &eval_symbolic_has_symbol} + {"has", &eval_symbolic_has_symbol}, }; } @@ -457,6 +457,48 @@ struct AttributeHandler { { throw SemanticError(msg, loc); }); } + static ASR::asr_t* eval_symbolic_is_Add(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("AddQ"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + + static ASR::asr_t* eval_symbolic_is_Mul(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("MulQ"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + + static ASR::asr_t* eval_symbolic_is_Pow(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("PowQ"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + }; // AttributeHandler } // namespace LCompilers::LPython