From 26601172ab89ba570db3cecf77b06b0227512a0a Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sat, 30 Sep 2023 18:37:33 +0530 Subject: [PATCH 1/3] Introduced query methods: is_Add, is_Mul & is_Pow --- src/libasr/pass/intrinsic_function_registry.h | 62 +++++++++++++++++++ src/lpython/semantics/python_ast_to_asr.cpp | 6 ++ src/lpython/semantics/python_attribute_eval.h | 47 +++++++++++++- 3 files changed, 114 insertions(+), 1 deletion(-) diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index 1be030ef7b..b97420ee34 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -77,6 +77,9 @@ enum class IntrinsicScalarFunctions : int64_t { SymbolicExp, SymbolicAbs, SymbolicHasSymbolQ, + SymbolicAddQ, + SymbolicMulQ, + SymbolicPowQ, // ... }; @@ -137,6 +140,9 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(SymbolicExp) INTRINSIC_NAME_CASE(SymbolicAbs) INTRINSIC_NAME_CASE(SymbolicHasSymbolQ) + INTRINSIC_NAME_CASE(SymbolicAddQ) + INTRINSIC_NAME_CASE(SymbolicMulQ) + INTRINSIC_NAME_CASE(SymbolicPowQ) default : { throw LCompilersException("pickle: intrinsic_id not implemented"); } @@ -2960,6 +2966,47 @@ namespace SymbolicHasSymbolQ { } } // namespace SymbolicHasSymbolQ +#define create_symbolic_query_macro(X) \ +namespace X { \ + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \ + diag::Diagnostics& diagnostics) { \ + const Location& loc = x.base.base.loc; \ + ASRUtils::require_impl(x.n_args == 1, \ + #X " must have exactly 1 input argument", loc, diagnostics); \ + \ + ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); \ + ASRUtils::require_impl(ASR::is_a(*input_type), \ + #X " expects an argument of type SymbolicExpression", loc, diagnostics); \ + } \ + \ + static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \ + ASR::ttype_t *, Vec &/*args*/) { \ + /*TODO*/ \ + return nullptr; \ + } \ + \ + static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \ + Vec& args, \ + const std::function err) { \ + if (args.size() != 1) { \ + err("Intrinsic " #X " function accepts exactly 1 argument", loc); \ + } \ + \ + ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); \ + if (!ASR::is_a(*argtype)) { \ + err("Argument of " #X " function must be of type SymbolicExpression", \ + args[0]->base.loc); \ + } \ + \ + return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_##X, \ + static_cast(IntrinsicScalarFunctions::X), 0, logical); \ + } \ +} // namespace X + +create_symbolic_query_macro(SymbolicAddQ) +create_symbolic_query_macro(SymbolicMulQ) +create_symbolic_query_macro(SymbolicPowQ) + #define create_symbolic_unary_macro(X) \ namespace X { \ static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \ @@ -3111,6 +3158,12 @@ namespace IntrinsicScalarFunctionRegistry { {nullptr, &SymbolicAbs::verify_args}}, {static_cast(IntrinsicScalarFunctions::SymbolicHasSymbolQ), {nullptr, &SymbolicHasSymbolQ::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicAddQ), + {nullptr, &SymbolicAddQ::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicMulQ), + {nullptr, &SymbolicMulQ::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicPowQ), + {nullptr, &SymbolicPowQ::verify_args}}, }; static const std::map& intrinsic_function_id_to_name = { @@ -3213,6 +3266,12 @@ namespace IntrinsicScalarFunctionRegistry { "SymbolicAbs"}, {static_cast(IntrinsicScalarFunctions::SymbolicHasSymbolQ), "SymbolicHasSymbolQ"}, + {static_cast(IntrinsicScalarFunctions::SymbolicAddQ), + "SymbolicAddQ"}, + {static_cast(IntrinsicScalarFunctions::SymbolicMulQ), + "SymbolicMulQ"}, + {static_cast(IntrinsicScalarFunctions::SymbolicPowQ), + "SymbolicPowQ"}, }; @@ -3267,6 +3326,9 @@ namespace IntrinsicScalarFunctionRegistry { {"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}}, {"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}}, {"has", {&SymbolicHasSymbolQ::create_SymbolicHasSymbolQ, &SymbolicHasSymbolQ::eval_SymbolicHasSymbolQ}}, + {"is_Add", {&SymbolicAddQ::create_SymbolicAddQ, &SymbolicAddQ::eval_SymbolicAddQ}}, + {"is_Mul", {&SymbolicMulQ::create_SymbolicMulQ, &SymbolicMulQ::eval_SymbolicMulQ}}, + {"is_Pow", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}}, }; static inline bool is_intrinsic_function(const std::string& name) { diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 31e59c1089..88a098f640 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -5989,6 +5989,12 @@ class BodyVisitor : public CommonVisitor { } else if(ASR::is_a(*type)) { ASR::Pointer_t* p = ASR::down_cast(type); visit_AttributeUtil(p->m_type, attr_char, t, loc); + } else if(ASR::is_a(*type)) { + std::string attr = attr_char; + ASR::expr_t *se = ASR::down_cast(ASR::make_Var_t(al, loc, t)); + Vec args; + args.reserve(al, 0); + handle_symbolic_attribute(se, attr, loc, args); } else { throw SemanticError(ASRUtils::type_to_str_python(type) + " not supported yet in Attribute.", loc); diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 1c8256cfc2..11afc71b79 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -46,7 +46,10 @@ struct AttributeHandler { symbolic_attribute_map = { {"diff", &eval_symbolic_diff}, {"expand", &eval_symbolic_expand}, - {"has", &eval_symbolic_has_symbol} + {"has", &eval_symbolic_has_symbol}, + {"is_Add", &eval_symbolic_is_Add}, + {"is_Mul", &eval_symbolic_is_Mul}, + {"is_Pow", &eval_symbolic_is_Pow} }; } @@ -457,6 +460,48 @@ struct AttributeHandler { { throw SemanticError(msg, loc); }); } + static ASR::asr_t* eval_symbolic_is_Add(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("is_Add"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + + static ASR::asr_t* eval_symbolic_is_Mul(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("is_Mul"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + + static ASR::asr_t* eval_symbolic_is_Pow(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("is_Pow"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + }; // AttributeHandler } // namespace LCompilers::LPython From e1a056998614bbd88a3999a9123364a3832c8ff9 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 3 Oct 2023 15:05:32 +0530 Subject: [PATCH 2/3] Added support for query functions in the symbolic pass --- src/libasr/pass/replace_symbolic.cpp | 93 ++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 044d944ecd..2e3d8e9a3e 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -626,6 +626,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } + ASR::symbol_t* declare_basic_get_type_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "basic_get_type"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr, SymbolTable* module_scope) { if (ASR::is_a(*expr)) { @@ -692,6 +731,60 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value1; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + // Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM + return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 16, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + break; + } + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ: { + ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope); + ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value1; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + // Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM + return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 15, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + break; + } + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ: { + ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope); + ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value1; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + // Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM + return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + break; + } default: { throw LCompilersException("IntrinsicFunction: `" + ASRUtils::get_intrinsic_name(intrinsic_id) From 8733e79ce64cdbeba438be4107b6b56e9c44b636 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 3 Oct 2023 15:19:19 +0530 Subject: [PATCH 3/3] Added tests --- integration_tests/symbolics_10.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/integration_tests/symbolics_10.py b/integration_tests/symbolics_10.py index c833c1e59f..0b636c71cb 100644 --- a/integration_tests/symbolics_10.py +++ b/integration_tests/symbolics_10.py @@ -23,4 +23,19 @@ def test_attributes(): assert(sin(Symbol("x")).has(Symbol("x")) != sin(Symbol("x")).has(Symbol("y"))) assert(sin(Symbol("x")).has(Symbol("x")) == sin(Symbol("y")).has(Symbol("y"))) + # test is_Add, is_Mul & is_Pow + a: S = x**pi + b: S = x + pi + c: S = x * pi + assert(a.is_Pow == True) + assert(b.is_Pow == False) + assert(c.is_Pow == False) + assert(a.is_Add == False) + assert(b.is_Add == True) + assert(c.is_Add == False) + assert(a.is_Mul == False) + assert(b.is_Mul == False) + assert(c.is_Mul == True) + + test_attributes() \ No newline at end of file