1. 引言
STARK/SNARK中包含了大量的有限域运算,如:
- STARK中包含了对 roots of unity domain D = { 1 , ω , ⋯ , ω 2 k − 1 } D=\{1,\omega,\cdots,\omega^{2^k-1}\} D={ 1,ω,⋯,ω2k−1} 的批量 polynomial evaluation运算
除可通过CPU/GPU等硬件加速外,还有一些算法可实现加速。
CPU加速可参看:
采用CPU加速的密码学库有:
前序博客有:
2. 采用FFT来加速roots of unity domain内的批量polynomial evaluation
有限域内的Fast Fourier Transform(FFT),又可称为Number Theory Transform (NTT)。FFT用于将多项式系数表示法,转换为点值表示法。
令多项式 f ( X ) = ∑ i = 0 d c i X i f(X)=\sum_{i=0}^{d}c_iX^i f(X)=∑i=0dciXi的degree 不高于 2 k − 1 2^k-1 2k−1,其系数 c i ∈ F p c_i\in\mathbb{F}_p ci∈Fp。
令 ω \omega ω为 2 k 2^k 2k-th root of unity,求以下所有polynomial evaluation 值:
( f ( ω i ) ) i = 0 2 k − 1 = ( f ( 1 ) , f ( ω ) , f ( ω 2 ) , … , f ( ω 2 k − 1 ) ) (f(\omega^i))_{i=0}^{2^k-1} = (f(1), f(\omega), f(\omega^2), \ldots, f(\omega^{2^k-1})) (f(ωi))i=02k−1=(f(1),f(ω),f(ω2),…,f(ω2k−1))
解决方案有:
- 1)最直观的方法是,依次计算每个evaluation值。【令 N = 2 k N=2^k N=2k,算法复杂度为 O ( N 2 ) O(N^2) O(N2)。】【interpolate、evaluate、divide等naive运算可参看:https://github.com/aszepieniec/stark-anatomy/blob/master/code/univariate.py】
- 2)更明智的方法为,根据FFT的divide-and-conquer策略,将多项式分为奇数项和偶数项表示:【递归调用,令 N = 2 k N=2^k N=2k,算法复杂度为 O ( N ⋅ log N ) O(N\cdot \log N) O(N⋅logN)。】【interpolate、evaluate、divide等借助FFT/IFFT加速运算可参看:https://github.com/aszepieniec/stark-anatomy/blob/master/code/ntt.py】
f ( X ) = f E ( X 2 ) + X ⋅ f O ( X 2 ) f(X)=f_E(X^2)+X\cdot f_O(X^2) f(X)=fE(X2)+X⋅fO(X2)
其中:
f E ( X 2 ) = f ( X ) + f ( − X ) 2 = ∑ i = 0 d + 1 2 − 1 c 2 i X 2 i f_E(X^2)=\frac{f(X)+f(-X)}{2}=\sum_{i=0}^{\frac{d+1}{2}-1}c_{2i}X^{2i} fE(X2)=2f(X)+f(−X)=∑i=02d+1−1c2iX2i
f O ( X 2 ) = f ( X ) − f ( − X ) 2 X = ∑ i = 0 d + 1 2 − 1 c 2 i + 1 X 2 i f_O(X^2)=\frac{f(X)-f(-X)}{2X}=\sum_{i=0}^{\frac{d+1}{2}-1}c_{2i+1}X^{2i} fO(X2)=2Xf(X)−f(−X)=∑i=02d+1−1c2i+1X2i
从而有:
f ( ω i ) = f E ( ω 2 i ) + ω i ⋅ f O ( ω 2 i ) f(\omega^i)=f_E(\omega^{2i})+\omega^i\cdot f_O(\omega^{2i}) f(ωi)=fE(ω2i)+ωi⋅fO(ω2i)
递归调用,多项式系数 c i c_i ci为参数values
,则调用ntt
返回的即为 ( f ( ω i ) ) i = 0 2 k − 1 = ( f ( 1 ) , f ( ω ) , f ( ω 2 ) , … , f ( ω 2 k − 1 ) ) (f(\omega^i))_{i=0}^{2^k-1} = (f(1), f(\omega), f(\omega^2), \ldots, f(\omega^{2^k-1})) (f(ωi))i=02k−1=(f(1),f(ω),f(ω2),…,f(ω2k−1)) evaluation值:【代码见:https://github.com/aszepieniec/stark-anatomy/blob/master/code/ntt.py】
def ntt( primitive_root, values ):
assert(len(values) & (len(values) - 1) == 0), "cannot compute ntt of non-power-of-two sequence"
if len(values) <= 1:
return values
field = values[0].field
assert(primitive_root^len(values) == field.one()), "primitive root must be nth root of unity, where n is len(values)"
assert(primitive_root^(len(values)//2) != field.one()), "primitive root is not primitive nth root of unity, where n is len(values)"
half = len(values) // 2
odds = ntt(primitive_root^2, values[1::2])
evens = ntt(primitive_root^2, values[::2])
return [evens[i % half] + (primitive_root^i) * odds[i % half] for i in range(len(values))]
3. 采用IFFT来获取多项式系数
有限域内的Inverse Fast Fourier Transform(IFFT),又可称为Inverse Number Theory Transform (INTT)。FFT用于将点值表示法,转换为多项式系数表示法。
IFFT(FFT())
,FFT之后再IFFT运算,得到的结果为原值。
def test_intt( ):
field = Field.main()
logn = 7
n = 1 << logn
primitive_root = field.primitive_nth_root(n)
values = [field.sample(os.urandom(1)) for i in range(n)]
coeffs = ntt(primitive_root, values)
values_again = intt(primitive_root, coeffs)
assert(values == values_again), "inverse ntt is different from forward ntt"
def intt( primitive_root, values ):
assert(len(values) & (len(values) - 1) == 0), "cannot compute intt of non-power-of-two sequence"
if len(values) == 1:
return values
field = values[0].field
ninv = FieldElement(len(values), field).inverse()
# primitive_root.inverse()后再进行
transformed_values = ntt(primitive_root.inverse(), values)
# 乘以ninv后为实际系数。
return [ninv*tv for tv in transformed_values]
4. 基于FFT和IFFT实现多项式fast multiplication
已知多项式 f ( X ) , g ( X ) f(X),g(X) f(X),g(X),求 h ( X ) = f ( X ) ⋅ g ( X ) h(X) = f(X) \cdot g(X) h(X)=f(X)⋅g(X),其中有 d e g ( h ( X ) ) < 2 k \mathsf{deg}(h(X)) < 2^k deg(h(X))<2k:
- 1)分别对 f ( X ) , g ( X ) f(X),g(X) f(X),g(X)进行FFT运算,获得相应的点值表示。
- 2)将点值表示的 f ( X ) , g ( X ) f(X),g(X) f(X),g(X)逐个相乘,结果为 h ( X ) h(X) h(X)的点值表示。
- 3)将相乘后的点值进行IFFT运算,获得的为 h ( X ) h(X) h(X)的系数。
def fast_multiply( lhs, rhs, primitive_root, root_order ):
assert(primitive_root^root_order == primitive_root.field.one()), "supplied root does not have supplied order"
assert(primitive_root^(root_order//2) != primitive_root.field.one()), "supplied root is not primitive root of supplied order"
if lhs.is_zero() or rhs.is_zero():
return Polynomial([])
field = lhs.coefficients[0].field
root = primitive_root
order = root_order
degree = lhs.degree() + rhs.degree()
if degree < 8: # 若degree较低,则直接进行多项式乘法运算。
return lhs * rhs
while degree < order // 2: # 取离degree最近的2^k order和相应的2^k-th root。
root = root^2
order = order // 2
lhs_coefficients = lhs.coefficients[:(lhs.degree()+1)]
while len(lhs_coefficients) < order: # 若系数个数不足order,则补0
lhs_coefficients += [field.zero()]
rhs_coefficients = rhs.coefficients[:(rhs.degree()+1)]
while len(rhs_coefficients) < order: # 若系数个数不足order,则补0
rhs_coefficients += [field.zero()]
# 1)分别对$f(X),g(X)$进行FFT运算,获得相应的点值表示。
lhs_codeword = ntt(root, lhs_coefficients)
rhs_codeword = ntt(root, rhs_coefficients)
# 2)将点值表示的$f(X),g(X)$逐个相乘,结果为$h(X)$的点值表示。
hadamard_product = [l * r for (l, r) in zip(lhs_codeword, rhs_codeword)]
# 3)将相乘后的点值进行IFFT运算,获得的为$h(X)$的系数。
product_coefficients = intt(root, hadamard_product)
# 只需取product_coefficients数组中的前degree+1个值即可,后面的都是0。
return Polynomial(product_coefficients[0:(degree+1)])
5. 基于fast multiplication 实现 fast zerofier
仍然借助divide-and-conquer思想,可递归调用fast_multiply
来计算zerofiers(又名vanish polynomial):
z ( X ) = ( X − 1 ) ( X − ω ) ⋯ ( X − ω 2 k − 1 ) z(X)=(X-1)(X-\omega)\cdots(X-\omega^{2^k-1}) z(X)=(X−1)(X−ω)⋯(X−ω2k−1)【roots of unity domain】
或
z ( X ) = ( X − x 0 ) ( X − x 1 ) ⋯ ( X − x d ) z(X)=(X-x_0)(X-x_1)\cdots(X-x_{d}) z(X)=(X−x0)(X−x1)⋯(X−xd)【任意domain】
详细的步骤为:
- 1)将domain切分为左右两等份。
- 2)分别对左右两份计算zerofiers。
- 3)使用fast multiplication将zerofiers相乘。
def fast_zerofier( domain, primitive_root, root_order ):
assert(primitive_root^root_order == primitive_root.field.one()), "supplied root does not have supplied order"
assert(primitive_root^(root_order//2) != primitive_root.field.one()), "supplied root is not primitive root of supplied order"
if len(domain) == 0:
return Polynomial([])
if len(domain) == 1:
return Polynomial([-domain[0], primitive_root.field.one()])
half = len(domain) // 2
left = fast_zerofier(domain[:half], primitive_root, root_order)
right = fast_zerofier(domain[half:], primitive_root, root_order)
return fast_multiply(left, right, primitive_root, root_order)
6. 基于fast zerofier 实现 任意domain的fast evaluate
对任意domain { x 0 , x 1 , ⋯ , x d } \{x_0,x_1,\cdots,x_d\} {
x0,x1,⋯,xd},求evaluate f ( X ) f(X) f(X):
f ( x 0 ) , f ( x 1 ) , ⋯ , f ( x d ) f(x_0),f(x_1),\cdots, f(x_d) f(x0),f(x1),⋯,f(xd)
根据Lagrange,可将 f ( X ) f(X) f(X)表示为:
f ( X ) = ∑ i = 0 d f ( x i ) ∏ j = 0 , j ≠ i d ( X − x j ) ∏ j = 0 , j ≠ i d ( x i − x j ) f(X)=\sum_{i=0}^{d}f(x_i)\frac{\prod_{j=0,j\neq i}^{d}(X-x_j)}{\prod_{j=0,j\neq i}^{d}(x_i-x_j)} f(X)=∑i=0df(xi)∏j=0,j=id(xi−xj)∏j=0,j=id(X−xj)
计算 [ f ( x 0 ) , f ( x 1 ) , ⋯ , f ( x d ) ] [f(x_0),f(x_1),\cdots, f(x_d)] [f(x0),f(x1),⋯,f(xd)],相应的fast evaluate算法为:
def fast_evaluate( polynomial, domain, primitive_root, root_order ):
assert(primitive_root^root_order == primitive_root.field.one()), "supplied root does not have supplied order"
assert(primitive_root^(root_order//2) != primitive_root.field.one()), "supplied root is not primitive root of supplied order"
if len(domain) == 0:
return []
if len(domain) == 1: # 返回值为数组。
return [polynomial.evaluate(domain[0])]
half = len(domain) // 2
left_zerofier = fast_zerofier(domain[:half], primitive_root, root_order)
right_zerofier = fast_zerofier(domain[half:], primitive_root, root_order)
left = fast_evaluate(polynomial % left_zerofier, domain[:half], primitive_root, root_order)
right = fast_evaluate(polynomial % right_zerofier, domain[half:], primitive_root, root_order)
return left + right
7. 基于fast zerofier和fast evaluate 实现 fast interpolate
根据Lagrange,可将 f ( X ) f(X) f(X)表示为:
f ( X ) = ∑ i = 0 d f ( x i ) ∏ j = 0 , j ≠ i d ( x i − x j ) ∏ j = 0 , j ≠ i d ( X − x j ) f(X)=\sum_{i=0}^{d}\frac{f(x_i)}{\prod_{j=0,j\neq i}^{d}(x_i-x_j)}\prod_{j=0,j\neq i}^{d}(X-x_j) f(X)=∑i=0d∏j=0,j=id(xi−xj)f(xi)∏j=0,j=id(X−xj)
其中 ∏ j = 0 , j ≠ i d ( x i − x j ) \prod_{j=0,j\neq i}^{d}(x_i-x_j) ∏j=0,j=id(xi−xj)对应为fast_interpolate
算法中的left_offset
和right offset
变量:
def fast_interpolate( domain, values, primitive_root, root_order ):
assert(primitive_root^root_order == primitive_root.field.one()), "supplied root does not have supplied order"
assert(primitive_root^(root_order//2) != primitive_root.field.one()), "supplied root is not primitive root of supplied order"
assert(len(domain) == len(values)), "cannot interpolate over domain of different length than values list"
if len(domain) == 0:
return Polynomial([])
if len(domain) == 1:
return Polynomial([values[0]])
half = len(domain) // 2
left_zerofier = fast_zerofier(domain[:half], primitive_root, root_order)
right_zerofier = fast_zerofier(domain[half:], primitive_root, root_order)
left_offset = fast_evaluate(right_zerofier, domain[:half], primitive_root, root_order)
right_offset = fast_evaluate(left_zerofier, domain[half:], primitive_root, root_order)
if not all(not v.is_zero() for v in left_offset):
print("left_offset:", " ".join(str(v) for v in left_offset))
left_targets = [n / d for (n,d) in zip(values[:half], left_offset)]
right_targets = [n / d for (n,d) in zip(values[half:], right_offset)]
left_interpolant = fast_interpolate(domain[:half], left_targets, primitive_root, root_order)
right_interpolant = fast_interpolate(domain[half:], right_targets, primitive_root, root_order)
return left_interpolant * right_zerofier + right_interpolant * left_zerofier
8. 基于FFT实现fast coset evaluate
所谓coset evaluate,是对多项式 f ( X ) = ∑ i = 0 2 k − 1 c i X i f(X)=\sum_{i=0}^{2^k-1}c_iX^i f(X)=∑i=02k−1ciXi,求:
f ( g ⋅ 1 ) , f ( g ⋅ ω ) , ⋯ , f ( g ⋅ ω 2 k − 1 ) f(g\cdot 1), f(g\cdot \omega), \cdots, f(g\cdot \omega^{2^k-1}) f(g⋅1),f(g⋅ω),⋯,f(g⋅ω2k−1)
【对应为博客 STARK入门知识 4.3节 “Coset-FRI”。】
实际算法为:
- 1)将多项式 f ( X ) f(X) f(X) scale为 f ( g ⋅ X ) = ∑ i = 0 2 k − 1 ( c i ⋅ g ) ⋅ X i f(g\cdot X)=\sum_{i=0}^{2^k-1}(c_i\cdot g)\cdot X^i f(g⋅X)=∑i=02k−1(ci⋅g)⋅Xi。
- 2)使用FFT对 f ( g ⋅ X ) f(g\cdot X) f(g⋅X) 进行roots of unity domain内的批量polynomial evaluation。
def fast_coset_evaluate( polynomial, offset, generator, order ):
scaled_polynomial = polynomial.scale(offset)
values = ntt(generator, scaled_polynomial.coefficients + [offset.field.zero()] * (order - len(polynomial.coefficients)))
return values
9. 利用FFT和IFFT实现fast coset divide
为求 h ( X ) = f ( X ) z ( X ) h(X)=\frac{f(X)}{z(X)} h(X)=z(X)f(X),转为求 h ( g ⋅ X ) = f ( g ⋅ X ) z ( g ⋅ X ) h(g\cdot X)=\frac{f(g\cdot X)}{z(g\cdot X)} h(g⋅X)=z(g⋅X)f(g⋅X),然后将 h ( g ⋅ X ) h(g\cdot X) h(g⋅X)的系数除以 g g g,即为 h ( X ) h(X) h(X)的系数。
实际算法实现步骤为:
- 1)分子分母系数都Scale
- 2)NTT
- 3)element-wise divide
- 4)inverse NTT
- 5)对求得的商系数unscale
# 只考虑了能整除干净的情况
def fast_coset_divide( lhs, rhs, offset, primitive_root, root_order ): # clean division only!
assert(primitive_root^root_order == primitive_root.field.one()), "supplied root does not have supplied order"
assert(primitive_root^(root_order//2) != primitive_root.field.one()), "supplied root is not primitive root of supplied order"
assert(not rhs.is_zero()), "cannot divide by zero polynomial"
if lhs.is_zero():
return Polynomial([])
assert(rhs.degree() <= lhs.degree()), "cannot divide by polynomial of larger degree"
field = lhs.coefficients[0].field
root = primitive_root
order = root_order
degree = max(lhs.degree(),rhs.degree())
if degree < 8:
return lhs / rhs
while degree < order // 2:
root = root^2
order = order // 2
# 1)分子分母系数都Scale
scaled_lhs = lhs.scale(offset)
scaled_rhs = rhs.scale(offset)
lhs_coefficients = scaled_lhs.coefficients[:(lhs.degree()+1)]
while len(lhs_coefficients) < order:
lhs_coefficients += [field.zero()]
rhs_coefficients = scaled_rhs.coefficients[:(rhs.degree()+1)]
while len(rhs_coefficients) < order:
rhs_coefficients += [field.zero()]
# 2)NTT
lhs_codeword = ntt(root, lhs_coefficients)
rhs_codeword = ntt(root, rhs_coefficients)
# 3)element-wise divide
quotient_codeword = [l / r for (l, r) in zip(lhs_codeword, rhs_codeword)]
# 4)inverse NTT
scaled_quotient_coefficients = intt(root, quotient_codeword)
# 只取scaled_quotient_coefficients数组中前(lhs.degree() - rhs.degree() + 1)个数,其余均为0
scaled_quotient = Polynomial(scaled_quotient_coefficients[:(lhs.degree() - rhs.degree() + 1)])
# 5)对求得的商系数unscale
return scaled_quotient.scale(offset.inverse())
相应的测试用例为:
def test_divide( ):
field = Field.main()
logn = 6
n = 1 << logn
primitive_root = field.primitive_nth_root(n)
for trial in range(20):
lhs_degree = int(os.urandom(1)[0]) % (n // 2)
rhs_degree = int(os.urandom(1)[0]) % (n // 2)
lhs = Polynomial([field.sample(os.urandom(17)) for i in range(lhs_degree+1)])
rhs = Polynomial([field.sample(os.urandom(17)) for i in range(rhs_degree+1)])
fast_product = fast_multiply(lhs, rhs, primitive_root, n)
quotient = fast_coset_divide(fast_product, lhs, field.generator(), primitive_root, n)
assert(quotient == rhs), "fast divide does not equal original factor"
参考资料
[1] Anatomy of a STARK, Part 6: Speeding Things Up
[2] Filecoin zk-SNARK Accelerating
[3] GPU-SNARKs
[4] Plonk with GPU acceleration
[5] ZKSwap GPU optimization