Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export(createDefaultExecuteSettings)
export(createDefaultSplitSetting)
export(createExecuteSettings)
export(createExistingSplitSettings)
export(createOutcomeLimitedSplitSettings)
export(createFeatureEngineeringSettings)
export(createGlmModel)
export(createHyperparameterSettings)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ PatientLevelPrediction 6.6.0
- Expanded imputation support and hardened the missing-indicator and predictive mean matching workflow (#622).
- Added support for using logits / linear predictors in rank-based metrics (#615).
- Persisted hyperparameter settings and model names in the results data model to improve downstream model identification and viewing (#633, #632).
- Added outcome-limited split settings for large data sets where model training should use a target number of outcome-positive rows (#396).

## Bug fixes
- Fixed cross-validation prediction generation for iterative hard thresholding models by reusing fitted per-covariate prior variances.
Expand Down
280 changes: 280 additions & 0 deletions R/DataSplitting.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,75 @@ createExistingSplitSettings <- function(splitIds) {
return(splitSettings)
}

#' Create outcome-limited split settings
#'
#' @description
#' Create split settings for large data sets where model training should use at
#' most a target number of outcome-positive rows. When the population has no more
#' outcome rows than the cap, ordinary 25 percent test and 75 percent train
#' splitting is used.
#'
#' John et al. found that, for LASSO logistic regression models fit on large
#' observational health care data sets, learning curves often plateau after a
#' certain number of outcome events. When a study has many more outcome events
#' than are needed to approach the full-data model performance, this split can be
#' useful to reduce computational cost while preserving a separate test set.
#'
#' In row-level stratified mode, the cap can be applied exactly. In subject-level
#' mode, whole subjects are assigned to either training or testing, so the
#' training outcome count can be below the cap. If the selected subjects cannot
#' leave outcome-positive rows for testing, splitting fails with an actionable
#' error.
#'
#' @param maxTrainingOutcomes Maximum number of outcome-positive rows to include
#' in the training set when the cap is triggered. There is no universal default:
#' the adequate number of outcomes depends on the prediction problem, database,
#' model, and acceptable performance loss. In subject-level mode this is
#' approximate because all rows for selected subjects are kept together.
#' @param splitSeed A seed to use when splitting the data for reproducibility.
#' @param nfold An integer > 1 specifying the number of folds used in cross
#' validation.
#' @param type Either `stratified` for row-level splitting or `subject` for
#' subject-level splitting.
#'
#' @return An object of class `splitSettings`.
#' @examples
#' createOutcomeLimitedSplitSettings(maxTrainingOutcomes = 1000, splitSeed = 42)
#' @references John LH, Kors JA, Reps JM, Ryan PB, Rijnbeek PR. Logistic
#' regression models for patient-level prediction based on massive observational
#' data: Do we need all data? International Journal of Medical Informatics.
#' 2022;163:104762. \doi{10.1016/j.ijmedinf.2022.104762}
#' @export
createOutcomeLimitedSplitSettings <- function(maxTrainingOutcomes,
splitSeed = sample(100000, 1),
nfold = 3,
type = "stratified") {
checkIsClass(maxTrainingOutcomes, c("numeric", "integer"))
checkHigher(maxTrainingOutcomes, 0)
checkIsWholeNumber(maxTrainingOutcomes)

checkIsClass(splitSeed, c("numeric", "integer"))
checkIsClass(nfold, c("numeric", "integer"))
checkHigher(nfold, 1)
checkIsWholeNumber(nfold)
if (floor(maxTrainingOutcomes / nfold) < 5) {
stop("Insufficient maxTrainingOutcomes for chosen nfold value, please reduce nfold or increase the cap")
}

checkIsClass(type, "character")
checkInStringVector(type, c("stratified", "subject"))

splitSettings <- list(
maxTrainingOutcomes = maxTrainingOutcomes,
seed = splitSeed,
nfold = nfold,
type = type
)
attr(splitSettings, "fun") <- "outcomeLimitedSplitter"
class(splitSettings) <- "splitSettings"
return(splitSettings)
}


#' Split the plpData into test/train sets using a splitting settings of class
#' \code{splitSettings}
Expand Down Expand Up @@ -555,6 +624,217 @@ subjectSplitter <- function(population, splitSettings) {
return(split)
}

outcomeLimitedSplitter <- function(population, splitSettings) {
checkColumnNames(population, c("rowId", "outcomeCount"))

maxTrainingOutcomes <- splitSettings$maxTrainingOutcomes
nfold <- splitSettings$nfold
seed <- splitSettings$seed
type <- splitSettings$type

if (!is.null(seed)) {
set.seed(seed)
}

totalOutcomes <- sum(population$outcomeCount > 0)
if (totalOutcomes <= maxTrainingOutcomes) {
ParallelLogger::logWarn(
paste0(
"Outcome-limited split did not trigger because the population has ",
totalOutcomes,
" outcome rows, which is not greater than maxTrainingOutcomes = ",
maxTrainingOutcomes,
". Using the default 25% test and 75% train split."
)
)
fallbackSettings <- createDefaultSplitSetting(
testFraction = 0.25,
trainFraction = 0.75,
splitSeed = seed,
nfold = nfold,
type = type
)
fallbackFun <- attr(fallbackSettings, "fun")
return(do.call(eval(parse(text = fallbackFun)), list(
population = population,
splitSettings = fallbackSettings
)))
}

if (type == "subject") {
return(outcomeLimitedSubjectSplitter(population, splitSettings))
}
outcomeLimitedRowSplitter(population, splitSettings)
}

outcomeLimitedRowSplitter <- function(population, splitSettings) {
outcomeRows <- population %>%
dplyr::filter(.data$outcomeCount > 0) %>%
dplyr::mutate(.randomOrder = stats::runif(dplyr::n())) %>%
dplyr::arrange(.data$.randomOrder)
nonOutcomeRows <- population %>%
dplyr::filter(.data$outcomeCount == 0) %>%
dplyr::mutate(.randomOrder = stats::runif(dplyr::n())) %>%
dplyr::arrange(.data$.randomOrder)

trainOutcomeIds <- outcomeRows %>%
dplyr::slice_head(n = splitSettings$maxTrainingOutcomes) %>%
dplyr::pull(.data$rowId)
trainingFraction <- length(trainOutcomeIds) / nrow(outcomeRows)
trainNonOutcomeCount <- floor(nrow(nonOutcomeRows) * trainingFraction)
trainNonOutcomeIds <- nonOutcomeRows %>%
dplyr::slice_head(n = trainNonOutcomeCount) %>%
dplyr::pull(.data$rowId)

trainSplit <- dplyr::bind_rows(
createRowFoldSplit(trainOutcomeIds, splitSettings$nfold),
createRowFoldSplit(trainNonOutcomeIds, splitSettings$nfold)
)
split <- population %>%
dplyr::select("rowId") %>%
dplyr::left_join(trainSplit, by = "rowId") %>%
dplyr::mutate(index = dplyr::coalesce(.data$index, -1L)) %>%
dplyr::arrange(dplyr::desc(.data$rowId)) %>%
as.data.frame()

logOutcomeLimitedSplit(
split = split,
splitSettings = splitSettings,
trainingOutcomeRows = length(trainOutcomeIds)
)
split
}

outcomeLimitedSubjectSplitter <- function(population, splitSettings) {
checkColumnNames(population, "subjectId")

subjectSummary <- population %>%
dplyr::group_by(.data$subjectId) %>%
dplyr::summarise(outcomes = sum(.data$outcomeCount > 0), .groups = "drop")
outcomeSubjects <- subjectSummary %>%
dplyr::filter(.data$outcomes > 0) %>%
dplyr::mutate(.randomOrder = stats::runif(dplyr::n())) %>%
dplyr::arrange(.data$.randomOrder) %>%
dplyr::select(-".randomOrder")
nonOutcomeSubjects <- subjectSummary %>%
dplyr::filter(.data$outcomes == 0) %>%
dplyr::mutate(.randomOrder = stats::runif(dplyr::n())) %>%
dplyr::arrange(.data$.randomOrder) %>%
dplyr::pull(.data$subjectId)

trainOutcomeSubjects <- selectOutcomeLimitedSubjects(
outcomeSubjects = outcomeSubjects,
maxTrainingOutcomes = splitSettings$maxTrainingOutcomes
)
checkOutcomeLimitedSubjectSelection(
outcomeSubjects = outcomeSubjects,
trainOutcomeSubjects = trainOutcomeSubjects
)
trainingFraction <- sum(trainOutcomeSubjects$outcomes) / sum(outcomeSubjects$outcomes)
trainNonOutcomeCount <- floor(length(nonOutcomeSubjects) * trainingFraction)
trainNonOutcomeSubjects <- nonOutcomeSubjects[seq_len(trainNonOutcomeCount)]

trainSubjectSplit <- dplyr::bind_rows(
createSubjectFoldSplit(trainOutcomeSubjects$subjectId, splitSettings$nfold),
createSubjectFoldSplit(trainNonOutcomeSubjects, splitSettings$nfold)
)

split <- population %>%
dplyr::select("rowId", "subjectId") %>%
dplyr::left_join(trainSubjectSplit, by = "subjectId") %>%
dplyr::mutate(index = dplyr::coalesce(.data$index, -1L)) %>%
dplyr::select("rowId", "index") %>%
dplyr::arrange(dplyr::desc(.data$rowId)) %>%
as.data.frame()

logOutcomeLimitedSplit(
split = split,
splitSettings = splitSettings,
trainingOutcomeRows = sum(trainOutcomeSubjects$outcomes)
)
split
}

checkOutcomeLimitedSubjectSelection <- function(outcomeSubjects, trainOutcomeSubjects) {
testOutcomeRows <- sum(outcomeSubjects$outcomes) - sum(trainOutcomeSubjects$outcomes)
if (testOutcomeRows <= 0) {
stop(paste0(
"Outcome-limited subject split would leave no outcome-positive rows in the test set. ",
"Reduce maxTrainingOutcomes, use type = 'stratified', or use a population with more outcome-positive subjects."
))
}
}

selectOutcomeLimitedSubjects <- function(outcomeSubjects, maxTrainingOutcomes) {
selected <- rep(FALSE, nrow(outcomeSubjects))
selectedOutcomeCount <- 0L
for (i in seq_len(nrow(outcomeSubjects))) {
nextOutcomeCount <- outcomeSubjects$outcomes[i]
if (selectedOutcomeCount == 0L && nextOutcomeCount > maxTrainingOutcomes) {
selected[i] <- TRUE
break
}
if (selectedOutcomeCount + nextOutcomeCount <= maxTrainingOutcomes) {
selected[i] <- TRUE
selectedOutcomeCount <- selectedOutcomeCount + nextOutcomeCount
}
if (selectedOutcomeCount == maxTrainingOutcomes) {
break
}
}
outcomeSubjects[selected, , drop = FALSE]
}

# Internal helpers for the outcome-limited splitters. Existing splitters keep
# their fold assignment logic to avoid changing legacy rounding behavior here.
createRowFoldSplit <- function(rowIds, nfold) {
data.frame(rowId = rowIds, index = createFoldIndexes(length(rowIds), nfold))
}

createSubjectFoldSplit <- function(subjectIds, nfold) {
data.frame(subjectId = subjectIds, index = createFoldIndexes(length(subjectIds), nfold))
}

createFoldIndexes <- function(n, nfold) {
if (n == 0) {
return(integer(0))
}
reps <- floor(n / nfold)
leftOver <- n %% nfold
folds <- integer(0)
if (reps > 0) {
folds <- rep(seq_len(nfold), each = reps)
}
if (leftOver > 0) {
folds <- c(folds, seq_len(leftOver))
}
folds
}

logOutcomeLimitedSplit <- function(split, splitSettings, trainingOutcomeRows) {
foldSizesTrain <- split %>%
dplyr::filter(.data$index > 0) %>%
dplyr::group_by(.data$index) %>%
dplyr::summarise(n = dplyr::n(), .groups = "drop") %>%
dplyr::right_join(data.frame(index = seq_len(splitSettings$nfold)), by = "index") %>%
dplyr::arrange(.data$index) %>%
dplyr::mutate(n = dplyr::coalesce(.data$n, 0L)) %>%
dplyr::pull(.data$n)
ParallelLogger::logInfo(paste0(
"Data split into ",
sum(split$index < 0),
" test cases and ",
sum(split$index > 0),
" train cases with ",
trainingOutcomeRows,
" training outcome rows (maxTrainingOutcomes = ",
splitSettings$maxTrainingOutcomes,
"; folds: ",
toString(foldSizesTrain),
")"
))
}


# this is not needed for each function - just the setting where is it not used? (fix in future)
checkInputsSplit <- function(test, train, nfold, seed) {
Expand Down
12 changes: 11 additions & 1 deletion R/ParamChecks.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,19 @@ checkIsClass <- function(parameter, classes) {
return(TRUE)
}

checkIsWholeNumber <- function(parameter) {
name <- deparse(substitute(parameter))
if (length(parameter) != 1 || !is.numeric(parameter) || is.na(parameter) ||
!is.finite(parameter) || parameter != floor(parameter)) {
ParallelLogger::logError(paste0(name, " must be a single whole number"))
stop(paste0(name, " must be a single whole number"))
}
return(TRUE)
}

checkInStringVector <- function(parameter, values) {
name <- deparse(substitute(parameter))
if (!parameter %in% values) {
if (length(parameter) != 1 || is.na(parameter) || !parameter %in% values) {
ParallelLogger::logError(paste0(name, " should be ", paste0(as.character(values), collapse = " or ")))
stop(paste0(name, " has incorrect value"))
}
Expand Down
58 changes: 58 additions & 0 deletions man/createOutcomeLimitedSplitSettings.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading