0qrtxxvyp8lddvl9u.jpeg

Este es el tipo de cosas que cualquiera que haya pasado mucho tiempo trabajando con transformadores y con atención personal habrá escuchado cientos de veces. Es absolutamente cierto, todos hemos experimentado esto cuando intentas aumentar el tamaño del contexto de tu modelo, todo de repente se detiene. Pero al mismo tiempo, al parecer prácticamente cada semana, hay un nuevo modelo de última generación con una longitud de contexto que bate un nuevo récord. (¡Géminis tiene una longitud de contexto de 2 millones de tokens!)

Hay muchos métodos sofisticados como RingAttention que hacen posible el entrenamiento de longitudes de contexto increíblemente largas en grandes sistemas distribuidos, pero lo que me interesa hoy es una pregunta más simple.

¿Hasta dónde podemos llegar sólo con la atención lineal?

Esta será una gira breve, pero tengan paciencia mientras tocamos algunos puntos clave antes de profundizar en los resultados.

Básicamente podemos resumir el mecanismo de atención tradicional en dos puntos clave:

  • Primero, la expresión de atención típica de softmax toma el producto de las matrices de consulta y clave, normaliza la estabilidad y luego toma softmax (por filas) para obtener las puntuaciones de atención entre cada elemento de la secuencia.
  • En segundo lugar, la complejidad temporal está dominada por los productos escalares N², y el que está dentro de softmax es el factor limitante. Ahí es donde calculamos las puntuaciones de atención.

Esto se expresa en la forma tradicional como:

Formulación tradicional del mecanismo de atención softmax.

Resulta que si les preguntamos a nuestros amigos matemáticos podemos pensar sobre esto de manera ligeramente diferente. Se puede considerar el softmax como una de las muchas formas de describir la distribución de probabilidad que relaciona los tokens entre sí. Podemos usar cualquier medida de similitud que queramos (el producto escalar es uno de los más simples) y siempre que lo normalicemos, estamos bien.

Expresión general de atención utilizando cualquier función de similitud.

Es un poco descuidado decir esto es atención, ya que de hecho es solo la atención que conocemos y amamos cuando la función de similitud es el exponencial del producto escalar de consultas y claves (que se muestra a continuación) como encontramos en softmax. Pero aquí es donde se pone interesante, si en lugar de usar esta expresión, ¿qué pasaría si pudiéramos aproximarnos a ella?

Aproxima la función de similitud a partir de la autoatención con dos mapas de características.

Podemos asumir que hay algún mapa de características «fi”lo que nos da un resultado cerca de lo mismo que tomar la exponencial del producto escalar. Y, lo que es más importante, escribir la expresión de esta manera nos permite jugar con el orden de las operaciones de multiplicación de matrices.

En el papel Proponen la Unidad Lineal Exponencial (ELU) como mapa de características debido a una serie de propiedades útiles:

  1. Para valores superiores a 0, el ELU(x) da un resultado lineal, que si bien no es el mismo que el exponencial, preserva el orden relativo entre las puntuaciones.
  2. Para valores menores o iguales a 0, el término exponencial preserva la naturaleza continua de la función y garantiza que los gradientes no desaparezcan.

No dedicaremos mucho más tiempo a esto aquí, pero está bastante bien verificado empíricamente como una aproximación justa a la función softmax.

Esto lo que nos permite es cambiar el orden de las operaciones. Podemos tomar el producto de nuestro mapa de características de K con V primero para hacer un bloque KV, luego el producto con Q. El producto cuadrado supera el tamaño de la dimensión del modelo en lugar de la longitud de la secuencia.

Poner todo esto junto en la expresión de atención lineal nos da:

Atención lineal utilizando mapas de características para aproximar la puntuación de similitud de softmax.

Donde solo necesitamos calcular los términos entre paréntesis una vez por fila de consulta.

(Si desea profundizar en cómo encaja el enmascaramiento casual en esto y cómo se calculan los gradientes, eche un vistazo al artículo. O mire este espacio para un blog futuro).

El argumento matemático es sólido, pero personalmente, hasta que veo algunos puntos de referencia, siempre desconfío un poco.

Comencemos mirando los fragmentos de código para describir cada uno de estos términos. La atención de softmax le resultará muy familiar, no estamos haciendo nada sofisticado aquí.

class TraditionalAttention(nn.Module):
def __init__(self, d_k):
super(TraditionalAttention, self).__init__()
self.d_k = d_k

def forward(self, Q, K, V):
Z = torch.sqrt(torch.tensor(self.d_k, device=Q.device, dtype=torch.float32))
scores = torch.matmul(Q, K.transpose(-2, -1)) / Z
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output

Luego, para la atención lineal, comenzamos obteniendo las matrices de consulta, clave y valor, luego aplicamos el mapeo de características ELU (x) a la consulta y las claves. Luego usamos la notación einsum para realizar las multiplicaciones.

class LinearAttention(nn.Module):
def __init__(self):
super(LinearAttention, self).__init__()
self.eps = 1e-6

def elu_feature_map(self, x):
return F.elu(x) + 1

def forward(self, Q, K, V):
Q = self.elu_feature_map(Q)
K = self.elu_feature_map(K)
KV = torch.einsum("nsd,nsd->ns", K, V)
# Compute the normalizer
Z = 1/(torch.einsum("nld,nd->nl", Q, K.sum(dim=1))+self.eps)
# Finally compute and return the new values
V = torch.einsum("nld,ns,nl->nd", Q, KV, Z)
return V.contiguous()

Ver esto escrito en código está muy bien, pero ¿qué significa realmente experimentalmente? ¿De qué aumento de rendimiento estamos hablando aquí? Puede ser difícil apreciar el grado de aceleración al pasar de un cuello de botella cuadrático a uno lineal, por lo que realicé el siguiente experimento.

Tomaremos una única capa de atención, con una dimensión de modelo d_k fija de 64, y compararemos el tiempo necesario para un avance de un conjunto de secuencias de tamaño de lote 32. La única variable que se cambiará será la longitud de la secuencia, que abarca desde 128 hasta 6000 (la longitud del contexto GPT-3 como referencia es 2048). Cada ejecución se realiza 100 veces para obtener una media y una desviación estándar, y los experimentos se realizan utilizando una GPU Nvidia T4.

Para un experimento tan simple, los resultados son bastante sorprendentes.

Puntos de referencia: Medición del tiempo por iteración para una secuencia única con atención tradicional (softmax) y atención lineal. La longitud de cada secuencia se promedia en 100 iteraciones y se traza la desviación estándar. Las longitudes de secuencia utilizadas varían de 128 a 6000. También se muestra que la relación mide más fácilmente el aumento del rendimiento.

Los resultados muestran que incluso para un ejemplo de juguete increíblemente pequeño obtenemos una velocidad de hasta 60x.

Discusión

Aquí hay algunas conclusiones obvias:

  1. La ventaja de la atención lineal es enorme: ya sea en velocidad, un mayor rendimiento siempre es algo bueno. O en términos de requisitos de memoria para procesar secuencias largas. En entornos con poca memoria, esto podría ser una gran ventaja.
  2. El gráfico de proporción tiene un problema sorprendente: nos lleva a sospechar que aquí se está produciendo alguna optimización adicional de nivel inferior, lo que significa que la proporción esperada no se materializa del todo. Por eso debemos tomar este resultado con una pizca de sal.

Para completar, tampoco confunda esto con decir «La atención lineal es 60 veces más rápida para modelos pequeños». En realidad, las capas de retroalimentación suelen ser una parte más grande de los parámetros en un Transformer y la codificación/decodificación suele ser también un componente de tamaño limitante. Pero en este problema tan bien definido, ¡bastante impresionante!