(Pytorch Advanced Road) Informer

论文:Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting (AAAI’21 Best Paper)

After reading the previous papers and learning, I also focus on application, so there will be more code parts, and the theoretical part will be mentioned in one stroke.

The author of the paper also gave colab very conscientiously, which greatly speeded up the speed of seeing how the source code is implemented: https://colab.research.google.com/drive/1_X7O2BkFLvqyCdZzDZvV2MB0aAvYALLC

So what do you mainly look at in the source code? First of all, issue. If the github issue cannot run at all, then you don’t need to spend time. If there are no big errors, it means that the code has no fatal errors.

The second step is to look at the data, what is the source data, and how the data is preprocessed

The third step is to look at the model implementation, which is generally under the model folder. This step is relatively simple, and the focus is on how to realize the innovation part

The fourth step is pth, to see the recurring results


Article directory

Model framework

insert image description here
Innovation point: The main idea of ​​ProbSparse Attention
is to use top-k to select the most useful information


code address

https://github.com/zhouhaoyi/Informer2020

Download the code and data, read the description of Data carefully, we learned that we have to put the data under the data/ETT folder

The parser part roughly looks at what it means, model, data, root_path, data_path, single-card multi-card and num_workers settings, and infers the general meaning based on the context. At the same time, the data dictionary is provided in github. We need to modify the data and data_path parameters at least

Since I am debugging on windows, if args is required=True, it will be very troublesome for us to manually fill in the parameters. Personally, I will change them to False first.

Right click to run successfully, then you can debug step by step


main_informer.py runs, and gradually runs to
exp.train(setting)
enter the train function

		train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')

First of all, _get_data fetches the data, and enters the function to have a look. When you see Dataset_Custom in data_dict, you know that it can customize the data. Then instantiate the dataset, instantiate the dataset, and then instantiate the dataloader. The dataset is ready.

Look at how to preprocess the data in the dataset. There are __read_data__ and __getitem__ functions in the dataset. The context analysis __read_data__ is the preprocessing step. Because I saw the StandardScaler, I made a standardization in it.

The time_features function performs feature encoding on the time dimension. The idea is very simple, but the code writing is particularly complicated

Finally construct the dataloader


Go down to epoch and start iterating the training data, to the _process_one_batch function

pred, true = self._process_one_batch(
    train_data, batch_x, batch_y, batch_x_mark, batch_y_mark)

_process_one_batch further processes data and inputs into the model, dec_input is initialized with all 0 or all 1

Then the 48 behind enc_input and dec_input are spliced ​​according to dim=1 dimension

The 48 in front of dec_input are the observations of time series, and we want to predict the next 24

The model input is enc_input of 96,12, enc_mark is 96,4 time encoding feature
dec_input is 72,12, dec_mark is 72,4


model part

Mainly the attention module (others are relatively simple), in model/attn.py, see the ProbAttention class, and directly look at the forward function

First divide QKV, select 25 out of 96 seqlen (U_part)

Here comes the point, the _prob_QK function

scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)

enter_prob_QK

First K expands the dimension of -3, K_expand=(32, 8, 96, 96, 64)

index_sample randomly samples a 96×25 matrix from 0 to 96, and K_sample takes out (32, 8, 96, 25, 64)

Calculate the inner product of Q and K_sample to Q_K_sample (32, 8, 96, 25)
Calculate max on Q_K_sample, select M_top Q with the largest max peak, and get Q_reduce (25 Q)

Q_reduce and 96 K do the inner product

    def _prob_QK(self, Q, K, sample_k, n_top):  # n_top: c*ln(L_q)
        # Q [B, H, L, D]
        B, H, L_K, E = K.shape
        _, _, L_Q, _ = Q.shape

        # calculate the sampled Q_K
        K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
        index_sample = torch.randint(L_K, (L_Q, sample_k))  # real U = U_part(factor*ln(L_k))*L_q
        K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
        Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)

        # find the Top_k query with sparisty measurement
        M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
        M_top = M.topk(n_top, sorted=False)[1]

        # use the reduced Q to calculate Q_K
        Q_reduce = Q[torch.arange(B)[:, None, None],
                   torch.arange(H)[None, :, None],
                   M_top, :]  # factor*ln(L_q)
        Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))  # factor*ln(L_q)*L_k

        return Q_K, M_top

The _get_initial_context function shows that if there is no selected Q, it is relatively mediocre, and the average V is directly used to represent it

	 V_sum = V.mean(dim=-2)

_update_context
only updates 25 Q

context_in[
            torch.arange(B)[:, None, None],
            torch.arange(H)[None, :, None],
            index, :]\
            = torch.matmul(attn, V).type_as(context_in)

After the attention is done
, go back to forward, do a distillation operation, MaxPool1d, stride=2, do a downsampling
96len to 48len

ConvLayer(
  (downConv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,), padding_mode=circular)
  (norm): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activation): ELU(alpha=1.0)
  (maxPool): MaxPool1d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)

After the encoder is done
, make the decoder, the module used is the same as that of the encoder, and there is a cross attention, which is a cliché, skip...

Guess you like

Origin blog.csdn.net/qq_19841133/article/details/129234584