Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e1fcd75

Browse files
liangel-02svekarssekyondaMeta
authored
update varlen tutorial to remove deprecated args (#3778)
Fixes #3775 Co-authored-by: Svetlana Karslioglu <[email protected]> Co-authored-by: sekyondaMeta <[email protected]>
1 parent 4ce8511 commit e1fcd75

2 files changed

Lines changed: 15 additions & 10 deletions

File tree

.jenkins/validate_tutorials_built.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
"intermediate_source/torchrec_intro_tutorial.py", #failing with 2.8 reenable after 3498
4343
"intermediate_source/torch_export_tutorial.py", # failing with 2.11 issue #3773
4444
"beginner_source/mosaic_memory_profiling_tutorial.py", # failing with 2.11 issue #3774
45-
"intermediate_source/variable_length_attention_tutorial.py", # failing with 2.11 issue #3775
4645
]
4746

4847
def tutorial_source_dirs() -> List[Path]:

intermediate_source/variable_length_attention_tutorial.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,27 +99,33 @@
9999
# cu_seq_k: torch.Tensor,
100100
# max_q: int,
101101
# max_k: int,
102-
# is_causal: bool = False,
102+
# *,
103103
# return_aux: AuxRequest | None = None,
104+
# scale: float | None = None,
105+
# window_size: tuple[int, int] = (-1, -1),
104106
# ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
105107
#
106108
# ``query``, ``key``, and ``value`` correspond to the ``q``, ``k``, and
107109
# ``v`` of the packed input. ``cu_seq_q`` and ``cu_seq_k`` are the
108110
# cumulative indices for query and key/value, respectively. These mark the
109111
# logical boundaries that separate the documents in our input. ``max_q``
110112
# and ``max_k`` are the maximum sequence lengths of query and key,
111-
# respectively. ``is_causal`` applies causal masking if set to True and
112-
# ``return_aux`` specifies which auxiliary outputs to return (ie ``lse``).
113+
# respectively. ``return_aux`` specifies which auxiliary outputs to return
114+
# (ie ``lse``). ``scale`` is an optional scaling factor applied to the
115+
# attention scores before softmax. ``window_size`` is a ``(left, right)``
116+
# tuple that controls sliding window attention: use ``(-1, -1)`` for full
117+
# attention (default), ``(-1, 0)`` for causal attention, or ``(W, 0)``
118+
# for causal attention with a sliding window of size ``W``.
113119

114120
######################################################################
115121
# **Note on causal masking**
116-
# When ``is_causal`` is set to True, causal masking is applied which means
117-
# that tokens can only attend to previous tokens. For bidirectional
118-
# attention, set this flag to False.
122+
# When ``window_size`` is set to ``(-1, 0)``, causal masking is applied
123+
# which means that tokens can only attend to previous tokens. For
124+
# bidirectional (full) attention, use the default ``(-1, -1)``.
119125
#
120126
# In torchtitan (PyTorch's pretraining framework), we set
121-
# ``is_causal = True`` uniformly to prevent the model from cheating and
122-
# artificially driving the loss down too quickly.
127+
# ``window_size = (-1, 0)`` uniformly to prevent the model from cheating
128+
# and artificially driving the loss down too quickly.
123129

124130

125131
######################################################################
@@ -241,7 +247,7 @@ def forward(
241247
cu_seq_k=cu_seq,
242248
max_q=max_len,
243249
max_k=max_len,
244-
is_causal=True,
250+
window_size=(-1, 0),
245251
)
246252
attn_out = attn_out.view(-1, self.embed_dim)
247253
attn_out = self.out_proj(attn_out)

0 commit comments

Comments
 (0)