JAX安装

本文安装版本为:jax0.2.21、jaxlib0.1.72、tensorflow2.7
jax下载地址:https://github.com/google/jax/releases
tensorflow下载地址:https://github.com/tensorflow/tensorflow

1 安装CPU版本

pip install "jax[cpu]===0.2.21" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

2 安装GPU版本

2.1 pip安装jax和jaxlib

(1)安装jax:pip install -v jax==0.2.21

(2)安装jaxlib:
版本对应链接:https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
下载链接:https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.72+cuda111-cp38-none-manylinux2010_x86_64.whl

1>在线安装:pip install jaxlib==0.1.72+cuda111 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
2>本地安装:pip install jaxlib-0.1.72+cuda111-cp38-none-manylinux2010_x86_64.whl

2.2 源码安装

(1)安装必要的python依赖:

pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install numpy==1.19.5 six wheel scipy==1.7.3 -i https://pypi.tuna.tsinghua.edu.cn/simple/

(2)使用修改后的tensorflow存储库从源代码构建jaxlib:

cd /public/home/xxx/data/jax/build/build.py
python build.py --enable_cuda --bazel_options=--override_repository=org_tensorflow=/public/home/xxx/data/tensorflow

(3)安装新构建的jaxlib轮子:

pip install /public/home/xxx/data/jax/build/dist/jaxlib-0.3.15-cp38-none-manylinux2014_x86_64.whl

编译成功如下图:
在这里插入图片描述
(4)安装jax:
在jax主目录执行:pip install -e .

猜你喜欢

转载自blog.csdn.net/weixin_50008473/article/details/126589113