[c++] 템플릿 메타 프로그래밍으로 르장드르 다항식 계산하기

Posted by 적분 ∫2tdt=t²+c
2018.07.08 18:26 프로그래밍/테크닉

르장드르 다항식은 [-1, 1] 구간에서 직교(orthogonality)하는 다항식들의 집합을 가리킵니다. 다항식들간에 서로 직교하면서, 간단하게 곱셈과 덧셈만으로 계산이 가능하다는 특징 덕분에 물리학이나 공학 등의 분야에서 특정 상태를 근사하여 풀 때 직교 기저로 자주 사용합니다. 저 역시도 최근 연구에서 임의의 모양을 띈 다차원 공간의 함수를 근사하기 위해서 이 다항식을 사용했는데요, 이를 위해서 르장드르 다항식을 구현하여 L-BFGS 함수식에 넣어 그 값을 계산해야 할 일이 있었습니다.

이 다항식은 재귀적 방법이나 조합식을 이용하여 쉽게 계산될 수 있지만, 이를 매번 함수 값을 구할때마다 반복할 수 없으므로 컴파일 타임에 처리하고자 템플릿 메타 프로그래밍(template meta programming)으로 간단하게 구현해 보았습니다. 실제로 구한 식은 Legendre Polynomials이 아니라 Shifted Legendre Polynomials이었구요, 이 변형된 르장드르 다항식의 경우 x 대신 2x-1을 대입한 것인데, 정의역 구간이 [0, 1]로 바뀌고, 각 다항식의 계수가 모두 정수가 된다는 편리한 점이 있기에 이를 선택하였습니다.


Shifted Legendre Polynomials

이 변형된 르장드르 다항식은 다음과 같이 정의됩니다.

여기서 C(n, k)는 조합(이항계수)입니다. 실제로 몇개의 n에 대한 값을 구해보면 아래와 같습니다.


르장드르 다항식을 쓰는 이유 중 하나가 계산이 간단하다는 것인데, 매 번 C(n, k)를 계산하기엔 손해가 막심하므로, 아예 이 함수식을 생성하는 템플릿 코드를 작성해보도록 합시다. 먼저 이항계수를 계산하는 부분을 짜보도록 합시다.


이항계수 계산

C++ 템플릿 메타 프로그래밍의 핵심은 결국 재귀를 용하여 해당 값을 얼마나 잘 정의하느냐에 달려있습니다. 먼저 combination 클래스를 정의해보도록 합시다.



익히 알려진대로 C(n, k)는 다음과 같은 관계식을 만족합니다.

관계식을 뒤집으면 다음과 같은 관계를 이끌어 낼수도 있겠죠.


이걸 바탕으로 다음과 같이 value를 정의할 수 있습니다.



재귀적 정의를 할 때 가장 중요한 것은 탈출 조건을 명시하는 것입니다. 그렇지 않으면 무한루프에 빠질테니깐요.

C(n, 0) = C(n, n) = 1

이라는 것을 떠올리며 다음과 같이 탈출 조건을 명시할 수 있겠네요.




이렇게 하면 음수가 아닌 n, k에 대해 C(n, k)를 모두 컴파일 타임에 계산할 수 있습니다. 궁금하시면 combination<10, 5>::value를 찍어보시면 되겠습니다.

사실 약간의 문제가 있는데 이 경우 combination<0, 0>의 경우 combination<n, 0>을 바탕으로 instantiated 되어야하는지, combination<n, n>을 바탕으로 해야하는지 모호하기 때문에 컴파일러가 에러를 뱉어냅니다. 따라서 다음과 같은 특수화를 하나 더 추가해서 문제를 해결할 수 있겠죠?





다항식 계산

임의의 다항식을 계산하는 함수를 어떻게 템플릿으로 생성할 수 있을까요?

예로 a, b, c, d라는 계수로 만들어지는 위의 다항식을 생각해봅시다. x^3을 계산하고 a를 곱하고, x^2을 계산하고 b를 곱하고, x에 c를 곱하고, d를 더하는건 조금 비효율적으로 보입니다. 하지만 위의 다항식은 다음과 같이 변형될 수 있습니다.


a에 x를 곱하고 b를 더하고, x를 곱하고 c를 더하고, x를 곱하고 d를 더하는 식으로 반복적으로 곱셈과 덧셈을 반복하면 다항식을 쉽게 계산할 수 있습니다.


이를 바탕으로 먼저 다항식을 계산하는 struct 틀을 잡아봅시다.




_Order가 3이라고 가정해보면 at<0>(x)는 다음과 같이 전개될 겁니다.


at<3>(x) := a (최고차의 계수)

at<2>(x) := at<3> * x + b (적절한 계수) 

at<1>(x) := at<2> * x + c (적절한 계수)

at<0>(x) := at<1> * x + d (적절한 계수)


위의 P(x) 계산식과 동일하다는 걸 알 수 있습니다. 함수가 함수를 호출하는 식을 구현되기에 혹시나 오버헤드가 발생할까 걱정하실수도 있는데, 컴파일 타임에 훌륭하게 inlining이 되기에 그런 걱정은 접어두셔도 됩니다. 이제 계수자리만 적절하게 채워주면 구현이 끝나겠군요.


앞서 소개한 정의에 따라 Shifted Legendre Polynomials의 각 계수는 다음과 같이 계산할 수 있습니다.


이항 계수를 계산하는 코드는 구현했으니 홀짝 여부에 따라 -1 혹은 1을 반환하는 부분만 만들면 되겠군요.


위와 같이 정의할 경우 even_odd<0>은 1, 그 외의 모든 경우는 -1을 돌려주는 템플릿 struct가 됩니다. 이제 이걸 바탕으로


이와 같이 even_odd<(n+k) % 2>::value를 넣어주면 2로 나눈 나머지가 0인 경우 즉 2의 배수일때는 1, 2의 배수가 아닐때는 -1의 값을 갖게 되어 깔끔하게 구현이 마무리됩니다.


그리고 매번 함수 값을 계산하기 위해 shiftedLegendre<10, float>::at(0.5f)와 같이 사용하기엔 귀찮으므로, 도우미 함수를 넣어 편리하게 써봅시다.



템플릿 함수의 경우 템플릿 파라미터를 자동으로 추정해주기 때문에 shiftedLegendreFunc<10>(0.5f)와 같이 사용하면 바로 float 버전의 르장드르 다항식이 계산되고, shiftedLegendreFunc<10>(0.5)과 같이 사용하면 double 버전의 다항식이 계산되겠죠. Generic한 함수 만들기까지 깔끔하게 끝이 납니다.


c++의 템플릿 메타 프로그래밍은 미리 계산 가능한 부분들을 컴파일 타임에 계산해둠으로써 (코드 및 바이너리의 크기는 커지겠지만) 런타임 실행 속도를 극한으로 끌어올릴 수 있는 아주 유용한 기법이죠. 재귀적 정의와 탈출 조건만 잘 고려한다면 쉽게 템플릿 메타 프로그래밍을 하실 수 있으실 겁니다! 도움이 되셨길 바랍니다!

Tags
이 댓글을 비밀 댓글로
  1. 감사합니다 !!!