-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathresampling-for-model-validation.Rmd
More file actions
386 lines (253 loc) · 23.7 KB
/
resampling-for-model-validation.Rmd
File metadata and controls
386 lines (253 loc) · 23.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
---
title: "Resampling Approaches to Model Validation"
output:
html_document: default
pdf_document: default
---
```{r libs, include=FALSE}
library(tidymodels)
library(tidyverse)
library(tidyr)
tidymodels_prefer()
library(doMC)
registerDoMC(cores = parallel::detectCores())
source("ames_snippets.R")
```
# Model Performance
While model performance can be measured by splitting datasets into training and test sets, building the model using the training set, then using the test set for obtaining an unbiased estimate of performance. However, we usually need to understand the performance of a model or even multiple models _before using the test set_, because we build multiple models as part of our process of finding the best model.
# Resampling Methods
Resampling methods are empirical simulation systems that emulate the process of using some data for modeling and different data for evaluation. Most resampling methods are iterative, meaning that this process is repeated multiple times.
Resampling is only conducted on the training set. The test set is not involved. For each iteration of resampling, the data are partitioned into two subsamples:
* The model is fit with the *analysis set*.
* The model is evaluated with the *assessment set*.
These two subsamples are somewhat analogous to training and test sets. Our language of _analysis_ and _assessment_ avoids confusion with initial split of the data. These data sets are mutually exclusive. The partitioning scheme used to create the analysis and assessment sets is usually the defining characteristic of the method.
Suppose twenty iterations of resampling are conducted. This means that twenty separate models are fit on the analysis sets and the corresponding assessment sets produce twenty sets of performance statistics. The final estimate of performance for a model is the average of the twenty replicates of the statistics. This average has very good generalization properties and is far better than an estimate computed on a single test set.
We examine several commonly used resampling methods and discusses their pros and cons.
# Cross-validation
Cross-validation is a well established resampling method. While there are a number of variations, the most common cross-validation method is _V_-fold cross-validation. The data are randomly partitioned into _V_ sets of roughly equal size (called the "folds").
For 3-fold cross-validation, one fold is held out for assessment statistics and the remaining two folds are used to fit the model. This process continues for each fold so that three models produce three sets of performance statistics. When _V_ = 3, the analysis sets are 2/3 of the training set and each assessment set is a distinct 1/3. The final resampling estimate of performance averages each of the _V_ replicates.
Using _V_ = 3 is a good choice to illustrate cross-validation but is a poor choice in practice because it is too low to generate reliable estimates. In practice, values of _V_ are most often 5 or 10; we generally prefer 10-fold cross-validation as a default because it is large enough for good results in most situations.
The primary input is the training set data frame as well as the number of folds (defaulting to 10):
```{r resampling-ames-cv}
set.seed(1001)
ames_folds <- vfold_cv(ames_train, v = 10)
ames_folds
```
The column named `splits` contains the information on how to split the data (similar to the object used to create the initial training/test partition). While each row of `splits` has an embedded copy of the entire training set, R is smart enough not to make copies of the data in memory. The print method inside of the tibble shows the frequency of each: `[2108/234]` indicates that over two thousand samples are in the analysis set and 234 are in that particular assessment set. These objects also always contain a character column called `id` that labels the partition.
To manually retrieve the partitioned data, the `analysis()` and `assessment()` functions return the corresponding data frames:
```{r resampling-analysis}
# For the first fold:
ames_folds$splits[[1]] %>% analysis() %>% dim()
```
The `r pkg(tidymodels)` packages, such as [`r pkg(tune)`](https://tune.tidymodels.org/), contain high-level user interfaces so that functions like `analysis()` are not generally needed for day-to-day work.
There are a variety of variations on cross-validation; we'll go through the most important ones.
## Repeated cross-validation
The most important variation on cross-validation is repeated _V_-fold cross-validation. Depending on the size or other characteristics of the data, the resampling estimate produced by _V_-fold cross-validation may be excessively noisy. As with many statistical problems, one way to reduce noise is to gather more data. For cross-validation, this means averaging more than _V_ statistics.
To create _R_ repeats of _V_-fold cross-validation, the same fold generation process is done _R_ times to generate _R_ collections of _V_ partitions. Now, instead of averaging _V_ statistics, $V \times R$ statistics produce the final resampling estimate. Due to the Central Limit Theorem, the summary statistics from each model tend toward a normal distribution, as long as we have a lot of data relative to $V \times R$.
Consider the Ames data. On average, 10-fold cross-validation uses assessment sets that contain roughly `r floor(nrow(ames_train) * .1)` properties. If RMSE is the statistic of choice, we can denote that estimate's standard deviation as $\sigma$. With simple 10-fold cross-validation, the standard error of the mean RMSE is $\sigma/\sqrt{10}$. If this is too noisy, repeats reduce the standard error to $\sigma/\sqrt{10R}$. For 10-fold cross-validation with $R$ replicates, the plot in Figure \@ref(fig:variance-reduction) shows how quickly the standard error decreases with replicates.
```{r variance-reduction}
y_lab <- expression(Multiplier ~ on ~ sigma)
cv_info <-
tibble(replicates = rep(1:10, 2), V = 10) %>%
mutate(B = V * replicates, reduction = 1/B, V = format(V))
ggplot(cv_info, aes(x = replicates, y = reduction)) +
geom_line() +
geom_point() +
labs(
y = y_lab,
x = "Number of 10F-CV Replicates"
) +
theme_bw() +
scale_x_continuous(breaks = 1:10)
```
Larger number of replicates tend to have less impact on the standard error. However, if the baseline value of $\sigma$ is impractically large, the diminishing returns on replication may still be worth the extra computational costs.
To create repeats, invoke `vfold_cv()` with an additional argument `repeats`:
```{r resampling-repeated}
vfold_cv(ames_train, v = 10, repeats = 5)
```
## Leave-one-out cross-validation
One variation of cross-validation is leave-one-out (LOO) cross-validation where _V_ is the number of data points in the training set. If there are $n$ training set samples, $n$ models are fit using $n-1$ rows of the training set. Each model predicts the single excluded data point. At the end of resampling, the $n$ predictions are pooled to produce a single performance statistic.
Leave-one-out methods are deficient compared to almost any other method. For anything but pathologically small samples, LOO is computationally excessive and it may not have good statistical properties. Although the `r pkg(rsample)` package contains a `loo_cv()` function, these objects are not generally integrated into the broader tidymodels frameworks.
## Monte Carlo cross-validation
Another variant of _V_-fold cross-validation is Monte Carlo cross-validation (MCCV, @xu2001monte). Like _V_-fold cross-validation, it allocates a fixed proportion of data to the assessment sets. The difference between MCCV and regular cross-validation is that, for MCCV, this proportion of the data is randomly selected each time. This results in assessment sets that are not mutually exclusive. To create these resampling objects:
```{r resampling-mccv}
mc_cv(ames_train, prop = 9/10, times = 20)
```
# Validation sets
A validation set is a single partition that is set aside to estimate performance separate from the test set. When using a validation set, the initial available data set is split into a training set, a validation set, and a test set.
Validation sets are often used when the original pool of data is very large. In this case, a single large partition may be adequate to characterize model performance without having to do multiple iterations of resampling.
With the `r pkg(rsample)` package, a validation set is like any other resampling object; this type is different only in that it has a single iteration. To create a validation set object that uses 3/4 of the data for model fitting:
```{r resampling-validation-split}
set.seed(1002)
val_set <- validation_split(ames_train, prop = 3/4)
val_set
```
# Bootstrapping
Bootstrap resampling was originally invented as a method for approximating the sampling distribution of statistics whose theoretical properties are intractable [@davison1997bootstrap]. Using it to estimate model performance is a secondary application of the method.
A bootstrap sample of the training set is a sample that is the same size as the training set but is drawn _with replacement_. This means that some training set data points are selected multiple times for the analysis set. Each data point has a `r round((1-exp(-1)) * 100, 1)`% chance of inclusion in the training set at least once. The assessment set contains all of the training set samples that were not selected for the analysis set (on average, with `r round((exp(-1)) * 100, 1)`% of the training set). When bootstrapping, the assessment set is often called the "out-of-bag" sample. Note that the sizes of the assessment sets vary.
Using the `r pkg(rsample)` package, we can create such bootstrap resamples:
```{r resampling-boot-set}
bootstraps(ames_train, times = 5)
```
Bootstrap samples produce performance estimates that have very low variance (unlike cross-validation) but have significant pessimistic bias. This means that, if the true accuracy of a model is 90%, the bootstrap would tend to estimate the value to be less than 90%. The amount of bias cannot be empirically determined with sufficient accuracy. Additionally, the amount of bias changes over the scale of the performance metric. For example, the bias is likely to be different when the accuracy is 90% versus when it is 70%.
The bootstrap is also used inside of many models. For example, the random forest model mentioned earlier contained 1,000 individual decision trees. Each tree was the product of a different bootstrap sample of the training set.
# Rolling forecasting origin resampling
When the data have a strong time component, a resampling method should support modeling to estimate seasonal and other temporal trends within the data. A technique that randomly samples values from the training set can disrupt the model's ability to estimate these patterns.
Rolling forecast origin resampling [@hyndman2018forecasting] provides a method that emulates how time series data is often partitioned in practice, estimating the model with historical data and evaluating it with the most recent data. For this type of resampling, the size of the initial analysis and assessment sets are specified. The first iteration of resampling uses these sizes, starting from the beginning of the series. The second iteration uses the same data sizes but shifts over by a set number of samples.
# Estimating Performance
Any of the resampling methods discussed in this chapter can be used to evaluate the modeling process (including preprocessing, model fitting, etc). These methods are effective because different groups of data are used to train the model and assess the model. To reiterate, the process to use resampling is as follows:
1. During resampling, the analysis set is used to preprocess the data, apply the preprocessing to itself, and use these processed data to fit the model.
2. The preprocessing statistics produced by the analysis set are applied to the assessment set. The predictions from the assessment set estimate performance on new data.
This sequence repeats for every resample. If there are _B_ resamples, there are _B_ replicates of each of the performance metrics. The final resampling estimate is the average of these _B_ statistics. If _B_ = 1, as with a validation set, the individual statistics represent overall performance.
Let's reconsider the random forest model for the Ames housing dataset contained in the `rf_wflow` object. The `fit_resamples()` function is analogous to `fit()`, but instead of having a `data` argument, `fit_resamples()` has `resamples` which expects an `rset` object like the ones shown in this chapter. The possible interfaces to the function are:
```{r resampling-usage, eval = FALSE}
model_spec %>% fit_resamples(formula, resamples, ...)
model_spec %>% fit_resamples(recipe, resamples, ...)
workflow %>% fit_resamples( resamples, ...)
```
There are a number of other optional arguments, such as:
* `metrics`: A metric set of performance statistics to compute. By default, regression models use RMSE and R<sup>2</sup> while classification models compute the area under the ROC curve and overall accuracy. Note that this choice also defines what predictions are produced during the evaluation of the model. For classification, if only accuracy is requested, class probability estimates are not generated for the assessment set (since they are not needed).
* `control`: A list created by `control_resamples()` with various options.
The control arguments include:
* `verbose`: A logical for printing logging.
* `extract`: A function for retaining objects from each model iteration (discussed later in this chapter).
* `save_pred`: A logical for saving the assessment set predictions.
For our example, we'll use the Ames housing dataset and the `ames_folds` variable that we created for 10-fold cross validation. We'll save the predictions in order to visualize the model fit and residuals:
```{r resampling-cv-ames}
keep_pred <- control_resamples(save_pred = TRUE, save_workflow = TRUE)
set.seed(1003)
rf_res <-
rf_wflow %>%
fit_resamples(resamples = ames_folds, control = keep_pred)
rf_res
```
The return value is a tibble similar to the input resamples, along with some extra columns:
* `.metrics` is a list column of tibbles containing the assessment set performance statistics.
* `.notes` is another list column of tibbles cataloging any warnings or errors generated during resampling. Note that errors will not stop subsequent execution of resampling.
* `.predictions` is present when `save_pred = TRUE`. This list column contains tibbles with the out-of-sample predictions.
While these list columns may look daunting, they can be easily reconfigured using `r pkg(tidyr)` or with convenience functions that tidymodels provides. For example, to return the performance metrics in a more usable format.a The `collect_metrics()` function returns the mean metrics with standard errors computed across all resamples.
```{r}
collect_metrics(rf_res)
```
These are the resampling estimates averaged over the individual replicates. To get the metrics for each resample, use the option `summarize = FALSE`
```{r}
collect_metrics(rf_res, summarize=FALSE)
```
To obtain the assessment set predictions:
```{r resampling-cv-pred}
assess_res <- collect_predictions(rf_res)
assess_res
```
The prediction column names follow the conventions discussed for `r pkg(parsnip)` models for consistency and ease of use. The observed outcome column always uses the original column name from the source data. The `.row` column is an integer that matches the row of the original training set so that these results can be properly arranged and joined with the original data.
For some resampling methods, such as the bootstrap or repeated cross-validation, there will be multiple predictions per row of the original training set. To obtain summarized values (averages of the replicate predictions) use `collect_predictions(object, summarize = TRUE)`.
Since this analysis used 10-fold cross-validation, there is one unique prediction for each training set sample. These data can generate helpful plots of the model to understand where it potentially failed. For example, we compare the observed and held-out predicted values below.
```{r resampling-cv-pred-plot, eval=FALSE}
assess_res %>%
ggplot(aes(x = Sale_Price, y = .pred)) +
geom_point(alpha = .15) +
geom_abline(color = "red") +
coord_obs_pred() +
ylab("Predicted")
```
There are two houses in the training set with a low observed sale price that are significantly overpredicted by the model. Which houses are these? Let's find out from the `assess_res` result:
```{r resampling-investigate}
over_predicted <-
assess_res %>%
mutate(residual = Sale_Price - .pred) %>%
arrange(desc(abs(residual))) %>%
slice(1:2)
over_predicted
ames_train %>%
slice(over_predicted$.row) %>%
select(Gr_Liv_Area, Neighborhood, Year_Built, Bedroom_AbvGr, Full_Bath)
```
Identifying examples like these with especially poor performance can help us follow up and investigate why these specific predictions are so poor.
# Parallel Processing
The models created during resampling are independent of one another. Computations of this kind are sometimes called "embarrassingly parallel"; each model could be fit simultaneously without issues.^[@parallel gives a technical overview of these technologies.] The `r pkg(tune)` package uses the [`r pkg(foreach)`](https://CRAN.R-project.org/package=foreach) package to facilitate parallel computations. These computations could be split across processors on the same computer or across different computers, depending on the chosen technology.
For computations conducted on a single computer, the number of possible "worker processes" is determined by the `r pkg(parallel)` package:
```{r resampling-find-cores}
# The number of physical cores in the hardware:
parallel::detectCores(logical = FALSE)
# The number of possible independent processes that can
# be simultaneously used:
parallel::detectCores(logical = TRUE)
```
For `fit_resamples()` and other functions in `r pkg(tune)`, parallel processing occurs when the user registers a parallel backend package. These R packages define how to execute parallel processing. On Unix and macOS operating systems, one method of splitting computations is by forking threads. To enable this, load the `r pkg(doMC)` package and register the number of parallel cores with `r pkg(foreach)`:
```{r resampling-mc, eval = FALSE}
# Unix and macOS only
library(doMC)
registerDoMC(cores = 2)
# Now run fit_resamples()...
```
This instructs `fit_resamples()` to run half of the computations on each of two cores. To reset the computations to sequential processing:
```{r resampling-seq, eval = FALSE}
registerDoSEQ()
```
Alternatively, a different approach to parallelizing computations uses network sockets. The `r pkg(doParallel)` package enables this method (usable by all operating systems):
```{r resampling-psock, eval = FALSE}
# All operating systems
library(doParallel)
# Create a cluster object and then register:
cl <- makePSOCKcluster(2)
registerDoParallel(cl)
# Now run fit_resamples()`...
stopCluster(cl)
```
Another R package that facilitates parallel processing is the [`r pkg(future)`](https://future.futureverse.org/) package. Like `r pkg(foreach)`, it provides a framework for parallelism. It is used in conjunction with `r pkg(foreach)` via the `r pkg(doFuture)` package.
:::rmdnote
The R packages with parallel backends for `r pkg(foreach)` start with the prefix `"do"`.
:::
Parallel processing with the `r pkg(tune)` package tends to provide linear speed-ups for the first few cores. This means that, with two cores, the computations are twice as fast. Depending on the data and type of model, the linear speedup deteriorates after 4-5 cores. Using more cores will still reduce the time it takes to complete the task; there are just diminishing returns for the additional cores.
Let's wrap up with one final note about parallelism. For each of these technologies, the memory requirements multiply for each additional core used. For example, if the current data set is 2 GB in memory and three cores are used, the total memory requirement is 8 GB (2 for each worker process plus the original). Using too many cores might cause the computations (and the computer) to slow considerably.
## Saving the Resampled Objects
The models created during resampling are not retained. These models are trained for the purpose of evaluating performance, and we typically do not need them after we have computed performance statistics. If a particular modeling approach does turn out to be the best option for our data set, then the best choice is to fit again to the whole training set so the model parameters can be estimated with more data.
While these models created during resampling are not preserved, there is a method for keeping them or some of their components. The `extract` option of `control_resamples()` specifies a function that takes a single argument; we'll use `x`. When executed, `x` results in a fitted workflow object, regardless of whether you provided `fit_resamples()` with a workflow. Recall that the `r pkg(workflows)` package has functions that can pull the different components of the objects (e.g. the model, recipe, etc.).
Let's fit a linear regression model for the Ames housing data.
```{r resampling-lm-ames}
ames_rec <-
recipe(Sale_Price ~ Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type +
Latitude + Longitude, data = ames_train) %>%
step_other(Neighborhood, threshold = 0.01) %>%
step_dummy(all_nominal_predictors()) %>%
step_interact( ~ Gr_Liv_Area:starts_with("Bldg_Type_") ) %>%
step_ns(Latitude, Longitude, deg_free = 20)
lm_wflow <-
workflow() %>%
add_recipe(ames_rec) %>%
add_model(linear_reg() %>% set_engine("lm"))
lm_fit <- lm_wflow %>% fit(data = ames_train)
# Select the recipe:
extract_recipe(lm_fit, estimated = TRUE)
```
We can save the linear model coefficients for a fitted model object from a workflow:
```{r resampling-extract-func}
get_model <- function(x) {
extract_fit_parsnip(x) %>% tidy()
}
# Test it using:
# get_model(lm_fit)
```
Now let's apply this function to the ten resampled fits. The results of the extraction function is wrapped in a list object and returned in a tibble:
```{r resampling-extract-all}
ctrl <- control_resamples(extract = get_model)
lm_res <- lm_wflow %>% fit_resamples(resamples = ames_folds, control = ctrl)
lm_res
```
Now there is a `.extracts` column with nested tibbles. What do these contain? Let's find out by subsetting.
```{r resampling-extract-res}
lm_res$.extracts[[1]]
# To get the results
lm_res$.extracts[[1]][[1]]
```
This might appear to be a convoluted method for saving the model results. However, `extract` is flexible and does not assume that the user will only save a single tibble per resample. For example, the `tidy()` method might be run on the recipe as well as the model. In this case, a list of two tibbles will be returned.
For our more simple example, all of the results can be flattened and collected using:
```{r resampling-extract-fraction}
all_coef <- map_dfr(lm_res$.extracts, ~ .x[[1]][[1]])
# Show the replicates for a single predictor:
filter(all_coef, term == "Year_Built")
```
Chapters \@ref(grid-search) and \@ref(iterative-search) discuss a suite of functions for tuning models. Their interfaces are similar to `fit_resamples()` and many of the features described here apply to those functions.
# References
1. Tantithamthavorn, Chakkrit, et al. "An empirical comparison of model validation techniques for defect prediction models." IEEE Transactions on Software Engineering 43.1 (2016): 1-18.
2. B. Efron, "Estimating the error rate of a prediction rule: Some improvements on cross-validation," J. Amer. Statist. Assoc., vol. 78, no. 382, pp. 316–331, 1983.
3. Max Kuhn and Julia Silge, _Tidy Modeling with R_, chapter 10. https://www.tmwr.org/resampling.html