安装适用于 CUDA 11.8 的 JAXlib 需要执行以下步骤:
确保您的系统已经安装了 CUDA 11.8 和相应的 GPU 驱动程序。
创建一个虚拟环境(可选但推荐),以防止与现有的 Python 环境产生冲突。执行以下命令创建虚拟环境:
python -m venv jax_env
激活虚拟环境。在 Windows 上执行以下命令:
jax_env\Scripts\activate
在 macOS/Linux 上执行以下命令:
source jax_env/bin/activate
安装 JAXlib。执行以下命令:
pip install jax jaxlib==0.1.69+cuda11.8 -f https://storage.googleapis.com/jax-releases/jax_releases.html
这将安装 JAX 和适用于 CUDA 11.8 的 JAXlib。
验证安装。在 Python 解释器中执行以下命令:
import jax
import jaxlib
print("JAX version:", jax.__version__)
print("JAXlib version:", jaxlib.__version__)
如果没有错误,并且输出中显示了正确的版本号,则说明 JAXlib 已成功安装并与 CUDA 11.8 兼容。
请注意,上述安装方法假设您已经正确安装了 CUDA 11.8 和相应的 GPU 驱动程序。如果遇到安装问题,请根据具体错误信息检查您的 CUDA 和 GPU 驱动程序版本,并确保它们与 JAXlib 兼容。