Pytorch empilhando várias perdas causa explosão de memória

Quando executei o código esses dias, mostrou que fui morto (a pessoa inteira não é boa). Verifique o log do sistema e constate que a memória não é suficiente (sem memória), não há... jeito... método..., apenas desista! Claro que isso é impossível, como pode o autor ser uma pessoa que desiste levianamente, haha.

Mais perto de casa, o dispositivo usado pelo autor é um disco rígido de 3T e o programa em execução batch_size=1024, que é dividido em 2.000 lotes no total, e a taxa de uso de memória de cada lote aumentará em cerca de 0,5%. , só posso depurar uma frase por vez. Por fim, foi descoberto que foi causado pelo acúmulo de perdas. Conforme mostrado abaixo, o código calculou três perdas: BPR Loss, Reg Loss e InfoNCE Loss, que não podem ser diretamente acumulado como total_loss! Em vez disso, o valor da perda é retirado por meio de .item() e, em seguida, acumulado.

# BPR Loss
bpr_loss = -torch.sum(F.logsigmoid(sup_logits)) 

# Reg Loss
reg_loss = l2_loss(
self.lightgcn.user_embeddings(bat_users),
self.lightgcn.item_embeddings(bat_pos_items),
self.lightgcn.item_embeddings(bat_neg_items),
)
                
# InfoNCE Loss
clogits_user = torch.logsumexp(ssl_logits_user / self.ssl_temp, dim=1)
clogits_item = torch.logsumexp(ssl_logits_item / self.ssl_temp, dim=1)
infonce_loss = torch.sum(clogits_user + clogits_item)
    
loss = bpr_loss + self.ssl_reg * infonce_loss + self.reg * reg_loss

total_loss = total_loss + loss.item() 
total_bpr_loss += bpr_loss.item()
total_reg_loss += self.reg * reg_loss.item()
               
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

Supongo que te gusta

Origin blog.csdn.net/qq_42018521/article/details/131612362
Recomendado
Clasificación