-
Notifications
You must be signed in to change notification settings - Fork 2
refactor(sarsa): standard network arch and use global step #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
""" WalkthroughThe changes refactor the SARSA reinforcement learning implementation by simplifying the Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer as SarsaTrainer
participant Agent
participant PolicyNet
participant Env as GymEnv
Trainer->>Env: Reset environment (with wrappers)
loop For each training step up to max_training_steps
Agent->>PolicyNet: Compute Q-values for observation
PolicyNet-->>Agent: Q-values for all actions
Agent->>Agent: Select action (epsilon-greedy)
Agent->>Env: Take action
Env-->>Agent: Next observation, reward, done, info
Agent->>PolicyNet: Compute Q-values for next observation
PolicyNet-->>Agent: Next Q-values
Agent->>Agent: Policy update with gradient clipping
Trainer->>Trainer: Log reward and statistics
alt Episode done
Env->>Env: Auto-reset via wrapper
end
end
Assessment against linked issues
Possibly related PRs
Poem
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the SARSA implementation to standardize the network architecture and transition the training loop from episode‐based to a global step approach.
- Removed the unnecessary action_dim parameter and adjusted PolicyNet accordingly.
- Replaced the iterative action evaluation with a single forward pass using argmax for action selection.
- Updated training loop structure, configuration settings, and optimizer from RMSprop to Adam.
Comments suppressed due to low confidence (1)
toyrl/sarsa.py:24
- Removing the action_dim parameter alters the expected input to the network. Please verify that downstream components relying on the concatenated observation–action input are updated accordingly.
nn.Linear(self.env_dim, 128),
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #20 +/- ##
==========================================
+ Coverage 97.08% 97.40% +0.32%
==========================================
Files 4 4
Lines 206 193 -13
==========================================
- Hits 200 188 -12
+ Misses 6 5 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
toyrl/sarsa.py (2)
163-164: Rationale for optimizer change neededThe optimizer was changed from RMSprop to Adam without documentation of the reasoning.
policy_net = PolicyNet(env_dim=env_dim, action_num=action_num) -optimizer = optim.Adam(policy_net.parameters(), lr=config.train.learning_rate) +# Using the Adam optimizer instead of RMSprop for its adaptive learning rate and momentum, +# which generally lead to better convergence in reinforcement learning tasks. +optimizer = optim.Adam(policy_net.parameters(), lr=config.train.learning_rate)
208-213: Episodic reward extraction needs clarificationThe code overrides the reward value from the info dictionary but doesn't explain how this differs from the regular reward.
if terminated or truncated: if info and "episode" in info: reward = info["episode"]["r"] + # Using the total episodic return from the RecordEpisodeStatistics wrapper + # which accumulates rewards over the entire episode loss = self.agent.policy_update(gamma=self.gamma) self.agent.onpolicy_reset() print(f"global_step={global_step}, epsilon={epsilon}, episodic_return={reward}, loss={loss}")
🧹 Nitpick comments (1)
tests/test_sarsa.py (1)
150-153: Minimal training test updatedThe test comment and configuration have been updated to reflect the step-based approach, although the comment still references "single episode" which should be updated to mention steps instead.
- """Test minimal training run with a single episode.""" + """Test minimal training run with a limited number of steps.""" # Create minimal config with just one step
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tests/test_sarsa.py(5 hunks)toyrl/sarsa.py(6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (1)
- GitHub Check: ci (3.12, ubuntu-latest)
🔇 Additional comments (21)
toyrl/sarsa.py (10)
23-26: Architecture update improves efficiencyThe network architecture now directly outputs Q-values for all actions from the state input, replacing the previous approach that required multiple forward passes for action selection.
98-101: Action selection optimizationThe refactored action selection now performs a single forward pass to get Q-values for all actions, rather than iterating through each action. This is more efficient and aligns with standard Q-learning implementations.
110-114: Fixed tensor structureThe updated tensor structure avoids unnecessary reshaping and ensures consistent tensor dimensions throughout the policy update logic.
117-121: Q-value computation streamlinedThe Q-value computation now directly gathers relevant values using tensor operations rather than constructing state-action pairs. This is both more efficient and easier to understand.
125-127: Gradient clipping prevents exploding gradientsGood addition of gradient clipping, which helps stabilize training by preventing large parameter updates that could destabilize learning.
140-146: Improved configuration documentationAdding docstrings to the configuration parameters significantly enhances code readability and understanding. The switch from episode-based to step-based training (max_training_steps) is well documented.
178-183: Environment wrapping enhances monitoringGood refactoring to create a dedicated environment setup method with appropriate wrappers. The
RecordEpisodeStatisticswrapper automatically tracks episode rewards, andAutoresetsimplifies training loop management.
186-192: Improved training loop structureThe step-based training loop with a global step counter is a standard approach that allows for better monitoring and control of the training process. The epsilon decay rate of 0.9999 per step results in a much slower exploration reduction than the previous episode-based approach.
214-221: Improved metric trackingThe wandb logging now includes the global step counter, which provides better tracking of training progress over time rather than just by episode.
227-231: Update configuration for step-based trainingThe configuration has been updated appropriately to use the new step-based approach with a reasonable number of maximum steps for the default configuration.
tests/test_sarsa.py (11)
21-22: Updated test to match refactored PolicyNetTests have been correctly updated to match the new PolicyNet constructor signature, which no longer takes an action_dim parameter.
25-30: Test case updated for new network outputThe test now verifies the new output shape correctly, ensuring the network returns Q-values for all actions as expected.
42-45: Explicit dtype specification improves test robustnessAdding explicit dtype specifications to test numpy arrays ensures consistent behavior across different numpy versions and environments.
56-59: Consistent dtype specificationsGood practice to maintain consistent explicit dtype specifications throughout the test cases.
78-79: Policy network test parameters updatedThe test correctly uses the updated constructor signature for PolicyNet.
93-94: Agent act test updatedThe test has been adjusted to work with the updated PolicyNet implementation.
124-125: Configuration test updated for step-based trainingThe test has been correctly updated to use max_training_steps instead of num_episodes.
129-130: Config assertion updatedThe assertions have been properly updated to check the new configuration parameters.
137-138: Trainer test config updatedThe trainer creation test has been updated to use the new max_training_steps parameter.
144-145: Updated trainer property checksThe tests now check for the gamma property instead of num_episodes, which is appropriate for the refactored implementation.
159-160: Updated test run commentThe test run comment has been updated to focus on steps rather than episodes, which is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
toyrl/sarsa.py (2)
162-163: Consider documenting the optimizer change rationaleThe optimizer was changed from RMSprop to Adam; consider adding a brief comment explaining the rationale behind this change to aid future maintainability.
policy_net = PolicyNet(env_dim=env_dim, action_num=action_num) -optimizer = optim.Adam(policy_net.parameters(), lr=config.train.learning_rate) +# Using the Adam optimizer instead of RMSprop for its adaptive learning rate and momentum, +# which generally lead to better convergence in reinforcement learning tasks. +optimizer = optim.Adam(policy_net.parameters(), lr=config.train.learning_rate)
208-222: Consider clarifying episodic reward computationSince the reward value is overridden from the info dictionary when available, it may be beneficial to clarify in a comment how episodic returns are computed to avoid future confusion.
Adding a brief comment would improve clarity:
if terminated or truncated: if info and "episode" in info: + # Extract episodic reward from RecordEpisodeStatistics wrapper + # This is the sum of all rewards received during the episode episode_reward = info["episode"]["r"] loss = self.agent.policy_update(gamma=self.gamma)
🧹 Nitpick comments (1)
toyrl/sarsa.py (1)
228-232: Note inconsistency in learning rate valuesThere's an inconsistency between the default learning rate in
TrainConfig(2.5e-4) and the one used in the main block (0.01). Consider either aligning these values or documenting why a different value is used for the CartPole environment specifically.train=TrainConfig( max_training_steps=2000000, - learning_rate=0.01, + # Higher learning rate works better for CartPole environment + learning_rate=0.01, log_wandb=True, ),
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tests/test_sarsa.py(5 hunks)toyrl/sarsa.py(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/test_sarsa.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
toyrl/sarsa.py (2)
toyrl/reinforce.py (11)
act(74-80)PolicyNet(13-25)train(153-189)reset(49-50)Experience(29-36)add_experience(46-47)add_experience(68-69)policy_update(82-105)onpolicy_reset(65-66)EnvConfig(109-112)TrainConfig(116-121)toyrl/dqn.py (11)
act(93-99)_make_env(223-227)PolicyNet(14-34)train(234-288)reset(68-70)Experience(38-44)add_experience(56-66)add_experience(90-91)policy_update(104-132)EnvConfig(143-146)TrainConfig(150-180)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: ci (3.13, ubuntu-latest)
- GitHub Check: ci (3.12, ubuntu-latest)
🔇 Additional comments (6)
toyrl/sarsa.py (6)
23-26: Well-structured network architecture change!The network architecture has been simplified to directly output Q-values for all actions from a state input, which aligns with standard practices in Q-learning implementations. This approach is more efficient than concatenating state-action pairs and matches the architecture seen in the DQN implementation.
93-100: Good refactoring of action selection logicThe
actmethod now returns only the action instead of a tuple, and uses a single forward pass to compute Q-values for all actions at once. This is more efficient and aligns perfectly with the updated network architecture.
115-120: Q-value computation and target calculation improvedThe Q-value computation has been refactored to use tensor operations more efficiently, using
.gather()to select appropriate action values from the network output. This is more vectorized and better aligned with PyTorch best practices.
124-125: Gradient clipping is a valuable additionAdding gradient clipping is a good practice in reinforcement learning algorithms as it helps prevent exploding gradients and stabilizes training, especially in environments with high variance in rewards.
139-145: Clear configuration improvements with helpful docstringsThe change from episode-based to step-based training with
max_training_stepsis a good approach that provides more fine-grained control over training duration. The added docstrings also improve code clarity and maintainability.
177-181: Great addition of environment wrappersThe introduction of a dedicated
_make_envmethod withRecordEpisodeStatisticsandAutoresetwrappers is a good practice. These wrappers simplify the training loop and provide built-in tracking of episodic metrics.
| observation, _ = self.env.reset() | ||
| while global_step < self.config.train.max_training_steps: | ||
| global_step += 1 | ||
| epsilon = max(0.05, epsilon * 0.9999) | ||
|
|
||
| action = self.agent.act(observation, epsilon) | ||
| next_observation, reward, terminated, truncated, info = self.env.step(action) | ||
| experience = Experience( | ||
| observation=observation, | ||
| action=action, | ||
| reward=float(reward), | ||
| next_observation=next_observation, | ||
| terminated=terminated, | ||
| truncated=truncated, | ||
| ) | ||
| if self.config.train.log_wandb: | ||
| wandb.log( | ||
| { | ||
| "episode": episode, | ||
| "loss": loss, | ||
| "q_value_mean": q_value_mean, | ||
| "total_reward": total_reward, | ||
| } | ||
| ) | ||
| self.agent.add_experience(experience) | ||
| observation = next_observation | ||
| if self.env.render_mode is not None: | ||
| self.env.render() | ||
|
|
||
| if terminated or truncated: | ||
| if info and "episode" in info: | ||
| episode_reward = info["episode"]["r"] | ||
| loss = self.agent.policy_update(gamma=self.gamma) | ||
| self.agent.onpolicy_reset() | ||
| print( | ||
| f"global_step={global_step}, epsilon={epsilon}, episodic_return={episode_reward}, loss={loss}" | ||
| ) | ||
| if self.config.train.log_wandb: | ||
| wandb.log( | ||
| { | ||
| "global_step": global_step, | ||
| "episode_reward": episode_reward, | ||
| "loss": loss, | ||
| } | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Robust step-based training loop implementation
The training loop refactoring to a step-based approach with a global step counter is excellent. This approach provides better control over training duration and aligns well with modern RL practices. The slower epsilon decay rate (0.9999) allows for more exploration throughout training.
However, there's a potential robustness issue in checking the info dictionary.
Consider adding an explicit check for the "episode" key in the info dictionary:
- if terminated or truncated:
- if info and "episode" in info:
+ if terminated or truncated:
+ # When using RecordEpisodeStatistics wrapper, episode statistics
+ # are available in the info dict when an episode ends
+ if info and isinstance(info, dict) and "episode" in info:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| observation, _ = self.env.reset() | |
| while global_step < self.config.train.max_training_steps: | |
| global_step += 1 | |
| epsilon = max(0.05, epsilon * 0.9999) | |
| action = self.agent.act(observation, epsilon) | |
| next_observation, reward, terminated, truncated, info = self.env.step(action) | |
| experience = Experience( | |
| observation=observation, | |
| action=action, | |
| reward=float(reward), | |
| next_observation=next_observation, | |
| terminated=terminated, | |
| truncated=truncated, | |
| ) | |
| if self.config.train.log_wandb: | |
| wandb.log( | |
| { | |
| "episode": episode, | |
| "loss": loss, | |
| "q_value_mean": q_value_mean, | |
| "total_reward": total_reward, | |
| } | |
| ) | |
| self.agent.add_experience(experience) | |
| observation = next_observation | |
| if self.env.render_mode is not None: | |
| self.env.render() | |
| if terminated or truncated: | |
| if info and "episode" in info: | |
| episode_reward = info["episode"]["r"] | |
| loss = self.agent.policy_update(gamma=self.gamma) | |
| self.agent.onpolicy_reset() | |
| print( | |
| f"global_step={global_step}, epsilon={epsilon}, episodic_return={episode_reward}, loss={loss}" | |
| ) | |
| if self.config.train.log_wandb: | |
| wandb.log( | |
| { | |
| "global_step": global_step, | |
| "episode_reward": episode_reward, | |
| "loss": loss, | |
| } | |
| ) | |
| if terminated or truncated: | |
| # When using RecordEpisodeStatistics wrapper, episode statistics | |
| # are available in the info dict when an episode ends | |
| if info and isinstance(info, dict) and "episode" in info: | |
| episode_reward = info["episode"]["r"] | |
| loss = self.agent.policy_update(gamma=self.gamma) | |
| self.agent.onpolicy_reset() | |
| print( | |
| f"global_step={global_step}, epsilon={epsilon}, episodic_return={episode_reward}, loss={loss}" | |
| ) | |
| if self.config.train.log_wandb: | |
| wandb.log( | |
| { | |
| "global_step": global_step, | |
| "episode_reward": episode_reward, | |
| "loss": loss, | |
| } | |
| ) |
Close #14
Summary by CodeRabbit
New Features
Improvements
Documentation