Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 140282d

Browse files
committed
Bug fixes
1 parent bec558b commit 140282d

2 files changed

Lines changed: 30 additions & 24 deletions

File tree

apex/contrib/bottleneck/bottleneck.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,6 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
289289
out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format)
290290
stream1.wait_stream(torch.cuda.current_stream())
291291
if spatial_method != 2: stream3.wait_stream(torch.cuda.current_stream())
292-
with torch.cuda.stream(stream3):
293-
if explicit_nhwc:
294-
out1_pad[:,1:Hs+1,:,:].copy_(out1)
295-
else:
296-
out1_pad[:,:,1:Hs+1,:].copy_(out1)
297292
with torch.cuda.stream(stream1):
298293
if explicit_nhwc:
299294
top_out1_halo = out1_pad[:,:1,:,:]
@@ -343,11 +338,11 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
343338
out1_pad[:,:,1:Hs+1,:].copy_(out1)
344339
elif spatial_method == 2:
345340
# wait for halo transfer to finish before doing a full convolution of padded x
346-
torch.cuda.current_stream().wait_stream(stream1)
347341
if explicit_nhwc:
348342
out1_pad[:,1:Hs+1,:,:].copy_(out1)
349343
else:
350344
out1_pad[:,:,1:Hs+1,:].copy_(out1)
345+
torch.cuda.current_stream().wait_stream(stream1)
351346
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
352347
elif spatial_method == 3:
353348
fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom)
@@ -705,8 +700,6 @@ def forward(self, x):
705700
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
706701
w_scale.append(s4)
707702
w_bias.append(b4)
708-
self.w_scale = w_scale
709-
self.w_bias = w_bias
710703
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
711704
else:
712705
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)

apex/contrib/csrc/peer_memory/peer_memory_cuda.cu

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,23 +153,36 @@ __device__ void checked_signal(
153153
const int v1, const int v2, const int v3, const int v4
154154
)
155155
{
156-
if (blockIdx.x == 0) {
157-
register int r1, r2, r3, r4;
158-
if (threadIdx.x == 0) {
159-
// wait for top neighbor to clear bottom signal (indicating ready for new input)
160-
do {
161-
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
162-
} while (r1 == v1 && r2 == v2 && r3 == v3 && r4 == v4);
163-
// signal to top neighbor my output is ready
164-
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
165-
} else if (threadIdx.x == 1) {
166-
// wait for bottom neighbor to clear top signal (indicating ready for new input)
156+
cg::this_grid().sync();
157+
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
158+
if (is_main_thread) {
159+
// flush all writes to global memory
160+
__threadfence_system();
161+
// wait for top or bottom neighbor to clear signal
162+
register int r1, r2, r3, r4;
163+
bool top_zeroed=false, btm_zeroed=false, top_done=false, btm_done=false;
164+
do {
167165
do {
168-
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
169-
} while (r1 == v1 && r2 == v2 && r3 == v3 && r4 == v4);
170-
// signal to bottom neighbor my output is ready
171-
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
172-
}
166+
if (!top_zeroed) {
167+
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
168+
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
169+
}
170+
if (!btm_zeroed) {
171+
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
172+
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
173+
}
174+
} while((top_zeroed == top_done) && (btm_zeroed == btm_done));
175+
if (!top_done && top_zeroed) {
176+
// signal to top neighbor my output is ready
177+
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
178+
top_done = true;
179+
}
180+
if (!btm_done && btm_zeroed) {
181+
// signal to bottom neighbor my output is ready
182+
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
183+
btm_done = true;
184+
}
185+
} while (!top_done || !btm_done);
173186
}
174187
}
175188

0 commit comments

Comments
 (0)