-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug: JumanjiToGymWrapper step function returns wrong value for "terminated" #275
Comments
Hi @georgkruse thanks for the issue. The way jumanji defines discount is I could not reproduce your results, but it may be because of a random seed selecting high weight values for your items and so the episode was immediately done? If you set your budget to 10 do you get the same behavior? Below is the result I get with exactly the same code you used:
I have checked that knapsack handles it's discounts in the way we'd expect. |
Sorry didn't mean to close the issue, please let me know if this solves your problem and if so I will close it 😄 |
Hey sash-a, thanks for the quick reply. I think the reward is the weight of the item in the knapsack problem (and I set the max weight to 2). For other runs I get:
which is the expected behavior except for the flag "terminate". name: test
channels:
- conda-forge
dependencies:
- bzip2=1.0.8=h2466b09_7
- ca-certificates=2025.1.31=h56e8100_0
- libexpat=2.6.4=he0c23c2_0
- libffi=3.4.6=h537db12_0
- liblzma=5.6.4=h2466b09_0
- libsqlite=3.49.1=h67fdade_1
- libzlib=1.3.1=h2466b09_2
- openssl=3.4.1=ha4e3fda_0
- pip=25.0.1=pyh8b19718_0
- python=3.12.9=h3f84c4b_0_cpython
- setuptools=75.8.2=pyhff2d567_0
- tk=8.6.13=h5226925_1
- tzdata=2025a=h78e105d_0
- ucrt=10.0.22621.0=h57928b3_1
- vc=14.3=h5fd82a7_24
- vc14_runtime=14.42.34433=h6356254_24
- wheel=0.45.1=pyhd8ed1ab_1
- pip:
- absl-py==2.1.0
- attrs==25.1.0
- certifi==2025.1.31
- charset-normalizer==3.4.1
- chex==0.1.89
- cloudpickle==3.1.1
- colorama==0.4.6
- contourpy==1.3.1
- cycler==0.12.1
- dm-env==1.6
- dm-tree==0.1.9
- farama-notifications==0.0.4
- filelock==3.17.0
- fonttools==4.56.0
- fsspec==2025.2.0
- gymnasium==1.1.0
- huggingface-hub==0.29.1
- idna==3.10
- jax==0.5.1
- jaxlib==0.5.1
- jumanji==1.1.0
- kiwisolver==1.4.8
- matplotlib==3.7.5
- ml-dtypes==0.5.1
- numpy==1.26.4
- opt-einsum==3.4.0
- packaging==24.2
- pillow==11.1.0
- pyparsing==3.2.1
- python-dateutil==2.9.0.post0
- pyyaml==6.0.2
- requests==2.32.3
- scipy==1.15.2
- six==1.17.0
- toolz==1.0.0
- tqdm==4.67.1
- typing-extensions==4.12.2
- urllib3==2.3.0
- wrapt==1.17.2 Thanks for the help. |
Hi @georgkruse sorry for the late reply I must have missed this message. I think I must have had a different version of the wrapper, it is fixed on the main branch as you can see here and in this PR. But looking at our current release is isn't fixed, sorry about that. I want to close this PR before doing another release, so an easy fix should be |
Description
I've noted that there is a potential bug in the implementation of the JumanjiToGymWrapper line 615
It returns the boo of the timestep.discount variable. However, it state in the dm_env docu:
When Im using the wrapper e.g. for the TSP or the Knapsack, the environment always terminates after the first step.
Here is an example:
Which prints:
However, terminate should never be set to True if it would follow the gymnasium logic.
What Jumanji version are you using?
v1.1.0
Which accelerator(s) are you using?
CPU
Additional System Info
Windows
Additional Context
No response
(Optional) Suggestion
No response
The text was updated successfully, but these errors were encountered: