Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Implementing query method for exp class #2386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

anutosh491
Copy link
Collaborator

We start with something like the following in the frontend

(lf) anutosh491@spbhat68:~/lpython/lpython$ cat examples/expr2.py 
from lpython import S
from sympy import pi, exp, Pow, Symbol

def main0():
    x: S = Symbol("x")
    y: S = exp(x)
    z: bool = y.func == exp
    print(z)

And I think we need to transform this into something like the following through the pass

def main0():
    _x: i64 = i64(0)
    x: CPtr = empty_c_void_p()
    p_c_pointer(pointer(_x, i64), x)
    basic_new_stack(x)
    symbol_set(x, "x")
    _y: i64 = i64(0)
    y: CPtr = empty_c_void_p()
    p_c_pointer(pointer(_y, i64), y)
    basic_new_stack(y)
    basic_exp(y, x)
    args: CPtr = vecbasic_new()
    basic_get_args(y, args)
    _base: i64 = i64(0)
    base: CPtr = empty_c_void_p()
    p_c_pointer(pointer(_base, i64), base)
    vecbasic_get(args, 0, base)
    _const: i64 = i64(0)
    const: CPtr = empty_c_void_p()
    p_c_pointer(pointer(_const, i64), const)
    basic_new_stack(const)
    basic_const_E(const)
    z: bool = basic_get_type(y) == 17 and basic_eq(base, const)
    print(z)

@anutosh491
Copy link
Collaborator Author

So basically after extracting the base we need to compare it against the constant E through basic_eq which tells me that along with pi, we would also need to introduce the constants E in Lpython which would be my next step.

@anutosh491
Copy link
Collaborator Author

Okay I've made a separate PR (#2387) for addressing the above comment, so that the reviewing is easier on this PR and we could just focus on implementing expQ through this PR. @certik once you merge the above PR, I shall rebase this one !

@anutosh491
Copy link
Collaborator Author

anutosh491 commented Oct 22, 2023

I was able to frame a patch for this and get an output

(lf) anutosh491@spbhat68:~/lpython/lpython$ cat examples/expr2.py 
from lpython import S
from sympy import pi, exp, Pow, Symbol, E

def main0():
    x: S = Symbol("x")
    y: S = exp(x)
    z: bool = y.func == exp
    print(z)

(lf) anutosh491@spbhat68:~/lpython/lpython$ lpython --enable-symengine --backend=llvm examples/expr2.py
True

But the code looks quite odd in comparison to other class query methods (because unfortunately symengine doesn't support an exp class)

So essentially just one statement in the frontend e.func == exp is replaced by the following

    args: CPtr = vecbasic_new()
    basic_get_args(y, args)
    _base: i64 = i64(0)
    base: CPtr = empty_c_void_p()
    p_c_pointer(pointer(_base, i64), base)
    vecbasic_get(args, 0, base)
    _const: i64 = i64(0)
    const: CPtr = empty_c_void_p()
    p_c_pointer(pointer(_const, i64), const)
    basic_new_stack(const)
    basic_const_E(const)
    z: bool = basic_get_type(y) == 17 and basic_eq(base, const)

Which involves replacing a single statement in the frontend with around 6 statements through the pass

  1. Defining an argument vector
  2. Extracting args from the expression
  3. Defining a base to hold the 0th index of the argument vector (using vecbasic_get(args, 0, base))
  4. Defining a const to hold E
  5. checking type as Pow and equating base against const

I'll paste the working patch below but just as I pointed out I am a bit doubtful if this is the best approach here. Because check the following case out . Through symengine we would get

(lf) anutosh491@spbhat68:~/lpython/lpython$ cat examples/expr2.py 
from lpython import S
from sympy import pi, exp, Pow, Symbol, E

def main0():
    x: S = Symbol("x")
    y: S = exp(x)
    z: bool = y.func == Pow
    print(z)

(lf) anutosh491@spbhat68:~/lpython/lpython$ lpython --enable-symengine --backend=llvm examples/expr2.py
True

Whereas if through sympy we would get

>>> exp(x).func == Pow
False

Hence we would have to tackle such cases too if we continue without introducing the exp class in SymEngine.

@anutosh491
Copy link
Collaborator Author

But yeah if we are interested in the approach that I've pasted above, this is the way to go

                case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicExpQ: {
                    // Define necessary function symbols
                    ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
                    ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
                    ASR::symbol_t* basic_get_args_sym = declare_basic_get_args_function(al, loc, module_scope);
                    ASR::symbol_t* vecbasic_new_sym = declare_vecbasic_new_function(al, loc, module_scope);
                    ASR::symbol_t* vecbasic_get_sym = declare_vecbasic_get_function(al, loc, module_scope);
                    ASR::symbol_t* basic_eq_sym = declare_basic_eq_function(al, loc, module_scope);

                    // Define necessary variables
                    ASR::ttype_t* CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc));
                    ASR::ttype_t* symbolic_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc));
                    ASR::symbol_t* args_sym = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
                        al, loc, current_scope, s2c(al, "args"), nullptr, 0, ASR::intentType::Local,
                        nullptr, nullptr, ASR::storage_typeType::Default, CPtr_type, nullptr,
                        ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));

                    current_scope->add_symbol(s2c(al, "args"), args_sym);
                    std::string symengine_var1 = symengine_stack.push();
                    ASR::symbol_t* base_sym = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
                        al, loc, current_scope, s2c(al, symengine_var1), nullptr, 0, ASR::intentType::Local,
                        nullptr, nullptr, ASR::storage_typeType::Default, symbolic_type, nullptr,
                        ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
                    current_scope->add_symbol(s2c(al, symengine_var1), base_sym);
                    std::string symengine_var2 = symengine_stack.push();
                    ASR::symbol_t* const_sym = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
                        al, loc, current_scope, s2c(al, symengine_var2), nullptr, 0, ASR::intentType::Local,
                        nullptr, nullptr, ASR::storage_typeType::Default, symbolic_type, nullptr,
                        ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
                    current_scope->add_symbol(s2c(al, symengine_var2), const_sym);
                    for (auto &item : current_scope->get_scope()) {
                        if (ASR::is_a<ASR::Variable_t>(*item.second)) {
                            ASR::Variable_t *s = ASR::down_cast<ASR::Variable_t>(item.second);
                            this->visit_Variable(*s);
                        }
                    }

                    ASR::expr_t* constant = ASRUtils::EXPR(ASR::make_Var_t(al, loc,
                        current_scope->get_symbol(symengine_stack.pop())));
                    ASR::expr_t* base = ASRUtils::EXPR(ASR::make_Var_t(al, loc,
                        current_scope->get_symbol(symengine_stack.pop())));

                    // Statement 1
                    ASR::expr_t* args = ASRUtils::EXPR(ASR::make_Var_t(al, loc, args_sym));
                    Vec<ASR::call_arg_t> call_args1;
                    call_args1.reserve(al, 1);
                    ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
                        vecbasic_new_sym, vecbasic_new_sym, call_args1.p, call_args1.n,
                        ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), nullptr, nullptr));
                    ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, args, function_call1, nullptr));
                    pass_result.push_back(al, stmt1);

                    // Statement 2
                    Vec<ASR::call_arg_t> call_args2;
                    call_args2.reserve(al, 2);
                    ASR::call_arg_t call_arg1, call_arg2;
                    call_arg1.loc = loc;
                    call_arg1.m_value = value1;
                    call_arg2.loc = loc;
                    call_arg2.m_value = args;
                    call_args2.push_back(al, call_arg1);
                    call_args2.push_back(al, call_arg2);
                    ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_get_args_sym,
                        basic_get_args_sym, call_args2.p, call_args2.n, nullptr));
                    pass_result.push_back(al, stmt2);

                    // Statement 3
                    Vec<ASR::call_arg_t> call_args3;
                    call_args3.reserve(al, 3);
                    ASR::call_arg_t call_arg3, call_arg4, call_arg5;
                    call_arg3.loc = loc;
                    call_arg3.m_value = args;
                    call_arg4.loc = loc;
                    call_arg4.m_value = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 0, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4))));
                    call_arg5.loc = loc;
                    call_arg5.m_value = base;
                    call_args3.push_back(al, call_arg3);
                    call_args3.push_back(al, call_arg4);
                    call_args3.push_back(al, call_arg5);
                    ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, vecbasic_get_sym,
                        vecbasic_get_sym, call_args3.p, call_args3.n, nullptr));
                    pass_result.push_back(al, stmt3);

                    // Statement 4
                    ASR::expr_t* E = ASRUtils::EXPR(ASR::make_IntrinsicScalarFunction_t(al, loc,
                        static_cast<int64_t>(LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicE), nullptr,
                        0, 0, symbolic_type, nullptr));
                    ASR::IntrinsicScalarFunction_t* intrinsic_func_E = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(E);
                    process_intrinsic_function(al, loc, intrinsic_func_E, module_scope, constant);

                    // Statement 5
                    Vec<ASR::call_arg_t> call_args4, call_args5;
                    call_args4.reserve(al, 1);
                    call_args5.reserve(al, 2);
                    ASR::call_arg_t call_arg6;
                    ASR::call_arg_t call_arg7, call_arg8;
                    call_arg6.loc = loc;
                    call_arg6.m_value = value1;
                    call_args4.push_back(al, call_arg6);
                    call_arg7.loc = loc;
                    call_arg7.m_value = base;
                    call_arg8.loc = loc;
                    call_arg8.m_value = constant;
                    call_args5.push_back(al, call_arg7);
                    call_args5.push_back(al, call_arg8);
                    ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
                        basic_get_type_sym, basic_get_type_sym, call_args4.p, call_args4.n,
                        ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
                    // Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM
                    ASR::expr_t* left = ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
                        ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
                        ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
                    ASR::expr_t* right = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
                        basic_eq_sym, basic_eq_sym, call_args5.p, call_args5.n,
                        ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr));
                    return ASRUtils::EXPR(ASR::make_LogicalBinOp_t(al, loc, left,
                        ASR::logicalbinopType::And, right, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
                    break;
                }

It works perfectly fine but confuses me if we would like to have such a big implementation or should we maybe just introduce the exp class in SymEngine .

@certik
Copy link
Contributor

certik commented Oct 24, 2023

The question of introducing the exp class in SymEngine is a design question which has pros/cons, but I suspect we should not do it, it seems it does not speed things up, and Mathematica also doesn't do it, only SymPy. I opened up an issue symengine/symengine#1984 to discuss that more.

Given that, and given that matching and querying also fails in Mathematica if you query for Exp, you have to query for Power, I would go with that design.

The problem is how to keep compatibility with SymPy. I think we can support y.func == exp, and internally check for Pow(E,...), so that shouldn't be a problem. The problem is what to do about y.func == Pow and y.args[0] == E, which will work with LPython, but fail in SymPy. We have three options:

  • Try to disallow it in LPython. However, what about just doing exp(x).func == Pow, which will pass in LPython but fail in SymPy. I don't see how we can disallow this, since the following must be allowed and passes in both SymPy and LPython: (x**x).func == Pow. So far I don't see a path forward here.
  • Introduce exp(x) into SymEngine, but given Design question: should exp(x) be its own type/function? symengine/symengine#1984, I don't see this viable either.
  • Allow .func == Pow in LPython, and yes, exp(x) will differ, which breaks the contract of "if it works in LPython, it must work in Python/SymPy`.

Any other option?

It seems the last option is the only way forward. This means there might not be any way we can enforce at compile time this difference. The solution might be that we might need to create a pure Python implementation of the Symbolic API, since SymPy itself differs slightly. This might be a good idea to do, and SymPy could later use it as well. For this pure Python implementation we will 100% guarantee that if it works in LPython, it will work with this pure Python implementation. SymPy will only work 99% or so. We will document that if you want to be compatible with sympy, to use .func == exp, then we can guarantee 100% compatibility. But if you use y.func == Pow, then if you want SymPy to work, you must ensure y is not exp(x).

@anutosh491
Copy link
Collaborator Author

As I have worked with SymPy's gruntz.py, I can say that differentiating exp and Pow in SymPy does play a role atleast in implementation of gruntz algorithm. I'm not sure of all advantages but this is actually a prominent one.

There would be cases where we would like to use exp._eval_nseries rather than the standard Pow._eval_nseries in gruntz. This is also pointed out by the note stating Important here

But yeah if we have don't see any other viable option, we can try not following the compilation rule (if works in Lpython, should work in SymPy) for this corner case and support only the Pow class and not the exp class.

@anutosh491
Copy link
Collaborator Author

This can now be closed as we have decided not to go ahead with a query method for exp rather just use the Pow class and check if the base of the expression is E.

@anutosh491 anutosh491 closed this Feb 20, 2024
@anutosh491 anutosh491 deleted the Implement_expQ branch February 20, 2024 13:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants