记录常使用的函数避免遗忘
def upsample(x,scale=2,features=64,activation=tf.nn.relu):
assert scale in [2,3,4]
x = slim.conv2d(x,features,[3,3],activation_fn=activation)
if scale == 2:
ps_features = 3*(scale**2) #filter个数,[3,3]卷积核维度
x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation)
#x = slim.conv2d_transpose(x,ps_features,6,stride=1,activation_fn=activation)
x = PS(x,2,color=True)
elif scale == 3:
ps_features =3*(scale**2) #特征图个数发生改变 64变成12
x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation)
#x = slim.conv2d_transpose(x,ps_features,9,stride=1,activation_fn=activation)
x = PS(x,3,color=True)
elif scale == 4:
ps_features = 3*(2**2)
for i in range(2):
x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation)
#x = slim.conv2d_transpose(x,ps_features,6,stride=1,activation_fn=activation)
x = PS(x,2,color=True)
return x
def PS(X, r, color=False):
if color:
Xc = tf.split(X, 3, 3) #将x在第3个维度切成3份 10*50*50*12切割成 10*50*50*4
#value:准备切分的张量; num_or_size_splits:准备切成几份; axis : 准备在第几个维度上进行切割
X = tf.concat([_phase_shift(x, r) for x in Xc],3) #对每一个通道填充像素
else:
X = _phase_shift(X, r)
return X
def _phase_shift(I, r):
bsize, a, b, c = I.get_shape().as_list()# bsize = 10, a=50, b=50, c=4
bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
X = tf.reshape(I, (bsize, a, b, r, r))
X = tf.transpose(X, (0, 1, 2, 4, 3)) # bsize, a, b, 1, 1
X = tf.split(X, a, 1) # a * [bsize, b, r, r]
#tf.squeeze函数
#从tensor中删除所有大小是1的维度,axis可以用来指定要删掉的为1的维度,但指定的维度必须确保其是1,否则会报错
X = tf.concat([tf.squeeze(x, axis=1) for x in X],2) # bsize, b, a*r, r
X = tf.split(X, b, 1) # b * [bsize, a*r, r]
X = tf.concat([tf.squeeze(x, axis=1) for x in X],2) # bsize, a*r, b*r
return tf.reshape(X, (bsize, a*r, b*r, 1))
def my_anti_shuffle(input_image, ratio):
shape = input_image.shape
ori_height = int(shape[0])
ori_width = int(shape[1])
ori_channels = int(shape[2])
if ori_height % ratio != 0 or ori_width % ratio != 0:
print("Error! Height and width must be divided by ratio!")
return
height = ori_height // ratio
width = ori_width // ratio
channels = ori_channels * ratio * ratio
anti_shuffle = np.zeros((height, width, channels), dtype=np.uint8)
for c in range(0, ori_channels):
for x in range(0, ratio):
for y in range(0, ratio):
anti_shuffle[:,:,c * ratio * ratio + x * ratio + y] = input_image[x::ratio, y::ratio, c]#每ratio采样一次
return anti_shuffle
def shuffle(input_image, ratio):
shape = input_image.shape
height = int(shape[0]) * ratio
width = int(shape[1]) * ratio
channels = int(shape[2]) / ratio / ratio
shuffled = np.zeros((height, width, channels), dtype=np.uint8)
for i in range(0, height):
for j in range(0, width):
for k in range(0, channels):
shuffled[i,j,k] = input_image[i / ratio, j / ratio, k * ratio * ratio + (i % ratio) * ratio + (j % ratio)]
return shuffled