(1)配置mxnet开发环境
最方便的方式是使用pycharm开发。1. 安装pycharm。 2.打开pycharm的file/default-setting点击安装 mxnet包。
(2)mxnet初试
import mxnet as mx
import numpy as np
import logging
logging.getLogger().setLevel(logging.DEBUG)
#preparing the data
#train_data
train_data = np.random.uniform(0,1,[100,2])
train_label = np.array([train_data[i][0] + 2 * train_data[i][1] for i in range(100)])
batch_size = 1
#Evaluation Data
eval_data = np.array([[7,2],[6,10],[12,2]])
eval_label = np.array([11,26,16])
train_iter = mx.io.NDArrayIter(train_data,train_label, batch_size, shuffle=True,label_name='lin_reg_label')
eval_iter = mx.io.NDArrayIter(eval_data, eval_label, batch_size, shuffle=False)
#define model
X = mx.sym.Variable('data')
Y = mx.sym.Variable('lin_reg_label')
fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden = 1)
lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
model = mx.mod.Module(
symbol = lro ,
data_names=['data'],
label_names = ['lin_reg_label']# network structure
)
#mx.viz.plot_network(symbol=lro).view()
model.fit(train_iter, eval_iter,
optimizer_params={'learning_rate':0.005, 'momentum': 0.9},
num_epoch=20,
eval_metric='mse',
batch_end_callback = mx.callback.Speedometer(batch_size, 2))
print model.predict(eval_iter).asnumpy()
代码流程
- 新建训练数据 mx.io.NDArrayIter
其参数为,默认data_name = ‘data’, label_name = ‘softmax_name’,
def __init__(self, data, label=None, batch_size=1, shuffle=False,
last_batch_handle='pad', data_name='data',
label_name='softmax_label'):
X = mx.sym.Variable('data') 对应输入数据的data
Y = mx.sym.Variable('lin_reg_label') 输入数据 lin_reg_label
2.建立模型 mx.mod.Module, 其参数为
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
logger=logging, context=ctx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None, group2ctxs=None,
compression_params=None):
3.开始训练 model.fit
def fit(self, train_data, eval_data=None, eval_metric='acc',
epoch_end_callback=None, batch_end_callback=None, kvstore='local',
optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
eval_end_callback=None,
eval_batch_end_callback=None, initializer=Uniform(0.01),
arg_params=None, aux_params=None, allow_missing=False,
force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
validation_metric=None, monitor=None):