Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Incorrect exponential calculation on Jetson devices with float32 dtype #61110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
BrettRyland opened this issue Jul 1, 2021 · 12 comments
Open
Labels
module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: cuda Related to torch.cuda, and CUDA support in general module: jetson Related to the Jetson builds by NVIDIA triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@BrettRyland
Copy link

BrettRyland commented Jul 1, 2021

πŸ› Bug

The exp function in torch and libtorch (https://pytorch.org/docs/stable/generated/torch.exp.html#torch.exp) give incorrect output when run on Jetson devices (tested on Xavier NX and Nano) with dtype=torch.float32.

To Reproduce

test.cpp:

#include <iostream>
#include <torch/script.h> // One-stop header.

int main(int argc, char *argv[])
{
        c10::InferenceMode guard(true); // Note: gives the same result without this.
        auto t = torch::ones({3, 3}, torch::dtype(torch::kFloat32));
        std::cout << "t:\n" << t << "\n";
        std::cout << "t.exp():\n" << t.exp() << "\n";
        return 0;
}

CMakeLists.txt:

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(test)

set(CMAKE_CXX_STANDARD 14)

find_package(Torch REQUIRED
        HINTS "/usr/local/lib/python3.8/dist-packages/torch")

add_executable(test test.cpp)
target_link_libraries(test PUBLIC "${TORCH_LIBRARIES}")

Compile and run:

br@NX:~/tmp/test$ Torch_DIR=/usr/local/lib/python3.6/dist-packages/torch cmake -GNinja .
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE  
-- Found CUDA: /usr/local/cuda (found version "10.2") 
-- Caffe2: CUDA detected: 10.2
-- Caffe2: CUDA nvcc is: /usr/local/cuda/bin/nvcc
-- Caffe2: CUDA toolkit directory: /usr/local/cuda
-- Caffe2: Header version is: 10.2
-- Found CUDNN: /usr/lib/aarch64-linux-gnu/libcudnn.so  
-- Found cuDNN: v8.0.0  (include: /usr/include, library: /usr/lib/aarch64-linux-gnu/libcudnn.so)
-- /usr/local/cuda/lib64/libnvrtc.so shorthash is c13e41e1
-- Autodetected CUDA architecture(s):  7.2
-- Added CUDA NVCC flags for: -gencode;arch=compute_72,code=sm_72
CMake Warning at /usr/local/lib/python3.6/dist-packages/torch/share/cmake/Torch/TorchConfig.cmake:22 (message):
  static library kineto_LIBRARY-NOTFOUND not found.
Call Stack (most recent call first):
  /usr/local/lib/python3.6/dist-packages/torch/share/cmake/Torch/TorchConfig.cmake:127 (append_torchlib_if_found)
  CMakeLists.txt:6 (find_package)


-- Found Torch: /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch.so  
-- Configuring done
-- Generating done
-- Build files have been written to: /home/br/tmp/test
br@NX:~/tmp/test$ ninja && ./test
[2/2] Linking CXX executable test
t:
 1  1  1
 1  1  1
 1  1  1
[ CPUFloatType{3,3} ]
t.exp():
 2.7183  1.0000  1.0000
 1.0000  9.4901  9.4901
 9.4901  9.4901  2.7183
[ CPUFloatType{3,3} ]

In python:

br@NX:~/tmp/test$ ipython
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.16.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import torch

In [2]: t=torch.ones((3,3), dtype=torch.float64); t.exp()
Out[2]: 
tensor([[2.7183, 2.7183, 2.7183],
        [2.7183, 2.7183, 2.7183],
        [2.7183, 2.7183, 2.7183]], dtype=torch.float64)

In [3]: t=torch.ones((3,3), dtype=torch.float32); t.exp()
Out[3]: 
tensor([[2.7183, 1.0000, 1.0000],
        [1.0000, 0.0000, 1.0000],
        [1.0000, 1.0000, 2.7183]])

Expected behavior

Similar values as when calculated with float64.

Environment

$ python3 collect_env.py 
Collecting environment information...
PyTorch version: 1.9.0
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (aarch64)
GCC version: (Ubuntu/Linaro 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.21.0-rc1
Libc version: glibc-2.25

Python version: 3.6.9 (default, Jan 26 2021, 15:33:00)  [GCC 8.4.0] (64-bit runtime)
Python platform: Linux-4.9.201-tegra-aarch64-with-Ubuntu-18.04-bionic
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/aarch64-linux-gnu/libcudnn.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_adv_infer.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_adv_train.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_cnn_infer.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_cnn_train.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_ops_infer.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_ops_train.so.8.0.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] torch==1.9.0
[pip3] torchvision==0.10.0
[conda] Could not collect

Additional Context

I've noticed this behaviour on older versions of libtorch due to getting strange results in a custom NMS function, but have only now figured out the source of those strange results, so I'm not sure how old this error is.

cc @malfet @ngimel

@BrettRyland
Copy link
Author

Note: my current workaround is to call

t.to(torch::kFloat64).exp().to(torch::kFloat32)

which gives the correct result but involves two type-casts.

@VitalyFedyunin VitalyFedyunin added module: jetson Related to the Jetson builds by NVIDIA module: cuda Related to torch.cuda, and CUDA support in general triage review labels Jul 1, 2021
@VitalyFedyunin
Copy link
Contributor

Triage review: Maybe promote to High Pri?

@malfet
Copy link
Contributor

malfet commented Jul 12, 2021

@BrettRyland can you please check if torch.exp problems are reproducible with official PyTorch cpu-only builds, which could be downloaded from https://pypi.org/project/torch/#files

gcc-7.5 have several known compiler bugs, which yields incorrect code for NEON optimized operations(that is almost everything using float32)

@soulitzer soulitzer added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jul 12, 2021
@malfet
Copy link
Contributor

malfet commented Jul 12, 2021

We should disable builds with NEON acceleration on known broken compilers, similar to what we used to have in #47099
cc: @shmsong @ptrblck

@malfet malfet added the module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 label Jul 12, 2021
@BrettRyland
Copy link
Author

@BrettRyland can you please check if torch.exp problems are reproducible with official PyTorch cpu-only builds, which could be downloaded from https://pypi.org/project/torch/#files

gcc-7.5 have several known compiler bugs, which yields incorrect code for NEON optimized operations(that is almost everything using float32)

Using the cpu-only build torch-1.9.0-cp36-cp36m-manylinux2014_aarch64.whl on the Xavier NX gives the correct results in python:

br@NX:~/tmp$ source bin/activate
(tmp) br@NX:~/tmp$ ipython
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
Type 'copyright', 'c5

In [2]: t=torch.ones((3,3), dtype=torch.float64); t.exp()
Out[2]: 
tensor([[2.7183, 2.7183, 2.7183],
        [2.7183, 2.7183, 2.7183],
        [2.7183, 2.7183, 2.7183]], dtype=torch.float64)

In [3]: t=torch.ones((3,3), dtype=torch.float32); t.exp()
Out[3]: 
tensor([[2.7183, 2.7183, 2.7183],
        [2.7183, 2.7183, 2.7183],
        [2.7183, 2.7183, 2.7183]])

and similarly in C++:

(tmp) br@NX:~/tmp/test$ rm CMakeCache.txt 
(tmp) br@NX:~/tmp/test$ cmake -GNinja
CMake Warning:
  No source or binary directory provided.  Both will be assumed to be the
  same as the current working directory, but note that this warning will
  become a fatal error in future CMake releases.


-- The C compiler identification is GNU 7.5.0
-- The CXX compiler identification is GNU 7.5.0
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE  
CMake Warning at /home/br/tmp/lib/python3.6/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:22 (message):
  static library kineto_LIBRARY-NOTFOUND not found.
Call Stack (most recent call first):
  /home/br/tmp/lib/python3.6/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:127 (append_torchlib_if_found)
  CMakeLists.txt:6 (find_package)


-- Found Torch: /home/br/tmp/lib/python3.6/site-packages/torch/lib/libtorch.so  
-- Configuring done
-- Generating done
-- Build files have been written to: /home/br/tmp/test
(tmp) br@NX:~/tmp/test$ ninja
[2/2] Linking CXX executable test
(tmp) br@NX:~/tmp/test$ ./test 
t:
 1  1  1
 1  1  1
 1  1  1
[ CPUFloatType{3,3} ]
t.exp():
 2.7183  2.7183  2.7183
 2.7183  2.7183  2.7183
 2.7183  2.7183  2.7183
[ CPUFloatType{3,3} ]
t.to(float64).exp().to(float32):
 2.7183  2.7183  2.7183
 2.7183  2.7183  2.7183
 2.7183  2.7183  2.7183
[ CPUFloatType{3,3} ]
(tmp) br@NX:~/tmp/test$ ldd test
        linux-vdso.so.1 (0x0000007fb07d9000)
        libtorch.so => /home/br/tmp/lib/python3.6/site-packages/torch/lib/libtorch.so (0x0000007fb0742000)
        libc10.so => /home/br/tmp/lib/python3.6/site-packages/torch/lib/libc10.so (0x0000007fb067f000)
        libtorch_cpu.so => /home/br/tmp/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so (0x0000007fab347000)
        libstdc++.so.6 => /usr/lib/aarch64-linux-gnu/libstdc++.so.6 (0x0000007fab18f000)
        libgcc_s.so.1 => /lib/aarch64-linux-gnu/libgcc_s.so.1 (0x0000007fab16b000)
        libc.so.6 => /lib/aarch64-linux-gnu/libc.so.6 (0x0000007fab012000)
        /lib/ld-linux-aarch64.so.1 (0x0000007fb07ad000)
        libm.so.6 => /lib/aarch64-linux-gnu/libm.so.6 (0x0000007faaf59000)
        libgomp-d22c30c5.so.1 => /home/br/tmp/lib/python3.6/site-packages/torch/lib/libgomp-d22c30c5.so.1 (0x0000007faaf05000)
        libpthread.so.0 => /lib/aarch64-linux-gnu/libpthread.so.0 (0x0000007faaed9000)
        librt.so.1 => /lib/aarch64-linux-gnu/librt.so.1 (0x0000007faaec2000)
        libdl.so.2 => /lib/aarch64-linux-gnu/libdl.so.2 (0x0000007faaead000)

However, I need the CUDA version for my project.

Note: numpy was not included in the dependencies of the wheel, giving

In [1]: import torch
/home/br/tmp/lib/python3.6/site-packages/torch/package/_mock_zipreader.py:17: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:67.)
  _dtype_to_storage = {data_type(0).dtype: data_type for data_type in _storages}

on my first attempt and I then ran into this bug numpy/numpy#18131 when installing its default version (1.19.5). I fixed it by using v1.19.4 of numpy:

(tmp) br@NX:~/tmp$ pip freeze
backcall==0.2.0
dataclasses==0.8
decorator==5.0.9
ipython==7.16.1
ipython-genutils==0.2.0
jedi==0.18.0
numpy==1.19.4
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
pkg-resources==0.0.0
prompt-toolkit==3.0.19
ptyprocess==0.7.0
Pygments==2.9.0
six==1.16.0
torch @ file:///home/br/tmp/torch-1.9.0-cp36-cp36m-manylinux2014_aarch64.whl
traitlets==4.3.3
typing-extensions==3.10.0.0
wcwidth==0.2.5

@BrettRyland
Copy link
Author

I've tried recompiling with gcc 8.4.0 (which is installable through apt on the Xavier NX) using

MAX_JOBS=4 CC=/usr/bin/gcc-8 CXX=/usr/bin/g++-8 sudo -HE python3 setup.py install

and it also gives the incorrect values for the exponential with float32 dtype.

Build summary:

-- ******** Summary ********
-- General:
--   CMake version         : 3.21.0-rc1
--   CMake command         : /usr/local/bin/cmake
--   System                : Linux
--   C++ compiler          : /usr/bin/g++-8
--   C++ compiler id       : GNU
--   C++ compiler version  : 8.4.0
--   Using ccache if found : ON
--   Found ccache          : CCACHE_PROGRAM-NOTFOUND
--   CXX flags             :  -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -DMISSING_ARM_VST1 -Wno-stringop-overflow
--   Build type            : Release
--   Compile definitions   : ONNX_ML=1;ONNXIFI_ENABLE_EXT=1;ONNX_NAMESPACE=onnx_torch;HAVE_MMAP=1;_FILE_OFFSET_BITS=64;HAVE_SHM_OPEN=1;HAVE_SHM_UNLINK=1;HAVE_MALLOC_USABLE_SIZE=1;USE_EXTERNAL_MZCRC;MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS
--   CMAKE_PREFIX_PATH     : /usr/lib/python3.6/site-packages;/usr/local/cuda
--   CMAKE_INSTALL_PREFIX  : /home/br/staging/pytorch/torch
--   USE_GOLD_LINKER       : OFF
-- 
--   TORCH_VERSION         : 1.9.0
--   CAFFE2_VERSION        : 1.9.0
--   BUILD_CAFFE2          : ON
--   BUILD_CAFFE2_OPS      : ON
--   BUILD_CAFFE2_MOBILE   : OFF
--   BUILD_STATIC_RUNTIME_BENCHMARK: OFF
--   BUILD_TENSOREXPR_BENCHMARK: OFF
--   BUILD_BINARY          : OFF
--   BUILD_CUSTOM_PROTOBUF : ON
--     Link local protobuf : ON
--   BUILD_DOCS            : OFF
--   BUILD_PYTHON          : True
--     Python version      : 3.6.9
--     Python executable   : /usr/bin/python3
--     Pythonlibs version  : 3.6.9
--     Python library      : /usr/lib/libpython3.6m.so.1.0
--     Python includes     : /usr/include/python3.6m
--     Python site-packages: lib/python3.6/site-packages
--   BUILD_SHARED_LIBS     : ON
--   CAFFE2_USE_MSVC_STATIC_RUNTIME     : OFF
--   BUILD_TEST            : True
--   BUILD_JNI             : OFF
--   BUILD_MOBILE_AUTOGRAD : OFF
--   BUILD_LITE_INTERPRETER: OFF
--   INTERN_BUILD_MOBILE   : 
--   USE_BLAS              : 1
--     BLAS                : open
--   USE_LAPACK            : 1
--     LAPACK              : open
--   USE_ASAN              : OFF
--   USE_CPP_CODE_COVERAGE : OFF
--   USE_CUDA              : ON
--     Split CUDA          : OFF
--     CUDA static link    : OFF
--     USE_CUDNN           : ON
--     CUDA version        : 10.2
--     cuDNN version       : 8.0.0
--     CUDA root directory : /usr/local/cuda
--     CUDA library        : /usr/local/cuda/lib64/stubs/libcuda.so
--     cudart library      : /usr/local/cuda/lib64/libcudart.so
--     cublas library      : /usr/lib/aarch64-linux-gnu/libcublas.so
--     cufft library       : /usr/local/cuda/lib64/libcufft.so
--     curand library      : /usr/local/cuda/lib64/libcurand.so
--     cuDNN library       : /usr/lib/aarch64-linux-gnu/libcudnn.so
--     nvrtc               : /usr/local/cuda/lib64/libnvrtc.so
--     CUDA include path   : /usr/local/cuda/include
--     NVCC executable     : /usr/local/cuda/bin/nvcc
--     NVCC flags          : -Xfatbin;-compress-all;-DONNX_NAMESPACE=onnx_torch;-gencode;arch=compute_72,code=sm_72;-Xcudafe;--diag_suppress=cc_clobber_ignored,--diag_suppress=integer_sign_change,--diag_suppress=useless_using_declaration,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=implicit_return_from_non_void_function,--diag_suppress=unsigned_compare_with_zero,--diag_suppress=declared_but_not_referenced,--diag_suppress=bad_friend_decl;-std=c++14;-Xcompiler;-fPIC;--expt-relaxed-constexpr;--expt-extended-lambda;-Wno-deprecated-gpu-targets;--expt-extended-lambda;-Xcompiler;-fPIC;-DCUDA_HAS_FP16=1;-D__CUDA_NO_HALF_OPERATORS__;-D__CUDA_NO_HALF_CONVERSIONS__;-D__CUDA_NO_BFLOAT16_CONVERSIONS__;-D__CUDA_NO_HALF2_OPERATORS__
--     CUDA host compiler  : /usr/bin/gcc-8
--     NVCC --device-c     : OFF
--     USE_TENSORRT        : OFF
--   USE_ROCM              : OFF
--   USE_EIGEN_FOR_BLAS    : ON
--   USE_FBGEMM            : OFF
--     USE_FAKELOWP          : OFF
--   USE_KINETO            : ON
--   USE_FFMPEG            : OFF
--   USE_GFLAGS            : OFF
--   USE_GLOG              : OFF
--   USE_LEVELDB           : OFF
--   USE_LITE_PROTO        : OFF
--   USE_LMDB              : OFF
--   USE_METAL             : OFF
--   USE_PYTORCH_METAL     : OFF
--   USE_FFTW              : OFF
--   USE_MKL               : OFF
--   USE_MKLDNN            : OFF
--   USE_NCCL              : ON
--     USE_SYSTEM_NCCL     : OFF
--   USE_NNPACK            : ON
--   USE_NUMPY             : ON
--   USE_OBSERVERS         : ON
--   USE_OPENCL            : OFF
--   USE_OPENCV            : OFF
--   USE_OPENMP            : ON
--   USE_TBB               : OFF
--   USE_VULKAN            : OFF
--   USE_PROF              : OFF
--   USE_QNNPACK           : ON
--   USE_PYTORCH_QNNPACK   : ON
--   USE_REDIS             : OFF
--   USE_ROCKSDB           : OFF
--   USE_ZMQ               : OFF
--   USE_DISTRIBUTED       : ON
--     USE_MPI             : ON
--     USE_GLOO            : ON
--     USE_TENSORPIPE      : ON
--   USE_DEPLOY           : OFF
--   Public Dependencies  : Threads::Threads
--   Private Dependencies : pthreadpool;cpuinfo;qnnpack;pytorch_qnnpack;nnpack;XNNPACK;/usr/lib/aarch64-linux-gnu/libnuma.so;fp16;/usr/lib/aarch64-linux-gnu/openmpi/lib/libmpi_cxx.so;/usr/lib/aarch64-linux-gnu/openmpi/lib/libmpi.so;gloo;tensorpipe;aten_op_header_gen;foxi_loader;rt;fmt::fmt-header-only;kineto;gcc_s;gcc;dl
-- Configuring done
-- Generating done
-- Build files have been written to: /home/br/staging/pytorch/build
cmake --build . --target install --config Release -- -j 4

Also, I noticed this during the configuring stage:

-- Could not find hardware support for NEON on this machine.

so simply disabling NEON acceleration probably won't help.

I've attached the CMakeCache.txt file in case that's useful.
CMakeCache.txt

@dusty-nv
Copy link

I've tried recompiling with gcc 8.4.0 (which is installable through apt on the Xavier NX) using

@BrettRyland I tried the same with GCC 8.4.0 and found the same behavior. Then I cherry-picked #47099 and this did fix it, with your test now passing.

Here is the updated Jetson wheel: https://nvidia.box.com/shared/static/h1z9sw4bb1ybi0rm3tu8qdj8hs05ljbm.whl (torch-1.9.0-cp36-cp36m-linux_aarch64.whl)

Interestingly this is not a problem with the PyTorch v1.8 wheel and appears to be a regression in v1.9.

@BrettRyland
Copy link
Author

OK, so cmake does eventually find NEON through testing for ASIMD:

-- Could not find hardware support for NEON on this machine.
-- No OMAP3 processor on this machine.
-- No OMAP4 processor on this machine.
-- asimd/Neon found with compiler flag : -D__NEON__

I see the wheel you compiled was compiled with gcc-7.5.0 (which works, thanks!), however, cherry-picking #47099 and compiling with gcc-8.4.0 still fails as the gcc version check in that commit is for gcc >= 8.4, which gcc-8.4.0 passes.

@malfet
Copy link
Contributor

malfet commented Jul 19, 2021

@dusty-nv I'm glad that fix works for you. Please note, that disabling NEON acceleration would have negative effect on CPU performance of PyTorch.

@Qengineering
Copy link

Best way to solve the whole issue is by using clang instead of gnu. The NEON acceleration is still enabled. As you see, the error occurs also on a Raspberry Pi 64-OS.
image
The solution works also for the Jetson family. I've placed some PyTorch wheels for the Jetson Nano on GitHub compiled with clang.

@Garfield2005
Copy link

@malfet @dusty-nv

Hi, the exp function with dtype=torch.float32 in libtorch still give incorrect output when running on Jetson xavier.

I recompiled the libtorch-1.10 with gcc 8.4.0:

git clone -b v1.10.0 --recursive https://github.com/pytorch/pytorch.git
# install dependence:  apt isntall ...
# export env: export USE_CUDA=ON ... 
python3 ./tools/build_libtorch.py

test code:

void check_torch_exp(){
   c10::InferenceMode guard(true); // Note: gives the same result without this.
   auto t = torch::ones({3, 3}, torch::dtype(torch::kFloat32));
   std::cout << "t:\n" << t << "\n";
   std::cout << "t.exp():\n" << t.exp() << "\n";
   std::cout << "t.hypot:\n" << torch::hypot(t, t) << "\n";
   std::cout << "t.sigmoid_:\n" << t.sigmoid_() << "\n";
   std::cout << "---------------------------------------------" << std::endl;

   auto t_cuda = torch::ones({3, 3}, torch::dtype(torch::kFloat32)).to(torch::kCUDA);
   std::cout << "t_CUDA:\n" << t_cuda << "\n";
   std::cout << "t_CUDA.exp():\n" << t_cuda.exp() << "\n";
   std::cout << "t_CUDAhypot:\n" << torch::hypot(t_cuda, t_cuda) << "\n";
   std::cout << "t_CUDA.sigmoid_:\n" << t_cuda.sigmoid_() << "\n";
   std::cout << "---------------------------------------------" << std::endl;
}

cmake output:

-- The C compiler identification is GNU 8.4.0 
-- The CXX compiler identification is GNU 8.4.0
-- The CUDA compiler identification is NVIDIA 10.2.300
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Check for working CUDA compiler: /usr/local/cuda-10.2/bin/nvcc
-- Check for working CUDA compiler: /usr/local/cuda-10.2/bin/nvcc -- works
-- Detecting CUDA compiler ABI info
-- Detecting CUDA compiler ABI info - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE  
-- Found CUDA: /usr/local/cuda-10.2 (found version "10.2") 
-- Caffe2: CUDA detected: 10.2
-- Caffe2: CUDA nvcc is: /usr/local/cuda-10.2/bin/nvcc
-- Caffe2: CUDA toolkit directory: /usr/local/cuda-10.2
-- Caffe2: Header version is: 10.2
-- Found CUDNN: /usr/lib/aarch64-linux-gnu/libcudnn.so  
-- Found cuDNN: v8.2.1  (include: /usr/include, library: /usr/lib/aarch64-linux-gnu/libcudnn.so)
-- /usr/local/cuda-10.2/lib64/libnvrtc.so shorthash is 7d272a04
-- Autodetected CUDA architecture(s):  7.2
-- Added CUDA NVCC flags for: -gencode;arch=compute_72,code=sm_72
-- Found Torch: /data/zhangmm/opensource/pytorch/pytorch_1.10_GCC8.4/build_libtorch/lib/libtorch.so  
-- Configuring done
-- Generating done
-- Build files have been written to: /home/nvidia/zhangmm/kd-distress-processor/test/debug/build

final output:

t:
 1  1  1
 1  1  1
 1  1  1
[ CPUFloatType{3,3} ]
t.exp():
 2.7183  1.0000  1.0000
 1.0000  1.0000  1.0000
 0.0000  1.0000  2.7183
[ CPUFloatType{3,3} ]
t.hypot:
 1.4142  1.4142  1.4142
 1.4142  1.4142  1.4142
 1.4142  1.4142  1.4142
[ CPUFloatType{3,3} ]
t.sigmoid_:
 0.7311  0.7311  0.7311
 0.7311  0.7311  0.7311
 0.7311  0.7311  0.7311
[ CPUFloatType{3,3} ]
---------------------------------------------
t_CUDA:
 1  1  1
 1  1  1
 1  1  1
[ CUDAFloatType{3,3} ]
t_CUDA.exp():
 2.7183  2.7183  2.7183
 2.7183  2.7183  2.7183
 2.7183  2.7183  2.7183
[ CUDAFloatType{3,3} ]
t_CUDA.hypot:
 1.4142  1.4142  1.4142
 1.4142  1.4142  1.4142
 1.4142  1.4142  1.4142
[ CUDAFloatType{3,3} ]
t_CUDA.sigmoid_:
 0.7311  0.7311  0.7311
 0.7311  0.7311  0.7311
 0.7311  0.7311  0.7311
[ CUDAFloatType{3,3} ]

@Qengineering
Copy link

@Garfield2005 ,

These errors are typical when using a GNU compiler like gcc 8.4.0.
Please use the Clang compilers. See our tutorial

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: cuda Related to torch.cuda, and CUDA support in general module: jetson Related to the Jetson builds by NVIDIA triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants