Colección métrica eficiente en Pytorch: evitando las trampas de rendimiento de Torchmetrics

La colección métrica es una parte esencial de cada proyecto de aprendizaje automático, lo que nos permite rastrear el rendimiento del modelo y monitorear el progreso de la capacitación. Idealmente, Métrica debe recolectarse y calcularse sin introducir ninguna sobrecarga adicional en el proceso de capacitación. Sin embargo, al igual que otros componentes del bucle de entrenamiento, el cálculo métrico ineficiente puede introducir gastos generales innecesarios, aumentar los tiempos de los pasos de entrenamiento e inflar los costos de capacitación.

Esta publicación es la séptima de nuestra serie en Perfil de rendimiento y optimización en Pytorch. La serie ha tenido como objetivo enfatizar el papel crítico del análisis del rendimiento y Mejoramiento En el desarrollo del aprendizaje automático. Cada publicación se ha centrado en diferentes etapas de la tubería de capacitación, demostrando herramientas y técnicas prácticas para analizar y aumentar la utilización de recursos y la eficiencia del tiempo de ejecución.

En esta entrega, nos centramos en la colección métrica. Demostraremos cómo una implementación ingenua de la colección métrica puede afectar negativamente el rendimiento del tiempo de ejecución y explorar herramientas y técnicas para su análisis y optimización.

Para implementar nuestra colección métrica, usaremos Torchmetrics una biblioteca popular diseñada para simplificar y estandarizar el cálculo métrico en Pytorch. Nuestros objetivos serán:

  1. Demuestre la sobrecarga de tiempo de ejecución causada por una implementación ingenua de la colección métrica.
  2. Use Pytorch Profiler para identificar los cuellos de botella de rendimiento introducidos por el cálculo métrico.
  3. Demostrar técnicas de optimización Para reducir la sobrecarga de la colección métrica.

Para facilitar nuestra discusión, definiremos un modelo de Pytorch de juguete y evaluaremos cómo la colección métrica puede afectar su rendimiento en tiempo de ejecución. Ejecutaremos nuestros experimentos en una GPU NVIDIA A40, con una Pytorch 2.5.1 Docker imagen y Torchmetrics 1.6.1.

Es importante tener en cuenta que el comportamiento de la recolección métrica puede variar mucho según el hardware, el entorno de tiempo de ejecución y la arquitectura de modelos. Los fragmentos de código proporcionados en esta publicación están destinados solo a fines demostrativos. No interprete nuestra mención de ninguna herramienta o técnica como un respaldo para su uso.

Modelo de Toy Resnet

En el bloque de código a continuación definimos un modelo de clasificación de imagen simple con un Resnet-18 columna vertebral.

import time
import torch
import torchvision

device = "cuda"

model = torchvision.models.resnet18().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters())

Definimos un conjunto de datos sintético que usaremos para entrenar nuestro modelo de juguete.

from torch.utils.data import Dataset, DataLoader

# A dataset with random images and labels
class FakeDataset(Dataset):
    def __len__(self):
        return 100000000

    def __getitem__(self, index):
        rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
        label = torch.tensor(data=index % 1000, dtype=torch.int64)
        return rand_image, label

train_set = FakeDataset()

batch_size = 128
num_workers = 12

train_loader = DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=True
)

Definimos una colección de métricas estándar de Torchmetrics, junto con un indicador de control para habilitar o deshabilitar el cálculo métrico.

from torchmetrics import (
    MeanMetric,
    Accuracy,
    Precision,
    Recall,
    F1Score,
)

# toggle to enable/disable metric collection
capture_metrics = False

if capture_metrics:
        metrics = {
        "avg_loss": MeanMetric(),
        "accuracy": Accuracy(task="multiclass", num_classes=1000),
        "precision": Precision(task="multiclass", num_classes=1000),
        "recall": Recall(task="multiclass", num_classes=1000),
        "f1_score": F1Score(task="multiclass", num_classes=1000),
    }

    # Move all metrics to the device
    metrics = {name: metric.to(device) for name, metric in metrics.items()}

A continuación, definimos un Perfilador de pytorch Instancia, junto con una bandera de control que nos permite habilitar o deshabilitar el perfil. Para obtener un tutorial detallado sobre el uso de Pytorch Profiler, consulte el primera publicación En esta serie.

from torch import profiler

# toggle to enable/disable profiling
enable_profiler = True

if enable_profiler:
    prof = profiler.profile(
        schedule=profiler.schedule(wait=10, warmup=2, active=3, repeat=1),
        on_trace_ready=profiler.tensorboard_trace_handler("./logs/"),
        profile_memory=True,
        with_stack=True
    )
    prof.start()

Por último, definimos un paso de entrenamiento estándar:

model.train()

t0 = time.perf_counter()
total_time = 0
count = 0

for idx, (data, target) in enumerate(train_loader):
    data = data.to(device, non_blocking=True)
    target = target.to(device, non_blocking=True)
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    if capture_metrics:
        # update metrics
        metrics["avg_loss"].update(loss)
        for name, metric in metrics.items():
            if name != "avg_loss":
                metric.update(output, target)

        if (idx + 1) % 100 == 0:
            # compute metrics
            metric_results = {
                name: metric.compute().item() 
                    for name, metric in metrics.items()
            }
            # print metrics
            print(f"Step {idx + 1}: {metric_results}")
            # reset metrics
            for metric in metrics.values():
                metric.reset()

    elif (idx + 1) % 100 == 0:
        # print last loss value
        print(f"Step {idx + 1}: Loss = {loss.item():.4f}")

    batch_time = time.perf_counter() - t0
    t0 = time.perf_counter()
    if idx > 10:  # skip first steps
        total_time += batch_time
        count += 1

    if enable_profiler:
        prof.step()

    if idx > 200:
        break

if enable_profiler:
    prof.stop()

avg_time = total_time/count
print(f'Average step time: {avg_time}')
print(f'Throughput: {batch_size/avg_time:.2f} images/sec')

Colección métrica

Para medir el impacto de la colección métrica en el paso de entrenamiento, ejecutamos nuestro script de entrenamiento con y sin cálculo métrico. Los resultados se resumen en la siguiente tabla.

La sobrecarga de la colección métrica ingenua (por autor)

¡Nuestra ingenua colección métrica resultó en una caída de casi el 10% en el rendimiento del tiempo de ejecución! Si bien la colección métrica es esencial para el desarrollo del aprendizaje automático, generalmente implica operaciones matemáticas relativamente simples y apenas garantiza una sobrecarga tan significativa. ¡¡¿Qué está pasando?!!

Identificar problemas de rendimiento con Pytorch Profiler

Para comprender mejor la fuente de la degradación del desempeño, volvemos al script de entrenamiento con el Pytorch Profiler habilitado. La traza resultante se muestra a continuación:

Trace de experimento de colección métrica (por autor)

La traza revela operaciones recurrentes de “cudastreamsynchronize” que coinciden con gotas notables en la utilización de GPU. Estos tipos de eventos de “sincronización de cpu-gpu” se discutieron en detalle en parte segunda de nuestra serie. En un paso de entrenamiento típico, la CPU y la GPU funcionan en paralelo: la CPU administra tareas como transferencias de datos a la carga de GPU y el núcleo, y la GPU ejecuta el modelo en los datos de entrada y actualiza sus pesos. Idealmente, nos gustaría minimizar los puntos de sincronización entre la CPU y la GPU para maximizar el rendimiento. Aquí, sin embargo, podemos ver que la colección métrica ha activado un evento de sincronización realizando una copia de datos de CPU a GPU. Esto requiere que la CPU suspenda su procesamiento hasta que la GPU se ponga al día, lo que, a su vez, hace que la GPU espere a que la CPU reanude la carga de las operaciones posteriores del núcleo. La conclusión es que estos puntos de sincronización conducen a una utilización ineficiente de la CPU y la GPU. Nuestra implicación de la colección métrica agrega ocho eventos de sincronización de estos a cada paso de entrenamiento.

Un examen más detallado de la traza muestra que los eventos de sincronización provienen del actualizar Llamada de la Medio metálico Torchmetric. Para el experto en perfil de perfil, esto puede ser suficiente para identificar la causa raíz, pero iremos un paso más allá y usaremos el torch.profiler.record_function utilidad para identificar la línea de código ofensiva exacta.

Perfil con registro_function

Para identificar la fuente exacta del evento de sincronización, extendimos el Medio metálico clase y anular el actualizar método utilizando registro_function Bloques de contexto. Este enfoque nos permite perfilar operaciones individuales dentro del método e identificar cuellos de botella de rendimiento.

class ProfileMeanMetric(MeanMetric):
    def update(self, value, weight = 1.0):
        # broadcast weight to value shape
        with profiler.record_function("process value"):
            if not isinstance(value, torch.Tensor):
                value = torch.as_tensor(value, dtype=self.dtype,
                                        device=self.device)
        with profiler.record_function("process weight"):
            if weight is not None and not isinstance(weight, torch.Tensor):
                weight = torch.as_tensor(weight, dtype=self.dtype,
                                         device=self.device)
        with profiler.record_function("broadcast weight"):
            weight = torch.broadcast_to(weight, value.shape)
        with profiler.record_function("cast_and_nan_check"):
            value, weight = self._cast_and_nan_check_input(value, weight)

        if value.numel() == 0:
            return

        with profiler.record_function("update value"):
            self.mean_value += (value * weight).sum()
        with profiler.record_function("update weight"):
            self.weight += weight.sum()

Luego actualizamos nuestra métrica AVG_LOSS para usar el recién creado ProfilineMeMeMetric y Reran el script de entrenamiento.

Trace de colección métrica con registro_function (por el autor)

La traza actualizada revela que el evento de sincronización se origina en la siguiente línea:

weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)

Esta operación convierte el valor escalar predeterminado weight=1.0 en un tensor de Pytorch y lo coloca en la GPU. El evento de sincronización ocurre porque esta acción desencadena una copia de datos de CPU a GPU, que requiere que la CPU espere a que la GPU procese el valor copiado.

Optimización 1: especificar el valor de peso

Ahora que hemos encontrado la fuente del problema, podemos superarlo fácilmente especificando un peso valor en nuestro actualizar llamar. Esto evita que el tiempo de ejecución convierta el escalar predeterminado weight=1.0 en un tensor en la GPU, evitando el evento de sincronización:

# update metrics
 if capture_metric:
     metrics["avg_loss"].update(loss, weight=torch.ones_like(loss))

Volver a ejecutar el guión después de aplicar este cambio revela que hemos logrado eliminar el evento de sincronización inicial … solo para haber descubierto uno nuevo, esta vez proveniente del _cast_and_nan_check_input función:

Trace de colección métrica después de la optimización 1 (por el autor)

Perfil con registro_function – Parte 2

Para explorar nuestro nuevo evento de sincronización, ampliamos nuestra métrica personalizada con sondas de perfiles adicionales y Reran nuestro script.

class ProfileMeanMetric(MeanMetric):
    def update(self, value, weight = 1.0):
        # broadcast weight to value shape
        with profiler.record_function("process value"):
            if not isinstance(value, torch.Tensor):
                value = torch.as_tensor(value, dtype=self.dtype,
                                        device=self.device)
        with profiler.record_function("process weight"):
            if weight is not None and not isinstance(weight, torch.Tensor):
                weight = torch.as_tensor(weight, dtype=self.dtype,
                                         device=self.device)
        with profiler.record_function("broadcast weight"):
            weight = torch.broadcast_to(weight, value.shape)
        with profiler.record_function("cast_and_nan_check"):
            value, weight = self._cast_and_nan_check_input(value, weight)

        if value.numel() == 0:
            return

        with profiler.record_function("update value"):
            self.mean_value += (value * weight).sum()
        with profiler.record_function("update weight"):
            self.weight += weight.sum()

    def _cast_and_nan_check_input(self, x, weight = None):
        """Convert input ``x`` to a tensor and check for Nans."""
        with profiler.record_function("process x"):
            if not isinstance(x, torch.Tensor):
                x = torch.as_tensor(x, dtype=self.dtype,
                                    device=self.device)
        with profiler.record_function("process weight"):
            if weight is not None and not isinstance(weight, torch.Tensor):
                weight = torch.as_tensor(weight, dtype=self.dtype,
                                         device=self.device)
            nans = torch.isnan(x)
            if weight is not None:
                nans_weight = torch.isnan(weight)
            else:
                nans_weight = torch.zeros_like(nans).bool()
                weight = torch.ones_like(x)

        with profiler.record_function("any nans"):
            anynans = nans.any() or nans_weight.any()

        with profiler.record_function("process nans"):
            if anynans:
                if self.nan_strategy == "error":
                    raise RuntimeError("Encountered `nan` values in tensor")
                if self.nan_strategy in ("ignore", "warn"):
                    if self.nan_strategy == "warn":
                        print("Encountered `nan` values in tensor."
                              " Will be removed.")
                    x = x[~(nans | nans_weight)]
                    weight = weight[~(nans | nans_weight)]
                else:
                    if not isinstance(self.nan_strategy, float):
                        raise ValueError(f"`nan_strategy` shall be float"
                                         f" but you pass {self.nan_strategy}")
                    x[nans | nans_weight] = self.nan_strategy
                    weight[nans | nans_weight] = self.nan_strategy

        with profiler.record_function("return value"):
            retval = x.to(self.dtype), weight.to(self.dtype)
        return retval

La traza resultante se captura a continuación:

Trace de colección métrica con registro_function – Parte 2 (por autor)

El rastreo apunta directamente a la línea ofensiva:

anynans = nans.any() or nans_weight.any()

Esta operación verifica para NaN Valores en los tensores de entrada, pero introduce un costoso evento de sincronización de CPU-GPU porque la operación implica copiar datos de la GPU a la CPU.

Sobre una inspección más cercana de la torchmetric Baseggregador Clase, encontramos varias opciones para manejar actualizaciones de valor NAN, todas las cuales pasan a través de la línea de código ofensiva. Sin embargo, para nuestro caso de uso, calculando la métrica de pérdida promedio, esta verificación es innecesaria y no justifica la penalización de rendimiento del tiempo de ejecución.

Optimización 2: Deshabilitar las comprobaciones de valor NAN

Para eliminar la sobrecarga, proponemos deshabilitar el NaN verificaciones de valor anulando el _cast_and_nan_check_input función. En lugar de una anulación estática, implementamos una solución dinámica que puede aplicarse de manera flexible a cualquier descendiente del Baseggregadorclase.

from torchmetrics.aggregation import BaseAggregator

def suppress_nan_check(MetricClass):
    assert issubclass(MetricClass, BaseAggregator), MetricClass
    class DisableNanCheck(MetricClass):
        def _cast_and_nan_check_input(self, x, weight=None):
            if not isinstance(x, torch.Tensor):
                x = torch.as_tensor(x, dtype=self.dtype, 
                                    device=self.device)
            if weight is not None and not isinstance(weight, torch.Tensor):
                weight = torch.as_tensor(weight, dtype=self.dtype,
                                         device=self.device)
            if weight is None:
                weight = torch.ones_like(x)
            return x.to(self.dtype), weight.to(self.dtype)
    return DisableNanCheck

NoNanMeanMetric = suppress_nan_check(MeanMetric)

metrics["avg_loss"] = NoNanMeanMetric().to(device)

Resultados de la optimización posterior: éxito

Después de implementar las dos optimizaciones, especificar el valor de peso y deshabilitar el NaN Verificaciones: encontramos el rendimiento del tiempo de paso y la utilización de la GPU para que coincida con las de nuestro experimento de referencia. Además, el rastro de perfilador de Pytorch resultante muestra que se han eliminado todos los eventos adicionales de “cudastreamsynchronize” que estaban asociados con la colección métrica. Con algunos pequeños cambios, hemos reducido el costo de la capacitación en ~ 10% sin ningún cambio en el comportamiento de la colección métrica.

En la siguiente sección exploraremos una optimización de colección métrica adicional.

Ejemplo 2: Optimización de la colocación de dispositivos métricos

En la sección anterior, los valores métricos residían en la GPU, lo que hace que sea lógico almacenar y calcular las métricas en la GPU. Sin embargo, en escenarios en los que los valores que deseamos residir en la CPU, podría ser preferible almacenar las métricas en la CPU para evitar transferencias innecesarias de dispositivos.

En el bloque de código a continuación, modificamos nuestro script para calcular el tiempo de paso promedio usando un Medio metálico en la CPU. Este cambio no tiene impacto en el rendimiento del tiempo de ejecución de nuestro paso de entrenamiento:

avg_time = NoNanMeanMetric()
t0 = time.perf_counter()

for idx, (data, target) in enumerate(train_loader):
    # move data to device
    data = data.to(device, non_blocking=True)
    target = target.to(device, non_blocking=True)

    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    if capture_metrics:
        metrics["avg_loss"].update(loss)
        for name, metric in metrics.items():
            if name != "avg_loss":
                metric.update(output, target)

        if (idx + 1) % 100 == 0:
            # compute metrics
            metric_results = {
                name: metric.compute().item()
                    for name, metric in metrics.items()
            }
            # print metrics
            print(f"Step {idx + 1}: {metric_results}")
            # reset metrics
            for metric in metrics.values():
                metric.reset()

    elif (idx + 1) % 100 == 0:
        # print last loss value
        print(f"Step {idx + 1}: Loss = {loss.item():.4f}")

    batch_time = time.perf_counter() - t0
    t0 = time.perf_counter()
    if idx > 10:  # skip first steps
        avg_time.update(batch_time)

    if enable_profiler:
        prof.step()

    if idx > 200:
        break

if enable_profiler:
    prof.stop()

avg_time = avg_time.compute().item()
print(f'Average step time: {avg_time}')
print(f'Throughput: {batch_size/avg_time:.2f} images/sec')

El problema surge cuando intentamos extender nuestro guión para apoyar la capacitación distribuida. Para demostrar el problema, modificamos nuestra definición de modelo para usar DistributedDataparallel (DDP) :

# toggle to enable/disable ddp
use_ddp = True

if use_ddp:
    import os
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group("nccl", rank=0, world_size=1)
    torch.cuda.set_device(0)
    model = DDP(torchvision.models.resnet18().to(device))
else:
    model = torchvision.models.resnet18().to(device)

# insert training loop

# append to end of the script:
if use_ddp:
    # destroy the process group
    dist.destroy_process_group()

La modificación DDP da como resultado el siguiente error:

RuntimeError: No backend type associated with device type cpu

Por defecto, las métricas en la capacitación distribuida están programadas para sincronizar en todos los dispositivos en uso. Sin embargo, el backend de sincronización utilizado por DDP no admite métricas almacenadas en la CPU.

Una forma de resolver esto es deshabilitar la sincronización métrica de servicio cruzado:

avg_time = NoNanMeanMetric(sync_on_compute=False)

En nuestro caso, donde estamos midiendo el tiempo promedio, esta solución es aceptable. Sin embargo, en algunos casos, la sincronización métrica es esencial, y es posible que no tengamos más remedio que mover la métrica a la GPU:

avg_time = NoNanMeanMetric().to(device)

Desafortunadamente, esta situación da lugar a un nuevo evento de sincronización de CPU-GPU que proviene del actualizar función.

Trace de colección métrica AVG_Time (por autor)

Este evento de sincronización difícilmente debería ser una sorpresa: después de todo, estamos actualizando una métrica de GPU con un valor que reside en la CPU, que debería requerir una copia de memoria. Sin embargo, en el caso de una métrica escalar, esta transferencia de datos se puede evitar completamente con una optimización simple.

Optimización 3: Realice actualizaciones métricas con tensores en lugar de escalar

La solución es sencilla: en lugar de actualizar la métrica con un valor flotante, nos convertimos a un tensor antes de llamar update.

batch_time = torch.as_tensor(batch_time)
avg_time.update(batch_time, torch.ones_like(batch_time))

Este cambio menor evita la línea problemática de código, elimina el evento de sincronización y restaura el paso de paso al rendimiento de línea de base.

A primera vista, este resultado puede parecer sorprendente: esperaríamos que actualizar una métrica de GPU con un tensor de CPU aún requiera una copia de memoria. Sin embargo, Pytorch optimiza las operaciones en tensores escalares mediante el uso de un núcleo dedicado que realiza la adición sin una transferencia de datos explícita. Esto evita el costoso evento de sincronización que de otro modo ocurriría.

Resumen

En esta publicación, exploramos cómo un enfoque ingenuo de la antorchmetrics puede introducir eventos de sincronización de CPU-GPU y degradar significativamente el rendimiento de entrenamiento de Pytorch. Usando Pytorch Profiler, identificamos las líneas de código responsables de estos eventos de sincronización y aplicó optimizaciones dirigidas para eliminarlos:

  • Especifique explícitamente un tensor de peso al llamar al MeanMetric.update función en lugar de confiar en el valor predeterminado.
  • Desactivar los controles NAN en la base Aggregator Clase o reemplácelos con una alternativa más eficiente.
  • Administre cuidadosamente la colocación del dispositivo de cada métrica para minimizar las transferencias innecesarias.
  • Deshabilite la sincronización métrica de servicio cruzado cuando no sea necesario.
  • Cuando la métrica reside en una GPU, convierta los escalares de punto flotante en tensores antes de pasarlos al update función para evitar la sincronización implícita.

Hemos creado un dedicado Solicitud de solicituden el Torchmetrics GithubPágina que cubre algunas de las optimizaciones discutidas en esta publicación. ¡No dude en contribuir con sus propias mejoras y optimizaciones!