Skip to content

GPU setup

If you have access to a GPU with CUDA support, you can gain a considerable processing speedup with dolphin. We use both Numba and JAX, which each have slightly different setups:

  • Numba instructions: https://numba.readthedocs.io/en/stable/cuda/overview.html#software
  • JAX instructions: https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu

Both of these require you to install software which matches your version of CUDA. You can check which CUDA version is installed with nvidia-smi:

$ nvidia-smi | grep -i version
| NVIDIA-SMI 510.39.01    Driver Version: 510.39.01    CUDA Version: 11.6     |

We see that there is version 11.6, so for Numba, we would install

mamba install cudatoolkit=11.6

For JAX, we would run

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Finally, we can also have extra GPU memory tracking by installing

mamba install pynvml

JAX GPU notes

For shared environments, one important note is that JAX will pre-allocate 75% of the GPU's memory at the start of the program to avoid memory fragmentation. See the JAX page for configuration options to avoid this.

For further optimizations and tuning your specific GPU, see the profiling notes