Deep Learning (5) Softmax Regression: Introduction to Classification Algorithms, How to Load the Fashion-MINIST Data Set

Softmax returns

Fundamental

Regression and classification are two common methods of deep learning. Regression is a continuous prediction (for example, I predict the next two-color ball number based on the past lottery list), and classification is a prediction of discrete categories (handwritten speech recognition, picture recognition).

1699720169075

Now that we have a certain understanding of processing regression, how do we transition to classification?

Suppose we have n classes, first we need to encode these classes to turn them into data. All classes become a column vector.

y = [ y 1 , y 2 , . . . y n ] T y=[y_1,y_2,...y_n]^Tand=[y1,and2,...yn]T

There is a data belonging to the i-th category, then its column vector is:

y = [ 0 , 0 , . . . , 1 , . . . , 0 , 0 ] T y=[0,0,...,1,...,0,0]^T and=[0,0,...,1,...,0,0]T

That is, only the elements of the class it belongs to =1.

You can use mean square loss training and determine which one to use based on probability.

Softmax regression is a classification method (the generalization of regression problems to multi-classification). First determine the number of input features and the number of output categories. For example, in the picture above we have 4 features and 3 possible categories, then the formula for calculating their respective probabilities includes 3 linear regressions:

The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly.

It can be seen that Softmax is a fully connected single-layer neural network.

The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly.

After normalizing all the output results, we select the largest possible classification result with the highest confidence.

image-20231112100423488

Taking the exponent of e makes all values ​​nonnegative.

Use the real probability vector - the probability vector we predict is the loss. The real value is a column vector with only one 1.

Cross entropy loss:

image-20231112101259670

It can be seen that in the classification problem, we don't care about the incorrect prediction value, we only care about whether the correct prediction value is large enough. **Because the correct value is a column vector with only one element being 1.

Commonly used loss functions

L2 Loss: Mean square loss.

image-20231112101555142

L1 Loss: Absolute value loss.

image-20231112101829868

The L2 gradient is a sloping straight line, which is more suitable for gradient descent algorithms; L1 is a jump, and the gradient is either -1 or 1. As shown in the figure, the gradient of L1 L2.

image-20231112102551104

We can combine the two to get a new loss function (robust loss Huber Robust):

KaTeX parse error: {equation} can be used only in display mode.

image-20231112102721527

Image classification dataset

MINIST is a commonly used image classification data set, but it is too simple. The later upgraded version was called Fashion-MINIST (clothing classification).

First, we study how to load the training data set for later testing of the algorithm.

# 导包
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

d2l.use_svg_display()

# 下载数据集并读取到内存
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)		# 训练数据集
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)	# 测试数据集用于评估性能

# 定义函数用于返回对应索引的标签
def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

# 图像可视化,让结果看着更直观,比如下面那个绿色图的样子
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

# 我们先读一点数据集看看啥样的
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

1699980345931

# 通过内置数据加载器读取一批量数据,自动随机打乱读取,不需要我们自己定义
batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

The above measurement time is basically 2-3 seconds.

To summarize and integrate the above data reading process, the code is as follows:

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

Loading an image also allows you to resize it.

Guess you like

Origin blog.csdn.net/jtwqwq/article/details/134471315