Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Add a `SaveArgs` option that allows disabling pinned host transfer on a per-array basis.

## [0.7.0] - 2024-10-07

### Removed
Expand Down
18 changes: 14 additions & 4 deletions checkpoint/orbax/checkpoint/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,19 @@ def block_until_ready(self):
)


def _transfer_shard_to_host(shard: jax.Shard) -> jax.Array:
def _transfer_shard_to_host(
shard: jax.Shard, enable_pinned_host_transfer: bool
) -> jax.Array:
"""Asynchronously transfers a shard to host memory. Does not block."""
data = shard.data
has_pinned_host = any(
m.kind == 'pinned_host' for m in shard.device.addressable_memories()
)
if jax._src.config.enable_memories.value and has_pinned_host: # pylint: disable=protected-access
if (
enable_pinned_host_transfer
and has_pinned_host
and jax._src.config.enable_memories.value # pylint: disable=protected-access
):
# If available, transfer to pinned host memory
sharding = jax.sharding.SingleDeviceSharding(
shard.device, memory_kind='pinned_host'
Expand All @@ -293,12 +299,16 @@ def get_shards_for_host_transfer(
]


def transfer_array_to_host(arr: jax.Array, replica_id: int) -> Shards:
def transfer_array_to_host(
arr: jax.Array, replica_id: int, *, enable_pinned_host_transfer: bool = True
) -> Shards:
"""Transfers a jax.Array to host memory."""
shard_data = []
dedup_shards = get_shards_for_host_transfer(arr, replica_id)
for shard in dedup_shards:
shard_data.append(_transfer_shard_to_host(shard))
shard_data.append(
_transfer_shard_to_host(shard, enable_pinned_host_transfer)
)

return Shards(
[
Expand Down
12 changes: 10 additions & 2 deletions checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,15 @@ class SaveArgs:
specified chunk_byte_size. Both the write_chunk_shape and read_chunk_shape
are automatically set to the chosen shape. This uses a greedy algorithm that
prioritizes splitting the largest dimensions first.
enable_pinned_host_transfer:
True by default. If False, disables transfer to pinned host when copying
from device to host, regardless of the presence of pinned host memory.
"""

aggregate: bool = False
dtype: Optional[jnp.dtype] = None
chunk_byte_size: Optional[int] = None
enable_pinned_host_transfer: bool = True

def __post_init__(self):
if self.aggregate:
Expand Down Expand Up @@ -1321,8 +1325,12 @@ async def serialize(

# Start D2H transfer in parallel for each array.
host_shards = [
serialization.transfer_array_to_host(value, self._get_replica_id(value))
for value in values
serialization.transfer_array_to_host(
value,
self._get_replica_id(value),
enable_pinned_host_transfer=arg.enable_pinned_host_transfer,
)
for value, arg in zip(values, args)
]
jax.tree.map(lambda x: x.block_until_ready(), host_shards)

Expand Down
Loading