Problem: After the single-channel image passes through torchvision.utils.make_grid, a three-channel image is generated.
torch.Size([4, 1, 256, 256]) → torch.Size([3, 518, 1034])
Reason: There is a piece of code inside the function:
When the stitched image is a single channel, the image will be superimposed three times on the channel.
Solution: Modify the code to: