Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit df414fb

Browse files
authored
[numpy] add trunc (#2316)
1 parent a6c965e commit df414fb

35 files changed

+1349
-1189
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ RUN(NAME elemental_09 LABELS cpython llvm c NOFAST)
596596
RUN(NAME elemental_10 LABELS cpython llvm c NOFAST)
597597
RUN(NAME elemental_11 LABELS cpython llvm c NOFAST)
598598
RUN(NAME elemental_12 LABELS cpython llvm c NOFAST)
599+
RUN(NAME elemental_13 LABELS cpython llvm c NOFAST)
599600
RUN(NAME test_random LABELS cpython llvm NOFAST)
600601
RUN(NAME test_os LABELS cpython llvm c NOFAST)
601602
RUN(NAME test_builtin LABELS cpython llvm c)

integration_tests/elemental_13.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from lpython import f32, f64
2+
from numpy import trunc, empty, sqrt, reshape, int32, float32, float64
3+
4+
5+
def elemental_trunc64():
6+
i: i32
7+
j: i32
8+
k: i32
9+
l: i32
10+
eps: f32
11+
eps = f32(1e-6)
12+
13+
arraynd: f64[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float64)
14+
15+
newshape: i32[1] = empty(1, dtype = int32)
16+
newshape[0] = 16384
17+
18+
for i in range(32):
19+
for j in range(16):
20+
for k in range(8):
21+
for l in range(4):
22+
arraynd[i, j, k, l] = f64((-1)**l) * sqrt(float(i + j + j + l))
23+
24+
observed: f64[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float64)
25+
observed = trunc(arraynd)
26+
27+
observed1d: f64[16384] = empty(16384, dtype=float64)
28+
observed1d = reshape(observed, newshape)
29+
30+
array: f64[16384] = empty(16384, dtype=float64)
31+
array = reshape(arraynd, newshape)
32+
33+
for i in range(16384):
34+
assert f32(abs(trunc(array[i]) - observed1d[i])) <= eps
35+
36+
37+
def elemental_trunc32():
38+
i: i32
39+
j: i32
40+
k: i32
41+
l: i32
42+
eps: f32
43+
eps = f32(1e-6)
44+
45+
arraynd: f32[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float32)
46+
47+
for i in range(32):
48+
for j in range(16):
49+
for k in range(8):
50+
for l in range(4):
51+
arraynd[i, j, k, l] = f32(f64((-1)**l) * sqrt(float(i + j + j + l)))
52+
53+
observed: f32[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float32)
54+
observed = trunc(arraynd)
55+
56+
for i in range(32):
57+
for j in range(16):
58+
for k in range(8):
59+
for l in range(4):
60+
assert abs(trunc(arraynd[i, j, k, l]) - observed[i, j, k, l]) <= eps
61+
62+
63+
elemental_trunc64()
64+
elemental_trunc32()

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2790,6 +2790,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
27902790
SET_INTRINSIC_NAME(Exp, "exp");
27912791
SET_INTRINSIC_NAME(Exp2, "exp2");
27922792
SET_INTRINSIC_NAME(Expm1, "expm1");
2793+
SET_INTRINSIC_NAME(Trunc, "trunc");
27932794
default : {
27942795
throw LCompilersException("IntrinsicScalarFunction: `"
27952796
+ ASRUtils::get_intrinsic_name(x.m_intrinsic_id)

src/libasr/codegen/asr_to_julia.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1899,6 +1899,7 @@ class ASRToJuliaVisitor : public ASR::BaseVisitor<ASRToJuliaVisitor>
18991899
SET_INTRINSIC_NAME(Exp, "exp");
19001900
SET_INTRINSIC_NAME(Exp2, "exp2");
19011901
SET_INTRINSIC_NAME(Expm1, "expm1");
1902+
SET_INTRINSIC_NAME(Trunc, "trunc");
19021903
default : {
19031904
throw LCompilersException("IntrinsicFunction: `"
19041905
+ ASRUtils::get_intrinsic_name(x.m_intrinsic_id)

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ enum class IntrinsicScalarFunctions : int64_t {
3838
Atan2,
3939
Gamma,
4040
LogGamma,
41+
Trunc,
4142
Abs,
4243
Exp,
4344
Exp2,
@@ -96,6 +97,7 @@ inline std::string get_intrinsic_name(int x) {
9697
INTRINSIC_NAME_CASE(Atan2)
9798
INTRINSIC_NAME_CASE(Gamma)
9899
INTRINSIC_NAME_CASE(LogGamma)
100+
INTRINSIC_NAME_CASE(Trunc)
99101
INTRINSIC_NAME_CASE(Abs)
100102
INTRINSIC_NAME_CASE(Exp)
101103
INTRINSIC_NAME_CASE(Exp2)
@@ -1142,6 +1144,44 @@ static inline ASR::expr_t* instantiate_LogGamma (Allocator &al,
11421144

11431145
} // namespace LogGamma
11441146

1147+
#define create_trunc_macro(X, stdeval) \
1148+
namespace X { \
1149+
static inline ASR::expr_t *eval_##X(Allocator &al, const Location &loc, \
1150+
ASR::ttype_t *t, Vec<ASR::expr_t*>& args) { \
1151+
LCOMPILERS_ASSERT(args.size() == 1); \
1152+
double rv = ASR::down_cast<ASR::RealConstant_t>(args[0])->m_r; \
1153+
if (ASRUtils::extract_value(args[0], rv)) { \
1154+
double val = std::stdeval(rv); \
1155+
return make_ConstantWithType(make_RealConstant_t, val, t, loc); \
1156+
} \
1157+
return nullptr; \
1158+
} \
1159+
static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \
1160+
Vec<ASR::expr_t*>& args, \
1161+
const std::function<void (const std::string &, const Location &)> err) { \
1162+
ASR::ttype_t *type = ASRUtils::expr_type(args[0]); \
1163+
if (args.n != 1) { \
1164+
err("Intrinsic `#X` accepts exactly one argument", loc); \
1165+
} else if (!ASRUtils::is_real(*type)) { \
1166+
err("`x` argument of `#X` must be real", \
1167+
args[0]->base.loc); \
1168+
} \
1169+
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, \
1170+
eval_##X, static_cast<int64_t>(IntrinsicScalarFunctions::Trunc), \
1171+
0, type); \
1172+
} \
1173+
static inline ASR::expr_t* instantiate_##X (Allocator &al, \
1174+
const Location &loc, SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, \
1175+
ASR::ttype_t *return_type, Vec<ASR::call_arg_t>& new_args, \
1176+
int64_t overload_id) { \
1177+
ASR::ttype_t* arg_type = arg_types[0]; \
1178+
return UnaryIntrinsicFunction::instantiate_functions(al, loc, scope, \
1179+
"#X", arg_type, return_type, new_args, overload_id); \
1180+
} \
1181+
} // namespace X
1182+
1183+
create_trunc_macro(Trunc, trunc)
1184+
11451185
// `X` is the name of the function in the IntrinsicScalarFunctions enum and
11461186
// we use the same name for `create_X` and other places
11471187
// `stdeval` is the name of the function in the `std` namespace for compile
@@ -2879,6 +2919,8 @@ namespace IntrinsicScalarFunctionRegistry {
28792919
verify_function>>& intrinsic_function_by_id_db = {
28802920
{static_cast<int64_t>(IntrinsicScalarFunctions::LogGamma),
28812921
{&LogGamma::instantiate_LogGamma, &UnaryIntrinsicFunction::verify_args}},
2922+
{static_cast<int64_t>(IntrinsicScalarFunctions::Trunc),
2923+
{&Trunc::instantiate_Trunc, &UnaryIntrinsicFunction::verify_args}},
28822924
{static_cast<int64_t>(IntrinsicScalarFunctions::Sin),
28832925
{&Sin::instantiate_Sin, &UnaryIntrinsicFunction::verify_args}},
28842926
{static_cast<int64_t>(IntrinsicScalarFunctions::Cos),
@@ -2977,6 +3019,8 @@ namespace IntrinsicScalarFunctionRegistry {
29773019
{static_cast<int64_t>(IntrinsicScalarFunctions::LogGamma),
29783020
"log_gamma"},
29793021

3022+
{static_cast<int64_t>(IntrinsicScalarFunctions::Trunc),
3023+
"trunc"},
29803024
{static_cast<int64_t>(IntrinsicScalarFunctions::Sin),
29813025
"sin"},
29823026
{static_cast<int64_t>(IntrinsicScalarFunctions::Cos),
@@ -3074,6 +3118,7 @@ namespace IntrinsicScalarFunctionRegistry {
30743118
std::tuple<create_intrinsic_function,
30753119
eval_intrinsic_function>>& intrinsic_function_by_name_db = {
30763120
{"log_gamma", {&LogGamma::create_LogGamma, &LogGamma::eval_log_gamma}},
3121+
{"trunc", {&Trunc::create_Trunc, &Trunc::eval_Trunc}},
30773122
{"sin", {&Sin::create_Sin, &Sin::eval_Sin}},
30783123
{"cos", {&Cos::create_Cos, &Cos::eval_Cos}},
30793124
{"tan", {&Tan::create_Tan, &Tan::eval_Tan}},
@@ -3134,6 +3179,7 @@ namespace IntrinsicScalarFunctionRegistry {
31343179
id_ == IntrinsicScalarFunctions::Cos ||
31353180
id_ == IntrinsicScalarFunctions::Gamma ||
31363181
id_ == IntrinsicScalarFunctions::LogGamma ||
3182+
id_ == IntrinsicScalarFunctions::Trunc ||
31373183
id_ == IntrinsicScalarFunctions::Sin ||
31383184
id_ == IntrinsicScalarFunctions::Exp ||
31393185
id_ == IntrinsicScalarFunctions::Exp2 ||

src/libasr/runtime/lfortran_intrinsics.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,18 @@ LFORTRAN_API double_complex_t _lfortran_zatanh(double_complex_t x)
11341134
return catanh(x);
11351135
}
11361136

1137+
// trunc -----------------------------------------------------------------------
1138+
1139+
LFORTRAN_API float _lfortran_strunc(float x)
1140+
{
1141+
return truncf(x);
1142+
}
1143+
1144+
LFORTRAN_API double _lfortran_dtrunc(double x)
1145+
{
1146+
return trunc(x);
1147+
}
1148+
11371149
// phase --------------------------------------------------------------------
11381150

11391151
LFORTRAN_API float _lfortran_cphase(float_complex_t x)

src/libasr/runtime/lfortran_intrinsics.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ LFORTRAN_API float _lfortran_satanh(float x);
168168
LFORTRAN_API double _lfortran_datanh(double x);
169169
LFORTRAN_API float_complex_t _lfortran_catanh(float_complex_t x);
170170
LFORTRAN_API double_complex_t _lfortran_zatanh(double_complex_t x);
171+
LFORTRAN_API float _lfortran_strunc(float x);
172+
LFORTRAN_API double _lfortran_dtrunc(double x);
171173
LFORTRAN_API float _lfortran_cphase(float_complex_t x);
172174
LFORTRAN_API double _lfortran_zphase(double_complex_t x);
173175
LFORTRAN_API bool _lpython_str_compare_eq(char** s1, char** s2);

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7311,7 +7311,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
73117311
if (!s) {
73127312
std::string intrinsic_name = call_name;
73137313
std::set<std::string> not_cpython_builtin = {
7314-
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand",
7314+
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand", "trunc",
73157315
"sum" // For sum called over lists
73167316
};
73177317
std::set<std::string> symbolic_functions = {

src/runtime/lpython_intrinsic_numpy.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,3 +410,23 @@ def ceil(x: f32) -> f32:
410410
if x <= f32(0) or x == resultf:
411411
return resultf
412412
return resultf + f32(1)
413+
414+
########## trunc ##########
415+
416+
@ccall
417+
def _lfortran_dtrunc(x: f64) -> f64:
418+
pass
419+
420+
@overload
421+
@vectorize
422+
def trunc(x: f64) -> f64:
423+
return _lfortran_dtrunc(x)
424+
425+
@ccall
426+
def _lfortran_strunc(x: f32) -> f32:
427+
pass
428+
429+
@overload
430+
@vectorize
431+
def trunc(x: f32) -> f32:
432+
return _lfortran_strunc(x)

tests/reference/asr-array_01_decl-39cf894.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-array_01_decl-39cf894.stdout",
9-
"stdout_hash": "137a0c427925ba7da2e7151f2cf52bfa9a64fede11fe8d2653f20b64",
9+
"stdout_hash": "2aa47467473392c970bb1ddde961e3007d4c157bb0ea507b5e0db4a4",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

0 commit comments

Comments
 (0)