Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ecc18dc

Browse files
committed
Added support & tests for subs attribute
1 parent e378b4e commit ecc18dc

File tree

6 files changed

+109
-5
lines changed

6 files changed

+109
-5
lines changed

integration_tests/symbolics_05.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,12 @@ def test_operations():
4040
assert(c.args[0] == x)
4141
assert(d.args[0] == x)
4242

43+
# test subs
44+
b1: S = b.subs(x, y)
45+
b1 = b1.subs(z, y)
46+
assert(a.subs(x, y) == S(4)*y**S(2))
47+
assert(b1 == S(27)*y**S(3))
48+
assert(c.subs(x, y) == sin(y))
49+
assert(d.subs(x, z) == cos(z))
50+
4351
test_operations()

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ inline std::string get_intrinsic_name(int x) {
148148
INTRINSIC_NAME_CASE(SymbolicInteger)
149149
INTRINSIC_NAME_CASE(SymbolicDiff)
150150
INTRINSIC_NAME_CASE(SymbolicExpand)
151+
INTRINSIC_NAME_CASE(SymbolicSubs)
151152
INTRINSIC_NAME_CASE(SymbolicSin)
152153
INTRINSIC_NAME_CASE(SymbolicCos)
153154
INTRINSIC_NAME_CASE(SymbolicLog)
@@ -435,6 +436,8 @@ namespace IntrinsicElementalFunctionRegistry {
435436
{nullptr, &SymbolicDiff::verify_args}},
436437
{static_cast<int64_t>(IntrinsicElementalFunctions::SymbolicExpand),
437438
{nullptr, &SymbolicExpand::verify_args}},
439+
{static_cast<int64_t>(IntrinsicElementalFunctions::SymbolicSubs),
440+
{nullptr, &SymbolicSubs::verify_args}},
438441
{static_cast<int64_t>(IntrinsicElementalFunctions::SymbolicSin),
439442
{nullptr, &SymbolicSin::verify_args}},
440443
{static_cast<int64_t>(IntrinsicElementalFunctions::SymbolicCos),
@@ -724,6 +727,8 @@ namespace IntrinsicElementalFunctionRegistry {
724727
"SymbolicDiff"},
725728
{static_cast<int64_t>(IntrinsicElementalFunctions::SymbolicExpand),
726729
"SymbolicExpand"},
730+
{static_cast<int64_t>(IntrinsicElementalFunctions::SymbolicSubs),
731+
"SymbolicSubs"},
727732
{static_cast<int64_t>(IntrinsicElementalFunctions::SymbolicSin),
728733
"SymbolicSin"},
729734
{static_cast<int64_t>(IntrinsicElementalFunctions::SymbolicCos),
@@ -889,6 +894,7 @@ namespace IntrinsicElementalFunctionRegistry {
889894
{"SymbolicInteger", {&SymbolicInteger::create_SymbolicInteger, &SymbolicInteger::eval_SymbolicInteger}},
890895
{"diff", {&SymbolicDiff::create_SymbolicDiff, &SymbolicDiff::eval_SymbolicDiff}},
891896
{"expand", {&SymbolicExpand::create_SymbolicExpand, &SymbolicExpand::eval_SymbolicExpand}},
897+
{"subs", {&SymbolicSubs::create_SymbolicSubs, &SymbolicSubs::eval_SymbolicSubs}},
892898
{"SymbolicSin", {&SymbolicSin::create_SymbolicSin, &SymbolicSin::eval_SymbolicSin}},
893899
{"SymbolicCos", {&SymbolicCos::create_SymbolicCos, &SymbolicCos::eval_SymbolicCos}},
894900
{"SymbolicLog", {&SymbolicLog::create_SymbolicLog, &SymbolicLog::eval_SymbolicLog}},

src/libasr/pass/intrinsic_functions.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ enum class IntrinsicElementalFunctions : int64_t {
149149
SymbolicInteger,
150150
SymbolicDiff,
151151
SymbolicExpand,
152+
SymbolicSubs,
152153
SymbolicSin,
153154
SymbolicCos,
154155
SymbolicLog,
@@ -5651,6 +5652,64 @@ create_symbolic_binary_macro(SymbolicDiv)
56515652
create_symbolic_binary_macro(SymbolicPow)
56525653
create_symbolic_binary_macro(SymbolicDiff)
56535654

5655+
#define create_symbolic_ternary_macro(X) \
5656+
namespace X{ \
5657+
static inline void verify_args(const ASR::IntrinsicElementalFunction_t& x, \
5658+
diag::Diagnostics& diagnostics) { \
5659+
ASRUtils::require_impl(x.n_args == 3, "Intrinsic function `"#X"` accepts" \
5660+
"exactly 3 arguments", x.base.base.loc, diagnostics); \
5661+
\
5662+
ASR::ttype_t* arg1_type = ASRUtils::expr_type(x.m_args[0]); \
5663+
ASR::ttype_t* arg2_type = ASRUtils::expr_type(x.m_args[1]); \
5664+
ASR::ttype_t* arg3_type = ASRUtils::expr_type(x.m_args[2]); \
5665+
\
5666+
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*arg1_type) && \
5667+
ASR::is_a<ASR::SymbolicExpression_t>(*arg2_type) && \
5668+
ASR::is_a<ASR::SymbolicExpression_t>(*arg3_type), \
5669+
"All arguments of `"#X"` must be of type SymbolicExpression", \
5670+
x.base.base.loc, diagnostics); \
5671+
} \
5672+
\
5673+
static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \
5674+
ASR::ttype_t *, Vec<ASR::expr_t*> &/*args*/, diag::Diagnostics& /*diag*/) { \
5675+
/*TODO*/ \
5676+
return nullptr; \
5677+
} \
5678+
\
5679+
static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \
5680+
Vec<ASR::expr_t*>& args, \
5681+
diag::Diagnostics& diag) { \
5682+
if (args.size() != 3) { \
5683+
append_error(diag, "Intrinsic function `"#X"` accepts exactly 3 arguments", \
5684+
loc); \
5685+
return nullptr; \
5686+
} \
5687+
\
5688+
for (size_t i = 0; i < args.size(); i++) { \
5689+
ASR::ttype_t* argtype = ASRUtils::expr_type(args[i]); \
5690+
if(!ASR::is_a<ASR::SymbolicExpression_t>(*argtype)) { \
5691+
append_error(diag, \
5692+
"Arguments of `"#X"` function must be of type SymbolicExpression", \
5693+
args[i]->base.loc); \
5694+
return nullptr; \
5695+
} \
5696+
} \
5697+
\
5698+
Vec<ASR::expr_t*> arg_values; \
5699+
arg_values.reserve(al, args.size()); \
5700+
for( size_t i = 0; i < args.size(); i++ ) { \
5701+
arg_values.push_back(al, ASRUtils::expr_value(args[i])); \
5702+
} \
5703+
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc)); \
5704+
ASR::expr_t* compile_time_value = eval_##X(al, loc, to_type, arg_values, diag); \
5705+
return ASR::make_IntrinsicElementalFunction_t(al, loc, \
5706+
static_cast<int64_t>(IntrinsicElementalFunctions::X), \
5707+
args.p, args.size(), 0, to_type, compile_time_value); \
5708+
} \
5709+
} // namespace X
5710+
5711+
create_symbolic_ternary_macro(SymbolicSubs)
5712+
56545713
#define create_symbolic_constants_macro(X) \
56555714
namespace X { \
56565715
static inline void verify_args(const ASR::IntrinsicElementalFunction_t& x, \

src/libasr/pass/replace_symbolic.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
5757
"basic_const_"#name, target)); \
5858
break; }
5959

60+
#define BASIC_TERNARYOP(SYM, name) \
61+
case LCompilers::ASRUtils::IntrinsicElementalFunctions::Symbolic##SYM: { \
62+
pass_result.push_back(al, basic_ternaryop(loc, "basic_"#name, target, \
63+
x->m_args[0], x->m_args[1], x->m_args[2])); \
64+
break; }
65+
6066
#define BASIC_BINOP(SYM, name) \
6167
case LCompilers::ASRUtils::IntrinsicElementalFunctions::Symbolic##SYM: { \
6268
pass_result.push_back(al, basic_binop(loc, "basic_"#name, target, \
@@ -241,6 +247,16 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
241247
return SubroutineCall(loc, basic_const_sym, {value});
242248
}
243249

250+
ASR::stmt_t *basic_ternaryop(const Location &loc, const std::string &fn_name,
251+
ASR::expr_t* target, ASR::expr_t* op_01, ASR::expr_t* op_02, ASR::expr_t* op_03) {
252+
ASR::ttype_t *cptr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc));
253+
ASR::symbol_t* basic_ternaryop_sym = create_bindc_function(loc, fn_name,
254+
{cptr_type, cptr_type, cptr_type, cptr_type});
255+
return SubroutineCall(loc, basic_ternaryop_sym, {target,
256+
handle_argument(al, loc, op_01), handle_argument(al, loc, op_02),
257+
handle_argument(al, loc, op_03)});
258+
}
259+
244260
ASR::stmt_t *basic_binop(const Location &loc, const std::string &fn_name,
245261
ASR::expr_t* target, ASR::expr_t* op_01, ASR::expr_t* op_02) {
246262
ASR::ttype_t *cptr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc));
@@ -462,6 +478,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
462478
BASIC_UNARYOP(Exp, exp)
463479
BASIC_UNARYOP(Abs, abs)
464480
BASIC_UNARYOP(Expand, expand)
481+
BASIC_TERNARYOP(Subs, subs2)
465482
case LCompilers::ASRUtils::IntrinsicElementalFunctions::SymbolicGetArgument: {
466483
// Define necessary function symbols
467484
ASR::expr_t* value1 = handle_argument(al, loc, x->m_args[0]);

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7565,7 +7565,7 @@ we will have to use something else.
75657565
} else {
75667566
st = current_scope->resolve_symbol(mod_name);
75677567
std::set<std::string> symbolic_attributes = {
7568-
"diff", "expand", "has"
7568+
"diff", "expand", "has", "subs"
75697569
};
75707570
std::set<std::string> symbolic_constants = {
75717571
"pi", "E", "oo"
@@ -7640,7 +7640,7 @@ we will have to use something else.
76407640
} else if (AST::is_a<AST::BinOp_t>(*at->m_value)) {
76417641
AST::BinOp_t* bop = AST::down_cast<AST::BinOp_t>(at->m_value);
76427642
std::set<std::string> symbolic_attributes = {
7643-
"diff", "expand", "has"
7643+
"diff", "expand", "has", "subs"
76447644
};
76457645
if (symbolic_attributes.find(at->m_attr) != symbolic_attributes.end()){
76467646
switch (bop->m_op) {
@@ -7687,7 +7687,7 @@ we will have to use something else.
76877687
} else if (AST::is_a<AST::Call_t>(*at->m_value)) {
76887688
AST::Call_t* call = AST::down_cast<AST::Call_t>(at->m_value);
76897689
std::set<std::string> symbolic_attributes = {
7690-
"diff", "expand", "has"
7690+
"diff", "expand", "has", "subs"
76917691
};
76927692
if (symbolic_attributes.find(at->m_attr) != symbolic_attributes.end()){
76937693
std::set<std::string> symbolic_functions = {
@@ -7819,7 +7819,7 @@ we will have to use something else.
78197819
if (!s) {
78207820
std::string intrinsic_name = call_name;
78217821
std::set<std::string> not_cpython_builtin = {
7822-
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand", "trunc", "fix",
7822+
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand", "trunc", "fix", "subs",
78237823
"sum" // For sum called over lists
78247824
};
78257825
std::set<std::string> symbolic_functions = {

src/lpython/semantics/python_attribute_eval.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ struct AttributeHandler {
4949
{"expand", &eval_symbolic_expand},
5050
{"has", &eval_symbolic_has_symbol},
5151
{"is_integer", &eval_symbolic_is_integer},
52-
{"is_positive", &eval_symbolic_is_positive}
52+
{"is_positive", &eval_symbolic_is_positive},
53+
{"subs", &eval_symbolic_subs}
5354
};
5455
}
5556

@@ -590,6 +591,19 @@ struct AttributeHandler {
590591
return create_function(al, loc, args_with_list, diag);
591592
}
592593

594+
static ASR::asr_t* eval_symbolic_subs(ASR::expr_t *s, Allocator &al, const Location &loc,
595+
Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
596+
Vec<ASR::expr_t*> args_with_list;
597+
args_with_list.reserve(al, args.size() + 1);
598+
args_with_list.push_back(al, s);
599+
for(size_t i = 0; i < args.size(); i++) {
600+
args_with_list.push_back(al, args[i]);
601+
}
602+
ASRUtils::create_intrinsic_function create_function =
603+
ASRUtils::IntrinsicElementalFunctionRegistry::get_create_function("subs");
604+
return create_function(al, loc, args_with_list, diag);
605+
}
606+
593607
}; // AttributeHandler
594608

595609
} // namespace LCompilers::LPython

0 commit comments

Comments
 (0)