토픽 모델링 이론들을 공부하다 보니 종종 깁스 샘플링 이후에 디리클레 분포를 추정하는 방법을 사용하는걸 봤었는데, 매번 봐도 잘 이해도 못하고 계속 까먹길래 아예 까먹지 포스팅을 하나 파둡니다.
디리클레 분포(Dirichlet Distribution)은 다항 분포에 대한 분포라는 건 잘 알고 계실 겁니다. 예를 들어 토픽 모델링과 같은 상황에서, 각 주제는 단어들에 대한 다항 확률 분포이므로, 주제의 분포는 다항 분포에 대한 분포, 즉 디리클레 분포가 되죠. 디리클레 분포에서 임의의 다항 (확률) 분포를 뽑아낼 수 있고, 다항 분포에서는 n지선다에서 하나를 뽑아낼 수 있습니다. 이런 유용성 때문에 토픽 모델링에서 디리클레 분포는 널리 쓰입니다.
디리클레 분포는 하이퍼 파라미터를 하나 가집니다. 흔히 α라고 적는 녀석이고, n차원 벡터의 형태를 띕니다. 이 값이 해당 디리클레 분포에서 뽑힐 다항분포들의 분포를 결정하는 중요한 역할을 합니다. 그렇기에 우리의 모델이 적절한 분포를 생성하려면 데이터에 잘 맞는 α을 지정해주어야 하는데, 이게 쉽지 않다보니 그냥 0.1을 준다던지, 50/K를 준다던지 여러 heuristic한 방법을 사용하는게 현실입니다.
하지만 데이터가 주어지면 적당한 횟수의 연산으로 데이터에 꼭 맞는 디리클레 분포의 하이퍼 파라미터를 찾아주는 방법이 이미 2000년에 제시되었습니다. 이 방법을 모르고 지나갈 수가 없겠죠.
Minka, T. (2000). Estimating a Dirichlet distribution.
(수학 논문은 간결하고 핵심적인 내용만 담겨있어서 보기에 좋습니다. 이해하는건 별개지만)
단도직입적으로 방법부터 정리하면 다음과 같습니다.
1~N 사이의 수를 뽑는 확률분포가 총 M개 있습니다. 이를 x_i라고 합시다. (i=1~M) 각각의 x_i는 총 n_i개의 관측 데이터(1~N 사이의 수)를 생성합니다. 이 관측 데이터 중 k의 개수를 n_ik라고 합시다. 이 때 x_i들의 분포인 디리클레 분포(하이퍼파라미터 α)는 다음과 같이 추정할 수 있습니다.
하이퍼파라미터 α는 총 N개의 요소로 이뤄진 벡터이고, 각각의 요소는 위와 같이 추정될 수 있습니다. ψ는 디감마 함수(digamma function)입니다. 이를 계산하면 새로운 하이퍼파라미터 α`를 얻게 됩니다. 이를 충분히 반복하면 α가 특정한 값으로 수렴하는데, 이것이 주어진 데이터에 가장 잘 맞는(수학적으로 말하면, 가능도를 최대로 하는, Maximizing likelihood) 값이라는 게 Minka가 제안한 내용입니다.
실제로 잘 작동하는지 감이 안 오므로 파이썬 코드를 작성해보았습니다.
결과는 다음과 같습니다.
# of iteration | α | LL |
---|---|---|
0 | [0.10000, 0.10000] | -49.6977 |
1 | [0.13129, 0.11945] | -48.74299 |
2 | [0.16297, 0.14339] | -47.98219 |
3 | [0.19671, 0.16842] | -47.37222 |
4 | [0.23216, 0.19353] | -46.88417 |
5 | [0.26874, 0.21821] | -46.4926 |
6 | [0.30592, 0.24216] | -46.17654 |
7 | [0.34331, 0.26521] | -45.91938 |
8 | [0.38057, 0.28731] | -45.70827 |
9 | [0.41748, 0.30845] | -45.53337 |
10 | [0.45387, 0.32864] | -45.38714 |
. . . |
||
700 | [2.1131, 1.0462] | -44.27781 |
701 | [2.1131, 1.0462] | -44.27781 |
702 | [2.1131, 1.0462] | -44.27781 |
703 | [2.1131, 1.0462] | -44.27781 |
704 | [2.1131, 1.0462] | -44.27781 |
705 | [2.1131, 1.0462] | -44.27781 |
706 | [2.1131, 1.0462] | -44.27781 |
707 | [2.1131, 1.0462] | -44.27781 |
708 | [2.1131, 1.0462] | -44.27781 |
709 | [2.1131, 1.0462] | -44.27781 |
처음에는 로그 가능도(LL)가 -50에 가까웠지만 계속 iteration을 거듭할 수록 하이퍼파라미터는 [2.1131, 1.0462]에 수렴하고, LL값은 -44까지 증가했습니다. 아주 잘 작동하네요.
참 쉽죠?
[기계 번역] 이중 언어 데이터에서의 단어 임베딩 (Bilingual Word Embeddings from Non-Parallel Document Data) (0) | 2018.11.30 |
---|---|
단어 의미의 역사적 변천을 추적하기 (1) | 2018.11.12 |
상위어 자동 추출(Hypernym Detection) 기법 정리 (1) | 2018.10.10 |
코퍼스 내에서 알려지지 않은 새로운 명사(미등록어)를 추출하기 (3) | 2018.09.02 |
[토픽 모델링] 가우시안 LDA 모델 - Word2Vec과 결합한 LDA (5) | 2018.08.05 |
[토픽 모델링] LDA에 용어 가중치를 적용하기 (8) | 2018.06.26 |
댓글 영역