[암호 SW 구현] 몽고메리 연산 구현
0. 몽고메리 곱셈을 왜 함??
일단 이부분을 이해하기 위해서는 큰 수에 대한 나눗셈에 대한 기본적인 이해가 필요되야함
(안됐으면 말해주길)
큰 수 나눗셈은 연산 구조상 계산 부하가 높음. 그래서 mod연산을 하기 위해서는 원래는 나눠서 그냥 나머지를 계산하는 방식으로 하는데, 이경우 곱셈을 할 때마다 나눗셈을 일일이 구현하기 어렵다.
그래서 곱셈을 할 때 마다 나눗셈 연산 대신 , bit shift연산을 이용해서 그 역할을 대신하겠다는 것임.
예를 들면
몽고메리 곱셈의 결과 : 11111111000000000
이런식으로 나오기 때문에 11111111
이런 식으로 뒤의 0만 없애주면 되기 때문에 나눗셈을 굳이 할 필요가 없어지는 것
1. Intro
몽고메리 곱셈은 일단 기본적으로 몽고메리 도메인과 정수 도메인을 이해한다.
가) 정수 도메인
우리가 평소에 알고 있던 정수 곱셈이 이루어지는 곳을 말한다. A * B = T mod m이 나오는 그냥 일반 정수 도메인
나) 몽고메리 도메인
몽고메리 도메인은 A * B = T * R^(-1) mod m이 나오는 도메인을 말한다.
다) 계산은 어떻게 하나요?
일단 각각의 수를 몽고메리 도메인으로 올려야 한다. 정수 도메인에 있던 A와 B를 몽고메리 도메인으로 변환을 해줘야 한다.
AR mod m = Mont(A , R^2)
BR mod m = Mont(B , R^2)
이렇게 몽고메리 도메인으로 두수를 올리게 되면 AR * BR = ABR mod m 이 된다.
몽고 도메인의 ABR을 다시 정수 도메인으로 내리려면 그냥 1을 몽고메리 곱셈을 해주면 된다.
2. Montgomery 감산 연산 동작
몽고메리 연산은 다음과 같다. 일단 모듈러 연산은 기본적으로 역원이 존재해야 하기 때문에, R과 m은 최대 공약수가 1이 되어야한다.
여기서 T는 A 와 B를 곱해지는 수이고, 최종적으로 1, 2,3 과정을 진행한다면 , TR^-1 이 나오게 된다.
EX )
여기서 감산한 것을 예를 들어보자. 세팅은 다음과 같다. 몽고메리 감산의 최종적인 목적은 나눗셈을 비트 쉬프트 연산으로 대체하는 것 이기 때문에, 뒷자리를 0으로 계속해서 만들어줘야한다. 그래서 저 과정에 써져있는 U라는 값은 계속해서 계산한 값의 뒷자리를 0으로 만드는 역할을 한다.
이렇게 되면, 최종적으로 ABR의 값은 3979600000이고 R^-1을 곱해주면, 39796가 된다.
3. Montgomery 곱셈
가) Input.
A : 최종 결과값 x , y : 곱해지는 수 m : mod 해주는 값 B : 워드의 크기는 2^32이다. R : B^n |
일단 input의 있는 수를 보면 A와 X , Y , M은 굳이 주어지는 값이고, M'을 만들어줘야 한다. 이때는 확장된 유클리드 알고리즘을 사용해서, 구해야한다. 그러나 기존의 구현했던 알고리즘은 한 워드에 대한 역원을 구한 것이기 때문에 기본적으로 큰 정수의 역원을 구하는 것으로 해야한다.
void MPZ_modinv(MPZ* result, MPZ* a, MPZ* mod) {
MPZ r1, r2, r, q;
MPZ t1, t2, t, one ,two , temp;
COPY_MPZ(&r1, mod);
COPY_MPZ(&r2, a);
t1.dat[0] = 0;
t1.len = 0;
t1.sign = 0;
t2.dat[0] = 1;
t2.len = 1;
t2.sign = 0;
two.dat[0] = 2;
two.len = 1;
two.sign = 0;
while (Compare_MPZ(&r1, &two)) { // r1>2
MPZ_UDIV(&q, &r, &r1, &r2);
MPZ_MODMUL( &temp ,&q, &t2 , mod);
MPZ_MODSUB(&t, &t1, &temp , mod);
COPY_MPZ(&r1, &r2);
COPY_MPZ(&r2, &r);
COPY_MPZ(&t1, &t2);
COPY_MPZ(&t2, &t);
}
COPY_MPZ(result, &t1);
}
(코드에 오타는 있을 수 있음)
기본적으로 구현했던 것들을 기존의 구현했던 큰 정수의 나눗셈과 곱셈으로 구현하면 된다.
(이부분 혹시 이해안되면 연락주셈!)
나) 동작
일단 2번에서 A(결과값) 에 0을 넣어준다. 몽고메리 곱셈의 기본적인 목표(꼭기억하자!) 나눗셈을 비트 쉬프트 연산으로 대체하는 것이 목표이기 때문에 u값을 계속해서 만들어줘야한다.
그래서 2.1에서는 y0와 xi의 값을 u에 담고 이렇게 구헌 ui값은 2.2에서 m값을 곱해줘서 더해주게 된다. 즉
ui * m은 마지막 자리의 워드를 0으로 만들어 주는 수라고 생각하면된다.
그래서 x의 한 워드와 y를 곱하고 마지막 자리의 워드를 0으로 만드는 ui * m을 더해주고 마지막 b를 나누어서 그 0으로 된부분을 없애준다. ( B = 2^32이므로 이때 bit shift를 하면 된다.)
그러나 3의 과정에서 혹시 결과값이 M을 넘을 수 있기 때문에 빼준다.
코드(이해 안되면 물어보고 꼭 주석을 이해하고 이해해야함)
void Mont_Mul(MPZ *r , MPZ *x , MPZ *y , MPZ *mod , UINT32 minv) // = r = xyR^-1
{
// minv = M' = -m^-1 mod b
SINT32 i;
UINT32 u;
MPZ temp ,A;
A.len = 0;
A.sign = 0;
A.dat[0] = 0; // 최종결과값에 0을 넣어주는 역할을 함.
for (i = 0; i < x->len; i++) {
u = (A.dat[0] + x->dat[i] * y->dat[0]) * minv; // 어짜피 32bit 짜리 변수여서 자동으로 넘어가는 것은 없어진다.
// 기본적으로 알고리즘에 mod b가 있는데 32bit 변수에 담기 때문에 넘어가는 값들은 자동으로 없어진다.
MUL_WORD_MPZ(&temp, x->dat[i], y);
MPZ_UADD(&A, &A, &temp);
// 여기까지 2.1
MUL_WORD_MPZ(&temp, u, mod);
MPZ_UADD(&A, &A, &temp);
MPZ_WORD_SHIFT(&A, &A, -1); //여기서 b는 2^32이므로 워드 쉬프트를 사용하면 한 워드가 없어진다.
// 여기까지 2.2
}
if (Compare_MPZ(&A, mod) == 1) {
MPZ_USUB(r, &A, mod); // 더 크면 없애줘야한다.
}
else
COPY_MPZ(r, &A);
}
4. 몽고메리 지수승.
void LtoR_Exp(MPZ* r, MPZ* input, MPZ* exp, MPZ* mod)
{
MPZ x, y;
SINT32 i, j;
COPY_MPZ(&x, input);
for (i = 31; i >= 0; i--)
{
if ((exp->dat[exp->len - 1] & (1 << i)))
break;
}
for (j = i - 1; j >= 0; j--)
{
MPZ_MODMUL(r, &x, &x, mod);
COPY_MPZ(&x, r);
if (exp->dat[exp->len - 1] & (1 << j))
{
MPZ_MODMUL(r, &x, input, mod);
COPY_MPZ(&x, r);
}
}
for (i = exp->len - 2; i >= 0; i--)
{
for (j = 31; j >= 0; j--)
{
MPZ_MODMUL(r, &x, &x, mod);
COPY_MPZ(&x, r);
if (exp->dat[i] & (1 << j))
{
MPZ_MODMUL(r, &x, input, mod);
COPY_MPZ(&x, r);
}
}
}
}
기존의 구현했던 코드는 기억이 나는지....
일단 지수의 1비트씩 스캔하면서 1이면 제곱하고 그 곱해주고 0이면 제곱만하는 알고리즘이다.
void LtoR_Mont_Exp(MPZ* r, MPZ* input, MPZ* exp, MPZ* mod , int tlqkf)
{
MPZ A, x_,x, y , R2, R2n, temp, one;
SINT32 i, j;
UINT32 minv;
R2.len = mod->len * 2 + 1;
memset(R2.dat, 0, sizeof(UINT32) * mod->len * 2);
R2.dat[2 * mod->len] = 1;//최상위 워드는 1로 채우기
R2.sign = 0;
MPZ_UDIV(&temp, &R2n, &R2, mod);
minv = minus_modinv(mod->dat[0], (UINT64)0x100000000);
MPZ_UDIV(&A, &temp, input, mod);
Mont_Mul(&A,input, &R2n, mod, minv);
COPY_MPZ(&x_, &A); //A = x_ = input * R
for (i = 31; i >= 0; i--)
{
if ((exp->dat[exp->len - 1] & (1 << i)))
break;
}
for (j = i - 1; j >= 0; j--)
{
Mont_Mul(&A, &A, &A, mod , minv);
if (exp->dat[exp->len - 1] & (1 << j))
{
Mont_Mul(&A , &A, &x_,mod, minv);
}
}
for (i = exp->len - 2; i >= 0; i--)
{
for (j = 31; j >= 0; j--)
{
Mont_Mul(&A, &A, &A, mod, minv);
if (exp->dat[i] & (1 << j))
{
Mont_Mul(&A, &A, &x_, mod, minv);
}
}
}
one.dat[0] = 1;
one.len = 1;
one.sign = 0;
Mont_Mul(r, &A, &one, mod, minv);
}
다음은 몽고메리 연산을 활용한 지수승 연산이다.
일단 기본적으로 몽고메리 연산에 필요한것(항상 알고리즘을 생각할 때 무엇무엇이 입력으로 들어가고 출력은 어떤 형태로 나오는지 이해해야함.) 따라서 아까 필요했던 것 역원 계산! M'이 필요하고 b라는 워드가 필요하니까 아래와 같은 코드를 이용해서 M'를 계산하게 된다.
minv = minus_modinv(mod->dat[0], (UINT64)0x100000000);
그리고 나머지는 mod_MUL을 mont_mul로 변환하기 만한다.
항상 알고리즘 공부할 때 위의 하이라이트 친게 중요한것 같음
1. 연산에 필요한 입력값이 무엇인가?
2. 연산후 나오는 출력값은 어떤 형태인가?
모르는 거있으면 톡하셈