|
| 1 | +use rand::Rng; |
| 2 | +use std::fmt::Debug; |
| 3 | + |
| 4 | +#[cfg(feature = "serde")] |
| 5 | +use serde::{Deserialize, Serialize}; |
| 6 | + |
| 7 | +use crate::error::{Failed, FailedError}; |
| 8 | +use crate::linalg::basic::arrays::{Array1, Array2}; |
| 9 | +use crate::numbers::basenum::Number; |
| 10 | +use crate::numbers::floatnum::FloatNumber; |
| 11 | + |
| 12 | +use crate::rand_custom::get_rng_impl; |
| 13 | +use crate::tree::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter}; |
| 14 | + |
| 15 | +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] |
| 16 | +#[derive(Debug, Clone)] |
| 17 | +/// Parameters of the Forest Regressor |
| 18 | +/// Some parameters here are passed directly into base estimator. |
| 19 | +pub struct BaseForestRegressorParameters { |
| 20 | + #[cfg_attr(feature = "serde", serde(default))] |
| 21 | + /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) |
| 22 | + pub max_depth: Option<u16>, |
| 23 | + #[cfg_attr(feature = "serde", serde(default))] |
| 24 | + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) |
| 25 | + pub min_samples_leaf: usize, |
| 26 | + #[cfg_attr(feature = "serde", serde(default))] |
| 27 | + /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) |
| 28 | + pub min_samples_split: usize, |
| 29 | + #[cfg_attr(feature = "serde", serde(default))] |
| 30 | + /// The number of trees in the forest. |
| 31 | + pub n_trees: usize, |
| 32 | + #[cfg_attr(feature = "serde", serde(default))] |
| 33 | + /// Number of random sample of predictors to use as split candidates. |
| 34 | + pub m: Option<usize>, |
| 35 | + #[cfg_attr(feature = "serde", serde(default))] |
| 36 | + /// Whether to keep samples used for tree generation. This is required for OOB prediction. |
| 37 | + pub keep_samples: bool, |
| 38 | + #[cfg_attr(feature = "serde", serde(default))] |
| 39 | + /// Seed used for bootstrap sampling and feature selection for each tree. |
| 40 | + pub seed: u64, |
| 41 | + #[cfg_attr(feature = "serde", serde(default))] |
| 42 | + pub bootstrap: bool, |
| 43 | + #[cfg_attr(feature = "serde", serde(default))] |
| 44 | + pub splitter: Splitter, |
| 45 | +} |
| 46 | + |
| 47 | +impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq |
| 48 | + for BaseForestRegressor<TX, TY, X, Y> |
| 49 | +{ |
| 50 | + fn eq(&self, other: &Self) -> bool { |
| 51 | + if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() { |
| 52 | + false |
| 53 | + } else { |
| 54 | + self.trees |
| 55 | + .iter() |
| 56 | + .zip(other.trees.iter()) |
| 57 | + .all(|(a, b)| a == b) |
| 58 | + } |
| 59 | + } |
| 60 | +} |
| 61 | + |
| 62 | +/// Forest Regressor |
| 63 | +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] |
| 64 | +#[derive(Debug)] |
| 65 | +pub struct BaseForestRegressor< |
| 66 | + TX: Number + FloatNumber + PartialOrd, |
| 67 | + TY: Number, |
| 68 | + X: Array2<TX>, |
| 69 | + Y: Array1<TY>, |
| 70 | +> { |
| 71 | + trees: Option<Vec<BaseTreeRegressor<TX, TY, X, Y>>>, |
| 72 | + samples: Option<Vec<Vec<bool>>>, |
| 73 | +} |
| 74 | + |
| 75 | +impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> |
| 76 | + BaseForestRegressor<TX, TY, X, Y> |
| 77 | +{ |
| 78 | + /// Build a forest of trees from the training set. |
| 79 | + /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. |
| 80 | + /// * `y` - the target class values |
| 81 | + pub fn fit( |
| 82 | + x: &X, |
| 83 | + y: &Y, |
| 84 | + parameters: BaseForestRegressorParameters, |
| 85 | + ) -> Result<BaseForestRegressor<TX, TY, X, Y>, Failed> { |
| 86 | + let (n_rows, num_attributes) = x.shape(); |
| 87 | + |
| 88 | + if n_rows != y.shape() { |
| 89 | + return Err(Failed::fit("Number of rows in X should = len(y)")); |
| 90 | + } |
| 91 | + |
| 92 | + let mtry = parameters |
| 93 | + .m |
| 94 | + .unwrap_or((num_attributes as f64).sqrt().floor() as usize); |
| 95 | + |
| 96 | + let mut rng = get_rng_impl(Some(parameters.seed)); |
| 97 | + let mut trees: Vec<BaseTreeRegressor<TX, TY, X, Y>> = Vec::new(); |
| 98 | + |
| 99 | + let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None; |
| 100 | + if parameters.keep_samples { |
| 101 | + // TODO: use with_capacity here |
| 102 | + maybe_all_samples = Some(Vec::new()); |
| 103 | + } |
| 104 | + |
| 105 | + let mut samples: Vec<usize> = (0..n_rows).map(|_| 1).collect(); |
| 106 | + |
| 107 | + for _ in 0..parameters.n_trees { |
| 108 | + if parameters.bootstrap { |
| 109 | + samples = |
| 110 | + BaseForestRegressor::<TX, TY, X, Y>::sample_with_replacement(n_rows, &mut rng); |
| 111 | + } |
| 112 | + |
| 113 | + // keep samples is flag is on |
| 114 | + if let Some(ref mut all_samples) = maybe_all_samples { |
| 115 | + all_samples.push(samples.iter().map(|x| *x != 0).collect()) |
| 116 | + } |
| 117 | + |
| 118 | + let params = BaseTreeRegressorParameters { |
| 119 | + max_depth: parameters.max_depth, |
| 120 | + min_samples_leaf: parameters.min_samples_leaf, |
| 121 | + min_samples_split: parameters.min_samples_split, |
| 122 | + seed: Some(parameters.seed), |
| 123 | + splitter: parameters.splitter.clone(), |
| 124 | + }; |
| 125 | + let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples.clone(), mtry, params)?; |
| 126 | + trees.push(tree); |
| 127 | + } |
| 128 | + |
| 129 | + Ok(BaseForestRegressor { |
| 130 | + trees: Some(trees), |
| 131 | + samples: maybe_all_samples, |
| 132 | + }) |
| 133 | + } |
| 134 | + |
| 135 | + /// Predict class for `x` |
| 136 | + /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. |
| 137 | + pub fn predict(&self, x: &X) -> Result<Y, Failed> { |
| 138 | + let mut result = Y::zeros(x.shape().0); |
| 139 | + |
| 140 | + let (n, _) = x.shape(); |
| 141 | + |
| 142 | + for i in 0..n { |
| 143 | + result.set(i, self.predict_for_row(x, i)); |
| 144 | + } |
| 145 | + |
| 146 | + Ok(result) |
| 147 | + } |
| 148 | + |
| 149 | + fn predict_for_row(&self, x: &X, row: usize) -> TY { |
| 150 | + let n_trees = self.trees.as_ref().unwrap().len(); |
| 151 | + |
| 152 | + let mut result = TY::zero(); |
| 153 | + |
| 154 | + for tree in self.trees.as_ref().unwrap().iter() { |
| 155 | + result += tree.predict_for_row(x, row); |
| 156 | + } |
| 157 | + |
| 158 | + result / TY::from_usize(n_trees).unwrap() |
| 159 | + } |
| 160 | + |
| 161 | + /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. |
| 162 | + pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> { |
| 163 | + let (n, _) = x.shape(); |
| 164 | + if self.samples.is_none() { |
| 165 | + Err(Failed::because( |
| 166 | + FailedError::PredictFailed, |
| 167 | + "Need samples=true for OOB predictions.", |
| 168 | + )) |
| 169 | + } else if self.samples.as_ref().unwrap()[0].len() != n { |
| 170 | + Err(Failed::because( |
| 171 | + FailedError::PredictFailed, |
| 172 | + "Prediction matrix must match matrix used in training for OOB predictions.", |
| 173 | + )) |
| 174 | + } else { |
| 175 | + let mut result = Y::zeros(n); |
| 176 | + |
| 177 | + for i in 0..n { |
| 178 | + result.set(i, self.predict_for_row_oob(x, i)); |
| 179 | + } |
| 180 | + |
| 181 | + Ok(result) |
| 182 | + } |
| 183 | + } |
| 184 | + |
| 185 | + fn predict_for_row_oob(&self, x: &X, row: usize) -> TY { |
| 186 | + let mut n_trees = 0; |
| 187 | + let mut result = TY::zero(); |
| 188 | + |
| 189 | + for (tree, samples) in self |
| 190 | + .trees |
| 191 | + .as_ref() |
| 192 | + .unwrap() |
| 193 | + .iter() |
| 194 | + .zip(self.samples.as_ref().unwrap()) |
| 195 | + { |
| 196 | + if !samples[row] { |
| 197 | + result += tree.predict_for_row(x, row); |
| 198 | + n_trees += 1; |
| 199 | + } |
| 200 | + } |
| 201 | + |
| 202 | + // TODO: What to do if there are no oob trees? |
| 203 | + result / TY::from(n_trees).unwrap() |
| 204 | + } |
| 205 | + |
| 206 | + fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> { |
| 207 | + let mut samples = vec![0; nrows]; |
| 208 | + for _ in 0..nrows { |
| 209 | + let xi = rng.gen_range(0..nrows); |
| 210 | + samples[xi] += 1; |
| 211 | + } |
| 212 | + samples |
| 213 | + } |
| 214 | +} |
0 commit comments