Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix reduce_blocks_into_lanes race condition#1798

Merged
crcrpar merged 2 commits into
NVIDIA:masterfrom
Fuzzkatt:Fuzzkatt/reduce_block_into_lanes_sync_threads_fix
Apr 26, 2024
Merged

Fix reduce_blocks_into_lanes race condition#1798
crcrpar merged 2 commits into
NVIDIA:masterfrom
Fuzzkatt:Fuzzkatt/reduce_block_into_lanes_sync_threads_fix

Conversation

@Fuzzkatt
Copy link
Copy Markdown
Contributor

@Fuzzkatt Fuzzkatt commented Apr 19, 2024

We are seeing numerical mismatches on GH and H100 when running the following unit tests: https://github.com/NVIDIA/apex/blob/master/tests/L0/run_optimizers/test_lamb.py#L251, https://github.com/NVIDIA/apex/blob/master/tests/L0/run_optimizers/test_lamb.py#L315. Running compute sanitizer on these tests reports races:

root@6b09a87d0bbc:/opt/pytorch/apex/tests/L0/run_optimizers# compute-sanitizer --tool racecheck python test_lamb.py -v -k test_multi_params ========= COMPUTE-SANITIZER test_multi_params (__main__.TestFusedLAMB) ... ========= Error: Race reported between Read access at T1 reduce_block_into_lanes<float>(T1 *, T1, int, bool)+0x2b0 in /opt/pytorch/apex/csrc/type_shim.h:350 ========= and Write access at T1 reduce_block_into_lanes<float>(T1 *, T1, int, bool)+0x5f0 in /opt/pytorch/apex/csrc/type_shim.h:333 [128 hazards]

Comparing with the pytorch reduce_block_into_lanes, we find that one major difference is the location of the final __sync_threads(): https://github.com/pytorch/pytorch/blob/1ec05c769b7e1c6ab5ba75f86b4ae6d43d77ac96/aten/src/ATen/native/cuda/WeightNorm.cu#L96. Looking at the usage: https://github.com/search?q=repo%3ANVIDIA%2Fapex%20reduce_block_into_lanes&type=code, we note that share_results=False is always used so the final __sync_threads() is never called in apex use cases. Thus, in the unit tests, we hypothesize that reduce_block_into_lanes is being called multiple times. Then, because there is no sync after the read in line 350, the write in line 333 from the second iteration is racing ahead of the read in line 350 from the first iteration.

This PR attempts to fix this issue by moving the final __sync_thread() to its proper location to fix this race.

cc @eqy, @crcrpar

@crcrpar crcrpar self-requested a review April 19, 2024 02:17
Comment thread csrc/type_shim.h
Copy link
Copy Markdown
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, excuse me for my slow response

@crcrpar crcrpar merged commit a7de60e into NVIDIA:master Apr 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants