Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b2ab92d

Browse files
minsiixuhancn
authored andcommitted
[c10d] use allocator trace callbacks for NCCL PG register (pytorch#112850)
Summary: We need to register all cache segments allocated by allocator, so that NCCL can apply zero copy algorithms at collective and point-to-point operations. How to track and register all cache segments: - It registers a register and a deregister hook to cache allocator as action tracker callbacks, tracking SEGMENT_ALLOC and SEGMENT_FREE trace entries, respectively. When SEGMENT_ALLOC is tracked, the register hook will register to the PG's communicators on the same device. Similarly, when SEGMENT_FREE is tracked, the deregister hook handles deregistration before cudaFree. - When a new NCCL communicator is created, it dumps the snapspot from cache allocator to register all existing cache segments at once. - When a NCCL communicator is aborted, it deregisters all segments that have been registered by this communicator Test Plan: See test in D50726971 Reviewed By: wconstab Differential Revision: D50726970 Pull Request resolved: pytorch#112850 Approved by: https://github.com/wconstab
1 parent cd71185 commit b2ab92d

4 files changed

Lines changed: 248 additions & 0 deletions

File tree

test/distributed/test_c10d_nccl.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,29 @@ def _check_nccl_timeout(expected_timeout):
12491249
_check_nccl_timeout(timedelta(seconds=1240))
12501250
dist.destroy_process_group()
12511251

1252+
@requires_nccl()
1253+
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
1254+
def test_tensor_register_hook(self):
1255+
os.environ["NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"] = "1"
1256+
1257+
store = c10d.FileStore(self.file_name, self.world_size)
1258+
pg = self._create_process_group_nccl(store, self.opts())
1259+
local_device_id = self.rank_to_GPU[self.rank][0]
1260+
1261+
def allgather_base(output_t, input_t):
1262+
work = pg._allgather_base(output_t, input_t)
1263+
work.wait()
1264+
1265+
# allgather_base is GPU number agnostic.
1266+
# Each rank contribute one tensor regardless of GPU counts
1267+
tensor = torch.tensor([self.rank]).cuda(local_device_id)
1268+
output_t = torch.empty((self.world_size), dtype=tensor.dtype).cuda(local_device_id)
1269+
1270+
allgather_base(output_t, tensor)
1271+
1272+
# Verification
1273+
self.assertEqual(torch.arange(self.world_size), output_t)
1274+
12521275
class DistributedDataParallelTest(
12531276
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
12541277
):

torch/csrc/distributed/c10d/NCCLUtils.hpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@
6060
#define NCCL_HAS_COMM_CTA_CGA
6161
#endif
6262

63+
#if defined(NCCL_REGISTRATION_SUPPORTED) || \
64+
((defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
65+
(NCCL_MINOR >= 19)))
66+
#define NCCL_HAS_COMM_REGISTER
67+
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
68+
#define NCCL_HAS_COMM_REGISTER
69+
#endif
70+
6371
// Macro to throw on a non-successful NCCL return value.
6472
#define C10D_NCCL_CHECK(cmd, failureReason) \
6573
do { \
@@ -264,6 +272,21 @@ class NCCLComm {
264272
return;
265273
}
266274

275+
#ifdef NCCL_HAS_COMM_REGISTER
276+
// Deregister all registered segments before aborting.
277+
for (auto& it : registeredSegmentHandles_) {
278+
void* handle = it.second;
279+
C10D_NCCL_CHECK(
280+
::ncclCommDeregister(ncclComm_, handle),
281+
c10::str(
282+
"Failed to deregister segment handle ",
283+
handle,
284+
" on ncclComm_ ",
285+
ncclComm_));
286+
}
287+
registeredSegmentHandles_.clear();
288+
#endif
289+
267290
// Set true failure reason if provided by ProcessGroupNCCL (e.g. work
268291
// timeout)
269292
commFailureReason_ = commFailureReason;
@@ -306,6 +329,62 @@ class NCCLComm {
306329
#endif
307330
}
308331

332+
ncclResult_t registerSegment(void* ptr, size_t size) {
333+
std::unique_lock<std::mutex> lock(mutex_);
334+
#ifdef NCCL_HAS_COMM_REGISTER
335+
// We register only segments from cache allocator
336+
// which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
337+
// maps to a unique handle and should not be registered before the current
338+
// ptr is deregistered and freed.
339+
TORCH_CHECK(
340+
registeredSegmentHandles_.count(ptr) == 0,
341+
"Segment with ptr ",
342+
ptr,
343+
" has already been registered on ncclComm_ ",
344+
ncclComm_);
345+
346+
void* handle;
347+
C10D_NCCL_CHECK(
348+
ncclCommRegister(ncclComm_, ptr, size, &handle),
349+
c10::str(
350+
"Failed to register segment with ptr ",
351+
ptr,
352+
", size ",
353+
size,
354+
" on ncclComm_ ",
355+
ncclComm_));
356+
registeredSegmentHandles_[ptr] = handle;
357+
return ncclSuccess;
358+
#else
359+
return ncclInvalidUsage;
360+
#endif
361+
}
362+
363+
ncclResult_t deregisterSegment(void* ptr) {
364+
std::unique_lock<std::mutex> lock(mutex_);
365+
#ifdef NCCL_HAS_COMM_REGISTER
366+
TORCH_CHECK(
367+
registeredSegmentHandles_.count(ptr) == 1,
368+
"Segment with ptr ",
369+
ptr,
370+
" is not registered on ncclComm_ ",
371+
ncclComm_);
372+
373+
void* handle = registeredSegmentHandles_[ptr];
374+
C10D_NCCL_CHECK(
375+
ncclCommDeregister(ncclComm_, handle),
376+
c10::str(
377+
"Failed to deregister segment handle ",
378+
handle,
379+
" on ncclComm_ ",
380+
ncclComm_));
381+
registeredSegmentHandles_.erase(ptr);
382+
return ncclSuccess;
383+
#else
384+
return ncclInvalidUsage;
385+
#endif
386+
}
387+
309388
protected:
310389
ncclComm_t ncclComm_;
311390
// Unique nccl_id for this communicator.
@@ -318,6 +397,10 @@ class NCCLComm {
318397
// Optional reason for communicator failure, provided by ProcessGroupNCCL for
319398
// better error messaging.
320399
c10::optional<std::string> commFailureReason_;
400+
#ifdef NCCL_HAS_COMM_REGISTER
401+
// Stores handlers for tensors registered by NCCL
402+
std::unordered_map<void*, void*> registeredSegmentHandles_;
403+
#endif
321404
};
322405

323406
// Helper that automatically cleans up premul sums.

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <ATen/cuda/CUDAContext.h>
1717
#include <ATen/cuda/CUDAGraph.h>
1818
#include <c10/core/DeviceType.h>
19+
#include <c10/cuda/CUDAAllocatorConfig.h>
1920
#include <c10/cuda/CUDAGraphsC10Utils.h>
2021
#include <c10/cuda/CUDAGuard.h>
2122
#include <c10/util/CallOnce.h>
@@ -317,6 +318,55 @@ c10::List<c10::IValue> new_list() {
317318

318319
} // namespace
319320

321+
// Map from each communicator to its device index.
322+
// This map is used when register/deregister cache segments from cache
323+
// allocator. See design notes below:
324+
// - Each segment should be registered only to the communicator on the
325+
// same device.
326+
// - We cannot reuse devNCCLCommMap_ in each ProcessGroup because the key may be
327+
// ranks rather than device in point-to-point case.
328+
// - This map has also to be maintained as global variable since the register
329+
// hooks are called outside the scope of any PG, thus we need traverse
330+
// communicators in all PGs.
331+
static std::unordered_map<std::shared_ptr<NCCLComm>, int> ncclCommDevIdxMap;
332+
static std::mutex ncclCommDevIdxMapMutex;
333+
static bool allocatorHooksAttached = false;
334+
void cacheAllocatorRegisterHook(
335+
const c10::cuda::CUDACachingAllocator::TraceEntry& te) {
336+
// Register after SEGMENT_ALLOC
337+
if (te.action_ !=
338+
c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_ALLOC) {
339+
return;
340+
}
341+
342+
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
343+
for (auto& it : ncclCommDevIdxMap) {
344+
auto& ncclComm = it.first;
345+
auto& devIdx = it.second;
346+
if (te.device_ == devIdx) {
347+
ncclComm->registerSegment(reinterpret_cast<void*>(te.addr_), te.size_);
348+
}
349+
}
350+
}
351+
352+
void cacheAllocatorDeregisterHook(
353+
const c10::cuda::CUDACachingAllocator::TraceEntry& te) {
354+
// deregister before SEGMENT_FREE
355+
if (te.action_ !=
356+
c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_FREE) {
357+
return;
358+
}
359+
360+
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
361+
for (auto& it : ncclCommDevIdxMap) {
362+
auto& ncclComm = it.first;
363+
auto& devIdx = it.second;
364+
if (te.device_ == devIdx) {
365+
ncclComm->deregisterSegment(reinterpret_cast<void*>(te.addr_));
366+
}
367+
}
368+
}
369+
320370
struct NCCLTraceBuffer {
321371
static NCCLTraceBuffer* get() {
322372
// intentionally leak on exit
@@ -817,6 +867,12 @@ void ProcessGroupNCCL::WorkNCCL::abort() {
817867
for (const auto& ncclComm : ncclComms_) {
818868
ncclComm->ncclCommAbort();
819869
}
870+
871+
ncclCommDevIdxMapMutex.lock();
872+
for (const auto& comm : ncclComms_) {
873+
ncclCommDevIdxMap.erase(comm);
874+
}
875+
ncclCommDevIdxMapMutex.unlock();
820876
}
821877

822878
ProcessGroupNCCL::CoalescedWorkNCCL::CoalescedWorkNCCL(
@@ -881,6 +937,17 @@ ProcessGroupNCCL::ProcessGroupNCCL(
881937
parseEnvVarIntDefault("TORCH_NCCL_TRACE_BUFFER_SIZE", 0) > 0);
882938
#endif
883939
avoidRecordStreams_ = parseEnvVarFlag(TORCH_NCCL_AVOID_RECORD_STREAMS);
940+
#ifdef NCCL_HAS_COMM_REGISTER
941+
useTensorRegisterAllocatorHook_ =
942+
parseEnvVarFlag(NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK);
943+
if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
944+
expandable_segments()) {
945+
useTensorRegisterAllocatorHook_ = false;
946+
LOG(INFO)
947+
<< "[Rank " << rank_
948+
<< "] disables NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode.";
949+
}
950+
#endif
884951

885952
if (blockingWait_) {
886953
if (asyncErrorHandling_ != NoHandling || desyncDebug_) {
@@ -932,6 +999,10 @@ ProcessGroupNCCL::ProcessGroupNCCL(
932999
<< options_->is_high_priority_stream
9331000
<< ", TORCH_DISTRIBUTED_DEBUG: "
9341001
<< std::string(torch_distributed_debug)
1002+
#ifdef NCCL_HAS_COMM_REGISTER
1003+
<< ", NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: "
1004+
<< useTensorRegisterAllocatorHook_
1005+
#endif
9351006
<< ", NCCL_DEBUG: " << std::string(nccl_debug)
9361007
<< ", ID=" << this->getID();
9371008

@@ -947,6 +1018,19 @@ ProcessGroupNCCL::ProcessGroupNCCL(
9471018
std::vector<int64_t>(), // outSplitSizes
9481019
size_); // worldSize
9491020

1021+
// Attach hooks to cache allocator to trigger the hooks whenever a traced
1022+
// action is called. In the following hooks, we register a newly allocated
1023+
// segment when SEGMENT_ALLOC action occurs, and deregister a segment when
1024+
// SEGMENT_FREE action occurs.
1025+
// We attach hooks only once at the first PG creation.
1026+
if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) {
1027+
c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
1028+
&cacheAllocatorRegisterHook);
1029+
c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
1030+
&cacheAllocatorDeregisterHook);
1031+
allocatorHooksAttached = true;
1032+
}
1033+
9501034
#ifdef USE_NCCL_WITH_UCC
9511035
static c10::once_flag initialize_ucc_lib_flag;
9521036
c10::call_once(initialize_ucc_lib_flag, [&] {
@@ -1124,6 +1208,20 @@ void abortCommsFromMap(
11241208

11251209
// Abort all communicators on this rank
11261210
void ProcessGroupNCCL::abort(c10::optional<std::string> abortReason) {
1211+
// Remove record from global ncclCommDevIdxMapMutex before aboarting,
1212+
// so that a new cache segment would not register to already aborded
1213+
// communicators. Note that ncclCommDevIdxMap is a global container which may
1214+
// contain other PG's communicators, thus we need to only erase communicators
1215+
// for the current PG.
1216+
ncclCommDevIdxMapMutex.lock();
1217+
for (auto& it : devNCCLCommMap_) {
1218+
auto& ncclComms = it.second;
1219+
for (const auto& ncclComm : ncclComms) {
1220+
ncclCommDevIdxMap.erase(ncclComm);
1221+
}
1222+
}
1223+
ncclCommDevIdxMapMutex.unlock();
1224+
11271225
std::lock_guard<std::mutex> lock(mutex_);
11281226
abortCommsFromMap(devNCCLCommMap_, rank_, abortReason);
11291227
abortCommsFromMap(inInitializationCommMap_, rank_, abortReason);
@@ -1498,6 +1596,12 @@ void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
14981596
devNCCLCommMap_.erase(devNCCLCommMapKey);
14991597
// Clear used device indices.
15001598
usedDeviceIdxs_.clear();
1599+
1600+
ncclCommDevIdxMapMutex.lock();
1601+
for (const auto& comm : ncclComms) {
1602+
ncclCommDevIdxMap.erase(comm);
1603+
}
1604+
ncclCommDevIdxMapMutex.unlock();
15011605
}
15021606

15031607
std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
@@ -1662,6 +1766,34 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
16621766
if (it != inInitializationCommMap_.end()) {
16631767
devNCCLCommMap_.emplace(devicesKey, std::move(it->second));
16641768
inInitializationCommMap_.erase(devicesKey);
1769+
1770+
// Now ncclComms are fully initialized.
1771+
// Register all active CUDA memory segments in cache allocator to
1772+
// the new NCCL communicators
1773+
if (useTensorRegisterAllocatorHook_) {
1774+
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
1775+
// Register the segment to a new NCCL communicator if on the same device
1776+
for (const auto& segmentInfo : snapshot.segments) {
1777+
for (const auto i : c10::irange(devices.size())) {
1778+
if (segmentInfo.device != devices[i].index())
1779+
continue;
1780+
ncclComms[i]->registerSegment(
1781+
reinterpret_cast<void*>(segmentInfo.address),
1782+
segmentInfo.total_size);
1783+
}
1784+
}
1785+
1786+
// Record the mapping between ncclComm and device index so that later
1787+
// register hook can register a newly allocated segment to communicators
1788+
// on the same device.
1789+
// NOTE: we need remove the communicator from this map when it is
1790+
// destroyed, otherwise may register onto an invalid communicator.
1791+
ncclCommDevIdxMapMutex.lock();
1792+
for (const auto i : c10::irange(devices.size())) {
1793+
ncclCommDevIdxMap.emplace(ncclComms[i], devices[i].index());
1794+
}
1795+
ncclCommDevIdxMapMutex.unlock();
1796+
}
16651797
}
16661798

16671799
it = devNCCLCommMap_.find(devicesKey);

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ enum ErrorHandlingMode {
7676
constexpr const char* TORCH_NCCL_AVOID_RECORD_STREAMS =
7777
"TORCH_NCCL_AVOID_RECORD_STREAMS";
7878

79+
// If set, ProcessGroupNCCL registers postAlloc and preFree hooks to cuda cache
80+
// allocator so that whenever a tensor is allocated or freed, ProcessGroupNCCL
81+
// can register/deregister the tensor on all available NCCL communicators.
82+
constexpr const char* NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK =
83+
"NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK";
84+
7985
// ProcessGroupNCCL implements NCCL bindings for c10d.
8086
//
8187
// All functions of the class are expected to be called in the same order
@@ -766,6 +772,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
766772
// for the operation to complete.
767773
bool blockingWait_ = false;
768774

775+
// Whether or not to hook the cache allocator to register all allocated
776+
// tensors
777+
bool useTensorRegisterAllocatorHook_ = false;
778+
769779
// Whether or not the workCleanupThread is used to perform async error
770780
// handling.
771781
ErrorHandlingMode asyncErrorHandling_ = NoHandling;

0 commit comments

Comments
 (0)