Halo中的快速傅里叶(逆)变换算法(I)FFT

1. FFT背景知识

FFT背景知识可参看博客十分简明易懂的FFT(快速傅里叶变换)

2. Halo中的FFT代码实现

在4核8G ubuntu16.04服务器上运行:

cargo test test_fft -- --nocapture

test_fft函数中实现的是对两个999阶(1000个系数)多项式的乘法运算,在该函数内,分别进行了直接乘法运算naive_product和通过FFT实现的乘法运算multiply_polynomials

2.1 ab系数列表均扩展为 2 e x p 2^{exp}

multiply_polynomials函数中会首先将两个多项式相乘后的系数总数扩展为 2 e x p 2^{exp} ,将ab系数列表补零扩展为 2 e x p 2^{exp}

 	let degree_of_result = (a.len() - 1) + (b.len() - 1); //1998
    let coeffs_of_result = degree_of_result + 1; //1999

    // Compute the size of our evaluation domain
    let mut m = 1; //2048
    let mut exp = 0; //11
    while m < coeffs_of_result {
        m *= 2;
        exp += 1;

        // The pairing-friendly curve may not be able to support
        // large enough (radix2) evaluation domains.
        if exp >= F::S {
            panic!("polynomial too large");
        }
    }
	//将`a`和`b`系数列表补零扩展为$2^{exp}$
	// Extend the vectors with zeroes
    a.resize(m, F::zero());
    b.resize(m, F::zero());

2.2 获取 2 e x p 2^{exp} -th primitive root of unity

F::ALPHA 2 32 2^{32} -th primitive root of unity,基于该值获取相应的 2 e x p 2^{exp} -th primitive root of unity:

// Compute alpha, the 2^exp primitive root of unity
    let mut alpha = F::ALPHA;
    for _ in exp..F::S {
        alpha = alpha.square();
    }
   //alpha为$2^{exp}$-th primitive root of unity

2.3 分别对ab系数列表做FFT

	////alpha为$2^{exp}$-th primitive root of unity,exp=11
	best_fft(&mut a, alpha, exp); 
    best_fft(&mut b, alpha, exp);

注意best_fft(&mut a, alpha, exp);返回的数组 a a 中(对多项式 A ( x ) = a 0 + a 1 x + a 2 x 2 + . . . + a n x n A(x)=a_0+a_1x+a_2x^2+...+a_nx^n )依次 x x w n 0 , w n 1 , w n 2 , . . . , w n ( n 1 ) w_n^0,w_n^1,w_n^2,...,w_n^{(n-1)} 的值 a = [ A ( w n 0 ) , A ( w n 1 ) , A ( w n 2 ) , . . . , A ( w n ( n 1 ) ) ] a=[A(w_n^0), A(w_n^1),A(w_n^2),...,A(w_n^{(n-1)})]
也就是说,通过best_fft函数,可将系数表示的多项式转换为点值表示: ( w n 0 , A ( w n 0 ) ) , . . . . , ( w n ( n 1 ) , A ( w n ( n 1 ) ) ) (w_n^0,A(w_n^0)),....,(w_n^{(n-1)},A(w_n^{(n-1)}))

best_fft中会针对exp与cpu核数的关系来决定调用串行方式serial_fft还是并行方式parallel_fft

fn best_fft<F: Field>(a: &mut [F], omega: F, log_n: u32) {
    let cpus = num_cpus::get(); //4
    let log_cpus = log2_floor(cpus); //2

    if log_n <= log_cpus {
        serial_fft(a, omega, log_n);
    } else { //11>2
        parallel_fft(a, omega, log_n, log_cpus);
    }
}

2.3.1 并行FFT算法parallel_fft

// omega为$2^{exp}$-th primitive root of unity,exp=11, log_n=11, log_cpus=2
fn parallel_fft<F: Field>(a: &mut [F], omega: F, log_n: u32, log_cpus: u32) {
    assert!(log_n >= log_cpus);

    let num_cpus = 1 << log_cpus; //4
    let log_new_n = log_n - log_cpus; //11-2=9
    let mut tmp = vec![vec![F::zero(); 1 << log_new_n]; num_cpus]; //  [2^2行 * 2^9列] 矩阵。行数2^2=4为cpu核数。
    let new_omega = omega.pow(&[num_cpus as u64, 0, 0, 0]); // new_omega=power_mod(omega, num_cpus,p),对应的new_omega为2^{exp-log_n}即2^9-th primitive root of unity。

    thread::scope(|scope| {
        let a = &*a;
		// 1)tmp为 [2^2行 * 2^9列] 矩阵,逐行操作。
        for (j, tmp) in tmp.iter_mut().enumerate() {
            scope.spawn(move |_| { //采用多线程方式,所有行同时处理。
                // Shuffle into a sub-FFT
                let omega_j = omega.pow(&[j as u64, 0, 0, 0]); // power_mod(omega,j,p)
                let omega_step = omega.pow(&[(j as u64) << log_new_n, 0, 0, 0]); //power_mod(omega, j*2^9,p)

                let mut elt = F::one(); //1
                for i in 0..(1 << log_new_n) { // 2)再逐列处理
                    for s in 0..num_cpus { // 3)再逐CPU处理
                        let idx = (i + (s << log_new_n)) % (1 << log_n);
                        //$idx$用于取系数列表`a`中的相应的值。
                        //s=0时,idx的取值范围为0~511;s=1时,idx:512~1023;s=2时,idx:1024~1535;s=3时,idx:1536~2047。
                        let mut t = a[idx]; 
                        t *= elt;
                        tmp[i] += t;
                        elt *= omega_step;
                    }
                    elt *= omega_j;
                }

                // Perform sub-FFT
                serial_fft(tmp, new_omega, log_new_n);
            });
        }
    })
    .unwrap();

    // Unshuffle
    let mask = (1 << log_cpus) - 1;
    for (idx, a) in a.iter_mut().enumerate() {
        *a = tmp[idx & mask][idx >> log_cpus];
    }
}

FFT算法的本质是将以系数表示的多项式转化为以点值表示。对于方程式:
y = a 0 + a 1 x + a 2 x 2 + . . . . + a n 1 x n 1 y=a_0+a_1x+a_2x^2+....+a_{n-1}x^{n-1}
其中, n = 2 e x p n=2^{exp} ,上例中 e x p = 11 , n = 2048 exp=11,n=2048 ,omega【表示为 w n 1 w_n^1 】为n-th root of unity(即满足 ( w n 1 ) n = 1 (w_n^1)^{n}=1 )。
转为 n n 个互不相同的点值序列 ( x 0 , y 0 ) , ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . . . . , ( x n 1 , y n 1 ) (x_0,y_0),(x_1,y_1),(x_2,y_2),......,(x_{n-1},y_{n-1}) ,其中 x k = ( w n 1 ) k = w n k x_k=(w_n^1)^k=w_n^k 。上例中parallel_fft函数中tmp数组内,存储的即为所有的 y y 值,即 t m p [ k ] = y k tmp[k]=y_k

以4核( p = 4 p=4 )CPU为例,支持将2047阶( n = 2048 n / p = 512 n=2048,n/p=512 )多项式拆分为四个线程分别执行:
A ( x ) = a 0 + a 1 x + a 2 x 2 + . . . + a 511 x 511     + x 512 ( a 512 + a 513 x + a 514 x 2 + . . . + a 1023 x 511 )     + x 1024 ( a 1024 + a 1025 x + a 1026 x 2 + . . . + a 1535 x 511 )     + x 1536 ( a 1536 + a 1537 x + a 1538 x 2 + . . . + a 2047 x 511 ) A(x)=a_0+a_1x+a_2x^2+...+a_{511}x^{511}\\ \ \ \ +x^{512}(a_{512}+a_{513}x+a_{514}x^2+... +a_{1023}x^{511})\\ \ \ \ +x^{1024}(a_{1024}+a_{1025}x+a_{1026}x^2+...+a_{1535}x^{511})\\ \ \ \ +x^{1536}(a_{1536}+a_{1537}x+a_{1538}x^2+...+a_{2047}x^{511})

逐列展开:
A ( x ) = C 0 ( x 512 ) + x C 1 ( x 512 ) + x 2 C 2 ( x 512 ) + . . . + x 511 C 511 ( x 512 ) A(x)=C_0(x^{512})+xC_1(x^{512})+x^2C_2(x^{512})+...+x^{511}C_{511}(x^{512})
其中:
C 0 ( x ) = a 0 + a 512 x + a 1024 x 2 + a 1536 x 3 C_0(x)=a_0+a_{512}x+a_{1024}x^2+a_{1536}x^3
C 1 ( x ) = a 1 + a 513 x + a 1025 x 2 + a 1537 x 3 C_1(x)=a_1+a_{513}x+a_{1025}x^2+a_{1537}x^3
. . . . . . . .......
C 511 ( x ) = a 511 + a 1023 x + a 1535 x 2 + a 2047 x 3 C_{511}(x)=a_{511}+a_{1023}x+a_{1535}x^2+a_{2047}x^3

再设 k < n p = 2048 4 = 512 k<\frac{n}{p}=\frac{2048}{4}=512 ,把 w n k = w 2048 k w_n^k=w_{2048}^k 作为 x x 值代入 A ( x ) A(x) 多项式,有:
A ( w 2048 k ) = C 0 ( w 2048 512 k ) + w 2048 k C 1 ( w 2048 512 k ) + . . . + w 2048 511 k C 511 ( w 2048 512 k ) = C 0 ( w 4 k ) + w 2048 k C 1 ( w 4 k ) + . . . + w 2048 511 k C 511 ( w 4 k ) A(w_{2048}^k)=C_0(w_{2048}^{512k})+w_{2048}^kC_1(w_{2048}^{512k})+...+w_{2048}^{511k}C_{511}(w_{2048}^{512k})\\ =C_0(w_4^k)+w_{2048}^kC_1({w_4^k})+...+w_{2048}^{511k}C_{511}(w_4^k)
因此,根据单位根的性质,只需计算分别取 k = 0 , 1 , 2 , 3 k=0,1,2,3 时相应的 C 0 ( w 4 k ) , C 1 ( w 4 k ) , . . . , C 511 ( w 4 k ) C_0(w_4^k),C_1({w_4^k}),...,C_{511}(w_4^k) 值,即可很方便的计算取任意 k < 512 k<512 时的 A ( w n k ) A(w_n^k) 的值。

在Halo代码中, w n 1 w_n^1 对应为omega w 4 k w_4^k 对应为omega_steplet omega_step = omega.pow(&[(j as u64) << log_new_n, 0, 0, 0]); //power_mod(omega, j*2^9,p)), w n k w_n^k 对应为omega_j

Halo代码中,tmp 2 2 × 2 9 2^2\times 2^9 矩阵:
t m p [ 0 ] = [ C 0 ( w 4 0 ) , w 2048 ( 0 1 ) C 1 ( w 4 0 ) , . . . , w 2048 ( 0 511 ) C 511 ( w 4 0 ) ] tmp[0]=[C_0(w_4^0),w_{2048}^{(0*1)}C_1(w_4^0),...,w_{2048}^{(0*511)}C_{511}(w_4^0)]
t m p [ 1 ] = [ C 0 ( w 4 1 ) , w 2048 ( 1 1 ) C 1 ( w 4 1 ) , . . . , w 2048 ( 1 511 ) C 511 ( w 4 1 ) ] tmp[1]=[C_0(w_4^1),w_{2048}^{(1*1)}C_1(w_4^1),...,w_{2048}^{(1*511)}C_{511}(w_4^1)]
t m p [ 2 ] = [ C 0 ( w 4 2 ) , w 2048 ( 2 1 ) C 1 ( w 4 2 ) , . . . , w 2048 ( 2 511 ) C 511 ( w 4 2 ) ] tmp[2]=[C_0(w_4^2),w_{2048}^{(2*1)}C_1(w_4^2),...,w_{2048}^{(2*511)}C_{511}(w_4^2)]
t m p [ 3 ] = [ C 0 ( w 4 3 ) , w 2048 ( 3 1 ) C 1 ( w 4 3 ) , . . . , w 2048 ( 3 511 ) C 511 ( w 4 3 ) ] tmp[3]=[C_0(w_4^3),w_{2048}^{(3*1)}C_1(w_4^3),...,w_{2048}^{(3*511)}C_{511}(w_4^3)]

2.3.2 串行FFT算法serial_fft

	// Perform sub-FFT
 	serial_fft(tmp, new_omega, log_new_n);
 	//此处的tmp为tmp[j]第j行内容,new_omega为2^9-th primitive root of unity,log_new_n值为9。

在这里插入图片描述
以下代码段的作用是将多项式系数数组 [ a 0 , a 1 , a 2 , . . . , a n ] [a_0,a_1,a_2,...,a_n] 按奇偶重新排列,数组的前半段为偶数系数,后半段为奇数系数,具体为: [ a 0 , a 2 , a 4 , . . . , a n 2 , a 1 , a 3 , a 5 , . . . , a n 1 ] [a_0,a_2,a_4,...,a_{n-2},a_1,a_3,a_5,...,a_{n-1}]

 	for k in 0..n {
        let rk = bitreverse(k, log_n);
        if k < rk {
            a.swap(rk as usize, k as usize);
        }
    }

在这里插入图片描述
A ( x ) = a 0 + a 1 x + a 2 x 2 + a 3 x 3 A(x)=a_0+a_1x+a_2x^2+a_3x^3 为例,下面程序的演示效果如上图所示。

fn serial_fft<F: Field>(a: &mut [F], omega: F, log_n: u32) {
    fn bitreverse(mut n: u32, l: u32) -> u32 {
        let mut r = 0;
        for _ in 0..l {
            r = (r << 1) | (n & 1);
            n >>= 1;
        }
        r
    }

    let n = a.len() as u32;
    assert_eq!(n, 1 << log_n);

    for k in 0..n {
        let rk = bitreverse(k, log_n);
        if k < rk {
            a.swap(rk as usize, k as usize);
        }
    }
    

    let mut m = 1;
    for _ in 0..log_n {
        let w_m = omega.pow(&[u64::from(n / (2 * m)), 0, 0, 0]);

        let mut k = 0;
        while k < n {
            let mut w = F::one();
            for j in 0..m {
                let mut t = a[(k + j + m) as usize];
                t *= w;
                a[(k + j + m) as usize] = a[(k + j) as usize] - t;
                a[(k + j) as usize] += t;
                w *= w_m;
            }

            k += 2 * m;
        }

        m *= 2;
    }
}

2.4 点值表示的多项式乘法运算

在这里插入图片描述

  	best_fft(&mut a, alpha, exp); //由系数表示转换为点值表示
    best_fft(&mut b, alpha, exp);

    // Multiply pairwise。点值表示的多项式乘法运算
    let num_cpus = num_cpus::get();
    if a.len() > num_cpus {
        thread::scope(|scope| {
            let chunk = a.len() / num_cpus::get();

            for (a, b) in a.chunks_mut(chunk).zip(b.chunks(chunk)) {
                scope.spawn(move |_| {
                    for (a, b) in a.iter_mut().zip(b.iter()) {
                        *a *= *b;
                    }
                });
            }
        })
        .unwrap();
    } else {
        for (a, b) in a.iter_mut().zip(b.iter()) {
            *a *= *b;
        }
    }

2.5 傅里叶逆变换IFFT

IFFT的作用是将点值表示转换为系数表示。
在这里插入图片描述
在这里插入图片描述

	// Inverse FFT
    let alpha_inv = alpha.invert().unwrap();
    best_fft(&mut a, alpha_inv, exp);

    // Divide all elements by m = a.len()
    let minv = F::from_u64(m as u64).invert().unwrap();
    if a.len() > num_cpus {
        thread::scope(|scope| {
            let chunk = a.len() / num_cpus::get();

            for a in a.chunks_mut(chunk) {
                scope.spawn(move |_| {
                    for a in a.iter_mut() {
                        *a *= minv;
                    }
                });
            }
        })
        .unwrap();
    } else {
        for a in a.iter_mut() {
            *a *= minv;
        }
    }
	//为了递归调用,a数组的长度做了扩展补零。只截取乘积后相应的阶即可。	
    a.truncate(coeffs_of_result); 

2.6 补充资料

在这里插入图片描述

sage: root=155978335310571138272812138773814534618935470879470300630834867870977
....: 67520449
sage: p=417793508166910149535221561915641187328610041455049381805950185075960478
....: 43329
sage:  omega= power_mod(root,2^21,p)//21=32-exp, for exp=11.
sage: omega
35014335792849108923302692126549442116295992392289760687159465394416590439942
sage: hex(omega)
'4d696968d9c7e5b55e6a88fe57cbaa9e166872f777629c2cd200ba70d7cec606'
sage: new_omega=power_mod(omega,4,p)
sage: hex(new_omega)
'323286f8bcf390a8b2be7ef037eab34c127cf6fc0b1e0bc4866f1ba33bc1dc80'
//对应的new_omega为2^{exp-log_n}即2^9-th primitive root of unity。
sage: power_mod(new_omega, 2^9,p)
1
sage: R=2^256
sage: mod(R,p) //域内的one值表示,之所以用R来代替1值,是因为采用了montgomery_reduce表示,1*R=1, montgomery_reduce后R/R=1。
32233387603934165516526672625559670387547976374630687678267546992721033953278
sage: hex(3223338760393416551652667262555967038754797637463068767826754699272103
....: 3953278)
'4743736b947db12c8a7ab15117a98d9efc82c7cb9bfdb6facc000305fffffffe'

对用地

参考资料:
[1] 论文《Halo: Recursive Proof Composition without a Trusted Setup
[2] https://electriccoin.co/zh/blog/halo-recursive-proof-composition-without-a-trusted-setup/
[3] https://github.com/ebfull/halo

发布了154 篇原创文章 · 获赞 13 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/mutourend/article/details/102697211
今日推荐