-
Book Overview & Buying
-
Table Of Contents
GPU-Accelerated Computing with Python 3 and CUDA
By :
In this chapter, we explored the capabilities of the high-level JAX library to harness the power of GPU computing through various optimization examples. We began by demonstrating the impact of JIT compilation, achieving performance gains with its use. Next, we explored automatic differentiation, using JAX's autodiff features, which we combined with JIT. vmap allowed us to auto-vectorize our functions, further enhancing performance.
We then applied JAX to build a linear regression model and calculated electrical resistance. Next, we extended our approach by implementing a neural network from scratch. This neural network was used to describe the response of an RLC circuit.
To further improve our training process, we incorporated physical laws into the loss function, creating a physics-informed neural network. This enhancement enabled our model to extrapolate beyond its target training data, showcasing its superior advantages over the conventional neural network.
In the next...