MiniMax Sparse Attention (MSA): una atención dispersa en bloques de dos ramas entrenada en un MoE de parámetros 109B con un presupuesto de 3T tokens

MiniMax lanzó MSA (MiniMax Sparse Attention), un método de atención dispersa creado directamente en Grouped Query Attention (GQA). Apunta a un cuello de botella: el costo cuadrático de la atención softmax en un contexto prolongado. El equipo de investigación de MiniMax lo probó dentro de un modelo de mezcla de expertos con parámetros 109B entrenado con datos multimodales nativos. También abrieron un núcleo de inferencia y enviaron un modelo de producción, MiniMax-M3.

¿Qué es MSA (Atención dispersa MiniMax)?

MSA (MiniMax Sparse Attention) divide la atención en dos etapas: una rama de índice y una rama principal. La rama de índice decide qué bloques clave-valor debe leer cada consulta. Luego, la rama principal ejecuta la atención softmax exacta solo sobre esos bloques.

La selección ocurre con granularidad de bloque, no por token. El tamaño de bloque predeterminado es Bk = 128 tokens. Cada consulta y grupo GQA mantiene k = 16 bloques. Eso fija el presupuesto por consulta en kBk = 2048 tokens de valor-clave.

Las dos estructuras de costos difieren. La atención densa de GQA se escala por consulta como O(N), el contexto completo. MSA escala como O (kBk), que permanece fijo a medida que N crece. Por lo tanto, la brecha informática se amplía a medida que aumenta la longitud del contexto.

La selección se comparte dentro de cada grupo de GQA, pero es independiente entre los grupos. Un encabezado de valor clave sirve a varios encabezados de consulta y comparten un conjunto de bloques. Diferentes grupos pueden atender a diferentes regiones de largo alcance.

Cómo funcionan las dos ramas

La rama Index agrega solo dos matrices de proyección a una capa GQA estándar. Define un encabezado de consulta de índice por grupo de GQA y un encabezado de clave de índice compartido. Puntúa tokens clave visibles y luego agrupa al máximo esas puntuaciones al nivel de bloque.

Luego, un operador Top-k selecciona los bloques con la puntuación más alta por consulta y grupo. El bloque local que contiene la consulta siempre se incluye. Esto evita que el selector elimine la vecindad inmediata de la consulta.

La rama principal reúne tokens causalmente visibles de los bloques seleccionados. Aplica atención softmax de producto escalado restringida a esos tokens. Cada encabezado de consulta mantiene su propia proyección de consulta pero comparte el conjunto de bloques del grupo.

Una visualización en el informe muestra lo que selecciona el indexador aprendido. Los cabezazos se concentran en la diagonal local y el primer bloque. El resto del presupuesto lo reservan para unos cuantos franjas de largo alcance.

https://arxiv.org/pdf/2606.13392v1
https://arxiv.org/pdf/2606.13392v1

Cómo se entrena MSA

La selección top-k no es diferenciable, por lo que la pérdida de modelado del lenguaje no puede entrenar las proyecciones del índice. MSA resuelve esto con una pérdida de alineación de KL. La pérdida coincide con la distribución de la sucursal del índice con el patrón de atención de la sucursal principal. El maestro es la distribución de la rama principal promediada por el grupo sobre los tokens seleccionados.

Tres mecanismos estabilizan el entrenamiento escaso. La separación de gradiente aplica un gradiente detenido a la entrada de la rama de índice. Esto limita la pérdida de KL a las proyecciones del índice, no a la columna vertebral. Sin él, los coeficientes KL más grandes provocaron picos de gradiente y divergencia de pérdidas.

Indexer Warmup presta plena atención en ambas ramas durante las primeras iteraciones. El indexador aprende de la pérdida de KL antes de controlar el enrutamiento. El bloque local forzado reserva un espacio para el contexto cercano.

Las ablaciones dieron forma a la receta final. Una de las primeras variantes agregó un encabezado de valor de Index Branch con su propia salida. Una vez que se utiliza el calentamiento, ese valor ya no es necesario. El diseño final lo descarta por motivos de eficiencia.

MSA admite dos rutas de formación. MSA-PT entrena desde cero después de un calentamiento del indexador de 40B tokens. MSA-CPT convierte un denso punto de control GQA entrenado en tokens de 2,6T. Luego continúa con 400 mil millones de tokens, incluidos 40 mil millones de tokens de calentamiento.

El codiseño del kernel

La escasez teórica no se convierte en velocidad sin una ruta de GPU correspondiente. MSA combina el algoritmo con dos ideas del núcleo.

La primera es la selección Top-k sin exp. Softmax preserva el orden, por lo que clasificar las puntuaciones brutas produce índices idénticos. El kernel omite los pasos max, exp y sum antes de la selección. En un contexto de 128K con k = 16, se ejecutó 5,1 veces más rápido que torch.topk. También superó al kernel de selección de raíz de TileLang por 3,7 ×.

El segundo es la atención escasa externa de KV con recopilación de consultas. La iteración sobre bloques KV aumenta la intensidad aritmética en comparación con la iteración sobre consultas. El kernel empaqueta posiciones de consulta ⌈128/G⌉ en un MMA de puntuación de 128 × 128. Un avance de dos fases divide la atención y combina pasos entre CTA.

El kernel de código abierto, fmha_sm100, está dirigido a las GPU NVIDIA SM100. Incluye FlashAttention denso y escasos núcleos Top-k bajo una licencia del MIT. Admite precisión BF16, FP8, NVFP4 y FP4.

Cómo se compara MSA con otros métodos dispersos

El equipo de investigación posiciona a MSA frente a cuatro diseños dispersos entrenados de forma nativa.

La siguiente tabla resume las diferencias que describe.

MétodoBackboneGranularidad de selecciónIndexador/señal de selecciónMSAGQABNivel de bloque (B_k = 128), pérdida de alineación kKL superior por grupo GQANSAMQA / MHAComprimido + bloques seleccionados + ventana deslizanteEntrenamiento nativo (de extremo a extremo)InfLLM-V2Denso↔conmutable dispersoSelección de bloque sin parámetros + ventana deslizanteSin parámetros (sin indexador entrenado)MoBAGQABloques KV muy grandes (claves promediadas por bloques) Solo gradiente LM DSAMLA (modo MQA) Nivel de token; único Top-k compartido entre cabezalesIndizador relámpago ReLU

El par distintivo de MSA es el intercambio Top-k por grupo GQA combinado con la selección a nivel de bloque. Esto mantiene las lecturas de KV contiguas y le da a cada grupo su propia recuperación.

El lado de la calidad se mantiene. Ambos modelos dispersos siguen siendo ampliamente competitivos con la línea base de Atención Total.

La siguiente tabla muestra resultados representativos del presupuesto de tokens 3T.

BenchmarkFullMSA-PTMSA-CPTMMLU67.067.266.8GSM8K76.277.773.7HumanEval61.064.057.9RULER-8K79.884.277.2RULER-32K75.077.575.7VideoMME41.1145.4839.65

Después de una extensión de contexto prolongado, MSA-CPT se mantuvo cerca de Full en HELMET-128K y RULER-128K. Cada consulta todavía atiende solo 2048 tokens clave-valor.

Zona de juegos explicativa