(Quantitative) Implement a simple backtesting framework with Tushare package

【Reference】
1. Station B: Tsinghua Computer PhD takes you to learn - Python financial quantitative analysis
2. Tushare official website (author ID: 492952)

The project code included in this blog basically refers to Reference1, and the update of the tushare API has been modified accordingly. All the code has been uploaded to the author's github , the following is only a knowledge card for the code ideas

1. Backtesting requirements & effect display

Taking China Ping An (601318.SH) as the trading object, test the income situation of the double moving average strategy from 2020-05-10 to 2021-01-01. The effect is as follows:
Please add image description

2. Code Framework

2.1 Objects

(1) Store account information and backtest information

Account information: cash, stocks held

Backtest information:

  1. start/end date
  2. current date
  3. Benchmark: Generally, a stock or an index is used as a benchmark to compare the pros and cons of strategies
  4. Information for all trading days between start and end
class Context:
    def __init__(self, cash, start_date, end_date):
        # 账户信息
        self.cash = cash    # 现金
        self.positions = {
    
    } # 持有的股票信息
        # 回测信息
        self.start_date = start_date
        self.end_date = end_date
        self.dt = start_date
        self.benchmark = None
        self.date_range = trade_cal[(trade_cal['is_open'] == '1') & \
                                    (trade_cal['cal_date'] >= start_date) & \
                                    (trade_cal['cal_date'] <= end_date)]['cal_date'].values

Among them, trade_calthe date information of all trading days is stored. Available from the pro.trade_cal()tushare

(2) Store other global variables

class G:
	pass

2.2 Functions

(0) main function

def run():
    initialize(context) # 初始化

	# 创建一个DataFrame,储存画图所需数据
    plt_df = pd.DataFrame(index=pd.to_datetime(context.date_range), columns=['value'])
    # 存储股票的上一个交易日的价格信息,避免以为股票停牌而无法获取价格
    last_prices = {
    
    }
    # 回测开始时的股票价格,用于计算之后的价格变动曲线
    initial_value = context.cash
    
    for dt in context.date_range:
        context.dt = dt # 更新当前时间
        handle_data(context)
		# 计算账户价值 = 现金 + 持仓股票
        value = context.cash
        for stock in context.positions.keys():
            try:
                data = get_today_data(stock)
                last_prices[stock] = data['open']
            except KeyError:
                # 如果取不到,说明当日停牌,取上一个交易日的价格
                price = last_prices[stock]
            value += price * context.positions[stock].amount
        plt_df.loc[dt, 'value'] = value
        
    # 绘制策略
    plt_df['ratio'] = (plt_df['value'] - initial_value) / initial_value
    # 绘制基准
    bm_df = attribute_daterange_history(context.benchmark, context.start_date, context.end_date)
    bm_init = bm_df['open'][0]
    plt_df['benchmark_ratio'] = (bm_df['open'] - bm_init) / bm_init

    plt_df[['ratio', 'benchmark_ratio']].plot()
    plt.show()

(1) Initialization function

  1. Set target stocks
  2. set benchmark
  3. Set double moving average information (g.p1 & g.p2)
  4. Obtain the historical data of the target stock: In the following code, the data of hist_1the g.p2day hist_2is obtained, and the data from the start to the end of the backtest is obtained. After combining the two, you can get all the stock price data required for the double moving average backtest
def initialize(context):
	g.security = '601318.SH'
    set_benchmark('601318.SH')
    g.p1 = 5
    g.p2 = 60

    hist_1 = attribute_history(g.security, g.p2)
    hist_2 = attribute_daterange_history(g.security, context.start_date, context.end_date)
    g.hist = hist_1.append(hist_2) 

(2) Functions that need to be executed every trading day

def handle_data(context):
    hist = g.hist[:dateutil.parser.parse(context.dt)][-g.p2:]
    ma5 = hist['close'][-g.p1:].mean()
    ma60 = hist['close'].mean()

	# 实现双均线策略:
	# 如果短均线高于长均线,且股票不在持仓中,则买入
	# 反之,且股票在持仓中,则卖出
    if ma5 > ma60 and g.security not in context.positions.keys():
        order_value(g.security, context.cash)
    elif ma5 < ma60 and g.security in context.positions.keys():
        order_target(g.security, 0)

(3) Order function

The code uploaded in github contains a total of four ordering functions:

  1. order: buy a certain number of shares
  2. order_valuebuy a certain amount of stock
  3. order_targetbuy a certain number of shares
  4. order_target_valuebuy a certain amount of stock

And these four ordering functions are based on:

def _order(today_data, security, amount):
    if today_data.empty: return
    # 获取股票价格
    price = today_data['open']
    # 判断是否持有该股票
    try:
        test = context.positions[security]
    except KeyError:
        # 如果卖出操作,直接退出函数
        if amount <= 0: return
        # 如果买入操作,创建position
        context.positions[security] = pd.Series(dtype=float)
        context.positions[security]['amount'] = 0

    # 买入/卖出操作时,必须以100的倍数购买,除非全部卖出
    if (amount % 100 != 0) and (amount != -context.positions[security].amount):
        amount = int(amount/100) * 100

    # 更新持仓
    context.positions[security].amount = context.positions[security].get('amount') + amount
    if context.positions[security].amount == 0: # 如果持仓股数为0,删除
        del context.positions[security]
    # 更新现金
    context.cash -= amount * price

Guess you like

Origin blog.csdn.net/weixin_43728138/article/details/122859016