diff --git a/algorithms/linfa-ensemble/examples/adaboost_iris.rs b/algorithms/linfa-ensemble/examples/adaboost_iris.rs new file mode 100644 index 000000000..1a76d57c6 --- /dev/null +++ b/algorithms/linfa-ensemble/examples/adaboost_iris.rs @@ -0,0 +1,134 @@ +use linfa::prelude::{Fit, Predict, ToConfusionMatrix}; +use linfa_ensemble::AdaBoostParams; +use linfa_trees::DecisionTree; +use ndarray_rand::rand::SeedableRng; +use rand::rngs::SmallRng; + +fn adaboost_with_stumps(n_estimators: usize, learning_rate: f64) { + // Load dataset + let mut rng = SmallRng::seed_from_u64(42); + let (train, test) = linfa_datasets::iris() + .shuffle(&mut rng) + .split_with_ratio(0.8); + + // Train AdaBoost model with decision tree stumps (max_depth=1) + // Stumps are weak learners commonly used with AdaBoost + let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng) + .n_estimators(n_estimators) + .learning_rate(learning_rate) + .fit(&train) + .unwrap(); + + // Make predictions + let predictions = model.predict(&test); + println!("Final Predictions: \n{predictions:?}"); + + let cm = predictions.confusion_matrix(&test).unwrap(); + println!("{cm:?}"); + println!( + "Test accuracy: {:.2}%\nwith Decision Tree stumps (max_depth=1),\nn_estimators: {n_estimators},\nlearning_rate: {learning_rate}.\n", + 100.0 * cm.accuracy() + ); + println!("Number of models trained: {}", model.n_estimators()); +} + +fn adaboost_with_shallow_trees(n_estimators: usize, learning_rate: f64, max_depth: usize) { + let mut rng = SmallRng::seed_from_u64(42); + let (train, test) = linfa_datasets::iris() + .shuffle(&mut rng) + .split_with_ratio(0.8); + + // Train AdaBoost model with shallow decision trees + let model = + AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(max_depth)), rng) + .n_estimators(n_estimators) + .learning_rate(learning_rate) + .fit(&train) + .unwrap(); + + // Make predictions + let predictions = model.predict(&test); + println!("Final Predictions: \n{predictions:?}"); + + let cm = predictions.confusion_matrix(&test).unwrap(); + println!("{cm:?}"); + println!( + "Test accuracy: {:.2}%\nwith Decision Trees (max_depth={max_depth}),\nn_estimators: {n_estimators},\nlearning_rate: {learning_rate}.\n", + 100.0 * cm.accuracy() + ); + + // Display model weights + println!("Model weights (alpha values):"); + for (i, weight) in model.weights().iter().enumerate() { + println!(" Model {}: {:.4}", i + 1, weight); + } + println!(); +} + +fn main() { + println!("{}", "=".repeat(80)); + println!("AdaBoost Examples on Iris Dataset"); + println!("{}", "=".repeat(80)); + println!(); + + // Example 1: AdaBoost with decision stumps (most common configuration) + println!("Example 1: AdaBoost with Decision Stumps"); + println!("{}", "-".repeat(80)); + adaboost_with_stumps(50, 1.0); + println!(); + + // Example 2: AdaBoost with lower learning rate + println!("Example 2: AdaBoost with Lower Learning Rate"); + println!("{}", "-".repeat(80)); + adaboost_with_stumps(100, 0.5); + println!(); + + // Example 3: AdaBoost with shallow trees + println!("Example 3: AdaBoost with Shallow Decision Trees"); + println!("{}", "-".repeat(80)); + adaboost_with_shallow_trees(50, 1.0, 2); + println!(); + + // Example 4: Comparing different configurations + println!("Example 4: Comparing Configurations"); + println!("{}", "-".repeat(80)); + let configs = vec![ + (25, 1.0, 1, "Few stumps, high learning rate"), + (50, 1.0, 1, "Medium stumps, high learning rate"), + (100, 0.5, 1, "Many stumps, low learning rate"), + (50, 1.0, 2, "Shallow trees, high learning rate"), + ]; + + for (n_est, lr, depth, desc) in configs { + let mut rng = SmallRng::seed_from_u64(42); + let (train, test) = linfa_datasets::iris() + .shuffle(&mut rng) + .split_with_ratio(0.8); + + let model = + AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(depth)), rng) + .n_estimators(n_est) + .learning_rate(lr) + .fit(&train) + .unwrap(); + + let predictions = model.predict(&test); + let cm = predictions.confusion_matrix(&test).unwrap(); + + println!( + "{desc:50} => Accuracy: {:.2}% (models trained: {})", + 100.0 * cm.accuracy(), + model.n_estimators() + ); + } + + println!(); + println!("{}", "=".repeat(80)); + println!("Notes:"); + println!("- AdaBoost works by training weak learners sequentially"); + println!("- Each learner focuses on samples misclassified by previous learners"); + println!("- Decision stumps (depth=1) are the most common weak learners"); + println!("- Lower learning_rate provides regularization but needs more estimators"); + println!("- Model weights (alpha) reflect each learner's contribution to prediction"); + println!("{}", "=".repeat(80)); +} diff --git a/algorithms/linfa-ensemble/src/adaboost.rs b/algorithms/linfa-ensemble/src/adaboost.rs new file mode 100644 index 000000000..ef60bf8ee --- /dev/null +++ b/algorithms/linfa-ensemble/src/adaboost.rs @@ -0,0 +1,319 @@ +use crate::AdaBoostValidParams; +use linfa::{ + dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned}, + error::Error, + traits::*, + DatasetBase, +}; +use ndarray::{Array1, Array2, Axis}; +use ndarray_rand::rand::distributions::WeightedIndex; +use ndarray_rand::rand::prelude::*; +use ndarray_rand::rand::Rng; +use std::{cmp::Eq, collections::HashMap, hash::Hash}; + +/// A fitted AdaBoost ensemble classifier. +/// +/// ## Structure +/// +/// AdaBoost (Adaptive Boosting) is an ensemble learning method that combines multiple weak learners +/// into a strong classifier. Unlike bagging methods (like Random Forest), AdaBoost trains learners +/// sequentially, where each new learner focuses more on examples that previous learners misclassified. +/// +/// Each fitted model `M` has an associated weight (alpha) that represents its contribution to the +/// final prediction. Models that perform better on their training data receive higher weights. +/// +/// ## Algorithm Overview +/// +/// Given a [DatasetBase](DatasetBase) denoted as `D` with `n` samples: +/// 1. Initialize sample weights uniformly: `w_i = 1/n` for all samples +/// 2. For each iteration `t` from 1 to T (number of estimators): +/// a. Train base learner on weighted dataset +/// b. Calculate weighted error rate +/// c. Compute model weight (alpha) based on error +/// d. Update sample weights: increase weights for misclassified samples +/// e. Normalize sample weights +/// +/// ## Prediction Algorithm +/// +/// The final prediction is computed using weighted majority voting: +/// - Each model's prediction is weighted by its alpha value +/// - The class with the highest weighted vote is selected +/// +/// ## Example +/// +/// ```no_run +/// use linfa::prelude::{Fit, Predict}; +/// use linfa_ensemble::AdaBoostParams; +/// use linfa_trees::DecisionTree; +/// use ndarray_rand::rand::SeedableRng; +/// use rand::rngs::SmallRng; +/// +/// // Load Iris dataset +/// let mut rng = SmallRng::seed_from_u64(42); +/// let (train, test) = linfa_datasets::iris() +/// .shuffle(&mut rng) +/// .split_with_ratio(0.8); +/// +/// // Train AdaBoost with decision tree stumps +/// let adaboost_model = AdaBoostParams::new(DecisionTree::params().max_depth(Some(1))) +/// .n_estimators(50) +/// .learning_rate(1.0) +/// .fit(&train) +/// .unwrap(); +/// +/// // Make predictions on the test set +/// let predictions = adaboost_model.predict(&test); +/// ``` +/// +/// ## References +/// +/// * Freund, Y., & Schapire, R. E. (1997). A decision-theoretic generalization of on-line learning +/// and an application to boosting. Journal of Computer and System Sciences, 55(1), 119-139. +/// * [Scikit-Learn AdaBoost Documentation](https://scikit-learn.org/stable/modules/ensemble.html#adaboost) +/// * [An Introduction to Statistical Learning](https://www.statlearning.com/), Chapter 8 +#[derive(Debug, Clone)] +pub struct AdaBoost { + /// The fitted base learner models + pub models: Vec, + /// The weight (alpha) for each model in the ensemble + pub model_weights: Vec, + /// The unique class labels seen during training + pub classes: Vec, +} + +impl AdaBoost { + /// Returns the number of estimators in the ensemble + pub fn n_estimators(&self) -> usize { + self.models.len() + } + + /// Returns the model weights (alpha values) + pub fn weights(&self) -> &[f64] { + &self.model_weights + } +} + +impl PredictInplace, T> for AdaBoost +where + M: PredictInplace, T>, + ::Elem: Copy + Eq + Hash + std::fmt::Debug + Into, + T: AsTargets + AsTargetsMut::Elem>, + usize: Into<::Elem>, +{ + fn predict_inplace(&self, x: &Array2, y: &mut T) { + let y_array = y.as_targets(); + assert_eq!( + x.nrows(), + y_array.len_of(Axis(0)), + "The number of data points must match the number of outputs." + ); + + // Collect predictions from all models + let mut all_predictions = Vec::with_capacity(self.models.len()); + for model in &self.models { + let mut pred = model.default_target(x); + model.predict_inplace(x, &mut pred); + all_predictions.push(pred); + } + + // Create a map for each sample to accumulate weighted votes + let mut prediction_maps = y_array.map(|_| HashMap::new()); + + // Accumulate weighted predictions from each model + for (model_idx, prediction) in all_predictions.iter().enumerate() { + let pred_array = prediction.as_targets(); + let weight = self.model_weights[model_idx]; + + // For each sample, add the model's weighted prediction + for (vote_map, &pred_val) in prediction_maps.iter_mut().zip(pred_array.iter()) { + let class_idx: usize = pred_val.into(); + *vote_map.entry(class_idx).or_insert(0.0) += weight; + } + } + + // For each sample, select the class with the highest weighted vote + let final_predictions = prediction_maps.map(|votes| { + votes + .iter() + .max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap()) + .map(|(k, _)| (*k).into()) + .expect("No predictions found") + }); + + // Write final predictions to output + let mut y_array_mut = y.as_targets_mut(); + for (y, pred) in y_array_mut.iter_mut().zip(final_predictions.iter()) { + *y = *pred; + } + } + + fn default_target(&self, x: &Array2) -> T { + self.models[0].default_target(x) + } +} + +impl Fit, T, Error> for AdaBoostValidParams +where + D: Clone + ndarray::ScalarOperand, + T: FromTargetArrayOwned + AsTargets + Clone, + T::Elem: Copy + Eq + Hash + std::fmt::Debug + Into, + P: Fit, T, Error> + Clone, + P::Object: PredictInplace, T>, + R: Rng + Clone, + usize: Into, +{ + type Object = AdaBoost; + + fn fit( + &self, + dataset: &DatasetBase, T>, + ) -> core::result::Result { + let n_samples = dataset.records.nrows(); + + if n_samples == 0 { + return Err(Error::Parameters( + "Cannot fit AdaBoost on empty dataset".to_string(), + )); + } + + // Extract unique class labels from target array + let target_array = dataset.targets.as_targets(); + let mut classes_set: Vec = target_array + .iter() + .copied() + .collect::>() + .into_iter() + .collect(); + // Sort by converting to usize for ordering + classes_set.sort_unstable_by_key(|x| (*x).into()); + + if classes_set.len() < 2 { + return Err(Error::Parameters( + "AdaBoost requires at least 2 classes".to_string(), + )); + } + + // Initialize sample weights uniformly + let mut sample_weights = Array1::::from_elem(n_samples, 1.0 / n_samples as f64); + + let mut models = Vec::with_capacity(self.n_estimators); + let mut model_weights = Vec::with_capacity(self.n_estimators); + + let mut rng = self.rng.clone(); + + for iteration in 0..self.n_estimators { + // Normalize weights to sum to 1 + let weight_sum: f64 = sample_weights.sum(); + if weight_sum <= 0.0 { + return Err(Error::NotConverged(format!( + "Sample weights sum to zero at iteration {}", + iteration + ))); + } + sample_weights /= weight_sum; + + // Resample dataset according to sample weights + // This is the practical implementation of AdaBoost when base learners don't support weights + let dist = WeightedIndex::new(sample_weights.iter().copied()) + .map_err(|_| Error::Parameters("Invalid sample weights".to_string()))?; + + let bootstrap_indices: Vec = + (0..n_samples).map(|_| dist.sample(&mut rng)).collect(); + + // Create bootstrap dataset by selecting rows according to weights + let bootstrap_records = dataset.records.select(Axis(0), &bootstrap_indices); + let bootstrap_targets_array = target_array.select(Axis(0), &bootstrap_indices); + + // Convert to owned target type using new_targets + let bootstrap_targets = T::new_targets(bootstrap_targets_array); + let bootstrap_dataset = DatasetBase::new(bootstrap_records, bootstrap_targets); + + // Fit base learner on resampled dataset + let model = self.model_params.fit(&bootstrap_dataset).map_err(|e| { + Error::NotConverged(format!( + "Base learner failed to fit at iteration {}: {}", + iteration, e + )) + })?; + + // Make predictions on training data + let mut predictions = model.default_target(&dataset.records); + model.predict_inplace(&dataset.records, &mut predictions); + let pred_array = predictions.as_targets(); + + // Calculate weighted error + let mut weighted_error = 0.0f64; + for ((true_label, pred_label), weight) in target_array + .iter() + .zip(pred_array.iter()) + .zip(sample_weights.iter()) + { + let true_idx: usize = (*true_label).into(); + let pred_idx: usize = (*pred_label).into(); + + if true_idx != pred_idx { + weighted_error += *weight; + } + } + + // Handle edge cases for weighted error + if weighted_error <= 0.0 { + // Perfect prediction - add model with maximum weight and stop + model_weights.push(10.0); // Large weight for perfect classifier + models.push(model); + break; + } + + // For multi-class SAMME, check if error rate is above the random guessing threshold + let k = classes_set.len() as f64; + let error_threshold = (k - 1.0) / k; + + if weighted_error >= error_threshold { + // Worse than random guessing for multi-class - don't add this model + if models.is_empty() { + return Err(Error::NotConverged(format!( + "First base learner performs worse than random guessing (error: {:.4}, threshold: {:.4})", + weighted_error, error_threshold + ))); + } + break; + } + + // Calculate model weight (alpha) using SAMME algorithm + // For multi-class: alpha = learning_rate * (log((1 - error) / error) + log(K - 1)) + // where K is number of classes + let error_ratio = (1.0 - weighted_error) / weighted_error; + let alpha = self.learning_rate * (error_ratio.ln() + (k - 1.0).ln()); + + // Update sample weights + for ((true_label, pred_label), weight) in target_array + .iter() + .zip(pred_array.iter()) + .zip(sample_weights.iter_mut()) + { + let true_idx: usize = (*true_label).into(); + let pred_idx: usize = (*pred_label).into(); + + if true_idx != pred_idx { + // Increase weight for misclassified samples + *weight *= alpha.exp(); + } + } + + model_weights.push(alpha); + models.push(model); + } + + if models.is_empty() { + return Err(Error::NotConverged( + "No models were successfully trained".to_string(), + )); + } + + Ok(AdaBoost { + models, + model_weights, + classes: classes_set, + }) + } +} diff --git a/algorithms/linfa-ensemble/src/adaboost_hyperparams.rs b/algorithms/linfa-ensemble/src/adaboost_hyperparams.rs new file mode 100644 index 000000000..91912fd90 --- /dev/null +++ b/algorithms/linfa-ensemble/src/adaboost_hyperparams.rs @@ -0,0 +1,230 @@ +use linfa::{ + error::{Error, Result}, + ParamGuard, +}; +use rand::rngs::ThreadRng; +use rand::Rng; + +/// The set of valid hyperparameters for the [AdaBoost](crate::AdaBoost) algorithm. +/// +/// ## Parameters +/// +/// * `n_estimators`: The maximum number of weak learners to train sequentially. +/// More estimators generally improve performance but increase training time and risk overfitting. +/// Typical values range from 50 to 500. Default: 50. +/// +/// * `learning_rate`: Shrinks the contribution of each classifier. There is a trade-off between +/// `learning_rate` and `n_estimators`. Lower values require more estimators to achieve the same +/// performance but may generalize better. Must be positive. Default: 1.0. +/// +/// * `model_params`: The parameters for the base learner (weak classifier). Typically, shallow +/// decision trees (stumps with max_depth=1 or max_depth=2) are used as weak learners. +/// +/// * `rng`: Random number generator used for tie-breaking and reproducibility. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct AdaBoostValidParams { + /// The maximum number of estimators to train + pub n_estimators: usize, + /// The learning rate (shrinkage parameter) + pub learning_rate: f64, + /// The base learner parameters + pub model_params: P, + /// Random number generator + pub rng: R, +} + +/// A helper struct for building [AdaBoost](crate::AdaBoost) hyperparameters. +/// +/// This struct follows the builder pattern, allowing you to chain method calls to configure +/// the AdaBoost algorithm before fitting. +/// +/// ## Example +/// +/// ```no_run +/// use linfa_ensemble::AdaBoostParams; +/// use linfa_trees::DecisionTree; +/// +/// let params = AdaBoostParams::new(DecisionTree::::params().max_depth(Some(1))) +/// .n_estimators(100) +/// .learning_rate(0.5); +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct AdaBoostParams(AdaBoostValidParams); + +impl

AdaBoostParams { + /// Create a new AdaBoost parameter set with default values and a thread-local RNG. + /// + /// # Arguments + /// + /// * `model_params` - The parameters for the base learner (e.g., DecisionTreeParams) + /// + /// # Default Values + /// + /// * `n_estimators`: 50 + /// * `learning_rate`: 1.0 + pub fn new(model_params: P) -> AdaBoostParams { + Self::new_fixed_rng(model_params, rand::thread_rng()) + } +} + +impl AdaBoostParams { + /// Create a new AdaBoost parameter set with a fixed RNG for reproducibility. + /// + /// # Arguments + /// + /// * `model_params` - The parameters for the base learner + /// * `rng` - A seeded random number generator for reproducible results + /// + /// # Example + /// + /// ```no_run + /// use linfa_ensemble::AdaBoostParams; + /// use linfa_trees::DecisionTree; + /// use ndarray_rand::rand::SeedableRng; + /// use rand::rngs::SmallRng; + /// + /// let rng = SmallRng::seed_from_u64(42); + /// let params = AdaBoostParams::new_fixed_rng( + /// DecisionTree::::params().max_depth(Some(1)), + /// rng + /// ); + /// ``` + pub fn new_fixed_rng(model_params: P, rng: R) -> AdaBoostParams { + Self(AdaBoostValidParams { + n_estimators: 50, + learning_rate: 1.0, + model_params, + rng, + }) + } + + /// Set the maximum number of weak learners to train. + /// + /// # Arguments + /// + /// * `n_estimators` - Must be at least 1. Typical values: 50-500 + /// + /// # Notes + /// + /// Higher values generally lead to better training performance but: + /// * Increase training time linearly + /// * May lead to overfitting + /// * Should be balanced with `learning_rate` + pub fn n_estimators(mut self, n_estimators: usize) -> Self { + self.0.n_estimators = n_estimators; + self + } + + /// Set the learning rate (shrinkage parameter). + /// + /// # Arguments + /// + /// * `learning_rate` - Must be positive. Typical values: 0.01 to 2.0 + /// + /// # Notes + /// + /// * Values < 1.0 provide regularization and often improve generalization + /// * Lower values require more estimators to achieve similar performance + /// * A common strategy is to use learning_rate=0.1 with n_estimators=500 + pub fn learning_rate(mut self, learning_rate: f64) -> Self { + self.0.learning_rate = learning_rate; + self + } +} + +impl ParamGuard for AdaBoostParams { + type Checked = AdaBoostValidParams; + type Error = Error; + + fn check_ref(&self) -> Result<&Self::Checked> { + if self.0.n_estimators < 1 { + Err(Error::Parameters(format!( + "n_estimators must be at least 1, but was {}", + self.0.n_estimators + ))) + } else if self.0.learning_rate <= 0.0 { + Err(Error::Parameters(format!( + "learning_rate must be positive, but was {}", + self.0.learning_rate + ))) + } else if !self.0.learning_rate.is_finite() { + Err(Error::Parameters( + "learning_rate must be finite (not NaN or infinity)".to_string(), + )) + } else { + Ok(&self.0) + } + } + + fn check(self) -> Result { + self.check_ref()?; + Ok(self.0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use linfa_trees::DecisionTree; + use ndarray_rand::rand::SeedableRng; + use rand::rngs::SmallRng; + + #[test] + fn test_default_params() { + let rng = SmallRng::seed_from_u64(42); + let params = AdaBoostParams::new_fixed_rng(DecisionTree::::params(), rng); + assert_eq!(params.0.n_estimators, 50); + assert_eq!(params.0.learning_rate, 1.0); + } + + #[test] + fn test_custom_params() { + let rng = SmallRng::seed_from_u64(42); + let params = AdaBoostParams::new_fixed_rng(DecisionTree::::params(), rng) + .n_estimators(100) + .learning_rate(0.5); + assert_eq!(params.0.n_estimators, 100); + assert_eq!(params.0.learning_rate, 0.5); + } + + #[test] + fn test_invalid_n_estimators() { + let rng = SmallRng::seed_from_u64(42); + let params = AdaBoostParams::new_fixed_rng(DecisionTree::::params(), rng) + .n_estimators(0); + assert!(params.check_ref().is_err()); + } + + #[test] + fn test_invalid_learning_rate_negative() { + let rng = SmallRng::seed_from_u64(42); + let params = AdaBoostParams::new_fixed_rng(DecisionTree::::params(), rng) + .learning_rate(-0.5); + assert!(params.check_ref().is_err()); + } + + #[test] + fn test_invalid_learning_rate_zero() { + let rng = SmallRng::seed_from_u64(42); + let params = AdaBoostParams::new_fixed_rng(DecisionTree::::params(), rng) + .learning_rate(0.0); + assert!(params.check_ref().is_err()); + } + + #[test] + fn test_invalid_learning_rate_nan() { + let rng = SmallRng::seed_from_u64(42); + let params = AdaBoostParams::new_fixed_rng(DecisionTree::::params(), rng) + .learning_rate(f64::NAN); + assert!(params.check_ref().is_err()); + } + + #[test] + fn test_valid_params() { + let rng = SmallRng::seed_from_u64(42); + let params = AdaBoostParams::new_fixed_rng(DecisionTree::::params(), rng) + .n_estimators(100) + .learning_rate(0.5); + assert!(params.check_ref().is_ok()); + } +} diff --git a/algorithms/linfa-ensemble/src/lib.rs b/algorithms/linfa-ensemble/src/lib.rs index 26a94f616..87f129a78 100644 --- a/algorithms/linfa-ensemble/src/lib.rs +++ b/algorithms/linfa-ensemble/src/lib.rs @@ -6,6 +6,7 @@ //! This crate (`linfa-ensemble`), provides pure Rust implementations of popular ensemble techniques, such as //! * [Boostrap Aggregation](EnsembleLearner) //! * [Random Forest](RandomForest) +//! * [AdaBoost] //! //! ## Bootstrap Aggregation (aka Bagging) //! @@ -18,6 +19,14 @@ //! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being //! the number of available features. //! +//! ## AdaBoost +//! +//! AdaBoost (Adaptive Boosting) is a boosting ensemble method that trains weak learners sequentially. +//! Each subsequent learner focuses on the examples that previous learners misclassified by increasing +//! their sample weights. The final prediction is a weighted vote of all learners, where better-performing +//! learners receive higher weights. Unlike bagging methods, boosting creates a strong classifier from +//! weak learners (typically shallow decision trees or "stumps"). +//! //! ## Reference //! //! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html) @@ -81,9 +90,13 @@ //! let predictions = random_forest.predict(&test); //! ``` +mod adaboost; +mod adaboost_hyperparams; mod algorithm; mod hyperparams; +pub use adaboost::*; +pub use adaboost_hyperparams::*; pub use algorithm::*; pub use hyperparams::*; @@ -135,4 +148,193 @@ mod tests { let acc = cm.accuracy(); assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc); } + + #[test] + fn test_adaboost_accuracy_on_iris_dataset() { + let mut rng = SmallRng::seed_from_u64(42); + let (train, test) = linfa_datasets::iris() + .shuffle(&mut rng) + .split_with_ratio(0.8); + + // Train AdaBoost with decision tree stumps (shallow trees) + let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng) + .n_estimators(50) + .learning_rate(1.0) + .fit(&train) + .unwrap(); + + let predictions = model.predict(&test); + + let cm = predictions.confusion_matrix(&test).unwrap(); + let acc = cm.accuracy(); + assert!( + acc >= 0.85, + "Expected accuracy to be above 85%, got {}", + acc + ); + } + + #[test] + fn test_adaboost_with_low_learning_rate() { + let mut rng = SmallRng::seed_from_u64(42); + let (train, test) = linfa_datasets::iris() + .shuffle(&mut rng) + .split_with_ratio(0.8); + + // Train AdaBoost with lower learning rate and more estimators + let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(2)), rng) + .n_estimators(100) + .learning_rate(0.5) + .fit(&train) + .unwrap(); + + let predictions = model.predict(&test); + + let cm = predictions.confusion_matrix(&test).unwrap(); + let acc = cm.accuracy(); + assert!( + acc >= 0.85, + "Expected accuracy to be above 85%, got {}", + acc + ); + } + + #[test] + fn test_adaboost_model_weights() { + let mut rng = SmallRng::seed_from_u64(42); + let (train, _) = linfa_datasets::iris() + .shuffle(&mut rng) + .split_with_ratio(0.8); + + let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng) + .n_estimators(10) + .fit(&train) + .unwrap(); + + // Verify that model weights are positive + for weight in model.weights() { + assert!( + *weight > 0.0, + "Model weight should be positive, got {}", + weight + ); + } + + // Verify we have the expected number of models + assert_eq!(model.n_estimators(), 10); + } + + #[test] + fn test_adaboost_different_learning_rates() { + // Test that different learning rates produce different model weights + let rng1 = SmallRng::seed_from_u64(42); + let rng2 = SmallRng::seed_from_u64(42); + let (train, _) = linfa_datasets::iris() + .shuffle(&mut SmallRng::seed_from_u64(42)) + .split_with_ratio(0.8); + + // Train with learning_rate = 1.0 + let model1 = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng1) + .n_estimators(10) + .learning_rate(1.0) + .fit(&train) + .unwrap(); + + // Train with learning_rate = 0.5 + let model2 = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng2) + .n_estimators(10) + .learning_rate(0.5) + .fit(&train) + .unwrap(); + + // Model weights should be different + let weights1 = model1.weights(); + let weights2 = model2.weights(); + + // At least one weight should be significantly different + let mut has_difference = false; + for (w1, w2) in weights1.iter().zip(weights2.iter()) { + if (w1 - w2).abs() > 0.01 { + has_difference = true; + break; + } + } + + assert!( + has_difference, + "Different learning rates should produce different model weights" + ); + } + + #[test] + fn test_adaboost_early_stopping_on_perfect_fit() { + use linfa::DatasetBase; + use ndarray::Array2; + + // Create a simple linearly separable dataset + let records = Array2::from_shape_vec( + (6, 2), + vec![ + 0.0, 0.0, // class 0 + 0.1, 0.1, // class 0 + 0.2, 0.2, // class 0 + 1.0, 1.0, // class 1 + 1.1, 1.1, // class 1 + 1.2, 1.2, // class 1 + ], + ) + .unwrap(); + let targets = ndarray::array![0, 0, 0, 1, 1, 1]; + let dataset = DatasetBase::new(records, targets); + + let rng = SmallRng::seed_from_u64(42); + let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(3)), rng) + .n_estimators(50) + .fit(&dataset) + .unwrap(); + + // Should stop early due to perfect classification + assert!( + model.n_estimators() < 50, + "Expected early stopping, but got {} estimators", + model.n_estimators() + ); + } + + #[test] + fn test_adaboost_single_class_error() { + use linfa::DatasetBase; + use ndarray::Array2; + + // Create dataset with only one class + let records = + Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 0.3, 0.3]).unwrap(); + let targets = ndarray::array![0, 0, 0, 0]; // All same class + let dataset = DatasetBase::new(records, targets); + + let rng = SmallRng::seed_from_u64(42); + let result = AdaBoostParams::new_fixed_rng(DecisionTree::params(), rng) + .n_estimators(10) + .fit(&dataset); + + assert!(result.is_err(), "Should fail with single class dataset"); + } + + #[test] + fn test_adaboost_classes_method() { + let mut rng = SmallRng::seed_from_u64(42); + let (train, _) = linfa_datasets::iris() + .shuffle(&mut rng) + .split_with_ratio(0.8); + + let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng) + .n_estimators(10) + .fit(&train) + .unwrap(); + + // Verify classes are properly stored + let classes = &model.classes; + assert_eq!(classes.len(), 3, "Iris has 3 classes"); + assert_eq!(classes, &vec![0, 1, 2], "Classes should be [0, 1, 2]"); + } }