diff --git a/requirements.txt b/requirements.txt index 0aeb60d988..bd2f3665b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ instructlab-schema>=0.4.2 instructlab-sdg>=0.7.3 # XXX(osilkin): we need to pin this version for now while we improve tokenizer logic # see: https://github.com/instructlab/training/pull/428 -instructlab-training>=0.8.0 +instructlab-training>=0.9.0 llama_cpp_python[server]==0.3.6 mlx>=0.5.1,<0.6.0; sys_platform == 'darwin' and platform_machine == 'arm64' numpy>=1.26.4,<2.0.0 @@ -34,9 +34,7 @@ toml>=0.10.2 # Default version. Can be overridden in extra requirements torch>=2.3.0,<2.6.0 tqdm>=4.66.2 -# temporary cap until https://github.com/instructlab/training/pull/443 is merged and consumed within instructlab -# above PR fixes interactions with newer versions of transformers through the training library -transformers>=4.41.2,<4.51.0 +transformers>=4.41.2 trl>=0.12.2,<0.15.0 wandb>=0.16.4 xdg-base-dirs>=6.0.1 diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 475645a44a..5b44608292 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -1,2 +1,2 @@ # Extra dependencies for NVIDIA CUDA -instructlab-training[cuda]>=0.8.0 +instructlab-training[cuda]>=0.9.0 diff --git a/requirements/hpu.txt b/requirements/hpu.txt index 697bfeccd6..a06a0d138d 100644 --- a/requirements/hpu.txt +++ b/requirements/hpu.txt @@ -10,4 +10,4 @@ habana_gpu_migration>=1.18.0 #habana-torch-dataloader # Extra dependencies for Intel Gaudi cards -instructlab-training[hpu]>=0.8.0 +instructlab-training[hpu]>=0.9.0 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 041d4e2a97..2de71b182b 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -1,2 +1,2 @@ # Extra dependencies for AMD ROCm -instructlab-training[rocm]>=0.8.0 +instructlab-training[rocm]>=0.9.0 diff --git a/src/instructlab/configuration.py b/src/instructlab/configuration.py index 6c263fffb4..ef72e6ec4f 100644 --- a/src/instructlab/configuration.py +++ b/src/instructlab/configuration.py @@ -7,7 +7,6 @@ import enum import logging import os -import pathlib import sys import textwrap import typing @@ -38,9 +37,6 @@ from typing_extensions import deprecated as Deprecated import click -# First Party -from instructlab.utils import get_model_arch, use_legacy_pretraining_format - # Local from . import log from .defaults import ( @@ -1572,6 +1568,7 @@ def map_train_to_library(ctx, params): click.secho(f"failed to get model with `--model-id`: {ve}", fg="red") raise click.exceptions.Exit(1) params["model_path"] = model_cfg.path + train_args.model_path = model_cfg.path ds_args = DeepSpeedOptions( cpu_offload_optimizer=params["deepspeed_cpu_offload_optimizer"], @@ -1617,14 +1614,8 @@ def map_train_to_library(ctx, params): if params["pipeline"] == "full": train_args.disable_flash_attn = True - student_model_arch = get_model_arch(pathlib.Path(params["model_path"])) if ctx.obj.config.general.use_legacy_tmpl: train_args.use_legacy_tmpl = True - else: - train_args.use_legacy_tmpl = use_legacy_pretraining_format( - params["model_path"], - student_model_arch, - ) return train_args, torch_args diff --git a/src/instructlab/model/full_train.py b/src/instructlab/model/full_train.py index baf29e0860..735abfb76b 100644 --- a/src/instructlab/model/full_train.py +++ b/src/instructlab/model/full_train.py @@ -20,6 +20,87 @@ logger = logging.getLogger(__name__) +def convert_loss_to_reduce_sum(model): + """ + this is necessary because multipack changes the samples per gpu, which biases the gradients to be larger for batches with less samples but longer lengths. + """ + # Standard + from typing import List, Optional + + # Third Party + import torch + + def reduce_sum_forward( + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + # pylint: disable=unused-argument + **deprecated_arguments, + ): + output = model.__original_forward__( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + ) + + logits = None + loss = None + return_dict = isinstance(output, dict) + if return_dict: + logits = output.logits + else: + # just checks that the output from the model is in the shape we expect, + # and that one of the tuple elements is the loss and one is the logits + if not ( + len(output) == 2 + and ( + (len(output[0].shape) == 3 and len(output[1].shape) == 0) + or (len(output[1].shape) == 3 and len(output[0].shape) == 0) + ) + ): + raise ValueError( + "Output does not match the expected structure. " + "Expected a tuple of length 2 with one element having shape of rank 3 and the other of rank 0." + ) + logits = output[0] if len(output[0].shape) == 3 else output[1] + + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, model.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Ensure tensors are on the same device + shift_labels = shift_labels.to(shift_logits.device) + loss_fct = torch.nn.CrossEntropyLoss(reduction="sum") + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + return ((loss,) + output) if loss is not None else output + + output.loss = loss + return output + + model.__original_forward__ = model.forward + model.forward = reduce_sum_forward + return model + + def setup_model( train_args, tokenizer, @@ -30,7 +111,7 @@ def setup_model( ): # pylint: disable=no-name-in-module # Third Party - from instructlab.training import multipack_sampler, utils + from instructlab.training import multipack_sampler from torch.utils.data import DataLoader collate_fn = partial(pad_collate_fn, pad_token_id=tokenizer.pad_token_id) @@ -121,8 +202,7 @@ def setup_model( if tokenizer.eos_token_id is not None and model.config.eos_token_id is None: model.config.eos_token_id = tokenizer.eos_token_id - model = utils.convert_loss_to_reduce_sum(model) - model = utils.add_noisy_embeddings(model, noise_alpha=None) + model = convert_loss_to_reduce_sum(model) return model, dataloader @@ -212,6 +292,12 @@ def train(train_args, device, optimize_memory): inner_pb = tqdm(range(len(dataloader)), desc=f"Epoch {epoch}") aggregated_values = torch.zeros(3, dtype=torch.float32).to(dev) + # in order to correctly calculate the loss, we need to divide each microbatch by + # a constant factor, so that we can later correct it by the actual `total_minibatch_tokens` amount + total_minibatch_tokens = 0 + interim_batch_denominator = packing_max_batch_len * accum + loss_accum = 0.0 # track this for logging puproses + for step, batch in enumerate(dataloader): aggregated_values[0] = batch.pop("num_loss_counted_tokens") aggregated_values[1] = len(batch["input_ids"]) @@ -220,16 +306,38 @@ def train(train_args, device, optimize_memory): for k in batch: batch[k] = batch[k].to(device=dev) - output = model(**batch, use_cache=False, return_dict=True) - loss = output.loss + output = model(**batch, use_cache=False, return_dict=False) + loss = None + if isinstance(output, tuple): + loss = output[0] + if len(output[0].shape) != 0: + raise ValueError( + "When output is a tuple, the loss should be the first element" + ) + else: + loss = output.loss + if loss is None: + raise ValueError( + "Loss is None. Ensure the model's output contains a valid loss." + ) + aggregated_values[2] = loss.item() num_loss_counted_tokens = aggregated_values[0] - loss = loss / num_loss_counted_tokens + total_minibatch_tokens += num_loss_counted_tokens - loss = loss / accum # Scale the loss for accumulation steps + # here we need to correctly rescale the loss, so we divide by the packing_max_batch_len + # in order to overshoot the average, and then we will later multiply each gradient + # by a correction term + loss_orig = loss.detach().cpu().item() + loss_accum += loss + loss = loss / interim_batch_denominator - logger.info(f"\nEpoch: {epoch}, Step: {step + 1}, Rank: 0, loss = {loss}") + per_batch_loss = loss_orig / num_loss_counted_tokens + + logger.info( + f"\nEpoch: {epoch}, Step: {step + 1}, Loss per batch: {loss.detach().item()}, Actual Loss Per Batch = {per_batch_loss}, accumulated loss: {loss_accum.item()}" + ) # Gradient accumulation loss.backward() # Backward pass @@ -243,9 +351,20 @@ def train(train_args, device, optimize_memory): # if we are on a step which is divisible by 4, step and zero gradients if (step + 1) % accum == 0: + # lets correct all of the gradients + for param in model.parameters(): + grad = param.grad + assert grad is not None + correction_term = interim_batch_denominator / total_minibatch_tokens + param.grad *= correction_term + optimizer.step() # Optimizer step optimizer.zero_grad() # Zero gradients + # reset all of the accumulated data + total_minibatch_tokens = 0.0 + loss_accum = 0.0 + # Clear cache after optimizer step if dev.type == "mps": torch.mps.empty_cache() @@ -257,7 +376,7 @@ def train(train_args, device, optimize_memory): torch.mps.empty_cache() output_dir = ( - Path(train_args.ckpt_output_dir) / "hf_format" / f"samples_{(epoch*8)}" + Path(train_args.ckpt_output_dir) / "hf_format" / f"samples_{(epoch * 8)}" ) logger.info(f"Saving Model to: {output_dir}") @@ -329,7 +448,7 @@ def pad_collate_fn(batch, pad_token_id): ] ) logger.info( - f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()} - rank: {0} " + f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()}" f"max len: {max_len} min len: {min(lens)} avg len: {lens.mean()} " f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m" )