Skip to content

PyCall is not using the packages in Conda #234

@karimn

Description

@karimn

Hi,

I'm new to Julia and PyCall. I'm trying to import the Python transformers package but I'm not having any success and it appears to be a problem with PyCall not using the correct packages. I'm using Julia 1.8.5.

I don't understand why when I pyimport("transformers") I get an error about jaxlib being the wrong version. I confirmed that v0.4.4 is actually installed by Conda.

Here are the steps I tried.

julia> import Pkg

julia> ENV["PYTHON"] = ""
""

julia> Pkg.build("Conda")
    Building Conda → `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/e32a90da027ca45d84678b826fffd3110bb3fc90/build.log`

julia> Pkg.build("PyCall")
    Building Conda ─→ `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/e32a90da027ca45d84678b826fffd3110bb3fc90/build.log`
    Building PyCall → `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/62f417f6ad727987c755549e9cd88c46578da562/build.log`

julia> exit()
julia> import Conda

julia> import PyCall

julia> Conda.ROOTENV
"/home/karim/.julia/conda/3"

julia> PyCall.conda
true

julia> PyCall.python
"/home/karim/.julia/conda/3/bin/python"

julia> PyCall.pyprogramname
"/home/karim/.julia/conda/3/bin/python"
julia> Conda.add("transformers")
[ Info: Running `conda install -y transformers` in root environment
Collecting package metadata (current_repodata.json): done
Solving environment: done

# All requested packages already installed.


julia> PyCall.pyimport("transformers")
ERROR: PyError (PyImport_ImportModule) <class 'RuntimeError'>
RuntimeError('jaxlib is version 0.1.75, but this version of jax requires version >= 0.4.2.')
  File "/home/karim/.local/lib/python3.10/site-packages/transformers/__init__.py", line 30, in <module>
    from . import dependency_versions_check
  File "/home/karim/.local/lib/python3.10/site-packages/transformers/dependency_versions_check.py", line 17, in <module>
    from .utils.versions import require_version, require_version_core
  File "/home/karim/.local/lib/python3.10/site-packages/transformers/utils/__init__.py", line 34, in <module>
    from .generic import (
  File "/home/karim/.local/lib/python3.10/site-packages/transformers/utils/generic.py", line 36, in <module>
    import jax.numpy as jnp
  File "/home/karim/.local/lib/python3.10/site-packages/jax/__init__.py", line 35, in <module>
    from jax import config as _config_module
  File "/home/karim/.local/lib/python3.10/site-packages/jax/config.py", line 17, in <module>
    from jax._src.config import config  # noqa: F401
  File "/home/karim/.local/lib/python3.10/site-packages/jax/_src/config.py", line 28, in <module>
    from jax._src import lib
  File "/home/karim/.local/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 74, in <module>
    version = check_jaxlib_version(
  File "/home/karim/.local/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 63, in check_jaxlib_version
    raise RuntimeError(msg)

Stacktrace:
 [1] pyimport(name::String)
   @ PyCall ~/.julia/packages/PyCall/twYvK/src/PyCall.jl:558
 [2] top-level scope
   @ REPL[6]:1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions