【chainer】 crf1d, CRF1dの使い方

chainerのlinks.CRF1dとfunctions.crf1dの使い方メモ。
crfと同様に系列を扱うNStepLSTMとかと同じ気分で使うと事故る。

概要

  • functions.crf1d(cost, xs, ts)

遷移スコア、タグスコア、正解ラベルをそれぞれ受け取って、最適パスをVariableのlistで返す。
遷移スコアは

\ 1,2,3
1 . . .
2 . . .
3 . . .

って感じの行列で、例えばタグ1から2への遷移スコアを5にするなら以下のようにする。

\ 1,2,3
1 . 5 .
2 . . .
3 . . .

他のスコアも適当に埋めて渡せばいい。

  • links.CRF1d(xs, ts)

最適な遷移スコアをヒューリスティックに探すのはしんどいので遷移スコアを学習する。
中でfunctions.crf1dを呼ぶので入出力は同じ形。

入力(xsの形)

# xs for crf1d
# batch size = 2
# batch = [(x,x,x), (x,x)]
[
1st :((1,0,0),(1,0,0))
2nd :((0,1,0),(0,1,0))
3rd :((0,0,1)
]

[]はリスト、()はVariableかnumpyのarray。
リストのインデックスが各時刻、行列がバッチ内の各タグスコア。

実際のコードだと以下のような感じ。
links.CRF1dを使っているが、functionsの方を使うなら上で述べた行列を第一引数に渡せばいい。

>>> import chainer
>>> from chainer import links as L
>>> import numpy as xp

>>> crfL = L.CRF1d(3) #tag sizeを3に設定

#縦方向が時刻
>>> xs = [xp.array([[1,0,0]], xp.float32), 
	  xp.array([[0,1,0]], xp.float32), 
	  xp.array([[0,0,1]], xp.float32)]

>>> crfL.argmax(xs)[1] #.argmaxはscoreとpathのタプルを返す
[array([0], dtype=int32), 
 array([1], dtype=int32), 
 array([2], dtype=int32)]
# batch処理の場合はこう。
# 長さ3の系列を二つまとめて入力。
# タグサイズは3
>>> xs = [xp.array([[1,0,0],[0,0,1]], xp.float32),
...       xp.array([[0,1,0],[0,1,0]], xp.float32),
...       xp.array([[0,0,1],[1,0,0]], xp.float32)]
>>> crfL.argmax(xs)[1]
[array([0, 2], dtype=int32), 
 array([1, 1], dtype=int32), 
 array([2, 0], dtype=int32)]


#可変長なbatch入力をするときは長い順にソートする必要がある。
>>> xs = [xp.array([[1,0,0],[1,0,0]], xp.float32),
...       xp.array([[0,1,0],[0,1,0]], xp.float32),
...       xp.array([[0,1,0]], xp.float32)]
>>> crfL.argmax(xs)[1]
[array([0, 0], dtype=int32), 
 array([1, 1], dtype=int32), 
 array([1], dtype=int32)]

入力(tsの形)

教師データもxsと同じで、縦にして大きい順に横に並べる。

# ts for crf1d
# batch size = 2
# batch = [(x,x,x), (x,x)]
1st :[(0,0)]
2nd :[(1,1)]
3rd :[(2)]
# input
>>> xs = [xp.array([[1,0,0],[1,0,0]], xp.float32),
...       xp.array([[0,1,0],[0,1,0]], xp.float32),
...       xp.array([[0,1,0]], xp.float32)]

# teacher
>>> ts = [xp.array([0,0],xp.int32),
...       xp.array([1,1],xp.int32),
...       xp.array([2],xp.int32)]

>>> crfL(xs, ts)
<variable at 0x10ea6cac8> # variableがかえってくる
>>> crfL(xs, ts).data
array(1.87861168384552, dtype=float32)

ちなみに僕が混乱したのは、同じ系列を扱うNStepLSTMの入力が以下のようだから。

xs = [(1,2,3),
      (1,2,3,4,5),
      (1,2,3,4)]
len(xs): batch Size

BiLSTMでのリストがバッチサイズを表しているのに対して、crf1dのリストは各時刻の入力を表している。
超紛らわしい。


(余談)
CRFを使うときにsoftmaxを噛ませると学習できないので、pre-trainingとかしてるひとは気をつけてね!
functions使うときにスコアをある程度正規化してから渡したいなぁ〜とか思ってると、普通に学習できなくて死ぬ。


わかりづらいとか、間違いとかあったらコメントでお願いします。