1pohi3ogfv7133bb2rlmmqq.png

Los MDN toman su vieja y aburrida red neuronal y la convierten en una potencia de predicción. ¿Por qué conformarse con una predicción cuando puedes tener toda una variedad de resultados potenciales?

Si la vida le presenta escenarios complejos e impredecibles, los MDN están preparados con una red de seguridad cargada de probabilidades.

La idea central

En un MDN, la densidad de probabilidad de la variable objetivo t dada la entrada X se representa como una combinación lineal de funciones del núcleo, típicamente funciones gaussianas, aunque no se limitan a. En matemáticas habla:

donde 𝛼ᵢ(x) son los coeficientes de mezcla, y ¿a quién no le gusta una buena mezcla, verdad? 🎛️ Estos determinan cuánto peso cada componente 𝜙ᵢ(t|x) — cada gaussiano en nuestro caso se cumple en el modelo.

Elaborando los gaussianos ☕

Cada componente gaussiano 𝜙ᵢ(t|x) tiene su propia media 𝜇ᵢ(x) y varianza 𝜎ᵢ².

Mezclando 🎧 con coeficientes

Los coeficientes de mezcla 𝛼ᵢ son cruciales ya que equilibran la influencia de cada componente gaussiano, regido por un softmax función para garantizar que sumen 1:

Parámetros mágicos ✨ Medias y variaciones

Significa 𝜇ᵢ y variaciones 𝜎ᵢ² define cada gaussiano. ¿Y adivina qué? ¡Las variaciones tienen que ser positivas! Logramos esto usando el exponencial de las salidas de la red:

Muy bien, entonces, ¿cómo entrenamos a esta bestia? Bueno, se trata de maximizar la probabilidad de nuestros datos observados. Términos elegantes, lo sé. Veámoslo en acción.

El hechizo de la probabilidad logarítmica ✨

La probabilidad de nuestros datos bajo el modelo MDN es el producto de las probabilidades asignadas a cada punto de datos. En matemáticas habla:

Esto básicamente dice, «Oye, ¿cuál es la probabilidad de que tengamos estos datos dado nuestro modelo?». Pero los productos pueden complicarse, así que tomamos el registro (porque a las matemáticas les encantan los registros), lo que convierte nuestro producto en una suma:

Ahora, aquí está el truco: en realidad queremos minimizar la probabilidad de registro negativo porque a nuestros algoritmos de optimización les gusta minimizar las cosas. Entonces, conectando la definición de p(t|x)la función de error que realmente minimizamos es:

Esta fórmula puede parecer intimidante, pero solo dice que sumamos las probabilidades de registro en todos los puntos de datos y luego agregamos un signo negativo porque la minimización es nuestro problema.

Ahora aquí le mostramos cómo traducir nuestra magia a Python y puede encontrar el código completo. aquí:

La función de pérdida

def mdn_loss(alpha, sigma, mu, target, eps=1e-8):
target = target.unsqueeze(1).expand_as(mu)
m = torch.distributions.Normal(loc=mu, scale=sigma)
log_prob = m.log_prob(target)
log_prob = log_prob.sum(dim=2)
log_alpha = torch.log(alpha + eps) # Avoid log(0) disaster
loss = -torch.logsumexp(log_alpha + log_prob, dim=1)
return loss.mean()

Aquí está el desglose:

  1. target = target.unsqueeze(1).expand_as(mu): expande el objetivo para que coincida con la forma de mu.
  2. m = torch.distributions.Normal(loc=mu, scale=sigma): Crea una distribución normal.
  3. log_prob = m.log_prob(target): Calcule la probabilidad logarítmica.
  4. log_prob = log_prob.sum(dim=2): Suma de probabilidades logarítmicas.
  5. log_alpha = torch.log(alpha + eps): Calcular el registro de los coeficientes de mezcla.
  6. loss = -torch.logsumexp(log_alpha + log_prob, dim=1): Combine y log-sum-exp las probabilidades.
  7. return loss.mean(): Devuelve la pérdida promedio.

La red neuronal

Creemos una red neuronal que esté lista para manejar la magia:

class MDN(nn.Module):
def __init__(self, input_dim, output_dim, num_hidden, num_mixtures):
super(MDN, self).__init__()
self.hidden = nn.Sequential(
nn.Linear(input_dim, num_hidden),
nn.Tanh(),
nn.Linear(num_hidden, num_hidden),
nn.Tanh(),
)
self.z_alpha = nn.Linear(num_hidden, num_mixtures)
self.z_sigma = nn.Linear(num_hidden, num_mixtures * output_dim)
self.z_mu = nn.Linear(num_hidden, num_mixtures * output_dim)
self.num_mixtures = num_mixtures
self.output_dim = output_dim

def forward(self, x):
hidden = self.hidden(x)
alpha = F.softmax(self.z_alpha(hidden), dim=-1)
sigma = torch.exp(self.z_sigma(hidden)).view(-1, self.num_mixtures, self.output_dim)
mu = self.z_mu(hidden).view(-1, self.num_mixtures, self.output_dim)
return alpha, sigma, mu

Observe la softmax siendo aplicado a 𝛼ᵢ alpha = F.softmax(self.z_alpha(hidden), dim=-1)entonces suman 1 y la exponencial a 𝜎ᵢ sigma = torch.exp(self.z_sigma(hidden)).view(-1, self.num_mixtures, self.output_dim)para garantizar que sigan siendo positivos, como se explicó anteriormente.

La predicción

Obtener predicciones de los MDN es un poco complicado. Así es como se toma una muestra del modelo de mezcla:

def get_sample_preds(alpha, sigma, mu, samples=10):
N, K, T = mu.shape
sampled_preds = torch.zeros(N, samples, T)
uniform_samples = torch.rand(N, samples)
cum_alpha = alpha.cumsum(dim=1)
for i, j in itertools.product(range(N), range(samples)):
u = uniform_samples[i, j]
k = torch.searchsorted(cum_alpha[i], u).item()
sampled_preds[i, j] = torch.normal(mu[i, k], sigma[i, k])
return sampled_preds

Aquí está el desglose:

  1. N, K, T = mu.shape: Obtenga la cantidad de puntos de datos, componentes de la mezcla y dimensiones de salida.
  2. sampled_preds = torch.zeros(N, samples, T): Inicializa el tensor para almacenar predicciones muestreadas.
  3. uniform_samples = torch.rand(N, samples): Genera números aleatorios uniformes para muestreo.
  4. cum_alpha = alpha.cumsum(dim=1): Calcule la suma acumulada de los pesos de la mezcla.
  5. for i, j in itertools.product(range(N), range(samples)): recorra cada combinación de puntos de datos y muestras.
  6. u = uniform_samples[i, j]: obtiene un número aleatorio para la muestra actual.
  7. k = torch.searchsorted(cum_alpha[i], u).item(): Encuentre el índice de componentes de la mezcla.
  8. sampled_preds[i, j] = torch.normal(mu[i, k], sigma[i, k]): Muestra del componente gaussiano seleccionado.
  9. return sampled_preds: Devuelve el tensor de predicciones muestreadas.

Apliquemos MDN para predecir ‘Temperatura aparente’ usando un simple Conjunto de datos meteorológicos. Entrené un MDN con una red de 50 capas ocultas y ¿adivinen qué? ¡Es genial! 🎸

Encuentra el código completo aquí. Aquí hay algunos resultados: