Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f76dff7

Browse files
authored
Merge pull request #2335 from anutosh491/Implement_visit_SubroutineCall
Implemented `visit_SubroutineCall` for the symbolic pass
2 parents f9b09dd + 37cc6eb commit f76dff7

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

integration_tests/symbolics_09.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sympy import Symbol, pi, S
1+
from sympy import Symbol, pi, sin, cos
22
from lpython import S, i32
33

44
def addInteger(x: S, y: S, z: S, i: i32):
@@ -9,7 +9,11 @@ def call_addInteger():
99
a: S = Symbol("x")
1010
b: S = Symbol("y")
1111
c: S = pi
12-
addInteger(a, b, c, 2)
12+
d: S = sin(a)
13+
e: S = cos(b)
14+
addInteger(c, d, e, 2)
15+
addInteger(c, sin(a), cos(b), 2)
16+
addInteger(pi, sin(Symbol("x")), cos(Symbol("y")), 2)
1317

1418
def main0():
1519
call_addInteger()

src/libasr/pass/replace_symbolic.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,56 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
695695
}
696696
}
697697

698+
void visit_SubroutineCall(const ASR::SubroutineCall_t &x) {
699+
SymbolTable* module_scope = current_scope->parent;
700+
Vec<ASR::call_arg_t> call_args;
701+
call_args.reserve(al, 1);
702+
703+
for (size_t i=0; i<x.n_args; i++) {
704+
ASR::expr_t* val = x.m_args[i].m_value;
705+
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*val) && ASR::is_a<ASR::SymbolicExpression_t>(*ASRUtils::expr_type(val))) {
706+
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(val);
707+
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc));
708+
std::string symengine_var = symengine_stack.push();
709+
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
710+
al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local,
711+
nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr,
712+
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
713+
current_scope->add_symbol(s2c(al, symengine_var), arg);
714+
for (auto &item : current_scope->get_scope()) {
715+
if (ASR::is_a<ASR::Variable_t>(*item.second)) {
716+
ASR::Variable_t *s = ASR::down_cast<ASR::Variable_t>(item.second);
717+
this->visit_Variable(*s);
718+
}
719+
}
720+
721+
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg));
722+
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target);
723+
724+
ASR::call_arg_t call_arg;
725+
call_arg.loc = x.base.base.loc;
726+
call_arg.m_value = target;
727+
call_args.push_back(al, call_arg);
728+
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
729+
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
730+
if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return;
731+
this->visit_Cast(*cast_t);
732+
ASR::symbol_t *var_sym = current_scope->get_symbol(symengine_stack.pop());
733+
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));
734+
735+
ASR::call_arg_t call_arg;
736+
call_arg.loc = x.base.base.loc;
737+
call_arg.m_value = target;
738+
call_args.push_back(al, call_arg);
739+
} else {
740+
call_args.push_back(al, x.m_args[i]);
741+
}
742+
}
743+
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, x.m_name,
744+
x.m_name, call_args.p, call_args.n, nullptr));
745+
pass_result.push_back(al, stmt);
746+
}
747+
698748
void visit_Print(const ASR::Print_t &x) {
699749
std::vector<ASR::expr_t*> print_tmp;
700750
SymbolTable* module_scope = current_scope->parent;
@@ -739,6 +789,25 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
739789
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg));
740790
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target);
741791

792+
// Now create the FunctionCall node for basic_str
793+
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
794+
Vec<ASR::call_arg_t> call_args;
795+
call_args.reserve(al, 1);
796+
ASR::call_arg_t call_arg;
797+
call_arg.loc = x.base.base.loc;
798+
call_arg.m_value = target;
799+
call_args.push_back(al, call_arg);
800+
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
801+
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
802+
ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr));
803+
print_tmp.push_back(function_call);
804+
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
805+
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
806+
if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return;
807+
this->visit_Cast(*cast_t);
808+
ASR::symbol_t *var_sym = current_scope->get_symbol(symengine_stack.pop());
809+
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));
810+
742811
// Now create the FunctionCall node for basic_str
743812
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
744813
Vec<ASR::call_arg_t> call_args;

0 commit comments

Comments
 (0)