Skip to content

Commit 320a2db

Browse files
committed
predict script should reuse the Ensemble code
1 parent 50ea125 commit 320a2db

File tree

4 files changed

+45
-101
lines changed

4 files changed

+45
-101
lines changed

predict/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
# -*- coding: utf-8 -*-
33
__author__ = 'maxim'
44

5-
from ensemble import Ensemble
5+
from ensemble import Ensemble, predict_multiple
66
from model_io import get_model_info, ModelNotAvailable

predict/ensemble.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
# -*- coding: utf-8 -*-
33
__author__ = 'maxim'
44

5+
from itertools import izip, count
56
import os
67

78
import numpy as np
9+
import pandas as pd
810

911
from train.job_info import parse_model_infos
1012
from util import *
@@ -53,3 +55,32 @@ def ensemble_top_models(job_info, top_n=5):
5355
models = [get_model_info(path, strict=False) for path in model_paths]
5456
top_models = [model for model in models if model.is_available()][:top_n]
5557
return Ensemble(top_models)
58+
59+
60+
def predict_multiple(job_info, raw_df, rows_to_predict, top_models_num=5):
61+
debug('Predicting %s target=%s' % (job_info.name, job_info.target))
62+
63+
raw_targets = raw_df[job_info.target][-(rows_to_predict + 1):].reset_index(drop=True)
64+
changes_df = to_changes(raw_df)
65+
target_changes = changes_df[job_info.target][-rows_to_predict:].reset_index(drop=True)
66+
dates = changes_df.date[-rows_to_predict:].reset_index(drop=True)
67+
68+
df = changes_df[:-1] # the data for models is shifted by one: the target for the last row is unknown
69+
70+
ensemble = Ensemble.ensemble_top_models(job_info, top_n=top_models_num)
71+
predictions = ensemble.predict_aggregated(df, last_rows=rows_to_predict)
72+
73+
result = []
74+
for idx, date, prediction_change, target_change in izip(count(), dates, predictions, target_changes):
75+
debug('%%-change on %s: predict=%+.5f target=%+.5f' % (date, prediction_change, target_change))
76+
77+
# target_change is approx. raw_targets[idx + 1] / raw_targets[idx] - 1.0
78+
raw_target = raw_targets[idx + 1]
79+
raw_predicted = (1 + prediction_change) * raw_targets[idx]
80+
debug(' value on %s: predict= %.5f target= %.5f' % (date, raw_predicted, raw_target))
81+
82+
result.append({'Time': date, 'Prediction': raw_predicted, 'True': raw_target})
83+
84+
result_df = pd.DataFrame(result)
85+
result_df.set_index('Time', inplace=True)
86+
return result_df

run_predict.py

Lines changed: 8 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,80 +2,28 @@
22
# -*- coding: utf-8 -*-
33
__author__ = 'maxim'
44

5-
6-
import os
7-
import numpy as np
8-
95
import poloniex
106
from predict import *
117
from train import *
12-
from train.evaluator import Evaluator
138
from util import *
149

15-
16-
def try_model(path, data_dir='_data', zoo_dir='_zoo'):
17-
model_info = get_model_info(path)
18-
run_params = model_info.run_params
19-
job = JobInfo(data_dir, zoo_dir, run_params['name'], run_params['target'])
20-
raw_df = read_df(job.get_source_name())
21-
changes_df = to_changes(raw_df)
22-
data_set = to_dataset(changes_df, run_params['k'], run_params['target'], model_info.model_class.DATA_WITH_BIAS)
23-
24-
model = model_info.model_class(**model_info.model_params)
25-
evaluator = Evaluator()
26-
27-
with model.session():
28-
model.restore(model_info.path)
29-
test_eval, test_stats = evaluator.eval(model, data_set)
30-
info('Result:\n%sEval=%.6f\n' % (evaluator.stats_str(test_stats), test_eval))
31-
32-
33-
def predict_model(changes_df, path):
34-
model_info = get_model_info(path)
35-
run_params = model_info.run_params
36-
model = model_info.model_class(**model_info.model_params)
37-
x = to_dataset_for_prediction(changes_df[:-1], run_params['k'], model_info.model_class.DATA_WITH_BIAS)
38-
x = x[-1:]
39-
40-
with model.session():
41-
model.restore(model_info.path)
42-
predicted = float(model.predict(x))
43-
info('Predicted change=%.5f' % predicted)
44-
return predicted
45-
46-
47-
def predict_all_models(changes_df, name, accept):
48-
home_dir = '_zoo/%s' % name
49-
models = [dir for dir in os.listdir(home_dir) if accept(dir)]
50-
if not models:
51-
info('No models found for %s' % name)
52-
return
53-
54-
predictions = []
55-
for model in models:
56-
try:
57-
value = predict_model(changes_df, os.path.join(home_dir, model))
58-
predictions.append(value)
59-
except ModelNotAvailable as e:
60-
warn('Cannot use model from "%s": class "%s" is not available not this system' % (model, e.model_class))
61-
warn('Most probable reason is that model dependencies are not met')
62-
info()
63-
info('Mean predicted value for %s: %.5f' % (name, np.mean(predictions)))
64-
info()
65-
66-
6710
def main():
68-
tickers, periods, targets = parse_command_line(default_tickers=[],
11+
tickers, periods, targets = parse_command_line(default_tickers=['BTC_ETH'],
6912
default_periods=['day'],
7013
default_targets=['high'])
7114

7215
for ticker in tickers:
7316
for period in periods:
7417
for target in targets:
18+
job = JobInfo('_data', '_zoo', name='%s_%s' % (ticker, period), target=target)
7519
raw_df = poloniex.get_latest_data(ticker, period=period, depth=100)
76-
changes_df = to_changes(raw_df)
77-
predict_all_models(changes_df, '%s_%s' % (ticker, period), lambda name: name.startswith('%s_' % target))
20+
result_df = predict_multiple(job, raw_df=raw_df, rows_to_predict=1)
21+
22+
raw_df.set_index('date', inplace=True)
23+
result_df.rename(columns={"True": "Current-Truth"}, inplace=True)
7824

25+
info('Latest chart info:', raw_df.tail(2), '', sep='\n')
26+
info('Prediction for "%s":' % target, result_df, '', sep='\n')
7927

8028
if __name__ == '__main__':
8129
main()

run_visual.py

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,50 +2,14 @@
22
# -*- coding: utf-8 -*-
33
__author__ = 'maxim'
44

5-
from itertools import izip, count
6-
75
import matplotlib.pyplot as plt
8-
import pandas as pd
96

10-
from predict import *
11-
from train import *
12-
from util import *
7+
from predict import predict_multiple
8+
from train import JobInfo
9+
from util import parse_command_line, read_df
1310

1411
plt.style.use('ggplot')
1512

16-
17-
def predict_multiple(job_info, last_rows):
18-
debug('Predicting %s target=%s' % (job_info.name, job_info.target))
19-
20-
raw_df = read_df(job_info.get_source_name())
21-
raw_targets = raw_df[job_info.target][-(last_rows + 1):].reset_index(drop=True)
22-
23-
changes_df = to_changes(raw_df)
24-
target_changes = changes_df[job_info.target][-last_rows:].reset_index(drop=True)
25-
dates = changes_df.date[-last_rows:].reset_index(drop=True)
26-
27-
df = changes_df[:-1] # the data for models is shifted by one: the target for the last row is unknown
28-
29-
ensemble = Ensemble.ensemble_top_models(job_info)
30-
predictions = ensemble.predict_aggregated(df, last_rows=last_rows)
31-
32-
result = []
33-
for idx, date, prediction_change, target_change in izip(count(), dates, predictions, target_changes):
34-
debug('%%-change on %s: predict=%+.5f target=%+.5f' % (date, prediction_change, target_change))
35-
36-
# target_change is approx. raw_targets[idx + 1] / raw_targets[idx] - 1.0
37-
raw_target = raw_targets[idx + 1]
38-
raw_predicted = (1 + prediction_change) * raw_targets[idx]
39-
debug(' value on %s: predict= %.5f target= %.5f' % (date, raw_predicted, raw_target))
40-
41-
result.append({'Time': date, 'Prediction': raw_predicted, 'True': raw_target})
42-
43-
result_df = pd.DataFrame(result)
44-
result_df.set_index('Time', inplace=True)
45-
result_df.index.names = ['']
46-
return result_df
47-
48-
4913
def main():
5014
train_date = None
5115
tickers, periods, targets = parse_command_line(default_tickers=['BTC_ETH', 'BTC_LTC'],
@@ -56,7 +20,8 @@ def main():
5620
for period in periods:
5721
for target in targets:
5822
job = JobInfo('_data', '_zoo', name='%s_%s' % (ticker, period), target=target)
59-
result_df = predict_multiple(job, last_rows=120)
23+
result_df = predict_multiple(job, raw_df=read_df(job.get_source_name()), rows_to_predict=120)
24+
result_df.index.names = ['']
6025
result_df.plot(title=job.name)
6126

6227
if train_date is not None:

0 commit comments

Comments
 (0)