La colección métrica es una parte esencial de cada proyecto de aprendizaje automático, lo que nos permite rastrear el rendimiento del modelo y monitorear el progreso de la capacitación. Idealmente, Métrica debe recolectarse y calcularse sin introducir ninguna sobrecarga adicional en el proceso de capacitación. Sin embargo, al igual que otros componentes del bucle de entrenamiento, el cálculo métrico ineficiente puede introducir gastos generales innecesarios, aumentar los tiempos de los pasos de entrenamiento e inflar los costos de capacitación.
Esta publicación es la séptima de nuestra serie en Perfil de rendimiento y optimización en Pytorch. La serie ha tenido como objetivo enfatizar el papel crítico del análisis del rendimiento y Mejoramiento En el desarrollo del aprendizaje automático. Cada publicación se ha centrado en diferentes etapas de la tubería de capacitación, demostrando herramientas y técnicas prácticas para analizar y aumentar la utilización de recursos y la eficiencia del tiempo de ejecución.
En esta entrega, nos centramos en la colección métrica. Demostraremos cómo una implementación ingenua de la colección métrica puede afectar negativamente el rendimiento del tiempo de ejecución y explorar herramientas y técnicas para su análisis y optimización.
Para implementar nuestra colección métrica, usaremos Torchmetrics una biblioteca popular diseñada para simplificar y estandarizar el cálculo métrico en Pytorch. Nuestros objetivos serán:
- Demuestre la sobrecarga de tiempo de ejecución causada por una implementación ingenua de la colección métrica.
- Use Pytorch Profiler para identificar los cuellos de botella de rendimiento introducidos por el cálculo métrico.
- Demostrar técnicas de optimización Para reducir la sobrecarga de la colección métrica.
Para facilitar nuestra discusión, definiremos un modelo de Pytorch de juguete y evaluaremos cómo la colección métrica puede afectar su rendimiento en tiempo de ejecución. Ejecutaremos nuestros experimentos en una GPU NVIDIA A40, con una Pytorch 2.5.1 Docker imagen y Torchmetrics 1.6.1.
Es importante tener en cuenta que el comportamiento de la recolección métrica puede variar mucho según el hardware, el entorno de tiempo de ejecución y la arquitectura de modelos. Los fragmentos de código proporcionados en esta publicación están destinados solo a fines demostrativos. No interprete nuestra mención de ninguna herramienta o técnica como un respaldo para su uso.
Modelo de Toy Resnet
En el bloque de código a continuación definimos un modelo de clasificación de imagen simple con un Resnet-18 columna vertebral.
import time
import torch
import torchvision
device = "cuda"
model = torchvision.models.resnet18().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters())
Definimos un conjunto de datos sintético que usaremos para entrenar nuestro modelo de juguete.
from torch.utils.data import Dataset, DataLoader
# A dataset with random images and labels
class FakeDataset(Dataset):
def __len__(self):
return 100000000
def __getitem__(self, index):
rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor(data=index % 1000, dtype=torch.int64)
return rand_image, label
train_set = FakeDataset()
batch_size = 128
num_workers = 12
train_loader = DataLoader(
dataset=train_set,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True
)
Definimos una colección de métricas estándar de Torchmetrics, junto con un indicador de control para habilitar o deshabilitar el cálculo métrico.
from torchmetrics import (
MeanMetric,
Accuracy,
Precision,
Recall,
F1Score,
)
# toggle to enable/disable metric collection
capture_metrics = False
if capture_metrics:
metrics = {
"avg_loss": MeanMetric(),
"accuracy": Accuracy(task="multiclass", num_classes=1000),
"precision": Precision(task="multiclass", num_classes=1000),
"recall": Recall(task="multiclass", num_classes=1000),
"f1_score": F1Score(task="multiclass", num_classes=1000),
}
# Move all metrics to the device
metrics = {name: metric.to(device) for name, metric in metrics.items()}
A continuación, definimos un Perfilador de pytorch Instancia, junto con una bandera de control que nos permite habilitar o deshabilitar el perfil. Para obtener un tutorial detallado sobre el uso de Pytorch Profiler, consulte el primera publicación En esta serie.
from torch import profiler
# toggle to enable/disable profiling
enable_profiler = True
if enable_profiler:
prof = profiler.profile(
schedule=profiler.schedule(wait=10, warmup=2, active=3, repeat=1),
on_trace_ready=profiler.tensorboard_trace_handler("./logs/"),
profile_memory=True,
with_stack=True
)
prof.start()
Por último, definimos un paso de entrenamiento estándar:
model.train()
t0 = time.perf_counter()
total_time = 0
count = 0
for idx, (data, target) in enumerate(train_loader):
data = data.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if capture_metrics:
# update metrics
metrics["avg_loss"].update(loss)
for name, metric in metrics.items():
if name != "avg_loss":
metric.update(output, target)
if (idx + 1) % 100 == 0:
# compute metrics
metric_results = {
name: metric.compute().item()
for name, metric in metrics.items()
}
# print metrics
print(f"Step {idx + 1}: {metric_results}")
# reset metrics
for metric in metrics.values():
metric.reset()
elif (idx + 1) % 100 == 0:
# print last loss value
print(f"Step {idx + 1}: Loss = {loss.item():.4f}")
batch_time = time.perf_counter() - t0
t0 = time.perf_counter()
if idx > 10: # skip first steps
total_time += batch_time
count += 1
if enable_profiler:
prof.step()
if idx > 200:
break
if enable_profiler:
prof.stop()
avg_time = total_time/count
print(f'Average step time: {avg_time}')
print(f'Throughput: {batch_size/avg_time:.2f} images/sec')
Colección métrica
Para medir el impacto de la colección métrica en el paso de entrenamiento, ejecutamos nuestro script de entrenamiento con y sin cálculo métrico. Los resultados se resumen en la siguiente tabla.
¡Nuestra ingenua colección métrica resultó en una caída de casi el 10% en el rendimiento del tiempo de ejecución! Si bien la colección métrica es esencial para el desarrollo del aprendizaje automático, generalmente implica operaciones matemáticas relativamente simples y apenas garantiza una sobrecarga tan significativa. ¡¡¿Qué está pasando?!!
Identificar problemas de rendimiento con Pytorch Profiler
Para comprender mejor la fuente de la degradación del desempeño, volvemos al script de entrenamiento con el Pytorch Profiler habilitado. La traza resultante se muestra a continuación:
La traza revela operaciones recurrentes de “cudastreamsynchronize” que coinciden con gotas notables en la utilización de GPU. Estos tipos de eventos de “sincronización de cpu-gpu” se discutieron en detalle en parte segunda de nuestra serie. En un paso de entrenamiento típico, la CPU y la GPU funcionan en paralelo: la CPU administra tareas como transferencias de datos a la carga de GPU y el núcleo, y la GPU ejecuta el modelo en los datos de entrada y actualiza sus pesos. Idealmente, nos gustaría minimizar los puntos de sincronización entre la CPU y la GPU para maximizar el rendimiento. Aquí, sin embargo, podemos ver que la colección métrica ha activado un evento de sincronización realizando una copia de datos de CPU a GPU. Esto requiere que la CPU suspenda su procesamiento hasta que la GPU se ponga al día, lo que, a su vez, hace que la GPU espere a que la CPU reanude la carga de las operaciones posteriores del núcleo. La conclusión es que estos puntos de sincronización conducen a una utilización ineficiente de la CPU y la GPU. Nuestra implicación de la colección métrica agrega ocho eventos de sincronización de estos a cada paso de entrenamiento.
Un examen más detallado de la traza muestra que los eventos de sincronización provienen del actualizar Llamada de la Medio metálico Torchmetric. Para el experto en perfil de perfil, esto puede ser suficiente para identificar la causa raíz, pero iremos un paso más allá y usaremos el torch.profiler.record_function utilidad para identificar la línea de código ofensiva exacta.
Perfil con registro_function
Para identificar la fuente exacta del evento de sincronización, extendimos el Medio metálico clase y anular el actualizar método utilizando registro_function Bloques de contexto. Este enfoque nos permite perfilar operaciones individuales dentro del método e identificar cuellos de botella de rendimiento.
class ProfileMeanMetric(MeanMetric):
def update(self, value, weight = 1.0):
# broadcast weight to value shape
with profiler.record_function("process value"):
if not isinstance(value, torch.Tensor):
value = torch.as_tensor(value, dtype=self.dtype,
device=self.device)
with profiler.record_function("process weight"):
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.as_tensor(weight, dtype=self.dtype,
device=self.device)
with profiler.record_function("broadcast weight"):
weight = torch.broadcast_to(weight, value.shape)
with profiler.record_function("cast_and_nan_check"):
value, weight = self._cast_and_nan_check_input(value, weight)
if value.numel() == 0:
return
with profiler.record_function("update value"):
self.mean_value += (value * weight).sum()
with profiler.record_function("update weight"):
self.weight += weight.sum()
Luego actualizamos nuestra métrica AVG_LOSS para usar el recién creado ProfilineMeMeMetric y Reran el script de entrenamiento.
La traza actualizada revela que el evento de sincronización se origina en la siguiente línea:
weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)
Esta operación convierte el valor escalar predeterminado weight=1.0
en un tensor de Pytorch y lo coloca en la GPU. El evento de sincronización ocurre porque esta acción desencadena una copia de datos de CPU a GPU, que requiere que la CPU espere a que la GPU procese el valor copiado.
Optimización 1: especificar el valor de peso
Ahora que hemos encontrado la fuente del problema, podemos superarlo fácilmente especificando un peso valor en nuestro actualizar llamar. Esto evita que el tiempo de ejecución convierta el escalar predeterminado weight=1.0
en un tensor en la GPU, evitando el evento de sincronización:
# update metrics
if capture_metric:
metrics["avg_loss"].update(loss, weight=torch.ones_like(loss))
Volver a ejecutar el guión después de aplicar este cambio revela que hemos logrado eliminar el evento de sincronización inicial … solo para haber descubierto uno nuevo, esta vez proveniente del _cast_and_nan_check_input función:
Perfil con registro_function – Parte 2
Para explorar nuestro nuevo evento de sincronización, ampliamos nuestra métrica personalizada con sondas de perfiles adicionales y Reran nuestro script.
class ProfileMeanMetric(MeanMetric):
def update(self, value, weight = 1.0):
# broadcast weight to value shape
with profiler.record_function("process value"):
if not isinstance(value, torch.Tensor):
value = torch.as_tensor(value, dtype=self.dtype,
device=self.device)
with profiler.record_function("process weight"):
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.as_tensor(weight, dtype=self.dtype,
device=self.device)
with profiler.record_function("broadcast weight"):
weight = torch.broadcast_to(weight, value.shape)
with profiler.record_function("cast_and_nan_check"):
value, weight = self._cast_and_nan_check_input(value, weight)
if value.numel() == 0:
return
with profiler.record_function("update value"):
self.mean_value += (value * weight).sum()
with profiler.record_function("update weight"):
self.weight += weight.sum()
def _cast_and_nan_check_input(self, x, weight = None):
"""Convert input ``x`` to a tensor and check for Nans."""
with profiler.record_function("process x"):
if not isinstance(x, torch.Tensor):
x = torch.as_tensor(x, dtype=self.dtype,
device=self.device)
with profiler.record_function("process weight"):
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.as_tensor(weight, dtype=self.dtype,
device=self.device)
nans = torch.isnan(x)
if weight is not None:
nans_weight = torch.isnan(weight)
else:
nans_weight = torch.zeros_like(nans).bool()
weight = torch.ones_like(x)
with profiler.record_function("any nans"):
anynans = nans.any() or nans_weight.any()
with profiler.record_function("process nans"):
if anynans:
if self.nan_strategy == "error":
raise RuntimeError("Encountered `nan` values in tensor")
if self.nan_strategy in ("ignore", "warn"):
if self.nan_strategy == "warn":
print("Encountered `nan` values in tensor."
" Will be removed.")
x = x[~(nans | nans_weight)]
weight = weight[~(nans | nans_weight)]
else:
if not isinstance(self.nan_strategy, float):
raise ValueError(f"`nan_strategy` shall be float"
f" but you pass {self.nan_strategy}")
x[nans | nans_weight] = self.nan_strategy
weight[nans | nans_weight] = self.nan_strategy
with profiler.record_function("return value"):
retval = x.to(self.dtype), weight.to(self.dtype)
return retval
La traza resultante se captura a continuación:
El rastreo apunta directamente a la línea ofensiva:
anynans = nans.any() or nans_weight.any()
Esta operación verifica para NaN
Valores en los tensores de entrada, pero introduce un costoso evento de sincronización de CPU-GPU porque la operación implica copiar datos de la GPU a la CPU.
Sobre una inspección más cercana de la torchmetric Baseggregador Clase, encontramos varias opciones para manejar actualizaciones de valor NAN, todas las cuales pasan a través de la línea de código ofensiva. Sin embargo, para nuestro caso de uso, calculando la métrica de pérdida promedio, esta verificación es innecesaria y no justifica la penalización de rendimiento del tiempo de ejecución.
Optimización 2: Deshabilitar las comprobaciones de valor NAN
Para eliminar la sobrecarga, proponemos deshabilitar el NaN
verificaciones de valor anulando el _cast_and_nan_check_input
función. En lugar de una anulación estática, implementamos una solución dinámica que puede aplicarse de manera flexible a cualquier descendiente del Baseggregadorclase.
from torchmetrics.aggregation import BaseAggregator
def suppress_nan_check(MetricClass):
assert issubclass(MetricClass, BaseAggregator), MetricClass
class DisableNanCheck(MetricClass):
def _cast_and_nan_check_input(self, x, weight=None):
if not isinstance(x, torch.Tensor):
x = torch.as_tensor(x, dtype=self.dtype,
device=self.device)
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.as_tensor(weight, dtype=self.dtype,
device=self.device)
if weight is None:
weight = torch.ones_like(x)
return x.to(self.dtype), weight.to(self.dtype)
return DisableNanCheck
NoNanMeanMetric = suppress_nan_check(MeanMetric)
metrics["avg_loss"] = NoNanMeanMetric().to(device)
Resultados de la optimización posterior: éxito
Después de implementar las dos optimizaciones, especificar el valor de peso y deshabilitar el NaN
Verificaciones: encontramos el rendimiento del tiempo de paso y la utilización de la GPU para que coincida con las de nuestro experimento de referencia. Además, el rastro de perfilador de Pytorch resultante muestra que se han eliminado todos los eventos adicionales de “cudastreamsynchronize” que estaban asociados con la colección métrica. Con algunos pequeños cambios, hemos reducido el costo de la capacitación en ~ 10% sin ningún cambio en el comportamiento de la colección métrica.
En la siguiente sección exploraremos una optimización de colección métrica adicional.
Ejemplo 2: Optimización de la colocación de dispositivos métricos
En la sección anterior, los valores métricos residían en la GPU, lo que hace que sea lógico almacenar y calcular las métricas en la GPU. Sin embargo, en escenarios en los que los valores que deseamos residir en la CPU, podría ser preferible almacenar las métricas en la CPU para evitar transferencias innecesarias de dispositivos.
En el bloque de código a continuación, modificamos nuestro script para calcular el tiempo de paso promedio usando un Medio metálico en la CPU. Este cambio no tiene impacto en el rendimiento del tiempo de ejecución de nuestro paso de entrenamiento:
avg_time = NoNanMeanMetric()
t0 = time.perf_counter()
for idx, (data, target) in enumerate(train_loader):
# move data to device
data = data.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if capture_metrics:
metrics["avg_loss"].update(loss)
for name, metric in metrics.items():
if name != "avg_loss":
metric.update(output, target)
if (idx + 1) % 100 == 0:
# compute metrics
metric_results = {
name: metric.compute().item()
for name, metric in metrics.items()
}
# print metrics
print(f"Step {idx + 1}: {metric_results}")
# reset metrics
for metric in metrics.values():
metric.reset()
elif (idx + 1) % 100 == 0:
# print last loss value
print(f"Step {idx + 1}: Loss = {loss.item():.4f}")
batch_time = time.perf_counter() - t0
t0 = time.perf_counter()
if idx > 10: # skip first steps
avg_time.update(batch_time)
if enable_profiler:
prof.step()
if idx > 200:
break
if enable_profiler:
prof.stop()
avg_time = avg_time.compute().item()
print(f'Average step time: {avg_time}')
print(f'Throughput: {batch_size/avg_time:.2f} images/sec')
El problema surge cuando intentamos extender nuestro guión para apoyar la capacitación distribuida. Para demostrar el problema, modificamos nuestra definición de modelo para usar DistributedDataparallel (DDP) :
# toggle to enable/disable ddp
use_ddp = True
if use_ddp:
import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group("nccl", rank=0, world_size=1)
torch.cuda.set_device(0)
model = DDP(torchvision.models.resnet18().to(device))
else:
model = torchvision.models.resnet18().to(device)
# insert training loop
# append to end of the script:
if use_ddp:
# destroy the process group
dist.destroy_process_group()
La modificación DDP da como resultado el siguiente error:
RuntimeError: No backend type associated with device type cpu
Por defecto, las métricas en la capacitación distribuida están programadas para sincronizar en todos los dispositivos en uso. Sin embargo, el backend de sincronización utilizado por DDP no admite métricas almacenadas en la CPU.
Una forma de resolver esto es deshabilitar la sincronización métrica de servicio cruzado:
avg_time = NoNanMeanMetric(sync_on_compute=False)
En nuestro caso, donde estamos midiendo el tiempo promedio, esta solución es aceptable. Sin embargo, en algunos casos, la sincronización métrica es esencial, y es posible que no tengamos más remedio que mover la métrica a la GPU:
avg_time = NoNanMeanMetric().to(device)
Desafortunadamente, esta situación da lugar a un nuevo evento de sincronización de CPU-GPU que proviene del actualizar función.
Este evento de sincronización difícilmente debería ser una sorpresa: después de todo, estamos actualizando una métrica de GPU con un valor que reside en la CPU, que debería requerir una copia de memoria. Sin embargo, en el caso de una métrica escalar, esta transferencia de datos se puede evitar completamente con una optimización simple.
Optimización 3: Realice actualizaciones métricas con tensores en lugar de escalar
La solución es sencilla: en lugar de actualizar la métrica con un valor flotante, nos convertimos a un tensor antes de llamar update
.
batch_time = torch.as_tensor(batch_time)
avg_time.update(batch_time, torch.ones_like(batch_time))
Este cambio menor evita la línea problemática de código, elimina el evento de sincronización y restaura el paso de paso al rendimiento de línea de base.
A primera vista, este resultado puede parecer sorprendente: esperaríamos que actualizar una métrica de GPU con un tensor de CPU aún requiera una copia de memoria. Sin embargo, Pytorch optimiza las operaciones en tensores escalares mediante el uso de un núcleo dedicado que realiza la adición sin una transferencia de datos explícita. Esto evita el costoso evento de sincronización que de otro modo ocurriría.
Resumen
En esta publicación, exploramos cómo un enfoque ingenuo de la antorchmetrics puede introducir eventos de sincronización de CPU-GPU y degradar significativamente el rendimiento de entrenamiento de Pytorch. Usando Pytorch Profiler, identificamos las líneas de código responsables de estos eventos de sincronización y aplicó optimizaciones dirigidas para eliminarlos:
- Especifique explícitamente un tensor de peso al llamar al
MeanMetric.update
función en lugar de confiar en el valor predeterminado. - Desactivar los controles NAN en la base
Aggregator
Clase o reemplácelos con una alternativa más eficiente. - Administre cuidadosamente la colocación del dispositivo de cada métrica para minimizar las transferencias innecesarias.
- Deshabilite la sincronización métrica de servicio cruzado cuando no sea necesario.
- Cuando la métrica reside en una GPU, convierta los escalares de punto flotante en tensores antes de pasarlos al
update
función para evitar la sincronización implícita.
Hemos creado un dedicado Solicitud de solicituden el Torchmetrics GithubPágina que cubre algunas de las optimizaciones discutidas en esta publicación. ¡No dude en contribuir con sus propias mejoras y optimizaciones!