Skip to content

Commit 788bc6f

Browse files
authored
Merge pull request #173 from AnMakc/regression_test
Add regression tests
2 parents eeb3168 + 94f3212 commit 788bc6f

File tree

5 files changed

+2750
-0
lines changed

5 files changed

+2750
-0
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import random
2+
from pathlib import Path
3+
4+
import numpy as np
5+
import pandas as pd
6+
import torch
7+
8+
from model import Kronos, KronosPredictor, KronosTokenizer
9+
10+
11+
TEST_DATA_ROOT = Path(__file__).parent
12+
INPUT_DATA_PATH = TEST_DATA_ROOT / "regression_input.csv"
13+
OUTPUT_DATA_DIR = TEST_DATA_ROOT
14+
MAX_CTX_LEN = 512
15+
TEST_CTX_LEN = [512, 256]
16+
PRED_LEN = 8
17+
FEATURE_NAMES = ["open", "high", "low", "close", "volume", "amount"]
18+
19+
MODEL_REVISION = "901c26c1332695a2a8f243eb2f37243a37bea320"
20+
TOKENIZER_REVISION = "0e0117387f39004a9016484a186a908917e22426"
21+
SEED = 123
22+
23+
DEVICE = "cpu"
24+
25+
26+
def set_seed(seed: int) -> None:
27+
random.seed(seed)
28+
np.random.seed(seed)
29+
torch.manual_seed(seed)
30+
if torch.backends.cudnn.is_available():
31+
torch.backends.cudnn.deterministic = True
32+
torch.backends.cudnn.benchmark = False
33+
34+
35+
def generate_output(ctx_len: int) -> None:
36+
if ctx_len > MAX_CTX_LEN:
37+
raise ValueError(
38+
f"Context length for output generation ({ctx_len}) "
39+
f"cannot exceed maximum context length ({MAX_CTX_LEN})."
40+
)
41+
42+
context_df = df.iloc[:ctx_len].copy()
43+
future_timestamps = df["timestamps"].iloc[
44+
ctx_len : ctx_len + PRED_LEN
45+
].reset_index(drop=True)
46+
47+
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base", revision=TOKENIZER_REVISION)
48+
model = Kronos.from_pretrained("NeoQuasar/Kronos-small", revision=MODEL_REVISION)
49+
tokenizer.eval()
50+
model.eval()
51+
52+
predictor = KronosPredictor(
53+
model, tokenizer, device=DEVICE, max_context=MAX_CTX_LEN
54+
)
55+
56+
with torch.no_grad():
57+
pred_df = predictor.predict(
58+
df=context_df[FEATURE_NAMES].reset_index(drop=True),
59+
x_timestamp=context_df["timestamps"].reset_index(drop=True),
60+
y_timestamp=future_timestamps,
61+
pred_len=PRED_LEN,
62+
T=1.0,
63+
top_k=1,
64+
top_p=1.0,
65+
verbose=False,
66+
sample_count=1,
67+
)
68+
69+
if pred_df.shape != (PRED_LEN, len(FEATURE_NAMES)):
70+
raise ValueError(f"Unexpected prediction shape: {pred_df.shape}")
71+
72+
output_df = pred_df.reset_index(drop=True)
73+
output_df["timestamps"] = future_timestamps
74+
output_df = output_df[["timestamps"] + FEATURE_NAMES]
75+
output_df.to_csv(OUTPUT_DATA_DIR / f"regression_output_{ctx_len}.csv", index=False)
76+
print(f"Saved {ctx_len} fixture to {OUTPUT_DATA_DIR / f'regression_output_{ctx_len}.csv'}")
77+
78+
79+
if __name__ == "__main__":
80+
set_seed(SEED)
81+
82+
83+
df = pd.read_csv(INPUT_DATA_PATH, parse_dates=["timestamps"])
84+
if df.shape[0] < MAX_CTX_LEN + PRED_LEN:
85+
raise ValueError(
86+
f"Input data must have at least {MAX_CTX_LEN + PRED_LEN} rows, "
87+
f"found {df.shape[0]} instead."
88+
)
89+
90+
for ctx_len in TEST_CTX_LEN:
91+
generate_output(ctx_len)

0 commit comments

Comments
 (0)