Fredholm第二类积分方程的Python代码实现(1)

 这是《The Numerical Solution of Integral Equations of the Second Kind》书上的例题代码实现,基于Python来编写的,其中的干货还是有的!!!随上课的进度,将不断更新...... 
 
# -*- coding: utf-8 -*-
"""
Created on Mon Apr  9 21:42:44 2018
@author: shaowu

本代码主要实现Fredholm integral equations of the second kind,方程如下:
    lamda*x(t)-integrate(K(t,s)*x(s))ds=y(t),a=0<=t<=b,K(t,s)取exp(s*t)
首先,我们给定两个精确解exp(-t)cos(t)和sqrt(t),求出其对应的y(t),然后再来反解x(t).更详细说明可
参见《The Numerical Solution of Integral Equations of the Second Kind》P30-34.
"""
import sympy as sp
import scipy as scp
import numpy as np
import pandas as pd
import numpy.linalg as LA
import matplotlib.pyplot as plt
from scipy.special.orthogonal import p_roots
import time
start_time=time.time()

def factorial(n):  ##定义阶乘函数,返回一个数
    if n==1 or n==0:
        return 1
    else:
        return n*factorial(n-1)
def gauss_xw(m=400):##默认用400个点求高斯——勒让德节点xi和权weight,并返回x和w数组
    x,w=p_roots(m+1)
    return x,w
def gauss_solve_b(x,w,lamda,a,b,n):
    """
    用Gauss_Legendre积分公式求解方程的右端项,即从a到b积分y(s)*s^(i-1),i=1...n;
    首先把区间[a,b]变换到[-1,1];其中下面分成了两部分数值积分,第一个部分是一重积分,后面是二重积分;
    这里的第一个xt1=sqrt(t),第二个xt2=exp(-t)cos(t),核函数K(t,s)=exp(s*t),其都可以改成其他的函数!
    结果返回两个n维的右端项,对应于xt1和xt2
    """
    c=(b-a)/2
    s=(b-a)/2*x+(a+b)/2 ##对区间做变换:[a,b]-->[-1,1]
    return [((w*c*lamda*scp.exp(-s)*scp.cos(s)*s**i).sum()-
            sum([c*c*w[k]*sum([w[j]*scp.exp(s[k]*s[j])*scp.exp(-s[j])*scp.cos(s[j])*s[k]**i for j in range(len(s))]) for k in range(len(s))]))
            for i in range(n)],[((w*c*lamda*scp.sqrt(s)*s**i).sum()-
            sum([c*c*w[k]*sum([w[j]*scp.exp(s[k]*s[j])*scp.sqrt(s[j])*s[k]**i for j in range(len(s))]) for k in range(len(s))]))
            for i in range(n)]

def quad_romberg(lamda,a,b,n): 
    """
    这里是用quad函数和romberg函数求方程组的右端项,精度亚于Gauss_Legendre积分.
    其中b1是x(t)=exp(-t)*cos(t)对应的右端项;b2是x(t)=sqrt(t)对应的右端项;
    结果返回b1和b2,其维数都是n
    """
    b1=[]
    b2=[]
    for i in range(n):
        b1.append(scp.integrate.quad(lambda s:(lamda*scp.exp(-s)*scp.cos(s)-
            scp.integrate.quad(lambda t:scp.exp(s*t)*scp.exp(-t)*scp.cos(t),a,b)[0])*s**(i),a,b)[0])
    for i in range(n): ##龙贝格积分
        b2.append(scp.integrate.romberg(lambda s:(lamda*scp.sqrt(s)-
            scp.integrate.romberg(lambda t:scp.exp(s*t)*scp.sqrt(t),a,b))*s**(i),a,b))
    return b1,b2
def solve_AC(lamda,a,b,n,f): 
    """
    这里是计算方程组的系数矩阵,即A;同时使用numpy.linalg.solve函数计算出C;输入的参数中,f表示方程组的右端项.
    结果返回一个n*n维的矩阵A,和一个n维的向量C
    """
    A=[]#存放系数矩阵
    for i in range(1,n+1):
        guodu=[-b**(i+j-1)/(factorial(j-1)*(i+j-1)) for j in range(1,n+1)]
        guodu[i-1]=lamda+guodu[i-1]
        A.append(guodu)
    A=np.array(A)
    f=np.array(f,dtype='float')
    C=np.linalg.solve(A,f) #求解C
    return A,C
def solve_xn(lamda,a,b,n,c1,c2):
    """
    这里计算xn(t);xnt1和xnt2分别对应xt1=exp(-t)*cos(t)和xt2=sqrt(t)的情况.
    为了计算范数,这里把t属于[a,b]进了行离散
    结果返回误差E1和E2
    """
    tt=np.array([i for i in np.arange(a+0.1,b+0.1,0.1)]) #变量tt取10个点
    xt1=scp.exp(-tt)*scp.cos(tt)  ##第一个精确解xt1
    xt2=scp.sqrt(tt)   ##第二个精确解xt2
    xnt1=[];xnt2=[]
    """
    ################################################################################################
    #这部分代码中的y1和y2是用sympy中的符号积分直接求得,分别对应于精确解xt1和xt2;
    s=sp.Symbol('s')
    t=sp.Symbol('t')
    y1=-sp.Integral(sp.exp(-s)*sp.exp(s*t)*sp.cos(s), (s, a, b)) + lamda*sp.exp(-t)*sp.cos(t)
    #y1=5*sp.exp(-t)*sp.cos(t)-sp.integrate(sp.exp(s*t)*sp.exp(-s)*sp.cos(s),(s,0,1))##get the y1(t)
    y2=lamda*sp.sqrt(t)-sp.integrate(sp.exp(s*t)*sp.sqrt(s),(s,a,b))
    for j in tt:
        m1=0;m2=0
        for i in range(n):
            m1=m1+c1[i]*j**i/factorial(i)
            m2=m2+c2[i]*j**i/factorial(i)
        xnt1.append((m1+y1.subs(t,j).evalf())/lamda)
        xnt2.append((m2+y2.subs(t,j).evalf())/lamda)
    ################################################################################################

    """
    ################################################################################################
    #这部分代码是用scipy中的数值积分求对应的y1和y2,和上一部分的相差不大
    for j in tt:
        m1=0;m2=0
        for i in range(n):
            m1=m1+c1[i]*j**i/factorial(i)
            m2=m2+c2[i]*j**i/factorial(i)
        xnt1.append((m1+lamda*scp.exp(-j)*scp.cos(j)-
            scp.integrate.quad(lambda x:scp.exp(x*j)*scp.exp(-x)*scp.cos(x),a,b)[0])/lamda)
        xnt2.append((m2+lamda*scp.sqrt(j)-
            scp.integrate.quad(lambda x:scp.exp(x*j)*scp.sqrt(x),a,b)[0])/lamda)
    
    ################################################################################################
    
    E1=LA.norm(xt1-xnt1,np.inf)  ##计算无穷范数
    E2=LA.norm(xt2-xnt2,np.inf)
    #如果n=10,那么就画误差图
    if n==10:
        plt.plot(tt,xt1-xnt1,'-',label='x1')
        plt.plot(tt,xt2-xnt2,'--r',label='x2')
        plt.title('Figure1.(where a=0,b=1,n=10,lambda=5)')
        plt.xlabel('t')
        plt.ylabel('error')
        plt.legend()
    return E1,E2
if __name__ == '__main__':
    print('******************************程序入口*******************************')
    lamda=int(input('pleas input lambda:'))
    nn=int(input('please input degree n:'))
    a=int(input('please input the left value of interval:'))
    b=int(input('please input the right value of interval:'))
    m=int(input('please input the node of Gauss_Legendre:'))
    print('计算中...')
    #lamda=5;a=0;b=1;m=500
    error=[]
    for n in range(1,nn+1,1):#范围从1到nn+1,后面的1为步长,可以省略
        x,w=gauss_xw(m) ##取400个高斯点
        f1,f2=gauss_solve_b(x,w,lamda,a,b,n)##利用Gauss_Legendre求解右端项
        #ff1,ff2=quad_romberg(lamda,a,b,n)##利用quad和Romberg求解右端项
        A1,c1=solve_AC(lamda,a,b,n,f1) ##计算C
        A2,c2=solve_AC(lamda,a,b,n,f2)
        error.append(solve_xn(lamda,a,b,n,c1,c2)) ##计算无穷范数
    error=pd.DataFrame(error)
    error.columns=['E1','E2']
    error.insert(0,'n',[i for i in range(1,nn+1,1)])
    print(error)
    plt.figure(2)
    plt.plot(error['n'],-np.log10(error['E1']),'o',label='error in x1')
    plt.plot(error['n'],-np.log10(error['E2']),'*',label='error in x2')
    plt.title('Figure2.(where a=0,b=1,lambda=5)')
    plt.xlabel('degree n')
    plt.ylabel('-log10 of error')
    plt.legend()
    print('all_cost_time:',time.time()-start_time)

猜你喜欢

转载自blog.csdn.net/wushaowu2014/article/details/79895113
今日推荐