RankNetの実装
はじめに
ランク学習に興味が湧いたのでRankNetで遊びます。 ランク学習とは、ある入力データをランクモデルに入力することで、任意の順序に並び替えて出力してくるものです.(らしい)
こちらを参考にPytorchで実装してみようかと思います。
RankNetの概要
まず、2つのデータとそれらに対するランク をサンプリングします。 これを PairWise手法と言うらしいです。ここで、2つのデータのランクの関係が と仮定します。
次に、あるランクモデル\(f\) を用いて2つのデータのスコア を計算します。(上記の仮定より となるような関数 を学習します。)
そして、2つのスコアを用いて となる確率を計算し、このモデルの予測値となります.
直感的には、 が大きい場合入力は正なので確率は大きくなり、逆に が大きい場合入力は負なので、 確率は小さくなることがわかります。
最後におなじみの損失関数(CrossEntropyLoss)を導入し、これを最小化するように学習します。
そして、2つのランクの大小によって以下のようにラベル\(\bar{P_{ij}} \)を定義します。
実験
今回は猫の画像のランクを出力するモデルを構築します。
データセット
オックスフォード大学が公開している動物画像の Visual Geometry Group - University of Oxfordを用います。
このデータセットには12種類(各200枚ずつ)の猫の画像が含まれており、簡潔化のために3種類に絞った物を学習データとします。
(左から、Bengal、Russian Blue、Brimanの3種類)
ランクは、0 : Russian Blue, 1: Brima , 2 : Briman と定義します。なので、ランクモデルにこの3種を入力した時、この順番で出力されることを期待します。
学習モデル
学習モデルも簡潔化のためにResNet18を用いて、最後の全結合層の出力数を1000 → 1 に変えて、転移学習させます。
コードはこちら
結果
テストとして、3種のネコを2枚ずつ合計6枚をランクモデルに入力した結果が以下となります。
定義したランクの通りに出力されました.
まとめ
今回は RankNetと言うランク学習の1つを実験しました。ネットワーク自体は転移学習でしたのでよくあるパターンのもので、一番の肝は誤差関数をどう設計するかでした。RankNetの後に出た ListNet と言うのもあるので、実装してみたいですね。