Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,15 @@ transformed, and applied, ensuring seamless integration with their existing infr
VanillaWeightUpdater
MultiProcessedWeightUpdater
RayWeightUpdater
DistributedWeightUpdater

.. currentmodule:: torchrl.collectors.distributed

.. autosummary::
:toctree: generated/
:template: rl_template.rst

RPCWeightUpdater
DistributedWeightUpdater

Collectors and replay buffers interoperability
----------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions torchrl/collectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
SyncDataCollector,
)
from .weight_update import (
MultiProcessedWeightUpdate,
MultiProcessedWeightUpdater,
RayWeightUpdater,
VanillaWeightUpdater,
WeightUpdaterBase,
Expand All @@ -24,7 +24,7 @@
"WeightUpdaterBase",
"VanillaWeightUpdater",
"RayWeightUpdater",
"MultiProcessedWeightUpdate",
"MultiProcessedWeightUpdater",
"aSyncDataCollector",
"DataCollectorBase",
"MultiaSyncDataCollector",
Expand Down
4 changes: 2 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
)
from torchrl.collectors.utils import split_trajectories
from torchrl.collectors.weight_update import (
MultiProcessedWeightUpdate,
MultiProcessedWeightUpdater,
VanillaWeightUpdater,
WeightUpdaterBase,
)
Expand Down Expand Up @@ -2010,7 +2010,7 @@ def __init__(
self._policy_weights_dict[policy_device] = weights
self._get_weights_fn = get_weights_fn
if weight_updater is None:
weight_updater = MultiProcessedWeightUpdate(
weight_updater = MultiProcessedWeightUpdater(
get_server_weights=self._get_weights_fn,
policy_weights=self._policy_weights_dict,
)
Expand Down
4 changes: 2 additions & 2 deletions torchrl/collectors/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DistributedWeightUpdater,
)
from .ray import RayCollector
from .rpc import RPCDataCollector
from .rpc import RPCDataCollector, RPCWeightUpdater
from .sync import DistributedSyncDataCollector
from .utils import submitit_delayed_launcher

Expand All @@ -19,7 +19,7 @@
"DistributedWeightUpdater",
"DistributedSyncDataCollector",
"RPCDataCollector",
"RPCDataCollector",
"RPCWeightUpdater",
"RayCollector",
"submitit_delayed_launcher",
]
4 changes: 2 additions & 2 deletions torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def __init__(
)
self._init()
if weight_updater is None:
weight_updater = RPCWeightUpdaterBase(
weight_updater = RPCWeightUpdater(
collector_infos=self.collector_infos,
collector_class=self.collector_class,
collector_rrefs=self.collector_rrefs,
Expand Down Expand Up @@ -810,7 +810,7 @@ def shutdown(self, timeout: float | None = None) -> None:
self._shutdown = True


class RPCWeightUpdaterBase(WeightUpdaterBase):
class RPCWeightUpdater(WeightUpdaterBase):
"""A remote weight updater for synchronizing policy weights across remote workers using RPC.

The `RPCWeightUpdater` class provides a mechanism for updating the weights of a policy
Expand Down
13 changes: 12 additions & 1 deletion torchrl/collectors/llm/weight_update/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util

import torch
import torch.cuda
import torch.distributed
Expand All @@ -13,7 +15,16 @@

from torchrl.collectors import WeightUpdaterBase
from torchrl.modules.llm.backends.vllm import stateless_init_process_group
from vllm.utils import get_open_port

_has_vllm = importlib.util.find_spec("vllm") is not None
if _has_vllm:
from vllm.utils import get_open_port
else:

def get_open_port(): # noqa: D103
raise ImportError(
"vllm is not installed. Please install it with `pip install vllm`."
)


class vLLMUpdater(WeightUpdaterBase):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/weight_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _sync_weights_with_worker(
self.policy_weights.update_(server_weights)


class MultiProcessedWeightUpdate(WeightUpdaterBase):
class MultiProcessedWeightUpdater(WeightUpdaterBase):
"""A remote weight updater for synchronizing policy weights across multiple processes or devices.

The `MultiProcessedWeightUpdater` class provides a mechanism for updating the weights
Expand Down
5 changes: 2 additions & 3 deletions torchrl/envs/llm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Any, Callable, Literal

import torch
import transformers
from tensordict import lazy_stack, TensorDict, TensorDictBase
from torch.utils.data import DataLoader
from torchrl.data import Composite, NonTensor
Expand Down Expand Up @@ -116,7 +115,7 @@ def __init__(
batch_size: tuple | torch.Size | None = None,
system_prompt: str | None = None,
apply_template: bool | None = None,
tokenizer: transformers.AutoTokenizer | None = None,
tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
template_kwargs: dict[str, Any] | None = None,
system_role: str = "system",
user_role: str = "user",
Expand Down Expand Up @@ -309,7 +308,7 @@ def __init__(
batch_size_dl: int = 1,
seed: int | None = None,
group_repeats: bool = False,
tokenizer: transformers.AutoTokenizer | None = None,
tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
device: torch.device | None = None,
template_kwargs: dict[str, Any] | None = None,
apply_template: bool | None = None,
Expand Down
Loading