NTT 的 C/C++ 实现

NTT (C ref)

ntt_ref.h

#ifndef NTT_H
#define NTT_H

typedef char int8;
typedef short int16;
typedef int int32;
typedef long long int64;

typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int  uint32;
typedef unsigned long long uint64;


//################################### 参数设置 ###################################

#define NTT_NEG 1   //0:循环NTT。1:反循环NTT。

#define NTT_Q 12289
#define NTT_N 1024

#define NTT_ROUND 10
#define NTT_ORDER (1<<(NTT_ROUND+1))
#define NTT_BASELEN (NTT_N>>NTT_ROUND)

#define NTT_ZETA 7


//################################### 快速模约减 ###################################

#define MONT_L  16
#define MONT_R  (1LL<<MONT_L)
#define MONT    4091                // MONT_R mod q
#define NEGQINV 12287               // -q^-1 mod MONT_R

#define BARR_R  (1LL<<32)
#define BARR    349497              // round(2^32/q)

// 蒙特马利模约简,计算 a*R^{-1} mod q
//当qinv负数时结果[-q, q],与NTT中逻辑冲突(要求正数),需要变一下号
//Newhope中q=12289,它的MONT_R选为18位(16位时数据溢出?没有吧!)
#define montgomery_reduce(a) (((a) + ((int16)((int64)(a)*NEGQINV)&(MONT_R-1))*NTT_Q)>>MONT_L)  

// 巴雷特模约简,计算 a mod q
#define barrett_reduce(a) ((a)-((BARR*(int64)(a))>>32)*NTT_Q)


//################################### 函数定义 ###################################

void get_ntt_param(int32 q, int32 n, int32 r);

void ntt(int16* f);
void intt(int16* f);
void nttmul(int16* r, const int16* a, const int16* b);

int32 print_bytes(int8* arr, int32 len);
int32 print_coeffs(int16* arr, int32 len);


#endif

ntt_ref.c

#include <stdio.h>
#include <stdlib.h>
#include "ntt_ref.h"


//################################### 参数设置 ###################################

int16 zetas[NTT_ORDER + 1] = {
    
      };

int16 zetas_mont[NTT_ORDER + 1] = {
    
      };

int16 bitrev_list[NTT_ORDER] = {
    
      };

int32 factor = 12277, factor_mont = 64;


//################################### 通用函数 ###################################

int32 brv(int32 b, int32 l)
{
    
    
    int32 bb = 0;

    for (int32 i = 0; i < l; i++)
    {
    
    
        bb <<= 1;
        bb |= (b & 1);
        b >>= 1;
    }

    return bb;
}

int64 fast_pow(int64 a, int64 b, int64 q)
{
    
    
    int64 result = 1;
    while (b != 0)
    {
    
    
        if (b % 2 == 1)
            result = (result * a) % q;

        a = (a * a) % q;
        b >>= 1;
    }
    return result;
}

int64 exgcd(int64* x, int64* y, int64 a, int64 b)
{
    
    
    if (b == 0)
    {
    
    
        *x = 1;
        *y = 0;
        return a;
    }
    int64 ret = exgcd(x, y, b, a % b);
    int64 tmp = *x;
    *x = *y;
    *y = tmp - (a / b) * (*y);
    return ret;
}

int32 print_bytes(int8* arr, int32 len)
{
    
    
    printf("[ %d", arr[0]);
    for (int64 i = 1; i < len; i++)
        printf(", %d", arr[i]);
    printf(" ]");
    return 0;
}

int32 print_coeffs(int16* arr, int32 len)
{
    
    
    printf("[ %d", arr[0]);
    for (int64 i = 1; i < len; i++)
        printf(", %d", arr[i]);
    printf(" ]");
    return 0;
}


//############################## 预计算参数 ##############################

int64 find_root(int64 q, int64 ord)
{
    
    
    int64 w = 2;
    while (w < q)
    {
    
    
        if (fast_pow(w, ord, q) == 1 && fast_pow(w, ord >> 1, q) != 1)
        {
    
    
            printf("%lld-th root = %lld\n\n", ord, w);
            return w;
        }
        w++;
    }
    return 0;
}

void get_zetas(int32* zetas, int32 zeta, int32 q, int32 ord)
{
    
    
    int64 wi = 1;
    int64 w = zeta;
    zetas[0] = 1;
    printf("zetas = { %d", zetas[0]);
    for (int64 i = 1; i <= ord; i++)
    {
    
    
        wi = (wi * w) % q;
        zetas[i] = wi;
        printf(", %lld", wi);
    }
    printf(" };\n\n");
}

void get_zetas_mont(int32* zetas_mont, int32* zetas, int64 q, int64 ord, int64 mont)
{
    
    
    int64 wi_pre = mont * zetas[0] % q;
    zetas_mont[0] = wi_pre;
    printf("zetas_mont = { %lld", wi_pre);
    for (int64 i = 1; i <= ord; i++)
    {
    
    
        wi_pre = mont * zetas[i] % q;
        zetas_mont[i] = wi_pre;
        printf(", %lld", wi_pre);
    }
    printf(" };\n\n");
}

void get_brv_table(int32 bits)
{
    
    
    printf("bitrev_list = { 0");
    int32 len = (1LL << bits);
    for (int i = 1; i < len; i++)
        printf(", %d", brv(i, bits));
    printf(" };\n\n");
}

void get_intt_factor(int64 q, int64 r, int64 mont)
{
    
    
    int64 factor, pinv;
    int64 gcd = exgcd(&factor, &pinv, 1LL << r, q);
    factor = factor < 0 ? factor + q : factor;
    int64 factor_mont = (factor * mont) % q;
    printf("factor = %lld, factor_mont = %lld\n\n", factor, factor_mont);
}

void get_ntt_param(int32 q, int32 n, int32 r)
{
    
    
    printf("/******************************* get ntt params *******************************/\n\n");

    printf("NTT_Q = %d, NTT_N = %d, NTT_ROUND = %d, MONT_R = %lld, BARR_R = %lld\n\n", q, n, r, MONT_R, BARR_R);

    int64 d, x, y;
    int64 mont = MONT_R % q;
    d = exgcd(&x, &y, q, MONT_R);
    if (d != 1)
    {
    
    
        printf("gcd(NTT_Q, MONT_R) != 1\n");
        return;
    }
    printf("mont = %lld mod q\n\nqinv = %lld mod R\n\n", mont, x);

    printf("barret = 2^32/q =  %lld\n\n", (BARR_R + (q >> 1)) / q);

    int32 order = 1 << (r+1);
    int64 Zeta;
    Zeta = find_root(q, order);

    int32* Zetas = (int32*)malloc(sizeof(int32)*(order + 1));
    int32* Zetas_mont = (int32*)malloc(sizeof(int32) * (order + 1));

    get_zetas(Zetas, Zeta, q, order);
    get_zetas_mont(Zetas_mont, Zetas, q, order, mont);

    get_brv_table(r + 1);
    get_intt_factor(q, r, mont);

    printf("//******************************* get ntt params *******************************//\n\n");
    free(Zetas);
    free(Zetas_mont);
}


//################################### NTT变换 ###################################


void ntt(int16* f) {
    
    
    int32 Blocknum = 1;
    int32 Blocksize = NTT_N;
    int32 Round = 0;

    /*
        Radix-2
        X = X + WY
        Y = X - WY
    */
    if ((NTT_ROUND & 1) == 1) {
    
       
        int32 offset = Blocksize >> 1;
        int32 X, Y, WY;
        int32 zeta_mont = zetas_mont[Blocknum * NTT_NEG];
        int16* pf = f;

        for (int32 k = 0; k < offset; k++) {
    
    
            X = pf[k];
            WY = pf[k + offset] * zeta_mont;
            WY = montgomery_reduce(WY);

            pf[k] = X + WY;
            pf[k + offset] = X + NTT_Q - WY;
        }

        Blocknum <<= 1;
        Blocksize >>= 1;
        Round++;
    }

    /*
        Radix-4
        Harvey,输入输出范围[0,2q)
        X1 = (X1 + W*Y1) + W0*(X2 + W*Y2),范围[0,4q)
        X2 = (X1 + W*Y1) - W0*(X2 + W*Y2),范围[0,4q)
        Y1 = (X1 - W*Y1) + W1*(X2 - W*Y2),范围[0,4q)
        X2 = (X1 - W*Y1) - W1*(X2 - W*Y2),范围[0,4q)
        先约束X1范围[0,q),接着约束(X1 + W*Y1)和(X1 - W*Y1)范围[0,q),共三次模约减
    */
    for (; Round < NTT_ROUND; Round += 2, Blocksize >>= 2, Blocknum <<= 2) {
    
    
        int32 offset = Blocksize >> 2;
        int32 X1, X2, Y1, Y2, WY;
        int32 zeta_mont, zeta1_mont, zeta2_mont;

        for (int32 i = 0; i < Blocknum; i++) {
    
    
            int16* pf = f + i * Blocksize;

            /*
                j=0是原始数组,第j次迭代中,j-1层第i个分块使用的单位根,
                w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{2^{r-j}*brv_{j}(2i)}
                brv_{j}(2i) = brv_{r}/(r-j+1)
                因此 w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{brv_{r}(i)/2}
            */
            zeta_mont = zetas_mont[bitrev_list[Blocknum * NTT_NEG + i] >> 1];            //Round层第i块
            zeta1_mont = zetas_mont[bitrev_list[2 * Blocknum * NTT_NEG + i * 2] >> 1];       //Round+1层第2i块
            zeta2_mont = zetas_mont[bitrev_list[2 * Blocknum * NTT_NEG + i * 2 + 1] >> 1];   //Round+1层第2i+1块

            for (int k = 0; k < offset; k++) {
    
    
                X1 = pf[k];
                X2 = pf[k + offset];
                Y1 = pf[k + offset * 2];
                Y2 = pf[k + offset * 3];
                X1 -= ((NTT_Q - X1 - 1) >> 31) & NTT_Q;

                WY = montgomery_reduce(Y1 * zeta_mont);
                Y1 = X1 + NTT_Q - WY;
                X1 += WY;
                X1 -= ((NTT_Q - X1 - 1) >> 31) & NTT_Q;
                Y1 -= ((NTT_Q - Y1 - 1) >> 31) & NTT_Q;

                WY = montgomery_reduce(Y2 * zeta_mont);
                Y2 = X2 + NTT_Q - WY;
                X2 += WY;

                WY = montgomery_reduce(X2 * zeta1_mont);
                X2 = X1 + NTT_Q - WY;
                X1 += WY;

                WY = montgomery_reduce(Y2 * zeta2_mont);
                Y2 = Y1 + NTT_Q - WY;
                Y1 += WY;
                
                pf[k] = X1;
                pf[k + offset] = X2;
                pf[k + offset * 2] = Y1;
                pf[k + offset * 3] = Y2;
            }
        }
    }

    //for (int32 k = 0; k < NTT_N; k++) {
    
    
    //    f[k] -= ((NTT_Q - f[k] - 1) >> 31) & NTT_Q;  //模约减,从[0,2q)约减到[0,q)
    //}
}

void intt(int16* f) {
    
    
    int32 Blocknum = 1 << NTT_ROUND;
    int32 Blocksize = NTT_N >> NTT_ROUND;
    int32 Round = NTT_ROUND;
    int32 Qtimes2 = NTT_Q * 2;
    Blocksize <<= 2;
    Blocknum >>= 2;

    /*
        Radix-4
        Harvey,输入输出范围[0,2q)
        X1 = (X1 + X2) + (Y1 + Y2),范围[0,8q)
        X2 = IW0*(X1 - X2) + IW1*(Y1 - Y2),范围[0,2q)
        Y1 = IW*((X1 + X2) + (Y1 + Y2)),范围[0,q)
        Y2 = IW*(IW0*(X1 - X2) + IW1*(Y1 - Y2)),范围[0,q)
        先约束(X1 + X2)和(Y1 + Y2)范围[0,2q),接着约束(X1 + X2) + (Y1 + Y2)范围[0,2q),共三次模约减
    */
    for (; Round > 1; Round -= 2, Blocksize <<= 2, Blocknum >>= 2) {
    
    
        int32 offset = Blocksize >> 2;
        int32 X1, X2, Y1, Y2, T;
        int32 zeta_mont, zeta1_mont, zeta2_mont;

        for (int32 i = 0; i < Blocknum; i++) {
    
    
            int16* pf = f + i * Blocksize;

            /*
                j=0是原始数组,第j次迭代中,j-1层第i个分块使用的单位根,
                w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{2^{r-j}*brv_{j}(2i)}
                brv_{j}(2i) = brv_{r}/(r-j+1)
                因此 w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{brv_{r}(i)/2}
            */
            zeta_mont = zetas_mont[NTT_ORDER - (bitrev_list[Blocknum * NTT_NEG + i] >> 1)];            //Round层第i块
            zeta1_mont = zetas_mont[NTT_ORDER - (bitrev_list[2 * Blocknum * NTT_NEG + i * 2] >> 1)];       //Round+1层第2i块
            zeta2_mont = zetas_mont[NTT_ORDER - (bitrev_list[2 * Blocknum * NTT_NEG + i * 2 + 1] >> 1)];   //Round+1层第2i+1块

            for (int k = 0; k < offset; k++) {
    
    
                X1 = pf[k];
                X2 = pf[k + offset];
                Y1 = pf[k + offset * 2];
                Y2 = pf[k + offset * 3];

                T = (X1 - X2) * zeta1_mont;
                X1 += X2;
                X2 = montgomery_reduce(T);
                X1 -= ((Qtimes2 - X1 - 1) >> 31) & Qtimes2; //模约减
                
                T = (Y1 - Y2) * zeta2_mont;
                Y1 += Y2;
                Y2 = montgomery_reduce(T);
                Y1 -= ((Qtimes2 - Y1 - 1) >> 31) & Qtimes2; //模约减
                
                T = (X1 - Y1) * zeta_mont;
                X1 += Y1;
                Y1 = montgomery_reduce(T);
                X1 -= ((Qtimes2 - X1 - 1) >> 31) & Qtimes2; //模约减

                T = (X2 - Y2) * zeta_mont;
                X2 += Y2;
                Y2 = montgomery_reduce(T);

                pf[k] = X1;
                pf[k + offset] = X2;
                pf[k + offset * 2] = Y1;
                pf[k + offset * 3] = Y2;
            }
        }
    }

    /*
        Radix-2
        X = X + Y
        Y = IW*(X - Y)
    */
    if ((NTT_ROUND & 1) == 1) {
    
    
        int32 offset = Blocksize >> 1;
        int32 X, Y, T;
        int32 zeta_mont = zetas_mont[NTT_ORDER - (bitrev_list[Blocknum * NTT_NEG] >> 1)];
        int16* pf = f;

        for (int32 k = 0; k < offset; k++) {
    
    
            X = pf[k];
            Y = pf[k + offset];

            T = (X - Y) * zeta_mont;
            pf[k] = X + Y;
            pf[k + offset] = montgomery_reduce(T);
        }
    }

    //逆变换因子
    for (int32 k = 0; k < NTT_N; k++) {
    
    
        int32 X = f[k] * factor_mont;
        X = montgomery_reduce(X);
        f[k] = X - (((NTT_Q - X - 1) >> 31) & NTT_Q);
    }
}

inline void basemul(int16* r, const int16* a, const int16* b, int16 zeta)
{
    
    
    int32 res;  // 用更长的累加器,延迟取模运算
    int32 s;
    for (int32 i = 0; i < NTT_BASELEN; i++)
    {
    
    
        res = 0;
        s = NTT_BASELEN + i;
        for (int32 j = 0; j <= i; j++)
            res += b[j] * a[i - j];
        for (int32 j = i + 1; j < NTT_BASELEN; j++)
            res += zeta * barrett_reduce(b[j] * a[s - j]);
        r[i] = barrett_reduce(res);
    }
}

void nttmul(int16* r, const int16* a, const int16* b)
{
    
    
    // 2^{r-1} 个 n/2^{r-1} 长小多项式,NTT_ROUND = r-1
    int32 num = 1 << NTT_ROUND;
    for (int32 i = 0; i < num; i++)
    {
    
    
#if (NTT_BASELEN == 1)
        int32 tmp = *a * *b;
        *r = barrett_reduce(tmp);
#elif (NTT_BASELEN == 2)
        // 第r层第2^{r-1}+i个多项式使用的单位根,
        // w_{2^r}^{brv_r(2^{r-1}+i)}
        int32 zeta = zetas[bitrev_list[num * NTT_NEG + i]];
        int32 tmp0 = a[0] * b[0] + zeta * barrett_reduce(a[1] * b[1]);
        int32 tmp1 = a[0] * b[1] + a[1] * b[0];
        r[0] = barrett_reduce(tmp0);
        r[1] = barrett_reduce(tmp1);
#else
        // 第r层第2^{r-1}+i个多项式使用的单位根,
        // w_{2^r}^{brv_r(2^{r-1}+i)}
        int32 zeta = zetas[bitrev_list[num * NTT_NEG + i]];
        basemul(r, a, b, zeta);
#endif

        r += NTT_BASELEN;
        a += NTT_BASELEN;
        b += NTT_BASELEN;
    }
}

NTT (AVX2)

ntt_avx2.h

#ifndef NTT_H
#define NTT_H

typedef char int8;
typedef short int16;
typedef int int32;
typedef long long int64;

typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int  uint32;
typedef unsigned long long uint64;


//################################### 参数设置 ###################################

#define NTT_NEG 1   //0:循环NTT。1:反循环NTT。

#define NTT_Q 12289
#define NTT_N 1024

#define NTT_ROUND 10
#define NTT_ORDER (1<<(NTT_ROUND+1))
#define NTT_BASELEN (NTT_N>>NTT_ROUND)

#define NTT_ZETA 7


//################################### 快速模约减 ###################################

#define MONT_L  16
#define MONT_R  (1LL<<MONT_L)
#define MONT    4091                // MONT_R mod q
#define QINV    -12287              // q^-1 mod MONT_R
#define NEGQINV 12287               // -q^-1 mod MONT_R

#define BARR_epi16    5         // round(2^16/q)
#define BARR_epi32    349497    // round(2^32/q)

// 蒙特马利模约简,计算 a*R^{-1} mod q
//#define montgomery_reduce(a) (((a) - (int32)((int16)((int64)(a)*QINV))*NTT_Q)>>MONT_L)

//当qinv负数时结果[-q, q],与NTT中逻辑冲突(要求正数),需要变一下号
//Newhope中q=12289,它的MONT_R选为18位(16位时数据溢出?没有吧!)
#define montgomery_reduce(a) (((a) + (((int64)(a)*NEGQINV)&(MONT_R-1))*NTT_Q)>>MONT_L)  

// 巴雷特模约简,计算 a mod q
#define barrett_reduce(a) ((a)-((BARR_epi32*(int64)(a))>>32)*NTT_Q)


//################################### 函数定义 ###################################

void get_ntt_param(int32 q, int32 n, int32 r);

void ntt(int16* f);
void intt(int16* f, int8 mont);
void nttmul(int16* r, const int16* a, const int16* b, int8 mont);

int32 print_bytes(int8* arr, int32 len);
int32 print_coeffs(int16* arr, int32 len);


#endif

ntt_avx2.c

#include <stdio.h>
#include <stdlib.h>
#include <xmmintrin.h>  // __m128
#include <immintrin.h>  // __m256
//#include <zmmintrin.h>  // __m512
#include "ntt_avx2.h"


//################################### 参数设置 ###################################

const int16 zetas[NTT_ORDER + 1] = {
    
      };

const int16 zetas_mont[NTT_ORDER + 1] = {
    
      };

const int16 bitrev_list[NTT_ORDER] = {
    
      };

const int32 factor = 12277, factor_mont = 64, factor_mont2 = 3755;


//################################### 通用函数 ###################################

int32 brv(int32 b, int32 l)
{
    
    
    int32 bb = 0;

    for (int32 i = 0; i < l; i++)
    {
    
    
        bb <<= 1;
        bb |= (b & 1);
        b >>= 1;
    }

    return bb;
}

int64 fast_pow(int64 a, int64 b, int64 q)
{
    
    
    int64 result = 1;
    while (b != 0)
    {
    
    
        if (b % 2 == 1)
            result = (result * a) % q;

        a = (a * a) % q;
        b >>= 1;
    }
    return result;
}

int64 exgcd(int64* x, int64* y, int64 a, int64 b)
{
    
    
    if (b == 0)
    {
    
    
        *x = 1;
        *y = 0;
        return a;
    }
    int64 ret = exgcd(x, y, b, a % b);
    int64 tmp = *x;
    *x = *y;
    *y = tmp - (a / b) * (*y);
    return ret;
}

int32 print_bytes(int8* arr, int32 len)
{
    
    
    printf("[ %d", arr[0]);
    for (int64 i = 1; i < len; i++)
        printf(", %d", arr[i]);
    printf(" ]");
    return 0;
}

int32 print_coeffs(int16* arr, int32 len)
{
    
    
    printf("[ %d", arr[0]);
    for (int64 i = 1; i < len; i++)
        printf(", %d", arr[i]);
    printf(" ]");
    return 0;
}


//############################## 预计算参数 ##############################

int64 find_root(int64 q, int64 ord)
{
    
    
    int64 w = 2;
    while (w < q)
    {
    
    
        if (fast_pow(w, ord, q) == 1 && fast_pow(w, ord >> 1, q) != 1)
        {
    
    
            printf("%lld-th root = %lld\n\n", ord, w);
            return w;
        }
        w++;
    }
    return 0;
}

void get_zetas(int32* zetas, int32 zeta, int32 q, int32 ord)
{
    
    
    int64 wi = 1;
    int64 w = zeta;
    zetas[0] = 1;
    printf("zetas = { %d", zetas[0]);
    for (int64 i = 1; i <= ord; i++)
    {
    
    
        wi = (wi * w) % q;
        zetas[i] = wi;
        printf(", %lld", wi);
    }
    printf(" };\n\n");
}

void get_zetas_mont(int32* zetas_mont, int32* zetas, int64 q, int64 ord, int64 mont)
{
    
    
    int64 wi_pre = mont * zetas[0] % q;
    zetas_mont[0] = wi_pre;
    printf("zetas_mont = { %lld", wi_pre);
    for (int64 i = 1; i <= ord; i++)
    {
    
    
        wi_pre = mont * zetas[i] % q;
        zetas_mont[i] = wi_pre;
        printf(", %lld", wi_pre);
    }
    printf(" };\n\n");
}

void get_brv_table(int32 bits)
{
    
    
    printf("bitrev_list = { 0");
    int32 len = (1LL << bits);
    for (int i = 1; i < len; i++)
        printf(", %d", brv(i, bits));
    printf(" };\n\n");
}

void get_intt_factor(int64 q, int64 r, int64 mont)
{
    
    
    int64 factor, pinv;
    int64 gcd = exgcd(&factor, &pinv, 1LL << r, q);
    factor = factor < 0 ? factor + q : factor;
    int64 factor_mont = (factor * mont) % q;
    int64 factor_mont2 = (factor_mont * mont) % q;
    printf("factor = %lld, factor_mont = %lld, factor_mont2 = %lld\n\n", factor, factor_mont, factor_mont2); //分别为:1/2^r,R/2^r,R^2/2^r
}

void get_ntt_param(int32 q, int32 n, int32 r)
{
    
    
    printf("/******************************* get ntt params *******************************/\n\n");

    printf("NTT_Q = %d, NTT_N = %d, NTT_ROUND = %d\n\n", q, n, r);

    int64 d, x, y;
    int64 mont = MONT_R % q;
    d = exgcd(&x, &y, q, MONT_R);
    if (d != 1)
    {
    
    
        printf("gcd(NTT_Q, MONT_R) != 1\n\n");
        return;
    }
    printf("MONT_R = %lld\nMONT = %lld mod q\nQINV = %lld mod R\n\n", MONT_R, mont, x);

    printf("BARR_R = 2^16, BARR_epi16 = 2^16/q = %d\nBARR_R = 2^32, BARR_epi32 = 2^32/q =  %d\n\n", (16 + (q >> 1)) / q, (32 + (q >> 1)) / q);

    int32 order = 1 << (r + 1);
    int64 Zeta;
    Zeta = find_root(q, order);

    int32* Zetas = (int32*)malloc(sizeof(int32) * (order + 1));
    int32* Zetas_mont = (int32*)malloc(sizeof(int32) * (order + 1));

    get_zetas(Zetas, Zeta, q, order);
    get_zetas_mont(Zetas_mont, Zetas, q, order, mont);

    get_brv_table(r + 1);
    get_intt_factor(q, r, mont);

    printf("//******************************* get ntt params *******************************//\n\n");
    free(Zetas);
    free(Zetas_mont);
}


//################################### Load/Store辅助函数 ###################################

__m256i NTT_TMP;

#define Half(X,Y)\
    NTT_TMP = _mm256_permute2x128_si256(X, Y, 0x31);\
    X = _mm256_permute2x128_si256(X, Y, 0x20);\
    Y = NTT_TMP;

#define Perm(X,Y)\
    X = _mm256_permute4x64_epi64(X, 0b11011000);\
    Y = _mm256_permute4x64_epi64(Y, 0b11011000);

#define Coll_32(X,Y)\
    X = _mm256_shuffle_epi32(X, 0b11011000);\
    Y = _mm256_shuffle_epi32(Y, 0b11011000);\

const int8 CollIndex[32] = {
    
    
    0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15,
    0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15
};

#define Coll_16(X,Y)\
    X = _mm256_shuffle_epi8(X,*(__m256i*)CollIndex);\
    Y = _mm256_shuffle_epi8(Y,*(__m256i*)CollIndex);

const int8 CollIndex_inv[32] = {
    
    
    0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15,
    0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15
};

#define Coll_16_inv(X,Y)\
    X = _mm256_shuffle_epi8(X,*(__m256i*)CollIndex_inv);\
    Y = _mm256_shuffle_epi8(Y,*(__m256i*)CollIndex_inv);

#define offset8(X,Y) Half(X, Y);
#define offset8_inv(X,Y) Half(X, Y);

#define offset4(X,Y) Perm(X, Y); Half(X, Y); 
#define offset4_inv(X,Y) Half(X, Y); Perm(X, Y); 

#define offset2(X,Y) Coll_32(X,Y); Perm(X,Y); Half(X, Y); 
#define offset2_inv(X,Y) Half(X, Y); Perm(X, Y); Coll_32(X,Y); 

#define offset1(X,Y) Coll_16(X,Y); Perm(X,Y); Half(X, Y); 
#define offset1_inv(X,Y) Half(X, Y); Perm(X, Y); Coll_16_inv(X,Y); 


//################################### 快速模约减 ###################################

/*
*   montgomery_mul(a, zeta) = (a * zeta_mont - (R-1)&(a * zeta_mont * qinv) * q) >> t
*   zeta_mont = zeta*R mod q,t=16,R=2^16
*/
__m256i montgomery_reduce_epi16(__m256i a, __m256i w) {
    
    
    __m256i q = _mm256_set1_epi16(NTT_Q);
    __m256i qinv = _mm256_set1_epi16(QINV);

    //正确性约束:(R - 1)*q + a*w < 2^32
    __m256i hi = _mm256_mulhi_epi16(a, w);  //有符号的高位,epi与epu的乘法结果模2^32同余,比特表示相同
    __m256i lo = _mm256_mullo_epi16(a, w);  //有符号的低位,是个无符号数,hi*65536 + lo

    /*
        只要不越界,_mm256_mullo_epi16 = _mm256_mullo_epu16
        但是,_mm256_mulhi_epi16 != _mm256_mulhi_epu16,注意高位补0还是补1
    */
    __m256i tmp = _mm256_mullo_epi16(lo, qinv);  //无符号模乘,R=2^16,无论lo和qinv是int16下的负数或正数
    tmp = _mm256_mulhi_epu16(tmp, q);  //无符号乘法,要使用epu,将tmp识别为无符号数,抑制乘法的高位补1

    /*
        a*w和tmp*q的低16位相同,没有进位借位
        hi是负数,减去无符号tmp后还是负数
        hi是正数,减去无符号tmp后可能是正数也可能是负数
    */
    hi = _mm256_sub_epi16(hi, tmp);

    //要让约减结果范围[0,q],不能出现负数(与Harvey蝴蝶冲突)
    tmp = _mm256_srai_epi16(hi, 15); //算数右移
    tmp = _mm256_and_si256(tmp, q);
    hi = _mm256_add_epi16(hi, tmp);

    return hi;
}


/*
*   a - ((m*a)>>t) * q
*   m = R/q,t=16,R=2^16
*/
__m256i barrett_reduce_epi16(__m256i a) {
    
    
    __m256i q = _mm256_set1_epi16(NTT_Q);
    __m256i m = _mm256_set1_epi16(BARR_epi16);

    __m256i tmp = _mm256_mulhi_epi16(a, m); //有符号
    tmp = _mm256_mullo_epi16(tmp, q);
    a = _mm256_sub_epi16(a, tmp);

    return a; //范围乱变[-q, 2q)
}


/*
*   a - ((m*a)>>t) * q
*   m = R/q,t=32,R=2^32
*/
__m256i barrett_reduce_epi32(__m256i a) {
    
    
    __m256i q = _mm256_set1_epi32(NTT_Q);
    __m256i m = _mm256_set1_epi32(BARR_epi32);

    __m256i tmp1 = _mm256_mul_epi32(a, m);
    tmp1 = _mm256_srli_epi64(tmp1, 32); //本应算数右移,为了重构方便采用逻辑右移,截断结果仍有符号

    __m256i tmp2 = _mm256_shuffle_epi32(a, 0b10110001);
    tmp2 = _mm256_mul_epi32(tmp2, m);
    tmp2 = _mm256_srli_epi64(tmp2, 32);
    tmp2 = _mm256_shuffle_epi32(tmp2, 0b10110001);

    tmp1 = _mm256_or_si256(tmp1, tmp2); //重构为epi32
    tmp1 = _mm256_mullo_epi32(tmp1, q);
    tmp1 = _mm256_sub_epi32(a, tmp1);

    return tmp1; //范围乱变[-q, 2q)
}



__m256i iflt0_addq(__m256i a) {
    
    
    __m256i q = _mm256_set1_epi16(NTT_Q);
    __m256i tmp = _mm256_srai_epi16(a, 15);
    tmp = _mm256_and_si256(tmp, q);
    return _mm256_add_epi16(a, tmp);
}

__m256i ifgeq_subq(__m256i a) {
    
    
    __m256i q = _mm256_set1_epi16(NTT_Q);
    __m256i tmp = _mm256_set1_epi16(NTT_Q - 1);
    tmp = _mm256_sub_epi16(tmp, a);
    tmp = _mm256_srai_epi16(tmp, 15);
    tmp = _mm256_and_si256(tmp, q);
    return _mm256_sub_epi16(a, tmp);
}

__m256i ifge2q_sub2q(__m256i a) {
    
    
    __m256i q = _mm256_set1_epi16(2 * NTT_Q);
    __m256i tmp = _mm256_set1_epi16(2 * NTT_Q - 1);
    tmp = _mm256_sub_epi16(tmp, a);
    tmp = _mm256_srai_epi16(tmp, 15);
    tmp = _mm256_and_si256(tmp, q);
    return _mm256_sub_epi16(a, tmp);
}


//################################### NTT变换 ###################################


void ntt(int16* f) {
    
    
    int32 Blocknum = 1;
    int32 Blocksize = NTT_N;
    int32 Round = 0;
    __m256i T, Q = _mm256_set1_epi16(NTT_Q);

    /*
        Radix-2
        X = X + WY
        Y = X - WY
    */
    if ((NTT_ROUND & 1) == 1) {
    
    
        int32 offset = Blocksize >> 1;
        __m256i W = _mm256_set1_epi16(bitrev_list[Blocknum * NTT_NEG] >> 1);

        for (int32 k = 0; k < offset; k += 16) {
    
    
            __m256i X = _mm256_loadu_si256((__m256i*)(f + k));
            __m256i Y = _mm256_loadu_si256((__m256i*)(f + k + offset));

            T = montgomery_reduce_epi16(Y, W);
            Y = _mm256_add_epi16(X, Q);
            Y = _mm256_sub_epi16(Y, T);
            X = _mm256_add_epi16(X, T);

            _mm256_storeu_si256((__m256i*)(f + k), X);
            _mm256_storeu_si256((__m256i*)(f + k + offset), Y);
        }

        Blocknum <<= 1;
        Blocksize >>= 1;
        Round++;
    }

    /*
        Radix-4
        Harvey,输入输出范围[0,2q)
        X1 = (X1 + W*Y1) + W0*(X2 + W*Y2),范围[0,4q)
        X2 = (X1 + W*Y1) - W0*(X2 + W*Y2),范围[0,4q)
        Y1 = (X1 - W*Y1) + W1*(X2 - W*Y2),范围[0,4q)
        X2 = (X1 - W*Y1) - W1*(X2 - W*Y2),范围[0,4q)
        先约束X1范围[0,q),接着约束(X1 + W*Y1)和(X1 - W*Y1)范围[0,q),共三次模约减
    */
    for (; Round < NTT_ROUND; Round += 2, Blocksize >>= 2, Blocknum <<= 2) {
    
    
        if (Blocksize >= 64)
            goto Block64;
        else
            switch (Blocksize)
            {
    
    
            case 32: goto Block32;
            case 16: goto Block16;
            case 8: goto Block8;
            case 4: goto Block4;
            default:
                goto Error; //本代码仅处理:NTT_N 是2的幂次
            }

    Block64: //处理分块大小整除64的情况,使用4个YMM,处理1个块
        for (int32 i = 0; i < Blocknum; i++) {
    
    
            int32 offset = Blocksize >> 2;
            int32 num = offset >> 4; //16个系数1个YMM
            int16* pf = f + i * Blocksize;

            /*
                j=0是原始数组,第j次迭代中,j-1层第i个分块使用的单位根,
                w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{2^{r-j}*brv_{j}(2i)}
                brv_{j}(2i) = brv_{r}/(r-j+1)
                因此 w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{brv_{r}(i)/2}
            */
            int32 ind = Blocknum * NTT_NEG + i;
            __m256i W = _mm256_set1_epi16(zetas_mont[bitrev_list[ind] >> 1]);            //Round层第i块
            __m256i W0 = _mm256_set1_epi16(zetas_mont[bitrev_list[ind * 2] >> 1]);       //Round+1层第2i块
            __m256i W1 = _mm256_set1_epi16(zetas_mont[bitrev_list[ind * 2 + 1] >> 1]);   //Round+1层第2i+1块

            for (int32 k = 0; k < num; k++) {
    
    
                __m256i X1 = _mm256_loadu_si256((__m256i*)(pf + k * 16));
                __m256i X2 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset));
                __m256i Y1 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset * 2));
                __m256i Y2 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset * 3));

                X1 = ifgeq_subq(X1);
                T = montgomery_reduce_epi16(Y1, W);
                Y1 = _mm256_add_epi16(X1, Q);
                Y1 = _mm256_sub_epi16(Y1, T);
                X1 = _mm256_add_epi16(X1, T);
                
                X2 = ifgeq_subq(X2);
                T = montgomery_reduce_epi16(Y2, W);
                Y2 = _mm256_add_epi16(X2, Q);
                Y2 = _mm256_sub_epi16(Y2, T);
                X2 = _mm256_add_epi16(X2, T);

                X1 = ifgeq_subq(X1);
                T = montgomery_reduce_epi16(X2, W0);
                X2 = _mm256_add_epi16(X1, Q);
                X2 = _mm256_sub_epi16(X2, T);
                X1 = _mm256_add_epi16(X1, T);
                
                Y1 = ifgeq_subq(Y1);
                T = montgomery_reduce_epi16(Y2, W1);
                Y2 = _mm256_add_epi16(Y1, Q);
                Y2 = _mm256_sub_epi16(Y2, T);
                Y1 = _mm256_add_epi16(Y1, T);

                _mm256_storeu_si256((__m256i*)(pf + k * 16), X1);
                _mm256_storeu_si256((__m256i*)(pf + k * 16 + offset), X2);
                _mm256_storeu_si256((__m256i*)(pf + k * 16 + offset * 2), Y1);
                _mm256_storeu_si256((__m256i*)(pf + k * 16 + offset * 3), Y2);
            }
        }
        continue;


    Block32: //处理分块大小为32的情况,使用2个YMM,处理1个块
        for (int32 i = 0; i < Blocknum; i++) {
    
    
            int16* pf = f + i * Blocksize;
            __m256i X = _mm256_loadu_si256((__m256i*)(pf));
            __m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));

            int32 ind = Blocknum * NTT_NEG + i;
            int16 w = zetas_mont[bitrev_list[ind] >> 1];
            __m256i W = _mm256_set1_epi16(w);

            X = ifgeq_subq(X);
            T = montgomery_reduce_epi16(Y, W);
            Y = _mm256_add_epi16(X, Q);
            Y = _mm256_sub_epi16(Y, T);
            X = _mm256_add_epi16(X, T);

            ind <<= 1;
            int16 w0 = zetas_mont[bitrev_list[ind] >> 1];
            int16 w1 = zetas_mont[bitrev_list[ind + 1] >> 1];
            W = _mm256_setr_epi16(w0, w0, w0, w0, w0, w0, w0, w0, w1, w1, w1, w1, w1, w1, w1, w1);

            offset8(X, Y);            
            X = ifgeq_subq(X);
            T = montgomery_reduce_epi16(Y, W);
            Y = _mm256_add_epi16(X, Q);
            Y = _mm256_sub_epi16(Y, T);
            X = _mm256_add_epi16(X, T);
            offset8_inv(X, Y);

            _mm256_storeu_si256((__m256i*)(pf), X);
            _mm256_storeu_si256((__m256i*)(pf + 16), Y);
        }
        continue;


    Block16: //处理分块大小为16的情况,使用2个YMM,处理2个块
        for (int32 i = 0; i < Blocknum; i+=2) {
    
    
            int16* pf = f + i * Blocksize;
            __m256i X = _mm256_loadu_si256((__m256i*)(pf));
            __m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));

            int32 ind = Blocknum * NTT_NEG + i;
            int16 w0 = zetas_mont[bitrev_list[ind] >> 1];
            int16 w1 = zetas_mont[bitrev_list[ind + 1] >> 1];
            __m256i W = _mm256_setr_epi16(w0, w0, w0, w0, w0, w0, w0, w0, w1, w1, w1, w1, w1, w1, w1, w1);

            offset8(X, Y);
            X = ifgeq_subq(X);
            T = montgomery_reduce_epi16(Y, W);
            Y = _mm256_add_epi16(X, Q);
            Y = _mm256_sub_epi16(Y, T);
            X = _mm256_add_epi16(X, T);
            offset8_inv(X, Y);

            ind <<= 1;
            int16 w00 = zetas_mont[bitrev_list[ind] >> 1];
            int16 w01 = zetas_mont[bitrev_list[ind + 1] >> 1];
            int16 w10 = zetas_mont[bitrev_list[ind + 2] >> 1];
            int16 w11 = zetas_mont[bitrev_list[ind + 3] >> 1];
            W = _mm256_setr_epi16(w00, w00, w00, w00, w01, w01, w01, w01, w10, w10, w10, w10, w11, w11, w11, w11);

            offset4(X, Y);
            X = ifgeq_subq(X);
            T = montgomery_reduce_epi16(Y, W);
            Y = _mm256_add_epi16(X, Q);
            Y = _mm256_sub_epi16(Y, T);
            X = _mm256_add_epi16(X, T);
            offset4_inv(X, Y);

            _mm256_storeu_si256((__m256i*)(pf), X);
            _mm256_storeu_si256((__m256i*)(pf + 16), Y);
        }
        continue;


    Block8: //处理分块大小为8的情况,使用2个YMM,处理4个块
        for (int32 i = 0; i < Blocknum; i += 4) {
    
    
            int16* pf = f + i * Blocksize;
            __m256i X = _mm256_loadu_si256((__m256i*)(pf));
            __m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));

            int32 ind = Blocknum * NTT_NEG + i;
            int16 w0 = zetas_mont[bitrev_list[ind] >> 1];
            int16 w1 = zetas_mont[bitrev_list[ind + 1] >> 1];
            int16 w2 = zetas_mont[bitrev_list[ind + 2] >> 1];
            int16 w3 = zetas_mont[bitrev_list[ind + 3] >> 1];
            __m256i W = _mm256_setr_epi16(w0, w0, w0, w0, w1, w1, w1, w1, w2, w2, w2, w2, w3, w3, w3, w3);

            offset4(X, Y);
            X = ifgeq_subq(X);
            T = montgomery_reduce_epi16(Y, W);
            Y = _mm256_add_epi16(X, Q);
            Y = _mm256_sub_epi16(Y, T);
            X = _mm256_add_epi16(X, T);
            offset4_inv(X, Y);

            ind <<= 1;
            int16 w00 = zetas_mont[bitrev_list[ind] >> 1];
            int16 w01 = zetas_mont[bitrev_list[ind + 1] >> 1];
            int16 w10 = zetas_mont[bitrev_list[ind + 2] >> 1];
            int16 w11 = zetas_mont[bitrev_list[ind + 3] >> 1];
            int16 w20 = zetas_mont[bitrev_list[ind + 4] >> 1];
            int16 w21 = zetas_mont[bitrev_list[ind + 5] >> 1];
            int16 w30 = zetas_mont[bitrev_list[ind + 6] >> 1];
            int16 w31 = zetas_mont[bitrev_list[ind + 7] >> 1];
            W = _mm256_setr_epi16(w00, w00, w01, w01, w10, w10, w11, w11, w20, w20, w21, w21, w30, w30, w31, w31);

            offset2(X, Y);
            X = ifgeq_subq(X);
            T = montgomery_reduce_epi16(Y, W);
            Y = _mm256_add_epi16(X, Q);
            Y = _mm256_sub_epi16(Y, T);
            X = _mm256_add_epi16(X, T);
            offset2_inv(X, Y);
            
            _mm256_storeu_si256((__m256i*)(pf), X);
            _mm256_storeu_si256((__m256i*)(pf + 16), Y);
        }
        continue;


    Block4: //处理分块大小为4的情况,使用2个YMM,处理8个块
        for (int32 i = 0; i < Blocknum; i += 8) {
    
    
            int16* pf = f + i * Blocksize;
            __m256i X = _mm256_loadu_si256((__m256i*)(pf));
            __m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));

            int32 ind = Blocknum * NTT_NEG + i;
            int16 w0 = zetas_mont[bitrev_list[ind] >> 1];
            int16 w1 = zetas_mont[bitrev_list[ind + 1] >> 1];
            int16 w2 = zetas_mont[bitrev_list[ind + 2] >> 1];
            int16 w3 = zetas_mont[bitrev_list[ind + 3] >> 1];
            int16 w4 = zetas_mont[bitrev_list[ind + 4] >> 1];
            int16 w5 = zetas_mont[bitrev_list[ind + 5] >> 1];
            int16 w6 = zetas_mont[bitrev_list[ind + 6] >> 1];
            int16 w7 = zetas_mont[bitrev_list[ind + 7] >> 1];
            __m256i W = _mm256_setr_epi16(w0, w0, w1, w1, w2, w2, w3, w3, w4, w4, w5, w5, w6, w6, w7, w7);

            offset2(X, Y);
            X = ifgeq_subq(X);
            T = montgomery_reduce_epi16(Y, W);
            Y = _mm256_add_epi16(X, Q);
            Y = _mm256_sub_epi16(Y, T);
            X = _mm256_add_epi16(X, T);
            offset2_inv(X, Y);

            ind <<= 1;
            int16 w00 = zetas_mont[bitrev_list[ind] >> 1];
            int16 w01 = zetas_mont[bitrev_list[ind + 1] >> 1];
            int16 w10 = zetas_mont[bitrev_list[ind + 2] >> 1];
            int16 w11 = zetas_mont[bitrev_list[ind + 3] >> 1];
            int16 w20 = zetas_mont[bitrev_list[ind + 4] >> 1];
            int16 w21 = zetas_mont[bitrev_list[ind + 5] >> 1];
            int16 w30 = zetas_mont[bitrev_list[ind + 6] >> 1];
            int16 w31 = zetas_mont[bitrev_list[ind + 7] >> 1];
            int16 w40 = zetas_mont[bitrev_list[ind + 8] >> 1];
            int16 w41 = zetas_mont[bitrev_list[ind + 9] >> 1];
            int16 w50 = zetas_mont[bitrev_list[ind + 10] >> 1];
            int16 w51 = zetas_mont[bitrev_list[ind + 11] >> 1];
            int16 w60 = zetas_mont[bitrev_list[ind + 12] >> 1];
            int16 w61 = zetas_mont[bitrev_list[ind + 13] >> 1];
            int16 w70 = zetas_mont[bitrev_list[ind + 14] >> 1];
            int16 w71 = zetas_mont[bitrev_list[ind + 15] >> 1];
            W = _mm256_setr_epi16(w00, w01, w10, w11, w20, w21, w30, w31, w40, w41, w50, w51, w60, w61, w70, w71);

            offset1(X, Y);
            X = ifgeq_subq(X);
            T = montgomery_reduce_epi16(Y, W);
            Y = _mm256_add_epi16(X, Q);
            Y = _mm256_sub_epi16(Y, T);
            X = _mm256_add_epi16(X, T);
            offset1_inv(X, Y);

            _mm256_storeu_si256((__m256i*)(pf), X);
            _mm256_storeu_si256((__m256i*)(pf + 16), Y);
        }
        continue;

    Error: //捕获块大小错误
        printf("Blocksize isn't power of 2.\n");
    }

    for (int32 k = 0; k < NTT_N; k += 16) {
    
    
        __m256i X = _mm256_loadu_si256((__m256i*)(f + k));  //模约减,从[0,2q)约减到[0,q)
        X = ifgeq_subq(X);
        _mm256_storeu_si256((__m256i*)(f + k), X);
    }
}


void intt(int16* f, int8 mont) {
    
    
    int32 Blocknum = 1 << NTT_ROUND;
    int32 Blocksize = NTT_N >> NTT_ROUND;
    int32 Round = NTT_ROUND;
    int32 Qtimes2 = NTT_Q * 2;
    Blocksize <<= 2;
    Blocknum >>= 2;
    __m256i T, Q = _mm256_set1_epi16(NTT_Q);

    /*
        Radix-4
        Harvey,输入输出范围[0,2q)
        X1 = (X1 + X2) + (Y1 + Y2),范围[0,8q)
        X2 = IW0*(X1 - X2) + IW1*(Y1 - Y2),范围[0,2q)
        Y1 = IW*((X1 + X2) + (Y1 + Y2)),范围[0,q)
        Y2 = IW*(IW0*(X1 - X2) + IW1*(Y1 - Y2)),范围[0,q)
        先约束(X1 + X2)和(Y1 + Y2)范围[0,2q),接着约束(X1 + X2) + (Y1 + Y2)范围[0,2q),共三次模约减
    */
    for (; Round > 1; Round -= 2, Blocksize <<= 2, Blocknum >>= 2) {
    
    
        if (Blocksize >= 64)
            goto Block64;
        else
            switch (Blocksize)
            {
    
    
            case 32: goto Block32;
            case 16: goto Block16;
            case 8: goto Block8;
            case 4: goto Block4;
            default:
                goto Error; //本代码仅处理:NTT_N 是2的幂次
            }

    Block64: //处理分块大小整除64的情况,使用4个YMM,处理1个块
        for (int32 i = 0; i < Blocknum; i++) {
    
    
            int32 offset = Blocksize >> 2;
            int32 num = offset >> 4; //16个系数1个YMM
            int16* pf = f + i * Blocksize;

            /*
                j=0是原始数组,第j次迭代中,j-1层第i个分块使用的单位根,
                w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{2^{r-j}*brv_{j}(2i)}
                brv_{j}(2i) = brv_{r}/(r-j+1)
                因此 w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{brv_{r}(i)/2}
            */
            int32 ind = Blocknum * NTT_NEG + i;
            __m256i W = _mm256_set1_epi16(zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)]);            //Round层第i块
            __m256i W0 = _mm256_set1_epi16(zetas_mont[NTT_ORDER - (bitrev_list[ind * 2] >> 1)]);       //Round+1层第2i块
            __m256i W1 = _mm256_set1_epi16(zetas_mont[NTT_ORDER - (bitrev_list[ind * 2 + 1] >> 1)]);   //Round+1层第2i+1块

            for (int32 k = 0; k < num; k++) {
    
    
                __m256i X1 = _mm256_loadu_si256((__m256i*)(pf + k * 16));
                __m256i X2 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset));
                __m256i Y1 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset * 2));
                __m256i Y2 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset * 3));

                T = _mm256_sub_epi16(X1, X2);
                X1 = _mm256_add_epi16(X1, X2);
                X2 = montgomery_reduce_epi16(T, W0);
                X1 = ifgeq_subq(X1);

                T = _mm256_sub_epi16(Y1, Y2);
                Y1 = _mm256_add_epi16(Y1, Y2);
                Y2 = montgomery_reduce_epi16(T, W1);
                Y1 = ifgeq_subq(Y1);

                T = _mm256_sub_epi16(X1, Y1);
                X1 = _mm256_add_epi16(X1, Y1);
                Y1 = montgomery_reduce_epi16(T, W);
                X1 = ifgeq_subq(X1);

                T = _mm256_sub_epi16(X2, Y2);
                X2 = _mm256_add_epi16(X2, Y2);
                Y2 = montgomery_reduce_epi16(T, W);
                X2 = ifgeq_subq(X2);

                _mm256_storeu_si256((__m256i*)(pf + k * 16), X1);
                _mm256_storeu_si256((__m256i*)(pf + k * 16 + offset), X2);
                _mm256_storeu_si256((__m256i*)(pf + k * 16 + offset * 2), Y1);
                _mm256_storeu_si256((__m256i*)(pf + k * 16 + offset * 3), Y2);
            }
        }
        continue;


    Block32: //处理分块大小为32的情况,使用2个YMM,处理1个块
        for (int32 i = 0; i < Blocknum; i++) {
    
    
            int16* pf = f + i * Blocksize;
            __m256i X = _mm256_loadu_si256((__m256i*)(pf));
            __m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));

            int32 ind = (Blocknum * NTT_NEG + i) * 2;
            int16 w0 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
            int16 w1 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
            __m256i W = _mm256_setr_epi16(w0, w0, w0, w0, w0, w0, w0, w0, w1, w1, w1, w1, w1, w1, w1, w1);

            offset8(X, Y);
            T = _mm256_sub_epi16(X, Y);
            X = _mm256_add_epi16(X, Y);
            Y = montgomery_reduce_epi16(T, W);
            X = ifgeq_subq(X);
            offset8_inv(X, Y);

            ind >>= 1;
            int16 w = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
            W = _mm256_set1_epi16(w);

            T = _mm256_sub_epi16(X, Y);
            X = _mm256_add_epi16(X, Y);
            Y = montgomery_reduce_epi16(T, W);
            X = ifgeq_subq(X);

            _mm256_storeu_si256((__m256i*)(pf), X);
            _mm256_storeu_si256((__m256i*)(pf + 16), Y);
        }
        continue;


    Block16: //处理分块大小为16的情况,使用2个YMM,处理2个块
        for (int32 i = 0; i < Blocknum; i += 2) {
    
    
            int16* pf = f + i * Blocksize;
            __m256i X = _mm256_loadu_si256((__m256i*)(pf));
            __m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));

            int32 ind = (Blocknum * NTT_NEG + i) * 2;
            int16 w00 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
            int16 w01 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
            int16 w10 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];
            int16 w11 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];
            __m256i W = _mm256_setr_epi16(w00, w00, w00, w00, w01, w01, w01, w01, w10, w10, w10, w10, w11, w11, w11, w11);

            offset4(X, Y);
            T = _mm256_sub_epi16(X, Y);
            X = _mm256_add_epi16(X, Y);
            Y = montgomery_reduce_epi16(T, W);
            X = ifgeq_subq(X);
            offset4_inv(X, Y);

            ind >>= 1;
            int16 w0 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
            int16 w1 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
            W = _mm256_setr_epi16(w0, w0, w0, w0, w0, w0, w0, w0, w1, w1, w1, w1, w1, w1, w1, w1);

            offset8(X, Y);
            T = _mm256_sub_epi16(X, Y);
            X = _mm256_add_epi16(X, Y);
            Y = montgomery_reduce_epi16(T, W);
            X = ifgeq_subq(X);
            offset8_inv(X, Y);

            _mm256_storeu_si256((__m256i*)(pf), X);
            _mm256_storeu_si256((__m256i*)(pf + 16), Y);
        }
        continue;


    Block8: //处理分块大小为8的情况,使用2个YMM,处理4个块
        for (int32 i = 0; i < Blocknum; i += 4) {
    
    
            int16* pf = f + i * Blocksize;
            __m256i X = _mm256_loadu_si256((__m256i*)(pf));
            __m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));

            int32 ind = (Blocknum * NTT_NEG + i) * 2;
            int16 w00 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
            int16 w01 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
            int16 w10 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];
            int16 w11 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];
            int16 w20 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 4] >> 1)];
            int16 w21 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 5] >> 1)];
            int16 w30 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 6] >> 1)];
            int16 w31 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 7] >> 1)];
            __m256i W = _mm256_setr_epi16(w00, w00, w01, w01, w10, w10, w11, w11, w20, w20, w21, w21, w30, w30, w31, w31);

            offset2(X, Y);
            T = _mm256_sub_epi16(X, Y);
            X = _mm256_add_epi16(X, Y);
            Y = montgomery_reduce_epi16(T, W);
            X = ifgeq_subq(X);
            offset2_inv(X, Y);

            ind >>= 1;
            int16 w0 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
            int16 w1 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
            int16 w2 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];
            int16 w3 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];
            W = _mm256_setr_epi16(w0, w0, w0, w0, w1, w1, w1, w1, w2, w2, w2, w2, w3, w3, w3, w3);

            offset4(X, Y);
            T = _mm256_sub_epi16(X, Y);
            X = _mm256_add_epi16(X, Y);
            Y = montgomery_reduce_epi16(T, W);
            X = ifgeq_subq(X);
            offset4_inv(X, Y);

            _mm256_storeu_si256((__m256i*)(pf), X);
            _mm256_storeu_si256((__m256i*)(pf + 16), Y);
        }
        continue;


    Block4: //处理分块大小为4的情况,使用2个YMM,处理8个块
        for (int32 i = 0; i < Blocknum; i += 8) {
    
    
            int16* pf = f + i * Blocksize;
            __m256i X = _mm256_loadu_si256((__m256i*)(pf));
            __m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));

            int32 ind = (Blocknum * NTT_NEG + i) * 2;
            int16 w00 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
            int16 w01 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
            int16 w10 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];
            int16 w11 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];
            int16 w20 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 4] >> 1)];
            int16 w21 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 5] >> 1)];
            int16 w30 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 6] >> 1)];
            int16 w31 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 7] >> 1)];
            int16 w40 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 8] >> 1)];
            int16 w41 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 9] >> 1)];
            int16 w50 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 10] >> 1)];
            int16 w51 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 11] >> 1)];
            int16 w60 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 12] >> 1)];
            int16 w61 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 13] >> 1)];
            int16 w70 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 14] >> 1)];
            int16 w71 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 15] >> 1)];
            __m256i W = _mm256_setr_epi16(w00, w01, w10, w11, w20, w21, w30, w31, w40, w41, w50, w51, w60, w61, w70, w71);

            offset1(X, Y);
            T = _mm256_sub_epi16(X, Y);
            X = _mm256_add_epi16(X, Y);
            Y = montgomery_reduce_epi16(T, W);
            X = ifgeq_subq(X);
            offset1_inv(X, Y);

            ind >>= 1;
            int16 w0 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
            int16 w1 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
            int16 w2 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];
            int16 w3 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];
            int16 w4 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 4] >> 1)];
            int16 w5 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 5] >> 1)];
            int16 w6 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 6] >> 1)];
            int16 w7 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 7] >> 1)];
            W = _mm256_setr_epi16(w0, w0, w1, w1, w2, w2, w3, w3, w4, w4, w5, w5, w6, w6, w7, w7);

            offset2(X, Y);
            T = _mm256_sub_epi16(X, Y);
            X = _mm256_add_epi16(X, Y);
            Y = montgomery_reduce_epi16(T, W);
            X = ifgeq_subq(X);
            offset2_inv(X, Y);

            _mm256_storeu_si256((__m256i*)(pf), X);
            _mm256_storeu_si256((__m256i*)(pf + 16), Y);
        }
        continue;

    Error: //捕获块大小错误
        printf("Blocksize isn't power of 2.\n");
    }

    /*
        Radix-2
        X = X + Y
        Y = IW*(X - Y)
    */
    if ((NTT_ROUND & 1) == 1) {
    
    
        int32 offset = Blocksize >> 1;
        int32 num = offset >> 4; //16个系数1个YMM
        __m256i W = _mm256_set1_epi16(zetas_mont[NTT_ORDER - (bitrev_list[Blocknum * NTT_NEG] >> 1)]);

        for (int32 k = 0; k < offset; k += 16) {
    
    
            __m256i X = _mm256_loadu_si256((__m256i*)(f + k));
            __m256i Y = _mm256_loadu_si256((__m256i*)(f + k + offset));

            T = _mm256_sub_epi16(X, Y);
            X = _mm256_add_epi16(X, Y);
            Y = montgomery_reduce_epi16(T, W);

            _mm256_storeu_si256((__m256i*)(f + k), X);
            _mm256_storeu_si256((__m256i*)(f + k + offset), Y);
        }
    }

    //逆变换因子
    __m256i F = _mm256_set1_epi16(factor_mont);
    if(mont != 0)
        F = _mm256_set1_epi16(factor_mont2); //执行了 montgomery 版本的 nttmul,需额外再乘一个 mont = R mod q

    for (int32 k = 0; k < NTT_N; k += 16) {
    
    
        __m256i X = _mm256_loadu_si256((__m256i*)(f + k));
        X = montgomery_reduce_epi16(X, F);
        _mm256_storeu_si256((__m256i*)(f + k), X);
    }
}

inline void basemul_mont(int16* r, const int16* a, const int16* b, int16 zeta_mont)
{
    
    
    int32 res;  // 用更长的累加器,延迟取模运算
    int32 s;
    for (int32 i = 0; i < NTT_BASELEN; i++)
    {
    
    
        res = 0;
        s = NTT_BASELEN + i;
        for (int32 j = 0; j <= i; j++)
            res += b[j] * a[i - j];
        for (int32 j = i + 1; j < NTT_BASELEN; j++) {
    
    
            res += montgomery_reduce(b[j] * zeta_mont) * a[s - j];
        }
        r[i] = montgomery_reduce(res);  //结果是 r = a*b/R
    }
}

void nttmul_mont(int16* r, const int16* a, const int16* b)
{
    
    
    // 2^{r-1} 个 n/2^{r-1} 长小多项式,NTT_ROUND = r-1
    int32 num = 1 << NTT_ROUND;

#if (NTT_BASELEN == 1)    //AVX2实现
    for (int32 i = 0; i < num; i += 16) {
    
    
        __m256i X = _mm256_loadu_si256(a);
        __m256i Y = _mm256_loadu_si256(b);

        X = montgomery_reduce_epi16(X, Y);
        _mm256_storeu_si256(r, X);

        r += 16;
        a += 16;
        b += 16;
    }

#elif (NTT_BASELEN == 2)    //常规实现
    for (int32 i = 0; i < num; i++) {
    
    
        // 第r层第2^{r-1}+i个多项式使用的单位根,
        // w_{2^r}^{brv_r(2^{r-1}+i)},NTT_ROUND = r-1
        int32 zeta = zetas_mont[bitrev_list[num * NTT_NEG + i]];

        int32 tmp0 = a[0] * b[0] + montgomery_reduce(zeta * a[1]) * b[1];
        int32 tmp1 = a[0] * b[1] + a[1] * b[0];
        r[0] = montgomery_reduce(tmp0);
        r[1] = montgomery_reduce(tmp1);

        r += NTT_BASELEN;
        a += NTT_BASELEN;
        b += NTT_BASELEN;
    }

#else
    for (int32 i = 0; i < num; i++) {
    
    
        int32 zeta = zetas_mont[bitrev_list[num * NTT_NEG + i]];
        basemul_mont(r, a, b, zeta); //常规实现

        r += NTT_BASELEN;
        a += NTT_BASELEN;
        b += NTT_BASELEN;
    }

#endif
}

inline void basemul(int16* r, const int16* a, const int16* b, int16 zeta)
{
    
    
    int32 res;  // 用更长的累加器,延迟取模运算
    int32 s;
    for (int32 i = 0; i < NTT_BASELEN; i++)
    {
    
    
        res = 0;
        s = NTT_BASELEN + i;
        for (int32 j = 0; j <= i; j++)
            res += b[j] * a[i - j];
        for (int32 j = i + 1; j < NTT_BASELEN; j++)
            res += zeta * barrett_reduce(b[j] * a[s - j]);
        r[i] = barrett_reduce(res);
    }
}

void nttmul(int16* r, const int16* a, const int16* b, int8 mont)
{
    
    
    if (mont == 0) {
    
    
        // 2^{r-1} 个 n/2^{r-1} 长小多项式,NTT_ROUND = r-1
        int32 num = 1 << NTT_ROUND;
        for (int32 i = 0; i < num; i++)
        {
    
    
#if (NTT_BASELEN == 1)
            int32 tmp = *a * *b;
            *r = barrett_reduce(tmp);
#elif (NTT_BASELEN == 2)
            // 第r层第2^{r-1}+i个多项式使用的单位根,
            // w_{2^r}^{brv_r(2^{r-1}+i)},NTT_ROUND = r-1
            int32 zeta = zetas[bitrev_list[num * NTT_NEG + i]];
            int32 tmp0 = a[0] * b[0] + zeta * barrett_reduce(a[1] * b[1]);
            int32 tmp1 = a[0] * b[1] + a[1] * b[0];
            r[0] = barrett_reduce(tmp0);
            r[1] = barrett_reduce(tmp1);
#else
            // 第r层第2^{r-1}+i个多项式使用的单位根,
            // w_{2^r}^{brv_r(2^{r-1}+i)},NTT_ROUND = r-1
            int32 zeta = zetas[bitrev_list[num * NTT_NEG + i]];
            basemul(r, a, b, zeta);
#endif

            r += NTT_BASELEN;
            a += NTT_BASELEN;
            b += NTT_BASELEN;
        }
    }
    else
        nttmul_mont(r, a, b);
}

Test

cputimer.h

#ifndef CPUTIMER
#define CPUTIMER

#if defined(__linux__)
// Linux系统
#include <unistd.h>
#elif defined(_WIN32)
// Windows系统
#include <intrin.h>
#include <windows.h>
#endif

/*单位:毫秒*/
void sleepms(int time) {
    
    
#if defined(__linux__)
	// Linux系统
	usleep(time * 1000);
#elif defined(_WIN32)
	// Windows系统
	Sleep(time);
#endif
}

/* Needs echo 2 > /sys/devices/cpu/rdpmc */
unsigned long long cputimer() {
    
    
    // 以下三种方法,是等价的(只在 x86 上运行,而 x64 不支持内联汇编)

    // 1.
    /*__asm {
        rdtsc;
        shl edx, 32;
        or eax, edx;
    }*/

    // 2.
    //__asm RDTSC;

    // 3.
    /*__asm _emit 0x0F
    __asm _emit 0x31*/

#if _WIN32
    return __rdtsc();
#else
    unsigned int lo, hi;
    __asm__ volatile ("rdtsc" : "=a" (lo), "=d" (hi));
    return ((unsigned long long)hi << 32) | lo;
#endif
}
//unsigned long long cputimer();   // 独立汇编代码
/*
    align 16

    _cputimer:
    rdtsc
    shl rdx, 32
    or rax, rdx
    ret
*/

unsigned long long CPUFrequency;

// 测量 CPU 主频
unsigned long long GetFrequency() {
    
    
    unsigned long long t1 = cputimer();
	sleepms(1000);
    unsigned long long t2 = cputimer();
    CPUFrequency = t2 - t1;
	return CPUFrequency;
}

#define pn printf("\n\n")

unsigned long long TM_start, TM_end;
#define Timer(code) TM_start = cputimer(); code; TM_end = cputimer(); \
    printf("time = %lld cycles (%f s)\n", TM_end - TM_start, (double)(TM_end - TM_start)/CPUFrequency); //对code部分计时

unsigned long long TM_mem[10000];
#define Loop(loop, code) for(int i=0; i<loop; i++) {
      
      \
    TM_start = cputimer(); code; TM_end = cputimer(); TM_mem[i] = TM_end - TM_start;} Analyis_TM(loop); 



void __quick_sort(unsigned long long* arr, int begin, int end) //快速排序,简化版
{
    
    
    if (begin >= end)
        return;

    unsigned long long temp1 = arr[begin], temp2;
    int k = begin;
    for (int i = begin + 1; i <= end; i++)
    {
    
    
        if (temp1 > arr[i])
        {
    
    
            temp2 = arr[i];
            int j;
            for (j = i - 1; j >= k; j--)
                arr[j + 1] = arr[j];
            arr[j + 1] = temp2;
            k++;
        }
    }
    __quick_sort(arr, begin, k - 1);
    __quick_sort(arr, k + 1, end);
}

void quick_sort(unsigned long long* arr, int size)
{
    
    
    __quick_sort(arr, 0, size - 1);
}

void Analyis_TM(int loop) //分析代码性能
{
    
    
    unsigned long long min, max, med, aver = 0;
    quick_sort(TM_mem, loop);
    min = TM_mem[0];
    max = TM_mem[loop-1];
    med = TM_mem[loop >> 1];
    for (int i = 0; i < loop; i++) {
    
    
        aver += TM_mem[i];
    }
    aver /= loop;
    printf("Time:\n\tMinimum\t%10lld cycles,%10.6f ms\n\tMaximum\t%10lld cycles,%10.6f ms\n\tMedian\t%10lld cycles,%10.6f ms\n\tAverage\t%10lld cycles,%10.6f ms\n", 
        min, (double)min / CPUFrequency * 1000, max, (double)max / CPUFrequency * 1000, med, (double)med / CPUFrequency * 1000, aver, (double)aver / CPUFrequency * 1000);
}


#endif

Result

WSL 下用 gcc 编译:

gcc ntt_avx2.c test_ntt_avx2.c -o test_ntt_avx2 -O3 -fopt-info-vec-optimized -mavx2

执行 ./test_ntt_avx2 的结果为:

CPU Frequency = 2918844449

Time: ntt 
        Minimum       2588 cycles,  0.000887 ms
        Maximum      39908 cycles,  0.013672 ms
        Median        2636 cycles,  0.000903 ms
        Average       2881 cycles,  0.000987 ms

Time: intt
        Minimum       2546 cycles,  0.000872 ms
        Maximum      29334 cycles,  0.010050 ms
        Median        2586 cycles,  0.000886 ms
        Average       2650 cycles,  0.000908 ms

Time: nttmul
        Minimum       1828 cycles,  0.000626 ms
        Maximum      91908 cycles,  0.031488 ms
        Median        1838 cycles,  0.000630 ms
        Average       1865 cycles,  0.000639 ms

Time: nttmul_mont
        Minimum        164 cycles,  0.000056 ms
        Maximum       7684 cycles,  0.002632 ms
        Median         186 cycles,  0.000064 ms
        Average        199 cycles,  0.000068 ms

猜你喜欢

转载自blog.csdn.net/weixin_44885334/article/details/130298035
NTT
今日推荐