MXNET源码中TShape值的获取和打印

承接上一篇(https://blog.csdn.net/zhqh100/article/details/91438657),尝试打印TShape的数值,

同样还是文件incubator-mxnet/src/c_api/c_api.cc中,函数MXNDArrayReshape64修改如下:

MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
                                 int ndim,
                                 dim_t *dims,
                                 bool reverse,
                                 NDArrayHandle *out) {
  NDArray *ptr = new NDArray();
  API_BEGIN();
  NDArray *arr = static_cast<NDArray*>(handle);
  mxnet::Tuple<dim_t> shape(dims, dims+ndim);
  CHECK_GT(arr->shape().Size(), 0) << "Source ndarray's shape is undefined. Input shape: "
    << arr->shape();
  mxnet::TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), reverse);
  std::cout << "ndim =" << new_shape.ndim() << std::endl;
  for (int i = 0; i < new_shape.ndim(); i ++){
    std::cout << "new_shape[" << i << "] = " << new_shape[i] << std::endl;
  }

  *ptr = arr->ReshapeWithRecord(new_shape);
  *out = ptr;
  API_END_HANDLE_ERROR(delete ptr);
}

Python测试代码为:

from mxnet import autograd, nd
import mxnet
print(mxnet.__version__)

x = nd.arange(42)
x = x.reshape((6, 7))
print(x[2:4].asnumpy())

打印结果为:

# python3 mxnet_test.py 
1.5.0
ndim =2
new_shape[0] = 6
new_shape[1] = 7
slice_begin:2
slice_end:4
p[0] = 14
p[1] = 15
p[2] = 16
p[3] = 17
[[14. 15. 16. 17. 18. 19. 20.]
 [21. 22. 23. 24. 25. 26. 27.]]

现在我们可以从C++中获取到NDArray的真实信息了

猜你喜欢

转载自blog.csdn.net/zhqh100/article/details/91441675