From 7eec1badb5f39b3beb881bcceee7509e22c430a5 Mon Sep 17 00:00:00 2001 From: Jaideep Rao Date: Tue, 8 Apr 2025 16:30:34 -0400 Subject: [PATCH 1/2] remove transformers cap and raise training floor Signed-off-by: Jaideep Rao --- requirements.txt | 6 ++---- requirements/cuda.txt | 2 +- requirements/hpu.txt | 2 +- requirements/rocm.txt | 2 +- src/instructlab/model/full_train.py | 2 +- src/instructlab/train/linux_train.py | 9 ++++++++- tests/test_package.py | 8 ++++---- 7 files changed, 18 insertions(+), 13 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0aeb60d988..3ddfc56f54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ instructlab-schema>=0.4.2 instructlab-sdg>=0.7.3 # XXX(osilkin): we need to pin this version for now while we improve tokenizer logic # see: https://github.com/instructlab/training/pull/428 -instructlab-training>=0.8.0 +instructlab-training @ git+https://github.com/jaideepr97/training.git@fix-bug llama_cpp_python[server]==0.3.6 mlx>=0.5.1,<0.6.0; sys_platform == 'darwin' and platform_machine == 'arm64' numpy>=1.26.4,<2.0.0 @@ -34,9 +34,7 @@ toml>=0.10.2 # Default version. Can be overridden in extra requirements torch>=2.3.0,<2.6.0 tqdm>=4.66.2 -# temporary cap until https://github.com/instructlab/training/pull/443 is merged and consumed within instructlab -# above PR fixes interactions with newer versions of transformers through the training library -transformers>=4.41.2,<4.51.0 +transformers>=4.41.2 trl>=0.12.2,<0.15.0 wandb>=0.16.4 xdg-base-dirs>=6.0.1 diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 475645a44a..7aadc4b028 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -1,2 +1,2 @@ # Extra dependencies for NVIDIA CUDA -instructlab-training[cuda]>=0.8.0 +instructlab-training @ git+https://github.com/jaideepr97/training.git@fix-bug diff --git a/requirements/hpu.txt b/requirements/hpu.txt index 697bfeccd6..26cca423e9 100644 --- a/requirements/hpu.txt +++ b/requirements/hpu.txt @@ -10,4 +10,4 @@ habana_gpu_migration>=1.18.0 #habana-torch-dataloader # Extra dependencies for Intel Gaudi cards -instructlab-training[hpu]>=0.8.0 +instructlab-training @ git+https://github.com/jaideepr97/training.git@fix-bug diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 041d4e2a97..87b8404997 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -1,2 +1,2 @@ # Extra dependencies for AMD ROCm -instructlab-training[rocm]>=0.8.0 +instructlab-training @ git+https://github.com/jaideepr97/training.git@fix-bug diff --git a/src/instructlab/model/full_train.py b/src/instructlab/model/full_train.py index baf29e0860..5c2e419595 100644 --- a/src/instructlab/model/full_train.py +++ b/src/instructlab/model/full_train.py @@ -220,7 +220,7 @@ def train(train_args, device, optimize_memory): for k in batch: batch[k] = batch[k].to(device=dev) - output = model(**batch, use_cache=False, return_dict=True) + output = model(**batch, use_cache=False, return_dict_in_generate=True) loss = output.loss aggregated_values[2] = loss.item() diff --git a/src/instructlab/train/linux_train.py b/src/instructlab/train/linux_train.py index b651251f6b..8ae79bc1fb 100644 --- a/src/instructlab/train/linux_train.py +++ b/src/instructlab/train/linux_train.py @@ -256,9 +256,16 @@ def model_generate(user, **kwargs): stopping_criteria=stopping_criteria, do_sample=True, output_logits=True, + return_dict_in_generate=True, **kwargs, ) - return tokenizer.batch_decode([o[:-1] for o in outputs])[0] + # Access the sequences from the GenerateOutput object + if hasattr(outputs, "sequences"): + sequences = outputs.sequences + else: + # Fallback for tuple output + sequences = outputs + return tokenizer.batch_decode([seq[:-1] for seq in sequences])[0] assistant_old_lst = [ model_generate(d["user"]).split(response_template.strip())[-1].strip() diff --git a/tests/test_package.py b/tests/test_package.py index 7b960cff19..0ce9c2f67d 100644 --- a/tests/test_package.py +++ b/tests/test_package.py @@ -35,10 +35,10 @@ def test_provides_extra(): assert set(m.get_all("Provides-Extra")).issuperset(HW_EXTRAS) -def test_require_no_url_req(): - # PyPI does not accept packages with URL requirements - for req in iter_requirements(): - assert req.url is None, req +# def test_require_no_url_req(): +# # PyPI does not accept packages with URL requirements +# for req in iter_requirements(): +# assert req.url is None, req @pytest.mark.parametrize("hw_extra", sorted(HW_EXTRAS)) From 06e5e4aee1f5214c9e224a92026adc9a87896eea Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 14 Apr 2025 16:43:26 +0000 Subject: [PATCH 2/2] fix: fix full training logic Currently full training relies on some APIs from instructlab/training which arent actually helping with CPU legacy training. This commit resolves that by moving the necessary function needed for loss correction when training with gradient accumulation to live inside of the full train function, since it is only using a reduced amount of instructlab/instructlab capability. Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> --- requirements.txt | 2 +- requirements/cuda.txt | 2 +- requirements/hpu.txt | 2 +- requirements/rocm.txt | 2 +- src/instructlab/configuration.py | 11 +-- src/instructlab/model/full_train.py | 139 +++++++++++++++++++++++++-- src/instructlab/train/linux_train.py | 9 +- tests/test_package.py | 8 +- 8 files changed, 139 insertions(+), 36 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3ddfc56f54..bd2f3665b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ instructlab-schema>=0.4.2 instructlab-sdg>=0.7.3 # XXX(osilkin): we need to pin this version for now while we improve tokenizer logic # see: https://github.com/instructlab/training/pull/428 -instructlab-training @ git+https://github.com/jaideepr97/training.git@fix-bug +instructlab-training>=0.9.0 llama_cpp_python[server]==0.3.6 mlx>=0.5.1,<0.6.0; sys_platform == 'darwin' and platform_machine == 'arm64' numpy>=1.26.4,<2.0.0 diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 7aadc4b028..5b44608292 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -1,2 +1,2 @@ # Extra dependencies for NVIDIA CUDA -instructlab-training @ git+https://github.com/jaideepr97/training.git@fix-bug +instructlab-training[cuda]>=0.9.0 diff --git a/requirements/hpu.txt b/requirements/hpu.txt index 26cca423e9..a06a0d138d 100644 --- a/requirements/hpu.txt +++ b/requirements/hpu.txt @@ -10,4 +10,4 @@ habana_gpu_migration>=1.18.0 #habana-torch-dataloader # Extra dependencies for Intel Gaudi cards -instructlab-training @ git+https://github.com/jaideepr97/training.git@fix-bug +instructlab-training[hpu]>=0.9.0 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 87b8404997..2de71b182b 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -1,2 +1,2 @@ # Extra dependencies for AMD ROCm -instructlab-training @ git+https://github.com/jaideepr97/training.git@fix-bug +instructlab-training[rocm]>=0.9.0 diff --git a/src/instructlab/configuration.py b/src/instructlab/configuration.py index 6c263fffb4..ef72e6ec4f 100644 --- a/src/instructlab/configuration.py +++ b/src/instructlab/configuration.py @@ -7,7 +7,6 @@ import enum import logging import os -import pathlib import sys import textwrap import typing @@ -38,9 +37,6 @@ from typing_extensions import deprecated as Deprecated import click -# First Party -from instructlab.utils import get_model_arch, use_legacy_pretraining_format - # Local from . import log from .defaults import ( @@ -1572,6 +1568,7 @@ def map_train_to_library(ctx, params): click.secho(f"failed to get model with `--model-id`: {ve}", fg="red") raise click.exceptions.Exit(1) params["model_path"] = model_cfg.path + train_args.model_path = model_cfg.path ds_args = DeepSpeedOptions( cpu_offload_optimizer=params["deepspeed_cpu_offload_optimizer"], @@ -1617,14 +1614,8 @@ def map_train_to_library(ctx, params): if params["pipeline"] == "full": train_args.disable_flash_attn = True - student_model_arch = get_model_arch(pathlib.Path(params["model_path"])) if ctx.obj.config.general.use_legacy_tmpl: train_args.use_legacy_tmpl = True - else: - train_args.use_legacy_tmpl = use_legacy_pretraining_format( - params["model_path"], - student_model_arch, - ) return train_args, torch_args diff --git a/src/instructlab/model/full_train.py b/src/instructlab/model/full_train.py index 5c2e419595..735abfb76b 100644 --- a/src/instructlab/model/full_train.py +++ b/src/instructlab/model/full_train.py @@ -20,6 +20,87 @@ logger = logging.getLogger(__name__) +def convert_loss_to_reduce_sum(model): + """ + this is necessary because multipack changes the samples per gpu, which biases the gradients to be larger for batches with less samples but longer lengths. + """ + # Standard + from typing import List, Optional + + # Third Party + import torch + + def reduce_sum_forward( + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + # pylint: disable=unused-argument + **deprecated_arguments, + ): + output = model.__original_forward__( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + ) + + logits = None + loss = None + return_dict = isinstance(output, dict) + if return_dict: + logits = output.logits + else: + # just checks that the output from the model is in the shape we expect, + # and that one of the tuple elements is the loss and one is the logits + if not ( + len(output) == 2 + and ( + (len(output[0].shape) == 3 and len(output[1].shape) == 0) + or (len(output[1].shape) == 3 and len(output[0].shape) == 0) + ) + ): + raise ValueError( + "Output does not match the expected structure. " + "Expected a tuple of length 2 with one element having shape of rank 3 and the other of rank 0." + ) + logits = output[0] if len(output[0].shape) == 3 else output[1] + + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, model.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Ensure tensors are on the same device + shift_labels = shift_labels.to(shift_logits.device) + loss_fct = torch.nn.CrossEntropyLoss(reduction="sum") + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + return ((loss,) + output) if loss is not None else output + + output.loss = loss + return output + + model.__original_forward__ = model.forward + model.forward = reduce_sum_forward + return model + + def setup_model( train_args, tokenizer, @@ -30,7 +111,7 @@ def setup_model( ): # pylint: disable=no-name-in-module # Third Party - from instructlab.training import multipack_sampler, utils + from instructlab.training import multipack_sampler from torch.utils.data import DataLoader collate_fn = partial(pad_collate_fn, pad_token_id=tokenizer.pad_token_id) @@ -121,8 +202,7 @@ def setup_model( if tokenizer.eos_token_id is not None and model.config.eos_token_id is None: model.config.eos_token_id = tokenizer.eos_token_id - model = utils.convert_loss_to_reduce_sum(model) - model = utils.add_noisy_embeddings(model, noise_alpha=None) + model = convert_loss_to_reduce_sum(model) return model, dataloader @@ -212,6 +292,12 @@ def train(train_args, device, optimize_memory): inner_pb = tqdm(range(len(dataloader)), desc=f"Epoch {epoch}") aggregated_values = torch.zeros(3, dtype=torch.float32).to(dev) + # in order to correctly calculate the loss, we need to divide each microbatch by + # a constant factor, so that we can later correct it by the actual `total_minibatch_tokens` amount + total_minibatch_tokens = 0 + interim_batch_denominator = packing_max_batch_len * accum + loss_accum = 0.0 # track this for logging puproses + for step, batch in enumerate(dataloader): aggregated_values[0] = batch.pop("num_loss_counted_tokens") aggregated_values[1] = len(batch["input_ids"]) @@ -220,16 +306,38 @@ def train(train_args, device, optimize_memory): for k in batch: batch[k] = batch[k].to(device=dev) - output = model(**batch, use_cache=False, return_dict_in_generate=True) - loss = output.loss + output = model(**batch, use_cache=False, return_dict=False) + loss = None + if isinstance(output, tuple): + loss = output[0] + if len(output[0].shape) != 0: + raise ValueError( + "When output is a tuple, the loss should be the first element" + ) + else: + loss = output.loss + if loss is None: + raise ValueError( + "Loss is None. Ensure the model's output contains a valid loss." + ) + aggregated_values[2] = loss.item() num_loss_counted_tokens = aggregated_values[0] - loss = loss / num_loss_counted_tokens + total_minibatch_tokens += num_loss_counted_tokens - loss = loss / accum # Scale the loss for accumulation steps + # here we need to correctly rescale the loss, so we divide by the packing_max_batch_len + # in order to overshoot the average, and then we will later multiply each gradient + # by a correction term + loss_orig = loss.detach().cpu().item() + loss_accum += loss + loss = loss / interim_batch_denominator - logger.info(f"\nEpoch: {epoch}, Step: {step + 1}, Rank: 0, loss = {loss}") + per_batch_loss = loss_orig / num_loss_counted_tokens + + logger.info( + f"\nEpoch: {epoch}, Step: {step + 1}, Loss per batch: {loss.detach().item()}, Actual Loss Per Batch = {per_batch_loss}, accumulated loss: {loss_accum.item()}" + ) # Gradient accumulation loss.backward() # Backward pass @@ -243,9 +351,20 @@ def train(train_args, device, optimize_memory): # if we are on a step which is divisible by 4, step and zero gradients if (step + 1) % accum == 0: + # lets correct all of the gradients + for param in model.parameters(): + grad = param.grad + assert grad is not None + correction_term = interim_batch_denominator / total_minibatch_tokens + param.grad *= correction_term + optimizer.step() # Optimizer step optimizer.zero_grad() # Zero gradients + # reset all of the accumulated data + total_minibatch_tokens = 0.0 + loss_accum = 0.0 + # Clear cache after optimizer step if dev.type == "mps": torch.mps.empty_cache() @@ -257,7 +376,7 @@ def train(train_args, device, optimize_memory): torch.mps.empty_cache() output_dir = ( - Path(train_args.ckpt_output_dir) / "hf_format" / f"samples_{(epoch*8)}" + Path(train_args.ckpt_output_dir) / "hf_format" / f"samples_{(epoch * 8)}" ) logger.info(f"Saving Model to: {output_dir}") @@ -329,7 +448,7 @@ def pad_collate_fn(batch, pad_token_id): ] ) logger.info( - f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()} - rank: {0} " + f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()}" f"max len: {max_len} min len: {min(lens)} avg len: {lens.mean()} " f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m" ) diff --git a/src/instructlab/train/linux_train.py b/src/instructlab/train/linux_train.py index 8ae79bc1fb..b651251f6b 100644 --- a/src/instructlab/train/linux_train.py +++ b/src/instructlab/train/linux_train.py @@ -256,16 +256,9 @@ def model_generate(user, **kwargs): stopping_criteria=stopping_criteria, do_sample=True, output_logits=True, - return_dict_in_generate=True, **kwargs, ) - # Access the sequences from the GenerateOutput object - if hasattr(outputs, "sequences"): - sequences = outputs.sequences - else: - # Fallback for tuple output - sequences = outputs - return tokenizer.batch_decode([seq[:-1] for seq in sequences])[0] + return tokenizer.batch_decode([o[:-1] for o in outputs])[0] assistant_old_lst = [ model_generate(d["user"]).split(response_template.strip())[-1].strip() diff --git a/tests/test_package.py b/tests/test_package.py index 0ce9c2f67d..7b960cff19 100644 --- a/tests/test_package.py +++ b/tests/test_package.py @@ -35,10 +35,10 @@ def test_provides_extra(): assert set(m.get_all("Provides-Extra")).issuperset(HW_EXTRAS) -# def test_require_no_url_req(): -# # PyPI does not accept packages with URL requirements -# for req in iter_requirements(): -# assert req.url is None, req +def test_require_no_url_req(): + # PyPI does not accept packages with URL requirements + for req in iter_requirements(): + assert req.url is None, req @pytest.mark.parametrize("hw_extra", sorted(HW_EXTRAS))