Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 58 additions & 29 deletions src/xgboost/xgb_regressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,38 +248,47 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
}

// A split is only valid if it results in a positive gain.
if *best_split_score > 0.0 {
let mut left_idxs = Vec::new();
let mut right_idxs = Vec::new();
for idx in idxs.iter() {
if data.get((*idx, *best_feature_idx)).to_f64().unwrap() <= *best_threshold {
left_idxs.push(*idx);
} else {
right_idxs.push(*idx);
}
if *best_split_score <= 0.0 {
return;
}

let mut left_idxs = Vec::new();
let mut right_idxs = Vec::new();
for idx in idxs.iter() {
if data.get((*idx, *best_feature_idx)).to_f64().unwrap() <= *best_threshold {
left_idxs.push(*idx);
} else {
right_idxs.push(*idx);
}
}

*left = Some(Box::new(TreeRegressor::fit(
data,
g,
h,
&left_idxs,
max_depth - 1,
min_child_weight,
lambda,
gamma,
)));
*right = Some(Box::new(TreeRegressor::fit(
data,
g,
h,
&right_idxs,
max_depth - 1,
min_child_weight,
lambda,
gamma,
)));
if left_idxs.is_empty() || right_idxs.is_empty() {
// A degenerate split where all samples land on one side. This can happen when feature
// values are large enough that `(x_i + x_i_next) / 2.0` overflows to +inf,
// all samples satisfy `<= +inf` and right_idxs is empty.
return;
}

*left = Some(Box::new(TreeRegressor::fit(
data,
g,
h,
&left_idxs,
max_depth - 1,
min_child_weight,
lambda,
gamma,
)));
*right = Some(Box::new(TreeRegressor::fit(
data,
g,
h,
&right_idxs,
max_depth - 1,
min_child_weight,
lambda,
gamma,
)));
}

/// Iterates through a single feature to find the best possible split point.
Expand Down Expand Up @@ -733,6 +742,26 @@ mod tests {
assert!((tree.right.unwrap().value - (-0.833333333)).abs() < 1e-9);
}

/// Exercises the degenerate-split guard in insert_child_nodes.
#[test]
fn test_no_panic_on_degenerate_split_from_overflow() {
let large = f64::MAX / 1.5;
let x_vec = vec![vec![large], vec![large * 1.1]];
let x = DenseMatrix::from_2d_vec(&x_vec).unwrap();
let y = vec![0.0, 1.0];

let params = XGRegressorParameters::default()
.with_n_estimators(10)
.with_max_depth(3);

let model = XGRegressor::fit(&x, &y, params);
assert!(model.is_ok(), "Fit panicked or failed: {:?}", model.err());

let predictions = model.unwrap().predict(&x);
assert!(predictions.is_ok());
assert_eq!(predictions.unwrap().len(), 2);
}

/// A "smoke test" to ensure the main XGRegressor can fit and predict on multidimensional data.
#[test]
fn test_xgregressor_fit_predict_multidimensional() {
Expand Down
Loading