Skip to content

Commit 52bf301

Browse files
author
ivankozlov98
committed
MLTOOLS-4119: add tutorial about using ONNX models in CatBoost
ref:7991db7cc612c7762c46148c0922960aa47cd474
1 parent 0f2efbc commit 52bf301

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Using ONNX models in CatBoost\n",
8+
"\n",
9+
"It is easy to apply ONNX models using CatBoost.\n",
10+
"+ Save your model in the ONNX format\n",
11+
"+ Load the ONNX model into CatBoost using the load_model() method\n",
12+
"+ Apply your model in CatBoost using the predict() method\n",
13+
"\n",
14+
"Let us follow this scenario step-by-step for a LightGBM model."
15+
]
16+
},
17+
{
18+
"cell_type": "markdown",
19+
"metadata": {},
20+
"source": [
21+
"\n",
22+
"Download the MSRank dataset and import the necessary packages:"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 1,
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"from catboost import datasets, CatBoostRegressor\n",
32+
"\n",
33+
"from lightgbm import LGBMRegressor\n",
34+
"\n",
35+
"import onnxmltools\n",
36+
"from onnxconverter_common import *\n",
37+
"\n",
38+
"\n",
39+
"train_df, _ = datasets.msrank()\n",
40+
"X, Y = train_df[train_df.columns[1:]], train_df[train_df.columns[0]]"
41+
]
42+
},
43+
{
44+
"cell_type": "markdown",
45+
"metadata": {},
46+
"source": [
47+
"\n",
48+
"Build a model:"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": 2,
54+
"metadata": {},
55+
"outputs": [
56+
{
57+
"name": "stdout",
58+
"output_type": "stream",
59+
"text": [
60+
"[1.30604501 1.60390655 0.35207384 ... 1.18672199 0.55631924 0.54655847]\n"
61+
]
62+
}
63+
],
64+
"source": [
65+
"model = LGBMRegressor()\n",
66+
"model.fit(X, Y)\n",
67+
"predict = model.predict(X)\n",
68+
"print(predict)"
69+
]
70+
},
71+
{
72+
"cell_type": "markdown",
73+
"metadata": {},
74+
"source": [
75+
"\n",
76+
"Save the model in the ONNX format:"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": 3,
82+
"metadata": {
83+
"scrolled": false
84+
},
85+
"outputs": [
86+
{
87+
"name": "stderr",
88+
"output_type": "stream",
89+
"text": [
90+
"The maximum opset needed by this model is only 1.\n",
91+
"The maximum opset needed by this model is only 1.\n"
92+
]
93+
}
94+
],
95+
"source": [
96+
"features_count = len(X.columns)\n",
97+
"onnx_model = onnxmltools.convert_lightgbm(model, name='LightGBM', initial_types=[['input', FloatTensorType([0, features_count])]])\n",
98+
"onnxmltools.utils.save_model(onnx_model, 'model.onnx')"
99+
]
100+
},
101+
{
102+
"cell_type": "markdown",
103+
"metadata": {},
104+
"source": [
105+
"\n",
106+
"Load the ONNX model into CatBoost and compare the CatBoost and LightGBM predictions:"
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": 4,
112+
"metadata": {},
113+
"outputs": [
114+
{
115+
"name": "stdout",
116+
"output_type": "stream",
117+
"text": [
118+
"[1.30604502 1.60390654 0.35207381 ... 1.18672202 0.55631925 0.54655849]\n"
119+
]
120+
}
121+
],
122+
"source": [
123+
"catboost_model = CatBoostRegressor()\n",
124+
"catboost_model.load_model('model.onnx', format='onnx')\n",
125+
"catboost_predict = catboost_model.predict(X)\n",
126+
"print(catboost_predict)"
127+
]
128+
}
129+
],
130+
"metadata": {
131+
"kernelspec": {
132+
"display_name": "Python 3",
133+
"language": "python",
134+
"name": "python3"
135+
},
136+
"language_info": {
137+
"codemirror_mode": {
138+
"name": "ipython",
139+
"version": 3
140+
},
141+
"file_extension": ".py",
142+
"mimetype": "text/x-python",
143+
"name": "python",
144+
"nbconvert_exporter": "python",
145+
"pygments_lexer": "ipython3",
146+
"version": "3.7.3"
147+
}
148+
},
149+
"nbformat": 4,
150+
"nbformat_minor": 2
151+
}

0 commit comments

Comments
 (0)