diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index cf003ae577..b30fca76a3 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -7,9 +7,6 @@ #include #include -#include - - namespace LCompilers { using ASR::down_cast; @@ -294,405 +291,197 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; - if (ASR::is_a(*x.m_value)) { - ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(x.m_value); - int64_t intrinsic_id = intrinsic_func->m_intrinsic_id; - if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) { - switch (static_cast(intrinsic_id)) { - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { - std::string new_name = "basic_const_pi"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg))); + ASR::expr_t* handle_argument(Allocator &al, const Location &loc, ASR::expr_t* arg) { + if (ASR::is_a(*arg)) { + return arg; + } else if (ASR::is_a(*arg)) { + this->visit_IntrinsicFunction(*ASR::down_cast(arg)); + } else if (ASR::is_a(*arg)) { + this->visit_Cast(*ASR::down_cast(arg)); + } else { + LCOMPILERS_ASSERT(false); + } + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + return ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym)); + } - Vec body; - body.reserve(al, 1); + void process_binary_operator(Allocator &al, const Location &loc, ASR::IntrinsicFunction_t* x, SymbolTable* module_scope, + const std::string& new_name, ASR::expr_t* target) { + ASR::expr_t* value1 = handle_argument(al, loc, x->m_args[0]); + ASR::expr_t* value2 = handle_argument(al, loc, x->m_args[1]); + perform_symbolic_binary_operation(al, loc, module_scope, new_name, target, value1, value2); + } - Vec dep; - dep.reserve(al, 1); + void process_unary_operator(Allocator &al, const Location &loc, ASR::IntrinsicFunction_t* x, SymbolTable* module_scope, + const std::string& new_name, ASR::expr_t* target) { + ASR::expr_t* value1 = handle_argument(al, loc, x->m_args[0]); + perform_symbolic_unary_operation(al, loc, module_scope, new_name, target, value1); + } - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } + void process_intrinsic_function(Allocator &al, const Location &loc, ASR::IntrinsicFunction_t* x, SymbolTable* module_scope, + ASR::expr_t* target){ + int64_t intrinsic_id = x->m_intrinsic_id; + switch (static_cast(intrinsic_id)) { + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { + std::string new_name = "basic_const_pi"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); - // Create the function call statement for basic_const_pi - ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = x.m_target; - call_args.push_back(al, call_arg); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_const_pi_sym, - basic_const_pi_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: { - std::string new_name = "symbol_set"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg))); - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "s"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); + Vec body; + body.reserve(al, 1); - Vec body; - body.reserve(al, 1); + Vec dep; + dep.reserve(al, 1); - Vec dep; - dep.reserve(al, 1); + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } + // Create the function call statement for basic_const_pi + ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); - ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = x.m_target; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = intrinsic_func->m_args[0]; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_const_pi_sym, + basic_const_pi_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: { + std::string new_name = "symbol_set"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, symbol_set_sym, - symbol_set_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "s"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_add", - x.m_target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } + Vec body; + body.reserve(al, 1); - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_sub", - x.m_target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } + Vec dep; + dep.reserve(al, 1); - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_mul", - x.m_target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_div", - x.m_target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } + ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = loc; + call_arg1.m_value = target; + call_arg2.loc = loc; + call_arg2.m_value = x->m_args[0]; + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_pow", - x.m_target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, symbol_set_sym, + symbol_set_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: { + process_binary_operator(al, loc, x, module_scope, "basic_add", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { + process_binary_operator(al, loc, x, module_scope, "basic_sub", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { + process_binary_operator(al, loc, x, module_scope, "basic_mul", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { + process_binary_operator(al, loc, x, module_scope, "basic_div", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { + process_binary_operator(al, loc, x, module_scope, "basic_pow", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { + process_binary_operator(al, loc, x, module_scope, "basic_diff", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { + process_unary_operator(al, loc, x, module_scope, "basic_sin", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { + process_unary_operator(al, loc, x, module_scope, "basic_cos", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { + process_unary_operator(al, loc, x, module_scope, "basic_log", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { + process_unary_operator(al, loc, x, module_scope, "basic_exp", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { + process_unary_operator(al, loc, x, module_scope, "basic_abs", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { + process_unary_operator(al, loc, x, module_scope, "basic_expand", target); + break; + } + default: { + throw LCompilersException("IntrinsicFunction: `" + + ASRUtils::get_intrinsic_name(intrinsic_id) + + "` is not implemented"); + } + } + } - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_diff", - x.m_target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_sin", - x.m_target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_cos", - x.m_target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_log", - x.m_target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_exp", - x.m_target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_abs", - x.m_target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_expand", - x.m_target, value); - break; - } - default: { - throw LCompilersException("IntrinsicFunction: `" - + ASRUtils::get_intrinsic_name(intrinsic_id) - + "` is not implemented"); - } - } + void visit_Assignment(const ASR::Assignment_t &x) { + SymbolTable* module_scope = current_scope->parent; + if (ASR::is_a(*x.m_value)) { + ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(x.m_value); + if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) { + process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, x.m_target); } } else if (ASR::is_a(*x.m_value)) { ASR::Cast_t* cast_t = ASR::down_cast(x.m_value); @@ -851,394 +640,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_intrinsic_id; - switch (static_cast(intrinsic_id)) { - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { - std::string new_name = "basic_const_pi"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - // Create the function call statement for basic_const_pi - ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = target; - call_args.push_back(al, call_arg); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_const_pi_sym, - basic_const_pi_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: { - std::string new_name = "symbol_set"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "s"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); + process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target); - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = target; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = intrinsic_func->m_args[0]; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, symbol_set_sym, - symbol_set_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_add", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_sub", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_mul", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_div", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_pow", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_diff", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_sin", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_cos", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_log", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_exp", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_abs", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_expand", - target, value); - break; - } - default: { - throw LCompilersException("IntrinsicFunction: `" - + ASRUtils::get_intrinsic_name(intrinsic_id) - + "` is not implemented"); - } - } std::string new_name = "basic_str"; symbolic_dependencies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { @@ -1322,395 +725,9 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x); ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); - switch (static_cast(intrinsic_id)) { - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { - std::string new_name = "basic_const_pi"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - // Create the function call statement for basic_const_pi - ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = target; - call_args.push_back(al, call_arg); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_const_pi_sym, - basic_const_pi_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: { - std::string new_name = "symbol_set"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "s"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = target; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = x.m_args[0]; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, symbol_set_sym, - symbol_set_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_add", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_sub", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_mul", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_div", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_pow", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_diff", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_sin", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_cos", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_log", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_exp", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_abs", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_expand", - target, value); - break; - } - default: { - throw LCompilersException("IntrinsicFunction: `" - + ASRUtils::get_intrinsic_name(intrinsic_id) - + "` is not implemented"); - } - } + process_intrinsic_function(al, x.base.base.loc, &xx, module_scope, target); } } @@ -1808,6 +825,35 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*expr)) { + var_sym = ASR::down_cast(expr)->m_v; + } else if (ASR::is_a(*expr)) { + ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(expr); + this->visit_IntrinsicFunction(*intrinsic_func); + var_sym = current_scope->get_symbol(symengine_stack.pop()); + } else if (ASR::is_a(*expr)) { + ASR::Cast_t* cast_t = ASR::down_cast(expr); + this->visit_Cast(*cast_t); + var_sym = current_scope->get_symbol(symengine_stack.pop()); + } + + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym)); + // Now create the FunctionCall node for basic_str + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_str_sym, basic_str_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), nullptr, nullptr)); + return function_call; + } + void visit_Assert(const ASR::Assert_t &x) { if (!ASR::is_a(*x.m_test)) return; ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_test); @@ -1851,108 +897,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(s2c(al, new_name), new_symbol); } ASR::symbol_t* basic_str_sym = module_scope->get_symbol(new_name); - - if(ASR::is_a(*s->m_left)) { - ASR::symbol_t *var_sym1 = ASR::down_cast(s->m_left)->m_v; - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - - // Now create the FunctionCall node for basic_str - Vec call_args1; - call_args1.reserve(al, 1); - ASR::call_arg_t call_arg1; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = target; - call_args1.push_back(al, call_arg1); - ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args1.p, call_args1.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - left_tmp = function_call1; - } else if(ASR::is_a(*s->m_left)) { - ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(s->m_left); - this->visit_IntrinsicFunction(*intrinsic_func); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - ASR::expr_t* left_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - - // Now create the FunctionCall node for basic_str - Vec call_args1; - call_args1.reserve(al, 1); - ASR::call_arg_t call_arg1; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = left_var; - call_args1.push_back(al, call_arg1); - ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args1.p, call_args1.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - left_tmp = function_call1; - } else if (ASR::is_a(*s->m_left)) { - ASR::Cast_t* cast_t = ASR::down_cast(s->m_left); - this->visit_Cast(*cast_t); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - ASR::expr_t* left_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - - // Now create the FunctionCall node for basic_str - Vec call_args1; - call_args1.reserve(al, 1); - ASR::call_arg_t call_arg1; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = left_var; - call_args1.push_back(al, call_arg1); - ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args1.p, call_args1.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - left_tmp = function_call1; - } - - if(ASR::is_a(*s->m_right)) { - ASR::symbol_t *var_sym1 = ASR::down_cast(s->m_right)->m_v; - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - - // Now create the FunctionCall node for basic_str - Vec call_args2; - call_args2.reserve(al, 1); - ASR::call_arg_t call_arg2; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = target; - call_args2.push_back(al, call_arg2); - ASR::expr_t* function_call2 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args2.p, call_args2.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - right_tmp = function_call2; - } else if(ASR::is_a(*s->m_right)) { - ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(s->m_right); - this->visit_IntrinsicFunction(*intrinsic_func); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - ASR::expr_t* right_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - - // Now create the FunctionCall node for basic_str - Vec call_args2; - call_args2.reserve(al, 1); - ASR::call_arg_t call_arg2; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = right_var; - call_args2.push_back(al, call_arg2); - ASR::expr_t* function_call2 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args2.p, call_args2.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - right_tmp = function_call2; - } else if (ASR::is_a(*s->m_right)) { - ASR::Cast_t* cast_t = ASR::down_cast(s->m_right); - this->visit_Cast(*cast_t); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - ASR::expr_t* right_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - - // Now create the FunctionCall node for basic_str - Vec call_args2; - call_args2.reserve(al, 1); - ASR::call_arg_t call_arg2; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = right_var; - call_args2.push_back(al, call_arg2); - ASR::expr_t* function_call2 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args2.p, call_args2.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - right_tmp = function_call2; - } + left_tmp = process_with_basic_str(al, x.base.base.loc, s->m_left, basic_str_sym); + right_tmp = process_with_basic_str(al, x.base.base.loc, s->m_right, basic_str_sym); ASR::expr_t* test = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp, s->m_op, right_tmp, s->m_type, s->m_value));