1616#include < ATen/cuda/CUDAContext.h>
1717#include < ATen/cuda/CUDAGraph.h>
1818#include < c10/core/DeviceType.h>
19+ #include < c10/cuda/CUDAAllocatorConfig.h>
1920#include < c10/cuda/CUDAGraphsC10Utils.h>
2021#include < c10/cuda/CUDAGuard.h>
2122#include < c10/util/CallOnce.h>
@@ -317,6 +318,55 @@ c10::List<c10::IValue> new_list() {
317318
318319} // namespace
319320
321+ // Map from each communicator to its device index.
322+ // This map is used when register/deregister cache segments from cache
323+ // allocator. See design notes below:
324+ // - Each segment should be registered only to the communicator on the
325+ // same device.
326+ // - We cannot reuse devNCCLCommMap_ in each ProcessGroup because the key may be
327+ // ranks rather than device in point-to-point case.
328+ // - This map has also to be maintained as global variable since the register
329+ // hooks are called outside the scope of any PG, thus we need traverse
330+ // communicators in all PGs.
331+ static std::unordered_map<std::shared_ptr<NCCLComm>, int > ncclCommDevIdxMap;
332+ static std::mutex ncclCommDevIdxMapMutex;
333+ static bool allocatorHooksAttached = false ;
334+ void cacheAllocatorRegisterHook (
335+ const c10::cuda::CUDACachingAllocator::TraceEntry& te) {
336+ // Register after SEGMENT_ALLOC
337+ if (te.action_ !=
338+ c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_ALLOC) {
339+ return ;
340+ }
341+
342+ std::lock_guard<std::mutex> lock (ncclCommDevIdxMapMutex);
343+ for (auto & it : ncclCommDevIdxMap) {
344+ auto & ncclComm = it.first ;
345+ auto & devIdx = it.second ;
346+ if (te.device_ == devIdx) {
347+ ncclComm->registerSegment (reinterpret_cast <void *>(te.addr_ ), te.size_ );
348+ }
349+ }
350+ }
351+
352+ void cacheAllocatorDeregisterHook (
353+ const c10::cuda::CUDACachingAllocator::TraceEntry& te) {
354+ // deregister before SEGMENT_FREE
355+ if (te.action_ !=
356+ c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_FREE) {
357+ return ;
358+ }
359+
360+ std::lock_guard<std::mutex> lock (ncclCommDevIdxMapMutex);
361+ for (auto & it : ncclCommDevIdxMap) {
362+ auto & ncclComm = it.first ;
363+ auto & devIdx = it.second ;
364+ if (te.device_ == devIdx) {
365+ ncclComm->deregisterSegment (reinterpret_cast <void *>(te.addr_ ));
366+ }
367+ }
368+ }
369+
320370struct NCCLTraceBuffer {
321371 static NCCLTraceBuffer* get () {
322372 // intentionally leak on exit
@@ -817,6 +867,12 @@ void ProcessGroupNCCL::WorkNCCL::abort() {
817867 for (const auto & ncclComm : ncclComms_) {
818868 ncclComm->ncclCommAbort ();
819869 }
870+
871+ ncclCommDevIdxMapMutex.lock ();
872+ for (const auto & comm : ncclComms_) {
873+ ncclCommDevIdxMap.erase (comm);
874+ }
875+ ncclCommDevIdxMapMutex.unlock ();
820876}
821877
822878ProcessGroupNCCL::CoalescedWorkNCCL::CoalescedWorkNCCL (
@@ -881,6 +937,17 @@ ProcessGroupNCCL::ProcessGroupNCCL(
881937 parseEnvVarIntDefault (" TORCH_NCCL_TRACE_BUFFER_SIZE" , 0 ) > 0 );
882938#endif
883939 avoidRecordStreams_ = parseEnvVarFlag (TORCH_NCCL_AVOID_RECORD_STREAMS);
940+ #ifdef NCCL_HAS_COMM_REGISTER
941+ useTensorRegisterAllocatorHook_ =
942+ parseEnvVarFlag (NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK);
943+ if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
944+ expandable_segments ()) {
945+ useTensorRegisterAllocatorHook_ = false ;
946+ LOG (INFO)
947+ << " [Rank " << rank_
948+ << " ] disables NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode." ;
949+ }
950+ #endif
884951
885952 if (blockingWait_) {
886953 if (asyncErrorHandling_ != NoHandling || desyncDebug_) {
@@ -932,6 +999,10 @@ ProcessGroupNCCL::ProcessGroupNCCL(
932999 << options_->is_high_priority_stream
9331000 << " , TORCH_DISTRIBUTED_DEBUG: "
9341001 << std::string (torch_distributed_debug)
1002+ #ifdef NCCL_HAS_COMM_REGISTER
1003+ << " , NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: "
1004+ << useTensorRegisterAllocatorHook_
1005+ #endif
9351006 << " , NCCL_DEBUG: " << std::string (nccl_debug)
9361007 << " , ID=" << this ->getID ();
9371008
@@ -947,6 +1018,19 @@ ProcessGroupNCCL::ProcessGroupNCCL(
9471018 std::vector<int64_t >(), // outSplitSizes
9481019 size_); // worldSize
9491020
1021+ // Attach hooks to cache allocator to trigger the hooks whenever a traced
1022+ // action is called. In the following hooks, we register a newly allocated
1023+ // segment when SEGMENT_ALLOC action occurs, and deregister a segment when
1024+ // SEGMENT_FREE action occurs.
1025+ // We attach hooks only once at the first PG creation.
1026+ if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) {
1027+ c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker (
1028+ &cacheAllocatorRegisterHook);
1029+ c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker (
1030+ &cacheAllocatorDeregisterHook);
1031+ allocatorHooksAttached = true ;
1032+ }
1033+
9501034#ifdef USE_NCCL_WITH_UCC
9511035 static c10::once_flag initialize_ucc_lib_flag;
9521036 c10::call_once (initialize_ucc_lib_flag, [&] {
@@ -1124,6 +1208,20 @@ void abortCommsFromMap(
11241208
11251209// Abort all communicators on this rank
11261210void ProcessGroupNCCL::abort (c10::optional<std::string> abortReason) {
1211+ // Remove record from global ncclCommDevIdxMapMutex before aboarting,
1212+ // so that a new cache segment would not register to already aborded
1213+ // communicators. Note that ncclCommDevIdxMap is a global container which may
1214+ // contain other PG's communicators, thus we need to only erase communicators
1215+ // for the current PG.
1216+ ncclCommDevIdxMapMutex.lock ();
1217+ for (auto & it : devNCCLCommMap_) {
1218+ auto & ncclComms = it.second ;
1219+ for (const auto & ncclComm : ncclComms) {
1220+ ncclCommDevIdxMap.erase (ncclComm);
1221+ }
1222+ }
1223+ ncclCommDevIdxMapMutex.unlock ();
1224+
11271225 std::lock_guard<std::mutex> lock (mutex_);
11281226 abortCommsFromMap (devNCCLCommMap_, rank_, abortReason);
11291227 abortCommsFromMap (inInitializationCommMap_, rank_, abortReason);
@@ -1498,6 +1596,12 @@ void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
14981596 devNCCLCommMap_.erase (devNCCLCommMapKey);
14991597 // Clear used device indices.
15001598 usedDeviceIdxs_.clear ();
1599+
1600+ ncclCommDevIdxMapMutex.lock ();
1601+ for (const auto & comm : ncclComms) {
1602+ ncclCommDevIdxMap.erase (comm);
1603+ }
1604+ ncclCommDevIdxMapMutex.unlock ();
15011605}
15021606
15031607std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm (
@@ -1662,6 +1766,34 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
16621766 if (it != inInitializationCommMap_.end ()) {
16631767 devNCCLCommMap_.emplace (devicesKey, std::move (it->second ));
16641768 inInitializationCommMap_.erase (devicesKey);
1769+
1770+ // Now ncclComms are fully initialized.
1771+ // Register all active CUDA memory segments in cache allocator to
1772+ // the new NCCL communicators
1773+ if (useTensorRegisterAllocatorHook_) {
1774+ auto snapshot = c10::cuda::CUDACachingAllocator::snapshot ();
1775+ // Register the segment to a new NCCL communicator if on the same device
1776+ for (const auto & segmentInfo : snapshot.segments ) {
1777+ for (const auto i : c10::irange (devices.size ())) {
1778+ if (segmentInfo.device != devices[i].index ())
1779+ continue ;
1780+ ncclComms[i]->registerSegment (
1781+ reinterpret_cast <void *>(segmentInfo.address ),
1782+ segmentInfo.total_size );
1783+ }
1784+ }
1785+
1786+ // Record the mapping between ncclComm and device index so that later
1787+ // register hook can register a newly allocated segment to communicators
1788+ // on the same device.
1789+ // NOTE: we need remove the communicator from this map when it is
1790+ // destroyed, otherwise may register onto an invalid communicator.
1791+ ncclCommDevIdxMapMutex.lock ();
1792+ for (const auto i : c10::irange (devices.size ())) {
1793+ ncclCommDevIdxMap.emplace (ncclComms[i], devices[i].index ());
1794+ }
1795+ ncclCommDevIdxMapMutex.unlock ();
1796+ }
16651797 }
16661798
16671799 it = devNCCLCommMap_.find (devicesKey);
0 commit comments