Pytorch 函数查看

版权声明:本文为博主原创文章,如需转载请附上博文链接 https://blog.csdn.net/wendygelin/article/details/89154548

刚接触pytorch,对其结构之类的都不熟悉。在改一个程序的时候,需要将RGB图改成灰度图,所以在torchvision.transforms.Compose中加入

        transforms.ToPILImage(),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),

但是问题出现了,ToPILImage需要输入时uint8的形式,而原本的输入时float32,总是报错。

为了解决这个问题,我找了transforms的所有函数,没有在compose中将输入的数据变为uint8格式的函数。我又想自己添加一个函数,找了好久,找到pytorch库的位置,在我这里的位置是:

/usr/local/lib/python3.5/dist-packages/torchvision/transforms
transfroms.py和functional.py都在这里,可以查看函数,修改函数和添加函数。看了看,也不好改。

于是想在引用数据的时候对数据进行处理,找到了用到input_transform = transforms.Compose(...) 的地方是:

    train_set, test_set = datasets.__dict__[args.dataset](     
        args.data,
        transform=input_transform,
        target_transform=target_transform, # 为flow的标签
        co_transform=co_transform,
        split=args.split_file if args.split_file else args.split_value
    )

本来以为datsets为一个自带的函数,但是发现并不是,自带的这里有解释:

https://zhuanlan.zhihu.com/p/30934236

我琢磨之后搞懂了,datasets为我程序文件夹下的datasets文件夹,而[args.dataset]的意思是,在datasets文件夹下有很多的.py文件,每个文件中定义了若干个函数,这里_dict_是在全部这些里找的意思,找args.dataset的值的函数。我这里args.dataset用的是default的值‘flying_chairs’,而datasets文件夹下有个flyingchairs.py文件中定义了flying_chairs这个函数。所以,这里datasets.__dict__[args.dataset](。。。)的意思是,引用flying_chairs这个函数,后面()中的参数也是传递到flying_chairs中的参数。

这就可以了,逐步追根溯源,在input_transform的input前加上np.uint8就不报错了。

小白自己琢磨,记录下来。老手勿喷

猜你喜欢

转载自blog.csdn.net/wendygelin/article/details/89154548