From c9de942f1ab59e3e5ec9a8d720ce954acd94f226 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Thu, 10 Aug 2023 11:28:52 +0530 Subject: [PATCH 1/6] Refactoring the symbolic ASR pass --- src/libasr/pass/replace_symbolic.cpp | 831 +++------------------------ 1 file changed, 83 insertions(+), 748 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index cf003ae577..a76f6fef17 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -7,9 +7,6 @@ #include #include -#include - - namespace LCompilers { using ASR::down_cast; @@ -294,6 +291,51 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]; + ASR::expr_t* value2 = x->m_args[1]; + + if (ASR::is_a(*value1) || ASR::is_a(*value1)) { + if (ASR::is_a(*value1)) { + this->visit_IntrinsicFunction(*ASR::down_cast(value1)); + } else { + this->visit_Cast(*ASR::down_cast(value1)); + } + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym1)); + } + + if (ASR::is_a(*value2) || ASR::is_a(*value2)) { + if (ASR::is_a(*value2)) { + this->visit_IntrinsicFunction(*ASR::down_cast(value2)); + } else { + this->visit_Cast(*ASR::down_cast(value2)); + } + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym2)); + } + + perform_symbolic_binary_operation(al, loc, module_scope, new_name, target, value1, value2); + } + + void process_unary_operator(Allocator &al, const Location &loc, ASR::IntrinsicFunction_t* x, SymbolTable* module_scope, + const std::string& new_name, ASR::expr_t* target) { + ASR::expr_t* value1 = x->m_args[0]; + + if (ASR::is_a(*value1) || ASR::is_a(*value1)) { + if (ASR::is_a(*value1)) { + this->visit_IntrinsicFunction(*ASR::down_cast(value1)); + } else { + this->visit_Cast(*ASR::down_cast(value1)); + } + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym1)); + } + + perform_symbolic_unary_operation(al, loc, module_scope, new_name, target, value1); + } + void visit_Assignment(const ASR::Assignment_t &x) { SymbolTable* module_scope = current_scope->parent; if (ASR::is_a(*x.m_value)) { @@ -400,291 +442,51 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_add", - x.m_target, value1, value2); + process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_add", x.m_target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_sub", - x.m_target, value1, value2); + process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_sub", x.m_target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_mul", - x.m_target, value1, value2); + process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_mul", x.m_target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_div", - x.m_target, value1, value2); + process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_div", x.m_target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_pow", - x.m_target, value1, value2); + process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_pow", x.m_target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_diff", - x.m_target, value1, value2); + process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_diff", x.m_target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_sin", - x.m_target, value); + process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_sin", x.m_target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_cos", - x.m_target, value); + process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_cos", x.m_target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_log", - x.m_target, value); + process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_log", x.m_target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_exp", - x.m_target, value); + process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_exp", x.m_target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_abs", - x.m_target, value); + process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_abs", x.m_target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_expand", - x.m_target, value); + process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_expand", x.m_target); break; } default: { @@ -952,285 +754,51 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_add", - target, value1, value2); + process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_add", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_sub", - target, value1, value2); + process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_sub", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_mul", - target, value1, value2); + process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_mul", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_div", - target, value1, value2); + process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_div", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_pow", - target, value1, value2); + process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_pow", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { - ASR::expr_t* value1 = intrinsic_func->m_args[0]; - ASR::expr_t* value2 = intrinsic_func->m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_diff", - target, value1, value2); + process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_diff", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_sin", - target, value); + process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_sin", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_cos", - target, value); + process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_cos", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_log", - target, value); + process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_log", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_exp", - target, value); + process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_exp", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_abs", - target, value); + process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_abs", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { - ASR::expr_t* value = intrinsic_func->m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t* s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_expand", - target, value); + process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_expand", target); break; } default: { @@ -1322,7 +890,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x); + int64_t intrinsic_id = xx.m_intrinsic_id; ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); switch (static_cast(intrinsic_id)) { case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { @@ -1424,285 +993,51 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_add", - target, value1, value2); + process_binary_operator(al, x.base.base.loc, &xx, module_scope, "basic_add", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_sub", - target, value1, value2); + process_binary_operator(al, x.base.base.loc, &xx, module_scope, "basic_sub", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_mul", - target, value1, value2); + process_binary_operator(al, x.base.base.loc, &xx, module_scope, "basic_mul", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_div", - target, value1, value2); + process_binary_operator(al, x.base.base.loc, &xx, module_scope, "basic_div", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_pow", - target, value1, value2); + process_binary_operator(al, x.base.base.loc, &xx, module_scope, "basic_pow", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } - - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_diff", - target, value1, value2); + process_binary_operator(al, x.base.base.loc, &xx, module_scope, "basic_diff", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_sin", - target, value); + process_unary_operator(al, x.base.base.loc, &xx, module_scope, "basic_sin", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_cos", - target, value); + process_unary_operator(al, x.base.base.loc, &xx, module_scope, "basic_cos", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_log", - target, value); + process_unary_operator(al, x.base.base.loc, &xx, module_scope, "basic_log", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_exp", - target, value); + process_unary_operator(al, x.base.base.loc, &xx, module_scope, "basic_exp", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_abs", - target, value); + process_unary_operator(al, x.base.base.loc, &xx, module_scope, "basic_abs", target); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_expand", - target, value); + process_unary_operator(al, x.base.base.loc, &xx, module_scope, "basic_expand", target); break; } default: { From 6e5a44a21983a516c5037b9253af76b914d9c029 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Thu, 10 Aug 2023 12:13:55 +0530 Subject: [PATCH 2/6] refactored switch case --- src/libasr/pass/replace_symbolic.cpp | 610 +++++++-------------------- 1 file changed, 155 insertions(+), 455 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index a76f6fef17..b8c6884b9b 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -336,165 +336,170 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; - if (ASR::is_a(*x.m_value)) { - ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(x.m_value); - int64_t intrinsic_id = intrinsic_func->m_intrinsic_id; - if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) { - switch (static_cast(intrinsic_id)) { - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { - std::string new_name = "basic_const_pi"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg))); + void process_intrinsic_function(Allocator &al, const Location &loc, ASR::IntrinsicFunction_t* x, SymbolTable* module_scope, + ASR::expr_t* target){ + int64_t intrinsic_id = x->m_intrinsic_id; + switch (static_cast(intrinsic_id)) { + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { + std::string new_name = "basic_const_pi"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); - Vec body; - body.reserve(al, 1); + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg))); - Vec dep; - dep.reserve(al, 1); + Vec body; + body.reserve(al, 1); - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } + Vec dep; + dep.reserve(al, 1); - // Create the function call statement for basic_const_pi - ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = x.m_target; - call_args.push_back(al, call_arg); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_const_pi_sym, - basic_const_pi_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: { - std::string new_name = "symbol_set"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "s"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); + // Create the function call statement for basic_const_pi + ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); - Vec body; - body.reserve(al, 1); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_const_pi_sym, + basic_const_pi_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: { + std::string new_name = "symbol_set"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); - Vec dep; - dep.reserve(al, 1); + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "s"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } + Vec body; + body.reserve(al, 1); - ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = x.m_target; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = intrinsic_func->m_args[0]; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); + Vec dep; + dep.reserve(al, 1); - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, symbol_set_sym, - symbol_set_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: { - process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_add", x.m_target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { - process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_sub", x.m_target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { - process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_mul", x.m_target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { - process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_div", x.m_target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { - process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_pow", x.m_target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { - process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_diff", x.m_target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { - process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_sin", x.m_target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { - process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_cos", x.m_target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { - process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_log", x.m_target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { - process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_exp", x.m_target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { - process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_abs", x.m_target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { - process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_expand", x.m_target); - break; - } - default: { - throw LCompilersException("IntrinsicFunction: `" - + ASRUtils::get_intrinsic_name(intrinsic_id) - + "` is not implemented"); - } + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); } + + ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = loc; + call_arg1.m_value = target; + call_arg2.loc = loc; + call_arg2.m_value = x->m_args[0]; + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, symbol_set_sym, + symbol_set_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: { + process_binary_operator(al, loc, x, module_scope, "basic_add", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { + process_binary_operator(al, loc, x, module_scope, "basic_sub", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { + process_binary_operator(al, loc, x, module_scope, "basic_mul", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { + process_binary_operator(al, loc, x, module_scope, "basic_div", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { + process_binary_operator(al, loc, x, module_scope, "basic_pow", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { + process_binary_operator(al, loc, x, module_scope, "basic_diff", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { + process_unary_operator(al, loc, x, module_scope, "basic_sin", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { + process_unary_operator(al, loc, x, module_scope, "basic_cos", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { + process_unary_operator(al, loc, x, module_scope, "basic_log", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { + process_unary_operator(al, loc, x, module_scope, "basic_exp", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { + process_unary_operator(al, loc, x, module_scope, "basic_abs", target); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { + process_unary_operator(al, loc, x, module_scope, "basic_expand", target); + break; + } + default: { + throw LCompilersException("IntrinsicFunction: `" + + ASRUtils::get_intrinsic_name(intrinsic_id) + + "` is not implemented"); + } + } + } + + void visit_Assignment(const ASR::Assignment_t &x) { + SymbolTable* module_scope = current_scope->parent; + if (ASR::is_a(*x.m_value)) { + ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(x.m_value); + if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) { + process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, x.m_target); } } else if (ASR::is_a(*x.m_value)) { ASR::Cast_t* cast_t = ASR::down_cast(x.m_value); @@ -653,160 +658,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_intrinsic_id; - switch (static_cast(intrinsic_id)) { - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { - std::string new_name = "basic_const_pi"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); + process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target); - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - // Create the function call statement for basic_const_pi - ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = target; - call_args.push_back(al, call_arg); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_const_pi_sym, - basic_const_pi_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: { - std::string new_name = "symbol_set"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "s"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = target; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = intrinsic_func->m_args[0]; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, symbol_set_sym, - symbol_set_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: { - process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_add", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { - process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_sub", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { - process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_mul", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { - process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_div", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { - process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_pow", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { - process_binary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_diff", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { - process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_sin", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { - process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_cos", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { - process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_log", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { - process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_exp", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { - process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_abs", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { - process_unary_operator(al, x.base.base.loc, intrinsic_func, module_scope, "basic_expand", target); - break; - } - default: { - throw LCompilersException("IntrinsicFunction: `" - + ASRUtils::get_intrinsic_name(intrinsic_id) - + "` is not implemented"); - } - } std::string new_name = "basic_str"; symbolic_dependencies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { @@ -891,161 +744,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x); - int64_t intrinsic_id = xx.m_intrinsic_id; ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); - switch (static_cast(intrinsic_id)) { - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { - std::string new_name = "basic_const_pi"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - // Create the function call statement for basic_const_pi - ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = target; - call_args.push_back(al, call_arg); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_const_pi_sym, - basic_const_pi_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: { - std::string new_name = "symbol_set"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "s"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = target; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = x.m_args[0]; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, symbol_set_sym, - symbol_set_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: { - process_binary_operator(al, x.base.base.loc, &xx, module_scope, "basic_add", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { - process_binary_operator(al, x.base.base.loc, &xx, module_scope, "basic_sub", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { - process_binary_operator(al, x.base.base.loc, &xx, module_scope, "basic_mul", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { - process_binary_operator(al, x.base.base.loc, &xx, module_scope, "basic_div", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { - process_binary_operator(al, x.base.base.loc, &xx, module_scope, "basic_pow", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { - process_binary_operator(al, x.base.base.loc, &xx, module_scope, "basic_diff", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { - process_unary_operator(al, x.base.base.loc, &xx, module_scope, "basic_sin", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { - process_unary_operator(al, x.base.base.loc, &xx, module_scope, "basic_cos", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { - process_unary_operator(al, x.base.base.loc, &xx, module_scope, "basic_log", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { - process_unary_operator(al, x.base.base.loc, &xx, module_scope, "basic_exp", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { - process_unary_operator(al, x.base.base.loc, &xx, module_scope, "basic_abs", target); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { - process_unary_operator(al, x.base.base.loc, &xx, module_scope, "basic_expand", target); - break; - } - default: { - throw LCompilersException("IntrinsicFunction: `" - + ASRUtils::get_intrinsic_name(intrinsic_id) - + "` is not implemented"); - } - } + process_intrinsic_function(al, x.base.base.loc, &xx, module_scope, target); } } From f49d3cefc6d749df4b250cada20b1d8177a99743 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Thu, 10 Aug 2023 13:16:53 +0530 Subject: [PATCH 3/6] Removed repetetive code from visit_assert --- src/libasr/pass/replace_symbolic.cpp | 133 +++++++-------------------- 1 file changed, 31 insertions(+), 102 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index b8c6884b9b..6a0976c989 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -843,6 +843,35 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*expr)) { + var_sym = ASR::down_cast(expr)->m_v; + } else if (ASR::is_a(*expr)) { + ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(expr); + this->visit_IntrinsicFunction(*intrinsic_func); + var_sym = current_scope->get_symbol(symengine_stack.pop()); + } else if (ASR::is_a(*expr)) { + ASR::Cast_t* cast_t = ASR::down_cast(expr); + this->visit_Cast(*cast_t); + var_sym = current_scope->get_symbol(symengine_stack.pop()); + } + + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym)); + // Now create the FunctionCall node for basic_str + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_str_sym, basic_str_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), nullptr, nullptr)); + return function_call; + } + void visit_Assert(const ASR::Assert_t &x) { if (!ASR::is_a(*x.m_test)) return; ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_test); @@ -886,108 +915,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(s2c(al, new_name), new_symbol); } ASR::symbol_t* basic_str_sym = module_scope->get_symbol(new_name); - - if(ASR::is_a(*s->m_left)) { - ASR::symbol_t *var_sym1 = ASR::down_cast(s->m_left)->m_v; - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - - // Now create the FunctionCall node for basic_str - Vec call_args1; - call_args1.reserve(al, 1); - ASR::call_arg_t call_arg1; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = target; - call_args1.push_back(al, call_arg1); - ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args1.p, call_args1.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - left_tmp = function_call1; - } else if(ASR::is_a(*s->m_left)) { - ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(s->m_left); - this->visit_IntrinsicFunction(*intrinsic_func); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - ASR::expr_t* left_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - - // Now create the FunctionCall node for basic_str - Vec call_args1; - call_args1.reserve(al, 1); - ASR::call_arg_t call_arg1; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = left_var; - call_args1.push_back(al, call_arg1); - ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args1.p, call_args1.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - left_tmp = function_call1; - } else if (ASR::is_a(*s->m_left)) { - ASR::Cast_t* cast_t = ASR::down_cast(s->m_left); - this->visit_Cast(*cast_t); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - ASR::expr_t* left_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - - // Now create the FunctionCall node for basic_str - Vec call_args1; - call_args1.reserve(al, 1); - ASR::call_arg_t call_arg1; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = left_var; - call_args1.push_back(al, call_arg1); - ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args1.p, call_args1.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - left_tmp = function_call1; - } - - if(ASR::is_a(*s->m_right)) { - ASR::symbol_t *var_sym1 = ASR::down_cast(s->m_right)->m_v; - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - - // Now create the FunctionCall node for basic_str - Vec call_args2; - call_args2.reserve(al, 1); - ASR::call_arg_t call_arg2; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = target; - call_args2.push_back(al, call_arg2); - ASR::expr_t* function_call2 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args2.p, call_args2.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - right_tmp = function_call2; - } else if(ASR::is_a(*s->m_right)) { - ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(s->m_right); - this->visit_IntrinsicFunction(*intrinsic_func); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - ASR::expr_t* right_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - - // Now create the FunctionCall node for basic_str - Vec call_args2; - call_args2.reserve(al, 1); - ASR::call_arg_t call_arg2; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = right_var; - call_args2.push_back(al, call_arg2); - ASR::expr_t* function_call2 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args2.p, call_args2.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - right_tmp = function_call2; - } else if (ASR::is_a(*s->m_right)) { - ASR::Cast_t* cast_t = ASR::down_cast(s->m_right); - this->visit_Cast(*cast_t); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - ASR::expr_t* right_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - - // Now create the FunctionCall node for basic_str - Vec call_args2; - call_args2.reserve(al, 1); - ASR::call_arg_t call_arg2; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = right_var; - call_args2.push_back(al, call_arg2); - ASR::expr_t* function_call2 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args2.p, call_args2.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - right_tmp = function_call2; - } + left_tmp = process_with_basic_str(al, x.base.base.loc, s->m_left, basic_str_sym); + right_tmp = process_with_basic_str(al, x.base.base.loc, s->m_right, basic_str_sym); ASR::expr_t* test = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp, s->m_op, right_tmp, s->m_type, s->m_value)); From 4c1d8429091023d5b3e4c3a8d5ba0a9a86a46dd4 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Thu, 10 Aug 2023 13:23:02 +0530 Subject: [PATCH 4/6] Fixed failing test --- src/libasr/pass/replace_symbolic.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 6a0976c989..db967bd0f3 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -845,7 +845,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*expr)) { var_sym = ASR::down_cast(expr)->m_v; } else if (ASR::is_a(*expr)) { From 287086b5aa322e81cbe2c2d71b9a1e30828ce598 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Thu, 10 Aug 2023 13:40:54 +0530 Subject: [PATCH 5/6] Implemented handle argument function --- src/libasr/pass/replace_symbolic.cpp | 55 +++++++++------------------- 1 file changed, 18 insertions(+), 37 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index db967bd0f3..7a039fb3be 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -291,49 +291,30 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]; - ASR::expr_t* value2 = x->m_args[1]; - - if (ASR::is_a(*value1) || ASR::is_a(*value1)) { - if (ASR::is_a(*value1)) { - this->visit_IntrinsicFunction(*ASR::down_cast(value1)); - } else { - this->visit_Cast(*ASR::down_cast(value1)); - } - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym1)); - } - - if (ASR::is_a(*value2) || ASR::is_a(*value2)) { - if (ASR::is_a(*value2)) { - this->visit_IntrinsicFunction(*ASR::down_cast(value2)); - } else { - this->visit_Cast(*ASR::down_cast(value2)); - } - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym2)); + ASR::expr_t* handle_argument(Allocator &al, const Location &loc, ASR::expr_t* arg) { + if (ASR::is_a(*arg) || ASR::is_a(*arg)) { + if (ASR::is_a(*arg)) { + this->visit_IntrinsicFunction(*ASR::down_cast(arg)); + } else if (ASR::is_a(*arg)) { + this->visit_Cast(*ASR::down_cast(arg)); } + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + return ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym)); + } + return arg; + } - perform_symbolic_binary_operation(al, loc, module_scope, new_name, target, value1, value2); + void process_binary_operator(Allocator &al, const Location &loc, ASR::IntrinsicFunction_t* x, SymbolTable* module_scope, + const std::string& new_name, ASR::expr_t* target) { + ASR::expr_t* value1 = handle_argument(al, loc, x->m_args[0]); + ASR::expr_t* value2 = handle_argument(al, loc, x->m_args[1]); + perform_symbolic_binary_operation(al, loc, module_scope, new_name, target, value1, value2); } void process_unary_operator(Allocator &al, const Location &loc, ASR::IntrinsicFunction_t* x, SymbolTable* module_scope, const std::string& new_name, ASR::expr_t* target) { - ASR::expr_t* value1 = x->m_args[0]; - - if (ASR::is_a(*value1) || ASR::is_a(*value1)) { - if (ASR::is_a(*value1)) { - this->visit_IntrinsicFunction(*ASR::down_cast(value1)); - } else { - this->visit_Cast(*ASR::down_cast(value1)); - } - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym1)); - } - - perform_symbolic_unary_operation(al, loc, module_scope, new_name, target, value1); + ASR::expr_t* value1 = handle_argument(al, loc, x->m_args[0]); + perform_symbolic_unary_operation(al, loc, module_scope, new_name, target, value1); } void process_intrinsic_function(Allocator &al, const Location &loc, ASR::IntrinsicFunction_t* x, SymbolTable* module_scope, From f9205b1f1213f109a315a1462bc9369ff4411d5b Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Fri, 11 Aug 2023 09:18:26 +0530 Subject: [PATCH 6/6] structured handle_argument function --- src/libasr/pass/replace_symbolic.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 7a039fb3be..b30fca76a3 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -292,16 +292,17 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*arg) || ASR::is_a(*arg)) { - if (ASR::is_a(*arg)) { - this->visit_IntrinsicFunction(*ASR::down_cast(arg)); - } else if (ASR::is_a(*arg)) { - this->visit_Cast(*ASR::down_cast(arg)); - } - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - return ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym)); + if (ASR::is_a(*arg)) { + return arg; + } else if (ASR::is_a(*arg)) { + this->visit_IntrinsicFunction(*ASR::down_cast(arg)); + } else if (ASR::is_a(*arg)) { + this->visit_Cast(*ASR::down_cast(arg)); + } else { + LCOMPILERS_ASSERT(false); } - return arg; + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + return ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym)); } void process_binary_operator(Allocator &al, const Location &loc, ASR::IntrinsicFunction_t* x, SymbolTable* module_scope,