Estás entrenando a tu último modelo de IA, observando ansiosamente cómo la pérdida disminuye constantemente cuando de repente: ¡BOOM! Sus registros están inundados de NANS (no un número): su modelo está irreparablemente corrompido y te quedan mirando tu pantalla con desesperación. Para empeorar las cosas, los Nans no aparecen constantemente. A veces tu modelo entrena bien; Otras veces, falla inexplicablemente. A veces se estrellará de inmediato, a veces después de muchos días de entrenamiento.
Nans en Aprendizaje profundo Las cargas de trabajo se encuentran entre los problemas más frustrantes para encontrar. Y debido a que a menudo aparecen esporádicamente, desencadenados por una combinación específica de estado modelo, datos de entrada y factores estocásticos, pueden ser increíblemente difícil de reproducir y depurar.
Dado el considerable costo de capacitación de modelos de IA y los posibles residuos causados por fallas de NAN, se recomienda tener herramientas dedicadas para capturar y analizar los sucesos NAN. En publicación anteriordiscutimos el desafío de depurar NANS en una carga de trabajo de entrenamiento TensorFlow. Propusimos un esquema eficiente para capturar y reproducir NANS y compartimos una implementación de flujo de tensor de muestra. En esta publicación, adoptamos y demostramos un mecanismo similar para la depuración de NANS en las cargas de trabajo de Pytorch. El esquema general es el siguiente:
En cada paso de entrenamiento:
- Guarde una copia del lote de entrada de entrenamiento.
- Verifique los gradientes en busca de valores de NAN. Si aparece alguno, guarde un punto de control con los pesos del modelo actual antes de que el modelo esté dañado. Además, guarde el lote de entrada y, si es necesario, el estado estocástico. Discontinuar el trabajo de capacitación.
- Reproducir y depurar la ocurrencia NAN cargando el estado de experimento guardado.
Aunque este esquema se puede implementar fácilmente en Pytorch nativo, aprovecharemos la oportunidad para demostrar algunas de las comodidades de Pytorch Lightning -Un potente marco de código abierto diseñado para optimizar el desarrollo de modelos de aprendizaje automático (ML). Construido en Pytorch, Lightning resume muchos de los componentes de placa de caldera de un experimento ML, como bucles de entrenamiento, distribución de datos, registro y más, lo que permite a los desarrolladores centrarse en la lógica central de sus modelos.
Para implementar nuestro esquema de captura nan, usaremos Backback de Lightning Interfaz: una estructura dedicada que permite insertar una lógica personalizada en puntos específicos durante el flujo de ejecución.
Es importante destacar que no vea nuestra elección de rayos o ninguna otra herramienta o técnica que mencionamos como un respaldo de su uso. El código que compartiremos está destinado a fines demostrativos: no confíe en su corrección u optimización.
Muchas gracias a Rom Maltser por sus contribuciones a esta publicación.
Devolución de llamada de Nancapture
Para implementar nuestra solución de captura NAN, creamos una devolución de llamada Nancapture Lightning. El constructor recibe una ruta de directorio para almacenar/cargar puntos de control y establece el estado de Nancapture. También definimos los servicios públicos para verificar los NANS, almacenar puntos de control y detener el trabajo de capacitación.
import os
import torch
from copy import deepcopy
import lightning.pytorch as pl
class NaNCapture(pl.Callback):
def __init__(self, dirpath: str):
# path to checkpoint
self.dirpath = dirpath
# update to True when Nan is identified
self.nan_captured = False
# stores a copy of the last batch
self.last_batch = None
self.batch_idx = None
@staticmethod
def contains_nan(tensor):
return torch.isnan(tensor).any().item()
# alternatively check for finite
# return not torch.isfinite(tensor).item()
@staticmethod
def halt_training(trainer):
trainer.should_stop = True
# communicate stop command to all other ranks
trainer.strategy.reduce_boolean_decision(trainer.should_stop,
all=False)
def save_ckpt(self, trainer):
os.makedirs(self.dirpath, exist_ok=True)
# include trainer.global_rank to avoid conflict
filename = f"nan_checkpoint_rank_{trainer.global_rank}.ckpt"
full_path = os.path.join(self.dirpath, filename)
print(f"saving ckpt to {full_path}")
trainer.save_checkpoint(full_path, False)
Función de devolución de llamada: on_train_batch_start
Comenzamos implementando el on_train_batch_start enganchar para almacenar una copia de cada lote de entrada. En caso de un evento NAN, este lote se almacenará en el punto de control.
Función de devolución de llamada: on_before_optimizer_step
A continuación implementamos el on_before_optimizer_step gancho. Aquí, verificamos las entradas de NAN en todos los tensores de gradiente. Si se encuentra, almacenamos un punto de control con los pesos del modelo no corruptos y detenemos el entrenamiento.
Python"> def on_before_optimizer_step(self, trainer, pl_module, optimizer):
if not self.nan_captured:
# Check if gradients contain NaN
grads = [p.grad.view(-1) for p in pl_module.parameters()
if p.grad is not None]
all_grads = torch.cat(grads)
if self.contains_nan(all_grads):
print("nan found")
self.save_ckpt(trainer)
self.halt_training(trainer)
Capturando el estado de entrenamiento
Para habilitar la reproducibilidad, incluimos el estado de Nancapture en el punto de control al agregarlo al diccionario estatal de capacitación. Lightning proporciona utilidades dedicadas para guardar y cargar un Estado de devolución de llamada:
def state_dict(self):
d = {"nan_captured": self.nan_captured}
if self.nan_captured:
d["last_batch"] = self.last_batch
return d
def load_state_dict(self, state_dict):
self.nan_captured = state_dict.get("nan_captured", False)
if self.nan_captured:
self.last_batch = state_dict["last_batch"]
Reproduciendo la ocurrencia nan
Hemos descrito cómo nuestra devolución de llamada Nancapture puede usarse para almacenar el estado de capacitación que resultó en una NAN, pero ¿cómo recargamos este estado para reproducir el problema y depurarlo? Para lograr esto, aprovechamos la clase de carga de datos dedicada de Lightning, Lightningdatamodule.
Función DataModule: on_before_batch_transfer
En el bloque de código a continuación, extendemos el Lightningdatamodule clase para permitir inyectar un lote de entrada de entrenamiento fijo. Esto se logra anulando el on_before_batch_transfer gancho, como se muestra a continuación:
from lightning.pytorch import LightningDataModule
class InjectableDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.cached_batch = None
def set_custom_batch(self, batch):
self.cached_batch = batch
def on_before_batch_transfer(self, batch, dataloader_idx):
if self.cached_batch:
return self.cached_batch
return batch
Función de devolución de llamada: on_train_start
El paso final es modificar el on_train_start gancho de nuestra devolución de llamada de Nancapture para inyectar el lote de entrenamiento almacenado en el Lightningdatamodule.
def on_train_start(self, trainer, pl_module):
if self.nan_captured:
datamodule = trainer.datamodule
datamodule.set_custom_batch(self.last_batch)
En la siguiente sección demostraremos la solución de extremo a extremo utilizando un ejemplo de juguete.
Ejemplo de juguete
Para probar nuestra nueva devolución de llamada, creamos un resnet50-El modelo de clasificación de imagen basada en una función de pérdida diseñada deliberadamente para activar ocurrencias NAN.
En lugar de usar el estándar Cruzado Pérdida, calculamos binary_cross_entropy_with_logits para cada clase de forma independiente y divide el resultado por el número de muestras pertenecientes a esa clase. Inevitablemente, nos encontraremos con un lote en el que faltan una o más clases, lo que lleva a una operación divide por cero, lo que resulta en valores de NAN y corrompe el modelo.
La implementación a continuación sigue a los rayos tutorial introductorio.
import lightning.pytorch as pl
import torch
import torchvision
import torch.nn.functional as F
num_classes = 20
# define a lightning module
class ResnetModel(pl.LightningModule):
def __init__(self):
"""Initializes a new instance of the MNISTModel class."""
super().__init__()
self.model = torchvision.models.resnet50(num_classes=num_classes)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_nb):
x, y = batch
outputs = self(x)
# uncomment for default loss
# return F.cross_entropy(outputs, y)
# calculate binary_cross_entropy for each class individually
losses = []
for c in range(num_classes):
count = torch.count_nonzero(y==c)
masked = torch.where(y==c, 1., 0.)
loss = F.binary_cross_entropy_with_logits(
outputs[..., c],
masked,
reduction='sum'
)
mean_loss = loss/count # could result in NaN
losses.append(mean_loss)
total_loss = torch.stack(losses).mean()
return total_loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
Definimos un conjunto de datos sintético y lo encapsulamos en nuestro InjectableDataModule clase:
import os
import random
from torch.utils.data import Dataset, DataLoader
batch_size = 128
num_steps = 800
# A dataset with random images and labels
class FakeDataset(Dataset):
def __len__(self):
return batch_size*num_steps
def __getitem__(self, index):
rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor(random.randint(0, num_classes-1),
dtype=torch.int64)
return rand_image, label
# define a lightning datamodule
class FakeDataModule(InjectableDataModule):
def train_dataloader(self):
dataset = FakeDataset()
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=os.cpu_count(),
pin_memory=True
)
Finalmente, inicializamos un rayo Entrenador con nuestra devolución de llamada de Nancapture y entrenador de llamadas con nuestro módulo Lightning y Lightning DataModule.
import time
if __name__ == "__main__":
# Initialize a lightning module
lit_module = ResnetModel()
# Initialize a DataModule
mnist_data = FakeDataModule()
# Train the model
ckpt_dir = "./ckpt_dir"
trainer = pl.Trainer(
max_epochs=1,
callbacks=[NaNCapture(ckpt_dir)]
)
ckpt_path = None
# check is nan ckpt exists
if os.path.isdir(ckpt_dir):
# check if nan ckpt exists
if os.path.isdir(ckpt_dir):
dir_contents = [os.path.join(ckpt_dir, f)
for f in os.listdir(ckpt_dir)]
ckpts = [f for f in dir_contents
if os.path.isfile(f) and f.endswith('.ckpt')]
if ckpts:
ckpt_path = ckpts[0]
t0 = time.perf_counter()
trainer.fit(lit_module, mnist_data, ckpt_path=ckpt_path)
print(f"total runtime: {time.perf_counter() - t0}")
Después de varios pasos de entrenamiento, ocurrirá un evento NAN. En este punto, se guarda un punto de control con el estado de entrenamiento completo y se detiene la capacitación.
Cuando el script se ejecute nuevamente, el estado exacto que causó el NAN se volverá a cargarnos, lo que nos permite reproducir fácilmente el problema y depurar su causa raíz.
Sobrecarga de rendimiento
Para evaluar el impacto de nuestra devolución de llamada Nancapture en el rendimiento del tiempo de ejecución, modificamos nuestro experimento para usar Cruzado (para evitar NANS) y midió el rendimiento promedio cuando se ejecuta con y sin devolución de llamada de Nancapture. Los experimentos se realizaron en un GPU NVIDIA L40Scon un Pytorch 2.5.1 Docker imagen.
Para nuestro modelo de juguete, la devolución de llamada de Nancapture agrega una sobrecarga mínima del 1.5% al rendimiento del tiempo de ejecución, un pequeño precio a pagar por las valiosas capacidades de depuración que proporciona.
Naturalmente, la sobrecarga real dependerá de los detalles del modelo y el entorno de tiempo de ejecución.
Cómo manejar la estocasticidad
La solución que hemos descrito en adelante tendrá éxito en reproducir el estado de entrenamiento siempre que el modelo no incluya ninguna aleatoriedad. Sin embargo, la introducción de estocasticidad en la definición del modelo a menudo es crítica para la convergencia. Un ejemplo común de una capa estocástica es antorch.nn.dropout.
Puede encontrar que su evento NAN depende del estado preciso de aleatoriedad cuando ocurrió la falla. En consecuencia, nos gustaría mejorar nuestra devolución de llamada Nancapture para capturar y restaurar el estado aleatorio en el punto de falla. El estado aleatorio está determinado por varias bibliotecas. En el bloque de código a continuación, intentamos capturar el estado completo de aleatoriedad:
import os
import torch
import random
import numpy as np
from copy import deepcopy
import lightning.pytorch as pl
class NaNCapture(pl.Callback):
def __init__(self, dirpath: str):
# path to checkpoint
self.dirpath = dirpath
# update to True when Nan is identified
self.nan_captured = False
# stores a copy of the last batch
self.last_batch = None
self.batch_idx = None
# rng state
self.rng_state = {
"torch": None,
"torch_cuda": None,
"numpy": None,
"random": None
}
@staticmethod
def contains_nan(tensor):
return torch.isnan(tensor).any().item()
# alternatively check for finite
# return not torch.isfinite(tensor).item()
@staticmethod
def halt_training(trainer):
trainer.should_stop = True
trainer.strategy.reduce_boolean_decision(trainer.should_stop,
all=False)
def save_ckpt(self, trainer):
os.makedirs(self.dirpath, exist_ok=True)
# include trainer.global_rank to avoid conflict
filename = f"nan_checkpoint_rank_{trainer.global_rank}.ckpt"
full_path = os.path.join(self.dirpath, filename)
print(f"saving ckpt to {full_path}")
trainer.save_checkpoint(full_path, False)
def on_train_start(self, trainer, pl_module):
if self.nan_captured:
# inject batch
datamodule = trainer.datamodule
datamodule.set_custom_batch(self.last_batch)
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
if self.nan_captured:
# restore random state
torch.random.set_rng_state(self.rng_state["torch"])
torch.cuda.set_rng_state_all(self.rng_state["torch_cuda"])
np.random.set_state(self.rng_state["numpy"])
random.setstate(self.rng_state["random"])
else:
# capture current batch
self.last_batch= deepcopy(batch)
self.batch_idx = batch_idx
# capture current random state
self.rng_state["torch"] = torch.random.get_rng_state()
self.rng_state["torch_cuda"] = torch.cuda.get_rng_state_all()
self.rng_state["numpy"] = np.random.get_state()
self.rng_state["random"] = random.getstate()
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
if not self.nan_captured:
# Check if gradients contain NaN
grads = [p.grad.view(-1) for p in pl_module.parameters()
if p.grad is not None]
all_grads = torch.cat(grads)
if self.contains_nan(all_grads):
print("nan found")
self.save_ckpt(trainer)
self.halt_training(trainer)
def state_dict(self):
d = {"nan_captured": self.nan_captured}
if self.nan_captured:
d["last_batch"] = self.last_batch
d["rng_state"] = self.rng_state
return d
def load_state_dict(self, state_dict):
self.nan_captured = state_dict.get("nan_captured", False)
if self.nan_captured:
self.last_batch = state_dict["last_batch"]
self.rng_state = state_dict["rng_state"]
Es importante destacar que establecer el estado aleatorio puede no garantizar completos reproducibilidad. La GPU debe su poder a su paralelismo masivo. En algunas operaciones de GPU, múltiples subprocesos pueden leer o escribir simultáneamente en las mismas ubicaciones de memoria, lo que resulta en un no determinismo. Pytorch permite cierto control sobre esto a través de su use_deterministic_algorithmspero esto puede afectar el rendimiento del tiempo de ejecución. Además, existe la posibilidad de que el evento NAN no se reproduzca una vez que se cambie esta configuración. Consulte la documentación de Pytorch en reproducibilidad Para más detalles.
Resumen
Encontrar fallas de NAN es uno de los eventos más desalentadores que pueden ocurrir en el desarrollo del aprendizaje automático. Estos errores no solo desperdician valiosos recursos de cálculo y desarrollo, sino que a menudo indican problemas fundamentales en la arquitectura del modelo o el diseño del experimento. Debido a su naturaleza esporádica, a veces esquiva, la depuración de fallas NAN puede ser una pesadilla.
Esta publicación introdujo un enfoque proactivo para capturar y reproducir errores NAN utilizando una devolución de llamada dedicada. La solución que compartimos es una propuesta que puede modificarse y extenderse para su caso de uso específico.
Si bien esta solución puede no abordar cada escenario NAN posible, reduce significativamente el tiempo de depuración cuando corresponde, lo que puede ahorrar a los desarrolladores innumerables horas de frustración y esfuerzo desperdiciado.