Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: JumanjiToGymWrapper step function returns wrong value for "terminated" #275

Open
georgkruse opened this issue Feb 27, 2025 · 4 comments
Labels
bug Something isn't working

Comments

@georgkruse
Copy link

Description

I've noted that there is a potential bug in the implementation of the JumanjiToGymWrapper line 615

def step(
            state: State, action: chex.Array
        ) -> Tuple[State, Observation, chex.Array, chex.Array, chex.Array, Optional[Any]]:
            """Step function of a Jumanji environment to be jitted."""
            state, timestep = self._env.step(state, action)
            term = ~timestep.discount.astype(bool)
            trunc = timestep.last().astype(bool)
            return state, timestep.observation, timestep.reward, term, trunc, timestep.extras

        self._step = jax.jit(step, backend=self.backend)

It returns the boo of the timestep.discount variable. However, it state in the dm_env docu:

NOTE: The discount does not determine when a sequence ends. The discount may be 0 in the middle of a sequence and ≥0 at the end of a sequence.

When Im using the wrapper e.g. for the TSP or the Knapsack, the environment always terminates after the first step.

Here is an example:

import jumanji.wrappers
from jumanji.environments import Knapsack 
from jumanji.environments.packing.knapsack.generator import RandomGenerator

generator_knapsack = RandomGenerator(total_budget=2, num_items=5)
env = Knapsack(generator=generator_knapsack)
env = jumanji.wrappers.JumanjiToGymWrapper(env)
state, _ = env.reset()

for i in range(5):
    state, reward, terminate, truncate, info = env.step(i)
    print(reward, terminate, truncate, info)

Which prints:

0.6626307964324951 True False
0.8461408615112305 True False
0.503207802772522 False True
0.0 False True
0.0 False True

However, terminate should never be set to True if it would follow the gymnasium logic.

What Jumanji version are you using?

v1.1.0

Which accelerator(s) are you using?

CPU

Additional System Info

Windows

Additional Context

No response

(Optional) Suggestion

No response

@georgkruse georgkruse added the bug Something isn't working label Feb 27, 2025
@sash-a
Copy link
Collaborator

sash-a commented Feb 27, 2025

Hi @georgkruse thanks for the issue.

The way jumanji defines discount is not terminated. If you look here you will see how we construct timesteps for terminated, truncated steps and for a normal step.

I could not reproduce your results, but it may be because of a random seed selecting high weight values for your items and so the episode was immediately done? If you set your budget to 10 do you get the same behavior? Below is the result I get with exactly the same code you used:

0.6626307964324951 False False {}
0.8461408615112305 False False {}
0.503207802772522 True True {}
0.0 True True {}
0.0 True True {}

I have checked that knapsack handles it's discounts in the way we'd expect.

@sash-a sash-a closed this as completed Feb 27, 2025
@sash-a sash-a reopened this Feb 27, 2025
@sash-a
Copy link
Collaborator

sash-a commented Feb 27, 2025

Sorry didn't mean to close the issue, please let me know if this solves your problem and if so I will close it 😄

@georgkruse
Copy link
Author

Hey sash-a, thanks for the quick reply. I think the reward is the weight of the item in the knapsack problem (and I set the max weight to 2). For other runs I get:

0.6355215311050415 True False {} 
0.8265048265457153 True False {} 
0.45131349563598633 True False {}
0.11718368530273438 True False {}
0.5285080671310425 False True {} 

which is the expected behavior except for the flag "terminate".
If its working for you, maybe its a version problem? Im running on python 3.12.9 and simply did a pip install ray jumanji
Here is my environemt.yaml:

name: test
channels:
  - conda-forge
dependencies:
  - bzip2=1.0.8=h2466b09_7
  - ca-certificates=2025.1.31=h56e8100_0
  - libexpat=2.6.4=he0c23c2_0
  - libffi=3.4.6=h537db12_0
  - liblzma=5.6.4=h2466b09_0
  - libsqlite=3.49.1=h67fdade_1
  - libzlib=1.3.1=h2466b09_2
  - openssl=3.4.1=ha4e3fda_0
  - pip=25.0.1=pyh8b19718_0
  - python=3.12.9=h3f84c4b_0_cpython
  - setuptools=75.8.2=pyhff2d567_0
  - tk=8.6.13=h5226925_1
  - tzdata=2025a=h78e105d_0
  - ucrt=10.0.22621.0=h57928b3_1
  - vc=14.3=h5fd82a7_24
  - vc14_runtime=14.42.34433=h6356254_24
  - wheel=0.45.1=pyhd8ed1ab_1
  - pip:
      - absl-py==2.1.0
      - attrs==25.1.0
      - certifi==2025.1.31
      - charset-normalizer==3.4.1
      - chex==0.1.89
      - cloudpickle==3.1.1
      - colorama==0.4.6
      - contourpy==1.3.1
      - cycler==0.12.1
      - dm-env==1.6
      - dm-tree==0.1.9
      - farama-notifications==0.0.4
      - filelock==3.17.0
      - fonttools==4.56.0
      - fsspec==2025.2.0
      - gymnasium==1.1.0
      - huggingface-hub==0.29.1
      - idna==3.10
      - jax==0.5.1
      - jaxlib==0.5.1
      - jumanji==1.1.0
      - kiwisolver==1.4.8
      - matplotlib==3.7.5
      - ml-dtypes==0.5.1
      - numpy==1.26.4
      - opt-einsum==3.4.0
      - packaging==24.2
      - pillow==11.1.0
      - pyparsing==3.2.1
      - python-dateutil==2.9.0.post0
      - pyyaml==6.0.2
      - requests==2.32.3
      - scipy==1.15.2
      - six==1.17.0
      - toolz==1.0.0
      - tqdm==4.67.1
      - typing-extensions==4.12.2
      - urllib3==2.3.0
      - wrapt==1.17.2

Thanks for the help.

@sash-a
Copy link
Collaborator

sash-a commented Mar 5, 2025

Hi @georgkruse sorry for the late reply I must have missed this message.

I think I must have had a different version of the wrapper, it is fixed on the main branch as you can see here and in this PR. But looking at our current release is isn't fixed, sorry about that. I want to close this PR before doing another release, so an easy fix should be pip install git+https://github.com/instadeepai/jumanji. I've tested this and it works, please let me know if that works for you. Hopefully the release will be out soon 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants