Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@shenxiangzhuang
Copy link
Contributor

@shenxiangzhuang shenxiangzhuang commented Apr 30, 2025

Close #14

Summary by CodeRabbit

  • New Features

    • Training progress is now tracked by steps instead of episodes, allowing for more precise control over training duration.
    • Improved logging of episodic rewards and training progress.
  • Improvements

    • Enhanced training stability with gradient clipping and a switch to the Adam optimizer.
    • Epsilon decay is now more gradual, supporting better exploration during training.
    • Environment handling is more robust, with improved episode tracking and automatic resets.
    • Simplified policy network outputs Q-values for all actions in a single pass, improving efficiency.
    • Updated action selection to choose the best action based on computed Q-values in one forward pass.
  • Documentation

    • Configuration options now include clearer descriptions and updated parameter names.

@shenxiangzhuang shenxiangzhuang requested a review from Copilot April 30, 2025 08:41
@shenxiangzhuang shenxiangzhuang self-assigned this Apr 30, 2025
@coderabbitai
Copy link

coderabbitai bot commented Apr 30, 2025

"""

Walkthrough

The changes refactor the SARSA reinforcement learning implementation by simplifying the PolicyNet output to return Q-values for all actions, updating the agent's action selection and policy update logic to use these outputs, and introducing gradient clipping. The training loop is restructured from episode-based to step-based, with a cap on the total number of training steps. Environment handling is improved with new wrappers for episodic statistics and auto-resetting, and the optimizer is switched from RMSprop to Adam. Configuration dataclasses are updated for clarity and to reflect the new step-based training approach.

Changes

File(s) Change Summary
toyrl/sarsa.py Refactored PolicyNet to output Q-values for all actions; removed action_dim parameter; updated Agent methods for action selection and policy update using new output; added gradient clipping; changed optimizer to Adam; restructured training loop to be step-based; improved environment handling with wrappers; updated logging; added _make_env method; updated config dataclasses to use max_training_steps and added docstrings.
tests/test_sarsa.py Updated tests to match new PolicyNet signature and output shape; removed total_reward assertions; replaced num_episodes with max_training_steps in config tests; adjusted comments and assertions accordingly; no changes to control flow or error handling.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer as SarsaTrainer
    participant Agent
    participant PolicyNet
    participant Env as GymEnv

    Trainer->>Env: Reset environment (with wrappers)
    loop For each training step up to max_training_steps
        Agent->>PolicyNet: Compute Q-values for observation
        PolicyNet-->>Agent: Q-values for all actions
        Agent->>Agent: Select action (epsilon-greedy)
        Agent->>Env: Take action
        Env-->>Agent: Next observation, reward, done, info
        Agent->>PolicyNet: Compute Q-values for next observation
        PolicyNet-->>Agent: Next Q-values
        Agent->>Agent: Policy update with gradient clipping
        Trainer->>Trainer: Log reward and statistics
        alt Episode done
            Env->>Env: Auto-reset via wrapper
        end
    end
Loading

Assessment against linked issues

Objective Addressed Explanation
Refactor network architecture to take only state as input and output Q-values for all actions (#14)

Possibly related PRs

  • feat: sarsa #9: Refactored and simplified the original SARSA implementation, including changes to the PolicyNet output, action selection, policy update logic, training loop, optimizer, and environment handling.

Poem

In the gym where rabbits leap and bound,
SARSA’s code has turned around—
Q-values now in a single hop,
Training steps that never stop.
With Adam’s help and wrappers neat,
This bunny’s learning can’t be beat!
🐇✨
"""

✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the SARSA implementation to standardize the network architecture and transition the training loop from episode‐based to a global step approach.

  • Removed the unnecessary action_dim parameter and adjusted PolicyNet accordingly.
  • Replaced the iterative action evaluation with a single forward pass using argmax for action selection.
  • Updated training loop structure, configuration settings, and optimizer from RMSprop to Adam.
Comments suppressed due to low confidence (1)

toyrl/sarsa.py:24

  • Removing the action_dim parameter alters the expected input to the network. Please verify that downstream components relying on the concatenated observation–action input are updated accordingly.
nn.Linear(self.env_dim, 128),

@codecov
Copy link

codecov bot commented Apr 30, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.40%. Comparing base (097acdb) to head (016b1fb).
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master      #20      +/-   ##
==========================================
+ Coverage   97.08%   97.40%   +0.32%     
==========================================
  Files           4        4              
  Lines         206      193      -13     
==========================================
- Hits          200      188      -12     
+ Misses          6        5       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
toyrl/sarsa.py (2)

163-164: Rationale for optimizer change needed

The optimizer was changed from RMSprop to Adam without documentation of the reasoning.

 policy_net = PolicyNet(env_dim=env_dim, action_num=action_num)
-optimizer = optim.Adam(policy_net.parameters(), lr=config.train.learning_rate)
+# Using the Adam optimizer instead of RMSprop for its adaptive learning rate and momentum,
+# which generally lead to better convergence in reinforcement learning tasks.
+optimizer = optim.Adam(policy_net.parameters(), lr=config.train.learning_rate)

208-213: Episodic reward extraction needs clarification

The code overrides the reward value from the info dictionary but doesn't explain how this differs from the regular reward.

            if terminated or truncated:
                if info and "episode" in info:
                    reward = info["episode"]["r"]
+                    # Using the total episodic return from the RecordEpisodeStatistics wrapper 
+                    # which accumulates rewards over the entire episode
                    loss = self.agent.policy_update(gamma=self.gamma)
                    self.agent.onpolicy_reset()
                    print(f"global_step={global_step}, epsilon={epsilon}, episodic_return={reward}, loss={loss}")
🧹 Nitpick comments (1)
tests/test_sarsa.py (1)

150-153: Minimal training test updated

The test comment and configuration have been updated to reflect the step-based approach, although the comment still references "single episode" which should be updated to mention steps instead.

-    """Test minimal training run with a single episode."""
+    """Test minimal training run with a limited number of steps."""
     # Create minimal config with just one step
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 33b0658 and f530e30.

📒 Files selected for processing (2)
  • tests/test_sarsa.py (5 hunks)
  • toyrl/sarsa.py (6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (1)
  • GitHub Check: ci (3.12, ubuntu-latest)
🔇 Additional comments (21)
toyrl/sarsa.py (10)

23-26: Architecture update improves efficiency

The network architecture now directly outputs Q-values for all actions from the state input, replacing the previous approach that required multiple forward passes for action selection.


98-101: Action selection optimization

The refactored action selection now performs a single forward pass to get Q-values for all actions, rather than iterating through each action. This is more efficient and aligns with standard Q-learning implementations.


110-114: Fixed tensor structure

The updated tensor structure avoids unnecessary reshaping and ensures consistent tensor dimensions throughout the policy update logic.


117-121: Q-value computation streamlined

The Q-value computation now directly gathers relevant values using tensor operations rather than constructing state-action pairs. This is both more efficient and easier to understand.


125-127: Gradient clipping prevents exploding gradients

Good addition of gradient clipping, which helps stabilize training by preventing large parameter updates that could destabilize learning.


140-146: Improved configuration documentation

Adding docstrings to the configuration parameters significantly enhances code readability and understanding. The switch from episode-based to step-based training (max_training_steps) is well documented.


178-183: Environment wrapping enhances monitoring

Good refactoring to create a dedicated environment setup method with appropriate wrappers. The RecordEpisodeStatistics wrapper automatically tracks episode rewards, and Autoreset simplifies training loop management.


186-192: Improved training loop structure

The step-based training loop with a global step counter is a standard approach that allows for better monitoring and control of the training process. The epsilon decay rate of 0.9999 per step results in a much slower exploration reduction than the previous episode-based approach.


214-221: Improved metric tracking

The wandb logging now includes the global step counter, which provides better tracking of training progress over time rather than just by episode.


227-231: Update configuration for step-based training

The configuration has been updated appropriately to use the new step-based approach with a reasonable number of maximum steps for the default configuration.

tests/test_sarsa.py (11)

21-22: Updated test to match refactored PolicyNet

Tests have been correctly updated to match the new PolicyNet constructor signature, which no longer takes an action_dim parameter.


25-30: Test case updated for new network output

The test now verifies the new output shape correctly, ensuring the network returns Q-values for all actions as expected.


42-45: Explicit dtype specification improves test robustness

Adding explicit dtype specifications to test numpy arrays ensures consistent behavior across different numpy versions and environments.


56-59: Consistent dtype specifications

Good practice to maintain consistent explicit dtype specifications throughout the test cases.


78-79: Policy network test parameters updated

The test correctly uses the updated constructor signature for PolicyNet.


93-94: Agent act test updated

The test has been adjusted to work with the updated PolicyNet implementation.


124-125: Configuration test updated for step-based training

The test has been correctly updated to use max_training_steps instead of num_episodes.


129-130: Config assertion updated

The assertions have been properly updated to check the new configuration parameters.


137-138: Trainer test config updated

The trainer creation test has been updated to use the new max_training_steps parameter.


144-145: Updated trainer property checks

The tests now check for the gamma property instead of num_episodes, which is appropriate for the refactored implementation.


159-160: Updated test run comment

The test run comment has been updated to focus on steps rather than episodes, which is correct.

@shenxiangzhuang shenxiangzhuang merged commit a7ea8ee into master Apr 30, 2025
4 of 5 checks passed
@shenxiangzhuang shenxiangzhuang deleted the refactor/sarsa branch April 30, 2025 09:10
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
toyrl/sarsa.py (2)

162-163: Consider documenting the optimizer change rationale

The optimizer was changed from RMSprop to Adam; consider adding a brief comment explaining the rationale behind this change to aid future maintainability.

 policy_net = PolicyNet(env_dim=env_dim, action_num=action_num)
-optimizer = optim.Adam(policy_net.parameters(), lr=config.train.learning_rate)
+# Using the Adam optimizer instead of RMSprop for its adaptive learning rate and momentum,
+# which generally lead to better convergence in reinforcement learning tasks.
+optimizer = optim.Adam(policy_net.parameters(), lr=config.train.learning_rate)

208-222: Consider clarifying episodic reward computation

Since the reward value is overridden from the info dictionary when available, it may be beneficial to clarify in a comment how episodic returns are computed to avoid future confusion.

Adding a brief comment would improve clarity:

 if terminated or truncated:
     if info and "episode" in info:
+        # Extract episodic reward from RecordEpisodeStatistics wrapper
+        # This is the sum of all rewards received during the episode
         episode_reward = info["episode"]["r"]
         loss = self.agent.policy_update(gamma=self.gamma)
🧹 Nitpick comments (1)
toyrl/sarsa.py (1)

228-232: Note inconsistency in learning rate values

There's an inconsistency between the default learning rate in TrainConfig (2.5e-4) and the one used in the main block (0.01). Consider either aligning these values or documenting why a different value is used for the CartPole environment specifically.

 train=TrainConfig(
     max_training_steps=2000000,
-    learning_rate=0.01,
+    # Higher learning rate works better for CartPole environment
+    learning_rate=0.01,
     log_wandb=True,
 ),
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f530e30 and 016b1fb.

📒 Files selected for processing (2)
  • tests/test_sarsa.py (5 hunks)
  • toyrl/sarsa.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/test_sarsa.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
toyrl/sarsa.py (2)
toyrl/reinforce.py (11)
  • act (74-80)
  • PolicyNet (13-25)
  • train (153-189)
  • reset (49-50)
  • Experience (29-36)
  • add_experience (46-47)
  • add_experience (68-69)
  • policy_update (82-105)
  • onpolicy_reset (65-66)
  • EnvConfig (109-112)
  • TrainConfig (116-121)
toyrl/dqn.py (11)
  • act (93-99)
  • _make_env (223-227)
  • PolicyNet (14-34)
  • train (234-288)
  • reset (68-70)
  • Experience (38-44)
  • add_experience (56-66)
  • add_experience (90-91)
  • policy_update (104-132)
  • EnvConfig (143-146)
  • TrainConfig (150-180)
⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: ci (3.13, ubuntu-latest)
  • GitHub Check: ci (3.12, ubuntu-latest)
🔇 Additional comments (6)
toyrl/sarsa.py (6)

23-26: Well-structured network architecture change!

The network architecture has been simplified to directly output Q-values for all actions from a state input, which aligns with standard practices in Q-learning implementations. This approach is more efficient than concatenating state-action pairs and matches the architecture seen in the DQN implementation.


93-100: Good refactoring of action selection logic

The act method now returns only the action instead of a tuple, and uses a single forward pass to compute Q-values for all actions at once. This is more efficient and aligns perfectly with the updated network architecture.


115-120: Q-value computation and target calculation improved

The Q-value computation has been refactored to use tensor operations more efficiently, using .gather() to select appropriate action values from the network output. This is more vectorized and better aligned with PyTorch best practices.


124-125: Gradient clipping is a valuable addition

Adding gradient clipping is a good practice in reinforcement learning algorithms as it helps prevent exploding gradients and stabilizes training, especially in environments with high variance in rewards.


139-145: Clear configuration improvements with helpful docstrings

The change from episode-based to step-based training with max_training_steps is a good approach that provides more fine-grained control over training duration. The added docstrings also improve code clarity and maintainability.


177-181: Great addition of environment wrappers

The introduction of a dedicated _make_env method with RecordEpisodeStatistics and Autoreset wrappers is a good practice. These wrappers simplify the training loop and provide built-in tracking of episodic metrics.

Comment on lines +187 to 223
observation, _ = self.env.reset()
while global_step < self.config.train.max_training_steps:
global_step += 1
epsilon = max(0.05, epsilon * 0.9999)

action = self.agent.act(observation, epsilon)
next_observation, reward, terminated, truncated, info = self.env.step(action)
experience = Experience(
observation=observation,
action=action,
reward=float(reward),
next_observation=next_observation,
terminated=terminated,
truncated=truncated,
)
if self.config.train.log_wandb:
wandb.log(
{
"episode": episode,
"loss": loss,
"q_value_mean": q_value_mean,
"total_reward": total_reward,
}
)
self.agent.add_experience(experience)
observation = next_observation
if self.env.render_mode is not None:
self.env.render()

if terminated or truncated:
if info and "episode" in info:
episode_reward = info["episode"]["r"]
loss = self.agent.policy_update(gamma=self.gamma)
self.agent.onpolicy_reset()
print(
f"global_step={global_step}, epsilon={epsilon}, episodic_return={episode_reward}, loss={loss}"
)
if self.config.train.log_wandb:
wandb.log(
{
"global_step": global_step,
"episode_reward": episode_reward,
"loss": loss,
}
)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Robust step-based training loop implementation

The training loop refactoring to a step-based approach with a global step counter is excellent. This approach provides better control over training duration and aligns well with modern RL practices. The slower epsilon decay rate (0.9999) allows for more exploration throughout training.

However, there's a potential robustness issue in checking the info dictionary.

Consider adding an explicit check for the "episode" key in the info dictionary:

-            if terminated or truncated:
-                if info and "episode" in info:
+            if terminated or truncated:
+                # When using RecordEpisodeStatistics wrapper, episode statistics 
+                # are available in the info dict when an episode ends
+                if info and isinstance(info, dict) and "episode" in info:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
observation, _ = self.env.reset()
while global_step < self.config.train.max_training_steps:
global_step += 1
epsilon = max(0.05, epsilon * 0.9999)
action = self.agent.act(observation, epsilon)
next_observation, reward, terminated, truncated, info = self.env.step(action)
experience = Experience(
observation=observation,
action=action,
reward=float(reward),
next_observation=next_observation,
terminated=terminated,
truncated=truncated,
)
if self.config.train.log_wandb:
wandb.log(
{
"episode": episode,
"loss": loss,
"q_value_mean": q_value_mean,
"total_reward": total_reward,
}
)
self.agent.add_experience(experience)
observation = next_observation
if self.env.render_mode is not None:
self.env.render()
if terminated or truncated:
if info and "episode" in info:
episode_reward = info["episode"]["r"]
loss = self.agent.policy_update(gamma=self.gamma)
self.agent.onpolicy_reset()
print(
f"global_step={global_step}, epsilon={epsilon}, episodic_return={episode_reward}, loss={loss}"
)
if self.config.train.log_wandb:
wandb.log(
{
"global_step": global_step,
"episode_reward": episode_reward,
"loss": loss,
}
)
if terminated or truncated:
# When using RecordEpisodeStatistics wrapper, episode statistics
# are available in the info dict when an episode ends
if info and isinstance(info, dict) and "episode" in info:
episode_reward = info["episode"]["r"]
loss = self.agent.policy_update(gamma=self.gamma)
self.agent.onpolicy_reset()
print(
f"global_step={global_step}, epsilon={epsilon}, episodic_return={episode_reward}, loss={loss}"
)
if self.config.train.log_wandb:
wandb.log(
{
"global_step": global_step,
"episode_reward": episode_reward,
"loss": loss,
}
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Network architecture change for SARSA & DQN

2 participants