From d01b918b5e6d60642eebdba6aa2d50ae31503525 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Fri, 6 Oct 2023 15:07:56 +0530 Subject: [PATCH 1/9] Added Support for func attribute --- src/libasr/pass/intrinsic_function_registry.h | 45 ++++++ src/libasr/pass/replace_symbolic.cpp | 148 +++++++++++++++++- src/lpython/semantics/python_ast_to_asr.cpp | 6 + src/lpython/semantics/python_attribute_eval.h | 17 +- 4 files changed, 212 insertions(+), 4 deletions(-) diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index 1be030ef7b..862cac36cd 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -77,6 +77,7 @@ enum class IntrinsicScalarFunctions : int64_t { SymbolicExp, SymbolicAbs, SymbolicHasSymbolQ, + SymbolicFuncQ, // ... }; @@ -137,6 +138,7 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(SymbolicExp) INTRINSIC_NAME_CASE(SymbolicAbs) INTRINSIC_NAME_CASE(SymbolicHasSymbolQ) + INTRINSIC_NAME_CASE(SymbolicFuncQ) default : { throw LCompilersException("pickle: intrinsic_id not implemented"); } @@ -2960,6 +2962,44 @@ namespace SymbolicHasSymbolQ { } } // namespace SymbolicHasSymbolQ +namespace SymbolicFuncQ { + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, + diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 1, "Intrinsic function SymbolicFuncQ" + "accepts exactly 1 argument", x.base.base.loc, diagnostics); + + ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); + ASRUtils::require_impl(ASR::is_a(*input_type), + "SymbolicFuncQ expects an argument of type SymbolicExpression", + x.base.base.loc, diagnostics); + } + + static inline ASR::expr_t* eval_SymbolicFuncQ(Allocator &/*al*/, + const Location &/*loc*/, ASR::ttype_t *, Vec &/*args*/) { + /*TODO*/ + return nullptr; + } + + static inline ASR::asr_t* create_SymbolicFuncQ(Allocator& al, + const Location& loc, Vec& args, + const std::function err) { + + if (args.size() != 1) { + err("Intrinsic function SymbolicFuncQ accepts exactly 1 argument", loc); + } + + ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); + if (!ASR::is_a(*argtype)) { + err("Argument of SymbolicFuncQ function must be of type SymbolicExpression", + args[0]->base.loc); + } + + return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_SymbolicFuncQ, + static_cast(IntrinsicScalarFunctions::SymbolicFuncQ), 0, character(0)); + } +} // namespace SymbolicFuncQ + + #define create_symbolic_unary_macro(X) \ namespace X { \ static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \ @@ -3111,6 +3151,8 @@ namespace IntrinsicScalarFunctionRegistry { {nullptr, &SymbolicAbs::verify_args}}, {static_cast(IntrinsicScalarFunctions::SymbolicHasSymbolQ), {nullptr, &SymbolicHasSymbolQ::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicFuncQ), + {nullptr, &SymbolicFuncQ::verify_args}}, }; static const std::map& intrinsic_function_id_to_name = { @@ -3213,6 +3255,8 @@ namespace IntrinsicScalarFunctionRegistry { "SymbolicAbs"}, {static_cast(IntrinsicScalarFunctions::SymbolicHasSymbolQ), "SymbolicHasSymbolQ"}, + {static_cast(IntrinsicScalarFunctions::SymbolicFuncQ), + "SymbolicFuncQ"}, }; @@ -3267,6 +3311,7 @@ namespace IntrinsicScalarFunctionRegistry { {"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}}, {"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}}, {"has", {&SymbolicHasSymbolQ::create_SymbolicHasSymbolQ, &SymbolicHasSymbolQ::eval_SymbolicHasSymbolQ}}, + {"func", {&SymbolicFuncQ::create_SymbolicFuncQ, &SymbolicFuncQ::eval_SymbolicFuncQ}}, }; static inline bool is_intrinsic_function(const std::string& name) { diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 044d944ecd..60cf04c7ad 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -626,6 +626,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } + ASR::symbol_t* declare_basic_get_type_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "basic_get_type"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr, SymbolTable* module_scope) { if (ASR::is_a(*expr)) { @@ -692,6 +731,97 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value1; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + + // Declare a temporary integer variable + ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)); + ASR::symbol_t* int_sym = ASR::down_cast(ASR::make_Variable_t(al, loc, current_scope, + s2c(al, "temp_integer"), nullptr, 0, + ASR::intentType::Local, nullptr, + nullptr, ASR::storage_typeType::Default, int_type, nullptr, + ASR::abiType::Source, ASR::Public, ASR::presenceType::Required, false)); + + if (!current_scope->get_symbol(s2c(al, "temp_integer"))) { + current_scope->add_symbol(s2c(al, "temp_integer"), int_sym); + } + ASR::symbol_t* temp_int_sym = current_scope->get_symbol("temp_integer"); + ASR::expr_t* target_int = ASRUtils::EXPR(ASR::make_Var_t(al, loc, temp_int_sym)); + ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target_int, function_call, nullptr)); + pass_result.push_back(al, stmt1); + + // Declare a temporary string variable + ASR::ttype_t* char_type = ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)); + ASR::symbol_t* str_sym = ASR::down_cast(ASR::make_Variable_t(al, loc, current_scope, + s2c(al, "temp_string"), nullptr, 0, + ASR::intentType::Local, nullptr, + nullptr, ASR::storage_typeType::Default, char_type, nullptr, + ASR::abiType::Source, ASR::Public, ASR::presenceType::Required, false)); + + if (!current_scope->get_symbol(s2c(al, "temp_string"))) { + current_scope->add_symbol(s2c(al, "temp_string"), str_sym); + } + ASR::symbol_t* temp_str_sym = current_scope->get_symbol("temp_string"); + ASR::expr_t* target_str = ASRUtils::EXPR(ASR::make_Var_t(al, loc, temp_str_sym)); + ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target_str, + ASRUtils::EXPR(ASR::make_StringConstant_t(al, loc, s2c(al, ""), + ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, 0, nullptr)))), nullptr)); + pass_result.push_back(al, stmt2); + + // If statement 1 + // Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM + ASR::expr_t* int_cmp_with_17 = ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, target_int, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + ASR::ttype_t *str_type_len_3 = ASRUtils::TYPE(ASR::make_Character_t( + al, loc, 1, 3, nullptr)); + Vec if_body1; + if_body1.reserve(al, 1); + ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target_str, + ASRUtils::EXPR(ASR::make_StringConstant_t(al, loc, s2c(al, "Pow"), str_type_len_3)), nullptr)); + if_body1.push_back(al, stmt3); + ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_If_t(al, loc, int_cmp_with_17, if_body1.p, if_body1.n, nullptr, 0)); + pass_result.push_back(al, stmt4); + + // If statement 2 + // Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM + ASR::expr_t* int_cmp_with_15 = ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, target_int, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 15, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + Vec if_body2; + if_body2.reserve(al, 1); + ASR::stmt_t* stmt5 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target_str, + ASRUtils::EXPR(ASR::make_StringConstant_t(al, loc, s2c(al, "MUL"), str_type_len_3)), nullptr)); + if_body2.push_back(al, stmt5); + ASR::stmt_t* stmt6 = ASRUtils::STMT(ASR::make_If_t(al, loc, int_cmp_with_15, if_body2.p, if_body2.n, nullptr, 0)); + pass_result.push_back(al, stmt6); + + // If statement 3 + // Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM + ASR::expr_t* int_cmp_with_16 = ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, target_int, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 16, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + Vec if_body3; + if_body3.reserve(al, 1); + ASR::stmt_t* stmt7 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target_str, + ASRUtils::EXPR(ASR::make_StringConstant_t(al, loc, s2c(al, "Add"), str_type_len_3)), nullptr)); + if_body3.push_back(al, stmt7); + ASR::stmt_t* stmt8 = ASRUtils::STMT(ASR::make_If_t(al, loc, int_cmp_with_16, if_body3.p, if_body3.n, nullptr, 0)); + pass_result.push_back(al, stmt8); + + return target_str; + break; + } default: { throw LCompilersException("IntrinsicFunction: `" + ASRUtils::get_intrinsic_name(intrinsic_id) @@ -708,7 +838,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x.m_value); if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) { process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, x.m_target); - } else if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { + } else if ((intrinsic_func->m_type->type == ASR::ttypeType::Logical) || + (intrinsic_func->m_type->type == ASR::ttypeType::Character)) { ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, x.m_value, module_scope); ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr)); pass_result.push_back(al, stmt); @@ -882,7 +1013,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*ASRUtils::expr_type(val))) { + } else if (ASR::is_a(*ASRUtils::expr_type(val)) || + ASR::is_a(*ASRUtils::expr_type(val))) { ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, val, module_scope); print_tmp.push_back(function_call); } @@ -1049,7 +1181,17 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_test)) { + } else if (ASR::is_a(*x.m_test)) { + ASR::StringCompare_t *st = ASR::down_cast(x.m_test); + + left_tmp = process_attributes(al, x.base.base.loc, st->m_left, module_scope); + right_tmp = process_attributes(al, x.base.base.loc, st->m_right, module_scope); + ASR::expr_t* test = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp, + st->m_op, right_tmp, st->m_type, st->m_value)); + + ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); + pass_result.push_back(al, assert_stmt); + } else if (ASR::is_a(*x.m_test)) { ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_test); SymbolTable* module_scope = current_scope->parent; ASR::expr_t* left_tmp = nullptr; diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 50bb12d77f..67dc4c853f 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -5989,6 +5989,12 @@ class BodyVisitor : public CommonVisitor { } else if(ASR::is_a(*type)) { ASR::Pointer_t* p = ASR::down_cast(type); visit_AttributeUtil(p->m_type, attr_char, t, loc); + } else if(ASR::is_a(*type)) { + std::string attr = attr_char; + ASR::expr_t *se = ASR::down_cast(ASR::make_Var_t(al, loc, t)); + Vec args; + args.reserve(al, 0); + handle_symbolic_attribute(se, attr, loc, args); } else { throw SemanticError(ASRUtils::type_to_str_python(type) + " not supported yet in Attribute.", loc); diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 1c8256cfc2..d1e392b1c4 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -46,7 +46,8 @@ struct AttributeHandler { symbolic_attribute_map = { {"diff", &eval_symbolic_diff}, {"expand", &eval_symbolic_expand}, - {"has", &eval_symbolic_has_symbol} + {"has", &eval_symbolic_has_symbol}, + {"func", &eval_symbolic_func} }; } @@ -457,6 +458,20 @@ struct AttributeHandler { { throw SemanticError(msg, loc); }); } + static ASR::asr_t* eval_symbolic_func(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("func"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + }; // AttributeHandler } // namespace LCompilers::LPython From 03ccc8025cf55d0bea80d0606b08f1db0e401e66 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sat, 7 Oct 2023 12:23:57 +0530 Subject: [PATCH 2/9] Simplified implementation --- src/libasr/pass/replace_symbolic.cpp | 139 +++++++++++---------------- 1 file changed, 55 insertions(+), 84 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 60cf04c7ad..abbd96a71b 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -665,6 +665,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } + ASR::symbol_t* declare_basic_get_class_from_id_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "basic_get_class_from_id"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr, SymbolTable* module_scope) { if (ASR::is_a(*expr)) { @@ -732,94 +771,26 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = loc; - call_arg.m_value = value1; - call_args.push_back(al, call_arg); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, - basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, + Vec call_args1, call_args2; + call_args1.reserve(al, 1); + call_args2.reserve(al, 1); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = loc; + call_arg1.m_value = value1; + call_args1.push_back(al, call_arg1); + ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args1.p, call_args1.n, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); - // Declare a temporary integer variable - ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)); - ASR::symbol_t* int_sym = ASR::down_cast(ASR::make_Variable_t(al, loc, current_scope, - s2c(al, "temp_integer"), nullptr, 0, - ASR::intentType::Local, nullptr, - nullptr, ASR::storage_typeType::Default, int_type, nullptr, - ASR::abiType::Source, ASR::Public, ASR::presenceType::Required, false)); - - if (!current_scope->get_symbol(s2c(al, "temp_integer"))) { - current_scope->add_symbol(s2c(al, "temp_integer"), int_sym); - } - ASR::symbol_t* temp_int_sym = current_scope->get_symbol("temp_integer"); - ASR::expr_t* target_int = ASRUtils::EXPR(ASR::make_Var_t(al, loc, temp_int_sym)); - ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target_int, function_call, nullptr)); - pass_result.push_back(al, stmt1); - - // Declare a temporary string variable - ASR::ttype_t* char_type = ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)); - ASR::symbol_t* str_sym = ASR::down_cast(ASR::make_Variable_t(al, loc, current_scope, - s2c(al, "temp_string"), nullptr, 0, - ASR::intentType::Local, nullptr, - nullptr, ASR::storage_typeType::Default, char_type, nullptr, - ASR::abiType::Source, ASR::Public, ASR::presenceType::Required, false)); - - if (!current_scope->get_symbol(s2c(al, "temp_string"))) { - current_scope->add_symbol(s2c(al, "temp_string"), str_sym); - } - ASR::symbol_t* temp_str_sym = current_scope->get_symbol("temp_string"); - ASR::expr_t* target_str = ASRUtils::EXPR(ASR::make_Var_t(al, loc, temp_str_sym)); - ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target_str, - ASRUtils::EXPR(ASR::make_StringConstant_t(al, loc, s2c(al, ""), - ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, 0, nullptr)))), nullptr)); - pass_result.push_back(al, stmt2); - - // If statement 1 - // Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM - ASR::expr_t* int_cmp_with_17 = ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, target_int, ASR::cmpopType::Eq, - ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), - ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); - ASR::ttype_t *str_type_len_3 = ASRUtils::TYPE(ASR::make_Character_t( - al, loc, 1, 3, nullptr)); - Vec if_body1; - if_body1.reserve(al, 1); - ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target_str, - ASRUtils::EXPR(ASR::make_StringConstant_t(al, loc, s2c(al, "Pow"), str_type_len_3)), nullptr)); - if_body1.push_back(al, stmt3); - ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_If_t(al, loc, int_cmp_with_17, if_body1.p, if_body1.n, nullptr, 0)); - pass_result.push_back(al, stmt4); - - // If statement 2 - // Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM - ASR::expr_t* int_cmp_with_15 = ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, target_int, ASR::cmpopType::Eq, - ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 15, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), - ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); - Vec if_body2; - if_body2.reserve(al, 1); - ASR::stmt_t* stmt5 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target_str, - ASRUtils::EXPR(ASR::make_StringConstant_t(al, loc, s2c(al, "MUL"), str_type_len_3)), nullptr)); - if_body2.push_back(al, stmt5); - ASR::stmt_t* stmt6 = ASRUtils::STMT(ASR::make_If_t(al, loc, int_cmp_with_15, if_body2.p, if_body2.n, nullptr, 0)); - pass_result.push_back(al, stmt6); - - // If statement 3 - // Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM - ASR::expr_t* int_cmp_with_16 = ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, target_int, ASR::cmpopType::Eq, - ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 16, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), - ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); - Vec if_body3; - if_body3.reserve(al, 1); - ASR::stmt_t* stmt7 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target_str, - ASRUtils::EXPR(ASR::make_StringConstant_t(al, loc, s2c(al, "Add"), str_type_len_3)), nullptr)); - if_body3.push_back(al, stmt7); - ASR::stmt_t* stmt8 = ASRUtils::STMT(ASR::make_If_t(al, loc, int_cmp_with_16, if_body3.p, if_body3.n, nullptr, 0)); - pass_result.push_back(al, stmt8); - - return target_str; + call_arg2.loc = loc; + call_arg2.m_value = function_call1; + call_args2.push_back(al, call_arg2); + return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_class_from_id_sym, basic_get_class_from_id_sym, call_args2.p, call_args2.n, + ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), nullptr, nullptr)); break; } default: { From 6b8b8566b0723382c14e9d727d85c47045122ca2 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Mon, 16 Oct 2023 09:54:51 +0530 Subject: [PATCH 3/9] Added support for add, mul and pow types --- src/libasr/pass/intrinsic_function_registry.h | 102 +++++++++------ src/libasr/pass/replace_symbolic.cpp | 122 ++++++++---------- src/lpython/semantics/python_ast_to_asr.cpp | 37 +++++- src/lpython/semantics/python_attribute_eval.h | 46 ++++++- 4 files changed, 191 insertions(+), 116 deletions(-) diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index 862cac36cd..5e1e5ea0f7 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -77,7 +77,9 @@ enum class IntrinsicScalarFunctions : int64_t { SymbolicExp, SymbolicAbs, SymbolicHasSymbolQ, - SymbolicFuncQ, + SymbolicAddQ, + SymbolicMulQ, + SymbolicPowQ, // ... }; @@ -138,7 +140,9 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(SymbolicExp) INTRINSIC_NAME_CASE(SymbolicAbs) INTRINSIC_NAME_CASE(SymbolicHasSymbolQ) - INTRINSIC_NAME_CASE(SymbolicFuncQ) + INTRINSIC_NAME_CASE(SymbolicAddQ) + INTRINSIC_NAME_CASE(SymbolicMulQ) + INTRINSIC_NAME_CASE(SymbolicPowQ) default : { throw LCompilersException("pickle: intrinsic_id not implemented"); } @@ -2962,42 +2966,46 @@ namespace SymbolicHasSymbolQ { } } // namespace SymbolicHasSymbolQ -namespace SymbolicFuncQ { - static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, - diag::Diagnostics& diagnostics) { - ASRUtils::require_impl(x.n_args == 1, "Intrinsic function SymbolicFuncQ" - "accepts exactly 1 argument", x.base.base.loc, diagnostics); - - ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); - ASRUtils::require_impl(ASR::is_a(*input_type), - "SymbolicFuncQ expects an argument of type SymbolicExpression", - x.base.base.loc, diagnostics); - } - - static inline ASR::expr_t* eval_SymbolicFuncQ(Allocator &/*al*/, - const Location &/*loc*/, ASR::ttype_t *, Vec &/*args*/) { - /*TODO*/ - return nullptr; - } - - static inline ASR::asr_t* create_SymbolicFuncQ(Allocator& al, - const Location& loc, Vec& args, - const std::function err) { - - if (args.size() != 1) { - err("Intrinsic function SymbolicFuncQ accepts exactly 1 argument", loc); - } - - ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); - if (!ASR::is_a(*argtype)) { - err("Argument of SymbolicFuncQ function must be of type SymbolicExpression", - args[0]->base.loc); - } +#define create_symbolic_query_macro(X) \ +namespace X { \ + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \ + diag::Diagnostics& diagnostics) { \ + const Location& loc = x.base.base.loc; \ + ASRUtils::require_impl(x.n_args == 1, \ + #X " must have exactly 1 input argument", loc, diagnostics); \ + \ + ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); \ + ASRUtils::require_impl(ASR::is_a(*input_type), \ + #X " expects an argument of type SymbolicExpression", loc, diagnostics); \ + } \ + \ + static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \ + ASR::ttype_t *, Vec &/*args*/) { \ + /*TODO*/ \ + return nullptr; \ + } \ + \ + static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \ + Vec& args, \ + const std::function err) { \ + if (args.size() != 1) { \ + err("Intrinsic " #X " function accepts exactly 1 argument", loc); \ + } \ + \ + ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); \ + if (!ASR::is_a(*argtype)) { \ + err("Argument of " #X " function must be of type SymbolicExpression", \ + args[0]->base.loc); \ + } \ + \ + return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_##X, \ + static_cast(IntrinsicScalarFunctions::X), 0, logical); \ + } \ +} // namespace X - return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_SymbolicFuncQ, - static_cast(IntrinsicScalarFunctions::SymbolicFuncQ), 0, character(0)); - } -} // namespace SymbolicFuncQ +create_symbolic_query_macro(SymbolicAddQ) +create_symbolic_query_macro(SymbolicMulQ) +create_symbolic_query_macro(SymbolicPowQ) #define create_symbolic_unary_macro(X) \ @@ -3151,8 +3159,12 @@ namespace IntrinsicScalarFunctionRegistry { {nullptr, &SymbolicAbs::verify_args}}, {static_cast(IntrinsicScalarFunctions::SymbolicHasSymbolQ), {nullptr, &SymbolicHasSymbolQ::verify_args}}, - {static_cast(IntrinsicScalarFunctions::SymbolicFuncQ), - {nullptr, &SymbolicFuncQ::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicAddQ), + {nullptr, &SymbolicAddQ::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicMulQ), + {nullptr, &SymbolicMulQ::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicPowQ), + {nullptr, &SymbolicPowQ::verify_args}}, }; static const std::map& intrinsic_function_id_to_name = { @@ -3255,8 +3267,12 @@ namespace IntrinsicScalarFunctionRegistry { "SymbolicAbs"}, {static_cast(IntrinsicScalarFunctions::SymbolicHasSymbolQ), "SymbolicHasSymbolQ"}, - {static_cast(IntrinsicScalarFunctions::SymbolicFuncQ), - "SymbolicFuncQ"}, + {static_cast(IntrinsicScalarFunctions::SymbolicAddQ), + "SymbolicAddQ"}, + {static_cast(IntrinsicScalarFunctions::SymbolicMulQ), + "SymbolicMulQ"}, + {static_cast(IntrinsicScalarFunctions::SymbolicPowQ), + "SymbolicPowQ"}, }; @@ -3311,7 +3327,9 @@ namespace IntrinsicScalarFunctionRegistry { {"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}}, {"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}}, {"has", {&SymbolicHasSymbolQ::create_SymbolicHasSymbolQ, &SymbolicHasSymbolQ::eval_SymbolicHasSymbolQ}}, - {"func", {&SymbolicFuncQ::create_SymbolicFuncQ, &SymbolicFuncQ::eval_SymbolicFuncQ}}, + {"is_Add", {&SymbolicAddQ::create_SymbolicAddQ, &SymbolicAddQ::eval_SymbolicAddQ}}, + {"is_Mul", {&SymbolicMulQ::create_SymbolicMulQ, &SymbolicMulQ::eval_SymbolicMulQ}}, + {"is_Pow", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}}, }; static inline bool is_intrinsic_function(const std::string& name) { diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index abbd96a71b..8b8584684d 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -665,45 +665,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } - ASR::symbol_t* declare_basic_get_class_from_id_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { - std::string name = "basic_get_class_from_id"; - symbolic_dependencies.push_back(name); - if (!module_scope->get_symbol(name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); - fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); - ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - return_var, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* symbol = ASR::down_cast(subrout); - module_scope->add_symbol(s2c(al, name), symbol); - } - return module_scope->get_symbol(name); - } - ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr, SymbolTable* module_scope) { if (ASR::is_a(*expr)) { @@ -770,27 +731,58 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); - Vec call_args1, call_args2; - call_args1.reserve(al, 1); - call_args2.reserve(al, 1); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = loc; - call_arg1.m_value = value1; - call_args1.push_back(al, call_arg1); - ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, - basic_get_type_sym, basic_get_type_sym, call_args1.p, call_args1.n, + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value1; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); - - call_arg2.loc = loc; - call_arg2.m_value = function_call1; - call_args2.push_back(al, call_arg2); - return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, - basic_get_class_from_id_sym, basic_get_class_from_id_sym, call_args2.p, call_args2.n, - ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), nullptr, nullptr)); + // Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM + return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 16, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + break; + } + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ: { + ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope); + ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value1; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + // Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM + return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 15, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + break; + } + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ: { + ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope); + ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value1; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + // Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM + return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); break; } default: { @@ -809,8 +801,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x.m_value); if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) { process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, x.m_target); - } else if ((intrinsic_func->m_type->type == ASR::ttypeType::Logical) || - (intrinsic_func->m_type->type == ASR::ttypeType::Character)) { + } else if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, x.m_value, module_scope); ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr)); pass_result.push_back(al, stmt); @@ -984,8 +975,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*ASRUtils::expr_type(val)) || - ASR::is_a(*ASRUtils::expr_type(val))) { + } else if (ASR::is_a(*ASRUtils::expr_type(val))) { ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, val, module_scope); print_tmp.push_back(function_call); } @@ -1150,16 +1140,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_op, right_tmp, l->m_type, l->m_value)); - ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); - pass_result.push_back(al, assert_stmt); - } else if (ASR::is_a(*x.m_test)) { - ASR::StringCompare_t *st = ASR::down_cast(x.m_test); - - left_tmp = process_attributes(al, x.base.base.loc, st->m_left, module_scope); - right_tmp = process_attributes(al, x.base.base.loc, st->m_right, module_scope); - ASR::expr_t* test = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp, - st->m_op, right_tmp, st->m_type, st->m_value)); - ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); pass_result.push_back(al, assert_stmt); } else if (ASR::is_a(*x.m_test)) { diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 67dc4c853f..f8f9277ed9 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -4791,6 +4791,7 @@ class BodyVisitor : public CommonVisitor { public: ASR::asr_t *asr; std::vector do_loop_variables; + bool using_func_attr = false; BodyVisitor(Allocator &al, LocationManager &lm, ASR::asr_t *unit, diag::Diagnostics &diagnostics, bool main_module, std::string module_name, std::map &ast_overload, @@ -5991,6 +5992,10 @@ class BodyVisitor : public CommonVisitor { visit_AttributeUtil(p->m_type, attr_char, t, loc); } else if(ASR::is_a(*type)) { std::string attr = attr_char; + if (attr == "func") { + using_func_attr = true; + return; + } ASR::expr_t *se = ASR::down_cast(ASR::make_Var_t(al, loc, t)); Vec args; args.reserve(al, 0); @@ -6188,8 +6193,6 @@ class BodyVisitor : public CommonVisitor { } void visit_Compare(const AST::Compare_t &x) { - this->visit_expr(*x.m_left); - ASR::expr_t *left = ASRUtils::EXPR(tmp); if (x.n_comparators > 1) { diag.add(diag::Diagnostic( "Only one comparison operator is supported for now", @@ -6200,6 +6203,36 @@ class BodyVisitor : public CommonVisitor { ); throw SemanticAbort(); } + this->visit_expr(*x.m_left); + if (using_func_attr) { + if (AST::is_a(*x.m_left) && AST::is_a(*x.m_comparators[0])) { + AST::Attribute_t *attr = AST::down_cast(x.m_left); + AST::Name_t *type_name = AST::down_cast(x.m_comparators[0]); + std::string symbolic_type = type_name->m_id; + if (AST::is_a(*attr->m_value)) { + AST::Name_t *var_name = AST::down_cast(attr->m_value); + std::string var = var_name->m_id; + ASR::symbol_t *st = current_scope->resolve_symbol(var); + ASR::expr_t *se = ASR::down_cast( + ASR::make_Var_t(al, x.base.base.loc, st)); + Vec args; + args.reserve(al, 0); + if (symbolic_type == "Add") { + handle_symbolic_attribute(se, "is_Add", x.base.base.loc, args); + return; + } else if (symbolic_type == "Mul") { + handle_symbolic_attribute(se, "is_Mul", x.base.base.loc, args); + return; + } else if (symbolic_type == "Pow") { + handle_symbolic_attribute(se, "is_Pow", x.base.base.loc, args); + return; + } else { + throw SemanticError(symbolic_type + " symbolic type not supported yet", x.base.base.loc); + } + } + } + } + ASR::expr_t *left = ASRUtils::EXPR(tmp); this->visit_expr(*x.m_comparators[0]); ASR::expr_t *right = ASRUtils::EXPR(tmp); diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index d1e392b1c4..3d0451a7d9 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -47,7 +47,9 @@ struct AttributeHandler { {"diff", &eval_symbolic_diff}, {"expand", &eval_symbolic_expand}, {"has", &eval_symbolic_has_symbol}, - {"func", &eval_symbolic_func} + {"is_Add", &eval_symbolic_is_Add}, + {"is_Mul", &eval_symbolic_is_Mul}, + {"is_Pow", &eval_symbolic_is_Pow} }; } @@ -472,6 +474,48 @@ struct AttributeHandler { { throw SemanticError(msg, loc); }); } + static ASR::asr_t* eval_symbolic_is_Add(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("is_Add"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + + static ASR::asr_t* eval_symbolic_is_Mul(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("is_Mul"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + + static ASR::asr_t* eval_symbolic_is_Pow(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("is_Pow"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + }; // AttributeHandler } // namespace LCompilers::LPython From 1c0f696b008094870b821e10d18efc876eec6c22 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Mon, 16 Oct 2023 10:10:23 +0530 Subject: [PATCH 4/9] Fixed failing tests --- src/libasr/pass/replace_symbolic.cpp | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 75bce011d5..45ddbffd59 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -711,27 +711,32 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } - ASR::symbol_t* declare_basic_assign_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { - std::string name = "basic_assign"; + ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "basic_eq"; symbolic_dependencies.push_back(name); if (!module_scope->get_symbol(name)) { std::string header = "symengine/cwrapper.h"; SymbolTable* fn_symtab = al.make_new(module_scope); Vec args; - args.reserve(al, 2); + args.reserve(al, 1); ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "y"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + fn_symtab->add_symbol(s2c(al, "y"), arg3); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); Vec body; body.reserve(al, 1); @@ -739,9 +744,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor dep; dep.reserve(al, 1); + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, + return_var, ASR::abiType::BindC, ASR::accessType::Public, ASR::deftypeType::Interface, s2c(al, name), false, false, false, false, false, nullptr, 0, false, false, false, s2c(al, header)); ASR::symbol_t* symbol = ASR::down_cast(subrout); From 4e2cc0e6c89ab8009b0bb5eda0d2740075a4c04d Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Mon, 16 Oct 2023 10:24:15 +0530 Subject: [PATCH 5/9] Added tests --- integration_tests/symbolics_02.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/integration_tests/symbolics_02.py b/integration_tests/symbolics_02.py index 74f4a4af35..52c18c0e7c 100644 --- a/integration_tests/symbolics_02.py +++ b/integration_tests/symbolics_02.py @@ -1,35 +1,51 @@ -from sympy import Symbol, pi +from sympy import Symbol, pi, Add, Mul, Pow from lpython import S def test_symbolic_operations(): x: S = Symbol('x') y: S = Symbol('y') - p1: S = pi - p2: S = pi + pi1: S = pi + pi2: S = pi # Addition z: S = x + y + z1: bool = z.func == Add + z2: bool = z.func == Mul assert(z == x + y) + assert(z1 == True) + assert(z2 == False) print(z) # Subtraction w: S = x - y + w1: bool = w.func == Add assert(w == x - y) + assert(w1 == True) print(w) # Multiplication u: S = x * y + u1: bool = u.func == Mul assert(u == x * y) + assert(u1 == True) print(u) # Division v: S = x / y + v1: bool = v.func == Mul assert(v == x / y) + assert(v1 == True) print(v) # Power p: S = x ** y + p1: bool = p.func == Pow + p2: bool = p.func == Add + p3: bool = p.func == Mul assert(p == x ** y) + assert(p1 == True) + assert(p2 == False) + assert(p3 == False) print(p) # Casting @@ -40,13 +56,13 @@ def test_symbolic_operations(): print(c) # Comparison - b1: bool = p1 == p2 + b1: bool = pi1 == pi2 print(b1) assert(b1 == True) - b2: bool = p1 != pi + b2: bool = pi1 != pi print(b2) assert(b2 == False) - b3: bool = p1 != x + b3: bool = pi1 != x print(b3) assert(b3 == True) b4: bool = pi == Symbol("x") From e6f7b0a4225ab8f548b98a29c09092853ce38a7b Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 17 Oct 2023 12:19:53 +0530 Subject: [PATCH 6/9] Removed is_Add, is_Mul and is_Pow attributes --- src/lpython/semantics/python_ast_to_asr.cpp | 6 +++--- src/lpython/semantics/python_attribute_eval.h | 17 ----------------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index df53d3e373..a9d01fae4c 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -6032,13 +6032,13 @@ class BodyVisitor : public CommonVisitor { Vec args; args.reserve(al, 0); if (symbolic_type == "Add") { - handle_symbolic_attribute(se, "is_Add", x.base.base.loc, args); + tmp = attr_handler.eval_symbolic_is_Add(se, al, x.base.base.loc, args, diag); return; } else if (symbolic_type == "Mul") { - handle_symbolic_attribute(se, "is_Mul", x.base.base.loc, args); + tmp = attr_handler.eval_symbolic_is_Mul(se, al, x.base.base.loc, args, diag); return; } else if (symbolic_type == "Pow") { - handle_symbolic_attribute(se, "is_Pow", x.base.base.loc, args); + tmp = attr_handler.eval_symbolic_is_Pow(se, al, x.base.base.loc, args, diag); return; } else { throw SemanticError(symbolic_type + " symbolic type not supported yet", x.base.base.loc); diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 3d0451a7d9..6fac11c811 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -47,9 +47,6 @@ struct AttributeHandler { {"diff", &eval_symbolic_diff}, {"expand", &eval_symbolic_expand}, {"has", &eval_symbolic_has_symbol}, - {"is_Add", &eval_symbolic_is_Add}, - {"is_Mul", &eval_symbolic_is_Mul}, - {"is_Pow", &eval_symbolic_is_Pow} }; } @@ -460,20 +457,6 @@ struct AttributeHandler { { throw SemanticError(msg, loc); }); } - static ASR::asr_t* eval_symbolic_func(ASR::expr_t *s, Allocator &al, const Location &loc, - Vec &args, diag::Diagnostics &/*diag*/) { - Vec args_with_list; - args_with_list.reserve(al, args.size() + 1); - args_with_list.push_back(al, s); - for(size_t i = 0; i < args.size(); i++) { - args_with_list.push_back(al, args[i]); - } - ASRUtils::create_intrinsic_function create_function = - ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("func"); - return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) - { throw SemanticError(msg, loc); }); - } - static ASR::asr_t* eval_symbolic_is_Add(ASR::expr_t *s, Allocator &al, const Location &loc, Vec &args, diag::Diagnostics &/*diag*/) { Vec args_with_list; From 4fce428b487b29e04395fad35f3101a5633badf1 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 17 Oct 2023 13:08:56 +0530 Subject: [PATCH 7/9] Implemented visit_If for supporting func attribute --- src/libasr/pass/replace_symbolic.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 45ddbffd59..f8059f50b1 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -1091,6 +1091,20 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x); + transform_stmts(xx.m_body, xx.n_body); + transform_stmts(xx.m_orelse, xx.n_orelse); + SymbolTable* module_scope = current_scope->parent; + if (ASR::is_a(*xx.m_test)) { + ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(xx.m_test); + if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { + ASR::expr_t* function_call = process_attributes(al, xx.base.base.loc, xx.m_test, module_scope); + xx.m_test = function_call; + } + } + } + void visit_SubroutineCall(const ASR::SubroutineCall_t &x) { SymbolTable* module_scope = current_scope->parent; Vec call_args; From 65e179b0a0c32b5ce505b1c1c328d1d33b7d008d Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 17 Oct 2023 13:17:56 +0530 Subject: [PATCH 8/9] added tests --- integration_tests/symbolics_02.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/integration_tests/symbolics_02.py b/integration_tests/symbolics_02.py index 52c18c0e7c..1bfda7e74d 100644 --- a/integration_tests/symbolics_02.py +++ b/integration_tests/symbolics_02.py @@ -14,6 +14,10 @@ def test_symbolic_operations(): assert(z == x + y) assert(z1 == True) assert(z2 == False) + if z.func == Add: + assert True + else: + assert False print(z) # Subtraction @@ -21,6 +25,10 @@ def test_symbolic_operations(): w1: bool = w.func == Add assert(w == x - y) assert(w1 == True) + if w.func == Add: + assert True + else: + assert False print(w) # Multiplication @@ -28,6 +36,10 @@ def test_symbolic_operations(): u1: bool = u.func == Mul assert(u == x * y) assert(u1 == True) + if u.func == Mul: + assert True + else: + assert False print(u) # Division @@ -35,6 +47,10 @@ def test_symbolic_operations(): v1: bool = v.func == Mul assert(v == x / y) assert(v1 == True) + if v.func == Mul: + assert True + else: + assert False print(v) # Power @@ -46,6 +62,10 @@ def test_symbolic_operations(): assert(p1 == True) assert(p2 == False) assert(p3 == False) + if p.func == Pow: + assert True + else: + assert False print(p) # Casting From ffbff02dd234110a82ba06fff9dca79533a27f45 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 17 Oct 2023 14:40:19 +0530 Subject: [PATCH 9/9] Addressed string names --- src/libasr/pass/intrinsic_function_registry.h | 6 +++--- src/lpython/semantics/python_attribute_eval.h | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index f99d61a650..7cd59fe8b9 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -3472,9 +3472,9 @@ namespace IntrinsicScalarFunctionRegistry { {"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}}, {"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}}, {"has", {&SymbolicHasSymbolQ::create_SymbolicHasSymbolQ, &SymbolicHasSymbolQ::eval_SymbolicHasSymbolQ}}, - {"is_Add", {&SymbolicAddQ::create_SymbolicAddQ, &SymbolicAddQ::eval_SymbolicAddQ}}, - {"is_Mul", {&SymbolicMulQ::create_SymbolicMulQ, &SymbolicMulQ::eval_SymbolicMulQ}}, - {"is_Pow", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}}, + {"AddQ", {&SymbolicAddQ::create_SymbolicAddQ, &SymbolicAddQ::eval_SymbolicAddQ}}, + {"MulQ", {&SymbolicMulQ::create_SymbolicMulQ, &SymbolicMulQ::eval_SymbolicMulQ}}, + {"PowQ", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}}, }; static inline bool is_intrinsic_function(const std::string& name) { diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 6fac11c811..6aef6f76cd 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -466,7 +466,7 @@ struct AttributeHandler { args_with_list.push_back(al, args[i]); } ASRUtils::create_intrinsic_function create_function = - ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("is_Add"); + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("AddQ"); return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }); } @@ -480,7 +480,7 @@ struct AttributeHandler { args_with_list.push_back(al, args[i]); } ASRUtils::create_intrinsic_function create_function = - ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("is_Mul"); + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("MulQ"); return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }); } @@ -494,7 +494,7 @@ struct AttributeHandler { args_with_list.push_back(al, args[i]); } ASRUtils::create_intrinsic_function create_function = - ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("is_Pow"); + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("PowQ"); return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }); }