Implementación de Softmax desde cero: evitando la trampa de la estabilidad numérica

En el aprendizaje profundo, los modelos de clasificación no sólo necesitan hacer predicciones, sino que también deben expresar confianza. Ahí es donde entra en juego la función de activación de Softmax. Softmax toma las puntuaciones crudas e ilimitadas producidas por una red neuronal y las transforma en una distribución de probabilidad bien definida, lo que permite interpretar cada resultado como la probabilidad de una clase específica.

Esta propiedad convierte a Softmax en una piedra angular de las tareas de clasificación de clases múltiples, desde el reconocimiento de imágenes hasta el modelado de lenguaje. En este artículo, desarrollaremos una comprensión intuitiva de cómo funciona Softmax y por qué los detalles de su implementación son más importantes de lo que parecen a primera vista. Consulta los CÓDIGOS COMPLETOS aquí.

Implementación de Softmax ingenuo

importar antorcha def softmax_naive(logits): exp_logits = torch.exp(logits) return exp_logits / exp_logits.sum(dim=1, keepdim=True)

Esta función implementa la activación de Softmax en su forma más sencilla. Exponencia cada logit y lo normaliza mediante la suma de todos los valores exponenciados entre clases, produciendo una distribución de probabilidad para cada muestra de entrada.

Si bien esta implementación es matemáticamente correcta y fácil de leer, es numéricamente inestable: los logits positivos grandes pueden provocar un desbordamiento y los logits negativos grandes pueden desbordarse hasta cero. Como resultado, esta versión debe evitarse en procesos de capacitación reales. Consulta los CÓDIGOS COMPLETOS aquí.

Logits de muestra y etiquetas de destino

Este ejemplo define un lote pequeño con tres muestras y tres clases para ilustrar casos normales y de falla. La primera y tercera muestras contienen valores logit razonables y se comportan como se esperaba durante el cálculo de Softmax. La segunda muestra incluye intencionalmente valores extremos (1000 y -1000) para demostrar inestabilidad numérica; aquí es donde la ingenua implementación de Softmax falla.

El tensor de objetivos especifica el índice de clase correcto para cada muestra y se utilizará para calcular la pérdida de clasificación y observar cómo se propaga la inestabilidad durante la retropropagación. Consulta los CÓDIGOS COMPLETOS aquí.

# Lote de 3 muestras, 3 clases logits = torch.tensor([
[2.0, 1.0, 0.1],
[1000.0, 1.0, -1000.0],
[3.0, 2.0, 1.0]
]require_grad=True) objetivos = torch.tensor([0, 2, 1])

Pase hacia adelante: salida de Softmax y el caso de falla

Durante el pase hacia adelante, la ingenua función Softmax se aplica a los logits para producir probabilidades de clase. Para valores logit normales (primera y tercera muestras), el resultado es una distribución de probabilidad válida donde los valores se encuentran entre 0 y 1 y la suma es 1.

Sin embargo, el segundo ejemplo expone claramente el problema numérico: exponenciar 1000 desbordamientos al infinito, mientras que -1000 subflujos a cero. Esto da como resultado operaciones no válidas durante la normalización, lo que produce valores NaN y probabilidades cero. Una vez que NaN aparece en esta etapa, contamina todos los cálculos posteriores, lo que hace que el modelo sea inutilizable para el entrenamiento. Consulta los CÓDIGOS COMPLETOS aquí.

# Problemas de pase directo = softmax_naive(logits) print(“Probabilidades de Softmax:”) print(probs)

Probabilidades objetivo y desglose de pérdidas

Aquí, extraemos la probabilidad predicha correspondiente a la clase verdadera para cada muestra. Mientras que la primera y tercera muestras devuelven probabilidades válidas, la probabilidad objetivo de la segunda muestra es 0,0, debido a un desbordamiento numérico insuficiente en el cálculo de Softmax. Cuando la pérdida se calcula usando -log(p), tomar el logaritmo de 0,0 da como resultado +∞.

Esto hace que la pérdida total sea infinita, lo cual es un fallo crítico durante el entrenamiento. Una vez que la pérdida se vuelve infinita, el cálculo del gradiente se vuelve inestable, lo que genera NaN durante la propagación hacia atrás y detiene efectivamente el aprendizaje. Consulta los CÓDIGOS COMPLETOS aquí.

# Extraer probabilidades objetivo target_probs = probs[torch.arange(len(targets)), targets]

print(“\nProbabilidades de destino:”) print(target_probs) # Calcular pérdida pérdida = -torch.log(target_probs).mean() print(“\nPérdida:”, pérdida)

Propagación hacia atrás: corrupción de gradiente

Cuando se activa la retropropagación, el impacto de la pérdida infinita se vuelve inmediatamente visible. Los gradientes para la primera y tercera muestra siguen siendo finitos porque sus salidas Softmax se comportaron bien. Sin embargo, la segunda muestra produce gradientes de NaN en todas las clases debido a la operación log(0) en la pérdida.

Estos NaN se propagan hacia atrás a través de la red, contaminando las actualizaciones de peso e interrumpiendo efectivamente el entrenamiento. Esta es la razón por la que la inestabilidad numérica en el límite de pérdida de Softmax es tan peligrosa: una vez que aparecen los NaN, la recuperación es casi imposible sin reiniciar el entrenamiento. Consulta los CÓDIGOS COMPLETOS aquí.

loss.backward() print(“\nGradientes:”) print(logits.grad)

Inestabilidad numérica y sus consecuencias

La separación de Softmax y la entropía cruzada crea un grave riesgo de estabilidad numérica debido al desbordamiento y el desbordamiento exponencial. Los logits grandes pueden llevar las probabilidades al infinito o a cero, provocando log(0) y generando gradientes de NaN que corrompen rápidamente el entrenamiento. A escala de producción, este no es un caso raro, sino una certeza: sin implementaciones estables y fusionadas, las grandes ejecuciones de entrenamiento de múltiples GPU fallarían de manera impredecible.

El problema numérico central proviene del hecho de que las computadoras no pueden representar números infinitamente grandes o infinitamente pequeños. Los formatos de punto flotante como FP32 tienen límites estrictos sobre qué tan grande o pequeño se puede almacenar un valor. Cuando Softmax calcula exp(x), los valores positivos grandes crecen tan rápido que exceden el número máximo representable y se vuelven infinitos, mientras que los valores negativos grandes se reducen tanto que se vuelven cero. Una vez que un valor se vuelve infinito o cero, las operaciones posteriores, como la división o los logaritmos, se descomponen y producen resultados no válidos. Consulta los CÓDIGOS COMPLETOS aquí.

Implementación de una pérdida de entropía cruzada estable mediante LogSumExp

Esta implementación calcula la pérdida de entropía cruzada directamente a partir de logits sin calcular explícitamente las probabilidades de Softmax. Para mantener la estabilidad numérica, los logits primero se desplazan restando el valor máximo por muestra, asegurando que los exponenciales se mantengan dentro de un rango seguro.

Luego se utiliza el truco LogSumExp para calcular el término de normalización, después del cual se resta el logit objetivo original (sin desplazamiento) para obtener la pérdida correcta. Este enfoque evita el desbordamiento, el desbordamiento insuficiente y los gradientes de NaN, y refleja cómo se implementa la entropía cruzada en marcos de aprendizaje profundo de nivel de producción. Consulta los CÓDIGOS COMPLETOS aquí.

def stable_cross_entropy(logits, objetivos): # Encontrar logit máximo por muestra max_logits, _ = torch.max(logits, dim=1, keepdim=True) # Desplazar logits para estabilidad numérica shifted_logits = logits – max_logits # Calcular LogSumExp log_sum_exp = torch.log(torch.sum(torch.exp(shifted_logits), dim=1)) + max_logits.squeeze(1) # Calcular la pérdida usando logits ORIGINALES loss = log_sum_exp – logits[torch.arange(len(targets)), targets]

devolver pérdida.media()

Pase estable hacia adelante y hacia atrás

Ejecutar la implementación estable de entropía cruzada en los mismos logits extremos produce una pérdida finita y gradientes bien definidos. Aunque una muestra contiene valores muy grandes (1000 y -1000), la formulación LogSumExp mantiene todos los cálculos intermedios en un rango numérico seguro. Como resultado, la retropropagación se completa con éxito sin producir NaN y cada clase recibe una señal de gradiente significativa.

Esto confirma que la inestabilidad vista anteriormente no fue causada por los datos en sí, sino por la ingenua separación de Softmax y la entropía cruzada, un problema completamente resuelto mediante el uso de una formulación de pérdida fusionada numéricamente estable. Consulta los CÓDIGOS COMPLETOS aquí.

logits = antorcha.tensor([
[2.0, 1.0, 0.1],
[1000.0, 1.0, -1000.0],
[3.0, 2.0, 1.0]
]require_grad=True) objetivos = torch.tensor([0, 2, 1]) pérdida = stable_cross_entropy(logits, objetivos) print(“Pérdida estable:”, pérdida) pérdida.backward() print(“\nGradientes:”) print(logits.grad)

Conclusión

En la práctica, la brecha entre las fórmulas matemáticas y el código del mundo real es donde se originan muchos fracasos en la capacitación. Si bien Softmax y la entropía cruzada están matemáticamente bien definidos, su implementación ingenua ignora los límites de precisión finitos del hardware IEEE 754, lo que hace inevitable el desbordamiento y el desbordamiento.

La solución clave es simple pero crítica: desplazar los logits antes de la exponenciación y operar en el dominio logarítmico siempre que sea posible. Lo más importante es que el entrenamiento rara vez requiere probabilidades explícitas: las probabilidades logarítmicas estables son suficientes y mucho más seguras. Cuando una pérdida repentinamente se convierte en NaN en producción, a menudo es una señal de que Softmax se está calculando manualmente en algún lugar donde no debería.

Consulta los CÓDIGOS COMPLETOS aquí. Además, no dude en seguirnos en Twitter y no olvide unirse a nuestro SubReddit de más de 100.000 ML y suscribirse a nuestro boletín. ¡Esperar! estas en telegrama? Ahora también puedes unirte a nosotros en Telegram.

Consulte nuestra última versión de ai2025.dev, una plataforma de análisis centrada en 2025 que convierte los lanzamientos de modelos, los puntos de referencia y la actividad del ecosistema en un conjunto de datos estructurado que puede filtrar, comparar y exportar.

Soy graduado en ingeniería civil (2022) de Jamia Millia Islamia, Nueva Delhi, y tengo un gran interés en la ciencia de datos, especialmente las redes neuronales y su aplicación en diversas áreas.