diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 19ee76f3..3a2ed8a5 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -2,14 +2,12 @@ # Standard from copy import deepcopy -from pathlib import Path import argparse import datetime import functools import logging import math import os -import re import subprocess import time import warnings @@ -32,10 +30,8 @@ try: # Third Party from deepspeed.ops.adam import FusedAdam - from deepspeed.runtime.zero.utils import ZeRORuntimeException except ImportError: FusedAdam = None - ZeRORuntimeException = None local_rank = int(os.getenv("LOCAL_RANK", "0")) if __name__ == "__main__" and (not local_rank or local_rank == 0): warnings.warn( @@ -83,7 +79,6 @@ ensure_loadable_dolomite_checkpoint, load_latest_full_state, prepare_peft_model, - prepare_universal_checkpoint_from_latest, save_checkpoint, save_hf_format_accelerate, set_random_seed, @@ -298,63 +293,6 @@ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-a return model, lr_scheduler, optimizer, accelerator -# this function is to check if the checkpoint provided can be resumed -def maybe_resume_training(args, model): - local_rank = int(os.environ["LOCAL_RANK"]) - - # DS's loading function will not raise if fails to reload a checkpoint - # - if lora is used, then the checkpoints will only be for the adapters - # so we need to disable load_module_strict - # - load checkpoint will find the latest checkpoint - # - it will also load the optimizer and scheduler states by default - load_module_strict = args.lora_r == 0 # can only be true if lora is not used - output_dir = Path(args.output_dir) / "ds_native" - - try: - # attempt to load a regular checkpoint first - model.load_checkpoint(output_dir, load_module_strict=load_module_strict) - except ZeRORuntimeException as e: - if str(e).startswith("The checkpoint being loaded used a DP world size of"): - # if it fails with the above exception, then a universal - # checkpoint is required - - # prepare the universal checkpoint - # - by reading 'latest' to get the resumable checkpoint - prepare_universal_checkpoint_from_latest(output_dir) - - # need to do this to trigger the universal checkpoint - # loading - model._config.load_universal_checkpoint = True - - # then attempt to load again - model.load_checkpoint(output_dir, load_module_strict=load_module_strict) - - # reset to regular checkpoint loading - model._config.load_universal_checkpoint = False - else: - raise e # reraise - - # do this to figure out the last_step - latest_file = output_dir / "latest" - try: - with open(latest_file) as f: - # there is some assumption here that the ds_native - # checkpoints are tagged as _(samples_seen) - step_folder = f.read() - (samples_seen,) = re.match("\w+_(\d+)", step_folder).groups() - samples_seen = int(samples_seen) - - last_step = samples_seen // args.effective_batch_size - args.__dict__["last_step"] = last_step - if local_rank == 0: - logger.info("Found checkpoint at %d, resuming training", last_step) - except FileNotFoundError: - pass - - # we will update the start step here - return model - - def train( args, model, @@ -512,16 +450,6 @@ def train( base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) torch.distributed.barrier() - # if ( - # args.save_samples_ds is not None - # and global_step * batch_size % args.save_samples_ds == 0 - # ): - # save_model_ds_native( - # args, - # model, - # tokenizer, - # global_step * args.samples_per_gpu * world_size, - # ) global_step += 1 if local_rank == 0: inner_pb.update(1) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 937c1f03..270bc33c 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -642,137 +642,6 @@ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-a return model -def prepare_universal_checkpoint_from_latest(output_dir): - """Populate the universal checkpoint in output_dir/step_folder - - 1. read output_dir/latest to get step_folder - - 2. populate tmp dir in output_dir/step_folder/tmp - - 3. populate zero checkpoints in output_dir/step_folder/zero - - 4. create output_dir/latest_universal - - Items 1, 2, 3, 4 are idempotent. There is atomicity in the sense that - only after 4 is completed, then the output_dir/latest_universal - checkpoint is created in which then the universal checkpoint - can be loaded. - - Be aware that this creates an extra dir `zero/` in the checkpoint dir, - which doubles the DS checkpoint storage requirement. - - DS checkpoints store 3X model parameters in 32bit. - - e.g., will be 6X a model parameter-only checkpoint in 16bit. - - Note that this requires a latest version of deepspeed. It kind of works if - the model is not saving universal checkpoint info, but only in the - the case where advanced features like tensor parallel (TP) and - pipeline parallel (PP) are turned off. - """ - - log_rank_0( - f"\033[93mPreparing universal checkpoint in {output_dir}\033[0m", to_print=True - ) - # Third Party - from transformers.utils.import_utils import _is_package_available - - _, ds_version = _is_package_available("deepspeed", return_version=True) - if ds_version < "0.14.3": - raise ValueError("universal checkpoint only supported on deepspeed >= 0.14.3") - - start = time.time() - if torch.distributed.get_rank() == 0: - try: - # Third Party - from deepspeed.checkpoint import DeepSpeedCheckpoint - from deepspeed.checkpoint.ds_to_universal import ( - PARAM_SHAPES, - UNIVERSAL_CHECKPOINT_INFO, - _check_for_required_state, - _extract_zero_shard_files, - _merge_tp_slice_files, - _save_optimizer_state, - ) - except ImportError as exc: - raise ImportError( - "DeepSpeed-specific checkpoints cannot be saved without DeepSpeed>=0.14.3 installed" - ) from exc - - # read the latest file to get the step folder - latest_file = output_dir / "latest" - with open(latest_file) as f: - step_folder = f.read() - - # will process the checkpoint in the latest step folder - input_folder = os.path.join(output_dir, step_folder) - - # create args for the scripts below - class UniversalCheckpointArgs: - num_extract_workers: int = 1 - num_merge_workers: int = 1 - output_folder: str = input_folder # just put in same place - strict: bool = True # strict checkpoint - - args = UniversalCheckpointArgs() - - # get the checkpoint - ds_checkpoint = DeepSpeedCheckpoint(input_folder) - - # hack, force this to null if we did not properly save - # any universal checkpoint information - # - this will not support any pipeline replication and other - # replication such as TP, row parallelism, vocab, sub_params - if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state: - warnings.warn( - "Universal checkpoint information not found, setting it to " - "an empty dictionary." - ) - ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {} - assert ds_checkpoint.tp_degree == 1, ( - "if universal checkpointing info is missing, TP must be absent" - ) - assert ds_checkpoint.pp_degree == 1, ( - "if universal checkpointing info is missing, PP must be absent" - ) - _check_for_required_state(ds_checkpoint) - - slice_shapes = [] - for mp_rank_file in ds_checkpoint.mp_rank_files: - mp_sd = torch.load(mp_rank_file, map_location=torch.device("cpu")) - slice_shapes += mp_sd[PARAM_SHAPES] - - # fix back to normal flat dict, merge duplicates for tp>1 - slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items()) - temp_dir = os.path.join(args.output_folder, "tmp") - - log_rank_0( - f"\033[93m1. Extracting ZeRO fragments into {temp_dir}\033[0m", - to_print=True, - ) - _extract_zero_shard_files(args, ds_checkpoint, temp_dir) - - zero_output_folder = os.path.join(args.output_folder, "zero") - - log_rank_0( - f"\033[93m2. Merging slices into {zero_output_folder}\033[0m", to_print=True - ) - _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir) - - log_rank_0( - f"\033[93m3. Saving common optimizer states into {zero_output_folder}\033[0m", - to_print=True, - ) - _save_optimizer_state(args, ds_checkpoint) - - log_rank_0( - f"\033[93m4. Removing temp directory {temp_dir}\033[0m", to_print=True - ) - shutil.rmtree(temp_dir, ignore_errors=True) - - latest_file = os.path.join(output_dir, "latest_universal") - log_rank_0(f"\033[93m5. Creating {latest_file}\033[0m", to_print=True) - with open(latest_file, "w") as f: - f.write(step_folder) - - dist.barrier() - log_rank_0(f"Preparing universal checkpoint took {time.time() - start} seconds") - - @contextmanager def ensure_loadable_dolomite_checkpoint( model_name_or_path: str, @@ -1050,44 +919,6 @@ def _get_state_dict_patched(model, unwrap=False): accelerator.get_state_dict = get_state_dict_unpatched -# this is native deepspeed saving with optimizer, scheduler -def save_model_ds_native( - args, - model, - tokenizer, # pylint: disable=unused-argument - samples_seen, -): - # to get a statedict from a zero checkpoint, all you need to do is - # - from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint - # - sd = get_fp32_state_dict_from_zero_checkpoint('ckpt') - # - sum([math.prod(x.shape) for x in sd.values()]) # check the size (should be correct) - - log_rank_0( - f"\033[93mSaving model+optimizer+scheduler in format at samples_seen: {samples_seen}\033[0m", - to_print=True, - ) - start = time.time() - # used to save huggingface format, so we can use it for hf.from_pretrained - output_dir = Path(args.output_dir) / "ds_native" - tag = f"samples_{samples_seen}" - use_lora = args.lora_r > 0 - - # NOTE: this is a distributed save - # if its lora, we only save the adapters - # - so we exclude frozen if use_lora==True - model.save_checkpoint( - output_dir, - exclude_frozen_parameters=use_lora, - tag=tag, # this will create the subdirectory with the correct name - ) - - # for now we are not saving tokenizer, config, eg.. - # so it is not totally "HF compatible" - - log_rank_0(f"\033[93mModel saved in {output_dir}\033[0m", to_print=True) - log_rank_0(f"saving took {time.time() - start} seconds") - - def set_random_seed(seed): if seed is not None: random.seed(seed)