|
99 | 99 | # cu_seq_k: torch.Tensor, |
100 | 100 | # max_q: int, |
101 | 101 | # max_k: int, |
102 | | -# is_causal: bool = False, |
| 102 | +# *, |
103 | 103 | # return_aux: AuxRequest | None = None, |
| 104 | +# scale: float | None = None, |
| 105 | +# window_size: tuple[int, int] = (-1, -1), |
104 | 106 | # ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: |
105 | 107 | # |
106 | 108 | # ``query``, ``key``, and ``value`` correspond to the ``q``, ``k``, and |
107 | 109 | # ``v`` of the packed input. ``cu_seq_q`` and ``cu_seq_k`` are the |
108 | 110 | # cumulative indices for query and key/value, respectively. These mark the |
109 | 111 | # logical boundaries that separate the documents in our input. ``max_q`` |
110 | 112 | # and ``max_k`` are the maximum sequence lengths of query and key, |
111 | | -# respectively. ``is_causal`` applies causal masking if set to True and |
112 | | -# ``return_aux`` specifies which auxiliary outputs to return (ie ``lse``). |
| 113 | +# respectively. ``return_aux`` specifies which auxiliary outputs to return |
| 114 | +# (ie ``lse``). ``scale`` is an optional scaling factor applied to the |
| 115 | +# attention scores before softmax. ``window_size`` is a ``(left, right)`` |
| 116 | +# tuple that controls sliding window attention: use ``(-1, -1)`` for full |
| 117 | +# attention (default), ``(-1, 0)`` for causal attention, or ``(W, 0)`` |
| 118 | +# for causal attention with a sliding window of size ``W``. |
113 | 119 |
|
114 | 120 | ###################################################################### |
115 | 121 | # **Note on causal masking** |
116 | | -# When ``is_causal`` is set to True, causal masking is applied which means |
117 | | -# that tokens can only attend to previous tokens. For bidirectional |
118 | | -# attention, set this flag to False. |
| 122 | +# When ``window_size`` is set to ``(-1, 0)``, causal masking is applied |
| 123 | +# which means that tokens can only attend to previous tokens. For |
| 124 | +# bidirectional (full) attention, use the default ``(-1, -1)``. |
119 | 125 | # |
120 | 126 | # In torchtitan (PyTorch's pretraining framework), we set |
121 | | -# ``is_causal = True`` uniformly to prevent the model from cheating and |
122 | | -# artificially driving the loss down too quickly. |
| 127 | +# ``window_size = (-1, 0)`` uniformly to prevent the model from cheating |
| 128 | +# and artificially driving the loss down too quickly. |
123 | 129 |
|
124 | 130 |
|
125 | 131 | ###################################################################### |
@@ -241,7 +247,7 @@ def forward( |
241 | 247 | cu_seq_k=cu_seq, |
242 | 248 | max_q=max_len, |
243 | 249 | max_k=max_len, |
244 | | - is_causal=True, |
| 250 | + window_size=(-1, 0), |
245 | 251 | ) |
246 | 252 | attn_out = attn_out.view(-1, self.embed_dim) |
247 | 253 | attn_out = self.out_proj(attn_out) |
|
0 commit comments