Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 0 additions & 72 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

# Standard
from copy import deepcopy
from pathlib import Path
import argparse
import datetime
import functools
import logging
import math
import os
import re
import subprocess
import time
import warnings
Expand All @@ -32,10 +30,8 @@
try:
# Third Party
from deepspeed.ops.adam import FusedAdam
from deepspeed.runtime.zero.utils import ZeRORuntimeException
except ImportError:
FusedAdam = None
ZeRORuntimeException = None
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if __name__ == "__main__" and (not local_rank or local_rank == 0):
warnings.warn(
Expand Down Expand Up @@ -83,7 +79,6 @@
ensure_loadable_dolomite_checkpoint,
load_latest_full_state,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
save_checkpoint,
save_hf_format_accelerate,
set_random_seed,
Expand Down Expand Up @@ -298,63 +293,6 @@ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-a
return model, lr_scheduler, optimizer, accelerator


# this function is to check if the checkpoint provided can be resumed
def maybe_resume_training(args, model):
local_rank = int(os.environ["LOCAL_RANK"])

# DS's loading function will not raise if fails to reload a checkpoint
# - if lora is used, then the checkpoints will only be for the adapters
# so we need to disable load_module_strict
# - load checkpoint will find the latest checkpoint
# - it will also load the optimizer and scheduler states by default
load_module_strict = args.lora_r == 0 # can only be true if lora is not used
output_dir = Path(args.output_dir) / "ds_native"

try:
# attempt to load a regular checkpoint first
model.load_checkpoint(output_dir, load_module_strict=load_module_strict)
except ZeRORuntimeException as e:
if str(e).startswith("The checkpoint being loaded used a DP world size of"):
# if it fails with the above exception, then a universal
# checkpoint is required

# prepare the universal checkpoint
# - by reading 'latest' to get the resumable checkpoint
prepare_universal_checkpoint_from_latest(output_dir)

# need to do this to trigger the universal checkpoint
# loading
model._config.load_universal_checkpoint = True

# then attempt to load again
model.load_checkpoint(output_dir, load_module_strict=load_module_strict)

# reset to regular checkpoint loading
model._config.load_universal_checkpoint = False
else:
raise e # reraise

# do this to figure out the last_step
latest_file = output_dir / "latest"
try:
with open(latest_file) as f:
# there is some assumption here that the ds_native
# checkpoints are tagged as <something>_(samples_seen)
step_folder = f.read()
(samples_seen,) = re.match("\w+_(\d+)", step_folder).groups()
samples_seen = int(samples_seen)

last_step = samples_seen // args.effective_batch_size
args.__dict__["last_step"] = last_step
if local_rank == 0:
logger.info("Found checkpoint at %d, resuming training", last_step)
except FileNotFoundError:
pass

# we will update the start step here
return model


def train(
args,
model,
Expand Down Expand Up @@ -512,16 +450,6 @@ def train(
base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank)
torch.distributed.barrier()

# if (
# args.save_samples_ds is not None
# and global_step * batch_size % args.save_samples_ds == 0
# ):
# save_model_ds_native(
# args,
# model,
# tokenizer,
# global_step * args.samples_per_gpu * world_size,
# )
global_step += 1
if local_rank == 0:
inner_pb.update(1)
Expand Down
169 changes: 0 additions & 169 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,137 +642,6 @@ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-a
return model


def prepare_universal_checkpoint_from_latest(output_dir):
"""Populate the universal checkpoint in output_dir/step_folder
- 1. read output_dir/latest to get step_folder
- 2. populate tmp dir in output_dir/step_folder/tmp
- 3. populate zero checkpoints in output_dir/step_folder/zero
- 4. create output_dir/latest_universal

Items 1, 2, 3, 4 are idempotent. There is atomicity in the sense that
only after 4 is completed, then the output_dir/latest_universal
checkpoint is created in which then the universal checkpoint
can be loaded.

Be aware that this creates an extra dir `zero/` in the checkpoint dir,
which doubles the DS checkpoint storage requirement.
- DS checkpoints store 3X model parameters in 32bit.
- e.g., will be 6X a model parameter-only checkpoint in 16bit.

Note that this requires a latest version of deepspeed. It kind of works if
the model is not saving universal checkpoint info, but only in the
the case where advanced features like tensor parallel (TP) and
pipeline parallel (PP) are turned off.
"""

log_rank_0(
f"\033[93mPreparing universal checkpoint in {output_dir}\033[0m", to_print=True
)
# Third Party
from transformers.utils.import_utils import _is_package_available

_, ds_version = _is_package_available("deepspeed", return_version=True)
if ds_version < "0.14.3":
raise ValueError("universal checkpoint only supported on deepspeed >= 0.14.3")

start = time.time()
if torch.distributed.get_rank() == 0:
try:
# Third Party
from deepspeed.checkpoint import DeepSpeedCheckpoint
from deepspeed.checkpoint.ds_to_universal import (
PARAM_SHAPES,
UNIVERSAL_CHECKPOINT_INFO,
_check_for_required_state,
_extract_zero_shard_files,
_merge_tp_slice_files,
_save_optimizer_state,
)
except ImportError as exc:
raise ImportError(
"DeepSpeed-specific checkpoints cannot be saved without DeepSpeed>=0.14.3 installed"
) from exc

# read the latest file to get the step folder
latest_file = output_dir / "latest"
with open(latest_file) as f:
step_folder = f.read()

# will process the checkpoint in the latest step folder
input_folder = os.path.join(output_dir, step_folder)

# create args for the scripts below
class UniversalCheckpointArgs:
num_extract_workers: int = 1
num_merge_workers: int = 1
output_folder: str = input_folder # just put in same place
strict: bool = True # strict checkpoint

args = UniversalCheckpointArgs()

# get the checkpoint
ds_checkpoint = DeepSpeedCheckpoint(input_folder)

# hack, force this to null if we did not properly save
# any universal checkpoint information
# - this will not support any pipeline replication and other
# replication such as TP, row parallelism, vocab, sub_params
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
warnings.warn(
"Universal checkpoint information not found, setting it to "
"an empty dictionary."
)
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
assert ds_checkpoint.tp_degree == 1, (
"if universal checkpointing info is missing, TP must be absent"
)
assert ds_checkpoint.pp_degree == 1, (
"if universal checkpointing info is missing, PP must be absent"
)
_check_for_required_state(ds_checkpoint)

slice_shapes = []
for mp_rank_file in ds_checkpoint.mp_rank_files:
mp_sd = torch.load(mp_rank_file, map_location=torch.device("cpu"))
slice_shapes += mp_sd[PARAM_SHAPES]

# fix back to normal flat dict, merge duplicates for tp>1
slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items())
temp_dir = os.path.join(args.output_folder, "tmp")

log_rank_0(
f"\033[93m1. Extracting ZeRO fragments into {temp_dir}\033[0m",
to_print=True,
)
_extract_zero_shard_files(args, ds_checkpoint, temp_dir)

zero_output_folder = os.path.join(args.output_folder, "zero")

log_rank_0(
f"\033[93m2. Merging slices into {zero_output_folder}\033[0m", to_print=True
)
_merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir)

log_rank_0(
f"\033[93m3. Saving common optimizer states into {zero_output_folder}\033[0m",
to_print=True,
)
_save_optimizer_state(args, ds_checkpoint)

log_rank_0(
f"\033[93m4. Removing temp directory {temp_dir}\033[0m", to_print=True
)
shutil.rmtree(temp_dir, ignore_errors=True)

latest_file = os.path.join(output_dir, "latest_universal")
log_rank_0(f"\033[93m5. Creating {latest_file}\033[0m", to_print=True)
with open(latest_file, "w") as f:
f.write(step_folder)

dist.barrier()
log_rank_0(f"Preparing universal checkpoint took {time.time() - start} seconds")


@contextmanager
def ensure_loadable_dolomite_checkpoint(
model_name_or_path: str,
Expand Down Expand Up @@ -1050,44 +919,6 @@ def _get_state_dict_patched(model, unwrap=False):
accelerator.get_state_dict = get_state_dict_unpatched


# this is native deepspeed saving with optimizer, scheduler
def save_model_ds_native(
args,
model,
tokenizer, # pylint: disable=unused-argument
samples_seen,
):
# to get a statedict from a zero checkpoint, all you need to do is
# - from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
# - sd = get_fp32_state_dict_from_zero_checkpoint('ckpt')
# - sum([math.prod(x.shape) for x in sd.values()]) # check the size (should be correct)

log_rank_0(
f"\033[93mSaving model+optimizer+scheduler in format at samples_seen: {samples_seen}\033[0m",
to_print=True,
)
start = time.time()
# used to save huggingface format, so we can use it for hf.from_pretrained
output_dir = Path(args.output_dir) / "ds_native"
tag = f"samples_{samples_seen}"
use_lora = args.lora_r > 0

# NOTE: this is a distributed save
# if its lora, we only save the adapters
# - so we exclude frozen if use_lora==True
model.save_checkpoint(
output_dir,
exclude_frozen_parameters=use_lora,
tag=tag, # this will create the subdirectory with the correct name
)

# for now we are not saving tokenizer, config, eg..
# so it is not totally "HF compatible"

log_rank_0(f"\033[93mModel saved in {output_dir}\033[0m", to_print=True)
log_rank_0(f"saving took {time.time() - start} seconds")


def set_random_seed(seed):
if seed is not None:
random.seed(seed)
Expand Down
Loading