Why Simple Models Generalise Better
An exploration of model complexity and generalisation bounds
With the advent of neural networks and GPUs, there has been an implicit race to build increasingly complex network architectures, some with billions of trainable weights. However, students of statistical learning theory will see the flaw in this approach; increasing model complexity results in a weaker generalisation bound.
Model Complexity
Let’s break that statement down. Firstly, what do we mean by model complexity? Intuitively, a model is complex when it has more trainable weights, because we can develop a wider variety of functions from the combinations of those weights. Restricting ourselves to polynomials, we can say that the complexity of the model will be the highest order of the polynomial. However, when we consider all functions, the notion of complexity becomes more difficult to quantify. The issue with using number of weights as a proxy is that we are not accounting for the way the weights affect the function itself. In other words, consider the following two functions with two trainable weights
Even though both functions only have two trainable weights, f1 is only able to capture linear patterns, while f2 is able to capture nonlinear patterns (especially when the data is normalised to fit in the interval [0, 1]). This points us to define model complexity in a more general way.
In the statistical learning theory literature, there exist several notions of model complexity: Rademacher complexity, Gaussian complexity, VC dimension, Fisher information, to name a few. However, for the purpose of this article, we will focus on Rademacher complexity, which I feel is the most intuitive definition of model complexity.
Rademacher Complexity
The empirical Rademacher complexity, for a discriminative binary classifier, f(x), is defined as follows
Note that we can extend this definition for multi-class classifiers and regression models by changing the number of possible values and distribution respectively, of Rademacher variables.
The intuition behind Rademacher complexity is as follows: the complexity of a function is given by the complexity of the function class it belongs to, which is then given by the ability of the functions in that function class to fit to noise. In particular, it gives the maximum accuracy of a set of functions in predicting noise. Let’s take a minute to understand why the ability to predict noise is a good metric to measure the complexity of a model. Another way to describe model complexity is the expressivity of a model. In other words, how complex of a pattern can the model express. This lends directly to describing how much noise can the model effectively capture.
Note that although we have an easy to understand definition of model complexity, computing the actual empirical Rademacher complexity of an arbitrary model is a very hard task, because we have to do an optimisation over all possible functions in a function class. Another remark is regarding choosing the function class. The astute reader would have noticed that we can choose a function class as complex as we want, as long as it contains the function we are considering. Which function class should be choose then? To answer this shortly, we would typically want to choose the function class with the smallest Rademacher complexity that contains the function we are considering. The reasoning behind this will become clearer when we consider how the complexity of a function relates to its errors.
Generalisation Bounds
Having defined model complexity, we go back to our initial statement. The generalisation bound is an inequality that allows us to find an upper bound on the generalisation error of a model in terms of quantities that are easier to compute. This is useful because
Generalisation error is typically not computable. We can have proxies such as error on a held-out test dataset, but several papers point to this being inadequate.
Upper bounding generalisation error allows us to quantify the worst case performance. This is particularly useful in situations where errors have high cost. Using this upper bound, we will be able to compute the costs of our error in the worst case scenarios
Before we are able to derive the generalisation bound, we need some important mathematical machinery. Deriving generalisation bounds hinges on the Hoeffding Inequality and McDiarmid Inequality.
Hoeffding Inequality
Given some independent variables, and their sum
Hoeffding inequality states that
which essentially bounds the probability that the sum of random variables deviates from its expected value with the squared deviation weighted by the sum of squared ranges of the random variables. Intuitively, this means that the probability of the sum of random variables deviating from its mean decays exponentially.
McDiarmid Inequality
Similar to Hoeffding inequality, McDiarmid inequality extends the idea to any functions of a set of random variables. Defining c as the range of each random variable (i.e. c = b-a), we get the following
Now, we are in good position to derive the generalisation bounds
Deriving the Generalisation Bound
Defining our sample and generalisation loss in terms of some arbitrary loss function L(f),
we want an expression of the form
In order to form such an expression, we notice that from the definition of a supremum, we get that for a fixed function g(x,y)
This already looks similar to our desired expression. We need to do two things now: convert the fixed function into our loss functions and express the supremum in terms of the Rademacher complexity. Noticing that the supremum term is a random variable that depends on the draw of the sample dataset S, if we let
and assume, without loss of generality, that every function h maps to the set [a, a+1] (i.e. our loss function is of width 1, as is typically the case), then we can simply show that
with this, we are able to apply McDiarmid inequality to obtain the following
This is promising, we have managed to bound the expectation over the data manifold of a function with its sample estimate and the Rademacher complexity of the function class. To get what we desire, we need to replace function g with our loss function. A critical step in this is relating the Rademacher complexity of the loss function class to the Rademacher complexity of the model function class. With a little thought, we see that the two are related rather simply
because of the way our loss function is constructed for binary classification loss. Putting these all together, we obtain that
This is the final form of our generalisation bound.
Interpreting the Bound
Given what we have derived, a natural step is to go back to the beginning and ask ourselves how everything fits together. We see that the generalisation loss of any model, is bounded by its sample loss combined with its Rademacher complexity. This means that, if we utilise a very complex model, we may be able to achieve a very low sample error. However, we will still be unsure on how the model will perform on data it has not seen before. Another way to interpret this bound is as overfitting. When we use very complicated models, we are essentially allowing the model to overfit to the sample data, which results in the low loss on sample datasets, but in fact we are not allowing the model to generalise to unseen data. Both these interpretations lead to the same conclusion: the complex model is not always the best choice. I typically phrase it the opposite way: the simplest model usually performs the best in real world datasets.
Please give your thoughts and questions in the comments, and keep an eye out for similar articles in the future!