diff --git a/R/data.table.R b/R/data.table.R index a989538b1..04584e10b 100644 --- a/R/data.table.R +++ b/R/data.table.R @@ -521,6 +521,58 @@ replace_dot_alias = function(e) { list(GForce=GForce, jsub=jsub, jvnames=jvnames) } +# Helper function to process SDcols +.processSDcols = function(SDcols_sub, SDcols_missing, x, jsub, by, enclos = parent.frame(), bynames = character(0L)) { + names_x = names(x) + allbyvars = intersect(all.vars(by), names_x) + usesSD = ".SD" %chin% all.vars(jsub) + if (!usesSD) { + return(NULL) + } + if (SDcols_missing) { + ansvars = sdvars = setdiff(unique(names_x), union(by, allbyvars)) + ansvals = match(ansvars, names_x) + return(list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals)) + } + sub.result = SDcols_sub + if (sub.result %iscall% ':' && length(sub.result) == 3L) { + return(NULL) + } + if (sub.result %iscall% c("!", "-") && length(sub.result) == 2L) { + negate_sdcols = TRUE + sub.result = sub.result[[2L]] + } else negate_sdcols = FALSE + if (sub.result %iscall% "patterns") { + .SDcols = eval_with_cols(sub.result, names_x) + } else { + .SDcols = eval(sub.result, enclos) + } + if (!is.character(.SDcols) && !is.numeric(.SDcols) && !is.logical(.SDcols)) { + return(NULL) + } + if (anyNA(.SDcols)) + stopf(".SDcols missing at the following indices: %s", brackify(which(is.na(.SDcols)))) + if (is.character(.SDcols)) { + idx = .SDcols %chin% names_x + if (!all(idx)) + stopf("Some items of .SDcols are not column names: %s", toString(.SDcols[!idx])) + ansvars = sdvars = .SDcols + ansvals = match(ansvars, names_x) + } else if (is.numeric(.SDcols)) { + ansvals = as.integer(.SDcols) + if (length(unique(sign(.SDcols))) > 1L) stopf(".SDcols is numeric but has both +ve and -ve indices") + if (any(idx <- abs(.SDcols) > ncol(x) | abs(.SDcols) < 1L)) stopf(".SDcols is numeric but out of bounds [1, %d] at: %s", ncol(x), brackify(which(idx))) + ansvals = if (negate_sdcols) setdiff(seq_along(names(x)), c(.SDcols, which(names(x) %chin% bynames))) else as.integer(.SDcols) + ansvars = sdvars = names_x[ansvals] + } else if (is.logical(.SDcols)) { + if (length(.SDcols) != length(names_x)) + stopf(".SDcols is a logical vector of length %d but there are %d columns", length(.SDcols), length(names_x)) + ansvals = which(.SDcols) + ansvars = sdvars = names_x[ansvals] + } + list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals) +} + "[.data.table" = function(x, i, j, by, keyby, with=TRUE, nomatch=NA, mult="all", roll=FALSE, rollends=if (roll=="nearest") c(TRUE,TRUE) else if (roll>=0.0) c(FALSE,TRUE) else c(TRUE,FALSE), which=FALSE, .SDcols, verbose=getOption("datatable.verbose"), allow.cartesian=getOption("datatable.allow.cartesian"), drop=NULL, on=NULL, env=NULL, showProgress=getOption("datatable.showProgress", interactive())) { # ..selfcount <<- ..selfcount+1 # in dev, we check no self calls, each of which doubles overhead, or could @@ -1413,10 +1465,28 @@ replace_dot_alias = function(e) { while(colsub %iscall% "(") colsub = as.list(colsub)[[-1L]] # fix for R-Forge #5190. colsub[[1L]] gave error when it's a symbol. # NB: _unary_ '-', not _binary_ '-' (#5826). Test for '!' length-2 should be redundant but low-cost & keeps code concise. - if (colsub %iscall% c("!", "-") && length(colsub) == 2L) { - negate_sdcols = TRUE - colsub = colsub[[2L]] - } else negate_sdcols = FALSE + try_processSDcols = !(colsub %iscall% c("!", "-") && length(colsub) == 2L) && !(colsub %iscall% ':') && !(colsub %iscall% 'patterns') + if (try_processSDcols) { + sdcols_result = .processSDcols( + SDcols_sub = colsub, + SDcols_missing = FALSE, + x = x, + jsub = jsub, + by = substitute(by), + enclos = parent.frame() + ) + if (!is.null(sdcols_result)) { + ansvars = sdvars = sdcols_result$ansvars + ansvals = sdcols_result$ansvals + } else { + try_processSDcols = FALSE + } + } + if (!try_processSDcols) { + if (colsub %iscall% c("!", "-") && length(colsub) == 2L) { + negate_sdcols = TRUE + colsub = colsub[[2L]] + } else negate_sdcols = FALSE # fix for #1216, make sure the parentheses are peeled from expr of the form (((1:4))) while(colsub %iscall% "(") colsub = as.list(colsub)[[-1L]] if (colsub %iscall% ':' && length(colsub)==3L && !is.call(colsub[[2L]]) && !is.call(colsub[[3L]])) { @@ -1465,6 +1535,7 @@ replace_dot_alias = function(e) { ansvals = chmatch(ansvars, names_x) } } + } # fix for long standing FR/bug, #495 and #484 allcols = c(names_x, xdotprefix, names_i, idotprefix) non_sdvars = setdiff(intersect(av, allcols), c(bynames, ansvars)) diff --git a/R/groupingsets.R b/R/groupingsets.R index 885a64830..e31284831 100644 --- a/R/groupingsets.R +++ b/R/groupingsets.R @@ -29,6 +29,17 @@ cube.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) { stopf("Argument 'id' must be a logical scalar.") if (missing(j)) stopf("Argument 'j' is required") + # Implementing NSE in cube using the helper, .processSDcols + jj = substitute(j) + sdcols_result = .processSDcols(SDcols_sub = substitute(.SDcols), SDcols_missing = missing(.SDcols), x = x, jsub = jj, by = by, enclos = parent.frame()) + if (is.null(sdcols_result)) { + .SDcols = NULL + } else { + ansvars = sdcols_result$ansvars + sdvars = sdcols_result$sdvars + ansvals = sdcols_result$ansvals + .SDcols = sdvars + } # generate grouping sets for cube - power set: http://stackoverflow.com/a/32187892/2490497 n = length(by) keepBool = sapply(2L^(seq_len(n)-1L), function(k) rep(c(FALSE, TRUE), times=k, each=((2L^n)/(2L*k)))) diff --git a/inst/tests/tests.Rraw b/inst/tests/tests.Rraw index 443487c6a..95dcbb867 100644 --- a/inst/tests/tests.Rraw +++ b/inst/tests/tests.Rraw @@ -11114,6 +11114,36 @@ test(1750.34, character(0)), id = TRUE) ) +test(1750.35, + cube(dt, j = lapply(.SD, sum), by = c("color","year","status"), id=TRUE, .SDcols=patterns("value")), + groupingsets(dt, j = lapply(.SD, sum), by = c("color","year","status"), .SDcols = "value", + sets = list(c("color","year","status"), + c("color","year"), + c("color","status"), + "color", + c("year","status"), + "year", + "status", + character(0)), + id = TRUE) +) +test(1750.36, + names(cube(dt, j = lapply(.SD, sum), by = "color", + .SDcols = -c(1L, 2L, 3L), id = TRUE)), + c("grouping", "color", "amount", "value") +) +test(1750.37, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(TRUE, FALSE)), + error = "\\.SDcols is a logical vector of length" +) +test(1750.38, + cube(dt, j = lapply(.SD, mean), by = "color", .SDcols = c(FALSE, FALSE, FALSE, TRUE, FALSE), id=TRUE)[grouping==0L, .(color, amount)], + dt[, lapply(.SD, mean), by = "color", .SDcols = "amount"] +) +test(1750.39, + cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = c(1, 99)), + error = "out of bounds" +) # grouping sets with integer64 if (test_bit64) { set.seed(26) @@ -11159,6 +11189,27 @@ if (test_bit64) { } # end Grouping Sets +# extra cube tests +test(1750.49, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(NA_character_, "amount")), + error = "\\.SDcols missing at the following indices: \\[1\\]" +) +test(1750.50, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(4L, 5L), id=TRUE)[grouping==0L, .(year, amount, value)], + dt[, lapply(.SD, sum), by = "year", .SDcols = c("amount", "value")] +) +test(1750.51, + data.table:::.processSDcols(quote(a:b), FALSE, dt, quote(lapply(.SD, sum)), "color"), + NULL +) +test(1750.52, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(1L, -2L)), + error = "\\.SDcols is numeric but has both \\+ve and -ve indices" +) +test(1750.53, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c("amount", "nonexistent")), + error = "Some items of .SDcols are not column names: nonexistent" +) # for completeness, added test for NA problem to close #1837. DT = data.table(x=NA) test(1751.1, capture.output(fwrite(DT, verbose=FALSE)), c("x","")) @@ -13085,7 +13136,10 @@ test(1956.3, DT[, .SD, .SDcols = NA_real_], error = 'missing at the following') test(1956.4, DT[, .SD, .SDcols = 2L], error = 'out of bounds.*1.*1.*at') test(1956.5, DT[, .SD, .SDcols = 'b'], error = 'not column names') test(1956.6, DT[, .SD, .SDcols = 3i], error = '.SDcols should be column numbers or names') - +test(1956.7, DT[, .SD, .SDcols = -c(1L, NA_integer_)], error = 'missing at the following') +test(1956.8, DT[, .SD, .SDcols = 1:-1], error = 'both.*ve and.*ve') +test(1956.9, DT[, .SD, .SDcols = 1:99], error = 'out of bounds') +test(1956.91, DT[, .SD, .SDcols = -c("a", "nonexistent")], error = 'not column names') # added brackify to utils for #3116 test(1957.1, brackify(1:3), '[1, 2, 3]') test(1957.2, brackify(1:11), "[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ...]")