Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ in the relevant functions:
>>> print(env2._env.env.env)
<gym.envs.classic_control.pendulum.PendulumEnv at 0x1629916a0>

We can see that the two libraries modify the value returned by :func:`~.gym.gym_backend()`
We can see that the two libraries modify the value returned by :func:`~torchrl.envs.gym.gym_backend()`
which can be further used to indicate which library needs to be used for
the current computation. :class:`~.gym.set_gym_backend` is also a decorator:
we can use it to tell to a specific function what gym backend needs to be used
Expand Down Expand Up @@ -1188,3 +1188,4 @@ the following function will return ``1`` when queried:
VmasWrapper
gym_backend
set_gym_backend
register_gym_spec_conversion
35 changes: 35 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
Composite,
MultiCategorical,
MultiOneHot,
NonTensor,
OneHot,
ReplayBuffer,
ReplayBufferEnsemble,
Expand Down Expand Up @@ -119,6 +120,7 @@
GymWrapper,
MOGymEnv,
MOGymWrapper,
register_gym_spec_conversion,
set_gym_backend,
)
from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
Expand Down Expand Up @@ -337,6 +339,39 @@ def test_gym_spec_cast(self, categorical):
assert spec == recon
assert recon.shape == spec.shape

def test_gym_new_spec_reg(self):
Space = gym_backend("spaces").Space

class MySpaceParent(Space):
...

s_parent = MySpaceParent()

class MySpaceChild(MySpaceParent):
...

# We intentionally register first the child then the parent
@register_gym_spec_conversion(MySpaceChild)
def convert_myspace_child(spec, **kwargs):
return NonTensor((), example_data="child")

@register_gym_spec_conversion(MySpaceParent)
def convert_myspace_parent(spec, **kwargs):
return NonTensor((), example_data="parent")

s_child = MySpaceChild()
assert _gym_to_torchrl_spec_transform(s_parent).example_data == "parent"
assert _gym_to_torchrl_spec_transform(s_child).example_data == "child"

class NoConversionSpace(Space):
...

s_no_conv = NoConversionSpace()
with pytest.raises(
KeyError, match="No conversion tool could be found with the gym space"
):
_gym_to_torchrl_spec_transform(s_no_conv)

@pytest.mark.parametrize("order", ["tuple_seq"])
@implement_for("gym")
def test_gym_spec_cast_tuple_sequential(self, order):
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
OpenSpielWrapper,
PettingZooEnv,
PettingZooWrapper,
register_gym_spec_conversion,
RoboHiveEnv,
set_gym_backend,
SMACv2Env,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/libs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
GymWrapper,
MOGymEnv,
MOGymWrapper,
register_gym_spec_conversion,
set_gym_backend,
)
from .habitat import HabitatEnv
Expand Down
Loading
Loading