Book Image

Enhancing Deep Learning with Bayesian Inference

By : Matt Benatan, Jochem Gietema, Marian Schneider
Book Image

Enhancing Deep Learning with Bayesian Inference

By: Matt Benatan, Jochem Gietema, Marian Schneider

Overview of this book

Deep learning has an increasingly significant impact on our lives, from suggesting content to playing a key role in mission- and safety-critical applications. As the influence of these algorithms grows, so does the concern for the safety and robustness of the systems which rely on them. Simply put, typical deep learning methods do not know when they don’t know. The field of Bayesian Deep Learning contains a range of methods for approximate Bayesian inference with deep networks. These methods help to improve the robustness of deep learning systems as they tell us how confident they are in their predictions, allowing us to take more in how we incorporate model predictions within our applications. Through this book, you will be introduced to the rapidly growing field of uncertainty-aware deep learning, developing an understanding of the importance of uncertainty estimation in robust machine learning systems. You will learn about a variety of popular Bayesian Deep Learning methods, and how to implement these through practical Python examples covering a range of application scenarios. By the end of the book, you will have a good understanding of Bayesian Deep Learning and its advantages, and you will be able to develop Bayesian Deep Learning models for safer, more robust deep learning systems.
Table of Contents (11 chapters)

2.3 Exploring the Gaussian process

As we’ve seen in the previous section, sampling quickly becomes prohibitively expensive. To address this, we can use ML models specifically designed to produce uncertainty estimates – the gold standard of which is the Gaussian process.

The Gaussian process, or GP, has become a staple probabilistic ML model, seeing use in a broad variety of applications from pharmacology through to robotics. Its success is largely down to its ability to produce high-quality uncertainty estimates over its predictions in a well-principled fashion. So, what do we mean by a Gaussian process?

In essence, a GP is a distribution over functions. To understand what we mean by this, let’s take a typical ML use case. We want to learn some function f(x), which maps a series of inputs x onto a series of outputs y, such that we can approximate our output via y = f(x). Before we see any data, we know nothing about our underlying function; there is an infinite number of possible functions this could be:

PIC

Figure 2.10: Illustration of space of possible functions before seeing data

Here, the black line is the true function we wish to learn, while the dotted lines are the possible functions given the data (in this case, no data). Once we observe some data, we see that the number of possible functions becomes more constrained, as we see here:

PIC

Figure 2.11: Illustration of space of possible functions after seeing some data

Here, we see that our possible functions all pass through our observed data points, but outside of those data points, our functions take on a range of very different values. In a simple linear model, we don’t care about these deviations in possible values: we’re happy to interpolate from one data point to another, as we see in Figure 2.12:

PIC

Figure 2.12: Illustration of linearly interpolating through our observations

But this interpolation can lead to wildly inaccurate predictions, and has no way of accounting for the degree of uncertainty associated with our model predictions. The deviations that we see here in the regions without data points are exactly what we want to capture with our GP. When there are a variety of possible values our function can take, then there is uncertainty – and through capturing the degree of uncertainty, we are able to estimate what the possible variation in these regions may be.

Formally, a GP can be defined as a function:

f(x) ≈ GP (m (x),k(x,x′))

Here, m(x) is simply the mean of our possible function values for a given point x:

m (x) = 𝔼[f (x)]

The next term, k(x,x) is a covariance function, or kernel. This is a fundamental component of the GP as it defines the way we model the relationship between different points in our data. GPs use the mean and covariance functions to model the space of possible functions, and thus to produce predictions as well as their associated uncertainties. Now that we’ve introduced some of the high-level concepts, let’s dig a little deeper and understand exactly how it is they model the space of possible functions, and thus estimate uncertainty. To do this, we need to understand GP priors.

2.3.1 Defining our prior beliefs with kernels

GP kernels describe the prior beliefs we have about our data, and so you’ll often see them referred to as GP priors. In the same way that the prior in equation 2.3 tells us something about the probability of the outcome of our two dice rolls, the GP prior tells us something important about the relationship we expect from our data.

While there are advanced methods for inferring a prior from our data, they are beyond the scope of this book. We will instead focus on more traditional uses of GPs, for which we select a prior using our knowledge of the data we’re working with.

In the literature and any implementations you encounter, you’ll see that the GP prior is often referred to as the kernel or covariance function (just as we have here). These three terms are all interchangeable, but for consistency with other work, we will henceforth refer to this as the kernel. Kernels simply provide a means of calculating a distance between two data points, and are exdivssed as k(x,x), where x and xare data points, and k() represents the function of the kernel. While the kernel can take on many forms, there are a small number of fundamental kernels that are used in a large proportion of GP applications.

Perhaps the most commonly encountered kernel is the squared exponential or radial basis function (RBF) kernel. This kernel takes the form:

 (x − x ′)2 k(x,x ′) = σ2exp − ----2---- 2l

This introduces us to a couple of common kernel parameters: l and σ2. The output variance parameter σ2 is simply a scaling factor, used to control the distance of the function from its mean. The length scale parameter l controls the smoothness of the function – in other words, how much your function is expected to vary across particular dimensions. This parameter can either be a scalar that is applied to all input dimensions, or a vector with a different scalar value for each input dimension. The latter is often achieved using Automatic Relevance Determination, or ARD, which identifies the relevant values in the input space.

GPs make predictions via a covariance matrix based on the kernel – essentially comparing a new data point to previously observed data points. However, just as with all ML models, GPs need to be trained, and this is where the length scale comes in. The length scale forms the parameters of our GP, and through the training process it learns the optimal value(s) for the length scale(s). This is typically done using a nonlinear optimizer, such as the Broyden-Fletcher-Goldfarb-Shanno (BFGS) optimizer. Many optimizers can be used, including optimizers you may be familiar with for deep learning, such as stochastic gradient descent and its variants.

Let’s take a look at how different kernels affect GP predictions. We’ll start with a straightforward example – a simple sine wave:

PIC

Figure 2.13: Plot of sine wave with four sampled points

We can see the function illustrated here, as well as some points sampled from this function. Now, let’s fit a GP with a periodic kernel to the data. The periodic kernel is defined as:

 ′ 2 ( 2sin2(π |x − x′|∕p)) kper(x, x) = σ exp -------l2--------

Here, we see a new parameter: p. This is simply the period of the periodic function. Setting p = 1 and applying a GP with a periodic kernel to the preceding example, we get the following:

PIC

Figure 2.14: Plot of posterior predictions from a periodic kernel with p = 1

This looks pretty noisy, but you should be able to see that there is clear periodicity in the functions produced by the posterior. It’s noisy for a couple of reasons: a lack of data, and a poor prior. If we’re limited on data, we can try to fix the problem by improving our prior. In this case, we can use our knowledge of the periodicity of the function to improve our prior by setting p = 6:

PIC

Figure 2.15: Plot of posterior predictions from a periodic kernel with p = 6

We see that this fits the data pretty well: we’re still uncertain in regions for which we have little data, but the periodicity of our posterior now looks sensible. This is possible because we’re using an informative prior; that is, a prior that incorporates information that describes the data well. This prior is composed of two key components:

  • Our periodic kernel

  • Our knowledge about the periodicity of the function

We can see how important this is if we modify our GP to use an RBF kernel:

PIC

Figure 2.16: Plot of posterior predictions from an RBF kernel

With an RBF kernel, we see that things are looking pretty chaotic again: because we have limited data and a poor prior, we’re unable to appropriately constrain the space of possible functions to fit our true function. In the ideal case, we’d fix this by using a more appropriate prior, as we saw in Figure 2.15 – but this isn’t always possible. Another solution is to sample more data. Sticking with our RBF kernel, we sample 10 data points from our function and re-train our GP:

PIC

Figure 2.17: Plot of posterior predictions from an RBF kernel, trained on 10 observations

This is looking much better – but what if we have more data and an informative prior?

PIC

Figure 2.18: Plot of posterior predictions from a periodic kernel with p = 6, trained on 10 observations

The posterior now fits our true function very closely. Because we don’t have infinite data, there are still some areas of uncertainty, but the uncertainty is relatively small.

Now that we’ve seen some of the core principles in action, let’s return to our example from Figures 2.10-2.12. Here’s a quick reminder of our target function, our posterior samples, and the linear interpolation we saw earlier:

PIC

Figure 2.19: Plot illustrating the difference between linear interpolation and the true function

Now that we’ve got some idea of how a GP will affect our predictive posterior, it’s easy to see that linear interpolation falls very short of what we achieve with a GP. To illustrate this more clearly, let’s take a look at what the GP prediction would be for this function given the three samples:

PIC

Figure 2.20: Plot illustrating the difference between GP predictions and the true function

Here, the dotted lines are our mean (μ) predictions from the GP, and the shaded area is the uncertainty associated with those predictions – the standard deviation (σ) around the mean. Let’s contrast what we see in Figure 2.20 with Figure 2.19. The differences may seem subtle at first, but we can clearly see that this is no longer a straightforward linear interpolation: the predicted values from the GP are being ”pulled” toward our actual function values. As with our earlier sine wave examples, the behavior of the GP predictions are affected by two key factors: the prior (or kernel) and the data.

But there’s another crucial detail illustrated in Figure 2.20: the predictive uncertainties from our GP. We see that, unlike many typical ML models, a GP gives us uncertainties associated with its predictions. This means we can make better decisions about what we do with the model’s predictions – having this information will help us to ensure that our systems are more robust. For example, if the uncertainty is too great, we can fall back to a manual system. We can even keep track of data points with high predictive uncertainty so that we can continuously refine our models.

We can see how this refinement affects our predictions by adding a few more observations – just as we did in the earlier examples:

PIC

Figure 2.21: Plot illustrating the difference between GP predictions and the true function, trained on 5 observations

Figure 2.21 illustrates how our uncertainty changes over regions with different numbers of observations. We see here that between x = 3 and x = 4 our uncertainty is quite high. This makes a lot of sense, as we can also see that our GP’s mean predictions deviate significantly from the true function values. Conversely, if we look at the region between x = 0.5 and x = 2, we can see that our GP’s predictions follow the true function fairly closely, and our model is also more confident about these predictions, as we can see from the smaller interval of uncertainty in this region.

What we’re seeing here is a great example of a very important concept: well calibrated uncertainty – also termed high-quality uncertainty. This refers to the fact that, in regions where our predictions are inaccurate, our uncertainty is also high. Our uncertainty estimates are poorly calibrated if we’re very confident in regions with inaccurate predictions, or very uncertain in regions with accurate predictions.

GPs are what we can term a well principled method – this means that they have solid mathematical foundations, and thus come with strong theoretical guarantees. One of these guarantees is that they are well calibrated, and this is what makes GPs so popular: if we use GPs, we know we can rely on their uncertainty estimates.

Unfortunately, however, GPs are not without their shortcomings – we’ll learn more about these in the following section.

2.3.2 Limitations of Gaussian processes

Given the fact that GPs are well-principled and capable of producing high-quality uncertainty estimates, you’d be forgiven for thinking they’re the perfect uncertainty-aware ML model. GPs struggle in a few key situations:

  • High-dimensional data

  • Large amounts of data

  • Highly complex data

The first two points here are largely down to the inability of GPs to scale well. To understand this, we just need to look at the training and inference procedures for GPs. While it’s beyond the scope of this book to cover this in detail, the key point here is in the matrix operations required for GP training.

During training, it is necessary to invert a D × D matrix, where D is the dimensionality of our data. Because of this, GP training quickly becomes computationally prohibitive. This can be somewhat alleviated through the use of Cholesky deomposition, rather than direct matrix inversion. As well as being more computationally efficient, Cholesky decomposition is also more numerically stable. Unfortunately, Cholesky decomposition also has its weaknesses: computationally, its complexity is O(n3). This means that, as the size of our dataset increases, GP training becomes more and more expensive.

But it’s not only training that’s affected: because we need to compute the covariance between a new data point and all observed data points at inference, GPs have a O(n2) computational complexity at inference.

As well as the computational cost, GPs aren’t light in memory: because we need to store our covariance matrix K, GPs have a O(n2) memory complexity. Thus, in the case of large datasets, even if we have the compute resources necessary to train them, it may not be practical to use them in real-world applications due to their memory requirements.

The last point in our list concerns the complexity of data. As you are probably aware – and as we’ll touch on in Chapter 3, Fundamentals of Deep Learning – one of the major advantages of DNNs is their ability to process complex, high-dimensional data through layers of non-linear transformations. While GPs are powerful, they’re also relatively simple models, and they’re not able to learn the kinds of powerful feature representations that are possible with DNNs.

All of these factors mean that, while GPs are an excellent choice for relatively low-dimensional data and reasonably small datasets, they aren’t practical for many of the complex problems we face in ML. And so, we turn to BDL methods: methods that have the flexibility and scalability of deep learning, while also producing model uncertainty estimates.