[Numerical analysis] Python implementation Lagrange Interpolation

Has been trying to achieve these interpolation formula in code about, nothing to do today, try try.

Start with the simplest of Lagrange interpolation start! Basics of Lagrange Interpolation Formula not go into details, search on Baidu a lot.

The basic idea is first read from the file into the given sample point, to select the appropriate sample points x interpolation section according to the number of desired input and a predicted point, the final calculation result obtained based functions. Direct look at the code! (Note: the point here that the sample is not very accurate, really stumped could not find a better description ...)

str2double

A small problem is how to convert the str type python into a float, after all, we give the sample points are not necessarily always be an integer, but also need to do some fault-tolerant processing, such as multiple + more - and so, it should be to identify the correct number. So implements a str2double method.

import re
def str2double(str_num):
    pattern = re.compile(r'^((\+*)|(\-*))?(\d+)(.(\d+))?$')
    m = pattern.match(str_num)
    if m is None:
        return m
    else:
        sign = 1 if str_num[0] == '+' or '0' <= str_num[0] <= '9' else -1
        num = re.sub(r'(\++)|(\-+)', "", m.group(0))
        matchObj = re.match(r'^\d+$', num)
        if matchObj is not None:
            num = sign * int(matchObj.group(0))
        else:
            matchObj = re.match(r'^(\d+).(\d+)$', num)
            if matchObj is not None:
                integer = int(matchObj.group(1))
                fraction = int(matchObj.group(2)) * pow(10, -1*(len(matchObj.group(2))))
                num = sign * (integer + fraction)
        return num

I use regular expressions to achieve, pattern = re.compile(r'^((\+*)|(\-*))?(\d+)(.(\d+))?$')can match all types of integer and floating point I mentioned above, after the match, the match is successful, if they are integers, the integer part of the direct return, with this (int)cast can; if it is floating points, then use (\d+)this regular expression matching again, respectively the integer and fractional parts, similar to the above processing of the integer part, the fractional part is multiplied with the pow(10, -小数位数)resulting, directly after the addition can be. Here In order to support a plurality of + or -, re.sub method using symbols removed, so we need to record the digital sign is positive or negative when the sign can be multiplied by the final return.

def binary_search(point_set, n, x):
    first = 0
    length = len(point_set)
    last = length
    while first < last:
        mid = (first + last) // 2
        if point_set[mid][0] < x:
            first = mid + 1
        elif point_set[mid][0] == x:
            return mid
        else:
            last = mid
    last =  last if last != length else last-1

    head = last - 1
    tail = last
    while n > 0:
        if head != -1:
            n -= 1
            head -= 1
        if tail != length:
            n -= 1
            tail += 1
    return [head+1, tail-1] if n == 0 else [head+1, tail-2]

Here is the set of all sample points point_set, n is the number of interpolation input, x is an input prediction point. Return the appropriate interpolation interval, i.e., as much as possible inside the packet x.

Because to get a suitable interpolation interval according to the input, so it involves finding knowledge. As used herein, a binary search, the first sample point set point_setare sorted (in ascending order), to find the first sample point is greater than the predicted point needs, both in its extended interval until the number of interpolation meet requirements. Here I realized that some problems may occur n=-1due to tailpay more once, and the whileouter loop has conducted a judge, n=-1when tail-2, this implementation is certainly not good, there may be bug. . .

Finally, the remaining content is better understood, put all the code directly.

import re
import matplotlib.pyplot as plt
import numpy as np

def str2double(str_num):
    pattern = re.compile(r'^((\+*)|(\-*))?(\d+)(.(\d+))?$')
    m = pattern.match(str_num)
    if m is None:
        return m
    else:
        sign = 1 if str_num[0] == '+' or '0' <= str_num[0] <= '9' else -1
        num = re.sub(r'(\++)|(\-+)', "", m.group(0))
        matchObj = re.match(r'^\d+$', num)
        if matchObj is not None:
            num = sign * int(matchObj.group(0))
        else:
            matchObj = re.match(r'^(\d+).(\d+)$', num)
            if matchObj is not None:
                integer = int(matchObj.group(1))
                fraction = int(matchObj.group(2)) * pow(10, -1*(len(matchObj.group(2))))
                num = sign * (integer + fraction)
        return num

def preprocess():
    f = open("input.txt", "r")
    lines = f.readlines()
    lines = [line.strip('\n') for line in lines]
    point_set = list()
    for line in lines:
        point = list(filter(None, line.split(" ")))
        point = [str2double(pos) for pos in point]
        point_set.append(point)
    return point_set

def lagrangeFit(point_set, x):
    res = 0
    for i in range(len(point_set)):
        L = 1
        for j in range(len(point_set)):
            if i == j:
                continue
            else:
                L = L * (x - point_set[j][0]) / (point_set[i][0] - point_set[j][0])
        L = L * point_set[i][1]
        res += L
    return res

def showbasis(point_set):
    print("Lagrange Basis Function:\n")
    for i in range(len(point_set)):
        top = ""
        buttom = ""
        for j in range(len(point_set)):
            if i == j:
                continue
            else:
                top += "(x-{})".format(point_set[j][0])
                buttom += "({}-{})".format(point_set[i][0], point_set[j][0])
        print("Basis function{}:".format(i))
        print("\t\t{}".format(top))
        print("\t\t{}".format(buttom))

def binary_search(point_set, n, x):
    first = 0
    length = len(point_set)
    last = length
    while first < last:
        mid = (first + last) // 2
        if point_set[mid][0] < x:
            first = mid + 1
        elif point_set[mid][0] == x:
            return mid
        else:
            last = mid
    last =  last if last != length else last-1

    head = last - 1
    tail = last
    while n > 0:
        if head != -1:
            n -= 1
            head -= 1
        if tail != length:
            n -= 1
            tail += 1
    return [head+1, tail-1] if n == 0 else [head+1, tail-2]

if __name__ == '__main__':
    pred_x = input("Predict x:")
    pred_x = float(pred_x)
    n = input("Interpolation times:")
    n = int(n)
    point_set = preprocess()
    point_set = sorted(point_set, key=lambda a: a[0])
    span = binary_search(point_set, n+1, pred_x)
    print("Chosen points: {}".format(point_set[span[0]:span[1]+1]))
    showbasis(point_set[span[0]:span[1]+1])

    X = np.linspace(-np.pi, np.pi, 256, endpoint=True)
    S = np.sin(X)
    L = [lagrangeFit(point_set, x) for x in X]
    L1 = [lagrangeFit(point_set[span[0]:span[1]+1], x) for x in X]
    
    plt.figure(figsize=(8, 4))
    plt.plot(X, S, label="$sin(x)$", color="red", linewidth=2)
    plt.plot(X, L, label="$LagrangeFit-all$", color="blue", linewidth=2)
    plt.plot(X, L1, label="$LagrangeFit-special$", color="green", linewidth=2)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title("$sin(x)$ and Lagrange Fit")
    plt.legend()
    plt.show()

About Input

Input.txt sample points used for reading, for each point of a line, a space in the middle.

result

Feeling quite good fun hhh, Newton interpolation try a few days! Bai Bai!

Guess you like

Origin www.cnblogs.com/LuoboLiam/p/11706151.html