The goal of this vignette is to explain the older resamplers:
ResamplingVariableSizeTrainCV
and ResamplingSameOtherCV
, which
output some data which are useful for visualizing the train/test
splits. If you do not want to visualize the train/test splits, then it
is recommended to instead use the newer resampler,
ResamplingSameOtherSizesCV
(see other vignette).
The goal of thie section is to explain how to quantify the extent to which it is possible to train on one data subset, and predict on another data subset. This kind of problem occurs frequently in many different problem domains:
The ideas are similar to my previous blog posts about how to do this
in
python
and R. Below
we explain how to use mlr3resampling
for this purpose, in simulated
regression and classification problems. To use this method in
real data, the important sections to read below are named “Benchmark:
computing test error,” which show how to create these cross-validation
experiments using mlr3 code.
We begin by generating some data which can be used with regression algorithms. Assume there is a data set with some rows from one person, some rows from another,
N <- 300
library(data.table)
set.seed(1)
abs.x <- 2
reg.dt <- data.table(
x=runif(N, -abs.x, abs.x),
person=rep(1:2, each=0.5*N))
reg.pattern.list <- list(
easy=function(x, person)x^2,
impossible=function(x, person)(x^2+person*3)*(-1)^person)
reg.task.list <- list()
for(task_id in names(reg.pattern.list)){
f <- reg.pattern.list[[task_id]]
yname <- paste0("y_",task_id)
reg.dt[, (yname) := f(x,person)+rnorm(N)][]
task.dt <- reg.dt[, c("x","person",yname), with=FALSE]
reg.task <- mlr3::TaskRegr$new(
task_id, task.dt, target=yname)
reg.task$col_roles$subset <- "person"
reg.task$col_roles$stratum <- "person"
reg.task$col_roles$feature <- "x"
reg.task.list[[task_id]] <- reg.task
}
reg.dt
#> x person y_easy y_impossible
#> <num> <int> <num> <num>
#> 1: -0.9379653 1 1.32996609 -2.918082
#> 2: -0.5115044 1 0.24307692 -3.866062
#> 3: 0.2914135 1 -0.23314657 -3.837799
#> 4: 1.6328312 1 1.73677545 -7.221749
#> 5: -1.1932723 1 -0.06356159 -5.877792
#> ---
#> 296: 0.7257701 2 -2.48130642 5.180948
#> 297: -1.6033236 2 1.20453459 9.604312
#> 298: -1.5243898 2 1.89966190 7.511988
#> 299: -1.7982414 2 3.47047566 11.035397
#> 300: 1.7170157 2 0.60541972 10.719685
The table above shows some simulated data for two regression problems:
mlr3::TaskRegr
line which tells mlr3 what data set to use, what is
the target column, and what is the subset/stratum column.First we reshape the data using the code below,
(reg.tall <- nc::capture_melt_single(
reg.dt,
task_id="easy|impossible",
value.name="y"))
#> x person task_id y
#> <num> <int> <char> <num>
#> 1: -0.9379653 1 easy 1.32996609
#> 2: -0.5115044 1 easy 0.24307692
#> 3: 0.2914135 1 easy -0.23314657
#> 4: 1.6328312 1 easy 1.73677545
#> 5: -1.1932723 1 easy -0.06356159
#> ---
#> 596: 0.7257701 2 impossible 5.18094849
#> 597: -1.6033236 2 impossible 9.60431191
#> 598: -1.5243898 2 impossible 7.51198770
#> 599: -1.7982414 2 impossible 11.03539747
#> 600: 1.7170157 2 impossible 10.71968480
The table above is a more convenient form for the visualization which we create using the code below,
if(require(animint2)){
ggplot()+
geom_point(aes(
x, y),
data=reg.tall)+
facet_grid(
task_id ~ person,
labeller=label_both,
space="free",
scales="free")+
scale_y_continuous(
breaks=seq(-100, 100, by=2))
}
#> Loading required package: animint2
#> Registered S3 methods overwritten by 'animint2':
#> method from
#> [.uneval ggplot2
#> drawDetails.zeroGrob ggplot2
#> grid.draw.absoluteGrob ggplot2
#> grobHeight.absoluteGrob ggplot2
#> grobHeight.zeroGrob ggplot2
#> grobWidth.absoluteGrob ggplot2
#> grobWidth.zeroGrob ggplot2
#> grobX.absoluteGrob ggplot2
#> grobY.absoluteGrob ggplot2
#> heightDetails.titleGrob ggplot2
#> heightDetails.zeroGrob ggplot2
#> makeContext.dotstackGrob ggplot2
#> print.element ggplot2
#> print.ggplot2_bins ggplot2
#> print.rel ggplot2
#> print.theme ggplot2
#> print.uneval ggplot2
#> widthDetails.titleGrob ggplot2
#> widthDetails.zeroGrob ggplot2
#>
#> Attaching package: 'animint2'
#> The following objects are masked from 'package:ggplot2':
#>
#> %+%, %+replace%, Coord, CoordCartesian, CoordFixed, CoordFlip,
#> CoordMap, CoordPolar, CoordQuickmap, CoordTrans, Geom, GeomAbline,
#> GeomAnnotationMap, GeomArea, GeomBar, GeomBlank, GeomContour,
#> GeomCrossbar, GeomCurve, GeomCustomAnn, GeomDensity, GeomDensity2d,
#> GeomDotplot, GeomErrorbar, GeomErrorbarh, GeomHex, GeomHline,
#> GeomLabel, GeomLine, GeomLinerange, GeomLogticks, GeomMap,
#> GeomPath, GeomPoint, GeomPointrange, GeomPolygon, GeomRaster,
#> GeomRasterAnn, GeomRect, GeomRibbon, GeomRug, GeomSegment,
#> GeomSmooth, GeomSpoke, GeomStep, GeomText, GeomTile, GeomViolin,
#> GeomVline, Position, PositionDodge, PositionFill, PositionIdentity,
#> PositionJitter, PositionJitterdodge, PositionNudge, PositionStack,
#> Scale, ScaleContinuous, ScaleContinuousDate,
#> ScaleContinuousDatetime, ScaleContinuousIdentity,
#> ScaleContinuousPosition, ScaleDiscrete, ScaleDiscreteIdentity,
#> ScaleDiscretePosition, Stat, StatBin, StatBin2d, StatBindot,
#> StatBinhex, StatContour, StatCount, StatDensity, StatDensity2d,
#> StatEcdf, StatEllipse, StatFunction, StatIdentity, StatQq,
#> StatSmooth, StatSum, StatSummary, StatSummary2d, StatSummaryBin,
#> StatSummaryHex, StatUnique, StatYdensity, aes, aes_, aes_all,
#> aes_auto, aes_q, aes_string, annotate, annotation_custom,
#> annotation_logticks, annotation_map, annotation_raster,
#> as_labeller, autoplot, benchplot, borders, calc_element,
#> continuous_scale, coord_cartesian, coord_equal, coord_fixed,
#> coord_flip, coord_map, coord_munch, coord_polar, coord_quickmap,
#> coord_trans, cut_interval, cut_number, cut_width, discrete_scale,
#> draw_key_abline, draw_key_blank, draw_key_crossbar,
#> draw_key_dotplot, draw_key_label, draw_key_path, draw_key_point,
#> draw_key_pointrange, draw_key_polygon, draw_key_rect,
#> draw_key_smooth, draw_key_text, draw_key_vline, draw_key_vpath,
#> economics, economics_long, element_blank, element_grob,
#> element_line, element_rect, element_text, expand_limits,
#> facet_grid, facet_null, facet_wrap, fortify, geom_abline,
#> geom_area, geom_bar, geom_bin2d, geom_blank, geom_contour,
#> geom_count, geom_crossbar, geom_curve, geom_density,
#> geom_density2d, geom_density_2d, geom_dotplot, geom_errorbar,
#> geom_errorbarh, geom_freqpoly, geom_hex, geom_histogram,
#> geom_hline, geom_jitter, geom_label, geom_line, geom_linerange,
#> geom_map, geom_path, geom_point, geom_pointrange, geom_polygon,
#> geom_qq, geom_raster, geom_rect, geom_ribbon, geom_rug,
#> geom_segment, geom_smooth, geom_spoke, geom_step, geom_text,
#> geom_tile, geom_violin, geom_vline, gg_dep, ggplot, ggplotGrob,
#> ggplot_build, ggplot_gtable, ggsave, ggtitle, guide_colorbar,
#> guide_colourbar, guide_legend, guides, is.Coord, is.facet,
#> is.ggplot, is.theme, label_both, label_bquote, label_context,
#> label_parsed, label_value, label_wrap_gen, labeller, labs,
#> last_plot, layer, layer_data, layer_grob, layer_scales, lims,
#> map_data, margin, mean_cl_boot, mean_cl_normal, mean_sdl, mean_se,
#> median_hilow, position_dodge, position_fill, position_identity,
#> position_jitter, position_jitterdodge, position_nudge,
#> position_stack, presidential, qplot, quickplot, rel,
#> remove_missing, resolution, scale_alpha, scale_alpha_continuous,
#> scale_alpha_discrete, scale_alpha_identity, scale_alpha_manual,
#> scale_color_brewer, scale_color_continuous, scale_color_discrete,
#> scale_color_distiller, scale_color_gradient, scale_color_gradient2,
#> scale_color_gradientn, scale_color_grey, scale_color_hue,
#> scale_color_identity, scale_color_manual, scale_colour_brewer,
#> scale_colour_continuous, scale_colour_date, scale_colour_datetime,
#> scale_colour_discrete, scale_colour_distiller,
#> scale_colour_gradient, scale_colour_gradient2,
#> scale_colour_gradientn, scale_colour_grey, scale_colour_hue,
#> scale_colour_identity, scale_colour_manual, scale_fill_brewer,
#> scale_fill_continuous, scale_fill_date, scale_fill_datetime,
#> scale_fill_discrete, scale_fill_distiller, scale_fill_gradient,
#> scale_fill_gradient2, scale_fill_gradientn, scale_fill_grey,
#> scale_fill_hue, scale_fill_identity, scale_fill_manual,
#> scale_linetype, scale_linetype_continuous, scale_linetype_discrete,
#> scale_linetype_identity, scale_linetype_manual, scale_radius,
#> scale_shape, scale_shape_continuous, scale_shape_discrete,
#> scale_shape_identity, scale_shape_manual, scale_size,
#> scale_size_area, scale_size_continuous, scale_size_date,
#> scale_size_datetime, scale_size_discrete, scale_size_identity,
#> scale_size_manual, scale_x_continuous, scale_x_date,
#> scale_x_datetime, scale_x_discrete, scale_x_log10, scale_x_reverse,
#> scale_x_sqrt, scale_y_continuous, scale_y_date, scale_y_datetime,
#> scale_y_discrete, scale_y_log10, scale_y_reverse, scale_y_sqrt,
#> should_stop, stat_bin, stat_bin2d, stat_bin_2d, stat_bin_hex,
#> stat_binhex, stat_contour, stat_count, stat_density,
#> stat_density2d, stat_density_2d, stat_ecdf, stat_ellipse,
#> stat_function, stat_identity, stat_qq, stat_smooth, stat_spoke,
#> stat_sum, stat_summary, stat_summary2d, stat_summary_2d,
#> stat_summary_bin, stat_summary_hex, stat_unique, stat_ydensity,
#> theme, theme_bw, theme_classic, theme_dark, theme_get, theme_gray,
#> theme_grey, theme_light, theme_linedraw, theme_minimal,
#> theme_replace, theme_set, theme_update, theme_void,
#> transform_position, update_geom_defaults, update_labels,
#> update_stat_defaults, waiver, xlab, xlim, ylab, ylim, zeroGrob
In the simulated data above, we can see that
In the code below, we define a K-fold cross-validation experiment.
(reg_same_other <- mlr3resampling::ResamplingSameOtherCV$new())
#> <ResamplingSameOtherCV> : Same versus Other Cross-Validation
#> * Iterations:
#> * Instantiated: FALSE
#> * Parameters:
#> List of 1
#> $ folds: int 3
In the code below, we define two learners to compare,
(reg.learner.list <- list(
if(requireNamespace("rpart"))mlr3::LearnerRegrRpart$new(),
mlr3::LearnerRegrFeatureless$new()))
#> [[1]]
#> <LearnerRegrRpart:regr.rpart>: Regression Tree
#> * Model: -
#> * Parameters: xval=0
#> * Packages: mlr3, rpart
#> * Predict Types: [response]
#> * Feature Types: logical, integer, numeric, factor, ordered
#> * Properties: importance, missings, selected_features, weights
#>
#> [[2]]
#> <LearnerRegrFeatureless:regr.featureless>: Featureless Regression Learner
#> * Model: -
#> * Parameters: robust=FALSE
#> * Packages: mlr3, stats
#> * Predict Types: [response], se, quantiles
#> * Feature Types: logical, integer, numeric, character, factor, ordered,
#> POSIXct
#> * Properties: featureless, importance, missings, selected_features
In the code below, we define the benchmark grid, which is all combinations of tasks (easy and impossible), learners (rpart and featureless), and the one resampling method.
(reg.bench.grid <- mlr3::benchmark_grid(
reg.task.list,
reg.learner.list,
reg_same_other))
#> task learner resampling
#> <char> <char> <char>
#> 1: easy regr.rpart same_other_cv
#> 2: easy regr.featureless same_other_cv
#> 3: impossible regr.rpart same_other_cv
#> 4: impossible regr.featureless same_other_cv
In the code below, we execute the benchmark experiment (in parallel using the multisession future plan).
if(FALSE){#for CRAN.
if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
#> Loading required package: lgr
#>
#> Attaching package: 'lgr'
#> The following object is masked from 'package:ggplot2':
#>
#> Layout
(reg.bench.result <- mlr3::benchmark(
reg.bench.grid, store_models = TRUE))
#> <BenchmarkResult> of 72 rows with 4 resampling runs
#> nr task_id learner_id resampling_id iters warnings errors
#> 1 easy regr.rpart same_other_cv 18 0 0
#> 2 easy regr.featureless same_other_cv 18 0 0
#> 3 impossible regr.rpart same_other_cv 18 0 0
#> 4 impossible regr.featureless same_other_cv 18 0 0
The code below computes the test error for each split,
reg.bench.score <- mlr3resampling::score(reg.bench.result)
reg.bench.score[1]
#> train.subsets test.fold test.subset person iteration test
#> <char> <int> <int> <int> <int> <list>
#> 1: all 1 1 1 1 1, 3, 5, 6,12,13,...
#> train uhash nr
#> <list> <char> <int>
#> 1: 4, 7, 9,10,18,20,... 5162c460-c881-4393-8b16-b7297db4d2bd 1
#> task task_id learner learner_id
#> <list> <char> <list> <char>
#> 1: <TaskRegr:easy> easy <LearnerRegrRpart:regr.rpart> regr.rpart
#> resampling resampling_id prediction_test regr.mse algorithm
#> <list> <char> <list> <num> <char>
#> 1: <ResamplingSameOtherCV> same_other_cv <PredictionRegr> 1.638015 rpart
The code below visualizes the resulting test accuracy numbers.
if(require(animint2)){
ggplot()+
scale_x_log10()+
geom_point(aes(
regr.mse, train.subsets, color=algorithm),
shape=1,
data=reg.bench.score)+
facet_grid(
task_id ~ person,
labeller=label_both,
scales="free")
}
It is clear from the plot above that
The code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.
inst <- reg.bench.score$resampling[[1]]$instance
rect.expand <- 0.2
grid.dt <- data.table(x=seq(-abs.x, abs.x, l=101), y=0)
grid.task <- mlr3::TaskRegr$new("grid", grid.dt, target="y")
pred.dt.list <- list()
point.dt.list <- list()
for(score.i in 1:nrow(reg.bench.score)){
reg.bench.row <- reg.bench.score[score.i]
task.dt <- data.table(
reg.bench.row$task[[1]]$data(),
reg.bench.row$resampling[[1]]$instance$id.dt)
names(task.dt)[1] <- "y"
set.ids <- data.table(
set.name=c("test","train")
)[
, data.table(row_id=reg.bench.row[[set.name]][[1]])
, by=set.name]
i.points <- set.ids[
task.dt, on="row_id"
][
is.na(set.name), set.name := "unused"
]
point.dt.list[[score.i]] <- data.table(
reg.bench.row[, .(task_id, iteration)],
i.points)
i.learner <- reg.bench.row$learner[[1]]
pred.dt.list[[score.i]] <- data.table(
reg.bench.row[, .(
task_id, iteration, algorithm
)],
as.data.table(
i.learner$predict(grid.task)
)[, .(x=grid.dt$x, y=response)]
)
}
(pred.dt <- rbindlist(pred.dt.list))
#> task_id iteration algorithm x y
#> <char> <int> <char> <num> <num>
#> 1: easy 1 rpart -2.00 3.557968
#> 2: easy 1 rpart -1.96 3.557968
#> 3: easy 1 rpart -1.92 3.557968
#> 4: easy 1 rpart -1.88 3.557968
#> 5: easy 1 rpart -1.84 3.557968
#> ---
#> 7268: impossible 18 featureless 1.84 7.204232
#> 7269: impossible 18 featureless 1.88 7.204232
#> 7270: impossible 18 featureless 1.92 7.204232
#> 7271: impossible 18 featureless 1.96 7.204232
#> 7272: impossible 18 featureless 2.00 7.204232
(point.dt <- rbindlist(point.dt.list))
#> task_id iteration set.name row_id y x fold person
#> <char> <int> <char> <int> <num> <num> <int> <int>
#> 1: easy 1 test 1 1.32996609 -0.9379653 1 1
#> 2: easy 1 train 2 0.24307692 -0.5115044 3 1
#> 3: easy 1 test 3 -0.23314657 0.2914135 1 1
#> 4: easy 1 train 4 1.73677545 1.6328312 2 1
#> 5: easy 1 test 5 -0.06356159 -1.1932723 1 1
#> ---
#> 21596: impossible 18 train 296 5.18094849 0.7257701 1 2
#> 21597: impossible 18 train 297 9.60431191 -1.6033236 1 2
#> 21598: impossible 18 test 298 7.51198770 -1.5243898 3 2
#> 21599: impossible 18 train 299 11.03539747 -1.7982414 1 2
#> 21600: impossible 18 test 300 10.71968480 1.7170157 3 2
#> subset display_row
#> <int> <int>
#> 1: 1 1
#> 2: 1 101
#> 3: 1 2
#> 4: 1 51
#> 5: 1 3
#> ---
#> 21596: 2 198
#> 21597: 2 199
#> 21598: 2 299
#> 21599: 2 200
#> 21600: 2 300
set.colors <- c(
train="#1B9E77",
test="#D95F02",
unused="white")
algo.colors <- c(
featureless="blue",
rpart="red")
make_person_subset <- function(DT){
DT[, "person/subset" := person]
}
make_person_subset(point.dt)
make_person_subset(reg.bench.score)
if(require(animint2)){
viz <- animint(
title="Train/predict on subsets, regression",
pred=ggplot()+
ggtitle("Predictions for selected train/test split")+
theme_animint(height=400)+
scale_fill_manual(values=set.colors)+
geom_point(aes(
x, y, fill=set.name),
showSelected="iteration",
size=3,
shape=21,
data=point.dt)+
scale_color_manual(values=algo.colors)+
geom_line(aes(
x, y, color=algorithm, subset=paste(algorithm, iteration)),
showSelected="iteration",
data=pred.dt)+
facet_grid(
task_id ~ `person/subset`,
labeller=label_both,
space="free",
scales="free")+
scale_y_continuous(
breaks=seq(-100, 100, by=2)),
err=ggplot()+
ggtitle("Test error for each split")+
theme_animint(height=400)+
scale_y_log10(
"Mean squared error on test set")+
scale_fill_manual(values=algo.colors)+
scale_x_discrete(
"People/subsets in train set")+
geom_point(aes(
train.subsets, regr.mse, fill=algorithm),
shape=1,
size=5,
stroke=2,
color="black",
color_off=NA,
clickSelects="iteration",
data=reg.bench.score)+
facet_grid(
task_id ~ `person/subset`,
labeller=label_both,
scales="free"),
diagram=ggplot()+
ggtitle("Select train/test split")+
theme_bw()+
theme_animint(height=300)+
facet_grid(
. ~ train.subsets,
scales="free",
space="free")+
scale_size_manual(values=c(subset=3, fold=1))+
scale_color_manual(values=c(subset="orange", fold="grey50"))+
geom_rect(aes(
xmin=-Inf, xmax=Inf,
color=rows,
size=rows,
ymin=display_row, ymax=display_end),
fill=NA,
data=inst$viz.rect.dt)+
scale_fill_manual(values=set.colors)+
geom_rect(aes(
xmin=iteration-rect.expand, ymin=display_row,
xmax=iteration+rect.expand, ymax=display_end,
fill=set.name),
clickSelects="iteration",
data=inst$viz.set.dt)+
geom_text(aes(
ifelse(rows=="subset", Inf, -Inf),
(display_row+display_end)/2,
hjust=ifelse(rows=="subset", 1, 0),
label=paste0(rows, "=", ifelse(rows=="subset", subset, fold))),
data=data.table(train.name="same", inst$viz.rect.dt))+
scale_x_continuous(
"Split number / cross-validation iteration")+
scale_y_continuous(
"Row number"),
source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/ResamplingSameOtherCV.Rmd")
viz
}
// Constructor for animint Object.
var animint = function (to_select, json_file) {
var default_axis_px = 16;
function wait_until_then(timeout, condFun, readyFun) {
var args=arguments
function checkFun() {
if(condFun()) {
readyFun(args[3],args[4]);
} else{
setTimeout(checkFun, timeout);
}
}
checkFun();
}
function convert_R_types(resp_array, types){
return resp_array.map(function (d) {
for (var v_name in d) {
if(!is_interactive_aes(v_name)){
var r_type = types[v_name];
if (r_type == "integer") {
d[v_name] = parseInt(d[v_name]);
} else if (r_type == "numeric") {
d[v_name] = parseFloat(d[v_name]);
} else if (r_type == "factor" || r_type == "rgb"
|| r_type == "linetype" || r_type == "label"
|| r_type == "character") {
// keep it as a character
} else if (r_type == "character" & v_name == "outliers") {
d[v_name] = parseFloat(d[v_name].split(" @ "));
}
}
}
return d;
});
}
// replacing periods in variable with an underscore this makes sure
// that selector doesn't confuse . in name with css selectors
function safe_name(unsafe_name){
return unsafe_name.replace(/[ .]/g, '_');
}
function legend_class_name(selector_name){
return safe_name(selector_name) + "_variable";
}
function is_interactive_aes(v_name){
if(v_name.indexOf("clickSelects") > -1){
return true;
}
if(v_name.indexOf("showSelected") > -1){
return true;
}
return false;
}
var linetypesize2dasharray = function (lt, size) {
var isInt = function(n) {
return typeof n === 'number' && parseFloat(n) == parseInt(n, 10) && !isNaN(n);
};
if(isInt(lt)){ // R integer line types.
if(lt == 1){
return null;
}
var o = {
0: size * 0 + "," + size * 10,
2: size * 4 + "," + size * 4,
3: size + "," + size * 2,
4: size + "," + size * 2 + "," + size * 4 + "," + size * 2,
5: size * 8 + "," + size * 4,
6: size * 2 + "," + size * 2 + "," + size * 6 + "," + size * 2
};
} else { // R defined line types
if(lt == "solid" || lt === null){
return null;
}
var o = {
"blank": size * 0 + "," + size * 10,
"none": size * 0 + "," + size * 10,
"dashed": size * 4 + "," + size * 4,
"dotted": size + "," + size * 2,
"dotdash": size + "," + size * 2 + "," + size * 4 + "," + size * 2,
"longdash": size * 8 + "," + size * 4,
"twodash": size * 2 + "," + size * 2 + "," + size * 6 + "," + size * 2,
"22": size * 2 + "," + size * 2,
"42": size * 4 + "," + size * 2,
"44": size * 4 + "," + size * 4,"13": size + "," + size * 3,
"1343": size + "," + size * 3 + "," + size * 4 + "," + size * 3,
"73": size * 7 + "," + size * 3,
"2262": size * 2 + "," + size * 2 + "," + size * 6 + "," + size * 2,
"12223242": size + "," + size * 2 + "," + size * 2 + "," + size * 2 + "," + size * 3 + "," + size * 2 + "," + size * 4 + "," + size * 2,
"F282": size * 15 + "," + size * 2 + "," + size * 8 + "," + size * 2,
"F4448444": size * 15 + "," + size * 4 + "," + size * 4 + "," + size * 4 + "," + size * 8 + "," + size * 4 + "," + size * 4 + "," + size * 4,
"224282F2": size * 2 + "," + size * 2 + "," + size * 4 + "," + size * 2 + "," + size * 8 + "," + size * 2 + "," + size * 16 + "," + size * 2,
"F1": size * 16 + "," + size
};
}
if (lt in o){
return o[lt];
} else{ // manually specified line types
str = lt.split("");
strnum = str.map(function (d) {
return size * parseInt(d, 16);
});
return strnum;
}
};
var isArray = function(o) {
return Object.prototype.toString.call(o) === '[object Array]';
};
// create a dummy element, apply the appropriate classes,
// and then measure the element
// Inspired from http://jsfiddle.net/uzddx/2/
var measureText = function(pText, pFontSize, pAngle, pStyle) {
if (!pText || pText.length === 0) return {height: 0, width: 0};
if (pAngle === null || isNaN(pAngle)) pAngle = 0;
var container = element.append('svg');
// do we need to set the class so that styling is applied?
//.attr('class', classname);
container.append('text')
.attr({x: -1000, y: -1000})
.attr("transform", "rotate(" + pAngle + ")")
.attr("style", pStyle)
.attr("font-size", pFontSize)
.text(pText);
var bbox = container.node().getBBox();
container.remove();
return {height: bbox.height, width: bbox.width};
};
var nest_by_group = d3.nest().key(function(d){ return d.group; });
var dirs = json_file.split("/");
dirs.pop(); //if a directory path exists, remove the JSON file from dirs
var element = d3.select(to_select);
this.element = element;
var viz_id = element.attr("id");
var Widgets = {};
this.Widgets = Widgets;
var Selectors = {};
this.Selectors = Selectors;
var Plots = {};
this.Plots = Plots;
var Geoms = {};
this.Geoms = Geoms;
// SVGs must be stored separately from Geoms since they are
// initialized first, with the Plots.
var SVGs = {};
this.SVGs = SVGs;
var Animation = {};
this.Animation = Animation;
var all_geom_names = {};
this.all_geom_names = all_geom_names;
//creating an array to contain the selectize widgets
var selectized_array = [];
var data_object_geoms = {
"line":true,
"path":true,
"ribbon":true,
"polygon":true
};
var css = document.createElement('style');
css.type = 'text/css';
var styles = [".axis path{fill: none;stroke: black;shape-rendering: crispEdges;}",
".axis line{fill: none;stroke: black;shape-rendering: crispEdges;}",
".axis text {font-family: sans-serif;font-size: 11px;}"];
var add_geom = function (g_name, g_info) {
// Determine if data will be an object or an array.
if(g_info.geom in data_object_geoms){
g_info.data_is_object = true;
}else{
g_info.data_is_object = false;
}
// Add a row to the loading table.
g_info.tr = Widgets["loading"].append("tr");
g_info.tr.append("td").text(g_name);
g_info.tr.append("td").attr("class", "chunk");
g_info.tr.append("td").attr("class", "downloaded").text(0);
g_info.tr.append("td").text(g_info.total);
g_info.tr.append("td").attr("class", "status").text("initialized");
// load chunk tsv
g_info.data = {};
g_info.download_status = {};
Geoms[g_name] = g_info;
// Determine whether common chunk tsv exists
// If yes, load it
if(g_info.hasOwnProperty("columns") && g_info.columns.common){
var common_tsv = get_tsv(g_info, "_common");
g_info.common_tsv = common_tsv;
var common_path = getTSVpath(common_tsv);
d3.tsv(common_path, function (error, response) {
var converted = convert_R_types(response, g_info.types);
g_info.data[common_tsv] = nest_by_group.map(converted);
});
} else {
g_info.common_tsv = null;
}
// Save this geom and load it!
update_geom(g_name, null);
};
var add_plot = function (p_name, p_info) {
// Each plot may have one or more legends. To make space for the
// legends, we put each plot in a table with one row and two
// columns: tdLeft and tdRight.
var plot_table = element.append("table").style("display", "inline-block");
var plot_tr = plot_table.append("tr");
var tdLeft = plot_tr.append("td");
var tdRight = plot_tr.append("td").attr("class", p_name+"_legend");
if(viz_id === null){
p_info.plot_id = p_name;
}else{
p_info.plot_id = viz_id + "_" + p_name;
}
var svg = tdLeft.append("svg")
.attr("id", p_info.plot_id)
.attr("height", p_info.options.height)
.attr("width", p_info.options.width);
// divvy up width/height based on the panel layout
var nrows = Math.max.apply(null, p_info.layout.ROW);
var ncols = Math.max.apply(null, p_info.layout.COL);
var panel_names = p_info.layout.PANEL;
var npanels = Math.max.apply(null, panel_names);
// Note axis names are "shared" across panels (just like the title)
var xtitlepadding = 5 + measureText(p_info["xtitle"], default_axis_px).height;
var ytitlepadding = 5 + measureText(p_info["ytitle"], default_axis_px).height;
// 'margins' are fixed across panels and do not
// include title/axis/label padding (since these are not
// fixed across panels). They do, however, account for
// spacing between panels
var text_height_pixels = measureText("foo", 11).height;
var margin = {
left: 0,
right: text_height_pixels * p_info.panel_margin_lines,
top: text_height_pixels * p_info.panel_margin_lines,
bottom: 0
};
var plotdim = {
width: 0,
height: 0,
xstart: 0,
xend: 0,
ystart: 0,
yend: 0,
graph: {
width: 0,
height: 0
},
margin: margin,
xlab: {
x: 0,
y: 0
},
ylab: {
x: 0,
y: 0
},
title: {
x: 0,
y: 0
}
};
// Draw the title
var titlepadding = measureText(p_info.title, p_info.title_size).height;
// why are we giving the title padding if it is undefined?
if (p_info.title === undefined) titlepadding = 0;
plotdim.title.x = p_info.options.width / 2;
plotdim.title.y = titlepadding;
svg.append("text")
.text(p_info.title)
.attr("class", "plottitle")
.attr("font-family", "sans-serif")
.attr("font-size", p_info.title_size)
.attr("transform", "translate(" + plotdim.title.x + "," +
plotdim.title.y + ")")
.style("text-anchor", "middle");
// grab max text size over axis labels and facet strip labels
var axispaddingy = 5;
if(p_info.hasOwnProperty("ylabs") && p_info.ylabs.length){
axispaddingy += Math.max.apply(null, p_info.ylabs.map(function(entry){
// + 5 to give a little extra space to avoid bad axis labels
// in shiny.
return measureText(entry, p_info.ysize).width + 5;
}));
}
var axispaddingx = 10 + 20;
if(p_info.hasOwnProperty("xlabs") && p_info.xlabs.length){
// TODO: throw warning if text height is large portion of plot height?
axispaddingx += Math.max.apply(null, p_info.xlabs.map(function(entry){
return measureText(entry, p_info.xsize, p_info.xangle).height;
}));
// TODO: carefully calculating this gets complicated with rotating xlabs
//margin.right += 5;
}
plotdim.margin = margin;
var strip_heights = p_info.strips.top.map(function(entry){
return measureText(entry, p_info.strip_text_xsize).height;
});
var strip_widths = p_info.strips.right.map(function(entry){
return measureText(entry, p_info.strip_text_ysize).height;
});
// compute the number of x/y axes, max strip height per row, and
// max strip width per columns, for calculating height/width of
// graphing region.
var row_strip_heights = [];
var col_strip_widths = [];
var n_xaxes = 0;
var n_yaxes = 0;
var current_row, current_col;
for (var layout_i = 0; layout_i < npanels; layout_i++) {
current_row = p_info.layout.ROW[layout_i] - 1;
current_col = p_info.layout.COL[layout_i] - 1;
if(row_strip_heights[current_row] === undefined){
row_strip_heights[current_row] = [];
}
if(col_strip_widths[current_col] === undefined){
col_strip_widths[current_col] = [];
}
row_strip_heights[current_row].push(strip_heights[layout_i]);
col_strip_widths[current_col].push(strip_widths[layout_i]);
if (p_info.layout.COL[layout_i] == 1) {
n_xaxes += p_info.layout.AXIS_X[layout_i];
}
if (p_info.layout.ROW[layout_i] == 1) {
n_yaxes += p_info.layout.AXIS_Y[layout_i];
}
}
function cumsum_array(array_of_arrays){
var cumsum = [], max_value, cumsum_value = 0;
for(var i=0; i, and -- so it is OK to use
// global d3.select here.
d3.select("title").text(response.title);
}
// Add plots.
for (var p_name in response.plots) {
add_plot(p_name, response.plots[p_name]);
add_legend(p_name, response.plots[p_name]);
// Append style sheet to document head.
css.appendChild(document.createTextNode(styles.join(" ")));
document.head.appendChild(css);
}
// Then add selectors and start downloading the first data subset.
for (var s_name in response.selectors) {
add_selector(s_name, response.selectors[s_name]);
}
// Update the scales/axes of the plots if needed
// We do this so that the plots zoom in initially after loading
for (var p_name in response.plots) {
if(response.plots[p_name].axis_domains !== null){
for(var xy in response.plots[p_name].axis_domains){
var selectors = response.plots[p_name].axis_domains[xy].selectors;
if(!isArray(selectors)){
selectors = [selectors];
}
update_scales(p_name, xy, selectors[0],
response.selectors[selectors[0]].selected);
}
}
}
////////////////////////////////////////////
// Widgets at bottom of page
////////////////////////////////////////////
element.append("br");
if(response.hasOwnProperty("source")){
element.append("a")
.attr("id","a_source_href")
.attr("href", response.source)
.text("source");
}
// loading table.
var show_hide_table = element.append("button")
.text("Show download status table");
show_hide_table
.on("click", function(){
if(this.textContent == "Show download status table"){
loading.style("display", "");
show_hide_table.text("Hide download status table");
}else{
loading.style("display", "none");
show_hide_table.text("Show download status table");
}
});
var loading = element.append("table")
.style("display", "none");
Widgets["loading"] = loading;
var tr = loading.append("tr");
tr.append("th").text("geom");
tr.append("th").attr("class", "chunk").text("selected chunk");
tr.append("th").attr("class", "downloaded").text("downloaded");
tr.append("th").attr("class", "total").text("total");
tr.append("th").attr("class", "status").text("status");
// Add geoms and construct nest operators.
for (var g_name in response.geoms) {
add_geom(g_name, response.geoms[g_name]);
}
// Animation control widgets.
var show_message = "Show animation controls";
// add a button to view the animation widgets
var show_hide_animation_controls = element.append("button")
.text(show_message)
.attr("id", viz_id + "_show_hide_animation_controls")
.on("click", function(){
if(this.textContent == show_message){
time_table.style("display", "");
show_hide_animation_controls.text("Hide animation controls");
}else{
time_table.style("display", "none");
show_hide_animation_controls.text(show_message);
}
})
;
// table of the animint widgets
var time_table = element.append("table")
.style("display", "none");
var first_tr = time_table.append("tr");
var first_th = first_tr.append("th");
// if there's a time variable, add a button to pause the animint
if(response.time){
Animation.next = {};
Animation.ms = response.time.ms;
Animation.variable = response.time.variable;
Animation.sequence = response.time.sequence;
Widgets["play_pause"] = first_th.append("button")
.text("Play")
.attr("id", "play_pause")
.on("click", function(){
if(this.textContent == "Play"){
Animation.play();
}else{
Animation.pause(false);
}
})
;
}
first_tr.append("th").text("milliseconds");
if(response.time){
var second_tr = time_table.append("tr");
second_tr.append("td").text("updates");
second_tr.append("td").append("input")
.attr("id", "updates_ms")
.attr("type", "text")
.attr("value", Animation.ms)
.on("change", function(){
Animation.pause(false);
Animation.ms = this.value;
Animation.play();
})
;
}
for(s_name in Selectors){
var s_info = Selectors[s_name];
if(!s_info.hasOwnProperty("duration")){
s_info.duration = 0;
}
}
var selector_array = d3.keys(Selectors);
var duration_rows = time_table.selectAll("tr.duration")
.data(selector_array)
.enter()
.append("tr");
duration_rows
.append("td")
.text(function(s_name){return s_name;});
var duration_tds = duration_rows.append("td");
var duration_inputs = duration_tds
.append("input")
.attr("id", function(s_name){
return viz_id + "_duration_ms_" + s_name;
})
.attr("type", "text")
.on("change", function(s_name){
Selectors[s_name].duration = this.value;
})
.attr("value", function(s_name){
return Selectors[s_name].duration;
});
// selector widgets
var toggle_message = "Show selection menus";
var show_or_hide_fun = function(){
if(this.textContent == toggle_message){
selector_table.style("display", "");
show_hide_selector_widgets.text("Hide selection menus");
d3.select(".urltable").style("display","")
}else{
selector_table.style("display", "none");
show_hide_selector_widgets.text(toggle_message);
d3.select(".urltable").style("display","none")
}
}
var show_hide_selector_widgets = element.append("button")
.text(toggle_message)
.attr("class", "show_hide_selector_widgets")
.on("click", show_or_hide_fun)
;
// adding a table for selector widgets
var selector_table = element.append("table")
.style("display", "none")
.attr("class", "table_selector_widgets")
;
var selector_first_tr = selector_table.append("tr");
selector_first_tr
.append("th")
.text("Variable")
;
selector_first_tr
.append("th")
.text("Selected value(s)")
;
// looping through and adding a row for each selector
for(s_name in Selectors) {
var s_info = Selectors[s_name];
// for .variable .value selectors, levels is undefined and we do
// not want to make a selectize widget.
// TODO: why does it take so long to initialize the selectize
// widget when there are many (>1000) values?
if(isArray(s_info.levels)){
// If there were no geoms that specified clickSelects for this
// selector, then there is no way to select it other than the
// selectize widgets (and possibly legends). So in this case
// we show the selectize widgets by default.
var selector_widgets_hidden =
show_hide_selector_widgets.text() == toggle_message;
var has_no_clickSelects =
!Selectors[s_name].hasOwnProperty("clickSelects")
var has_no_legend =
!Selectors[s_name].hasOwnProperty("legend")
if(selector_widgets_hidden && has_no_clickSelects && has_no_legend){
var node = show_hide_selector_widgets.node();
show_or_hide_fun.apply(node);
}
// removing "." from name so it can be used in ids
var s_name_id = legend_class_name(s_name);
// adding a row for each selector
var selector_widget_row = selector_table
.append("tr")
.attr("class", function() { return s_name_id + "_selector_widget"; })
;
selector_widget_row.append("td").text(s_name);
// adding the selector
var selector_widget_select = selector_widget_row
.append("td")
.append("select")
.attr("class", function() { return s_name_id + "_input"; })
.attr("placeholder", function() { return "Toggle " + s_name; });
// adding an option for each level of the variable
selector_widget_select.selectAll("option")
.data(s_info.levels)
.enter()
.append("option")
.attr("value", function(d) { return d; })
.text(function(d) { return d; });
// making sure that the first option is blank
selector_widget_select
.insert("option")
.attr("value", "")
.text(function() { return "Toggle " + s_name; });
// calling selectize
var selectize_selector = to_select + ' .' + s_name_id + "_input";
if(s_info.type == "single") {
// setting up array of selector and options
var selector_values = [];
for(i in s_info.levels) {
selector_values[i] = {
id: s_name.concat("___", s_info.levels[i]),
text: s_info.levels[i]
};
}
// the id of the first selector
var selected_id = s_name.concat("___", s_info.selected);
// if single selection, only allow one item
var $temp = $(selectize_selector)
.selectize({
create: false,
valueField: 'id',
labelField: 'text',
searchField: ['text'],
options: selector_values,
items: [selected_id],
maxItems: 1,
allowEmptyOption: true,
onChange: function(value) {
// extracting the name and the level to update
var selector_name = value.split("___")[0];
var selected_level = value.split("___")[1];
// updating the selector
update_selector(selector_name, selected_level);
}
})
;
} else { // multiple selection:
// setting up array of selector and options
var selector_values = [];
if(typeof s_info.levels == "object") {
for(i in s_info.levels) {
selector_values[i] = {
id: s_name.concat("___", s_info.levels[i]),
text: s_info.levels[i]
};
}
} else {
selector_values[0] = {
id: s_name.concat("___", s_info.levels),
text: s_info.levels
};
}
// setting up an array to contain the initally selected elements
var initial_selections = [];
for(i in s_info.selected) {
initial_selections[i] = s_name.concat("___", s_info.selected[i]);
}
// construct the selectize
var $temp = $(selectize_selector)
.selectize({
create: false,
valueField: 'id',
labelField: 'text',
searchField: ['text'],
options: selector_values,
items: initial_selections,
maxItems: s_info.levels.length,
allowEmptyOption: true,
onChange: function(value) {
// if nothing is selected, remove what is currently selected
if(value == null) {
// extracting the selector ids from the options
var the_ids = Object.keys($(this)[0].options);
// the name of the appropriate selector
var selector_name = the_ids[0].split("___")[0];
// the previously selected elements
var old_selections = Selectors[selector_name].selected;
// updating the selector for each of the old selections
old_selections.forEach(function(element) {
update_selector(selector_name, element);
});
} else { // value is not null:
// grabbing the name of the selector from the selected value
var selector_name = value[0].split("___")[0];
// identifying the levels that should be selected
var specified_levels = [];
for(i in value) {
specified_levels[i] = value[i].split("___")[1];
}
// the previously selected entries
old_selections = Selectors[selector_name].selected;
// the levels that need to have selections turned on
specified_levels
.filter(function(n) {
return old_selections.indexOf(n) == -1;
})
.forEach(function(element) {
update_selector(selector_name, element);
})
;
// the levels that need to be turned off
// - same approach
old_selections
.filter(function(n) {
return specified_levels.indexOf(n) == -1;
})
.forEach(function(element) {
update_selector(selector_name, element);
})
;
}//value==null
}//onChange
})//selectize
;
}//single or multiple selection.
selectized_array[s_name] = $temp[0].selectize;
}//levels, is.variable.value
} // close for loop through selector widgets
// If this is an animation, then start downloading all the rest of
// the data, and start the animation.
if (response.time) {
var i, prev, cur;
for (var i = 0; i < Animation.sequence.length; i++) {
if (i == 0) {
prev = Animation.sequence[Animation.sequence.length-1];
} else {
prev = Animation.sequence[i - 1];
}
cur = Animation.sequence[i];
Animation.next[prev] = cur;
}
Animation.timer = null;
Animation.play = function(){
if(Animation.timer == null){ // only play if not already playing.
// as shown on http://bl.ocks.org/mbostock/3808234
Animation.timer = setInterval(update_next_animation, Animation.ms);
Widgets["play_pause"].text("Pause");
}
};
Animation.play_after_visible = false;
Animation.pause = function(play_after_visible){
Animation.play_after_visible = play_after_visible;
clearInterval(Animation.timer);
Animation.timer = null;
Widgets["play_pause"].text("Play");
};
var s_info = Selectors[Animation.variable];
Animation.done_geoms = {};
s_info.update.forEach(function(g_name){
var g_info = Geoms[g_name];
if(g_info.chunk_order.length == 1 &&
g_info.chunk_order[0] == Animation.variable){
g_info.seq_i = Animation.sequence.indexOf(s_info.selected);
g_info.seq_count = 0;
Animation.done_geoms[g_name] = 0;
download_next(g_name);
}
});
Animation.play();
all_geom_names = d3.keys(response.geoms);
// This code starts/stops the animation timer when the page is
// hidden, inspired by
// http://stackoverflow.com/questions/1060008
function onchange (evt) {
if(document.visibilityState == "visible"){
if(Animation.play_after_visible){
Animation.play();
}
}else{
if(Widgets["play_pause"].text() == "Pause"){
Animation.pause(true);
}
}
};
document.addEventListener("visibilitychange", onchange);
}
// update_selector_url()
var check_func=function(){
var status_array = $('.status').map(function(){
return $.trim($(this).text());
}).get();
status_array=status_array.slice(1)
return status_array.every(function(elem){ return elem === "displayed"});
}
if(window.location.hash) {
var fragment=window.location.hash;
fragment=fragment.slice(1);
fragment=decodeURI(fragment)
var frag_array=fragment.split(/(.*?})/);
frag_array=frag_array.filter(function(x){ return x!=""})
frag_array.forEach(function(selector_string){
var selector_hash=selector_string.split("=");
var selector_nam=selector_hash[0];
var selector_values=selector_hash[1];
var re = /\{(.*?)\}/;
selector_values = re.exec(selector_values)[1];
var array_values = selector_values.split(',');
if(Selectors.hasOwnProperty(selector_nam)){
var s_info = Selectors[selector_nam]
if(s_info.type=="single"){//TODO fix
array_values.forEach(function(element) {
wait_until_then(100, check_func, update_selector,selector_nam,element)
if(response.time)Animation.pause(true)
});
}else{
var old_selections = Selectors[selector_nam].selected;
// the levels that need to have selections turned on
array_values
.filter(function(n) {
return old_selections.indexOf(n) == -1;
})
.forEach(function(element) {
wait_until_then(100, check_func, update_selector,selector_nam,element)
if(response.time){
Animation.pause(true)
}
});
old_selections
.filter(function(n) {
return array_values.indexOf(n) == -1;
})
.forEach(function(element) {
wait_until_then(100, check_func, update_selector,selector_nam,element)
if(response.time){
Animation.pause(true)
}
});
}//if(single) else multiple selection
}//if(Selectors.hasOwnProperty(selector_nam))
})//frag_array.forEach
}//if(window.location.hash)
});
};
If you are viewing this in an installed package or on CRAN, then there will be no data viz on this page, but you can view it on: https://tdhock.github.io/2023-12-13-train-predict-subsets-regression/ Simulated classification problemsThe previous section investigated a simulated regression problem, whereas in this section we simulate a binary classification problem. Assume there is a data set with some rows from one person, some rows from another,
Above each row has an person ID between 1 and 2. We can imagine a spam filtering system, that has training data for multiple people (here just two). Each row in the table above represents a message which has been labeled as spam or not, by one of the two people. Can we train on one person, and accurately predict on the other person? To do that we will need some features, which we generate/simulate below:
In the table above, there are two sets of two features:
Static visualization of simulated dataBelow we reshape the data to a table which is more suitable for visualization:
Below we visualize the pattern for each person and feature type:
In the plot above, it is apparent that
Benchmark: computing test errorWe use the code below to create a list of classification tasks, for use in the mlr3 framework.
Note in the code above that person is assigned roles subset and
stratum, whereas label is assigned roles target and stratum. When
adapting the code above to real data, the important part is the
The code below is used to define a K-fold cross-validation experiment,
The code below is used to define the learning algorithms to test,
The code below defines the grid of tasks, learners, and resamplings.
The code below runs the benchmark experiment grid. Note that each iteration can be parallelized by declaring a future plan.
Below we compute scores (test error) for each resampling iteration, and show the first row of the result.
Finally we plot the test error values below.
It is clear from the plot above that
Interactive visualization of data, test error, and splitsThe code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.
If you are viewing this in an installed package or on CRAN, then there will be no data viz on this page, but you can view it on: https://tdhock.github.io/2023-12-13-train-predict-subsets-classification/ ConclusionIn this section we have shown how to use mlr3resampling for comparing test error of models trained on same/all/other subsets. Variable size train resamplerThe goal of this section is to explain how to
Simulated regression problemsThe code below creates data for simulated regression problems. First we define a vector of input values,
Below we define a list of two true regression functions (tasks in mlr3 terminology) for our simulated data,
The constant function represents a regression problem which can be solved by always predicting the mean value of outputs (featureless is the best possible learning algorithm). The sin function will be used to generate data with a non-linear pattern that will need to be learned. Below we use a for loop over these two functions/tasks, to simulate the data which will be used as input to the learning algorithms:
In the table above, the input is x, and the output is y. Below we visualize these data, with one task in each facet/panel:
In the plot above we can see two different simulated data sets
(constant and sin). Note that the code above used the Visualizing instance tableIn the code below, we define a K-fold cross-validation experiment, with K=3 folds.
In the output above we can see the parameters of the resampling object, all of which should be integer scalars:
Below we instantiate the resampling on one of the tasks:
Above we see the instance, which need not be examined by the user, but for informational purposes, it contains the following data:
Benchmark: computing test errorIn the code below, we define two learners to compare,
The code above defines
In the code below, we define the benchmark grid, which is all combinations of tasks (constant and sin), learners (rpart and featureless), and the one resampling method.
In the code below, we execute the benchmark experiment (optionally in parallel using the multisession future plan).
The code below computes the test error for each split, and visualizes the information stored in the first row of the result:
The output above contains all of the results related to a particular train/test split. In particular for our purposes, the interesting columns are:
The code below visualizes the resulting test accuracy numbers.
Above we plot the test error for each fold and train set size. There is a different panel for each task and test fold. Each line represents a random seed (ordering of data in train set), and each dot represents a specific train set size. So the plot above shows that some variation in test error, for a given test fold, is due to the random ordering of the train data. Below we summarize each train set size, by taking the mean and standard deviation over each random seed.
The plot above shows a line for the mean, and a ribbon for the standard deviation, over the three random seeds. It is clear from the plot above that
Interactive data vizThe code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.
If you are viewing this in an installed package or on CRAN, then there will be no data viz on this page, but you can view it on: https://tdhock.github.io/2023-12-26-train-sizes-regression/ The interactive data viz consists of two plots:
Simulated classification problemsWhereas in the section above, we focused on regression (output is a real number), in this section we simulate a binary classification problem (output if a factor with two levels).
The simulated data table above consists of two input features (
The table above shows that the
The plot above shows how the output
In the mlr3 code below, we define a list of learners, our resampling method, and a benchmark grid:
Below we run the learning algorithm for each of the train/test splits defined by our benchmark grid:
Below we compute scores (test error) for each resampling iteration, and show the first row of the result.
The output above has columns which are very similar to the regression
example in the previous section. The main difference is the
Finally we plot the test error values below.
It is clear from the plot above that
Exercise for the reader: compute and plot mean and SD for these classification tasks, similar to the plot for the regression tasks in the previous section. Interactive visualization of data, test error, and splitsThe code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.
If you are viewing this in an installed package or on CRAN, then there will be no data viz on this page, but you can view it on: https://tdhock.github.io/2023-12-27-train-sizes-classification/ The interactive data viz consists of two plots
ConclusionIn this section we have shown how to use mlr3resampling for comparing test error of models trained on different sized train sets. Session info
|