diff --git a/R/PipeOp.R b/R/PipeOp.R index 3da95e78..e98cbfa6 100644 --- a/R/PipeOp.R +++ b/R/PipeOp.R @@ -61,11 +61,11 @@ PipeOp = R6::R6Class("PipeOp", private$.id = id private$.param_set = param_set #FIXME: we really need a function in paradox now to get defaults - private$.param_vals = param_set$data$default - names(private$.param_vals) = param_set$ids - private$.param_vals = insert_named(private$.param_vals, param_vals) - if (!param_set$test(private$.param_vals)) { - stop("Parameters out of bounds") + private$.param_vals = param_set$defaults + if(!is.null(private$.param_vals)) { + if (!param_set$test(private$.param_vals)) { + stop("Parameters out of bounds") + } } }, diff --git a/R/PipeOpDT.R b/R/PipeOpDT.R index d292add1..9b14a2b6 100644 --- a/R/PipeOpDT.R +++ b/R/PipeOpDT.R @@ -50,10 +50,10 @@ PipeOpDT = R6Class("PipeOpDT", list(TaskClassif$new(id = task$id, backend = db, target = tn)) }, - predict = function() { - assert_list(self$inputs, len = 1L, type = "Task") + predict = function(inputs) { + assert_list(inputs, len = 1L, type = "Task") assert_function(self$predict_dt, args = "newdt") - task = self$inputs[[1L]] + task = inputs[[1L]] fn = task$feature_names d = task$data() @@ -66,7 +66,9 @@ PipeOpDT = R6Class("PipeOpDT", d[, (colnames(dt)) := dt] d[, "..row_id" := seq_len(nrow(d))] - list(task$overwrite(d)) + db = DataBackendDataTable$new(d, primary_key = task$backend$primary_key) + tn = task$target_names + list(TaskClassif$new(id = task$id, backend = db, target = tn)) } ) ) diff --git a/R/PipeOpLearner.R b/R/PipeOpLearner.R index ea7db536..543687a2 100644 --- a/R/PipeOpLearner.R +++ b/R/PipeOpLearner.R @@ -40,9 +40,10 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, # private$.result }, - predict2 = function() { + predict = function(inputs) { assert_list(inputs, len = 1L, type = "Task") - predict(self$state) + task = inputs[[1L]] + list(self$state$learner$predict(task)) } ), diff --git a/tests/testthat/test_usecases.R b/tests/testthat/test_usecases.R index 637894bf..5714e7c4 100644 --- a/tests/testthat/test_usecases.R +++ b/tests/testthat/test_usecases.R @@ -4,10 +4,17 @@ test_that("scale + pca", { task = mlr_tasks$get("iris") g = PipeOpScale$new() %>>% PipeOpPCA$new() res1 = g$train(task) - assert_list(res1) res2 = g$predict(task) }) +test_that("scale + pca + PipeOpLearner", { + task = mlr_tasks$get("iris") + g = PipeOpScale$new() %>>% + PipeOpPCA$new() %>>% + PipeOpLearner$new(mlr_learners$get("classif.rpart")) + res1 = g$train(task) + res2 = g$predict(task) +})