Skip to content

Commit 2fc8412

Browse files
Pull request "add CatBoostRanker in the ranking tutorial" by @AnnaAraslanova from catboost/catboost#1704
MERGED FROM catboost/catboost#1704 ref:3ec9b295e347d91a16dc7409794bd57ffa7c68ee
1 parent 2a1fd8a commit 2fc8412

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

ranking/ranking_tutorial.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
},
2323
"outputs": [],
2424
"source": [
25-
"from catboost import CatBoost, Pool, MetricVisualizer\n",
25+
"from catboost import CatBoostRanker, Pool, MetricVisualizer\n",
2626
"from copy import deepcopy\n",
2727
"import numpy as np\n",
2828
"import os\n",
@@ -365,7 +365,7 @@
365365
" if additional_params is not None:\n",
366366
" parameters.update(additional_params)\n",
367367
" \n",
368-
" model = CatBoost(parameters)\n",
368+
" model = CatBoostRanker(**parameters)\n",
369369
" model.fit(train_pool, eval_set=test_pool, plot=True)\n",
370370
" \n",
371371
" return model"
@@ -395,9 +395,9 @@
395395
"source": [
396396
"### Group weights parameter\n",
397397
"Suppose we know that some queries are more important than others for us.<br/>\n",
398-
"The word \"importance\" used here in terms of accuracy or quality of CatBoost prediction for given queries.<br/>\n",
398+
"The word \"importance\" used here in terms of accuracy or quality of CatBoostRanker prediction for given queries.<br/>\n",
399399
"You can pass this additional information for learner using a ``group_weights`` parameter.<br/>\n",
400-
"Under the hood, CatBoost uses this weights in loss function simply multiplying it on a group summand.<br/>\n",
400+
"Under the hood, CatBoostRanker uses this weights in loss function simply multiplying it on a group summand.<br/>\n",
401401
"So the bigger weight $\\rightarrow$ the more attention for query.<br/>\n",
402402
"Let's show an example of training procedure with random query weights."
403403
]
@@ -450,7 +450,7 @@
450450
"### A special case: top-1 prediction\n",
451451
"\n",
452452
"Someday you may face with a problem $-$ you will need to predict the top one most relevant object for a given query.<br/>\n",
453-
"For this purpose CatBoost has a mode called __QuerySoftMax__.\n",
453+
"For this purpose CatBoostRanker has a mode called __QuerySoftMax__.\n",
454454
"\n",
455455
"Suppose our dataset contain a binary target: 1 $-$ mean best document for a query, 0 $-$ others.<br/>\n",
456456
"We will maximize the probability of being the best document for given query.<br/>\n",
@@ -572,7 +572,7 @@
572572
"\n",
573573
"$$ - \\sum_{i,j \\in Pairs} \\log \\left( \\frac{1}{1 + \\exp{-(f(d_i) - f(d_j))}} \\right) $$\n",
574574
"\n",
575-
"Methods based on pair comparisons called __pairwise__ in CatBoost this objective called __PairLogit__.\n",
575+
"Methods based on pair comparisons called __pairwise__ in CatBoostRanker this objective called __PairLogit__.\n",
576576
"\n",
577577
"There's no need to change the dataset CatBoost generate the pairs for us. The number of generating pairs managed via parameter max_size."
578578
]
@@ -615,7 +615,7 @@
615615
" for doc_id, line in enumerate(f):\n",
616616
" line = line.split(',')[:2]\n",
617617
" \n",
618-
" label, query_id = tuple(map(float, line))\n",
618+
" label, query_id = float(line[0]), int(line[1])\n",
619619
" if query_id not in groups:\n",
620620
" groups[query_id] = []\n",
621621
" groups[query_id].append((doc_id, label))\n",

0 commit comments

Comments
 (0)