Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ add_eos: True
# Dataset
per_device_batch_size: 12.0
expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else total_hosts//expansion_factor_real_data will pull data from GCS.
eval_per_device_batch_size: 0
eval_per_device_batch_size: 0.0
max_corpus_chars: 10_000_000
train_data_column: 'text'
eval_data_column: 'text'
Expand Down
7 changes: 4 additions & 3 deletions MaxText/inference_microbenchmark_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
import pyconfig

try:
JaxRuntimeError = jax.errors.JaxRuntimeError # added in JAX 0.4.34
JaxRuntimeError = jax.errors.JaxRuntimeError # added in JAX 0.4.34
except AttributeError:
from jax._src.lib import xla_extension
JaxRuntimeError = xla_extension.XlaRuntimeError
from jax._src.lib import xla_extension

JaxRuntimeError = xla_extension.XlaRuntimeError


def main():
Expand Down
49 changes: 27 additions & 22 deletions MaxText/input_pipeline/_grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def preprocessing_pipeline(
return multihost_gen


def make_grain_iterator(
def make_grain_train_iterator(
config: ml_collections.ConfigDict,
global_mesh,
process_indices,
Expand All @@ -130,25 +130,30 @@ def make_grain_iterator(
add_bos=config.add_bos,
add_eos=config.add_eos,
)
return train_iter

if config.eval_interval > 0:
eval_ds = get_datasets(config.grain_eval_files)
eval_iter = preprocessing_pipeline(
dataset=eval_ds,
tokenizer_path=config.tokenizer_path,
global_batch_size=config.global_batch_size_to_load,
global_mesh=global_mesh,
max_target_length=config.max_target_length,
grain_worker_count=config.grain_worker_count,
dataloading_host_index=process_indices.index(jax.process_index()),
dataloading_host_count=len(process_indices),
data_column=config.eval_data_column,
shuffle=False,
data_shuffle_seed=config.data_shuffle_seed,
tokenize=config.tokenize_eval_data,
add_bos=config.add_bos,
add_eos=config.add_eos,
)
else:
eval_iter = None
return train_iter, eval_iter

def make_grain_eval_iterator(
config: ml_collections.ConfigDict,
global_mesh,
process_indices,
):

eval_ds = get_datasets(config.grain_eval_files)
eval_iter = preprocessing_pipeline(
dataset=eval_ds,
tokenizer_path=config.tokenizer_path,
global_batch_size=config.global_batch_size_to_load_eval,
global_mesh=global_mesh,
max_target_length=config.max_target_length,
grain_worker_count=config.grain_worker_count,
dataloading_host_index=process_indices.index(jax.process_index()),
dataloading_host_count=len(process_indices),
data_column=config.eval_data_column,
shuffle=False,
data_shuffle_seed=config.data_shuffle_seed,
tokenize=config.tokenize_eval_data,
add_bos=config.add_bos,
add_eos=config.add_eos,
)
return eval_iter
83 changes: 41 additions & 42 deletions MaxText/input_pipeline/_hf_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ def preprocessing_pipeline(
return multihost_gen


def make_hf_iterator(
def make_hf_train_iterator(
config: ml_collections.ConfigDict,
global_mesh,
process_indices,
process_indices_train,
):
"""Load, preprocess dataset and return iterators"""
train_ds = datasets.load_dataset(
Expand All @@ -143,8 +143,8 @@ def make_hf_iterator(
token=config.hf_access_token,
)
train_iter = preprocessing_pipeline(
dataloading_host_index=process_indices.index(jax.process_index()),
dataloading_host_count=len(process_indices),
dataloading_host_index=process_indices_train.index(jax.process_index()),
dataloading_host_count=len(process_indices_train),
global_mesh=global_mesh,
dataset=train_ds,
data_column_name=config.train_data_column,
Expand All @@ -159,43 +159,42 @@ def make_hf_iterator(
add_eos=config.add_eos,
generate_padding_example=True,
)
return train_iter

if config.eval_interval > 0:
eval_ds = datasets.load_dataset(
config.hf_path,
data_dir=config.hf_data_dir,
data_files=config.hf_eval_files,
split=config.hf_eval_split,
streaming=True,
token=config.hf_access_token,
)
if config.eval_per_device_batch_size > 0:
eval_batch_size = config.eval_per_device_batch_size * global_mesh.size
else:
eval_batch_size = config.global_batch_size_to_load

if config.eval_steps > 0:
eval_generate_padding_example = True
else:
eval_generate_padding_example = False
eval_iter = preprocessing_pipeline(
dataloading_host_index=process_indices.index(jax.process_index()),
dataloading_host_count=len(process_indices),
global_mesh=global_mesh,
dataset=eval_ds,
data_column_name=config.eval_data_column,
tokenize=config.tokenize_eval_data,
tokenizer_path=config.tokenizer_path,
hf_access_token=config.hf_access_token,
global_batch_size=eval_batch_size,
max_target_length=config.max_target_length,
shuffle=False,
data_shuffle_seed=config.data_shuffle_seed,
add_bos=config.add_bos,
add_eos=config.add_eos,
generate_padding_example=eval_generate_padding_example,
)
else:
eval_iter = None

return train_iter, eval_iter
def make_hf_eval_iterator(
config: ml_collections.ConfigDict,
global_mesh,
process_indices_eval,
):
eval_ds = datasets.load_dataset(
config.hf_path,
data_dir=config.hf_data_dir,
data_files=config.hf_eval_files,
split=config.hf_eval_split,
streaming=True,
token=config.hf_access_token,
)

if config.eval_steps > 0:
eval_generate_padding_example = True
else:
eval_generate_padding_example = False
eval_iter = preprocessing_pipeline(
dataloading_host_index=process_indices_eval.index(jax.process_index()),
dataloading_host_count=len(process_indices_eval),
global_mesh=global_mesh,
dataset=eval_ds,
data_column_name=config.eval_data_column,
tokenize=config.tokenize_eval_data,
tokenizer_path=config.tokenizer_path,
hf_access_token=config.hf_access_token,
global_batch_size=config.global_batch_size_to_load_eval,
max_target_length=config.max_target_length,
shuffle=False,
data_shuffle_seed=config.data_shuffle_seed,
add_bos=config.add_bos,
add_eos=config.add_eos,
generate_padding_example=eval_generate_padding_example,
)
return eval_iter
67 changes: 33 additions & 34 deletions MaxText/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,19 @@ def preprocessing_pipeline(
return multihost_gen


def make_tfds_iterator(
def make_tfds_train_iterator(
config: ml_collections.ConfigDict,
global_mesh,
process_indices,
process_indices_train,
):
"""load dataset, preprocess and return iterators"""
train_ds = get_datasets(
dataset_name=config.dataset_name,
data_split="train",
shuffle_files=config.enable_data_shuffling,
shuffle_seed=config.data_shuffle_seed,
dataloading_host_index=process_indices.index(jax.process_index()),
dataloading_host_count=len(process_indices),
dataloading_host_index=process_indices_train.index(jax.process_index()),
dataloading_host_count=len(process_indices_train),
)
train_iter = preprocessing_pipeline(
dataset=train_ds,
Expand All @@ -162,36 +162,35 @@ def make_tfds_iterator(
add_bos=config.add_bos,
add_eos=config.add_eos,
)
return train_iter

if config.eval_interval > 0:
eval_ds = get_datasets(
dataset_name=config.eval_dataset_name,
data_split=config.eval_split,
shuffle_files=False,
shuffle_seed=config.data_shuffle_seed,
dataloading_host_index=process_indices.index(jax.process_index()),
dataloading_host_count=len(process_indices),
)

if config.eval_per_device_batch_size > 0:
eval_batch_size = config.eval_per_device_batch_size * global_mesh.size
else:
eval_batch_size = config.global_batch_size_to_load

eval_iter = preprocessing_pipeline(
dataset=eval_ds,
tokenizer_path=config.tokenizer_path,
global_batch_size=eval_batch_size,
global_mesh=global_mesh,
max_target_length=config.max_target_length,
data_column_name=config.eval_data_column,
shuffle=False,
data_shuffle_seed=config.data_shuffle_seed,
tokenize=config.tokenize_eval_data,
add_bos=config.add_bos,
add_eos=config.add_eos,
)
else:
eval_iter = None
def make_tfds_eval_iterator(
config: ml_collections.ConfigDict,
global_mesh,
process_indices_eval,
):
eval_ds = get_datasets(
dataset_name=config.eval_dataset_name,
data_split=config.eval_split,
shuffle_files=False,
shuffle_seed=config.data_shuffle_seed,
dataloading_host_index=process_indices_eval.index(jax.process_index()),
dataloading_host_count=len(process_indices_eval),
)

eval_iter = preprocessing_pipeline(
dataset=eval_ds,
tokenizer_path=config.tokenizer_path,
global_batch_size=config.global_batch_size_to_load_eval,
global_mesh=global_mesh,
max_target_length=config.max_target_length,
data_column_name=config.eval_data_column,
shuffle=False,
data_shuffle_seed=config.data_shuffle_seed,
tokenize=config.tokenize_eval_data,
add_bos=config.add_bos,
add_eos=config.add_eos,
)

return train_iter, eval_iter
return eval_iter
11 changes: 1 addition & 10 deletions MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,6 @@ def _add_pad(x):
return ds.concatenate(pad_ds)


def get_eval_global_batch_size_to_load(config: ml_collections.ConfigDict, global_mesh) -> int:
"""Calculate the global batch size for evaluation."""
if config.eval_per_device_batch_size > 0:
return config.eval_per_device_batch_size * global_mesh.size
else:
return config.global_batch_size_to_load


def get_dataset(
dataset_name: str,
split: str,
Expand Down Expand Up @@ -347,11 +339,10 @@ def make_c4_mlperf_eval_iterator(
# note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and split to target_length
# mainly to avoid eval sequences change depending on the number of hosts
eval_ds = rekey(eval_ds, {"inputs": None, "targets": "ids"})
eval_global_batch_size_to_load = get_eval_global_batch_size_to_load(config, global_mesh)

eval_ds = preprocess_eval_dataset(
eval_ds,
eval_global_batch_size_to_load=eval_global_batch_size_to_load,
eval_global_batch_size_to_load=config.global_batch_size_to_load_eval,
max_target_length=config.max_target_length,
)

Expand Down
Loading