Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 24 additions & 50 deletions docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,75 +118,49 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
Policy copy decision tree in Collectors.

Weight Synchronization in Distributed Environments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--------------------------------------------------
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read the diagram above as

  1. CollectorServer: main thread of RayCollector
  2. Collector Worker {i}, remote DataCollector

If this read is correct, in my mind, it might sometimes make sense to have the receiver on the collector worker rather than the collector server
e.g. If the number of remote workers is sufficiently high, the collector worker might not be colocated with the collector server, in that case it might not make sense to pass the weights "two hops" to get to the worker

Separate qn -- from the diagram it looks like the collector server chooses when to pull from the param server and then "forcefully pushes" to all the workers at once. Is this design intentional? (e.g. Is the purpose of this to batch up workers to different collector servers and update them in batches?)


In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.

Local and Remote Weight Updaters
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Sending and receiving model weights with WeightUpdaters
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.WeightUpdateReceiverBase`
and :class:`~torchrl.collectors.WeightUpdateSenderBase`. These base classes provide a structured interface for
The weight synchronization process is facilitated by one dedicated extension point:
:class:`~torchrl.collectors.WeightUpdaterBase`. These base class provides a structured interface for
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.

- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for updating the policy weights on
the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
situations where the server decides when to update the worker policies).
- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to
remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
devices or processes.
:class:`~torchrl.collectors.WeightUpdaterBase` handles the distribution of policy weights to
the policy or to remote inference workers, as well as formatting / gathering the weights from a server if necessary.
Every collector -- server or worker -- should have a `WeightUpdaterBase` instance to handle the
weight synchronization with the policy.
Even the simplest collectors use a :class:`~torchrl.collectors.VanillaWeightUpdater` instance to update the policy
state-dict (assuming it is a :class:`~torch.nn.Module` instance).

Extending the Updater Classes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Extending the Updater Class
~~~~~~~~~~~~~~~~~~~~~~~~~~~

To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations.
The goal is to be able to customize the weight sync strategy while leaving the collector and policy implementation
untouched.
This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware
setups. By implementing the abstract methods in these base classes, users can define how weights are retrieved,
setups.
By implementing the abstract methods in these base classes, users can define how weights are retrieved,
transformed, and applied, ensuring seamless integration with their existing infrastructure.

Default Implementations
~~~~~~~~~~~~~~~~~~~~~~~

For common scenarios, the API provides default implementations of these updaters, such as
:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`,
:class:`~torchrl.collectors.RayWeightUpdateSender`, :class:`~torchrl.collectors.RPCWeightUpdateSender`, and
:class:`~torchrl.collectors.DistributedWeightUpdateSender`.
These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
distributed systems.

Practical Considerations
~~~~~~~~~~~~~~~~~~~~~~~~

When designing a system that leverages this API, consider the following:

- Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your
implementation accounts for potential delays and optimizes data transfer where possible.
- Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across
the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to
suboptimal policy performance.
- Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the
overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks.

By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment
scenarios, ensuring that their policies remain up-to-date and performant.

.. currentmodule:: torchrl.collectors

.. autosummary::
:toctree: generated/
:template: rl_template.rst

WeightUpdateReceiverBase
WeightUpdateSenderBase
VanillaLocalWeightUpdater
MultiProcessedRemoteWeightUpdate
RayWeightUpdateSender
DistributedWeightUpdateSender
RPCWeightUpdateSender
WeightUpdaterBase
VanillaWeightUpdater
MultiProcessedWeightUpdater
RayWeightUpdater
DistributedWeightUpdater
RPCWeightUpdater

Collectors and replay buffers interoperability
----------------------------------------------
Expand Down
6 changes: 3 additions & 3 deletions examples/collectors/mp_collector_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ class is necessary because MPS tensors cannot be sent over a pipe due to seriali
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.collectors import MultiSyncDataCollector, WeightUpdateSenderBase
from torchrl.collectors import MultiSyncDataCollector, WeightUpdaterBase

from torchrl.envs.libs.gym import GymEnv


class MPSWeightUpdaterBase(WeightUpdateSenderBase):
class MPSWeightUpdaterBase(WeightUpdaterBase):
def __init__(self, policy_weights, num_workers):
# Weights are on mps device, which cannot be shared
self.policy_weights = policy_weights.data
Expand Down Expand Up @@ -101,7 +101,7 @@ def policy_factory(device=device):
reset_at_each_iter=False,
device=device,
storing_device="cpu",
weight_update_sender=MPSWeightUpdaterBase(policy_weights, 2),
weight_updater=MPSWeightUpdaterBase(policy_weights, 2),
# use_buffers=False,
# cat_results="stack",
)
Expand Down
10 changes: 3 additions & 7 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,7 @@
prod,
seed_generator,
)
from torchrl.collectors import (
aSyncDataCollector,
SyncDataCollector,
WeightUpdateSenderBase,
)
from torchrl.collectors import aSyncDataCollector, SyncDataCollector, WeightUpdaterBase
from torchrl.collectors.collectors import (
_Interruptor,
MultiaSyncDataCollector,
Expand Down Expand Up @@ -3489,7 +3485,7 @@ def __deepcopy_error__(*args, **kwargs):


class TestPolicyFactory:
class MPSWeightUpdaterBase(WeightUpdateSenderBase):
class MPSWeightUpdaterBase(WeightUpdaterBase):
def __init__(self, policy_weights, num_workers):
# Weights are on mps device, which cannot be shared
self.policy_weights = policy_weights.data
Expand Down Expand Up @@ -3533,7 +3529,7 @@ def test_weight_update(self):
reset_at_each_iter=False,
device=device,
storing_device="cpu",
weight_update_sender=self.MPSWeightUpdaterBase(policy_weights, 2),
weight_updater=self.MPSWeightUpdaterBase(policy_weights, 2),
)

collector.update_policy_weights_()
Expand Down
6 changes: 2 additions & 4 deletions torchrl/collectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
MultiProcessedWeightUpdate,
RayWeightUpdater,
VanillaWeightUpdater,
WeightUpdateReceiverBase,
WeightUpdateSenderBase,
WeightUpdaterBase,
)

__all__ = [
"RandomPolicy",
"WeightUpdateReceiverBase",
"WeightUpdateSenderBase",
"WeightUpdaterBase",
"VanillaWeightUpdater",
"RayWeightUpdater",
"MultiProcessedWeightUpdate",
Expand Down
Loading
Loading