Add a fast path for batch-norm CPU inference.#19152
Conversation
soumith
left a comment
There was a problem hiding this comment.
remember that this file is not compiled with AVX2 or AVX enabled (internally or externally).
We have to use dispatch for that. The same folder (ATen/native) has some examples on that.
soumith
left a comment
There was a problem hiding this comment.
apparently it's already saturating memory bandwidth even without being vectorized
|
I believe we actually do compile with |
|
@zheng-xq - If you want to make use of the dispatch we use in OSS you need to move 'native/Normalization.cpp' into 'native/cpu/Normalization.cpp'. We can also do that as a separate step afterwards. |
There was a problem hiding this comment.
The RHS of this statement is computed in double precision when scalar_t=float, then it is being narrowed to float on assignment. Is that intended?
There was a problem hiding this comment.
The LLVM vectorizer does this with that:
vpslld $31, %ymm5, %ymm5
vpsrad $31, %ymm5, %ymm5
xorl %r10d, %r10d
movq -64(%rbp), %r11 ## 8-byte Reload
movq -56(%rbp), %r14 ## 8-byte Reload
.p2align 4, 0x90
LBB0_33: ## =>This Inner Loop Header: Depth=1
vcvtps2pd (%rdi,%rbx), %ymm6
vcvtps2pd 16(%rdi,%rbx), %ymm7
vaddpd %ymm7, %ymm2, %ymm7
vaddpd %ymm6, %ymm2, %ymm6
vsqrtpd %ymm6, %ymm6
vsqrtpd %ymm7, %ymm7
vdivpd %ymm7, %ymm3, %ymm7
vdivpd %ymm6, %ymm3, %ymm6
vcvtpd2ps %ymm6, %xmm6
vcvtpd2ps %ymm7, %xmm7
vinsertf128 $1, %xmm7, %ymm6, %ymm6
vmovups (%rsi,%rbx), %ymm7
vmaskmovps (%rdx,%rbx), %ymm4, %ymm8
vpandn %ymm8, %ymm5, %ymm8
vmulps %ymm6, %ymm7, %ymm9
vmovups %ymm9, (%r14,%rbx)
vmulps (%rcx,%rbx), %ymm6, %ymm6
vmulps %ymm6, %ymm7, %ymm6
vsubps %ymm6, %ymm8, %ymm6
vmovups %ymm6, (%r11,%rbx)
vcvtps2pd 32(%rdi,%rbx), %ymm6
vcvtps2pd 48(%rdi,%rbx), %ymm7
vaddpd %ymm7, %ymm2, %ymm7
vaddpd %ymm6, %ymm2, %ymm6
vsqrtpd %ymm6, %ymm6
vsqrtpd %ymm7, %ymm7
vdivpd %ymm7, %ymm3, %ymm7
vdivpd %ymm6, %ymm3, %ymm6
vcvtpd2ps %ymm6, %xmm6
vcvtpd2ps %ymm7, %xmm7
vinsertf128 $1, %xmm7, %ymm6, %ymm6
vmovups 32(%rsi,%rbx), %ymm7
vmaskmovps 32(%rdx,%rbx), %ymm4, %ymm8
vpandn %ymm8, %ymm5, %ymm8
vmulps %ymm6, %ymm7, %ymm9
vmovups %ymm9, 32(%r14,%rbx)
vmulps 32(%rcx,%rbx), %ymm6, %ymm6
vmulps %ymm6, %ymm7, %ymm6
vsubps %ymm6, %ymm8, %ymm6
vmovups %ymm6, 32(%r11,%rbx)
addq $16, %r10
addq $64, %rbx
addq $2, %rax
jne LBB0_33
## %bb.34:
testq %r8, %r8
je LBB0_36
There was a problem hiding this comment.
I pulled it into godbolt here: https://godbolt.org/z/OdsS-W
There was a problem hiding this comment.
I think we should figure out what flags we shuold use when analyzing this stuff. iirc the configuration is:
- FBCode: -mavx
- OSS build (default): (none)
- I build with
-march=native -mavx -mavx2on my mac
Do you know what flags are used in the distributed binaries?
There was a problem hiding this comment.
I just confirmed that the OSS build with default settings emits SSE instructions
There was a problem hiding this comment.
The distributed binaries use runtime dispatch and are compiled with -mavx and -mavx2 -mfma if I remember this correctly.
There was a problem hiding this comment.
OSS only compiles code with -mavx (and -mavx2) if it's in the cpu/ subdirectory. This isn't in that directory so you'll only get standard x86-64 instructions (including SSE)
There was a problem hiding this comment.
Regarding "The RHS of this statement is computed in double precision when scalar_t=float, then it is being narrowed to float on assignment. Is that intended?".
It was intended for better accuracy because often the epsilon was very small. And it would have no performance implication if (C << total_size). However, in degenerate cases where (N==1 && image_size==1), meaning (C==total_size), this is quite slow. So I change the default to casting epsilon in this case.
|
FWIW for BatchNorm1D this falls back to scalar code because only the inner loop at line 101 is vectorized, and image_size=1 |
|
This specialization gives me a 1.6x speedup on BatchNorm1D: |
There was a problem hiding this comment.
can you make output "Tensor &" instead? That's the usual pattern for output parameters currently (you can search in this directory for functions ending in "_out").
jamesr66a
left a comment
There was a problem hiding this comment.
We should address the things discussed in the comments
There was a problem hiding this comment.
I don't know how much it matters, but computing the rsqrt as a distinct expression is losing precision compared to doing an fdiv for each of alpha_data and beta_data. It's probably fine if that necessary for performance, but if it's not then it would be nice to preserve the extra precision.
There was a problem hiding this comment.
This portion is a performance critical. And from other experiments we had before, it shouldn't cause any accuracy problem.
There was a problem hiding this comment.
This loop is unvectorized in llvm because of the presence of the sqrt function (unless -ffast-math is set). Benchmark for the image_size=1 case:
batch_norm: data shape: [1, 65536, 1], bandwidth: 6.23 GB/s
And as a debug I printed out the timing for this loop ("first" loop) versus the loop below ("second" loop):
first time 0.000150609
second time 2.1667e-05
So, this loop is taking ~6x longer than the loop below. I am working with @ZolotukhinM to see if we can fix this
There was a problem hiding this comment.
@cpuhrsch we've found that setting -fno-math-errno allows the sqrt call to be vectorized. This actually seems fairly safe. Do you think we can turn this flag on in ATen?
Wouldn't flattening the loop nest be a simpler fix? |
|
Probably, yeah
…Sent from my iPhone
On Apr 11, 2019, at 11:20 PM, Owen Anderson <[email protected]<mailto:[email protected]>> wrote:
This specialization gives me a 1.6x speedup on BatchNorm1D:
--- a/aten/src/ATen/native/Normalization.cpp
+++ b/aten/src/ATen/native/Normalization.cpp
@@ -96,13 +96,24 @@ void batch_norm_cpu_inference_contiguous(Tensor* output, const Tensor& input,
// No need to use parallel_for as this function is supposed to be
// memory-limited.
// Keep the loop struture simple to make sure compiler vetorization kicks in.
- for (int64_t n = 0; n < n_batch; ++n) {
- for (int64_t c = 0; c < n_channel; ++c) {
- for (int64_t i = 0; i < image_size; ++i) {
- // Keep all the offset calculation within the inner loop for simplicity.
- // Compilers are very good at hoisting the common part outside.
- int64_t offset = n * n_channel * image_size + c * image_size + i;
- output_data[offset] = input_data[offset] * alpha_data[c] + beta_data[c];
+ if (image_size != 1) {
+ for (int64_t n = 0; n < n_batch; ++n) {
+ for (int64_t c = 0; c < n_channel; ++c) {
+ for (int64_t i = 0; i < image_size; ++i) {
+ // Keep all the offset calculation within the inner loop for simplicity.
+ // Compilers are very good at hoisting the common part outside.
+ int64_t offset = n * n_channel * image_size + c * image_size + i;
+ output_data[offset] = input_data[offset] * alpha_data[c] + beta_data[c];
+ }
+ }
+ }
+ } else {
+ for (int64_t n = 0; n < n_batch; ++n) {
+ for (int64_t c = 0; c < n_channel; ++c) {
+ // Keep all the offset calculation within the inner loop for simplicity.
+ // Compilers are very good at hoisting the common part outside.
+ int64_t offset = n * n_channel + c;
+ output_data[offset] = input_data[offset] * alpha_data[c] + beta_data[c];
Wouldn't flattening the loop nest be a simpler fix?
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub<#19152 (comment)>, or mute the thread<https://github.com/notifications/unsubscribe-auth/AEd-SLtCSIc8_PGioFOq8vtrcQKwNY-2ks5vgCWIgaJpZM4co0wH>.
|
|
Thanks for the comments! PTAL |
jamesr66a
left a comment
There was a problem hiding this comment.
Looking good, but a few more comments.
There was a problem hiding this comment.
Can you post the results from this benchmark in the PR description?
There was a problem hiding this comment.
This loop is unvectorized in llvm because of the presence of the sqrt function (unless -ffast-math is set). Benchmark for the image_size=1 case:
batch_norm: data shape: [1, 65536, 1], bandwidth: 6.23 GB/s
And as a debug I printed out the timing for this loop ("first" loop) versus the loop below ("second" loop):
first time 0.000150609
second time 2.1667e-05
So, this loop is taking ~6x longer than the loop below. I am working with @ZolotukhinM to see if we can fix this
There was a problem hiding this comment.
I confirmed that having this file outside of cpu/ emits SSE instructions, but surprisingly (to me) we're still hitting machine BW peak in my tests. For future proofing, maybe we should add a TODO to indicate we should enable newer vector extensions in the future, in case the balance of BW vs compute throughput on future CPU SKUs changes
Summary: Pull Request resolved: #19152 Adding a fast path for batch-norm CPU inference when all tensors are contiguous. * Leverage vectorization through smiple loops. * Folding linear terms before computation. * For resnext-101, this version gets 18.95 times faster. * Add a microbenchmark: * (buck build mode/opt -c python.package_style=inplace --show-output //caffe2/benchmarks/operator_benchmark:batchnorm_benchmark) && \ (OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 buck-out/gen/caffe2/benchmarks/operator_benchmark/batchnorm_benchmark#binary.par) * batch_norm: data shape: [1, 256, 3136], bandwidth: 22.26 GB/s * batch_norm: data shape: [1, 65536, 1], bandwidth: 5.57 GB/s * batch_norm: data shape: [128, 2048, 1], bandwidth: 18.21 GB/s Reviewed By: soumith, BIT-silence Differential Revision: D14889728 fbshipit-source-id: b2264ada175410f06c505ae57bd50d651410ced5
Summary: Pull Request resolved: pytorch/pytorch#19152 Adding a fast path for batch-norm CPU inference when all tensors are contiguous. * Leverage vectorization through smiple loops. * Folding linear terms before computation. * For resnext-101, this version gets 18.95 times faster. * Add a microbenchmark: * (buck build mode/opt -c python.package_style=inplace --show-output //caffe2/benchmarks/operator_benchmark:batchnorm_benchmark) && \ (OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 buck-out/gen/caffe2/benchmarks/operator_benchmark/batchnorm_benchmark#binary.par) * batch_norm: data shape: [1, 256, 3136], bandwidth: 22.26 GB/s * batch_norm: data shape: [1, 65536, 1], bandwidth: 5.57 GB/s * batch_norm: data shape: [128, 2048, 1], bandwidth: 18.21 GB/s Reviewed By: soumith, BIT-silence Differential Revision: D14889728 fbshipit-source-id: 20c9e567e38ff7dbb9097873b85160eca2b0a795
|
This pull request has been merged in 5627940. |
|
Do we actually want to compare PT vs Caffe2 implementations as well? I see that in Caffe2 we use Eigen, could have it worked here? |
|
about eigen in particular, we put some effort to develop our own packet abstraction (vec256.h) and a switchable threadpool. So Eigen doesn't add a lot of value, and maybe not necessary. |
|
@salexspb - Eigen is an entirely separate topic in itself. So far we haven't seen any advantages in using it. Ping me directly if you want to talk more and see what we've done already. In general, a rigorous comparison to evaluate whether adding that dependency is worthwhile and a whole project in and of itself. |
|
I am confused on the dependency part. Don't we already use it in Caffe2? How is unified build is working then? |
|
@salexspb - yes, that's true we already have it as part of our build chain. In terms of dependency, I mean includes. It's adding complexity to the code. In general, we should have comparison benchmarks for operators like this which are already supported by Caffe2 and then see if we can recycle them. But it's often also worthwhile to see whether we can write faster code by using ATen's abstractions. For this particular operator, I agree that it could be good to write a side-by-side comparison in a single script to compare Caffe2's and PyTorch's implementation. |
Summary: As suggested in #19152 (comment), this may give the compiler more opportunities for auto-vectorization Pull Request resolved: #19552 Differential Revision: D15048358 Pulled By: jamesr66a fbshipit-source-id: db2c2c515c3e9f7d22305c039ab0c8a867fc43a2
Summary: Pull Request resolved: pytorch#19152 Adding a fast path for batch-norm CPU inference when all tensors are contiguous. * Leverage vectorization through smiple loops. * Folding linear terms before computation. * For resnext-101, this version gets 18.95 times faster. * Add a microbenchmark: * (buck build mode/opt -c python.package_style=inplace --show-output //caffe2/benchmarks/operator_benchmark:batchnorm_benchmark) && \ (OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 buck-out/gen/caffe2/benchmarks/operator_benchmark/batchnorm_benchmark#binary.par) * batch_norm: data shape: [1, 256, 3136], bandwidth: 22.26 GB/s * batch_norm: data shape: [1, 65536, 1], bandwidth: 5.57 GB/s * batch_norm: data shape: [128, 2048, 1], bandwidth: 18.21 GB/s Reviewed By: soumith, BIT-silence Differential Revision: D14889728 fbshipit-source-id: 20c9e567e38ff7dbb9097873b85160eca2b0a795
Summary: As suggested in pytorch#19152 (comment), this may give the compiler more opportunities for auto-vectorization Pull Request resolved: pytorch#19552 Differential Revision: D15048358 Pulled By: jamesr66a fbshipit-source-id: db2c2c515c3e9f7d22305c039ab0c8a867fc43a2
Summary: Pull Request resolved: pytorch#19152 Adding a fast path for batch-norm CPU inference when all tensors are contiguous. * Leverage vectorization through smiple loops. * Folding linear terms before computation. * For resnext-101, this version gets 18.95 times faster. * Add a microbenchmark: * (buck build mode/opt -c python.package_style=inplace --show-output //caffe2/benchmarks/operator_benchmark:batchnorm_benchmark) && \ (OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 buck-out/gen/caffe2/benchmarks/operator_benchmark/batchnorm_benchmark#binary.par) * batch_norm: data shape: [1, 256, 3136], bandwidth: 22.26 GB/s * batch_norm: data shape: [1, 65536, 1], bandwidth: 5.57 GB/s * batch_norm: data shape: [128, 2048, 1], bandwidth: 18.21 GB/s Reviewed By: soumith, BIT-silence Differential Revision: D14889728 fbshipit-source-id: 20c9e567e38ff7dbb9097873b85160eca2b0a795
Summary: As suggested in pytorch#19152 (comment), this may give the compiler more opportunities for auto-vectorization Pull Request resolved: pytorch#19552 Differential Revision: D15048358 Pulled By: jamesr66a fbshipit-source-id: db2c2c515c3e9f7d22305c039ab0c8a867fc43a2
Summary:
Adding a fast path for batch-norm CPU inference when all tensors are contiguous.
Differential Revision: D14889728
== Benchmark Results ==
batch_norm: data shape: [1, 256, 3136], bandwidth: 22.26 GB/s
batch_norm: data shape: [1, 65536, 1], bandwidth: 5.57 GB/s
batch_norm: data shape: [128, 2048, 1], bandwidth: 18.21 GB/s