Skip to content

Commit d4f558b

Browse files
Pull request "add prediction diff for nonsymmetric trees" by @felixandrer from catboost/catboost#1237
MERGED FROM catboost/catboost#1237 ref:faf9cfb196f2f35f93640d56ccd812800c2e2759
1 parent efa1073 commit d4f558b

File tree

1 file changed

+168
-0
lines changed

1 file changed

+168
-0
lines changed
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# $$CatBoost\\ Feature\\ Importance\\ Tutorial$$"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"#### Sometimes it is very important to understand which feature made the greatest contribution to the final result. To do this, the CatBoost model has a get_feature_importance method."
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": 3,
20+
"metadata": {},
21+
"outputs": [],
22+
"source": [
23+
"import numpy as np\n",
24+
"from catboost import CatBoost, Pool, datasets\n",
25+
"from sklearn.model_selection import train_test_split"
26+
]
27+
},
28+
{
29+
"cell_type": "markdown",
30+
"metadata": {
31+
"collapsed": true
32+
},
33+
"source": [
34+
"#### First, let's prepare the dataset:"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": 4,
40+
"metadata": {},
41+
"outputs": [],
42+
"source": [
43+
"train_df, _ = datasets.higgs()"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": 5,
49+
"metadata": {},
50+
"outputs": [],
51+
"source": [
52+
"X, y = np.array(train_df.drop(0, axis=1))[:1000], np.array(train_df[0])[:1000]\n",
53+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)\n",
54+
"train_pool = Pool(X_train, y_train)\n",
55+
"test_pool = Pool(X_test, y_test)"
56+
]
57+
},
58+
{
59+
"cell_type": "markdown",
60+
"metadata": {},
61+
"source": [
62+
"#### Let's train CatBoost:"
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": 6,
68+
"metadata": {
69+
"scrolled": true
70+
},
71+
"outputs": [],
72+
"source": [
73+
"cb = CatBoost({'iterations': 20, 'verbose': False, 'random_seed': 42, 'grow_policy': 'Lossguide'})\n",
74+
"cb.fit(train_pool);"
75+
]
76+
},
77+
{
78+
"cell_type": "markdown",
79+
"metadata": {},
80+
"source": [
81+
"#### Catboost provides several types of feature importances. One of them is PredictionDiff: A vector with contributions of each feature to the RawFormulaVal difference for each pair of objects."
82+
]
83+
},
84+
{
85+
"cell_type": "markdown",
86+
"metadata": {},
87+
"source": [
88+
"#### Let's find two objects with incorrect labels on test data:"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": 16,
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"prediction = np.argmax(cb.predict(X_test, prediction_type='Probability'), axis=1)"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": 20,
103+
"metadata": {},
104+
"outputs": [],
105+
"source": [
106+
"wrong_prediction_idxs = np.arange(prediction.size)[y_test != prediction]\n",
107+
"test_pool_slice = test_pool.slice(wrong_prediction_idxs[:2])"
108+
]
109+
},
110+
{
111+
"cell_type": "markdown",
112+
"metadata": {},
113+
"source": [
114+
"#### Let's calculate PredictionDiff for these two objects:"
115+
]
116+
},
117+
{
118+
"cell_type": "code",
119+
"execution_count": 37,
120+
"metadata": {},
121+
"outputs": [
122+
{
123+
"name": "stdout",
124+
"output_type": "stream",
125+
"text": [
126+
"22: 0.590958854452\n",
127+
"25: 0.706977071538\n"
128+
]
129+
}
130+
],
131+
"source": [
132+
"prediction_diff = cb.get_feature_importance(type='PredictionDiff', data=test_pool_slice)\n",
133+
"\n",
134+
"for feature_id, diff in np.ndenumerate(prediction_diff):\n",
135+
" if diff > 0.:\n",
136+
" print('{}: {}'.format(feature_id[0], diff))"
137+
]
138+
},
139+
{
140+
"cell_type": "markdown",
141+
"metadata": {},
142+
"source": [
143+
"#### As you can see, feature 25 is most important for getting the right prediction."
144+
]
145+
}
146+
],
147+
"metadata": {
148+
"kernelspec": {
149+
"display_name": "Python 2",
150+
"language": "python",
151+
"name": "python2"
152+
},
153+
"language_info": {
154+
"codemirror_mode": {
155+
"name": "ipython",
156+
"version": 2
157+
},
158+
"file_extension": ".py",
159+
"mimetype": "text/x-python",
160+
"name": "python",
161+
"nbconvert_exporter": "python",
162+
"pygments_lexer": "ipython2",
163+
"version": "2.7.17"
164+
}
165+
},
166+
"nbformat": 4,
167+
"nbformat_minor": 1
168+
}

0 commit comments

Comments
 (0)