[Code error record] display data set picture - picture tensor problem

The matplotlib drawing imshow() function reports an error "TypeError: Invalid dimensions for image data"

error code

plt.imshow((img[6, :, :, :].moveaxis(0, 2)))

changed to

plt.imshow((img[6, :, :, :]))

Error
TypeError: Invalid dimensions for image data"

change into:

plt.imshow((img[6, :, :, :].squeeze().numpy().transpose(1,2,0)))

Reference
The key to solving this problem is to understand the parameters of the imshow function.
The input of the matplotlib.pyplot.imshow() function needs to be a two-dimensional numpy or a numpy whose third dimension is 3 or 4,

  • When the depth of the third dimension is 1, use the np.squeeze() function to compress the data into a two-dimensional array.
  • Because I use it in the pytorch environment, the output of the result is a tensor of (batch_size, channel, width, height), so I first need the detach() function to cut off the backpropagation.
  • It should be pointed out that imshow does not support tensor display, so I need to use the .cpu() function to transfer to the cpu.
  • As mentioned earlier, the input of the imshow function needs to be a two-dimensional numpy or a numpy whose third dimension is 3 or 4,
  • Because my use case is quite special,There is one more batch_size dimension, but fortunately, I set the batch_size to only 1. At this time, I can use the .squeeze() function to remove 1 and get a numpy of (channel, width, height), which obviously does not match the input requirements of imshow. Therefore, we need to use the transpose function to move channel(=3) to the end, which is why there is such a usage of .transpose(1,2,0). Of course, if the image to be displayed itself is channel=1, then you can use the squeeze() function to get rid of it, and directly input a two-dimensional numpy to the imshow function.

Guess you like

Origin blog.csdn.net/zhe470719/article/details/127163650