Focal lossの実装(PyTorch)

Focal lossとは教師データに含まれるクラスごとのインスタンスが不均一であるときに学習がうまくいかないことを是正するために提案されたものだ。 One stageのObject detectionで背景クラスが大半を占めることで発生する問題に対して効果的に働くらしい。 仕組みがシンプルなので適用先はObject detectionには限らない汎用的な仕組みだといえる。

詳しいことは元の論文を読むなり、Qiitaを読むなりすることをお勧めするが、 この記事ではPyTorchでFocal lossを使用するために私が行った修正等について説明したい。

PyTorchのコミュニティでFocal lossについて議論されており、以下がおすすめされていたので使ってみたが、途中でone_hotを作って計算するところがGPUに載せ替えないと動かないのが不満になり、そこだけ直した。(Skorchを使っているとlossの計算時に to(device) とかやりづらいので。) github.com

直した結果は以下である。修正前の実装をFocalLossWithOneHot とし、修正後の実装をFocalLossWithoutOneHotとしている。 main()to("cuda")してFocalLossWithOneHot はエラーを吐き、FocalLossWithoutOneHotはエラーを吐かないこと確認した。 gist.github.com