Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ instructlab-schema>=0.4.2
instructlab-sdg>=0.7.3
# XXX(osilkin): we need to pin this version for now while we improve tokenizer logic
# see: https://github.com/instructlab/training/pull/428
instructlab-training>=0.8.0
instructlab-training>=0.9.0
llama_cpp_python[server]==0.3.6
mlx>=0.5.1,<0.6.0; sys_platform == 'darwin' and platform_machine == 'arm64'
numpy>=1.26.4,<2.0.0
Expand All @@ -34,9 +34,7 @@ toml>=0.10.2
# Default version. Can be overridden in extra requirements
torch>=2.3.0,<2.6.0
tqdm>=4.66.2
# temporary cap until https://github.com/instructlab/training/pull/443 is merged and consumed within instructlab
# above PR fixes interactions with newer versions of transformers through the training library
transformers>=4.41.2,<4.51.0
transformers>=4.41.2
trl>=0.12.2,<0.15.0
wandb>=0.16.4
xdg-base-dirs>=6.0.1
Expand Down
2 changes: 1 addition & 1 deletion requirements/cuda.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Extra dependencies for NVIDIA CUDA
instructlab-training[cuda]>=0.8.0
instructlab-training[cuda]>=0.9.0
2 changes: 1 addition & 1 deletion requirements/hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ habana_gpu_migration>=1.18.0
#habana-torch-dataloader

# Extra dependencies for Intel Gaudi cards
instructlab-training[hpu]>=0.8.0
instructlab-training[hpu]>=0.9.0
2 changes: 1 addition & 1 deletion requirements/rocm.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Extra dependencies for AMD ROCm
instructlab-training[rocm]>=0.8.0
instructlab-training[rocm]>=0.9.0
11 changes: 1 addition & 10 deletions src/instructlab/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import enum
import logging
import os
import pathlib
import sys
import textwrap
import typing
Expand Down Expand Up @@ -38,9 +37,6 @@
from typing_extensions import deprecated as Deprecated
import click

# First Party
from instructlab.utils import get_model_arch, use_legacy_pretraining_format

# Local
from . import log
from .defaults import (
Expand Down Expand Up @@ -1572,6 +1568,7 @@ def map_train_to_library(ctx, params):
click.secho(f"failed to get model with `--model-id`: {ve}", fg="red")
raise click.exceptions.Exit(1)
params["model_path"] = model_cfg.path
train_args.model_path = model_cfg.path

ds_args = DeepSpeedOptions(
cpu_offload_optimizer=params["deepspeed_cpu_offload_optimizer"],
Expand Down Expand Up @@ -1617,14 +1614,8 @@ def map_train_to_library(ctx, params):
if params["pipeline"] == "full":
train_args.disable_flash_attn = True

student_model_arch = get_model_arch(pathlib.Path(params["model_path"]))
if ctx.obj.config.general.use_legacy_tmpl:
train_args.use_legacy_tmpl = True
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this change mean that legacy tokenizer config won't be detected anymore?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@booxter We don't need it anymore

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need the change or we don't need support for legacy auto-detection? And is this change related to transformers cap removal / training dependency update?

train_args.use_legacy_tmpl = use_legacy_pretraining_format(
params["model_path"],
student_model_arch,
)
return train_args, torch_args


Expand Down
139 changes: 129 additions & 10 deletions src/instructlab/model/full_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,87 @@
logger = logging.getLogger(__name__)


def convert_loss_to_reduce_sum(model):
"""
this is necessary because multipack changes the samples per gpu, which biases the gradients to be larger for batches with less samples but longer lengths.
"""
# Standard
from typing import List, Optional

# Third Party
import torch

def reduce_sum_forward(
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
# pylint: disable=unused-argument
**deprecated_arguments,
):
output = model.__original_forward__(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
)

logits = None
loss = None
return_dict = isinstance(output, dict)
if return_dict:
logits = output.logits
else:
# just checks that the output from the model is in the shape we expect,
# and that one of the tuple elements is the loss and one is the logits
if not (
len(output) == 2
and (
(len(output[0].shape) == 3 and len(output[1].shape) == 0)
or (len(output[1].shape) == 3 and len(output[0].shape) == 0)
)
):
raise ValueError(
"Output does not match the expected structure. "
"Expected a tuple of length 2 with one element having shape of rank 3 and the other of rank 0."
)
logits = output[0] if len(output[0].shape) == 3 else output[1]

if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Ensure tensors are on the same device
shift_labels = shift_labels.to(shift_logits.device)
loss_fct = torch.nn.CrossEntropyLoss(reduction="sum")
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
return ((loss,) + output) if loss is not None else output

output.loss = loss
return output

model.__original_forward__ = model.forward
model.forward = reduce_sum_forward
return model


def setup_model(
train_args,
tokenizer,
Expand All @@ -30,7 +111,7 @@ def setup_model(
):
# pylint: disable=no-name-in-module
# Third Party
from instructlab.training import multipack_sampler, utils
from instructlab.training import multipack_sampler
from torch.utils.data import DataLoader

collate_fn = partial(pad_collate_fn, pad_token_id=tokenizer.pad_token_id)
Expand Down Expand Up @@ -121,8 +202,7 @@ def setup_model(
if tokenizer.eos_token_id is not None and model.config.eos_token_id is None:
model.config.eos_token_id = tokenizer.eos_token_id

model = utils.convert_loss_to_reduce_sum(model)
model = utils.add_noisy_embeddings(model, noise_alpha=None)
model = convert_loss_to_reduce_sum(model)
return model, dataloader


Expand Down Expand Up @@ -212,6 +292,12 @@ def train(train_args, device, optimize_memory):
inner_pb = tqdm(range(len(dataloader)), desc=f"Epoch {epoch}")
aggregated_values = torch.zeros(3, dtype=torch.float32).to(dev)

# in order to correctly calculate the loss, we need to divide each microbatch by
# a constant factor, so that we can later correct it by the actual `total_minibatch_tokens` amount
total_minibatch_tokens = 0
interim_batch_denominator = packing_max_batch_len * accum
loss_accum = 0.0 # track this for logging puproses

for step, batch in enumerate(dataloader):
aggregated_values[0] = batch.pop("num_loss_counted_tokens")
aggregated_values[1] = len(batch["input_ids"])
Expand All @@ -220,16 +306,38 @@ def train(train_args, device, optimize_memory):
for k in batch:
batch[k] = batch[k].to(device=dev)

output = model(**batch, use_cache=False, return_dict=True)
loss = output.loss
output = model(**batch, use_cache=False, return_dict=False)
loss = None
if isinstance(output, tuple):
loss = output[0]
if len(output[0].shape) != 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if len(loss.shape) != 0? (Or even if loss.shape?)

raise ValueError(
"When output is a tuple, the loss should be the first element"
)
else:
loss = output.loss
if loss is None:
raise ValueError(
"Loss is None. Ensure the model's output contains a valid loss."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These errors will bubble up to users through CLI error. I am not sure they have the context to interpret the errors meaningfully. As a user, how do I "ensure the model's output contains a valid loss"? Are you asking the user to check their model inputs / config perhaps? If so, this is what should be communicated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This applies to other errors here I think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@booxter This shouldn't be happening, so if this happens then we want users to report this to us. They are also welcome to look at the internals of the CLI and potentially contribute back. We shouldn't assume that our users are incapable of being technical.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So it's always a bug in the code and not the inputs from the user? If so, the suggestion to "ensure the model's output" seems misplaced and we may instead want to direct the user to report the issue. (And maybe dump some more info to include with the report?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be a few things, the model implementation could have changed or the transformers API itself could have changed (as in the case here). So when that happens we just have this as a safeguard

)

aggregated_values[2] = loss.item()

num_loss_counted_tokens = aggregated_values[0]
loss = loss / num_loss_counted_tokens
total_minibatch_tokens += num_loss_counted_tokens

loss = loss / accum # Scale the loss for accumulation steps
# here we need to correctly rescale the loss, so we divide by the packing_max_batch_len
# in order to overshoot the average, and then we will later multiply each gradient
# by a correction term
loss_orig = loss.detach().cpu().item()
loss_accum += loss
loss = loss / interim_batch_denominator

logger.info(f"\nEpoch: {epoch}, Step: {step + 1}, Rank: 0, loss = {loss}")
per_batch_loss = loss_orig / num_loss_counted_tokens

logger.info(
f"\nEpoch: {epoch}, Step: {step + 1}, Loss per batch: {loss.detach().item()}, Actual Loss Per Batch = {per_batch_loss}, accumulated loss: {loss_accum.item()}"
)

# Gradient accumulation
loss.backward() # Backward pass
Expand All @@ -243,9 +351,20 @@ def train(train_args, device, optimize_memory):

# if we are on a step which is divisible by 4, step and zero gradients
if (step + 1) % accum == 0:
# lets correct all of the gradients
for param in model.parameters():
grad = param.grad
assert grad is not None
correction_term = interim_batch_denominator / total_minibatch_tokens
param.grad *= correction_term

optimizer.step() # Optimizer step
optimizer.zero_grad() # Zero gradients

# reset all of the accumulated data
total_minibatch_tokens = 0.0
loss_accum = 0.0

# Clear cache after optimizer step
if dev.type == "mps":
torch.mps.empty_cache()
Expand All @@ -257,7 +376,7 @@ def train(train_args, device, optimize_memory):
torch.mps.empty_cache()

output_dir = (
Path(train_args.ckpt_output_dir) / "hf_format" / f"samples_{(epoch*8)}"
Path(train_args.ckpt_output_dir) / "hf_format" / f"samples_{(epoch * 8)}"
)

logger.info(f"Saving Model to: {output_dir}")
Expand Down Expand Up @@ -329,7 +448,7 @@ def pad_collate_fn(batch, pad_token_id):
]
)
logger.info(
f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()} - rank: {0} "
f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()}"
f"max len: {max_len} min len: {min(lens)} avg len: {lens.mean()} "
f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m"
)
Expand Down
Loading