Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions docs/source/reference/llms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,23 @@ When fine-tuning an LLM using TorchRL, the environment is a crucial component of
policy and collector. Environments manage operations that are not handled by the LLM itself, such as interacting with
tools, loading prompts from datasets, computing rewards (when necessary), and formatting data.

Therefore, the fundamental structure of an LLM post-training pipeline is:

- A policy that wraps the LLM and the LLM only
- An environment that handles the world around the LLM:
- Loading data (through :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer`)
- Formatting data (through :class:`~torchrl.envs.llm.transforms.TemplateTransform`)
- Executing tools (through :class:`~torchrl.envs.llm.transforms.PythonInterpreter`)
- Computing rewards online, if needed (through :class:`~torchrl.envs.llm.transforms.KLRewardTransform`)
- A data collector that takes the policy (the LLM) and the environment, and handles the inference part of the pipeline:
- Running reset, step and gathering actions;
- Yielding the data in a consistent format - or populating a buffer;
- Updating the policy weights (through :class:`~torchrl.collectors.WeightUpdaterBase` classes)
- A replay buffer that stores the data collected using the collector
- A loss that takes the LLM's output and returns a loss (through :class:`~torchrl.objectives.llm.GRPOLoss` for example)

These elements are presented in the GRPO scripts in the `sota-implementations/llm` directory.

The design of environments in TorchRL allows for flexibility and modularity. By framing tasks as environments, users can
easily extend or modify existing environments using transforms. This approach enables the isolation of individual
components within specific :class:`~torchrl.envs.EnvBase` or :class:`~torchrl.envs.Transform` subclasses, making it
Expand Down Expand Up @@ -87,6 +104,73 @@ These components can be used to create customized environments tailored to speci
Transforms
~~~~~~~~~~

Transforms are used to modify the data before it is passed to the LLM.
Tools are usually implemented as transforms, and appended to a base environment
such as :class:`~torchrl.envs.llm.ChatEnv`.

An example of a tool transform is the :class:`~torchrl.envs.llm.transforms.PythonInterpreter` transform, which is used
to execute Python code in the context of the LLM.

>>> from torchrl.envs.llm.transforms import PythonInterpreter
>>> from torchrl.envs.llm import ChatEnv
>>> from tensordict import TensorDict, set_list_to_stack
>>> from transformers import AutoTokenizer
>>> from pprint import pprint
>>> set_list_to_stack(True).set()
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
>>> base_env = ChatEnv(
... tokenizer=tokenizer,
... system_prompt="You are an assistant that can execute Python code. Decorate your code with ```python``` tags.",
... user_role="user",
... system_role="system",
... batch_size=[1],
... )
>>> env = base_env.append_transform(PythonInterpreter())
>>> env.set_seed(0)
>>> # Pass the reset data - the prompt - to the environment
>>> reset_data = env.reset(TensorDict(
... text="Let's write a Python function that returns the square of a number.",
... batch_size=[1])
... )
>>> # Simulate an action - i.e., a response from the LLM (as if we were an LLM)
>>> action = """Here is a block of code to be executed in python:
... ```python
... def square(x):
... return x * x
... print('testing the square function with input 2:', square(2))
... ```
... <|im_end|>
... """
>>> step_data = reset_data.set("text_response", [action])
>>> s, s_ = env.step_and_maybe_reset(reset_data)
>>> # The history is a stack of chat messages.
>>> # The python interpreter transform has executed the code in the last message.
>>> pprint(s_["history"].apply_chat_template(tokenizer=tokenizer))
['<|im_start|>system\n'
'You are an assistant that can execute Python code. Decorate your code with '
'```python``` tags.<|im_end|>\n'
'<|im_start|>user\n'
"Let's write a Python function that returns the square of a "
'number.<|im_end|>\n'
'<|im_start|>assistant\n'
'Here is a block of code to be executed in python:\n'
'```python\n'
'def square(x):\n'
' return x * x\n'
"print('testing the square function with input 2:', square(2))\n"
'```<|im_end|>\n'
'<|im_start|>user\n'
'<tool_response>\n'
'Code block 1 executed successfully:\n'
'testing the square function with input 2: 4\n'
'\n'
'</tool_response><|im_end|>\n'
'<|im_start|>assistant\n']

Similarly, environments that load data from a dataset are just special instances of the :class:`~torchrl.envs.llm.ChatEnv`
augmented with a :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer` transforms (and some dedicated reward parsing
transforms).

.. currentmodule:: torchrl.envs.llm.transforms

.. autosummary::
Expand Down
4 changes: 4 additions & 0 deletions torchrl/envs/llm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def __init__(
):
if batch_size is None:
batch_size = (1,)
if isinstance(batch_size, int):
batch_size = (batch_size,)
if isinstance(batch_size, list):
batch_size = torch.Size(batch_size)
if batch_size == ():
raise ValueError(f"{type(self).__name__} must have at least one dimension")

Expand Down
Loading