Skip to content

Commit 9fef05e

Browse files
authored
refactored random forest regressor into reusable compoennts (#318)
1 parent c5816b0 commit 9fef05e

File tree

5 files changed

+239
-157
lines changed

5 files changed

+239
-157
lines changed
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
use rand::Rng;
2+
use std::fmt::Debug;
3+
4+
#[cfg(feature = "serde")]
5+
use serde::{Deserialize, Serialize};
6+
7+
use crate::error::{Failed, FailedError};
8+
use crate::linalg::basic::arrays::{Array1, Array2};
9+
use crate::numbers::basenum::Number;
10+
use crate::numbers::floatnum::FloatNumber;
11+
12+
use crate::rand_custom::get_rng_impl;
13+
use crate::tree::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter};
14+
15+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16+
#[derive(Debug, Clone)]
17+
/// Parameters of the Forest Regressor
18+
/// Some parameters here are passed directly into base estimator.
19+
pub struct BaseForestRegressorParameters {
20+
#[cfg_attr(feature = "serde", serde(default))]
21+
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
22+
pub max_depth: Option<u16>,
23+
#[cfg_attr(feature = "serde", serde(default))]
24+
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
25+
pub min_samples_leaf: usize,
26+
#[cfg_attr(feature = "serde", serde(default))]
27+
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
28+
pub min_samples_split: usize,
29+
#[cfg_attr(feature = "serde", serde(default))]
30+
/// The number of trees in the forest.
31+
pub n_trees: usize,
32+
#[cfg_attr(feature = "serde", serde(default))]
33+
/// Number of random sample of predictors to use as split candidates.
34+
pub m: Option<usize>,
35+
#[cfg_attr(feature = "serde", serde(default))]
36+
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
37+
pub keep_samples: bool,
38+
#[cfg_attr(feature = "serde", serde(default))]
39+
/// Seed used for bootstrap sampling and feature selection for each tree.
40+
pub seed: u64,
41+
#[cfg_attr(feature = "serde", serde(default))]
42+
pub bootstrap: bool,
43+
#[cfg_attr(feature = "serde", serde(default))]
44+
pub splitter: Splitter,
45+
}
46+
47+
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
48+
for BaseForestRegressor<TX, TY, X, Y>
49+
{
50+
fn eq(&self, other: &Self) -> bool {
51+
if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() {
52+
false
53+
} else {
54+
self.trees
55+
.iter()
56+
.zip(other.trees.iter())
57+
.all(|(a, b)| a == b)
58+
}
59+
}
60+
}
61+
62+
/// Forest Regressor
63+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
64+
#[derive(Debug)]
65+
pub struct BaseForestRegressor<
66+
TX: Number + FloatNumber + PartialOrd,
67+
TY: Number,
68+
X: Array2<TX>,
69+
Y: Array1<TY>,
70+
> {
71+
trees: Option<Vec<BaseTreeRegressor<TX, TY, X, Y>>>,
72+
samples: Option<Vec<Vec<bool>>>,
73+
}
74+
75+
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
76+
BaseForestRegressor<TX, TY, X, Y>
77+
{
78+
/// Build a forest of trees from the training set.
79+
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
80+
/// * `y` - the target class values
81+
pub fn fit(
82+
x: &X,
83+
y: &Y,
84+
parameters: BaseForestRegressorParameters,
85+
) -> Result<BaseForestRegressor<TX, TY, X, Y>, Failed> {
86+
let (n_rows, num_attributes) = x.shape();
87+
88+
if n_rows != y.shape() {
89+
return Err(Failed::fit("Number of rows in X should = len(y)"));
90+
}
91+
92+
let mtry = parameters
93+
.m
94+
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
95+
96+
let mut rng = get_rng_impl(Some(parameters.seed));
97+
let mut trees: Vec<BaseTreeRegressor<TX, TY, X, Y>> = Vec::new();
98+
99+
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
100+
if parameters.keep_samples {
101+
// TODO: use with_capacity here
102+
maybe_all_samples = Some(Vec::new());
103+
}
104+
105+
let mut samples: Vec<usize> = (0..n_rows).map(|_| 1).collect();
106+
107+
for _ in 0..parameters.n_trees {
108+
if parameters.bootstrap {
109+
samples =
110+
BaseForestRegressor::<TX, TY, X, Y>::sample_with_replacement(n_rows, &mut rng);
111+
}
112+
113+
// keep samples is flag is on
114+
if let Some(ref mut all_samples) = maybe_all_samples {
115+
all_samples.push(samples.iter().map(|x| *x != 0).collect())
116+
}
117+
118+
let params = BaseTreeRegressorParameters {
119+
max_depth: parameters.max_depth,
120+
min_samples_leaf: parameters.min_samples_leaf,
121+
min_samples_split: parameters.min_samples_split,
122+
seed: Some(parameters.seed),
123+
splitter: parameters.splitter.clone(),
124+
};
125+
let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples.clone(), mtry, params)?;
126+
trees.push(tree);
127+
}
128+
129+
Ok(BaseForestRegressor {
130+
trees: Some(trees),
131+
samples: maybe_all_samples,
132+
})
133+
}
134+
135+
/// Predict class for `x`
136+
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
137+
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
138+
let mut result = Y::zeros(x.shape().0);
139+
140+
let (n, _) = x.shape();
141+
142+
for i in 0..n {
143+
result.set(i, self.predict_for_row(x, i));
144+
}
145+
146+
Ok(result)
147+
}
148+
149+
fn predict_for_row(&self, x: &X, row: usize) -> TY {
150+
let n_trees = self.trees.as_ref().unwrap().len();
151+
152+
let mut result = TY::zero();
153+
154+
for tree in self.trees.as_ref().unwrap().iter() {
155+
result += tree.predict_for_row(x, row);
156+
}
157+
158+
result / TY::from_usize(n_trees).unwrap()
159+
}
160+
161+
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
162+
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
163+
let (n, _) = x.shape();
164+
if self.samples.is_none() {
165+
Err(Failed::because(
166+
FailedError::PredictFailed,
167+
"Need samples=true for OOB predictions.",
168+
))
169+
} else if self.samples.as_ref().unwrap()[0].len() != n {
170+
Err(Failed::because(
171+
FailedError::PredictFailed,
172+
"Prediction matrix must match matrix used in training for OOB predictions.",
173+
))
174+
} else {
175+
let mut result = Y::zeros(n);
176+
177+
for i in 0..n {
178+
result.set(i, self.predict_for_row_oob(x, i));
179+
}
180+
181+
Ok(result)
182+
}
183+
}
184+
185+
fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
186+
let mut n_trees = 0;
187+
let mut result = TY::zero();
188+
189+
for (tree, samples) in self
190+
.trees
191+
.as_ref()
192+
.unwrap()
193+
.iter()
194+
.zip(self.samples.as_ref().unwrap())
195+
{
196+
if !samples[row] {
197+
result += tree.predict_for_row(x, row);
198+
n_trees += 1;
199+
}
200+
}
201+
202+
// TODO: What to do if there are no oob trees?
203+
result / TY::from(n_trees).unwrap()
204+
}
205+
206+
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
207+
let mut samples = vec![0; nrows];
208+
for _ in 0..nrows {
209+
let xi = rng.gen_range(0..nrows);
210+
samples[xi] += 1;
211+
}
212+
samples
213+
}
214+
}

src/ensemble/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
//!
1717
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)
1818
19+
mod base_forest_regressor;
1920
/// Random forest classifier
2021
pub mod random_forest_classifier;
2122
/// Random forest regressor

0 commit comments

Comments
 (0)