orizuru

つながる.見える.わかる IoTソリュ-ション

CatBoostのランク学習(Learning to rank)をためそう

約 4 分
CatBoostのランク学習(Learning to rank)をためそう

エンジニアのtetsuです。

勾配ブースティング法のライブラリとして、去年の夏頃、ロシアのYandex社からCatBoostがリリースされました。売りとしては他の勾配ブースティングのライブラリよりも過学習を抑えることで、精度が良いという点のようです。ベンチマークはこちらです。
勾配ブースティング法の他のライブラリも同様ですが、CatBoostでは回帰、分類、ランキングが対象の問題となっています。CatBoostの回帰や分類のチュートリアルはありますが、ランク学習をおこなうものは見つかりませんでした。せっかくなので、ここではランク学習を試していこうと思います。

ランク学習とは?

ランク学習とは文書や商品などのランキングを学習する方法のことをいいます。たとえば、どこかの検索エンジンでキーワードを入力して検索をおこなうと色々なウェブページがでてきますが、これらのウェブページをどういう順番で表示するのが良いのかを学習するのがランク学習になります。ランク学習を利用することで、検索エンジンではユーザーに見られる確率が高いページを上位に表示したり、ショッピングサイトではおすすめの商品を提示することができるようになります。

対象とするデータ

今回の実験ではSUSHI Preference Data Setsを利用させていただきます。このデータセットのなかには以下のデータが含まれています。

  • お寿司の平均価格やこってり度などのお寿司に関するデータ
  • 各ユーザの5段階評価でのお寿司の好みに関するデータ(ほとんどのお寿司が無回答となっていて、10種類のみ評価がついています)
  • 各ユーザの年齢や出身地、性別などの個人に関するデータ

今回このデータセットを使った場合にランク学習によって達成したいことは、ユーザ情報と各お寿司の情報が与えられた時にどのお寿司がおすすめなのか(好きそうか)を順位付けて出力できるようになることです。

このデータセットを組み合わせてPandasのDataFrameの形式の学習データ、テストデータを生成します。学習データとテストデータの各行は単純に”ユーザ情報+お寿司の情報+お寿司の評価値(score)”となっています。1人のユーザに対してそのユーザが評価したお寿司の数と同じ行数のデータがDataFrameに存在します。たとえばuser_idが5695のユーザに対応している部分のDataFrameの行は次のようになっています(なおカラム名は適当につけたので、意味は正確ではありません)。

sushi_dataframe

CatBoostのランク学習

CatBoostのランク学習をおこなうコードを下記に示します。train_dfとtest_dfは上で示した形式のDataFrameとなっており、それぞれ訓練用とテスト用です。cat_featuresはDataFrameのなかのカテゴリ変数のインデックスをあらわしています。このようなリストを用意して、Poolに与えるだけでカテゴリ変数として扱ってくれるようになります。また、ランク学習では各データがどのグループに属するかを指定する必要があります。たとえば、1つのグループは1つのクエリに対する検索結果一覧になります。CatBoostでは各データがどのグループに属するかを示すgroup_idをリストの形式で与えることになっています。今回は一人のユーザ単位でグループにすればよいため、user_idをgroup_idとして与えます。

学習したモデルで予測をおこなう場合には次のようにします。これにより各個人のそれぞれのお寿司のスコア値が算出されます。このスコア値が大きいほど、よりおすすめのお寿司となります。

私の個人情報を入力して予測をさせたところ、まぐろ、トロ、鉄火巻の順でおすすめという結果になりました。まぐろづくしですね。たしかにまぐろは好きなのですが。

まとめ

今回はCatBoostでランク学習を動かすということを試してみました。使用させて頂いたお寿司のデータセットはなかなかおもしろいですね。
勾配ブースティング法のライブラリが色々存在するなかでどれを選択するか迷いどころになりそうです。今後の動向も要チェックですね。

Leave A Reply

*
*
* (公開されません)