Parallax: una atención lineal local parametrizada que mantiene Softmax y agrega una rama de corrección de covarianza aprendida

El mecanismo de atención de Transformer apenas ha cambiado desde 2017. La mayor parte del trabajo de eficiencia ha intentado reemplazar directamente la atención de softmax. Un nuevo periódico toma una ruta diferente. Mantiene la atención de Softmax y se atornilla a una rama de corrección.

Un equipo de investigadores de la Universidad Northwestern, Tilde Research y la Universidad de Washington introducen una atención lineal local parametrizada llamada “Parallax” que se adapta al preentrenamiento de LLM y codiseña con Muon.

Parallax no persigue la eficiencia recortando la computación. Agrega computación deliberadamente y luego hace que esa computación sea más barata de ejecutar en GPU modernas.

¿Qué es el paralaje?

Parallax se basa en la atención lineal local (LLA). LLA proviene del marco de regresión en el momento de la prueba. Ese marco interpreta la atención como un solucionador de regresión sobre pares clave-valor.

Desde este punto de vista, las claves son puntos de datos de entrenamiento. Los valores son etiquetas. La consulta es el punto de prueba. La atención Softmax es un estimador no paramétrico llamado Nadaraya-Watson. Se ajusta a una función constante local para cada consulta.

LLA actualiza esa estimación constante local a una estimación lineal local. El equipo de investigación demuestra que esto produce un error cuadrático medio integrado estrictamente más pequeño. El beneficio son mejores compensaciones entre sesgo y varianza para la memoria asociativa.

Pero LLA tiene un problema a gran escala. Su avance exacto requiere resolver un sistema lineal para cada consulta. Eso utiliza un solucionador de gradiente conjugado paralelo (CG). El solucionador CG crea tres problemas: E/S intensiva, una difícil compensación entre regularización y expresividad e incompatibilidad de baja precisión.

Parallax elimina el solucionador. En cambio, aprende una matriz de proyección adicional. El equipo de investigación escribe esto como ρi = WRxi. Aquí, WR es una matriz que se puede aprender y que prueba la covarianza KV directamente desde la entrada de la capa.

Entonces Parallax mantiene el principio lineal local. Simplemente reemplaza la resolución por consulta con un proyector similar a una consulta aprendida. Eso lo hace más simple, más eficiente y más fácil de implementar.

Cómo funciona el mecanismo

Parallax reformula LLA como atención softmax más una corrección aditiva. El resultado es igual al resultado de atención de softmax menos un término de covarianza proyectado. En la notación del artículo de investigación, ese término es la covarianza KV multiplicada por la sonda aprendida ρi.

El equipo de investigación también elimina una parte de LLA llamada factor de amplificación de límites, establecida en cero. Esto es necesario para la estabilidad. Una vez que la sonda es paramétrica, la interpretación geométrica original se rompe. Dejar el factor activado podría hacer que la escala diverja o cambie de signo.

Parallax se encuentra dentro de una familia de mecanismos de atención. El equipo de investigación los organiza en el artículo según tres ejes: el ancho de banda, la construcción de la sonda y la estructura afín. En un extremo, Parallax degenera exactamente a atención softmax cuando la norma de la sonda llega a cero.

Establecer WR = 0 hace que una capa Parallax se comporte de manera idéntica a la atención de softmax. Por lo tanto, un punto de control de Transformer previamente entrenado se puede convertir agregando WR y ajustándolo.

El argumento del hardware

Parallax hereda la estructura de transmisión de FlashAttention. Agrega una rama de covarianza que reutiliza el mismo flujo clave-valor.

El equipo de investigación amplía la delantera en dos ramas de puntuación paralelas. Ambas ramas comparten el máximo en línea, el factor de reescalado y los mosaicos K y V. Entonces Parallax no necesita E/S adicionales por iteración.

La propiedad clave es una mayor intensidad aritmética (IA). La IA es la relación entre operaciones de punto flotante y tráfico de memoria de gran ancho de banda. En el régimen donde domina el trabajo KV, Parallax aproximadamente duplica la intensidad aritmética. Agrega computación mientras reutiliza el mismo flujo de memoria.

Esto desvía la atención hacia un régimen más vinculado a la informática. Ese es exactamente el régimen en el que la optimización del kernel ayuda en el hardware moderno.

El equipo de investigación creó un prototipo de un núcleo de decodificación en CuTeDSL en las GPU NVIDIA Hopper. Las instrucciones matmul del núcleo tensor de Hopper funcionan en mosaicos de al menos 64 filas. Un paso de decodificación proporciona sólo una fila de consulta. Por lo tanto, los productos QK y RK se pueden calcular conjuntamente, dentro de las instrucciones que ya se emiten con atención estándar.

Realizaron perfiles contra FlashAttention 2 y 3 en GPU H200 con precisión BF16. Barrieron los tamaños de lote de 1 a 2048 y las longitudes de contexto de 128 a 32 768. El kernel prototipo iguala o supera a FlashAttention en todas las configuraciones. La siguiente figura muestra aceleraciones de 1,54× en la configuración de coincidencia de cálculo y de 1,14× en la configuración de coincidencia de E/S.

https://arxiv.org/pdf/2605.29157

Lo que muestran los experimentos

El equipo de investigación validó Parallax en tareas sintéticas y en preentrenamiento LLM en escalas 0.6B y 1.7B. Los modelos utilizaron la arquitectura Qwen-3 en el repositorio de torchtitan. Se entrenaron en el conjunto de datos Ultra-FineWeb con una longitud de contexto de 4096. Las líneas de base incluyeron atención softmax (Transformer), Mamba, Gated DeltaNet, MesaNet y Kimi DeltaAttention.

En MAD-Benchmark, Parallax logró la precisión general más alta con un promedio de 0,716. Mejoró constantemente las tareas orientadas a la recuperación, como la recuperación en contexto y la copia selectiva. Se mantuvo competitivo en tareas de compresión y memorización.

En el modelado del lenguaje, Parallax con Muon logró la mayor perplejidad en ambas escalas. También registró la precisión descendente promedio más alta. Con 1.7B, Parallax obtuvo un promedio de 62,45 frente al 61,43 del Transformer.

Dos controles prueban de dónde proviene la ganancia. Un transformador con parámetros coincidentes cerró solo una pequeña fracción de la brecha. Un Parallax emparejado por computadora aún superó ambas líneas de base. El artículo sostiene que esto apunta al mecanismo en sí, no a parámetros o cálculos adicionales.

El giro del optimizador

Un hallazgo fundamental es la interacción optimizador-arquitectura. Parallax muestra una gran ventaja bajo Muon. Con AdamW, la ventaja se reduce notablemente o incluso desaparece.

Muon es un optimizador reciente para parámetros matriciales en capas ocultas. Utiliza el factor polar del buffer de impulso, por lo que las actualizaciones tienen el número de condición exactamente uno. Trabajos anteriores muestran que esto produce matrices de peso mejor acondicionadas.

El equipo de investigación del artículo rastrea la brecha hasta la rama de corrección. Definen una relación de corrección a salida (COR). Bajo Muon, COR supera los 8 en las capas más profundas. Con AdamW, se mantiene por debajo de 4.

La proyección WR se ve afectada desproporcionadamente. Su rango estable colapsa bajo AdamW pero se mantiene alto bajo Muon. Un experimento de activación confirma el patrón. Bajo AdamW, el modelo aprende a suprimir la rama de corrección en lugar de usarla.

El equipo de investigación llama a esto la primera demostración empírica de un sólido diseño de código optimizador de arquitectura para mecanismos de atención. No afirman que Muon con WSD sea la receta óptima. Una ablación del apéndice muestra que la ventaja se reduce durante la fase de decadencia.

En qué se diferencian las puntuaciones

Parallax también produce distribuciones de puntuación diferentes a las de la atención softmax. Sus pesos por token pueden tomar valores negativos y superar uno en magnitud. Los pesos softmax estándar no pueden hacer esto.

El equipo de investigación informa tres efectos. Parallax puede restar activamente componentes de valor de tokens irrelevantes. Reduce sustancialmente la pérdida de atención en el primer token. Su entropía base softmax se mantiene más alta, dando pesos de atención más difusos.

Fortalezas y debilidades y preguntas abiertas

Fortalezas

Mantiene intacta la atención de softmax, por lo que un Transformer previamente entrenado puede realizar la conversión agregando WR y realizando ajustes. No agrega E/S adicionales por iteración al reutilizar el flujo clave-valor de FlashAttention. Duplica la intensidad aritmética, con un kernel prototipo que iguala o supera a FlashAttention 2/3 en decodificación. Muestra perplejidad constante y ganancias posteriores bajo controles de comparación de parámetros y de cálculo.

Debilidades y preguntas abiertas

Las ganancias dependen en gran medida de Muon; Con AdamW la ventaja desaparece en gran medida. La causa precisa de la dependencia del optimizador sigue siendo una cuestión abierta. Los resultados se detienen en una escala de 1.700 millones, sin MoE, contexto más largo o ejecuciones más grandes. La ventaja se erosiona durante la fase de desintegración de WSD, y solo se fija parcialmente mediante el recocido de desintegración de peso.

Conclusiones clave

Parallax mantiene la atención de softmax y agrega una rama de corrección de covarianza aprendida, reemplazando el solucionador de gradiente conjugado por consulta de LLA. Duplica la intensidad aritmética mientras reutiliza el mismo flujo KV, con un núcleo de decodificación que iguala o supera a FlashAttention 2/3. Perplejidad constante y ganancias posteriores en 0,6 mil millones y 1,7 mil millones, manteniéndose bajo controles de parámetros y cálculos. Las ganancias dependen en gran medida de Muon; bajo AdamW la ventaja se reduce notablemente o desaparece. Establecer WR = 0 recupera exactamente la atención de softmax, por lo que los Transformers previamente entrenados pueden realizar conversiones agregando WR y realizando ajustes.

Consulte el documento y el repositorio. Además, no dude en seguirnos en Twitter y no olvide unirse a nuestro SubReddit de más de 150.000 ML y suscribirse a nuestro boletín. ¡Esperar! estas en telegrama? Ahora también puedes unirte a nosotros en Telegram.

¿Necesita asociarse con nosotros para promocionar su repositorio de GitHub O su página principal de Hugging O su lanzamiento de producto O seminario web, etc.? Conéctate con nosotros