Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 910236f

Browse files
cast SE weights and activations to fp32 (NVIDIA-NeMo#14743)
Signed-off-by: Elena Rastorgueva <[email protected]>
1 parent 350ec2d commit 910236f

File tree

1 file changed

+9
-1
lines changed
  • nemo/collections/asr/parts/submodules

1 file changed

+9
-1
lines changed

nemo/collections/asr/parts/submodules/jasper.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,13 @@ def forward_for_export(self, x, lengths):
477477
# Create sample mask - 1 represents value, 0 represents pad
478478
mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device)
479479
mask = ~mask # 0 represents value, 1 represents pad
480-
x = x.float() # For stable AMP, SE must be computed at fp32.
480+
481+
# Ensure SE runs in FP32: cast fc weights and activations to float32
482+
if self.fc[0].weight.dtype != torch.float32:
483+
self.fc.float()
484+
if x.dtype != torch.float32:
485+
x = x.float()
486+
481487
x = x.masked_fill(mask, 0.0) # mask padded values explicitly to 0
482488
y = self._se_pool_step(x, mask) # [B, C, 1]
483489
y = y.transpose(1, -1) # [B, 1, C]
@@ -490,6 +496,8 @@ def forward_for_export(self, x, lengths):
490496

491497
y = torch.sigmoid(y)
492498
y = x * y
499+
# Cast back to original dtype for downstream consistency
500+
y = y.to(dtype)
493501
return y, lengths
494502

495503
def _se_pool_step(self, x, mask):

0 commit comments

Comments
 (0)