-
Notifications
You must be signed in to change notification settings - Fork 450
chore: remove transformers cap and raise training floor #3264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| # Extra dependencies for NVIDIA CUDA | ||
| instructlab-training[cuda]>=0.8.0 | ||
| instructlab-training[cuda]>=0.9.0 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| # Extra dependencies for AMD ROCm | ||
| instructlab-training[rocm]>=0.8.0 | ||
| instructlab-training[rocm]>=0.9.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,6 @@ | |
| import enum | ||
| import logging | ||
| import os | ||
| import pathlib | ||
| import sys | ||
| import textwrap | ||
| import typing | ||
|
|
@@ -38,9 +37,6 @@ | |
| from typing_extensions import deprecated as Deprecated | ||
| import click | ||
|
|
||
| # First Party | ||
| from instructlab.utils import get_model_arch, use_legacy_pretraining_format | ||
|
|
||
| # Local | ||
| from . import log | ||
| from .defaults import ( | ||
|
|
@@ -1572,6 +1568,7 @@ def map_train_to_library(ctx, params): | |
| click.secho(f"failed to get model with `--model-id`: {ve}", fg="red") | ||
| raise click.exceptions.Exit(1) | ||
| params["model_path"] = model_cfg.path | ||
| train_args.model_path = model_cfg.path | ||
|
|
||
| ds_args = DeepSpeedOptions( | ||
| cpu_offload_optimizer=params["deepspeed_cpu_offload_optimizer"], | ||
|
|
@@ -1617,14 +1614,8 @@ def map_train_to_library(ctx, params): | |
| if params["pipeline"] == "full": | ||
| train_args.disable_flash_attn = True | ||
|
|
||
| student_model_arch = get_model_arch(pathlib.Path(params["model_path"])) | ||
| if ctx.obj.config.general.use_legacy_tmpl: | ||
| train_args.use_legacy_tmpl = True | ||
| else: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this change mean that legacy tokenizer config won't be detected anymore?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @booxter We don't need it anymore
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need the change or we don't need support for legacy auto-detection? And is this change related to transformers cap removal / training dependency update? |
||
| train_args.use_legacy_tmpl = use_legacy_pretraining_format( | ||
| params["model_path"], | ||
| student_model_arch, | ||
| ) | ||
| return train_args, torch_args | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,87 @@ | |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def convert_loss_to_reduce_sum(model): | ||
| """ | ||
| this is necessary because multipack changes the samples per gpu, which biases the gradients to be larger for batches with less samples but longer lengths. | ||
| """ | ||
| # Standard | ||
| from typing import List, Optional | ||
|
|
||
| # Third Party | ||
| import torch | ||
|
|
||
| def reduce_sum_forward( | ||
| input_ids: torch.LongTensor, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
| inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| labels: Optional[torch.LongTensor] = None, | ||
| use_cache: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict_in_generate: Optional[bool] = None, | ||
| # pylint: disable=unused-argument | ||
| **deprecated_arguments, | ||
| ): | ||
| output = model.__original_forward__( | ||
| input_ids=input_ids, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
| past_key_values=past_key_values, | ||
| inputs_embeds=inputs_embeds, | ||
| labels=labels, | ||
| use_cache=use_cache, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict_in_generate=return_dict_in_generate, | ||
| ) | ||
|
|
||
| logits = None | ||
| loss = None | ||
| return_dict = isinstance(output, dict) | ||
| if return_dict: | ||
| logits = output.logits | ||
| else: | ||
| # just checks that the output from the model is in the shape we expect, | ||
| # and that one of the tuple elements is the loss and one is the logits | ||
| if not ( | ||
| len(output) == 2 | ||
| and ( | ||
| (len(output[0].shape) == 3 and len(output[1].shape) == 0) | ||
| or (len(output[1].shape) == 3 and len(output[0].shape) == 0) | ||
| ) | ||
| ): | ||
| raise ValueError( | ||
| "Output does not match the expected structure. " | ||
| "Expected a tuple of length 2 with one element having shape of rank 3 and the other of rank 0." | ||
| ) | ||
| logits = output[0] if len(output[0].shape) == 3 else output[1] | ||
|
|
||
| if labels is not None: | ||
| # Shift so that tokens < n predict n | ||
| shift_logits = logits[..., :-1, :].contiguous() | ||
| shift_labels = labels[..., 1:].contiguous() | ||
| # Flatten the tokens | ||
| shift_logits = shift_logits.view(-1, model.config.vocab_size) | ||
| shift_labels = shift_labels.view(-1) | ||
| # Ensure tensors are on the same device | ||
| shift_labels = shift_labels.to(shift_logits.device) | ||
| loss_fct = torch.nn.CrossEntropyLoss(reduction="sum") | ||
| loss = loss_fct(shift_logits, shift_labels) | ||
|
|
||
| if not return_dict: | ||
| return ((loss,) + output) if loss is not None else output | ||
|
|
||
| output.loss = loss | ||
| return output | ||
|
|
||
| model.__original_forward__ = model.forward | ||
| model.forward = reduce_sum_forward | ||
| return model | ||
|
|
||
|
|
||
| def setup_model( | ||
| train_args, | ||
| tokenizer, | ||
|
|
@@ -30,7 +111,7 @@ def setup_model( | |
| ): | ||
| # pylint: disable=no-name-in-module | ||
| # Third Party | ||
| from instructlab.training import multipack_sampler, utils | ||
| from instructlab.training import multipack_sampler | ||
| from torch.utils.data import DataLoader | ||
|
|
||
| collate_fn = partial(pad_collate_fn, pad_token_id=tokenizer.pad_token_id) | ||
|
|
@@ -121,8 +202,7 @@ def setup_model( | |
| if tokenizer.eos_token_id is not None and model.config.eos_token_id is None: | ||
| model.config.eos_token_id = tokenizer.eos_token_id | ||
|
|
||
| model = utils.convert_loss_to_reduce_sum(model) | ||
| model = utils.add_noisy_embeddings(model, noise_alpha=None) | ||
| model = convert_loss_to_reduce_sum(model) | ||
booxter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return model, dataloader | ||
|
|
||
|
|
||
|
|
@@ -212,6 +292,12 @@ def train(train_args, device, optimize_memory): | |
| inner_pb = tqdm(range(len(dataloader)), desc=f"Epoch {epoch}") | ||
| aggregated_values = torch.zeros(3, dtype=torch.float32).to(dev) | ||
|
|
||
| # in order to correctly calculate the loss, we need to divide each microbatch by | ||
| # a constant factor, so that we can later correct it by the actual `total_minibatch_tokens` amount | ||
| total_minibatch_tokens = 0 | ||
| interim_batch_denominator = packing_max_batch_len * accum | ||
| loss_accum = 0.0 # track this for logging puproses | ||
|
|
||
| for step, batch in enumerate(dataloader): | ||
| aggregated_values[0] = batch.pop("num_loss_counted_tokens") | ||
| aggregated_values[1] = len(batch["input_ids"]) | ||
|
|
@@ -220,16 +306,38 @@ def train(train_args, device, optimize_memory): | |
| for k in batch: | ||
| batch[k] = batch[k].to(device=dev) | ||
|
|
||
| output = model(**batch, use_cache=False, return_dict=True) | ||
| loss = output.loss | ||
| output = model(**batch, use_cache=False, return_dict=False) | ||
| loss = None | ||
| if isinstance(output, tuple): | ||
| loss = output[0] | ||
| if len(output[0].shape) != 0: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
| raise ValueError( | ||
| "When output is a tuple, the loss should be the first element" | ||
| ) | ||
| else: | ||
| loss = output.loss | ||
| if loss is None: | ||
| raise ValueError( | ||
| "Loss is None. Ensure the model's output contains a valid loss." | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These errors will bubble up to users through CLI error. I am not sure they have the context to interpret the errors meaningfully. As a user, how do I "ensure the model's output contains a valid loss"? Are you asking the user to check their model inputs / config perhaps? If so, this is what should be communicated.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This applies to other errors here I think?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @booxter This shouldn't be happening, so if this happens then we want users to report this to us. They are also welcome to look at the internals of the CLI and potentially contribute back. We shouldn't assume that our users are incapable of being technical.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. So it's always a bug in the code and not the inputs from the user? If so, the suggestion to "ensure the model's output" seems misplaced and we may instead want to direct the user to report the issue. (And maybe dump some more info to include with the report?)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It could be a few things, the model implementation could have changed or the transformers API itself could have changed (as in the case here). So when that happens we just have this as a safeguard |
||
| ) | ||
|
|
||
| aggregated_values[2] = loss.item() | ||
|
|
||
| num_loss_counted_tokens = aggregated_values[0] | ||
| loss = loss / num_loss_counted_tokens | ||
| total_minibatch_tokens += num_loss_counted_tokens | ||
|
|
||
| loss = loss / accum # Scale the loss for accumulation steps | ||
| # here we need to correctly rescale the loss, so we divide by the packing_max_batch_len | ||
| # in order to overshoot the average, and then we will later multiply each gradient | ||
| # by a correction term | ||
| loss_orig = loss.detach().cpu().item() | ||
| loss_accum += loss | ||
| loss = loss / interim_batch_denominator | ||
|
|
||
| logger.info(f"\nEpoch: {epoch}, Step: {step + 1}, Rank: 0, loss = {loss}") | ||
| per_batch_loss = loss_orig / num_loss_counted_tokens | ||
|
|
||
| logger.info( | ||
| f"\nEpoch: {epoch}, Step: {step + 1}, Loss per batch: {loss.detach().item()}, Actual Loss Per Batch = {per_batch_loss}, accumulated loss: {loss_accum.item()}" | ||
| ) | ||
|
|
||
| # Gradient accumulation | ||
| loss.backward() # Backward pass | ||
|
|
@@ -243,9 +351,20 @@ def train(train_args, device, optimize_memory): | |
|
|
||
| # if we are on a step which is divisible by 4, step and zero gradients | ||
| if (step + 1) % accum == 0: | ||
| # lets correct all of the gradients | ||
| for param in model.parameters(): | ||
| grad = param.grad | ||
| assert grad is not None | ||
| correction_term = interim_batch_denominator / total_minibatch_tokens | ||
| param.grad *= correction_term | ||
|
|
||
| optimizer.step() # Optimizer step | ||
| optimizer.zero_grad() # Zero gradients | ||
|
|
||
| # reset all of the accumulated data | ||
| total_minibatch_tokens = 0.0 | ||
| loss_accum = 0.0 | ||
|
|
||
| # Clear cache after optimizer step | ||
| if dev.type == "mps": | ||
| torch.mps.empty_cache() | ||
|
|
@@ -257,7 +376,7 @@ def train(train_args, device, optimize_memory): | |
| torch.mps.empty_cache() | ||
|
|
||
| output_dir = ( | ||
| Path(train_args.ckpt_output_dir) / "hf_format" / f"samples_{(epoch*8)}" | ||
| Path(train_args.ckpt_output_dir) / "hf_format" / f"samples_{(epoch * 8)}" | ||
booxter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| logger.info(f"Saving Model to: {output_dir}") | ||
|
|
@@ -329,7 +448,7 @@ def pad_collate_fn(batch, pad_token_id): | |
| ] | ||
| ) | ||
| logger.info( | ||
| f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()} - rank: {0} " | ||
| f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()}" | ||
| f"max len: {max_len} min len: {min(lens)} avg len: {lens.mean()} " | ||
| f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m" | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.