Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Add a `SaveArgs` option that allows disabling pinned host transfer on a per-array basis.

## [0.7.0] - 2024-10-07

### Removed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,28 @@ def _get_kvstore_for_gcs(ckpt_path: str) -> JsonSpec:
return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket}


def _get_kvstore_for_yt(ckpt_path: str):
if "TS_GRPC_ADDRESS" not in os.environ:
raise ValueError(
'yt:// scheme requires TS_GRPC_ADDRESS environment variable to be set.'
)

grpc_address = os.environ.get('TS_GRPC_ADDRESS')

path_without_prefix = ckpt_path.removeprefix("yt:")

return {
'driver': 'tsgrpc_kvstore',
'address': grpc_address,
'path': path_without_prefix,
'timeout': '1h',
}


def _get_default_kvstore(ckpt_path: str):
return {'driver': DEFAULT_DRIVER, 'path': ckpt_path}


def build_kvstore_tspec(
directory: str,
name: Optional[str] = None,
Expand All @@ -79,15 +101,19 @@ def build_kvstore_tspec(
Returns:
A Tensorstore KvStore spec in dictionary form.
"""
default_driver = DEFAULT_DRIVER
# Normalize path to exclude trailing '/'. In GCS path case, we will need to
# fix the path prefix to add back the stripped '/'.
directory = os.path.normpath(directory).replace('gs:/', 'gs://')

directory = os.path.normpath(directory)\
.replace('gs:/', 'gs://')\
.replace('yt:/', 'yt://')
is_gcs_path = directory.startswith('gs://')
is_yt_path = directory.startswith('yt://')
is_special_path = is_gcs_path or is_yt_path
kv_spec = {}

if use_ocdbt:
if not is_gcs_path and not os.path.isabs(directory):
if not is_special_path and not os.path.isabs(directory):
raise ValueError(f'Checkpoint path should be absolute. Got {directory}')
if process_id is not None:
process_id = str(process_id)
Expand All @@ -99,11 +125,14 @@ def build_kvstore_tspec(
directory = os.path.join(
directory, f'{PROCESS_SUBDIR_PREFIX}{process_id}'
)
base_driver_spec = (
directory
if is_gcs_path
else {'driver': default_driver, 'path': str(directory)}
)

if is_gcs_path:
base_driver_spec = directory
elif is_yt_path:
base_driver_spec = _get_kvstore_for_yt(directory)
else:
base_driver_spec = _get_default_kvstore(str(directory))

kv_spec.update({
'driver': 'ocdbt',
'base': base_driver_spec,
Expand All @@ -126,10 +155,13 @@ def build_kvstore_tspec(
path = directory
else:
path = os.path.join(directory, name)

if is_gcs_path:
kv_spec = _get_kvstore_for_gcs(path)
elif is_yt_path:
kv_spec = _get_kvstore_for_yt(path)
else:
kv_spec = {'driver': default_driver, 'path': path}
kv_spec = _get_default_kvstore(path)

return kv_spec

Expand Down
18 changes: 14 additions & 4 deletions checkpoint/orbax/checkpoint/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,19 @@ def block_until_ready(self):
)


def _transfer_shard_to_host(shard: jax.Shard) -> jax.Array:
def _transfer_shard_to_host(
shard: jax.Shard, enable_pinned_host_transfer: bool
) -> jax.Array:
"""Asynchronously transfers a shard to host memory. Does not block."""
data = shard.data
has_pinned_host = any(
m.kind == 'pinned_host' for m in shard.device.addressable_memories()
)
if jax._src.config.enable_memories.value and has_pinned_host: # pylint: disable=protected-access
if (
enable_pinned_host_transfer
and has_pinned_host
and jax._src.config.enable_memories.value # pylint: disable=protected-access
):
# If available, transfer to pinned host memory
sharding = jax.sharding.SingleDeviceSharding(
shard.device, memory_kind='pinned_host'
Expand All @@ -293,12 +299,16 @@ def get_shards_for_host_transfer(
]


def transfer_array_to_host(arr: jax.Array, replica_id: int) -> Shards:
def transfer_array_to_host(
arr: jax.Array, replica_id: int, *, enable_pinned_host_transfer: bool = True
) -> Shards:
"""Transfers a jax.Array to host memory."""
shard_data = []
dedup_shards = get_shards_for_host_transfer(arr, replica_id)
for shard in dedup_shards:
shard_data.append(_transfer_shard_to_host(shard))
shard_data.append(
_transfer_shard_to_host(shard, enable_pinned_host_transfer)
)

return Shards(
[
Expand Down
12 changes: 10 additions & 2 deletions checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,15 @@ class SaveArgs:
specified chunk_byte_size. Both the write_chunk_shape and read_chunk_shape
are automatically set to the chosen shape. This uses a greedy algorithm that
prioritizes splitting the largest dimensions first.
enable_pinned_host_transfer:
True by default. If False, disables transfer to pinned host when copying
from device to host, regardless of the presence of pinned host memory.
"""

aggregate: bool = False
dtype: Optional[jnp.dtype] = None
chunk_byte_size: Optional[int] = None
enable_pinned_host_transfer: bool = True

def __post_init__(self):
if self.aggregate:
Expand Down Expand Up @@ -1315,8 +1319,12 @@ async def serialize(

# Start D2H transfer in parallel for each array.
host_shards = [
serialization.transfer_array_to_host(value, self._get_replica_id(value))
for value in values
serialization.transfer_array_to_host(
value,
self._get_replica_id(value),
enable_pinned_host_transfer=arg.enable_pinned_host_transfer,
)
for value, arg in zip(values, args)
]
jax.tree.map(lambda x: x.block_until_ready(), host_shards)

Expand Down