Crosstalk-Generation/code/train.py
-
checkpoint = None
Initializes a variablecheckpoint
that may be used to store the state of the model. -
print('Building encoder and decoder ...')
Output a message that the encoder and decoder are being built. -
embedding = nn.Embedding(voc.n_words, hidden_size)
Create an embedding layer whose input dimension is the size of the vocabulary and whose output dimension is the size of the hidden layer . -
encoder = EncoderRNN(voc.n_words, hidden_size, embedding, n_layers, dropout)
UseEncoderRNN
a class to build an encoder. -
attn_model = 'dot'
Set the type of attention model to 'dot'. -
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.n_words, n_layers, dropout)
UseLuongAttnDecoderRNN
the class to build a decoder. -
if loadFilename:
checkpoint = torch.load(loadFilename)
encoder.load_state_dict(checkpoint['en'])
decoder.load_state_dict(checkpoint['de'])
If providedloadFilename
, the state of the model is loaded from this path, and the state is loaded into the encoder and decoder. -
encoder = encoder.to(device)
decoder = decoder.to(device)
Place the encoder and decoder on the specified device (CPU or GPU). -
print('Building optimizers ...')
Output a message that the optimizer is being built . -
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
Create an Adam optimizer for each encoder and decoder . -
if loadFilename:
encoder_optimizer.load_state_dict(checkpoint['en_opt'])
decoder_optimizer.load_state_dict(checkpoint['de_opt'])
If providedloadFilename
, loads the optimizer state from the checkpoint and loads the state into the encoder and decoder optimizers. -
print('Initializing ...')
Print a message indicating that an initialization operation is in progress. -
start_iteration = 1
perplexity = []
print_loss = 0
Initialize some variables, including the number of iterations to start, perplexity, and loss to print. -
if loadFilename:
start_iteration = checkpoint['iteration'] + 1
perplexity = checkpoint['plt']
If providedloadFilename
, the starting iteration count and perplexity are loaded from the checkpoint.
In general, the main function of this code is to build and initialize the model (including encoder and decoder), build the optimizer, and load the state of the model and optimizer from the checkpoint ( if checkpoint is provided).