diff --git a/R/data.table.R b/R/data.table.R index 27c985e44..9c6eaa847 100644 --- a/R/data.table.R +++ b/R/data.table.R @@ -1036,56 +1036,80 @@ replace_dot_alias = function(e) { while(colsub %iscall% "(") colsub = as.list(colsub)[[-1L]] # fix for R-Forge #5190. colsub[[1L]] gave error when it's a symbol. # NB: _unary_ '-', not _binary_ '-' (#5826). Test for '!' length-2 should be redundant but low-cost & keeps code concise. - if (colsub %iscall% c("!", "-") && length(colsub) == 2L) { - negate_sdcols = TRUE - colsub = colsub[[2L]] - } else negate_sdcols = FALSE - # fix for #1216, make sure the parentheses are peeled from expr of the form (((1:4))) - while(colsub %iscall% "(") colsub = as.list(colsub)[[-1L]] - if (colsub %iscall% ':' && length(colsub)==3L && !is.call(colsub[[2L]]) && !is.call(colsub[[3L]])) { - # .SDcols is of the format a:b, ensure none of : arguments is a call data.table(V1=-1L, V2=-2L, V3=-3L)[,.SD,.SDcols=-V2:-V1] #4231 - .SDcols = eval(colsub, setattr(as.list(seq_along(x)), 'names', names_x), parent.frame()) - } else { - if (colsub %iscall% 'patterns') { - patterns_list_or_vector = eval_with_cols(colsub, names_x) - .SDcols = if (is.list(patterns_list_or_vector)) { - # each pattern gives a new filter condition, intersect the end result - Reduce(intersect, patterns_list_or_vector) + try_processSDcols = !(colsub %iscall% c("!", "-") && length(colsub) == 2L) && !(colsub %iscall% ':') && !(colsub %iscall% 'patterns') + if (try_processSDcols) { + sdcols_result = tryCatch({ + .processSDcols( + SDcols_sub = colsub, + SDcols_missing = FALSE, + x = x, + jsub = jsub, + by = substitute(by), + enclos = parent.frame() + ) + }, error = function(e) { + NULL + }) + if (!is.null(sdcols_result)) { + ansvars = sdvars = sdcols_result$ansvars + ansvals = sdcols_result$ansvals + try_processSDcols = TRUE } else { - patterns_list_or_vector + try_processSDcols = FALSE } + } + if (!try_processSDcols) { + if (colsub %iscall% c("!", "-") && length(colsub) == 2L) { + negate_sdcols = TRUE + colsub = colsub[[2L]] + } else negate_sdcols = FALSE + # fix for #1216, make sure the parentheses are peeled from expr of the form (((1:4))) + while(colsub %iscall% "(") colsub = as.list(colsub)[[-1L]] + if (colsub %iscall% ':' && length(colsub)==3L && !is.call(colsub[[2L]]) && !is.call(colsub[[3L]])) { + # .SDcols is of the format a:b, ensure none of : arguments is a call data.table(V1=-1L, V2=-2L, V3=-3L)[,.SD,.SDcols=-V2:-V1] #4231 + .SDcols = eval(colsub, setattr(as.list(seq_along(x)), 'names', names_x), parent.frame()) } else { - .SDcols = eval(colsub, parent.frame(), parent.frame()) - # allow filtering via function in .SDcols, #3950 - if (is.function(.SDcols)) { - .SDcols = lapply(x, .SDcols) - if (any(idx <- lengths(.SDcols) > 1L | vapply_1c(.SDcols, typeof) != 'logical' | vapply_1b(.SDcols, anyNA))) - stopf("When .SDcols is a function, it is applied to each column; the output of this function must be a non-missing boolean scalar signalling inclusion/exclusion of the column. However, these conditions were not met for: %s", brackify(names(x)[idx])) - .SDcols = unlist(.SDcols, use.names = FALSE) + if (colsub %iscall% 'patterns') { + patterns_list_or_vector = eval_with_cols(colsub, names_x) + .SDcols = if (is.list(patterns_list_or_vector)) { + # each pattern gives a new filter condition, intersect the end result + Reduce(intersect, patterns_list_or_vector) + } else { + patterns_list_or_vector + } + } else { + .SDcols = eval(colsub, parent.frame(), parent.frame()) + # allow filtering via function in .SDcols, #3950 + if (is.function(.SDcols)) { + .SDcols = lapply(x, .SDcols) + if (any(idx <- lengths(.SDcols) > 1L | vapply_1c(.SDcols, typeof) != 'logical' | vapply_1b(.SDcols, anyNA))) + stopf("When .SDcols is a function, it is applied to each column; the output of this function must be a non-missing boolean scalar signalling inclusion/exclusion of the column. However, these conditions were not met for: %s", brackify(names(x)[idx])) + .SDcols = unlist(.SDcols, use.names = FALSE) + } } } - } - if (anyNA(.SDcols)) - stopf(".SDcols missing at the following indices: %s", brackify(which(is.na(.SDcols)))) - if (is.logical(.SDcols)) { - if (length(.SDcols)!=length(x)) stopf(".SDcols is a logical vector of length %d but there are %d columns", length(.SDcols), length(x)) - ansvals = which_(.SDcols, !negate_sdcols) - ansvars = sdvars = names_x[ansvals] - } else if (is.numeric(.SDcols)) { - .SDcols = as.integer(.SDcols) - # if .SDcols is numeric, use 'dupdiff' instead of 'setdiff' - if (length(unique(sign(.SDcols))) > 1L) stopf(".SDcols is numeric but has both +ve and -ve indices") - if (any(idx <- abs(.SDcols)>ncol(x) | abs(.SDcols)<1L)) - stopf(".SDcols is numeric but out of bounds [1, %d] at: %s", ncol(x), brackify(which(idx))) - ansvars = sdvars = if (negate_sdcols) dupdiff(names_x[-.SDcols], bynames) else names_x[.SDcols] - ansvals = if (negate_sdcols) setdiff(seq_along(names(x)), c(.SDcols, which(names(x) %chin% bynames))) else .SDcols - } else { - if (!is.character(.SDcols)) stopf(".SDcols should be column numbers or names") - if (!all(idx <- .SDcols %chin% names_x)) - stopf("Some items of .SDcols are not column names: %s", brackify(.SDcols[!idx])) - ansvars = sdvars = if (negate_sdcols) setdiff(names_x, c(.SDcols, bynames)) else .SDcols - # dups = FALSE here. DT[, .SD, .SDcols=c("x", "x")] again doesn't really help with which 'x' to keep (and if '-' which x to remove) - ansvals = chmatch(ansvars, names_x) + if (anyNA(.SDcols)) + stopf(".SDcols missing at the following indices: %s", brackify(which(is.na(.SDcols)))) + if (is.logical(.SDcols)) { + if (length(.SDcols)!=length(x)) stopf(".SDcols is a logical vector of length %d but there are %d columns", length(.SDcols), length(x)) + ansvals = which_(.SDcols, !negate_sdcols) + ansvars = sdvars = names_x[ansvals] + } else if (is.numeric(.SDcols)) { + .SDcols = as.integer(.SDcols) + # if .SDcols is numeric, use 'dupdiff' instead of 'setdiff' + if (length(unique(sign(.SDcols))) > 1L) stopf(".SDcols is numeric but has both +ve and -ve indices") + if (any(idx <- abs(.SDcols)>ncol(x) | abs(.SDcols)<1L)) + stopf(".SDcols is numeric but out of bounds [1, %d] at: %s", ncol(x), brackify(which(idx))) + ansvars = sdvars = if (negate_sdcols) dupdiff(names_x[-.SDcols], bynames) else names_x[.SDcols] + ansvals = if (negate_sdcols) setdiff(seq_along(names(x)), c(.SDcols, which(names(x) %chin% bynames))) else .SDcols + } else { + if (!is.character(.SDcols)) stopf(".SDcols should be column numbers or names") + if (!all(idx <- .SDcols %chin% names_x)) + stopf("Some items of .SDcols are not column names: %s", brackify(.SDcols[!idx])) + ansvars = sdvars = if (negate_sdcols) setdiff(names_x, c(.SDcols, bynames)) else .SDcols + # dups = FALSE here. DT[, .SD, .SDcols=c("x", "x")] again doesn't really help with which 'x' to keep (and if '-' which x to remove) + ansvals = chmatch(ansvars, names_x) + } } } # fix for long standing FR/bug, #495 and #484 diff --git a/R/groupingsets.R b/R/groupingsets.R index f5fc2101f..29105e316 100644 --- a/R/groupingsets.R +++ b/R/groupingsets.R @@ -16,6 +16,50 @@ rollup.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) { groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj, label=label, enclos = parent.frame()) } +# Helper function to process SDcols +.processSDcols = function(SDcols_sub, SDcols_missing, x, jsub, by, enclos = parent.frame()) { + names_x = names(x) + bysub = substitute(by) + allbyvars = intersect(all.vars(bysub), names_x) + usesSD = ".SD" %chin% all.vars(jsub) + if (!usesSD) { + return(NULL) + } + if (SDcols_missing) { + ansvars = sdvars = setdiff(unique(names_x), union(by, allbyvars)) + ansvals = match(ansvars, names_x) + return(list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals)) + } + sub.result = SDcols_sub + if (sub.result %iscall% "patterns") { + .SDcols = eval_with_cols(sub.result, names_x) + } else { + .SDcols = eval(sub.result, enclos) + } + if (anyNA(.SDcols)) + stopf(".SDcols missing at the following indices: %s", brackify(which(is.na(.SDcols)))) + if (is.character(.SDcols)) { + idx = .SDcols %chin% names_x + if (!all(idx)) + stopf("Some items of .SDcols are not column names: %s", toString(.SDcols[!idx])) + ansvars = sdvars = .SDcols + ansvals = match(ansvars, names_x) + } else if (is.numeric(.SDcols)) { + ansvals = as.integer(.SDcols) + if (any(ansvals < 1L | ansvals > length(names_x))) + stopf(".SDcols contains indices out of bounds") + ansvars = sdvars = names_x[ansvals] + } else if (is.logical(.SDcols)) { + if (length(.SDcols) != length(names_x)) + stopf(".SDcols is a logical vector of length %d but there are %d columns", length(.SDcols), length(names_x)) + ansvals = which(.SDcols) + ansvars = sdvars = names_x[ansvals] + } else { + stopf(".SDcols must be character, numeric, or logical") + } + list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals) +} + cube = function(x, ...) { UseMethod("cube") } @@ -29,6 +73,17 @@ cube.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) { stopf("Argument 'id' must be a logical scalar.") if (missing(j)) stopf("Argument 'j' is required") + # Implementing NSE in cube using the helper, .processSDcols + jj = substitute(j) + sdcols_result = .processSDcols(SDcols_sub = substitute(.SDcols), SDcols_missing = missing(.SDcols), x = x, jsub = jj, by = by, enclos = parent.frame()) + if (is.null(sdcols_result)) { + .SDcols = NULL + } else { + ansvars = sdcols_result$ansvars + sdvars = sdcols_result$sdvars + ansvals = sdcols_result$ansvals + .SDcols = sdvars + } # generate grouping sets for cube - power set: http://stackoverflow.com/a/32187892/2490497 n = length(by) keepBool = sapply(2L^(seq_len(n)-1L), function(k) rep(c(FALSE, TRUE), times=k, each=((2L^n)/(2L*k)))) diff --git a/inst/tests/tests.Rraw b/inst/tests/tests.Rraw index fcf78e9f3..23ef7ff12 100644 --- a/inst/tests/tests.Rraw +++ b/inst/tests/tests.Rraw @@ -11468,6 +11468,11 @@ sets = local({ by=c("color","year","status") lapply(length(by):0, function(i) by[0:i]) }) +test(1750.25, + cube(copy(dt), j = lapply(.SD, mean), by = "color", .SDcols = 4, id=TRUE), + groupingsets(dt, j = lapply(.SD, mean), by = "color", .SDcols = "amount", + sets = list("color", character(0)), id = TRUE) +) test(1750.31, rollup(dt, j = c(list(cnt=.N), lapply(.SD, sum)), by = c("color","year","status"), id=TRUE), groupingsets(dt, j = c(list(cnt=.N), lapply(.SD, sum)), by = c("color","year","status"), sets=sets, id=TRUE) @@ -11503,6 +11508,41 @@ test(1750.34, character(0)), id = TRUE) ) +test(1750.35, + cube(dt, j = lapply(.SD, sum), by = c("color","year","status"), id=TRUE, .SDcols=patterns("value")), + groupingsets(dt, j = lapply(.SD, sum), by = c("color","year","status"), .SDcols = "value", + sets = list(c("color","year","status"), + c("color","year"), + c("color","status"), + "color", + c("year","status"), + "year", + "status", + character(0)), + id = TRUE) +) +test(1750.36, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c("value", "BADCOL")), + error = "Some items of \\.SDcols are not column names" +) +test(1750.37, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(TRUE, FALSE)), + error = "\\.SDcols is a logical vector of length" +) +test(1750.38, + cube(dt, j = lapply(.SD, mean), by = "color", .SDcols = c(FALSE, FALSE, FALSE, TRUE, FALSE), id=TRUE), + groupingsets(dt, j = lapply(.SD, mean), by = "color", .SDcols = "amount", + sets = list("color", character(0)), + id = TRUE) +) +test(1750.39, + cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = list("amount")), + error = ".SDcols must be character, numeric, or logical" +) +test(1750.40, + cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = c(1, 99)), + error = "out of bounds" +) # grouping sets with integer64 if (test_bit64) { set.seed(26)