import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
import PIL
from torchvision import transforms
from transforms import *
from PIL import Image
import matplotlib.pyplot as plt
subpolicies =[]
sub=[]
subpolicy1 = transforms.Compose([
## baseline augmentation
transforms.Resize([32, 32]),
transforms.RandomCrop([120, 720]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
subpolicy2 = transforms.Compose([
## baseline augmentation
transforms.Resize([32, 64]),
transforms.RandomCrop([120, 720]),
# transforms.Pad(4),
# transforms.RandomCrop(32),
# transforms.RandomCrop([140, 720]),
transforms.RandomHorizontalFlip(),
## policy
# *subpolicy,
## to tensor
transforms.ToTensor()])
tran1 = transforms.Compose([transforms.Resize([64, 64]),transforms.CenterCrop((12, 12)) ])
tran2= transforms.Compose([transforms.Resize([32, 32]), transforms.CenterCrop((20, 12))])
sub=[tran1 , tran2]
print("subpolicy2 222", sub)
print("subpolicies", sub[0])
transform1 = transforms.RandomChoice(sub)
print("trans1 ", transform1)
img_rgb = Image.open("./cifar-10-images/temp/1_dog.jpg")
# img_store = t.tensor(img_rgb).permute(2,0, 1)
img_store = transform1(img_rgb)
img_store = np.uint8(img_store)
plt.figure("dog")
plt.imshow(img_store)
plt.show()
备注:tran1 和tran2 如果转换成tensor 则需要转换后,方可图像显示。