1818from torch ._inductor .codegen .wrapper import IndentedBuffer
1919from torch ._inductor .codegen .wrapper import WrapperLine
2020
21+ import apex .contrib .torchsched .config as torchsched_config
22+ from apex .contrib .torchsched .inductor ._utils import DEFAULT_STREAM_IDX
2123from apex .contrib .torchsched .inductor ._utils import ENTRANCE_EVENT
2224from apex .contrib .torchsched .inductor ._utils import EVENT_NAME_TEMPLATE
2325from apex .contrib .torchsched .inductor ._utils import get_stream_name
@@ -31,6 +33,7 @@ class CudaEventSym:
3133 Args:
3234 factory: The CUDAEventFactory that generate this event.
3335 idx: Indexing number assigned in chronological order during scheduling.
36+ originate_stream_idx: The index of the CUDA stream that this event originated from.
3437 ref_count: Reference count of this event instance.
3538 materialized_event: The actual CUDA Event name that will be used in the final PyTorch
3639 program. Only symbolic event with reference count larger than one will be materialized.
@@ -42,22 +45,30 @@ class CudaEventSym:
4245
4346 factory : CudaEventFactory
4447 idx : int
48+ originate_stream_idx : int
4549 ref_count : int = 0
4650 materialized_event : str | None = None
4751
4852 def __lt__ (self , rhs : CudaEventSym ) -> bool :
4953 """Whether the current event is generated before the rhs event."""
50- return self .idx < rhs .idx and self .factory is rhs .factory
54+ if self .factory is not rhs .factory :
55+ return NotImplemented
56+ return (self .idx , self .originate_stream_idx ) < (rhs .idx , rhs .originate_stream_idx )
5157
5258 def __eq__ (self , rhs : object ) -> bool :
5359 """Whether the current event is identical to the rhs event."""
5460 if not isinstance (rhs , CudaEventSym ):
5561 return NotImplemented
56- return self .idx == rhs .idx and self .factory is rhs .factory
62+ return (
63+ self .idx == rhs .idx
64+ and self .originate_stream_idx == rhs .originate_stream_idx
65+ and self .factory is rhs .factory
66+ )
5767
5868 def __str__ (self ) -> str :
5969 """Represent this symbolic event in string."""
6070 ret = f"{ self .__class__ .__name__ } (idx={ self .idx } "
71+ ret += f", originate_stream_idx={ self .originate_stream_idx } "
6172 if self .ref_count :
6273 ret += f", ref_count={ self .ref_count } "
6374 if self .materialized_event :
@@ -67,7 +78,7 @@ def __str__(self) -> str:
6778
6879 def __hash__ (self ) -> int :
6980 """Hash this symbolic event."""
70- return hash (f" { id (self .factory )= } , { self .idx = } " )
81+ return hash (( id (self .factory ), self .idx , self . originate_stream_idx ) )
7182
7283 def record (self , stream_idx : int ) -> _CudaEventRecordLine :
7384 """Record this event on a given stream.
@@ -103,6 +114,7 @@ def wait(self, stream_idx: int) -> _CudaEventWaitLine:
103114 the reference count of this event. If an event object has called this method, it is
104115 guaranteed to be generated in the final program.
105116 """
117+ assert stream_idx != self .originate_stream_idx
106118 self .ref_count += 1
107119 stream = get_stream_name (stream_idx )
108120 return _CudaEventWaitLine (self , stream )
@@ -113,11 +125,12 @@ class _CudaEventRecordLine(WrapperLine):
113125
114126 event : CudaEventSym
115127 stream : str
128+ _reuse_cuda_event : bool = torchsched_config .reuse_cuda_event
116129
117130 def codegen (self , code : IndentedBuffer ) -> None :
118131 assert 0 <= self .event .ref_count
119132 assert self .event .materialized_event is None
120- if self .event .ref_count :
133+ if self .event .ref_count or not self . _reuse_cuda_event :
121134 self .event .materialized_event = self .event .factory .get_materialized_event (code )
122135 code .writeline (f"{ self .event .materialized_event } .record({ self .stream } )" )
123136
@@ -131,12 +144,13 @@ class _CudaEventWaitLine(WrapperLine):
131144 def codegen (self , code : IndentedBuffer ) -> None :
132145 assert 0 < self .event .ref_count
133146 assert self .event .materialized_event is not None
134- code . writeline ( f"{ self .event .materialized_event } .wait({ self .stream } )" )
147+ code_line = f"{ self .event .materialized_event } .wait({ self .stream } )"
135148 self .event .ref_count -= 1
136149 if self .event .ref_count == 0 :
137150 self .event .factory .deposit_materialized_event (self .event .materialized_event )
138151 self .event .materialized_event = None
139- code .writeline (f"# End lifecycle of { self .event } " )
152+ code_line += f" # End lifecycle of { self .event } "
153+ code .writeline (code_line )
140154
141155
142156class CudaEventFactory :
@@ -153,23 +167,32 @@ def __init__(self) -> None:
153167 self .materialized_event_idx : itertools .count = itertools .count (start = 1 )
154168 self .available_materialized_events : set [str ] = set ()
155169 self ._entrance_event : CudaEventSym | None = None
170+ self ._reuse_cuda_event : bool = torchsched_config .reuse_cuda_event
156171
157172 def get_entrance_event (self ) -> CudaEventSym :
158173 """Return the cuda event that corresponding to compute graph entering."""
159174 if self ._entrance_event is None :
160- self ._entrance_event = CudaEventSym (factory = self , idx = 0 )
175+ self ._entrance_event = CudaEventSym (
176+ factory = self ,
177+ idx = 0 ,
178+ originate_stream_idx = DEFAULT_STREAM_IDX ,
179+ )
161180 # Code-gen for entrance event is almost hard-coded in device guard enter so the
162181 # materialization is slightly different here.
163182 self ._entrance_event .materialized_event = ENTRANCE_EVENT
164183 return self ._entrance_event
165184
166- def get_sym_event (self ) -> CudaEventSym :
185+ def get_sym_event (self , originate_stream_idx : int ) -> CudaEventSym :
167186 """Allocate a symbolic cuda event."""
168- return CudaEventSym (factory = self , idx = next (self .symbolic_event_idx ))
187+ return CudaEventSym (
188+ factory = self ,
189+ idx = next (self .symbolic_event_idx ),
190+ originate_stream_idx = originate_stream_idx ,
191+ )
169192
170193 def get_materialized_event (self , code : IndentedBuffer ) -> str :
171194 """Allocate or reuse a materialized cuda event."""
172- if self .available_materialized_events :
195+ if self ._reuse_cuda_event and self . available_materialized_events :
173196 return self .available_materialized_events .pop ()
174197 else :
175198 event = EVENT_NAME_TEMPLATE .format (event_idx = next (self .materialized_event_idx ))
0 commit comments