STARK/SNARK加速小技巧

1. 引言

STARK/SNARK中包含了大量的有限域运算,如:

  • STARK中包含了对 roots of unity domain D = { 1 , ω , ⋯   , ω 2 k − 1 } D=\{1,\omega,\cdots,\omega^{2^k-1}\} D={ 1,ω,,ω2k1} 的批量 polynomial evaluation运算

除可通过CPU/GPU等硬件加速外,还有一些算法可实现加速。

CPU加速可参看:

采用CPU加速的密码学库有:

前序博客有:

2. 采用FFT来加速roots of unity domain内的批量polynomial evaluation

有限域内的Fast Fourier Transform(FFT),又可称为Number Theory Transform (NTT)。FFT用于将多项式系数表示法,转换为点值表示法。

令多项式 f ( X ) = ∑ i = 0 d c i X i f(X)=\sum_{i=0}^{d}c_iX^i f(X)=i=0dciXi的degree 不高于 2 k − 1 2^k-1 2k1,其系数 c i ∈ F p c_i\in\mathbb{F}_p ciFp
ω \omega ω 2 k 2^k 2k-th root of unity,求以下所有polynomial evaluation 值:
( f ( ω i ) ) i = 0 2 k − 1 = ( f ( 1 ) , f ( ω ) , f ( ω 2 ) , … , f ( ω 2 k − 1 ) ) (f(\omega^i))_{i=0}^{2^k-1} = (f(1), f(\omega), f(\omega^2), \ldots, f(\omega^{2^k-1})) (f(ωi))i=02k1=(f(1),f(ω),f(ω2),,f(ω2k1))

解决方案有:

  • 1)最直观的方法是,依次计算每个evaluation值。【令 N = 2 k N=2^k N=2k,算法复杂度为 O ( N 2 ) O(N^2) O(N2)。】【interpolate、evaluate、divide等naive运算可参看:https://github.com/aszepieniec/stark-anatomy/blob/master/code/univariate.py
  • 2)更明智的方法为,根据FFT的divide-and-conquer策略,将多项式分为奇数项和偶数项表示:【递归调用,令 N = 2 k N=2^k N=2k,算法复杂度为 O ( N ⋅ log ⁡ N ) O(N\cdot \log N) O(NlogN)。】【interpolate、evaluate、divide等借助FFT/IFFT加速运算可参看:https://github.com/aszepieniec/stark-anatomy/blob/master/code/ntt.py
    f ( X ) = f E ( X 2 ) + X ⋅ f O ( X 2 ) f(X)=f_E(X^2)+X\cdot f_O(X^2) f(X)=fE(X2)+XfO(X2)
    其中:
    f E ( X 2 ) = f ( X ) + f ( − X ) 2 = ∑ i = 0 d + 1 2 − 1 c 2 i X 2 i f_E(X^2)=\frac{f(X)+f(-X)}{2}=\sum_{i=0}^{\frac{d+1}{2}-1}c_{2i}X^{2i} fE(X2)=2f(X)+f(X)=i=02d+11c2iX2i
    f O ( X 2 ) = f ( X ) − f ( − X ) 2 X = ∑ i = 0 d + 1 2 − 1 c 2 i + 1 X 2 i f_O(X^2)=\frac{f(X)-f(-X)}{2X}=\sum_{i=0}^{\frac{d+1}{2}-1}c_{2i+1}X^{2i} fO(X2)=2Xf(X)f(X)=i=02d+11c2i+1X2i
    从而有:
    f ( ω i ) = f E ( ω 2 i ) + ω i ⋅ f O ( ω 2 i ) f(\omega^i)=f_E(\omega^{2i})+\omega^i\cdot f_O(\omega^{2i}) f(ωi)=fE(ω2i)+ωifO(ω2i)

递归调用,多项式系数 c i c_i ci为参数values,则调用ntt返回的即为 ( f ( ω i ) ) i = 0 2 k − 1 = ( f ( 1 ) , f ( ω ) , f ( ω 2 ) , … , f ( ω 2 k − 1 ) ) (f(\omega^i))_{i=0}^{2^k-1} = (f(1), f(\omega), f(\omega^2), \ldots, f(\omega^{2^k-1})) (f(ωi))i=02k1=(f(1),f(ω),f(ω2),,f(ω2k1)) evaluation值:【代码见:https://github.com/aszepieniec/stark-anatomy/blob/master/code/ntt.py

def ntt( primitive_root, values ):
    assert(len(values) & (len(values) - 1) == 0), "cannot compute ntt of non-power-of-two sequence"
    if len(values) <= 1:
        return values

    field = values[0].field

    assert(primitive_root^len(values) == field.one()), "primitive root must be nth root of unity, where n is len(values)"
    assert(primitive_root^(len(values)//2) != field.one()), "primitive root is not primitive nth root of unity, where n is len(values)"

    half = len(values) // 2

    odds = ntt(primitive_root^2, values[1::2])
    evens = ntt(primitive_root^2, values[::2])

    return [evens[i % half] + (primitive_root^i) * odds[i % half] for i in range(len(values))]

3. 采用IFFT来获取多项式系数

有限域内的Inverse Fast Fourier Transform(IFFT),又可称为Inverse Number Theory Transform (INTT)。FFT用于将点值表示法,转换为多项式系数表示法。

IFFT(FFT()),FFT之后再IFFT运算,得到的结果为原值。

扫描二维码关注公众号,回复: 13499016 查看本文章
def test_intt( ):
    field = Field.main()

    logn = 7
    n = 1 << logn
    primitive_root = field.primitive_nth_root(n)

    values = [field.sample(os.urandom(1)) for i in range(n)]
    coeffs = ntt(primitive_root, values)
    values_again = intt(primitive_root, coeffs)

    assert(values == values_again), "inverse ntt is different from forward ntt"

def intt( primitive_root, values ):
    assert(len(values) & (len(values) - 1) == 0), "cannot compute intt of non-power-of-two sequence"

    if len(values) == 1:
        return values

    field = values[0].field
    ninv = FieldElement(len(values), field).inverse()
	# primitive_root.inverse()后再进行
    transformed_values = ntt(primitive_root.inverse(), values)
    # 乘以ninv后为实际系数。
    return [ninv*tv for tv in transformed_values]

4. 基于FFT和IFFT实现多项式fast multiplication

已知多项式 f ( X ) , g ( X ) f(X),g(X) f(X),g(X),求 h ( X ) = f ( X ) ⋅ g ( X ) h(X) = f(X) \cdot g(X) h(X)=f(X)g(X),其中有 d e g ( h ( X ) ) < 2 k \mathsf{deg}(h(X)) < 2^k deg(h(X))<2k

  • 1)分别对 f ( X ) , g ( X ) f(X),g(X) f(X),g(X)进行FFT运算,获得相应的点值表示。
  • 2)将点值表示的 f ( X ) , g ( X ) f(X),g(X) f(X),g(X)逐个相乘,结果为 h ( X ) h(X) h(X)的点值表示。
  • 3)将相乘后的点值进行IFFT运算,获得的为 h ( X ) h(X) h(X)的系数。
def fast_multiply( lhs, rhs, primitive_root, root_order ):
    assert(primitive_root^root_order == primitive_root.field.one()), "supplied root does not have supplied order"
    assert(primitive_root^(root_order//2) != primitive_root.field.one()), "supplied root is not primitive root of supplied order"

    if lhs.is_zero() or rhs.is_zero():
        return Polynomial([])

    field = lhs.coefficients[0].field
    root = primitive_root
    order = root_order
    degree = lhs.degree() + rhs.degree()

    if degree < 8: # 若degree较低,则直接进行多项式乘法运算。
        return lhs * rhs

    while degree < order // 2: # 取离degree最近的2^k order和相应的2^k-th root。
        root = root^2
        order = order // 2

    lhs_coefficients = lhs.coefficients[:(lhs.degree()+1)]
    while len(lhs_coefficients) < order: # 若系数个数不足order,则补0
        lhs_coefficients += [field.zero()]
    rhs_coefficients = rhs.coefficients[:(rhs.degree()+1)]
    while len(rhs_coefficients) < order: # 若系数个数不足order,则补0
        rhs_coefficients += [field.zero()]
	#  1)分别对$f(X),g(X)$进行FFT运算,获得相应的点值表示。
    lhs_codeword = ntt(root, lhs_coefficients)
    rhs_codeword = ntt(root, rhs_coefficients)
	# 2)将点值表示的$f(X),g(X)$逐个相乘,结果为$h(X)$的点值表示。
    hadamard_product = [l * r for (l, r) in zip(lhs_codeword, rhs_codeword)]
	# 3)将相乘后的点值进行IFFT运算,获得的为$h(X)$的系数。
    product_coefficients = intt(root, hadamard_product)
    # 只需取product_coefficients数组中的前degree+1个值即可,后面的都是0。
    return Polynomial(product_coefficients[0:(degree+1)])

5. 基于fast multiplication 实现 fast zerofier

仍然借助divide-and-conquer思想,可递归调用fast_multiply来计算zerofiers(又名vanish polynomial):
z ( X ) = ( X − 1 ) ( X − ω ) ⋯ ( X − ω 2 k − 1 ) z(X)=(X-1)(X-\omega)\cdots(X-\omega^{2^k-1}) z(X)=(X1)(Xω)(Xω2k1)【roots of unity domain】

z ( X ) = ( X − x 0 ) ( X − x 1 ) ⋯ ( X − x d ) z(X)=(X-x_0)(X-x_1)\cdots(X-x_{d}) z(X)=(Xx0)(Xx1)(Xxd)【任意domain】

详细的步骤为:

  • 1)将domain切分为左右两等份。
  • 2)分别对左右两份计算zerofiers。
  • 3)使用fast multiplication将zerofiers相乘。
def fast_zerofier( domain, primitive_root, root_order ):
    assert(primitive_root^root_order == primitive_root.field.one()), "supplied root does not have supplied order"
    assert(primitive_root^(root_order//2) != primitive_root.field.one()), "supplied root is not primitive root of supplied order"

    if len(domain) == 0:
        return Polynomial([])

    if len(domain) == 1:
        return Polynomial([-domain[0], primitive_root.field.one()])

    half = len(domain) // 2

    left = fast_zerofier(domain[:half], primitive_root, root_order)
    right = fast_zerofier(domain[half:], primitive_root, root_order)
    return fast_multiply(left, right, primitive_root, root_order)

6. 基于fast zerofier 实现 任意domain的fast evaluate

对任意domain { x 0 , x 1 , ⋯   , x d } \{x_0,x_1,\cdots,x_d\} { x0,x1,,xd},求evaluate f ( X ) f(X) f(X)
f ( x 0 ) , f ( x 1 ) , ⋯   , f ( x d ) f(x_0),f(x_1),\cdots, f(x_d) f(x0),f(x1),,f(xd)

根据Lagrange,可将 f ( X ) f(X) f(X)表示为:
f ( X ) = ∑ i = 0 d f ( x i ) ∏ j = 0 , j ≠ i d ( X − x j ) ∏ j = 0 , j ≠ i d ( x i − x j ) f(X)=\sum_{i=0}^{d}f(x_i)\frac{\prod_{j=0,j\neq i}^{d}(X-x_j)}{\prod_{j=0,j\neq i}^{d}(x_i-x_j)} f(X)=i=0df(xi)j=0,j=id(xixj)j=0,j=id(Xxj)

计算 [ f ( x 0 ) , f ( x 1 ) , ⋯   , f ( x d ) ] [f(x_0),f(x_1),\cdots, f(x_d)] [f(x0),f(x1),,f(xd)],相应的fast evaluate算法为:

def fast_evaluate( polynomial, domain, primitive_root, root_order ):
    assert(primitive_root^root_order == primitive_root.field.one()), "supplied root does not have supplied order"
    assert(primitive_root^(root_order//2) != primitive_root.field.one()), "supplied root is not primitive root of supplied order"

    if len(domain) == 0:
        return []

    if len(domain) == 1: # 返回值为数组。
        return [polynomial.evaluate(domain[0])]

    half = len(domain) // 2

    left_zerofier = fast_zerofier(domain[:half], primitive_root, root_order)
    right_zerofier = fast_zerofier(domain[half:], primitive_root, root_order)

    left = fast_evaluate(polynomial % left_zerofier, domain[:half], primitive_root, root_order)
    right = fast_evaluate(polynomial % right_zerofier, domain[half:], primitive_root, root_order)

    return left + right

7. 基于fast zerofier和fast evaluate 实现 fast interpolate

根据Lagrange,可将 f ( X ) f(X) f(X)表示为:
f ( X ) = ∑ i = 0 d f ( x i ) ∏ j = 0 , j ≠ i d ( x i − x j ) ∏ j = 0 , j ≠ i d ( X − x j ) f(X)=\sum_{i=0}^{d}\frac{f(x_i)}{\prod_{j=0,j\neq i}^{d}(x_i-x_j)}\prod_{j=0,j\neq i}^{d}(X-x_j) f(X)=i=0dj=0,j=id(xixj)f(xi)j=0,j=id(Xxj)

其中 ∏ j = 0 , j ≠ i d ( x i − x j ) \prod_{j=0,j\neq i}^{d}(x_i-x_j) j=0,j=id(xixj)对应为fast_interpolate算法中的left_offsetright offset变量:

def fast_interpolate( domain, values, primitive_root, root_order ):
    assert(primitive_root^root_order == primitive_root.field.one()), "supplied root does not have supplied order"
    assert(primitive_root^(root_order//2) != primitive_root.field.one()), "supplied root is not primitive root of supplied order"
    assert(len(domain) == len(values)), "cannot interpolate over domain of different length than values list"

    if len(domain) == 0:
        return Polynomial([])

    if len(domain) == 1:
        return Polynomial([values[0]])

    half = len(domain) // 2

    left_zerofier = fast_zerofier(domain[:half], primitive_root, root_order)
    right_zerofier = fast_zerofier(domain[half:], primitive_root, root_order)

    left_offset = fast_evaluate(right_zerofier, domain[:half], primitive_root, root_order)
    right_offset = fast_evaluate(left_zerofier, domain[half:], primitive_root, root_order)

    if not all(not v.is_zero() for v in left_offset):
        print("left_offset:", " ".join(str(v) for v in left_offset))

    left_targets = [n / d for (n,d) in zip(values[:half], left_offset)]
    right_targets = [n / d for (n,d) in zip(values[half:], right_offset)]

    left_interpolant = fast_interpolate(domain[:half], left_targets, primitive_root, root_order)
    right_interpolant = fast_interpolate(domain[half:], right_targets, primitive_root, root_order)

    return left_interpolant * right_zerofier + right_interpolant * left_zerofier

8. 基于FFT实现fast coset evaluate

所谓coset evaluate,是对多项式 f ( X ) = ∑ i = 0 2 k − 1 c i X i f(X)=\sum_{i=0}^{2^k-1}c_iX^i f(X)=i=02k1ciXi,求:
f ( g ⋅ 1 ) , f ( g ⋅ ω ) , ⋯   , f ( g ⋅ ω 2 k − 1 ) f(g\cdot 1), f(g\cdot \omega), \cdots, f(g\cdot \omega^{2^k-1}) f(g1),f(gω),,f(gω2k1)
【对应为博客 STARK入门知识 4.3节 “Coset-FRI”。】

实际算法为:

  • 1)将多项式 f ( X ) f(X) f(X) scale为 f ( g ⋅ X ) = ∑ i = 0 2 k − 1 ( c i ⋅ g ) ⋅ X i f(g\cdot X)=\sum_{i=0}^{2^k-1}(c_i\cdot g)\cdot X^i f(gX)=i=02k1(cig)Xi
  • 2)使用FFT对 f ( g ⋅ X ) f(g\cdot X) f(gX) 进行roots of unity domain内的批量polynomial evaluation。
def fast_coset_evaluate( polynomial, offset, generator, order ):
    scaled_polynomial = polynomial.scale(offset)
    values = ntt(generator, scaled_polynomial.coefficients + [offset.field.zero()] * (order - len(polynomial.coefficients)))
    return values

9. 利用FFT和IFFT实现fast coset divide

为求 h ( X ) = f ( X ) z ( X ) h(X)=\frac{f(X)}{z(X)} h(X)=z(X)f(X),转为求 h ( g ⋅ X ) = f ( g ⋅ X ) z ( g ⋅ X ) h(g\cdot X)=\frac{f(g\cdot X)}{z(g\cdot X)} h(gX)=z(gX)f(gX),然后将 h ( g ⋅ X ) h(g\cdot X) h(gX)的系数除以 g g g,即为 h ( X ) h(X) h(X)的系数。

实际算法实现步骤为:

  • 1)分子分母系数都Scale
  • 2)NTT
  • 3)element-wise divide
  • 4)inverse NTT
  • 5)对求得的商系数unscale
# 只考虑了能整除干净的情况
def fast_coset_divide( lhs, rhs, offset, primitive_root, root_order ): # clean division only!
    assert(primitive_root^root_order == primitive_root.field.one()), "supplied root does not have supplied order"
    assert(primitive_root^(root_order//2) != primitive_root.field.one()), "supplied root is not primitive root of supplied order"
    assert(not rhs.is_zero()), "cannot divide by zero polynomial"

    if lhs.is_zero():
        return Polynomial([])

    assert(rhs.degree() <= lhs.degree()), "cannot divide by polynomial of larger degree"

    field = lhs.coefficients[0].field
    root = primitive_root
    order = root_order
    degree = max(lhs.degree(),rhs.degree())

    if degree < 8:
        return lhs / rhs

    while degree < order // 2:
        root = root^2
        order = order // 2
	# 1)分子分母系数都Scale
    scaled_lhs = lhs.scale(offset)
    scaled_rhs = rhs.scale(offset)
    
    lhs_coefficients = scaled_lhs.coefficients[:(lhs.degree()+1)]
    while len(lhs_coefficients) < order:
        lhs_coefficients += [field.zero()]
    rhs_coefficients = scaled_rhs.coefficients[:(rhs.degree()+1)]
    while len(rhs_coefficients) < order:
        rhs_coefficients += [field.zero()]
	# 2)NTT
    lhs_codeword = ntt(root, lhs_coefficients)
    rhs_codeword = ntt(root, rhs_coefficients)
	# 3)element-wise divide
    quotient_codeword = [l / r for (l, r) in zip(lhs_codeword, rhs_codeword)]
    # 4)inverse NTT
    scaled_quotient_coefficients = intt(root, quotient_codeword)
    # 只取scaled_quotient_coefficients数组中前(lhs.degree() - rhs.degree() + 1)个数,其余均为0
    scaled_quotient = Polynomial(scaled_quotient_coefficients[:(lhs.degree() - rhs.degree() + 1)])
	# 5)对求得的商系数unscale
    return scaled_quotient.scale(offset.inverse())

相应的测试用例为:

def test_divide( ):
    field = Field.main()

    logn = 6
    n = 1 << logn
    primitive_root = field.primitive_nth_root(n)

    for trial in range(20):
        lhs_degree = int(os.urandom(1)[0]) % (n // 2)
        rhs_degree = int(os.urandom(1)[0]) % (n // 2)

        lhs = Polynomial([field.sample(os.urandom(17)) for i in range(lhs_degree+1)])
        rhs = Polynomial([field.sample(os.urandom(17)) for i in range(rhs_degree+1)])

        fast_product = fast_multiply(lhs, rhs, primitive_root, n)
        quotient = fast_coset_divide(fast_product, lhs, field.generator(), primitive_root, n)

        assert(quotient == rhs), "fast divide does not equal original factor"

参考资料

[1] Anatomy of a STARK, Part 6: Speeding Things Up
[2] Filecoin zk-SNARK Accelerating
[3] GPU-SNARKs
[4] Plonk with GPU acceleration
[5] ZKSwap GPU optimization

猜你喜欢

转载自blog.csdn.net/mutourend/article/details/121651654