k-means ha sido una herramienta fuera de línea durante décadas. Lo ejecuta una vez para preprocesar los datos y luego continúa. Un equipo de investigadores de UC Berkeley y UT Austin lanzó Flash-KMeans, una nueva biblioteca de código abierto dirigida a un entorno diferente. Los canales de IA modernos ahora llaman k-means dentro de bucles de inferencia y entrenamiento. A esa frecuencia, la latencia por llamada importa más que los FLOP teóricos.
Flash-KMeans es una implementación compatible con IO del estándar k-means de Lloyd. No cambia las matemáticas y no se aproxima. Solo reestructura la forma en que el algoritmo mueve datos en una GPU. En una NVIDIA H200, el equipo de investigación informó una velocidad de hasta 17,9 veces de extremo a extremo con respecto a la mejor línea de base. Contra NVIDIA cuML reportan 33×. Contra FAISS reportan más de 200×.
¿Qué es Flash-KMeans?
Flash-KMeans es una biblioteca k-means por lotes escrita en núcleos de GPU Triton. Se envía bajo Apache 2.0 y se instala con pip install flash-kmeans.
La salida es matemáticamente idéntica a la k-media estándar de Lloyd. La aceleración proviene del flujo de datos a nivel de kernel, no de saltarse el trabajo. Eso lo separa de los métodos algorítmicos como la poda de desigualdad de triángulos o el muestreo de conjuntos de núcleos.
Una iteración estándar de Lloyd tiene dos etapas. La etapa de asignación calcula la distancia de cada punto a cada centroide y luego elige el más cercano. La etapa de actualización promedia los puntos en cada grupo para formar nuevos centroides. Ambas etapas son aritmética simple. En las GPU, ambas tienen cuellos de botella debido a la memoria, no a la computación.
Los dos cuellos de botella que ataca
El primer obstáculo es la etapa de asignación. El código estándar construye una matriz de distancia completa D de forma N×K en memoria de alto ancho de banda (HBM). Escribe la matriz y luego la vuelve a leer para ejecutar argmin. Para N=65536, K=1024, d=128, B=32, la distancia matemática toma 2,6 ms. Escribir y consumir D tarda unos 23 ms. La matriz es el costo, no la aritmética.
Flash-KMeans reemplaza esto con FlashAssign. El diseño toma prestado de FlashAttention. FlashAssign transmite mosaicos de puntos y centroides desde HBM a la SRAM en el chip. Fusiona el cálculo de distancias con un argmin en línea. La matriz N×K completa nunca se materializa. Esto reduce la complejidad IO dominante de O (NK) a O (Nd + Kd). A nivel de kernel, FlashAssign alcanza hasta 21,2×. En un caso, redujo la asignación de 122,5 ms a 5,8 ms.
El segundo cuello de botella es la etapa de actualización del centroide. El código estándar utiliza adiciones atómicas de estilo disperso. Cada hilo agrega su punto en un búfer de suma compartido codificado por la identificación del clúster. Muchos subprocesos llegan al mismo grupo “caliente” a la vez. Eso provoca contención atómica y serialización de hardware. El equipo de investigación midió sólo 50 GB/s de ancho de banda efectivo aquí en un H200.
Flash-KMeans reemplaza esto con Sort-Inverse Update. Ordena el vector de asignación 1D por ID de clúster mediante argsort. Los identificadores de clúster idénticos forman segmentos contiguos. Cada bloque de hilo reduce un segmento en el chip y luego emite una adición atómica por segmento. La matriz de puntos pesados nunca se permuta físicamente. Las operaciones atómicas caen de (O((K+NBN)d))(O((K + \frac{N}{B_N})d)) . El kernel alcanza hasta 6,3×.
Punto de referencia
El equipo de investigación lo prueba en un H200 con CUDA 12.8, datos FP16 y d=128. Barren N, K y el tamaño de lote B. Se comparan con cuatro líneas de base optimizadas: fast_pytorch_kmeans, fastkmeans, cuML y FAISS.
Un modo de falla importa para el contexto. Las implementaciones estándar de PyTorch se quedan sin memoria en regímenes de K grande. No pueden materializar la matriz N×K. FAISS es la biblioteca estándar de la industria en muchos sistemas de búsqueda de vectores de producción.
La biblioteca también se queda sin núcleo. Con mil millones de puntos (K=32768, d=128), finaliza una iteración en 41,4 segundos, frente a 261,8 segundos para la línea base. Utiliza superposición de flujo fragmentado para ocultar la transferencia PCIe detrás de la computación. Una heurística de compilación con reconocimiento de caché también reduce la sobrecarga de ajuste hasta 175 veces, dentro del 0,3% de la velocidad sintonizada.
Explicador interactivo de MTP
Marktechpost · Explicador interactivo
Flash-KMeans: k-means exactos, reconstruidos alrededor de la memoria de la GPU
Las mismas matemáticas de Lloyd que las k-medias estándar: más rápidas sólo gracias al flujo de datos. Ejecute la agrupación en clústeres en vivo, observe el cuello de botella de la actualización y mida el IO que elimina.
17,9×de un extremo a otro frente a la mejor línea de base
33×frente a NVIDIA cuML
200×+frente a FAISS
1Bpuntos, fuera del núcleo
1 · Agrupación en vivo
2 · Contención de actualización
3 · calculadora de E/S
Puntos de datos (N) 800
Grupos (K) 5
Ejecutar paso Nuevos datos
Iteración0
cambio de centroide—
Estadoinactivo
Esto ejecuta k-means reales de Lloyd en su navegador en puntos 2-D. El algoritmo es idéntico a lo que acelera Flash-KMeans; solo difiere el flujo de datos de la GPU. Cada paso = una asignación + una actualización del centroide.
Pulsa reproducir. La actualización de dispersión estándar se serializa cuando los bloques escriben el mismo centroide “activo” (puestos rojos). La actualización de clasificación inversa ordena primero los ID de los clústeres, de modo que cada bloque fusiona segmentos contiguos con una adición atómica, sin conflictos.
Reproducir línea de tiempo Restablecer
Atómicos estándarO(N·d)
Ordenamiento atómico inversoO((K+N/B)·d)
Ancho de banda estándar medido50 GB/s
aceleración del kernel6,3×
Las actualizaciones estándar emiten una adición atómica por token. Muchos hilos golpean el mismo centroide a la vez, lo que genera contención. La clasificación por ID de clúster convierte las dispersiones en reducciones a nivel de segmento en la memoria del chip.
—menos tráfico de HBM para el paso de asignación (teórico)
Casos de uso
K-means exactos y más rápidos cambian lo que puede ejecutar en línea, no solo sin conexión.
Indexación de búsqueda vectorial: FAISS construye sus índices de búsqueda con k-medias. K-means más rápido le permite volver a indexar a medida que cambian los datos, en lugar de reconstruirlos de la noche a la mañana. Enrutamiento de atención escasa: enrutamiento de tokens de clúster de transformadores y tácticas para enrutar la atención. Las k-medias de milisegundos hacen que esto sea viable dentro del ciclo de inferencia. Compresión de caché KV: ClusterKV agrupa tokens en un espacio semántico para comprimir el caché. Una agrupación más económica hace que la compresión por capa y por paso sea práctica. Cuantización de KV de bits bajos: los métodos recientes agrupan las entradas de KV en libros de códigos, repetidamente. Una agrupación en clústeres más rápida reduce el costo de preprocesamiento. Transformadores de difusión: Sparse VideoGen2 llama k-medias por lotes durante pases hacia adelante. Permuta tokens por similitud semántica para explotar la escasez.
Usándolo
La API refleja faiss y sklearn. La siguiente llamada agrupa un tensor por lotes (B, N, d).
También está disponible una interfaz estilo scikit-learn.
El kernel se envía automáticamente por forma y tipo. Una ruta de D pequeña maneja d≤512. Una ruta dividida en D maneja d más grande sin materializar la matriz de distancias. Las ejecuciones de múltiples GPU se activan automáticamente para datos de N grandes almacenados en la memoria de la CPU.
Conclusiones clave
Flash-KMeans es exacto, no aproximado: las mismas matemáticas de Lloyd, aceleradas únicamente por el flujo de datos de la GPU. FlashAssign fusiona distancia + argmin en línea, cortando la asignación IO de O(NK) a O(Nd+Kd) — hasta 21,2×. La actualización Sort-Inverse clasifica los ID de los clústeres en segmentos, reemplazando los átomos dispersos, hasta 6,3×. Informa hasta 17,9 veces de extremo a extremo, 33 veces más que cuML y más de 200 veces más que FAISS en un H200. Escala fuera del núcleo a mil millones de puntos y reduce la sobrecarga de sintonización hasta 175×.
Consulte el documento y el repositorio. Además, no dude en seguirnos en Twitter y no olvide unirse a nuestro SubReddit de más de 150.000 ml y suscribirse a nuestro boletín. ¡Esperar! estas en telegrama? Ahora también puedes unirte a nosotros en Telegram.
¿Necesita asociarse con nosotros para promocionar su repositorio de GitHub O su página principal de Hugging O su lanzamiento de producto O seminario web, etc.? Conéctate con nosotros