|
22 | 22 | }, |
23 | 23 | "outputs": [], |
24 | 24 | "source": [ |
25 | | - "from catboost import CatBoost, Pool, MetricVisualizer\n", |
| 25 | + "from catboost import CatBoostRanker, Pool, MetricVisualizer\n", |
26 | 26 | "from copy import deepcopy\n", |
27 | 27 | "import numpy as np\n", |
28 | 28 | "import os\n", |
|
365 | 365 | " if additional_params is not None:\n", |
366 | 366 | " parameters.update(additional_params)\n", |
367 | 367 | " \n", |
368 | | - " model = CatBoost(parameters)\n", |
| 368 | + " model = CatBoostRanker(**parameters)\n", |
369 | 369 | " model.fit(train_pool, eval_set=test_pool, plot=True)\n", |
370 | 370 | " \n", |
371 | 371 | " return model" |
|
395 | 395 | "source": [ |
396 | 396 | "### Group weights parameter\n", |
397 | 397 | "Suppose we know that some queries are more important than others for us.<br/>\n", |
398 | | - "The word \"importance\" used here in terms of accuracy or quality of CatBoost prediction for given queries.<br/>\n", |
| 398 | + "The word \"importance\" used here in terms of accuracy or quality of CatBoostRanker prediction for given queries.<br/>\n", |
399 | 399 | "You can pass this additional information for learner using a ``group_weights`` parameter.<br/>\n", |
400 | | - "Under the hood, CatBoost uses this weights in loss function simply multiplying it on a group summand.<br/>\n", |
| 400 | + "Under the hood, CatBoostRanker uses this weights in loss function simply multiplying it on a group summand.<br/>\n", |
401 | 401 | "So the bigger weight $\\rightarrow$ the more attention for query.<br/>\n", |
402 | 402 | "Let's show an example of training procedure with random query weights." |
403 | 403 | ] |
|
450 | 450 | "### A special case: top-1 prediction\n", |
451 | 451 | "\n", |
452 | 452 | "Someday you may face with a problem $-$ you will need to predict the top one most relevant object for a given query.<br/>\n", |
453 | | - "For this purpose CatBoost has a mode called __QuerySoftMax__.\n", |
| 453 | + "For this purpose CatBoostRanker has a mode called __QuerySoftMax__.\n", |
454 | 454 | "\n", |
455 | 455 | "Suppose our dataset contain a binary target: 1 $-$ mean best document for a query, 0 $-$ others.<br/>\n", |
456 | 456 | "We will maximize the probability of being the best document for given query.<br/>\n", |
|
572 | 572 | "\n", |
573 | 573 | "$$ - \\sum_{i,j \\in Pairs} \\log \\left( \\frac{1}{1 + \\exp{-(f(d_i) - f(d_j))}} \\right) $$\n", |
574 | 574 | "\n", |
575 | | - "Methods based on pair comparisons called __pairwise__ in CatBoost this objective called __PairLogit__.\n", |
| 575 | + "Methods based on pair comparisons called __pairwise__ in CatBoostRanker this objective called __PairLogit__.\n", |
576 | 576 | "\n", |
577 | 577 | "There's no need to change the dataset CatBoost generate the pairs for us. The number of generating pairs managed via parameter max_size." |
578 | 578 | ] |
|
615 | 615 | " for doc_id, line in enumerate(f):\n", |
616 | 616 | " line = line.split(',')[:2]\n", |
617 | 617 | " \n", |
618 | | - " label, query_id = tuple(map(float, line))\n", |
| 618 | + " label, query_id = float(line[0]), int(line[1])\n", |
619 | 619 | " if query_id not in groups:\n", |
620 | 620 | " groups[query_id] = []\n", |
621 | 621 | " groups[query_id].append((doc_id, label))\n", |
|
0 commit comments