diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index ddfe84d1f9..f58a9841bd 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -715,6 +715,7 @@ RUN(NAME symbolics_08 LABELS cpython_sym c_sym llvm_sym) RUN(NAME symbolics_09 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_10 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_11 LABELS cpython_sym c_sym NOFAST) +RUN(NAME symbolics_12 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME sizeof_01 LABELS llvm c EXTRAFILES sizeof_01b.c) diff --git a/integration_tests/symbolics_12.py b/integration_tests/symbolics_12.py new file mode 100644 index 0000000000..afc0c277f5 --- /dev/null +++ b/integration_tests/symbolics_12.py @@ -0,0 +1,39 @@ +from sympy import Symbol, E, log, exp +from lpython import S + +def main0(): + # Define symbolic variables + x: S = Symbol('x') + y: S = Symbol('y') + + # Assign E to the variable x + x = E + + # Check if x is equal to E + assert x == E + + # Perform some symbolic operations + z: S = x + y + + # Check if z is equal to E + y + assert z == E + y + + # Check if x is not equal to 2E + y + assert x != S(2) * E + y + + # Evaluate some mathematical expressions + expr1: S = log(E) + expr2: S = exp(S(1)) + + # Check the results + assert expr1 == S(1) + assert expr2 == E ** S(1) + + # Print the results + print("x =", x) + print("z =", z) + print("log(E) =", expr1) + print("exp(1) =", expr2) + + +main0() diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index 8b310a32f3..4cf8cf36b0 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -70,6 +70,7 @@ enum class IntrinsicScalarFunctions : int64_t { SymbolicDiv, SymbolicPow, SymbolicPi, + SymbolicE, SymbolicInteger, SymbolicDiff, SymbolicExpand, @@ -135,6 +136,7 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(SymbolicDiv) INTRINSIC_NAME_CASE(SymbolicPow) INTRINSIC_NAME_CASE(SymbolicPi) + INTRINSIC_NAME_CASE(SymbolicE) INTRINSIC_NAME_CASE(SymbolicInteger) INTRINSIC_NAME_CASE(SymbolicDiff) INTRINSIC_NAME_CASE(SymbolicExpand) @@ -3004,30 +3006,34 @@ create_symbolic_binary_macro(SymbolicDiv) create_symbolic_binary_macro(SymbolicPow) create_symbolic_binary_macro(SymbolicDiff) -namespace SymbolicPi { - - static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) { - ASRUtils::require_impl(x.n_args == 0, "SymbolicPi does not take arguments", - x.base.base.loc, diagnostics); - } - - static inline ASR::expr_t *eval_SymbolicPi(Allocator &/*al*/, - const Location &/*loc*/, ASR::ttype_t *, Vec& /*args*/) { - // TODO - return nullptr; - } - - static inline ASR::asr_t* create_SymbolicPi(Allocator& al, const Location& loc, - Vec& args, - const std::function /*err*/) { - ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc)); - ASR::expr_t* compile_time_value = eval_SymbolicPi(al, loc, to_type, args); - return ASR::make_IntrinsicScalarFunction_t(al, loc, - static_cast(IntrinsicScalarFunctions::SymbolicPi), - nullptr, 0, 0, to_type, compile_time_value); - } +#define create_symbolic_constants_macro(X) \ +namespace X { \ + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \ + diag::Diagnostics& diagnostics) { \ + const Location& loc = x.base.base.loc; \ + ASRUtils::require_impl(x.n_args == 0, \ + #X " does not take arguments", loc, diagnostics); \ + } \ + \ + static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \ + ASR::ttype_t *, Vec &/*args*/) { \ + /*TODO*/ \ + return nullptr; \ + } \ + \ + static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \ + Vec& args, \ + const std::function /*err*/) { \ + ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc)); \ + ASR::expr_t* compile_time_value = eval_##X(al, loc, to_type, args); \ + return ASR::make_IntrinsicScalarFunction_t(al, loc, \ + static_cast(IntrinsicScalarFunctions::X), \ + nullptr, 0, 0, to_type, compile_time_value); \ + } \ +} // namespace X -} // namespace SymbolicPi +create_symbolic_constants_macro(SymbolicPi) +create_symbolic_constants_macro(SymbolicE) namespace SymbolicInteger { @@ -3286,6 +3292,8 @@ namespace IntrinsicScalarFunctionRegistry { {nullptr, &SymbolicPow::verify_args}}, {static_cast(IntrinsicScalarFunctions::SymbolicPi), {nullptr, &SymbolicPi::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicE), + {nullptr, &SymbolicE::verify_args}}, {static_cast(IntrinsicScalarFunctions::SymbolicInteger), {nullptr, &SymbolicInteger::verify_args}}, {static_cast(IntrinsicScalarFunctions::SymbolicDiff), @@ -3398,6 +3406,8 @@ namespace IntrinsicScalarFunctionRegistry { "SymbolicPow"}, {static_cast(IntrinsicScalarFunctions::SymbolicPi), "pi"}, + {static_cast(IntrinsicScalarFunctions::SymbolicE), + "E"}, {static_cast(IntrinsicScalarFunctions::SymbolicInteger), "SymbolicInteger"}, {static_cast(IntrinsicScalarFunctions::SymbolicDiff), @@ -3470,6 +3480,7 @@ namespace IntrinsicScalarFunctionRegistry { {"SymbolicDiv", {&SymbolicDiv::create_SymbolicDiv, &SymbolicDiv::eval_SymbolicDiv}}, {"SymbolicPow", {&SymbolicPow::create_SymbolicPow, &SymbolicPow::eval_SymbolicPow}}, {"pi", {&SymbolicPi::create_SymbolicPi, &SymbolicPi::eval_SymbolicPi}}, + {"E", {&SymbolicE::create_SymbolicE, &SymbolicE::eval_SymbolicE}}, {"SymbolicInteger", {&SymbolicInteger::create_SymbolicInteger, &SymbolicInteger::eval_SymbolicInteger}}, {"diff", {&SymbolicDiff::create_SymbolicDiff, &SymbolicDiff::eval_SymbolicDiff}}, {"expand", {&SymbolicExpand::create_SymbolicExpand, &SymbolicExpand::eval_SymbolicExpand}}, diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 062dafdccb..3885b67667 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -370,6 +370,50 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + + ASR::symbol_t* func_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = loc; + call_arg.m_value = value; + call_args.push_back(al, call_arg); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, func_sym, + func_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + } + ASR::expr_t* handle_argument(Allocator &al, const Location &loc, ASR::expr_t* arg) { if (ASR::is_a(*arg)) { return arg; @@ -397,55 +441,15 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_intrinsic_id; switch (static_cast(intrinsic_id)) { - case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPi: { - std::string new_name = "basic_const_pi"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( - al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); - - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); - } - - // Create the function call statement for basic_const_pi - ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = loc; - call_arg.m_value = target; - call_args.push_back(al, call_arg); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_const_pi_sym, - basic_const_pi_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSymbol: { std::string new_name = "symbol_set"; symbolic_dependencies.push_back(new_name); @@ -499,6 +503,14 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor { std::string name = x.m_id; ASR::symbol_t *s = current_scope->resolve_symbol(name); std::set not_cpython_builtin = { - "pi"}; + "pi", "E"}; if (s) { tmp = ASR::make_Var_t(al, x.base.base.loc, s); } else if (name == "i32" || name == "i64" || name == "f32" || @@ -7116,7 +7116,7 @@ we will have to use something else. "diff", "expand", "has" }; std::set symbolic_constants = { - "pi" + "pi", "E" }; if (symbolic_attributes.find(call_name) != symbolic_attributes.end() && symbolic_constants.find(mod_name) != symbolic_constants.end()){