Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: Dependency conflict for most recent version #276

Open
thomfoster opened this issue Mar 5, 2025 · 19 comments
Open

bug: Dependency conflict for most recent version #276

thomfoster opened this issue Mar 5, 2025 · 19 comments
Labels
bug Something isn't working

Comments

@thomfoster
Copy link

Description

I cloned the latest version of the repo, and am trying to install in editable mode in order to run and edit the training scripts.
I am running the install via a requirements file that includes "jax[cuda12]" and "-e ./jumanji[dev,train]"
This results in a chex dependency conflict:

11.92 The conflict is caused by:
11.92     jumanji 1.1.0 depends on chex>=0.1.3
11.92     jumanji[dev,train] 1.1.0 depends on chex>=0.1.3
11.92     distrax 0.1.5 depends on chex>=0.1.8
11.92     esquilax 2.0.0 depends on chex<0.2.0 and >=0.1.86

Specifying the chex==0.1.89 in requirements.txt leads to:

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-a1724fd13edc-ae6ee8a8cc84c00f-29-62f95acc2671c, line 5; fatal   : Unsupported .version 8.3; current version is '8.1'
ptxas fatal   : Ptx assembly aborted due to errors

I tried this for python3.10 and python3.11

CUDA information is

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0
Wed Mar  5 10:11:43 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |

What Jumanji version are you using?

jumanji 1.1.0

Which accelerator(s) are you using?

GPU

Additional System Info

No response

Additional Context

No response

(Optional) Suggestion

No response

@thomfoster thomfoster added the bug Something isn't working label Mar 5, 2025
@sash-a
Copy link
Collaborator

sash-a commented Mar 5, 2025

Thanks for this, @zombie-einstein can you please unpin chex in esquilax, that should fix it I think

@sash-a
Copy link
Collaborator

sash-a commented Mar 5, 2025

I just tested this and it seems to work fine for me @thomfoster. Can you try running the following:

uv venv -p=3.12
uv pip install -e ".[dev,train]"
uv pip install "jax[cuda12]"
python jumanji/training/train.py

Let me know if that works 😄

@thomfoster
Copy link
Author

thomfoster commented Mar 5, 2025 via email

@zombie-einstein
Copy link
Contributor

Thanks for this, @zombie-einstein can you please unpin chex in esquilax, that should fix it I think

Hey, just saw this, still need me to make a change?

@thomfoster
Copy link
Author

Thanks for this, @zombie-einstein can you please unpin chex in esquilax, that should fix it I think

Hey, just saw this, still need me to make a change?

Yes please

@zombie-einstein
Copy link
Contributor

Looking at this closer, what is the source off the issue here? There's several releases that satisfy the bounds, unless I'm missing something?

@thomfoster
Copy link
Author

thomfoster commented Mar 5, 2025

@sash-a I created a fork with my current install setup at https://github.com/thomfoster/jumanji that currently fails with the PXTAS error if I run the jumanji install seperately to the jax[cuda12] install.

# RUN python3 -m pip install -e ".[dev,train]"
# RUN python3 -m pip install "jax[cuda12]"

@zombie-einstein if I run the fork above but run the jumanji install at the same time as the jax[cuda12] install ie RUN python3 -m pip install "jax[cuda12]" -e ".[dev,train]" (ie forcing the requirements to match) I get the chex error above

@zombie-einstein
Copy link
Contributor

Is there more to the version conflict message though?

11.92 The conflict is caused by:
11.92     jumanji 1.1.0 depends on chex>=0.1.3
11.92     jumanji[dev,train] 1.1.0 depends on chex>=0.1.3
11.92     distrax 0.1.5 depends on chex>=0.1.8
11.92     esquilax 2.0.0 depends on chex<0.2.0 and >=0.1.86

Does not seem like a conflict right, versions 0.1.86 ... 0.1.89 (the latest release) satisfy this? Is Chex forcing some requirement on the JAX version causing an issue?

Happy to make the change, just wondering what the best change is, given distrax requires >0.18 anyway.

@thomfoster
Copy link
Author

It's weird pip looks at the versions 0.1.86 -> 0.1.89 but doesn't use them:

#16 7.256 INFO: pip is looking at multiple versions of esquilax to determine which version is compatible with other requirements. This could take a while.
#16 7.257 Collecting chex>=0.1.3 (from jumanji==1.1.0)
#16 7.267   Downloading chex-0.1.88-py3-none-any.whl.metadata (17 kB)
#16 7.287   Downloading chex-0.1.87-py3-none-any.whl.metadata (17 kB)
#16 7.305   Downloading chex-0.1.86-py3-none-any.whl.metadata (17 kB)
#16 7.325   Downloading chex-0.1.85-py3-none-any.whl.metadata (17 kB)
#16 7.340 ERROR: Cannot install None, jumanji and jumanji[dev,train]==1.1.0 because these package versions have conflicting dependencies.
#16 7.340 
#16 7.340 The conflict is caused by:
#16 7.340     jumanji 1.1.0 depends on chex>=0.1.3
#16 7.340     jumanji[dev,train] 1.1.0 depends on chex>=0.1.3
#16 7.340     distrax 0.1.5 depends on chex>=0.1.8
#16 7.340     esquilax 2.0.0 depends on chex<0.2.0 and >=0.1.86
#16 7.340 
#16 7.340 To fix this you could try to:
#16 7.340 1. loosen the range of package versions you've specified
#16 7.340 2. remove package versions to allow pip to attempt to solve the dependency conflict
#16 7.340 
#16 7.433 ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts
#16 ERROR: process "/bin/sh -c python3 -m pip install \"jax[cuda12]\" -e \".[dev,train]\"" did not complete successfully: exit code: 1

@zombie-einstein
Copy link
Contributor

Yeah strange, I just tested the same install command on my machine and worked ok so not sure what is going on here. Is it easy to switch esquilax to install from a local build and update the dependency? Just to check the solution before running the release process.

If that's a massive pain I can make a release, though still not clear what the solution is, removing the upper bound maybe, and the lower bound to match distrax?? Can't see what this would actually be doing though.

@sash-a
Copy link
Collaborator

sash-a commented Mar 5, 2025

Ye I don't think it's a version conflict, sorry @zombie-einstein. It seemed like it from the original error message where esquilax was the only thing with an upper bound on chex.

@thomfoster can you please try use uv as I've found this worked perfectly for me, you can install uv in your dockerfile or via pip itself. Then remember to install the cuda version of jax as the last dependency you install 😄

@thomfoster
Copy link
Author

@sash-a Thanks for the continued support. I ran the install with uv. We now install without the chex error (I'm so impressed that pip couldnt resolve the conflict but uv could!!). However, we get the same pxtas error as if I manually specify the chex version:

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-5a8224235f31-9401eefcb4d32fa4-1-62f99c3a31a87, line 5; fatal   : Unsupported .version 8.3; current version is '8.1'
ptxas fatal   : Ptx assembly aborted due to errors

My fork here has instructions to reproduce:https://github.com/thomfoster/jumanji

@zombie-einstein
Copy link
Contributor

Ye I don't think it's a version conflict, sorry @zombie-einstein. It seemed like it from the original error message where esquilax was the only thing with an upper bound on chex.

No worries, give me a shout if it does turn out to be anything my end.

Could it be some cuda 12.1 issue along these lines jax-ml/jax#25718 ? May be a clue.

@thomfoster
Copy link
Author

thomfoster commented Mar 5, 2025

@sash-a I realised I had cloned the current main branch of jumanji, not 1.1.0. https://github.com/thomfoster/jumanji/tree/1-1-0%2Bdocker contains my fork and docker that now runs successfully with 1.1.0. Note this also works with pip not uv.

Sadly I wanted to use search and rescue! Looking at the changes from 1.1.0 -> main perhaps it is the commit "fix: Require jax below 0.4.36" causing the issue?

@zombie-einstein
Copy link
Contributor

Sadly I wanted to use search and rescue! Looking at the changes from 1.1.0 -> main perhaps it is the commit "fix: Require jax below 0.4.36" causing the issue?

@thomfoster that was just a temp change due to a short-lived bug in the 0.4.36. I believe the current main branch has no upper bound.

@zombie-einstein
Copy link
Contributor

Wanted to run this to check I'm not responsible for breaking something 😂. This truncated Docker file from yours seems to build ok from the current main with the new environment, can see the install in the uv virtual env:

FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04

ENV CUDA_PATH /usr/local/cuda
ENV CUDA_INCLUDE_PATH /usr/local/cuda/include
ENV CUDA_LIBRARY_PATH /usr/local/cuda/lib64

# Set timezone
ENV TZ=Europe/London DEBIAN_FRONTEND=noninteractive

RUN apt update
RUN apt install -y software-properties-common && add-apt-repository ppa:deadsnakes/ppa
RUN apt install -y \
    git \
    python3.10 \
    python3-pip \
    python3.10-venv \
    python3-setuptools \
    python3-wheel

RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
RUN update-alternatives --set python3 /usr/bin/python3.10

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install uv
RUN uv venv -p=3.10

RUN git clone https://github.com/instadeepai/jumanji.git
WORKDIR /jumanji
RUN uv pip install "jax[cuda12]" -e ".[dev,train]"
RUN uv pip list

@zombie-einstein
Copy link
Contributor

Apologies! I see now this is a runtime error, seeing the same ptxas when I try and import jumanji from that docker file.

@zombie-einstein
Copy link
Contributor

zombie-einstein commented Mar 6, 2025

@thomfoster I think this an issue with a conflict older version of CUDA (12.1 in this case) and a newer version of Jax. The image above was selecting 0.4.38. If you downgrade jax to 0.4.36 using uv pip install "jax[cuda12]==0.4.36" then the ptxas issue is no longer present.

Another option could possibly be upgrading the CUDA version? But the max for this particular image is 12.2, and this still has the same issue.

I think since Esquilax is trying to match to a minor version ^0.4.30 then it's not allowing newer versions of Jax >0.5.0 to be installed that also fix the pxtas issue (according to this issue versions 0.5.0 fixed some pxtas issues).

@sash-a I'll relax some of the bounds in Esquilax, probably better for it's use as a library.

@sash-a
Copy link
Collaborator

sash-a commented Mar 9, 2025

@tomdbar so cool that you want to use search and rescue!

Given the ptax error I'd guess it is a cuda and jax version mismatch. Thanks for looking into this @zombie-einstein! If you can unpin your jax and chex version I think that would be great, maybe just keep a min version for both of them? I've found that a lot of the newer jax versions introduce serious fixes so I think Jumanji's strategy will just be to have a minimum version going forward

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants