matplotlib.pyplot subplots、plot、xlabel等

一、plt.subplots(nrows, ncols, ...)

import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, num='train', figsize=(12, 6))

The above code creates a figure with 1 row and 3 columns of axes, the size of the figure is (12,6), and the name of the figure is 'train'. As shown below. At this point plt points to the rightmost ax (because it was created last).

 The above code is equivalent to: (same as above, at this time plt points to the rightmost ax).

import matplotlib.pyplot as plt
plt.figure("train", (12, 6))
plt.subplot(1,3,1)
plt.subplot(1,3,2)
plt.subplot(1,3,3)

 2. The fig/ax currently pointed to by plt is always the newly created fig/ax. When calling the plt.xxx function, pay attention to which ax of which fig is the object of operation. (but plt.show will display all figures)

import matplotlib.pyplot as plt
import numpy as np

np.random.seed(0)
epochs = 4
epoch_loss_values = np.random.randint(5, size=epochs)

fig, axes = plt.subplots(1, 3, num='train', figsize=(12, 6))
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
axes[0].plot(x, y)  # ax也有plot方法
axes[0].set_xlabel('aaa')  # ax有set_xlabel方法,没有xlabel方法
plt.xlabel("epoch")
plt.title("Epoch Average Loss")

The result is as follows:

3. The newly created ax in a fig may overwrite the old ax

import matplotlib.pyplot as plt
import numpy as np

np.random.seed(0)
epochs = 4
epoch_loss_values = np.random.randint(5, size=epochs)

fig, axes = plt.subplots(1, 3, num='train', figsize=(12, 6))
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
axes[0].plot(x, y)
axes[0].set_xlabel('aaa')
plt.subplot(1,2,2)
plt.xlabel("epoch")
plt.title("Epoch Average Loss")

The result is as follows:

Guess you like

Origin blog.csdn.net/qq_41021141/article/details/125973412