diff --git a/integration_tests/symbolics_02.py b/integration_tests/symbolics_02.py index 1bfda7e74d..975f9c541a 100644 --- a/integration_tests/symbolics_02.py +++ b/integration_tests/symbolics_02.py @@ -18,6 +18,7 @@ def test_symbolic_operations(): assert True else: assert False + assert(z.func == Add) print(z) # Subtraction @@ -29,6 +30,7 @@ def test_symbolic_operations(): assert True else: assert False + assert(w.func == Add) print(w) # Multiplication @@ -40,6 +42,7 @@ def test_symbolic_operations(): assert True else: assert False + assert(u.func == Mul) print(u) # Division @@ -51,6 +54,7 @@ def test_symbolic_operations(): assert True else: assert False + assert(v.func == Mul) print(v) # Power @@ -66,6 +70,7 @@ def test_symbolic_operations(): assert True else: assert False + assert(p.func == Pow) print(p) # Casting diff --git a/integration_tests/symbolics_06.py b/integration_tests/symbolics_06.py index f56aa52c76..b9733cf763 100644 --- a/integration_tests/symbolics_06.py +++ b/integration_tests/symbolics_06.py @@ -30,8 +30,15 @@ def test_elementary_functions(): b: S = sin(a) c: S = cos(b) d: S = log(c) + d1: bool = d.func == log e: S = Abs(d) print(e) + assert(d1 == True) + if d.func == log: + assert True + else: + assert False + assert(d.func == log) assert(e == Abs(log(cos(sin(exp(x)))))) test_elementary_functions() \ No newline at end of file diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index 7cd59fe8b9..8b310a32f3 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -82,6 +82,7 @@ enum class IntrinsicScalarFunctions : int64_t { SymbolicAddQ, SymbolicMulQ, SymbolicPowQ, + SymbolicLogQ, // ... }; @@ -146,6 +147,7 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(SymbolicAddQ) INTRINSIC_NAME_CASE(SymbolicMulQ) INTRINSIC_NAME_CASE(SymbolicPowQ) + INTRINSIC_NAME_CASE(SymbolicLogQ) default : { throw LCompilersException("pickle: intrinsic_id not implemented"); } @@ -3146,6 +3148,7 @@ namespace X { create_symbolic_query_macro(SymbolicAddQ) create_symbolic_query_macro(SymbolicMulQ) create_symbolic_query_macro(SymbolicPowQ) +create_symbolic_query_macro(SymbolicLogQ) #define create_symbolic_unary_macro(X) \ @@ -3307,6 +3310,8 @@ namespace IntrinsicScalarFunctionRegistry { {nullptr, &SymbolicMulQ::verify_args}}, {static_cast(IntrinsicScalarFunctions::SymbolicPowQ), {nullptr, &SymbolicPowQ::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicLogQ), + {nullptr, &SymbolicLogQ::verify_args}}, }; static const std::map& intrinsic_function_id_to_name = { @@ -3417,6 +3422,8 @@ namespace IntrinsicScalarFunctionRegistry { "SymbolicMulQ"}, {static_cast(IntrinsicScalarFunctions::SymbolicPowQ), "SymbolicPowQ"}, + {static_cast(IntrinsicScalarFunctions::SymbolicLogQ), + "SymbolicLogQ"}, }; @@ -3475,6 +3482,7 @@ namespace IntrinsicScalarFunctionRegistry { {"AddQ", {&SymbolicAddQ::create_SymbolicAddQ, &SymbolicAddQ::eval_SymbolicAddQ}}, {"MulQ", {&SymbolicMulQ::create_SymbolicMulQ, &SymbolicMulQ::eval_SymbolicMulQ}}, {"PowQ", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}}, + {"LogQ", {&SymbolicLogQ::create_SymbolicLogQ, &SymbolicLogQ::eval_SymbolicLogQ}}, }; static inline bool is_intrinsic_function(const std::string& name) { diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index a8c6fdc3d5..062dafdccb 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -921,6 +921,24 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value1; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + // Using 29 as the right value of the IntegerCompare node as it represents SYMENGINE_LOG through SYMENGINE_ENUM + return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 29, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))), + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + break; + } default: { throw LCompilersException("IntrinsicFunction: `" + ASRUtils::get_intrinsic_name(intrinsic_id) @@ -1437,6 +1455,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_test)) { + ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(x.m_test); + if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { + ASR::expr_t* test = process_attributes(al, x.base.base.loc, x.m_test, module_scope); + ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); + pass_result.push_back(al, assert_stmt); + } } } }; diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index a9d01fae4c..5093fff717 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -6040,6 +6040,9 @@ class BodyVisitor : public CommonVisitor { } else if (symbolic_type == "Pow") { tmp = attr_handler.eval_symbolic_is_Pow(se, al, x.base.base.loc, args, diag); return; + } else if (symbolic_type == "log") { + tmp = attr_handler.eval_symbolic_is_log(se, al, x.base.base.loc, args, diag); + return; } else { throw SemanticError(symbolic_type + " symbolic type not supported yet", x.base.base.loc); } diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 6aef6f76cd..2f70c308e8 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -499,6 +499,20 @@ struct AttributeHandler { { throw SemanticError(msg, loc); }); } + static ASR::asr_t* eval_symbolic_is_log(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("LogQ"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + }; // AttributeHandler } // namespace LCompilers::LPython