[Python] 디리클레 분포 추정하기

Posted by 적분 ∫2tdt=t²+c
2018.09.03 17:59 그냥 공부

토픽 모델링 이론들을 공부하다 보니 종종 깁스 샘플링 이후에 디리클레 분포를 추정하는 방법을 사용하는걸 봤었는데, 매번 봐도 잘 이해도 못하고 계속 까먹길래 아예 까먹지 포스팅을 하나 파둡니다.


디리클레 분포(Dirichlet Distribution)은 다항 분포에 대한 분포라는 건 잘 알고 계실 겁니다. 예를 들어 토픽 모델링과 같은 상황에서, 각 주제는 단어들에 대한 다항 확률 분포이므로, 주제의 분포는 다항 분포에 대한 분포, 즉 디리클레 분포가 되죠. 디리클레 분포에서 임의의 다항 (확률) 분포를 뽑아낼 수 있고, 다항 분포에서는 n지선다에서 하나를 뽑아낼 수 있습니다. 이런 유용성 때문에 토픽 모델링에서 디리클레 분포는 널리 쓰입니다.


디리클레 분포는 하이퍼 파라미터를 하나 가집니다. 흔히 α라고 적는 녀석이고, n차원 벡터의 형태를 띕니다. 이 값이 해당 디리클레 분포에서 뽑힐 다항분포들의 분포를 결정하는 중요한 역할을 합니다. 그렇기에 우리의 모델이 적절한 분포를 생성하려면 데이터에 잘 맞는 α을 지정해주어야 하는데, 이게 쉽지 않다보니 그냥 0.1을 준다던지, 50/K를 준다던지 여러 heuristic한 방법을 사용하는게 현실입니다.


하지만 데이터가 주어지면 적당한 횟수의 연산으로 데이터에 꼭 맞는 디리클레 분포의 하이퍼 파라미터를 찾아주는 방법이 이미 2000년에 제시되었습니다. 이 방법을 모르고 지나갈 수가 없겠죠. 


Minka, T. (2000). Estimating a Dirichlet distribution.

(수학 논문은 간결하고 핵심적인 내용만 담겨있어서 보기에 좋습니다. 이해하는건 별개지만)


단도직입적으로 방법부터 정리하면 다음과 같습니다.


1~N 사이의 수를 뽑는 확률분포가 총 M개 있습니다. 이를 x_i라고 합시다. (i=1~M) 각각의 x_i는 총 n_i개의 관측 데이터(1~N 사이의 수)를 생성합니다. 이 관측 데이터 중 k의 개수를 n_ik라고 합시다. 이 때 x_i들의 분포인 디리클레 분포(하이퍼파라미터 α)는 다음과 같이 추정할 수 있습니다. 

하이퍼파라미터 α는 총 N개의 요소로 이뤄진 벡터이고, 각각의 요소는 위와 같이 추정될 수 있습니다. ψ는 디감마 함수(digamma function)입니다. 이를 계산하면 새로운 하이퍼파라미터 α`를 얻게 됩니다. 이를 충분히 반복하면 α가 특정한 값으로 수렴하는데, 이것이 주어진 데이터에 가장 잘 맞는(수학적으로 말하면, 가능도를 최대로 하는, Maximizing likelihood)  값이라는 게 Minka가 제안한 내용입니다.


실제로 잘 작동하는지 감이 안 오므로 파이썬 코드를 작성해보았습니다.




결과는 다음과 같습니다.

# of iteration α LL
0 [0.10000, 0.10000] -49.6977
1 [0.13129, 0.11945] -48.74299
2 [0.16297, 0.14339] -47.98219
3 [0.19671, 0.16842] -47.37222
4 [0.23216, 0.19353] -46.88417
5 [0.26874, 0.21821] -46.4926
6 [0.30592, 0.24216] -46.17654
7 [0.34331, 0.26521] -45.91938
8 [0.38057, 0.28731] -45.70827
9 [0.41748, 0.30845] -45.53337
10 [0.45387, 0.32864] -45.38714

. . .

700 [2.1131, 1.0462] -44.27781
701 [2.1131, 1.0462] -44.27781
702 [2.1131, 1.0462] -44.27781
703 [2.1131, 1.0462] -44.27781
704 [2.1131, 1.0462] -44.27781
705 [2.1131, 1.0462] -44.27781
706 [2.1131, 1.0462] -44.27781
707 [2.1131, 1.0462] -44.27781
708 [2.1131, 1.0462] -44.27781
709 [2.1131, 1.0462] -44.27781


처음에는 로그 가능도(LL)가 -50에 가까웠지만 계속 iteration을 거듭할 수록 하이퍼파라미터는 [2.1131, 1.0462]에 수렴하고, LL값은 -44까지 증가했습니다. 아주 잘 작동하네요.

참 쉽죠?

Tags
이 댓글을 비밀 댓글로