Pérdida focal frente a entropía cruzada binaria: una guía práctica para la clasificación desequilibrada

La entropía cruzada binaria (BCE) es la función de pérdida predeterminada para la clasificación binaria, pero se descompone gravemente en conjuntos de datos desequilibrados. La razón es sutil pero importante: BCE pondera los errores de ambas clases por igual, incluso cuando una clase es extremadamente rara.

Imaginemos dos predicciones: una muestra de clase minoritaria con la etiqueta verdadera 1 predicha en 0,3, y una muestra de clase mayoritaria con la etiqueta verdadera 0 predicha en 0,7. Ambos producen el mismo valor BCE: −log(0,3). ¿Pero deberían tratarse estos dos errores por igual? En un conjunto de datos desequilibrado, definitivamente no: el error en la muestra minoritaria es mucho más costoso.

Aquí es exactamente donde entra en juego la pérdida focal. Reduce la contribución de predicciones fáciles y seguras y amplifica el impacto de ejemplos difíciles de clases minoritarias. Como resultado, el modelo se centra menos en la clase mayoritaria abrumadoramente fácil y más en los patrones que realmente importan. Consulta los CÓDIGOS COMPLETOS aquí.

En este tutorial, demostramos este efecto entrenando dos redes neuronales idénticas en un conjunto de datos con una relación de desequilibrio de 99:1 (una usando BCE y la otra usando Focal Loss) y comparando su comportamiento, regiones de decisión y matrices de confusión. Consulta los CÓDIGOS COMPLETOS aquí.

Instalando las dependencias

pip instala numpy pandas matplotlib scikit-learn torch

Crear un conjunto de datos desequilibrado

Creamos un conjunto de datos de clasificación binaria sintética con un desequilibrio de 99:1 con 6000 muestras usando make_classification. Esto garantiza que casi todas las muestras pertenezcan a la clase mayoritaria, lo que la convierte en una configuración ideal para demostrar por qué BCE tiene problemas y cómo ayuda Focal Loss. Consulta los CÓDIGOS COMPLETOS aquí.

importar numpy como np importar matplotlib.pyplot como plt desde sklearn.datasets importar make_classification desde sklearn.model_selection importar train_test_split importar torch importar torch.nn as nn importar torch.optim as optim # Generar conjunto de datos desequilibrado X, y = make_classification( n_samples=6000, n_features=2, n_redundante=0, n_clusters_per_class=1, pesos=[0.99, 0.01]class_sep=1.5, random_state=42 ) X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.3, random_state=42 ) X_train = torch.tensor(X_train, dtype=torch.float32) y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1) X_test = torch.tensor(X_test, dtype=torch.float32) y_test = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

Creando la red neuronal

Definimos una red neuronal simple con dos capas ocultas para mantener el experimento liviano y enfocado en las funciones de pérdida. Esta pequeña arquitectura es suficiente para conocer el límite de decisión en nuestro conjunto de datos 2D y al mismo tiempo resalta claramente las diferencias entre BCE y Focal Loss. Consulta los CÓDIGOS COMPLETOS aquí.

clase SimpleNN(nn.Module): def __init__(self): super().__init__() self.layers = nn.Sequential( nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 1), nn.Sigmoid() ) def adelante(self, x): regresar auto.capas(x)

Implementación de pérdida focal

Esta clase implementa la función Focal Loss, que modifica la entropía cruzada binaria al reducir el peso de los ejemplos sencillos y centrar el entrenamiento en muestras difíciles y mal clasificadas. El término gamma controla la agresividad con la que se suprimen las muestras fáciles, mientras que alfa asigna mayor peso a la clase minoritaria. Juntos, ayudan al modelo a aprender mejor sobre conjuntos de datos desequilibrados. Consulta los CÓDIGOS COMPLETOS aquí.

clase FocalLoss(nn.Module): def __init__(self, alpha=0.25, gamma=2): super().__init__() self.alpha = alpha self.gamma = gamma def adelante(self, preds, objetivos): eps = 1e-7 preds = torch.clamp(preds, eps, 1 – eps) pt = torch.where(targets == 1, preds, 1 – preds) pérdida = -self.alpha * (1 – pt) ** self.gamma * torch.log(pt) return pérdida.mean()

Entrenando el modelo

Definimos un ciclo de entrenamiento simple que optimiza el modelo usando la función de pérdida elegida y evalúa la precisión en el conjunto de prueba. Luego entrenamos dos redes neuronales idénticas, una con pérdida BCE estándar y la otra con pérdida focal, lo que nos permite comparar directamente cómo se desempeña cada función de pérdida en el mismo conjunto de datos desequilibrado. Las precisiones impresas resaltan la brecha de rendimiento entre BCE y Focal Loss.

Aunque BCE muestra una precisión muy alta (98%), esto es engañoso porque el conjunto de datos está muy desequilibrado: predecir casi todo como la clase mayoritaria aún produce una alta precisión. Focal Loss, por otro lado, mejora la detección de clases minoritarias, por lo que su precisión ligeramente mayor (99%) es mucho más significativa en este contexto. Consulta los CÓDIGOS COMPLETOS aquí.

def train(model, loss_fn, lr=0.01, epochs=30): opt = optim.Adam(model.parameters(), lr=lr) for _ in range(epochs): preds = model(X_train) loss = loss_fn(preds, y_train) opt.zero_grad() loss.backward() opt.step() con torch.no_grad(): test_preds = model(X_test) test_acc = ((test_preds > 0.5).float() == y_test).float().mean().item() return test_acc, test_preds.squeeze().detach().numpy() # Modelos model_bce = SimpleNN() model_focal = SimpleNN() acc_bce, preds_bce = train(model_bce, nn.BCELoss()) acc_focal, preds_focal = train(model_focal, FocalLoss(alpha=0.25, gamma=2)) print(“Precisión de la prueba (BCE):”, acc_bce) print(“Precisión de la prueba (pérdida focal):”, acc_focal)

Trazar el límite de decisión

El modelo BCE produce un límite de decisión casi plano que predice sólo la clase mayoritaria, ignorando por completo las muestras minoritarias. Esto sucede porque, en un conjunto de datos desequilibrado, BCE está dominado por ejemplos de clases mayoritarias y aprende a clasificar casi todo como esa clase. Por el contrario, el modelo de pérdida focal muestra un límite de decisión mucho más refinado y significativo, identificando con éxito más regiones de clases minoritarias y capturando patrones que BCE no logra aprender. Consulta los CÓDIGOS COMPLETOS aquí.

def plot_decision_boundary(modelo, título): # Crea una cuadrícula x_min, x_max = X[:,0].min()-1,X[:,0].max()+1 y_min, y_max = X[:,1].min()-1,X[:,1].max()+1 xx, yy = np.meshgrid( np.linspace(x_min, x_max, 300), np.linspace(y_min, y_max, 300) ) grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()]dtype=torch.float32) con torch.no_grad(): Z = model(grid).reshape(xx.shape) # Trazar plt.contourf(xx, yy, Z, niveles=[0,0.5,1]alfa=0,4) plt.scatter(X[:,0]X[:,1]c=y, cmap=’coolwarm’, s=10) plt.title(title) plt.show() plot_decision_boundary(model_bce, “Límite de decisión – Pérdida BCE”) plot_decision_boundary(model_focal, “Límite de decisión – Pérdida focal”)

Trazando la matriz de confusión

En la matriz de confusión del modelo BCE, la red identifica correctamente sólo una muestra de clase minoritaria, mientras que clasifica erróneamente a 27 de ellas como clase mayoritaria. Esto muestra que el BCE fracasa en predecir casi todo como clase mayoritaria debido al desequilibrio. Por el contrario, el modelo Focal Loss predice correctamente 14 muestras minoritarias y reduce las clasificaciones erróneas de 27 a 14. Esto demuestra cómo Focal Loss pone más énfasis en ejemplos concretos de clases minoritarias, lo que permite al modelo aprender un límite de decisión que realmente captura la clase rara en lugar de ignorarla. Consulta los CÓDIGOS COMPLETOS aquí.

de sklearn.metrics importe confusion_matrix, ConfusionMatrixDisplay def plot_conf_matrix(y_true, y_pred, title): cm = confusion_matrix(y_true, y_pred) disp = ConfusionMatrixDisplay(confusion_matrix=cm) disp.plot(cmap=”Blues”, value_format=”d”) plt.title(title) plt.show() # Convertir tensores de antorcha a numpy y_test_np = y_test.numpy().astype(int) preds_bce_label = (preds_bce > 0.5).astype(int) preds_focal_label = (preds_focal > 0.5).astype(int) plot_conf_matrix(y_test_np, preds_bce_label, “Matriz de confusión – Pérdida de BCE”) plot_conf_matrix(y_test_np, preds_focal_label, “Matriz de confusión – Pérdida focal”)

Consulta los CÓDIGOS COMPLETOS aquí. No dude en consultar nuestra página de GitHub para tutoriales, códigos y cuadernos. Además, no dude en seguirnos en Twitter y no olvide unirse a nuestro SubReddit de más de 100.000 ML y suscribirse a nuestro boletín. ¡Esperar! estas en telegrama? Ahora también puedes unirte a nosotros en Telegram.

Soy graduado en ingeniería civil (2022) de Jamia Millia Islamia, Nueva Delhi, y tengo un gran interés en la ciencia de datos, especialmente las redes neuronales y su aplicación en diversas áreas.

🙌 Siga MARKTECHPOST: agréguenos como fuente preferida en Google.