Package Details: python-jaxlib-cuda 0.6.0-1

Git Clone URL: https://aur.archlinux.org/python-jaxlib-cuda.git (read-only, click to copy)
Package Base: python-jaxlib-cuda
Description: XLA library for JAX
Upstream URL: https://github.com/jax-ml/jax/
Keywords: deep-learning google jax maching-learning xla
Licenses: Apache-2.0
Groups: jax
Conflicts: python-jaxlib
Provides: python-jaxlib
Submitter: daskol
Maintainer: daskol
Last Packager: daskol
Votes: 9
Popularity: 0.004777
First Submitted: 2023-02-12 23:18 (UTC)
Last Updated: 2025-04-20 20:54 (UTC)

Latest Comments

1 2 3 4 5 6 Next › Last »

medaminezghal commented on 2025-10-03 08:09 (UTC)

@daskol @truncs Using this PKGBUILD file bellow and using this merge request, I was able to bypass all problems related with building JAX with local CUDA but I got problems related to architecture and options:

clang-20: error: unknown argument: '-Xcuda-fatbinary=--compress-all' clang-20: error: unknown argument: '-nvcc_options=expt-relaxed-constexpr' clang-20: warning: CUDA version is newer than the latest partially supported version 12.8 [-Wunknown-cuda-version] clang-20: error: unsupported CUDA gpu architecture: sm_88 

I think if I found a way to use NVCC instead of Clang, the problem will be solved.

This is the PKGBUILD file:

# Maintainer: Daniel Bershatsky <bepshatsky@yandex.ru>  pkgname=python-jaxlib pkgver=0.7.2 pkgrel=1 pkgdesc='XLA library for JAX' arch=('x86_64') url='https://github.com/jax-ml/jax/' license=('Apache-2.0') groups=('jax') depends=(     'python-ml-dtypes'     'python-numpy'     'python-scipy' ) makedepends=('clang' 'python-build' 'python-installer' 'python-setuptools' 'python-wheel') _bazel_ver=7.4.1 _xla_commit=0fccb8a6037019b20af2e502ba4b8f5e0f98c8f6 #https://github.com/jax-ml/jax/blob/jax-v0.7.2/third_party/xla/revision.bzl source=("jax-${pkgver}.tar.gz::$url/archive/refs/tags/jax-v${pkgver}.tar.gz"         "openxla-xla-${_xla_commit:0:7}.tar.gz::https://api.github.com/repos/openxla/xla/tarball/$_xla_commit"         "bazel-${_bazel_ver}-linux-x86_64::https://github.com/bazelbuild/bazel/releases/download/${_bazel_ver}/bazel-${_bazel_ver}-linux-x86_64") noextract=("bazel-${_bazel_ver}-linux-x86_64") sha256sums=('56d92604f1bb60bb3dbd7dc7c7dc21502d10b3474b8b905ce29ce06db6a26e45'             '504315851ae676bf27122f20f68980fafb2a2c37e10113f58b03f6c284c55cfd'             'c97f02133adce63f0c28678ac1f21d65fa8255c80429b588aeeba8a1fac6202b')  prepare() {     ln -sf $(readlink bazel-${_bazel_ver}-linux-x86_64) $srcdir/jax-jax-v${pkgver}/build     chmod +x $srcdir/bazel-${_bazel_ver}-linux-x86_64     cd $srcdir/openxla-xla-${_xla_commit:0:7}     sed -i 's/5354032ea08eadd7fc4456477f7f7c6308818509/54cbae0d3a67fa890b4c3d9ee162b7860315e341/g' third_party/gloo/workspace.bzl     sed -i 's/5759a06e6c8863c58e8ceadeb56f7c701fec89b2559ba33a103a447207bf69c7/61089361dbdbc9d6f75e297148369b13f615a3e6b78de1be56cce74ca2f64940/g' third_party/gloo/workspace.bzl }  build() {     local  CUDA_MAJOR_VERSION=$(/opt/cuda/bin/nvcc --version | sed -n 's/^.*release \([0-9]\+\).*/\1/p')     # Override default version.     export JAXLIB_RELEASE=$pkgver      cd $srcdir/jax-jax-v$pkgver     build/build.py build \         --bazel_path="$srcdir/bazel-${_bazel_ver}-linux-x86_64" \         --bazel_startup_options="--output_user_root=$srcdir/bazel" \         --bazel_options='--repo_env=LOCAL_CUDA_PATH="/opt/cuda"' \         --bazel_options='--repo_env=LOCAL_CUDNN_PATH="/opt/cuda"' \         --bazel_options='--repo_env=LOCAL_NCCL_PATH="/usr"' \         --bazel_options='--repo_env=LOCAL_NVSHMEM_PATH="/usr"' \         --bazel_options='--action_env=TF_NVCC_CLANG="0"' \         --verbose \         --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt \         --cuda_major_version=$CUDA_MAJOR_VERSION \         --cuda_compute_capabilities=$(echo sm_{75,80,86,87,88,89,90,90a,100,100a,103,103a,110,110a,120,120a,121,121a} | tr ' ' ','),compute_121 \         --clang_path=/usr/bin/clang \         --target_cpu_features=release \         --local_xla_path=$srcdir/openxla-xla-${_xla_commit:0:7} }  package() {     cd $srcdir/jax-jax-v$pkgver     install -Dm 644 LICENSE "$pkgdir/usr/share/licenses/$pkgname/LICENSE"     python -m installer --compile-bytecode=1 --destdir=$pkgdir \         $srcdir/jax-jax-v$pkgver/dist/jaxlib-$pkgver-*.whl     python -m installer --compile-bytecode=1 --destdir=$pkgdir \         $srcdir/jax-jax-v$pkgver/dist/jax_cuda$CUDA_MAJOR_VERSION_pjrt-$pkgver-*.whl     python -m installer --compile-bytecode=1 --destdir=$pkgdir \         $srcdir/jax-jax-v$pkgver/dist/jax_cuda$CUDA_MAJOR_VERSION_plugin-$pkgver-*.whl } 

Could you help me to fix the problem?

truncs commented on 2025-09-11 00:16 (UTC)

Using jax 0.7.0, clang-20 and including <cstdint> as @wheelsofindustry mentioned, I was able to build jaxlib-cuda but jaxlib-plugin build still fails. This one also looks like it is related to the toolchain

# Configuration: a1e1d9399d833527724a5906f125d5456703d1b9bc091a7b2e8a1e4f7ddfaf4c # Execution platform: @@local_execution_config_platform//:platform clang-20: warning: CUDA version is newer than the latest partially supported version 12.8 [-Wunknown-cuda-version] In file included from <built-in>:1: In file included from /usr/lib/clang/20/include/__clang_cuda_runtime_wrapper.h:41: In file included from /usr/lib/clang/20/include/cuda_wrappers/cmath:28: /usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/15.2.1/../../../../include/c++/15.2.1/cmath:55:15: fatal error: 'math.h' file not found    55 | #include_next <math.h>       |               ^~~~~~~~ 1 error generated when compiling for sm_100. Target //jaxlib/tools:build_gpu_kernels_wheel failed to build INFO: Elapsed time: 1.658s, Critical Path: 0.91s INFO: 37 processes: 34 internal, 3 local. ERROR: Build did NOT complete successfully ERROR: Build failed. Not running target 2025-09-10 17:09:22,680 - DEBUG - Command finished with return code 1 Traceback (most recent call last): 

I will probably remove the arch jax and install jax using pip in a virtualenv.

truncs commented on 2025-08-27 23:36 (UTC)

This still doesn't build

Traceback (most recent call last):   File "/home/aditya/.cache/yay/python-jaxlib-cuda/src/jax-jax-v0.6.0/build/build.py", line 778, in <module>     asyncio.run(main())     ~~~~~~~~~~~^^^^^^^^   File "/usr/lib/python3.13/asyncio/runners.py", line 195, in run     return runner.run(main)            ~~~~~~~~~~^^^^^^   File "/usr/lib/python3.13/asyncio/runners.py", line 118, in run     return self._loop.run_until_complete(task)            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^   File "/usr/lib/python3.13/asyncio/base_events.py", line 725, in run_until_complete     return future.result()            ~~~~~~~~~~~~~^^   File "/home/aditya/.cache/yay/python-jaxlib-cuda/src/jax-jax-v0.6.0/build/build.py", line 723, in main     raise RuntimeError(f"Command failed with return code {result.return_code}") RuntimeError: Command failed with return code 1 ==> ERROR: A failure occurred in build().     Aborting...  -> error making: python-jaxlib-cuda-exit status 4 

What is the alternative to have jax working with cuda?

wheelsofindustry commented on 2025-05-24 20:50 (UTC) (edited on 2025-05-24 20:52 (UTC) by wheelsofindustry)

@truncs adding an include statement for cstdint to types.h got me past the same issue

--- /python-jaxlib/src/bazel/*/gloo/gloo/types.h    2023-12-02 17:32:51.000000000 -0800 +++ /python-jaxlib-cuda/src/bazel/*/external/gloo/gloo/types.h  2025-05-24 06:28:59.597042640 -0700 @@ -5,6 +5,7 @@  #pragma once   #include <iostream> +#include <cstdint>   #ifdef __CUDA_ARCH__  #include <cuda.h> 

works on python-jaxlib & python-jaxlib-cuda, but after that it complains about the latest cuda version being too new (12.8)? rolling cuda back to 12.2 requires gcc12 trying that now.

daskol commented on 2025-05-07 08:27 (UTC)

@truncs It seems that the issues is gcc-libs>=15 again (other AUR packages does not build now too). Clang shares GCC's standard library that sometimes causes odd issues like broken include order, missing type defs, unexpected #warnings directives that causes errors, etc.

I'll try to build it and patch sources if possible.

truncs commented on 2025-05-06 17:46 (UTC)

Actually forcing it to use gcc is something I am not able to do. I tried removing the clang options and setting the repo_env to point to gcc but it would still use clang.

Additional Bazel build options: ['--action_env=JAXLIB_RELEASE', '--action_env=TF_CUDA_COMPUTE_CAPABILITIES=sm_80,sm_86,sm_89,sm_90,compute_90', '--repo_env=HERMETIC_CUDA_VERSION=12.8.0', '--repo_env=LOCAL_NCCL_PATH=/usr', '--repo_env=CC=/usr/bin/gcc', '--repo_env=CXX=/usr/bin/g++'] 2025-05-06 10:41:34,195 - INFO - Bazel options written to .jax_configure.bazelrc 2025-05-06 10:41:34,195 - DEBUG - Artifacts output directory: /home/aditya/.cache/yay/python-jaxlib-cuda/src/jax-jax-v0.6.0/dist   2025-05-06 10:41:34,195 - INFO - Building jaxlib for linux x86_64... 2025-05-06 10:41:34,195 - INFO - [EXECUTING] .cache/yay/python-jaxlib-cuda/src/bazel-7.4.1-linux-x86_64 --output_user_root=.cache/yay/python-jaxlib-cuda/src/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.13 --verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/bin/clang-19" --repo_env=CC="/usr/bin/clang-19" --repo_env=CXX="/usr/bin/clang++-19" --repo_env=BAZEL_COMPILER="/usr/bin/clang-19" --config=clang --config=mkl_open_source_only --config=avx_posix --config=cuda --config=cuda_libraries_from_stubs --action_env=CLANG_CUDA_COMPILER_PATH="/usr/bin/clang-19" --config=build_cuda_with_clang --repo_env=HERMETIC_CUDNN_VERSION=9.8.0 --action_env=JAXLIB_RELEASE --action_env=TF_CUDA_COMPUTE_CAPABILITIES=sm_80,sm_86,sm_89,sm_90,compute_90 --repo_env=HERMETIC_CUDA_VERSION=12.8.0 --repo_env=LOCAL_NCCL_PATH=/usr --repo_env=CC=/usr/bin/gcc --repo_env=CXX=/usr/bin/g++ --config=cuda_libraries_from_stubs //jaxlib/tools:build_wheel -- --output_path=".cache/yay/python-jaxlib-cuda/src/jax-jax-v0.6.0/dist" --cpu=x86_64 --jaxlib_git_hash=72273970ff20769d87749c73d893ea28171fd53d 

daskol commented on 2025-05-06 15:12 (UTC) (edited on 2025-05-06 15:13 (UTC) by daskol)

@truncs It seems a compiler issue. Try to enforce use of gcc13. It could solve the issue.

truncs commented on 2025-05-05 20:25 (UTC)

Getting this error for 0.6.0-1

 TF_CUDA_COMPUTE_CAPABILITIES=sm_80,sm_86,sm_89,sm_90,compute_90 \   /usr/lib/llvm18/bin/clang-18 -MD -MF bazel-out/k8-opt/bin/external/gloo/_objs/gloo/types.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/gloo/_objs/gloo/types.pic.o' -iquote external/gloo -iquote bazel-out/k8-opt/bin/external/gloo -isystem external/gloo -isystem bazel-out/k8-opt/bin/external/gloo -fmerge-all-constants -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -Wno-invalid-partial-specialization -fno-omit-frame-pointer -no-canonical-prefixes -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '--cuda-path=external/cuda_nvcc' '-fvisibility=hidden' -Wno-sign-compare -Wno-unknown-warning-option -Wno-stringop-truncation -Wno-array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' '-DNB_DOMAIN=jax' -Wno-gnu-offsetof-extensions -Qunused-arguments '-Werror=mismatched-tags' '-Wno-error=c23-extensions' -mavx -Wno-gnu-offsetof-extensions -Qunused-arguments '-Werror=mismatched-tags' '-Wno-error=c23-extensions' -mavx '-std=c++17' -fexceptions -Wno-unused-variable -c external/gloo/gloo/types.cc -o bazel-out/k8-opt/bin/external/gloo/_objs/gloo/types.pic.o) # Configuration: 6f9d4bb27a6c9bc488a0cd126309e9bf2df92d2e54fc24d1a8d5d26084fd8413 # Execution platform: @@local_execution_config_platform//:platform In file included from external/gloo/gloo/types.cc:9: external/gloo/gloo/types.h:66:11: error: unknown type name 'uint8_t'    66 | constexpr uint8_t kGatherSlotPrefix = 0x01;       |           ^ external/gloo/gloo/types.h:67:11: error: unknown type name 'uint8_t'    67 | constexpr uint8_t kAllgatherSlotPrefix = 0x02;       |           ^ external/gloo/gloo/types.h:68:11: error: unknown type name 'uint8_t'    68 | constexpr uint8_t kReduceSlotPrefix = 0x03;       |           ^ external/gloo/gloo/types.h:69:11: error: unknown type name 'uint8_t'    69 | constexpr uint8_t kAllreduceSlotPrefix = 0x04;       |           ^ external/gloo/gloo/types.h:70:11: error: unknown type name 'uint8_t'    70 | constexpr uint8_t kScatterSlotPrefix = 0x05;       |           ^ external/gloo/gloo/types.h:71:11: error: unknown type name 'uint8_t'    71 | constexpr uint8_t kBroadcastSlotPrefix = 0x06;       |           ^ external/gloo/gloo/types.h:72:11: error: unknown type name 'uint8_t'    72 | constexpr uint8_t kBarrierSlotPrefix = 0x07;       |           ^ external/gloo/gloo/types.h:73:11: error: unknown type name 'uint8_t'    73 | constexpr uint8_t kAlltoallSlotPrefix = 0x08;       |           ^ external/gloo/gloo/types.h:77:21: error: unknown type name 'uint8_t'    77 |   static Slot build(uint8_t prefix, uint32_t tag);       |                     ^ external/gloo/gloo/types.h:77:37: error: unknown type name 'uint32_t'    77 |   static Slot build(uint8_t prefix, uint32_t tag);       |                                     ^ external/gloo/gloo/types.h:79:12: error: unknown type name 'uint64_t'    79 |   operator uint64_t() const {       |            ^ external/gloo/gloo/types.h:86:17: error: unknown type name 'uint64_t'    86 |   explicit Slot(uint64_t base, uint64_t delta) : base_(base), delta_(delta) {}       |                 ^ external/gloo/gloo/types.h:86:32: error: unknown type name 'uint64_t'    86 |   explicit Slot(uint64_t base, uint64_t delta) : base_(base), delta_(delta) {}       |                                ^ external/gloo/gloo/types.h:88:9: error: unknown type name 'uint64_t'    88 |   const uint64_t base_;       |         ^ external/gloo/gloo/types.h:89:9: error: unknown type name 'uint64_t'    89 |   const uint64_t delta_;       |         ^ external/gloo/gloo/types.h:97:3: error: unknown type name 'uint16_t'    97 |   uint16_t x;       |   ^ external/gloo/gloo/types.cc:16:18: error: unknown type name 'uint8_t'    16 | Slot Slot::build(uint8_t prefix, uint32_t tag) {       |                  ^ external/gloo/gloo/types.cc:16:34: error: unknown type name 'uint32_t'    16 | Slot Slot::build(uint8_t prefix, uint32_t tag) {       |                                  ^ external/gloo/gloo/types.cc:17:3: error: unknown type name 'uint64_t'    17 |   uint64_t u64prefix = ((uint64_t)prefix) << 56;       |   ^ fatal error: too many errors emitted, stopping now [-ferror-limit=] 20 errors generated. 

I did a quicksearch on upstream and I can't find anything related to this. Thoughts?

medaminezghal commented on 2025-04-22 19:02 (UTC)

@daskol the build will fails because you use --bazel_options="--action_env=TF_CUDA_COMPUTE_CAPABILITIES=${CUDA_COMPUTE_CAPABILITIES}" \.

It will not succeed because clang can't compile for sm_100 and sm_120 that exist in .bazelrc.

Instead, use --bazel_options="--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES=${CUDA_COMPUTE_CAPABILITIES}" \.

And I have some suggestions:

1- Use the default Clang provided by system (The previous problem was fixed).

2- Add support for more cards by using: export CUDA_COMPUTE_CAPABILITIES=sm_50,sm_52,sm_53,sm_60,sm_61,sm_62,sm_70,sm_72,sm_75,sm_80,sm_86,sm_87,sm_89,sm_90,sm_90a,compute_90

daskol commented on 2025-03-20 22:39 (UTC)

@medaminezghal Thanks for your PKGBUILD. My build server was under maintenance for a while.

However, there is a building issue for 0.5.3 regarding transition to bazel 7. It turns out that python rules does not handle properly symlink creation. This is why python -m build fails with missing Lorem ipsum.txt file.

Looking for a solution. Probably, manual downgrading to bazel 6 will solve the issue.