Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Update FMA/flip_sign pass #2313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1884,6 +1884,30 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
break ;
}
case ASRUtils::IntrinsicScalarFunctions::FlipSign: {
Vec<ASR::call_arg_t> args;
args.reserve(al, 2);
ASR::call_arg_t arg0_, arg1_;
arg0_.loc = x.m_args[0]->base.loc, arg0_.m_value = x.m_args[0];
args.push_back(al, arg0_);
arg1_.loc = x.m_args[1]->base.loc, arg1_.m_value = x.m_args[1];
args.push_back(al, arg1_);
generate_flip_sign(args.p);
break;
}
case ASRUtils::IntrinsicScalarFunctions::FMA: {
Vec<ASR::call_arg_t> args;
args.reserve(al, 3);
ASR::call_arg_t arg0_, arg1_, arg2_;
arg0_.loc = x.m_args[0]->base.loc, arg0_.m_value = x.m_args[0];
args.push_back(al, arg0_);
arg1_.loc = x.m_args[1]->base.loc, arg1_.m_value = x.m_args[1];
args.push_back(al, arg1_);
arg2_.loc = x.m_args[2]->base.loc, arg2_.m_value = x.m_args[2];
args.push_back(al, arg2_);
generate_fma(args.p);
break;
}
default: {
throw CodeGenError( ASRUtils::IntrinsicScalarFunctionRegistry::
get_intrinsic_function_name(x.m_intrinsic_id) +
Expand Down Expand Up @@ -7372,7 +7396,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm::Value* int_var = builder->CreateBitCast(CreateLoad(variable), shifted_signal->getType());
tmp = builder->CreateXor(shifted_signal, int_var);
llvm::Type* variable_type = llvm_utils->get_type_from_ttype_t_util(asr_variable->m_type, module.get());
builder->CreateStore(builder->CreateBitCast(tmp, variable_type->getPointerTo()), variable);
tmp = builder->CreateBitCast(tmp, variable_type);
}

void generate_fma(ASR::call_arg_t* m_args) {
Expand Down Expand Up @@ -8300,7 +8324,12 @@ Result<std::unique_ptr<LLVMModule>> asr_to_llvm(ASR::TranslationUnit_t &asr,
pass_options.run_fun = run_fn;
pass_options.always_run = false;
pass_options.verbose = co.verbose;
std::vector<int64_t> skip_optimization_func_instantiation;
skip_optimization_func_instantiation.push_back(static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FlipSign));
skip_optimization_func_instantiation.push_back(static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA));
pass_options.skip_optimization_func_instantiation = skip_optimization_func_instantiation;
pass_manager.rtlib = co.rtlib;

pass_manager.apply_passes(al, &asr, pass_options, diagnostics);

// Uncomment for debugging the ASR after the transformation
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/flip_sign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class FlipSignVisitor : public PassUtils::SkipOptimizationFunctionVisitor<FlipSi
LCOMPILERS_ASSERT(flip_sign_signal_variable);
LCOMPILERS_ASSERT(flip_sign_variable);
ASR::expr_t* flip_sign_result = PassUtils::get_flipsign(flip_sign_signal_variable,
flip_sign_variable, al, unit, x.base.base.loc);
flip_sign_variable, al, unit, x.base.base.loc, pass_options);
pass_result.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc,
flip_sign_variable, flip_sign_result, nullptr)));
}
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/fma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class FMAVisitor : public PassUtils::SkipOptimizationFunctionVisitor<FMAVisitor>
}

fma_var = PassUtils::get_fma(other_expr, first_arg, second_arg,
al, unit, x.base.base.loc);
al, unit, x.base.base.loc, pass_options);
from_fma = false;
}

Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -2451,7 +2451,7 @@ namespace IntrinsicScalarFunctionRegistry {
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
{&FMA::instantiate_FMA, &FMA::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::FlipSign),
{&FlipSign::instantiate_FlipSign, &FMA::verify_args}},
{&FlipSign::instantiate_FlipSign, &FlipSign::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Abs),
{&Abs::instantiate_Abs, &Abs::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Partition),
Expand Down
40 changes: 35 additions & 5 deletions src/libasr/pass/pass_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,14 +587,34 @@ namespace LCompilers {
int32_type, bound_type, nullptr));
}

bool skip_instantiation(PassOptions pass_options, int64_t id) {
if (!pass_options.skip_optimization_func_instantiation.empty()) {
for (size_t i=0; i<pass_options.skip_optimization_func_instantiation.size(); i++) {
if (pass_options.skip_optimization_func_instantiation[i] == id) {
return true;
}
}
}
return false;
}

ASR::expr_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc){
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc,
PassOptions pass_options){
ASR::ttype_t* type = ASRUtils::expr_type(arg1);
int64_t fp_s = static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FlipSign);
if (skip_instantiation(pass_options, fp_s)) {
Vec<ASR::expr_t*> args;
args.reserve(al, 2);
args.push_back(al, arg0);
args.push_back(al, arg1);
return ASRUtils::EXPR(ASRUtils::make_IntrinsicScalarFunction_t_util(al, loc, fp_s,
args.p, args.n, 0, type, nullptr));
}
ASRUtils::impl_function instantiate_function =
ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function(
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FlipSign));
Vec<ASR::ttype_t*> arg_types;
ASR::ttype_t* type = ASRUtils::expr_type(arg1);
arg_types.reserve(al, 2);
arg_types.push_back(al, ASRUtils::expr_type(arg0));
arg_types.push_back(al, ASRUtils::expr_type(arg1));
Expand Down Expand Up @@ -667,13 +687,23 @@ namespace LCompilers {
}

ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc){

Allocator& al, ASR::TranslationUnit_t& unit, Location& loc,
PassOptions pass_options){
int64_t fma_id = static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA);
ASR::ttype_t* type = ASRUtils::expr_type(arg0);
if (skip_instantiation(pass_options, fma_id)) {
Vec<ASR::expr_t*> args;
args.reserve(al, 3);
args.push_back(al, arg0);
args.push_back(al, arg1);
args.push_back(al, arg2);
return ASRUtils::EXPR(ASRUtils::make_IntrinsicScalarFunction_t_util(al, loc, fma_id,
args.p, args.n, 0, type, nullptr));
}
ASRUtils::impl_function instantiate_function =
ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function(
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA));
Vec<ASR::ttype_t*> arg_types;
ASR::ttype_t* type = ASRUtils::expr_type(arg0);
arg_types.reserve(al, 3);
arg_types.push_back(al, ASRUtils::expr_type(arg0));
arg_types.push_back(al, ASRUtils::expr_type(arg1));
Expand Down
6 changes: 4 additions & 2 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ namespace LCompilers {
Allocator& al);

ASR::expr_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc);
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc,
PassOptions pass_options);

ASR::expr_t* to_int32(ASR::expr_t* x, ASR::ttype_t* int32type, Allocator& al);

Expand All @@ -86,7 +87,8 @@ namespace LCompilers {
ASR::intentType var_intent=ASR::intentType::Local);

ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc);
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc,
PassOptions pass_options);

ASR::expr_t* get_sign_from_value(ASR::expr_t* arg0, ASR::expr_t* arg1,
Allocator& al, ASR::TranslationUnit_t& unit,
Expand Down
1 change: 1 addition & 0 deletions src/libasr/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ namespace LCompilers {
bool verbose = false; // For developer debugging
bool pass_cumulative = false; // Apply passes cumulatively
bool disable_main = false;
std::vector<int64_t> skip_optimization_func_instantiation;
bool module_name_mangling = false;
bool global_symbols_mangling = false;
bool intrinsic_symbols_mangling = false;
Expand Down