[小白系列][线性回归模型]股票回归分析实例代码详解

代码详解

P.S:记录下第一个搞明白的模型哦!

import statsmodels.api as sm  # 基本api
import statsmodels.formula.api as smf  # 公式api
import statsmodels.graphics.api as smg  # 图形界面api
import patsy  # 主要类似 R 语言的公式转成 statsmodels 可以识别的形式
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pandas import Series,DataFrame
from scipy import stats  # Python中一个很好的统计推断包
import seaborn as sns
import datetime, os, warnings  # os模块是与计算机操作系统交互的一个接口,Python 的 os 模块封装了常见的文件和目录操作
                               # 和exception异常要求用户立刻进行处理不同,warnings通常用于提示用户一些错误或者过时的用法

warnings.filterwarnings('ignore')  # 通过警告过滤器进行控制是否发出警告消息,“ignore”参数的作用是:忽略匹配的警告
plt.rcParams['font.sans-serif']=['SimHei'] # 用来正常显示中文标签,设置为黑体
plt.rcParams['axes.unicode_minus'] = False  # 可以显示负号

# 设置起始时间
start = datetime.datetime(2019,1,1)
end = datetime.datetime(2019,12,31)
#print(start)
# Python提供了专门从财经网站获取金融数据的API接口:DataReader
from pandas_datareader.data import DataReader
# 读取上证综指 及 探路者数据
def load_data():
    # 判断在当前的网站目录下是否存在000001.csv和300005.csv的文件,如果没有的话,则利用DataReader接口,用pandas自带的read_csv方法下载
	if os.path.exists('000001.csv'):  # os.path.exists()就是判断括号里的文件是否存在的意思,括号内的可以是文件路径
		data_ss = pd.read_csv('000001.csv')
		data_tlz = pd.read_csv('300005.csv')
	else:
		# 上证综指
		data_ss = DataReader("000001.SS", "yahoo",start,end)
		# 300005 探路者股票 深证
		data_tlz = DataReader("300005.SZ", "yahoo",start,end)
		data_ss.to_csv('000001.csv')
		data_tlz.to_csv('300005.csv')
	return data_ss, data_tlz

data_ss, data_tlz = load_data()
# 数据探索,查看前五行数据
print(data_ss.head())
print(data_tlz.head())

# 探路者与上证综指
# 这里是把股票的收盘价close拿了出来
close_ss = data_ss["Close"]
close_tlz = data_tlz["Close"]
# 对两支股票分别做数据探索
print(close_ss.head()) 
print(close_tlz.head()) 
# 将探路者与上证综指进行数据合并
# 此处merge函数利用的是左右连接键名不一样时候的情况,详见[merge函数左右键名不一样的情况](https://blog.csdn.net/KaelCui/article/details/105156974)
stock = pd.merge(data_ss, data_tlz, left_index = True, right_index = True)
stock = stock[["Close_x","Close_y"]]  # 把合并之后的闭盘价格拿出来
stock.columns = ["上证综指","探路者"]
# 对合并之后的数据做数据探索
print(stock.head())
# 统计每日收益率
# .diff()函数,是用来将数据进行某种移动之后与原来的数据进行比较得出差异的数据
# .shift()函数,是可以把数据移动指定的位数,period = -1代表往上/往左,period = 1代表往下/往右
daily_return = (stock.diff()/stock.shift(periods = 1)).dropna()  # dropna代表去除空值,因为有除法,否则程序没有意义
print(daily_return.head())
# 找出当天收益率大于10%的,应该是没有,因为涨停为10%
print(daily_return[daily_return["探路者"] > 0.1])

# 每日收益率可视化
fig,ax = plt.subplots(nrows=1,ncols=2,figsize=(15,6))
daily_return["上证综指"].plot(ax=ax[0])
ax[0].set_title("上证综指")
daily_return["探路者"].plot(ax=ax[1])
ax[1].set_title("探路者")
plt.show()

# 散点图
fig,ax = plt.subplots(nrows=1,ncols=1,figsize=(12,6))
plt.scatter(daily_return["探路者"],daily_return["上证综指"])
plt.title("每日收益率散点图 from 探路者 & 上证综指")
plt.show()

# 回归分析
import statsmodels.api as sm
# 加入截距项
daily_return["intercept"]=1.0  # 此处默认选择截距项为1
model = sm.OLS(daily_return["探路者"],daily_return[["上证综指","intercept"]])  # OLS(Ordinary Least Squares):最小二乘法进行回归分析
results = model.fit()  # 做拟合
print(results.summary())  # 拟合结果展示
#  print( result.params())  # 对应的各个变量的权重


输出图例

最小二乘法分析结果

发布了5 篇原创文章 · 获赞 5 · 访问量 123

猜你喜欢

转载自blog.csdn.net/KaelCui/article/details/105156091
今日推荐