Una simple implementación del mecanismo de atención desde cero

El Mecanismo de atención a menudo se asocia con la arquitectura del transformador, pero ya se usó en RNN. En tareas de traducción automática o tareas MT (por ejemplo, inglés-italiano), cuando desea predecir la próxima palabra italiana, necesita que su modelo se concentre o preste atención a las palabras en inglés más importantes que son útiles para hacer una buena traducción.

No entraré en detalles de los RNN, pero la atención ayudó a estos modelos a mitigar el problema de gradiente de desaparición y capturar más dependencias de largo alcance entre las palabras.

En cierto punto, entendimos que lo único importante era el mecanismo de atención, y toda la arquitectura RNN era exagerada. Por eso, ¡La atención es todo lo que necesitas!

Autoatención en transformadores

La atención clásica indica cuando las palabras en la secuencia de salida deben centrar la atención en relación con las palabras en la secuencia de entrada. Esto es importante en tareas de secuencia a secuencia como MT.

El autoenvío es un tipo específico de atención. Funciona entre dos elementos en la misma secuencia. Proporciona información sobre cómo “correlacionadas” las palabras están en la misma oración.

Para un token (o palabra) dado en una secuencia, la autoatición genera una lista de pesos de atención correspondientes a todos los demás tokens en la secuencia. Este proceso se aplica a cada token en la oración, obteniendo una matriz de pesos de atención (como en la imagen).

Esta es la idea general, en la práctica, las cosas son un poco más complicadas porque queremos agregar muchos parámetros aprendibles a nuestra red neuronal, veamos cómo.

K, V, Q Representaciones

Nuestra entrada de modelo es una oración como “Mi nombre es Marcello Politi “. Con el proceso de tokenizaciónuna oración se convierte en una lista de números como [2, 6, 8, 3, 1].

Antes de alimentar la oración al transformador, necesitamos crear una representación densa para cada token.

¿Cómo crear esta representación? Multiplicamos cada token por una matriz. La matriz se aprende durante el entrenamiento.

Agreguemos algo de complejidad ahora.

Para cada token, creamos 3 vectores en lugar de uno, llamamos a estos vectores: clave, valor y consulta. (Vemos más tarde cómo creamos estos 3 vectores).

Conceptualmente, estos 3 tokens tienen un significado particular:

  • La clave vector representa la información central capturada por el token
  • El valor vectorial captura la información completa de un token
  • La consulta vectorial, es una pregunta sobre la relevancia del token para la tarea actual.

Entonces, la idea es que nos centramos en una token en particular I, y queremos preguntar cuál es la importancia de los otros tokens en la oración con respecto a la ficha que estamos teniendo en cuenta.

Esto significa que tomamos el vector Q_i (hacemos una pregunta sobre I) para el token I, y hacemos algunas operaciones matemáticas con todos los otros tokens k_j (j! = I). Esto es como preguntarse a primera vista cuáles son los otros tokens en la secuencia que parecen realmente importantes para comprender el significado del token i.

¿Qué es esta operación matemática mágica?

Necesitamos multiplicar (producir puntos) el vector de consulta por los vectores clave y dividir por un factor de escala. Hacemos esto para cada token K_J.

De esta manera, obtenemos una puntuación para cada par (Q_i, K_J). Hacemos que esta lista se convierta en una distribución de probabilidad aplicando una operación Softmax en ella. Genial ahora hemos obtenido el Pesos de atención!

Con los pesos de atención, sabemos cuál es la importancia de cada token k_j para para desanimar el token i. Entonces, ahora multiplicamos el Vector V_J de valor asociado con cada token por su peso y sumamos los vectores. De esta manera obtenemos la final vector consciente de contexto de token_i.

Si estamos calculando el vector denso contextual de token_1, calculamos:

z1 = a11*v1 + a12*v2 +… + a15*v5

Donde A1J son los pesos de atención de la computadora, y V_J son los vectores de valor.

¡Hecho! Casi…

No cubrí cómo obtuvimos los vectores K, V y Q de cada token. Necesitamos definir algunas matrices W_K, W_V y W_Q para que cuando multipliquemos:

  • Token * W_K -> K
  • Token * W_Q -> Q
  • Token * W_V -> V

Estas 3 matrices se establecen al azar y se aprenden durante la capacitación, es por eso que tenemos muchos parámetros en modelos modernos como LLMS.

Autoatención de múltiples cabezas en Transformers (MHSA)

¿Estamos seguros de que el mecanismo de autoatención anterior es capaz de capturar todas las relaciones importantes entre los tokens (palabras) y crear vectores densos de esas fichas que realmente tienen sentido?

En realidad, no podría funcionar siempre perfectamente. ¿Qué pasa si para mitigar el error volvemos a ejecutar todo 2 veces con las nuevas matrices W_Q, W_K y W_V y de alguna manera fusionamos los 2 vectores densos obtenidos? De esta manera, tal vez una autoatención logró capturar alguna relación y la otra logró capturar otra relación.

Bueno, esto es lo que sucede exactamente en MHSA. El caso que acabamos de discutir contiene dos cabezas porque tiene dos conjuntos de matrices W_Q, W_K y W_V. Podemos tener aún más cabezas: 4, 8, 16 etc.

Lo único complicado es que todas estas cabezas se gestionan en paralelo, procesamos todo en el mismo cálculo usando tensores.

La forma en que fusionamos los vectores densos de cada cabeza es simple, los concatenamos (por lo tanto, la dimensión de cada vector será más pequeña para que cuando los concatezcamos la dimensión original que queríamos), y pasamos el vector obtenido a través de otra matriz aprendida W_O.

Práctico

Supongamos que tienes una oración. Después de la tokenización, cada token (palabra por simplicidad) corresponde a un índice (número):

Antes de alimentar la oración a la transferencia, necesitamos crear una representación densa para cada token.

¿Cómo crear esta representación? Multiplicamos cada token por matriz. Esta matriz se aprende durante el entrenamiento.

Construyamos esta matriz de incrustación.

Si multiplicamos nuestra oración tokenizada con los incrustaciones, obtenemos una representación densa de la dimensión 16 para cada token

Para usar el mecanismo de atención, necesitamos crear 3 nuevos, definimos 3 matrices w_q, w_k y w_v. Cuando multiplicamos un tiempo de token de entrada, W_Q obtenemos el vector q. Lo mismo con W_K y W_V.

Calcular pesas de atención

Ahora calculemos los pesos de atención solo para el primer token de entrada de la oración.

Necesitamos multiplicar el vector de consulta asociado a Token1 (Query_1) con todas las claves de los otros vectores.

Entonces, ahora necesitamos calcular todas las teclas (Key_2, Key_2, Key_4, Key_5). Pero espere, podemos calcular todo en una vez multiplicando la oración_mbed de la matriz W_K.

Hagamos lo mismo con los valores

Calculemos la primera parte de la fórmula Attions.

import torch.nn.functional as F

Con los pesos de atención sabemos cuál es la importancia de cada token. Así que ahora multiplicamos el vector de valor asociado a cada token según su peso.

Para obtener el contexto final de vector consciente de token_1.

De la misma manera, podríamos calcular los vectores densos conscientes del contexto de todos los otros tokens. Ahora siempre estamos usando las mismas matrices W_K, W_Q, W_V. Decimos que usamos una cabeza.

Pero podemos tener múltiples trillizos de matrices, por lo que de múltiples cabezas. Por eso se llama atención múltiple.

Los vectores densos de un tokens de entrada, dado en Oputut desde cada cabeza, se concatenan y se transforman linealmente para obtener el vector denso final.

Implementación de la atención múltiple

Los mismos pasos que antes …

Definiremos un mecanismo de atención de múltiples cabezas con cabezas H (digamos 4 cabezas para este ejemplo). Cada cabezal tendrá su propia matrices W_Q, W_K y W_V, y la salida de cada cabezal se concatenará y pasará a través de una capa lineal final.

Dado que la salida de la cabeza se concatenará, y queremos una dimensión final de D, la dimensión de cada cabeza debe ser d/h. Además, cada vector concatenado irá a través de una transformación lineal, por lo que necesitamos otra matriz w_ouptut como puede ver en la fórmula.

Como tenemos 4 cabezas, queremos 4 copias para cada matriz. En lugar de copias, agregamos una dimensión, que es lo mismo, pero solo hacemos una operación. (Imagine las matrices de apilamiento uno encima del otro, es lo mismo).

Estoy usando para la simplicidad de Einsum de Torch. Si no estás familiarizado con él, mira mi blog.

La operación Einsum torch.einsum('sd,hde->hse', sentence_embed, w_query) En Pytorch usa letras para definir cómo multiplicar y reorganizar números. Esto es lo que significa cada parte:

  1. Tensores de entrada:
    • sentence_embed con la notación 'sd':
      • s representa el número de palabras (longitud de secuencia), que es 5.
      • d representa el número de números por palabra (tamaño de incrustación), que es 16.
      • La forma de este tensor es [5, 16].
    • w_query con la notación 'hde':
      • h Representa el número de cabezas, que es 4.
      • d Representa el tamaño de incrustación, que nuevamente es 16.
      • e Representa el nuevo tamaño de número por cabeza (D_K), que es 4.
      • La forma de este tensor es [4, 16, 4].
  2. Tensor de salida:
    • La salida tiene la notación 'hse':
      • h representa 4 cabezas.
      • s representa 5 palabras.
      • e representa 4 números por cabeza.
      • La forma del tensor de salida es [4, 5, 4].

Esta ecuación de Einsum realiza un producto DOT entre las consultas (HSE) y las claves transpuestas (HEK) para obtener decenas de forma [h, seq_len, seq_len]dónde:

  • H -> Número de cabezas.
  • S y K -> Longitud de secuencia (número de tokens).
  • E -> Dimensión de cada cabezal (D_K).

La división de (D_K ** 0.5) escala los puntajes para estabilizar los gradientes. Softmax se aplica para obtener pesos de atención:

Ahora concatenamos todas las cabezas de Token 1

Finalmente multiplicemos por la última matriz W_Output como en la fórmula anterior

Pensamientos finales

En esta publicación de blog he implementado una versión simple del mecanismo de atención. No es así como se implementa realmente en los marcos modernos, pero mi alcance es proporcionar algunas ideas para permitir a cualquier persona comprender cómo funciona. En futuros artículos pasaré por toda la implementación de una arquitectura de transformador.

Seguirme TDS ¡Si te gusta este artículo! 😁

💼 LinkedIn ️ | 🐦 X (Twitter) | 💻 Sitio web


A menos que se indique lo contrario, las imágenes son del autor