PyTorch Lightning - Lógica de treinamento do LightningModule (training_step) tratamento de exceção try-except

Bem-vindo a seguir meu CSDN:https://spike.blog.csdn.net/
Endereço deste artigo:< a i =3>https://spike.blog.csdn.net/article/details/133673820

Módulo Lightning

Ao usar a estrutura LightningModule para treinar um modelo, os erros de treinamento causados ​​pelos dados afetam seriamente a estabilidade do treinamento, portanto, try-except precisa ser usado para detectar erros a tempo. Ou seja, quando ocorre um erro, a exceção retorna None em training_step. Ao mesmo tempo, on_before_zero_grad também precisa realizar o tratamento de exceções para tratar o retorno da exceção de < /span>training_step.Nenhum.

Da mesma forma, validation_step também pode ser tratado desta forma.

O código fonte é o seguinte:

class MyObject(pl.LightningModule):
	def __init__(self, config, args):
		# ...
		
	def training_step_wrapper(self, batch, batch_idx, log_interval=10):
		# train key process
		
	def training_step(self, batch, batch_idx, log_interval=10):
        """
        typically, each step costs 50 seconds
        参考: https://github.com/Lightning-AI/lightning/pull/3566
        """
        try:
            res = self.training_step_wrapper(batch, batch_idx, log_interval)
            return res
        except Exception as e:
            logger.info(f"[CL] training_step, exception: {
      
      e}")
            return None
            
	def on_before_zero_grad(self, *args, **kwargs):
        try:
            self.ema.update(self.model)
        except Exception as e:
            # 支持 training_step return None
            logger.info(f"[CL] on_before_zero_grad, exception: {
      
      e}")
            return
            
	def validation_step_wrapper(self, batch, batch_idx):
        # val key process

    def validation_step(self, batch, batch_idx):
        try:
            self.validation_step_wrapper(batch, batch_idx)
        except Exception as e:
            logger.info(f"[CL] validation_step, exception: {
      
      e}")
            return

Erros comuns são os seguintes

Matriz fora dos limites:

index 0 is out of bounds for dimension 0 with size 0

Campos de erro de dicionário:

num_res = int(np_example["seq_length"])
KeyError: 'seq_length'

Calcular o valor de entrada está vazio:

V, _, W = torch.linalg.svd(C)

exceção grátis():

free(): invalid next size (fast)

munmap_chunk()Ponteiro nulo:

munmap_chunk(): invalid pointer

Guess you like

Origin blog.csdn.net/u012515223/article/details/133673820