diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index f58a9841bd..1051514f43 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -716,6 +716,7 @@ RUN(NAME symbolics_09 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_10 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_11 LABELS cpython_sym c_sym NOFAST) RUN(NAME symbolics_12 LABELS cpython_sym c_sym llvm_sym NOFAST) +RUN(NAME symbolics_13 LABELS cpython_sym c_sym NOFAST) RUN(NAME sizeof_01 LABELS llvm c EXTRAFILES sizeof_01b.c) diff --git a/integration_tests/symbolics_13.py b/integration_tests/symbolics_13.py new file mode 100644 index 0000000000..06f2c27599 --- /dev/null +++ b/integration_tests/symbolics_13.py @@ -0,0 +1,12 @@ +from lpython import S +from sympy import pi, Symbol + +def func() -> S: + return pi + +def test_func(): + z: S = func() + print(z) + assert z == pi + +test_func() \ No newline at end of file diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index ded8cb078b..27b5ee5bb8 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -116,7 +116,8 @@ namespace LCompilers { static inline bool is_aggregate_or_array_type(ASR::expr_t* var) { return (ASR::is_a(*ASRUtils::expr_type(var)) || - ASRUtils::is_array(ASRUtils::expr_type(var))); + ASRUtils::is_array(ASRUtils::expr_type(var)) || + ASR::is_a(*ASRUtils::expr_type(var))); } template @@ -775,7 +776,7 @@ namespace LCompilers { } static inline void handle_fn_return_var(Allocator &al, ASR::Function_t *x, - bool (*is_array_or_struct)(ASR::expr_t*)) { + bool (*is_array_or_struct_or_symbolic)(ASR::expr_t*)) { if (ASRUtils::get_FunctionType(x)->m_abi == ASR::abiType::BindPython) { return; } @@ -787,7 +788,7 @@ namespace LCompilers { * in avoiding deep copies and the destination memory directly gets * filled inside the function. */ - if( is_array_or_struct(x->m_return_var)) { + if( is_array_or_struct_or_symbolic(x->m_return_var)) { for( auto& s_item: x->m_symtab->get_scope() ) { ASR::symbol_t* curr_sym = s_item.second; if( curr_sym->type == ASR::symbolType::Variable ) { @@ -834,7 +835,7 @@ namespace LCompilers { for (auto &item : x->m_symtab->get_scope()) { if (ASR::is_a(*item.second)) { handle_fn_return_var(al, ASR::down_cast( - item.second), is_array_or_struct); + item.second), is_array_or_struct_or_symbolic); } } } diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 4b856c2fbe..7abf80f8fd 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -123,7 +123,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor((ASR::asr_t*)&xx)); } if(xx.m_intent == ASR::intentType::In){