Day10:字符串相乘(Karatsuba 乘法)

leetcode地址:https://leetcode-cn.com/problems/multiply-strings/

Day10:字符串相乘(Karatsuba 乘法)

一. 问题背景:

           

因为字符串相乘问题涉及字符串相加,因此先引入字符串相加问题。

    leetcode地址:https://leetcode-cn.com/problems/add-strings/

            

 

二. 解决思路:

1. 字符串相加

    模拟竖式相加。将输入字符串逆序并返回字符对应数字的ASCII 数值或 Unicode 数值,对较短字符串补‘0’。相加,分别保留进位和结果。

2.字符串相乘(Karatsuba 乘法)

    模拟竖式乘法,并使用Karatsuba 乘法提升效率。

                                                    

    Karatsuba 乘法是一种快速乘法。此算法在1960年由 Anatolii Alexeevitch Karatsuba 提出,并于1962年得以发表。此算法主要用于两个大数相乘。普通乘法的复杂度是 n²,而 Karatsuba 算法的复杂度仅为 3nlog2(3) 

    可以注意到 AD +BC 这个计算需要两个 O(n²/4) 的乘法和一个 O(n) 的加法。

                     (A+B)(C+D)-AC-BD=AC+AD+BC+BD-AC-BD=AD+BC

    Karatsuba 把这个原始的计算改成上面这个计算,因为 AC 和 BD 是已知的,因此现在的 AD +BC 的时间复杂度变成了 一个 O(n²/4) 的乘法和四个 O(n) 的加法或减法,达到 O(n^log2(3)) 的时间复杂度。

三. 算法实现:

1. 字符串相加

#1 借用int()处理每一位字符
def add(x, y):
    if len(x) > len(y):
       y = '0'*(len(x)-len(y)) + y
    else:
       x = '0'*(len(y)-len(x)) + x
       
    x, y = x[::-1], y[::-1]
    carry = 0
    z = ''
    
    for i in range(len(x)):
        z += str((int(x[i]) + int(y[i]) + carry) % 10)
        carry = (int(x[i]) + int(y[i]) + carry) // 10
    else:
        if carry == 1:
            z += '1'
    
    return z[::-1]


x, y = '9', '99'
print(add(x, y))
    

#2 使用ord()返回字符对应ASCII或Unicode
def addStrings(num1: str, num2: str) -> str:
    a = [ord(i)-ord('0') for i in list(num1[::-1])]  #返回对应的 ASCII 数值,或者 Unicode 数值
    b = [ord(i)-ord('0') for i in list(num2[::-1])]
    n1,n2 = len(num1),len(num2)
    n = max(n1,n2)
    c = [0]*(n+1)
    #print(c)
    for i in range(n):
        x = a[i] if i<n1 else 0
        y = b[i] if i<n2 else 0         #补0
        j = x+y+c[i]
        c[i] = j%10
        c[i+1] = j//10
    c = [chr(ord('0')+i) for i in c[::-1]]
    num3 = ''.join(c)
    #print(n,c,num3)
    if n>=1 and num3[0]=='0':
        return num3[1:]
    else:
        return num3
    
print(addStrings('6994','36'))


#3 双指针法
def addStrings(num1: str, num2: str) -> str:
    res = ""
    i, j, carry = len(num1) - 1, len(num2) - 1, 0
    while i >= 0 or j >= 0:
        n1 = ord(num1[i]) - ord('0') if i >= 0 else 0
        n2 = ord(num2[j]) - ord('0') if j >= 0 else 0
        tmp = n1 + n2 + carry
        carry = tmp // 10
        res = str(tmp % 10) + res
        i, j = i - 1, j - 1
    if carry: res = "1" + res
    return res

print(addStrings('6994','36'))

2. 字符串相乘

#karatsuba算法
from time import perf_counter
def add(num1, num2):
    z, carry = [], 0
    #补齐长度
    n = len(num1) - len(num2)
    if n > 0:
        num2 += [0]*n
    else:
        num1 += [0]*-n
    
    for i in range(len(num1)):
        tmp = num1[i] + num2[i] + carry
        z += [tmp % 10]
        carry = tmp // 10
    if carry:
        z += [carry]
    
    return z


def sub(num1, num2):
    z, carry = [], 0
    #补齐长度
    n = len(num1) - len(num2)
    if n > 0:
        num2 += [0]*n
        
    for i in range(len(num1)):
        tmp = num1[i] - num2[i] + carry
        z += [tmp % 10]
        carry = tmp // 10
#    while(not z[-1] and len(z) > 1):
#        del z[-1]
            
    return z

def karatsuba(num1, num2):    
    #补齐长度
    n = len(num1) - len(num2)
    if n > 0:
        num2 += [0]*n
    else:
        num1 += [0]*-n
    
    #基链
    if len(num1) == 1:
        return add([num1[0]*num2[0]],[0])
    
    #分割
    n_2 = (len(num1) + 1) >> 1
    A, B = num1[n_2:], num1[:n_2]
    C, D = num2[n_2:], num2[:n_2]
    
    # karatsuba乘法计算
    tmp0 = karatsuba(A, C)
    tmp1 = karatsuba(B, D)
    tmp2 = karatsuba(add(A, B), add(C, D))
    tmp2 = sub(sub(tmp2, tmp0), tmp1)
    #print(tmp2)
    #print(tmp2)
    
    z = add(tmp1, [0] * (n_2 << 1) + tmp0)   #移位相加
    z = add(z, [0] * n_2 + tmp2)
    
    return z

def mult(num1, num2):
    #返回对应的 ASCII 数值,或者 Unicode 数值
    num1 = [ord(i)-ord('0') for i in list(num1[::-1])]  
    num2 = [ord(i)-ord('0') for i in list(num2[::-1])]
    z = karatsuba(num1, num2)        
    while len(z) > 1 and z[-1] == 0:
        del z[-1]
    return ''.join(map(str, reversed(z)))

#print(mult('2345','123'))

x, y = '2345', '123'
start = perf_counter()
for i in range(100):
    mult(x,y)
print(mult(x,y))
print("运行时间是: {:.5f}s".format(perf_counter() - start))





#模拟竖式相乘
def multiply(num1: str, num2: str) -> str:
    if num1 == "0" or num2 == "0":
        return "0"
    l1 = len(num1)
    l2 = len(num2)
    mul = [0] * (l1 + l2)
    v0 = ord('0')
    for i in range(l1 - 1, -1, -1):
        for j in range(l2 -1, -1, -1):
            bitmul = (ord(num1[i]) - v0) * (ord(num2[j]) - v0)
            bitmul += mul[i + j + 1]
            mul[i + j] += bitmul // 10
            mul[i + j + 1] = bitmul % 10
    n = 0
    while n < len(mul):
        if mul[n] != 0:
            break
        n += 1
    s = ''
    for i in range(n, len(mul)):
        s += str(mul[i])
    return s

x, y = '2345', '123'
start = perf_counter()
for i in range(100):
    multiply(x,y)
print(multiply(x,y))
print("运行时间是: {:.5f}s".format(perf_counter() - start))

 

注:关于算法复杂度

比如一种算法的时间复杂度是 O(N),那么运行时间就是正比于要素个数N,另一种排序算法的时间复杂度是O(N*LogN),那么运行时间就正比于N*LogN。所以N足够大的情况下,总是第一种算法快。

但是,如果N不是很大,那么具体的运算时间并不一定都是前一种算法快,比如刚才的第一种算法的实际速度是 100×N, 第二种算法的实际速度是 2× N × LogN,N=100,就会是第二种算法快。

猜你喜欢

转载自blog.csdn.net/A994850014/article/details/95177369