Skip to content
Prev Previous commit
Next Next commit
Spacing
  • Loading branch information
LukeMathWalker committed Jul 17, 2019
commit 53f8124511c1a884e048d76bcb692e58b72458f7
3 changes: 3 additions & 0 deletions examples/linear_regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,14 @@ pub fn main() {
let n_train_samples = 5000;
let n_test_samples = 1000;
let n_features = 3;

let (X, y) = get_data(n_train_samples + n_test_samples, n_features);
let (X_train, X_test) = X.view().split_at(Axis(0), n_train_samples);
let (y_train, y_test) = y.view().split_at(Axis(0), n_train_samples);

let mut linear_regressor = LinearRegression::new(false);
linear_regressor.fit(X_train, y_train);

let test_predictions = linear_regressor.predict(&X_test);
let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap();
println!("Beta estimated from the training data: {:.3}", linear_regressor.beta.unwrap());
Expand Down