-
Book Overview & Buying
-
Table Of Contents
The Regularization Cookbook
By :
In this recipe, we will explain what hyperparameter optimization is and some related concepts: the definition of a hyperparameter, cross-validation, and various hyperparameter optimization methods. We will then perform a grid search to optimize the hyperparameters of the logistic regression task on the Titanic dataset.
Most of the time, in ML, we do not simply train a model on the training set and evaluate it against the test set.
This is because, like most other algorithms, ML algorithms can be fine-tuned. This fine-tuning process allows us to optimize hyperparameters to achieve the best possible results. This sometimes acts as leverage so that we can regularize a model.
Note
In ML, hyperparameters can be tuned by humans, unlike parameters, which are learned through the model training process, and thus can’t be tuned.
To properly optimize hyperparameters, a third split has to be introduced: the validation set.
This means there are now three splits:
You could create such a set by splitting X_train into X_train and X_valid with the train_test_split() function from scikit-learn.
But in practice, most people just use cross-validation and do not bother creating this validation set. The k-fold cross-validation method allows us to make k splits out of the training set and divide it, as presented in Figure 2.8:
Figure 2.8 – Typical split between training, validation, and test sets, without cross-validation (top) and with cross-validation (bottom)
In doing so, not just one model is trained, but k, for a given set of hyperparameters. The performances are averaged over those k models, based on a chosen metric (for example, accuracy, MSE, and so on).
Several sets of hyperparameters can then be tested, and the one that shows the best performance is selected. After selecting the best hyperparameter set, the model is trained one more time on the entire train set to maximize the data for training purposes.
Finally, you can implement several strategies to optimize the hyperparameters, as follows:
While being rather complicated to explain conceptually, hyperparameter optimization with cross-validation is super easy to implement. In this recipe, we’ll assume that we want to optimize a logistic regression model to predict whether a passenger would have survived:
GridSearchCV class from sklearn.model_selection.C: [0.01, 0.03, 0.1]. We must define a parameter grid with the hyperparameter as the key and the list of values to test as the value.The C hyperparameter is the inverse of the penalization strength: the higher C is, the lower the regularization. See the next chapter for more details:
# Define the hyperparameters we want to test
param_grid = { 'C': [0.01, 0.03, 0.1] }
GridSearchCV object and provide the following arguments:LogisticRegression instanceparam_grid, which we defined previouslyaccuracy5 herereturn_train_score to True to get some useful information we can use later:# Instantiate the grid search object
grid = GridSearchCV(
LogisticRegression(),
param_grid,
scoring='accuracy',
cv=5,
return_train_score=True
)
# Fit and wait
grid.fit(X_train, y_train)
GridSearchCV(cv=5, estimator=LogisticRegression(),
param_grid={'C': [0.01, 0.03, 0.1]},return_train_score=True, scoring='accuracy')
Note
Depending on the input dataset and the number of tested hyperparameters, the fit may take some time.
Once the fit has been completed, you can get a lot of useful information, such as the following:
.best_params attribute.best_score attribute.cv_results attribute.predict() method:y_pred = grid.predict(X_test)
print('Hyperparameter optimized accuracy:',accuracy_score(y_pred, y_test))
This provides the following output:
Hyperparameter optimized accuracy: 0.781229050279329
Thanks to the tools provided by scikit-learn, it is fairly easy to have a well-optimized model and evaluate it against several metrics. In the next recipe, we’ll learn how to diagnose bias and variance based on such an evaluation.
See also
The documentation for GridSearchCV can be found at https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html.