libtorch中tensor与vector的转换方法

tensor与vector的转换,是通过数据的指针来完成的。下面以ATen为例讲解,其他的如torch等只是命名空间不一样,其他的是一样的。

tensor转vector

#include<ATen/ATen.h>//引入头文件
#include<iostream>
using namespace std;
int main(){
    at::Tensor t=at::ones({2,2},at::kInt);//建立一个2X2的tensor
    vector<int> v(t.data_ptr<int>(),t.data_ptr<int>()+t.numel());//将tensor转换为vector
    //输出转换后的结果
    for(auto val:v){
        cout<<val<<" ";
    }
    cout<<endl;
}

t是一个类型为at::kInt的tensor,其中kInt可以用其他数据类型替换如kFloat等,t.data_ptr<int>()返回int类型的指针,返回的地址是数据存储的起始位置。t.numel()返回t中的元素个数。

vector转tensor

#include<ATen/ATen.h>
#include<iostream>
using namespace std;
int main(){
    vector<int> v={1,2,3,4};
    at::TensorOptions opts=at::TensorOptions().dtype(at::kInt);
    c10::IntArrayRef s={2,2};//设置返回的tensor的大小
    at::Tensor t=at::from_blob(v.data(),s,opts).clone();
    cout<<t<<endl;
}

opts用来对返回的tensor作一些额外的解释,例如类型。s用来指定返回的tensor的维度。clone是为了深复制,让后面对t的操作不会收到v的影响。v.data()返回vector中的数据的指针。

猜你喜欢

转载自blog.csdn.net/watqw/article/details/123363618