-
Book Overview & Buying
-
Table Of Contents
GPU-Accelerated Computing with Python 3 and CUDA
By :
JAX offers a clean path to distributed machine learning. As far as using a high-level API is concerned, JAX automatically manages device allocation, parallel computation, and gradient synchronization. This facilitates data processing or machine learning model training across multiple GPUs with minimal changes to the code. We will specifically use the jax.pmap (parallel map) function, which allows parallelizing a computation across multiple devices. JAX also handles data sharding for inputs, so we can basically pass a full batch and it will divide it evenly across the GPUs, enabling synchronous data parallel training with very little boilerplate. This chapter focuses on multi-GPU computing; readers new to JAX should refer to Chapter 10 for an introduction.
We will demonstrate how the previous matrix multiplication example can be executed across multiple GPUs using JAX. We follow the same parallelization strategy by distributing the rows...