FlashKDA de código abierto de Moonshot AI: núcleos CUTLASS para la atención de Kimi Delta con procesamiento por lotes de longitud variable y puntos de referencia H20

El equipo detrás de Kimi.ai (Moonshot AI) acaba de hacer una contribución significativa al espacio de infraestructura de IA de código abierto. El equipo de investigación ha hecho una contribución significativa al espacio de infraestructura de IA de código abierto. Lanzaron FlashKDA (Flash Kimi Delta Attention), una implementación de kernel de alto rendimiento basada en CUTLASS del mecanismo Kimi Delta Attention (KDA). La biblioteca FlashKDA está disponible en GitHub bajo una licencia MIT. Ofrece aceleraciones de precarga de 1,72× a 2,22× sobre la línea base de atención lineal flash en las GPU NVIDIA H20 y funciona como backend directo para la popular biblioteca de atención lineal flash.

¿Qué es la atención Kimi Delta y por qué es importante?

Para comprender FlashKDA, primero es útil comprender dónde se encuentra en el panorama de atención de LLM.

La atención softmax estándar tiene una complejidad cuadrática con respecto a la longitud de la secuencia, lo que significa que a medida que se introduce un contexto más largo en un modelo, los costos de computación crecen extremadamente rápido. Esto ha impulsado una ola de investigación sobre mecanismos de atención lineal, que se aproximan o reemplazan la operación softmax para lograr un escalado lineal. Kimi Delta Attention (KDA) es la contribución de Moonshot AI a este espacio: un mecanismo de atención lineal que refina Gated DeltaNet con un mecanismo de activación por canales más fino, lo que permite un uso más efectivo de la memoria RNN de estado finito limitada.

KDA no es sólo un prototipo de investigación. Es el mecanismo de atención central en Kimi Linear, el modelo híbrido de código abierto de Moonshot AI con 48 mil millones de parámetros totales y 3 mil millones de parámetros activados. Kimi Linear utiliza una relación KDA a MLA (atención latente de múltiples cabezales) de 3:1 (tres capas KDA por cada capa de atención global), lo que reduce el uso de caché KV hasta en un 75 % durante la generación de secuencias largas y, al mismo tiempo, logra un rendimiento de decodificación hasta 6 veces mayor con 1 millón de longitud de contexto en comparación con la atención total. FlashKDA es el kernel CUDA de nivel de producción que hace que esa arquitectura sea rápida durante el precarga.

Concretamente, el paso directo de KDA toma consultas (q), claves (k), valores (v), una puerta antes de la activación (g) y logits beta (beta), junto con un factor de escala, un tensor de salida (out) y parámetros de puerta: A_log (parámetro de puerta de registro por cabeza), dt_bias (sesgo de puerta) y lower_bound (límite inferior de puerta, que va de -5,0 a 0). La activación sigmoidea en beta la aplica internamente el núcleo. El mecanismo también admite estados recurrentes iniciales y finales opcionales, lo que resulta útil para la inferencia de varios turnos en la que desea transferir el estado a través de solicitudes.

La formulación recurrente significa que el modelo puede procesar de manera eficiente secuencias largas durante la generación. Pero el precargado eficiente de estas arquitecturas aún requiere núcleos de GPU altamente optimizados, que es exactamente lo que ofrece FlashKDA.

Debajo del capó: CUTLASS en Hopper

FlashKDA se basa en CUTLASS, la biblioteca de código abierto de NVIDIA de abstracciones de plantillas CUDA C++ para álgebra lineal de alto rendimiento y desarrollo de kernel personalizado. CUTLASS permite a los desarrolladores escribir kernels que aprovechen al máximo la arquitectura Tensor Core de NVIDIA, y es la misma base utilizada por bibliotecas como FlashAttention-3.

La biblioteca está dirigida a SM90 y superiores, es decir, a la arquitectura Hopper de NVIDIA (H100, H20) y posteriores. Los requisitos mínimos son CUDA 12.9 y PyTorch 2.4. La base del código es predominantemente CUDA (56,4%), con enlaces Python (36,2%) y código adhesivo C++ (6,7%).

La API principal es flash_kda.fwd, que recibe las siguientes entradas:

q, k, v, g: todo en bf16 con forma [B, T, H, K] o [B, T, H, V] (donde g es la puerta antes de la activación) beta: forma logits beta bf16 [B, T, H] (sigmoide aplicado internamente) escala: fp32 factor de escala escalar: bf16 tensor de salida en forma [B, T, H, V]

A_log, dt_bias, lower_bound: parámetros de puerta fp32 estado_inicial, estado_final: estados recurrentes bf16 o fp32 opcionales cu_seqlens: longitudes de secuencia acumulativas int64 opcionales para procesamiento por lotes de longitud variable

Una restricción actual: el núcleo requiere K = V = 128 para la dimensión de la cabeza.

El soporte de procesamiento por lotes de longitud variable a través de cu_seqlens es particularmente notable para uso en producción. En el servicio de inferencia real, las solicitudes de un lote rara vez comparten la misma longitud de secuencia. Ser capaz de empaquetar múltiples secuencias de diferentes longitudes en una única llamada al kernel es un requisito clave para los sistemas de servicio de alto rendimiento.

Resultados de referencia: 1,72× a 2,22× en H20

Los resultados de referencia (al 20 de abril de 2026) comparan flash_kda con fla_chunk_kda (la implementación flash-linear-attention existente) en una longitud de secuencia de T=8192, dimensión de cabeza D=128 y dos configuraciones de recuento de cabezas: H=96 y H=64. Cada punto de referencia se ejecutó con 30 iteraciones de calentamiento, 200 iteraciones de medición y 5 repeticiones.

Para H=96:

Caseflash_kda (ms)fla_chunk_kda (ms)SpeedupFixed2.62194.50521.72×Varlen, seq_lens=[1300, 547, 2048, 963, 271, 3063]2.34204.57171.95×Varlen, lente_secuencia=1024 × 82.01004.46682.22×

Para H=64:

Caseflash_kda (ms)fla_chunk_kda (ms)SpeedupFixed1.61992.95871.83×Varlen, seq_lens=[1300, 547, 2048, 963, 271, 3063]1.70273.05951.80×Varlen, lente_secuencia=1024 × 81.39303.04122.18×

La aceleración máxima de 2,22 × aparece en el caso de longitud variable uniforme (seq_lens = 1024 × 8, ocho secuencias de longitud 1024 que suman T = 8192). La caja de longitud fija ofrece el mínimo de la gama con 1,72×. En ambas configuraciones de cabezales y en los tres escenarios de secuencia, FlashKDA supera consistentemente la línea base de atención lineal flash por un margen significativo.

Integración con atención lineal flash

Uno de los aspectos más prácticos de FlashKDA es su historia de integración. Una vez instalado, FlashKDA se envía automáticamente desde chunk_kda de flash-linear-attention, lo que significa que las bases de código existentes que utilizan flash-linear-attention no necesitan cableado manual para aprovechar el kernel más rápido. Se realiza un seguimiento de la integración en flash-linear-attention PR #852.

La instalación es sencilla:

git clone https://github.com/MoonshotAI/FlashKDA.git flash-kda cd flash-kda actualización del submódulo git –init –recursive pip install -v.

El conjunto de pruebas de corrección (tests/test_fwd.py) ejecuta una verificación de coincidencia exacta con una implementación de referencia de PyTorch y realiza una validación cruzada con flash-linear-attention. Esto brinda a los desarrolladores de IA una base confiable para auditar el comportamiento del kernel antes de implementarlo en producción.

Conclusiones clave

FlashKDA es el kernel CUDA de código abierto basado en CUTLASS de Moonshot AI para Kimi Delta Attention (KDA), que ofrece una velocidad de precarga de 1,72 × –2,22 × sobre la línea base de atención lineal flash en las GPU NVIDIA H20. KDA amplía Gated DeltaNet con compuerta detallada por canal: es el mecanismo de atención central detrás de Kimi Linear, un modelo híbrido de 48 B en total/3 B de parámetros activos que reduce el uso de caché KV hasta en un 75 % y logra un rendimiento de decodificación hasta 6 veces mayor con una longitud de contexto de 1 M. El kernel se dirige al hardware SM90+ (NVIDIA Hopper — H100, H20 y superiores), requiere CUDA 12.9+ y PyTorch 2.4+, y actualmente admite una dimensión de cabeza fija de K = V = 128. El procesamiento por lotes de longitud variable se admite de forma nativa a través del parámetro cu_seqlens, lo que permite empaquetar múltiples secuencias de diferentes longitudes en una sola llamada al kernel, una característica crítica para el servicio de inferencia de alto rendimiento. Una vez instalado, FlashKDA se envía automáticamente desde chunk_kda de flash-linear-attention, lo que lo convierte en una actualización de rendimiento inmediata para cualquier base de código existente que ya utilice la biblioteca flash-linear-attention; no se requieren cambios en la arquitectura.

Consulte el repositorio de GitHub. Además, no dude en seguirnos en Twitter y no olvide unirse a nuestro SubReddit de más de 130.000 ML y suscribirse a nuestro boletín. ¡Esperar! estas en telegrama? Ahora también puedes unirte a nosotros en Telegram.

¿Necesita asociarse con nosotros para promocionar su repositorio de GitHub O su página principal de Hugging O su lanzamiento de producto O seminario web, etc.? Conéctate con nosotros