Conozca Flash-KMeans: un K-Means exacto y compatible con IO que se ejecuta más de 200 veces más rápido que FAISS en GPU

k-means ha sido una herramienta fuera de línea durante décadas. Lo ejecuta una vez para preprocesar los datos y luego continúa. Un equipo de investigadores de UC Berkeley y UT Austin lanzó Flash-KMeans, una nueva biblioteca de código abierto dirigida a un entorno diferente. Los canales de IA modernos ahora llaman k-means dentro de bucles de inferencia y entrenamiento. A esa frecuencia, la latencia por llamada importa más que los FLOP teóricos.

Flash-KMeans es una implementación compatible con IO del estándar k-means de Lloyd. No cambia las matemáticas y no se aproxima. Solo reestructura la forma en que el algoritmo mueve datos en una GPU. En una NVIDIA H200, el equipo de investigación informó una velocidad de hasta 17,9 veces de extremo a extremo con respecto a la mejor línea de base. Contra NVIDIA cuML reportan 33×. Contra FAISS reportan más de 200×.

¿Qué es Flash-KMeans?

Flash-KMeans es una biblioteca k-means por lotes escrita en núcleos de GPU Triton. Se envía bajo Apache 2.0 y se instala con pip install flash-kmeans.

La salida es matemáticamente idéntica a la k-media estándar de Lloyd. La aceleración proviene del flujo de datos a nivel de kernel, no de saltarse el trabajo. Eso lo separa de los métodos algorítmicos como la poda de desigualdad de triángulos o el muestreo de conjuntos de núcleos.

Una iteración estándar de Lloyd tiene dos etapas. La etapa de asignación calcula la distancia de cada punto a cada centroide y luego elige el más cercano. La etapa de actualización promedia los puntos en cada grupo para formar nuevos centroides. Ambas etapas son aritmética simple. En las GPU, ambas tienen cuellos de botella debido a la memoria, no a la computación.

Los dos cuellos de botella que ataca

El primer obstáculo es la etapa de asignación. El código estándar construye una matriz de distancia completa D de forma N×K en memoria de alto ancho de banda (HBM). Escribe la matriz y luego la vuelve a leer para ejecutar argmin. Para N=65536, K=1024, d=128, B=32, la distancia matemática toma 2,6 ms. Escribir y consumir D tarda unos 23 ms. La matriz es el costo, no la aritmética.

Flash-KMeans reemplaza esto con FlashAssign. El diseño toma prestado de FlashAttention. FlashAssign transmite mosaicos de puntos y centroides desde HBM a la SRAM en el chip. Fusiona el cálculo de distancias con un argmin en línea. La matriz N×K completa nunca se materializa. Esto reduce la complejidad IO dominante de O (NK) a O (Nd + Kd). A nivel de kernel, FlashAssign alcanza hasta 21,2×. En un caso, redujo la asignación de 122,5 ms a 5,8 ms.

El segundo cuello de botella es la etapa de actualización del centroide. El código estándar utiliza adiciones atómicas de estilo disperso. Cada hilo agrega su punto en un búfer de suma compartido codificado por la identificación del clúster. Muchos subprocesos llegan al mismo grupo “caliente” a la vez. Eso provoca contención atómica y serialización de hardware. El equipo de investigación midió sólo 50 GB/s de ancho de banda efectivo aquí en un H200.

Flash-KMeans reemplaza esto con Sort-Inverse Update. Ordena el vector de asignación 1D por ID de clúster mediante argsort. Los identificadores de clúster idénticos forman segmentos contiguos. Cada bloque de hilo reduce un segmento en el chip y luego emite una adición atómica por segmento. La matriz de puntos pesados ​​nunca se permuta físicamente. Las operaciones atómicas caen de (O((K+NBN)d))(O((K + \frac{N}{B_N})d)) . El kernel alcanza hasta 6,3×.

Punto de referencia

El equipo de investigación lo prueba en un H200 con CUDA 12.8, datos FP16 y d=128. Barren N, K y el tamaño de lote B. Se comparan con cuatro líneas de base optimizadas: fast_pytorch_kmeans, fastkmeans, cuML y FAISS.

ComparaciónAceleración reportadaContexto de carga de trabajoDe extremo a extremo frente a la mejor línea basehasta 17,9×N=8M, K=1024 (N grande, K pequeña)frente a NVIDIA cuML33×biblioteca industrial frente a FAISSmás de 200×biblioteca industrialFlashAsignar kernelhasta 21,2×N=1M, K=8192 (asignación)Actualización de kernel de orden inverso hasta 6,3 × N = 33 M, K = 4096 (actualización) Fuera del núcleo, gran escala hasta 10,5 × N = 400 M, K = 16384 frente a medios rápidos

Un modo de falla importa para el contexto. Las implementaciones estándar de PyTorch se quedan sin memoria en regímenes de K grande. No pueden materializar la matriz N×K. FAISS es la biblioteca estándar de la industria en muchos sistemas de búsqueda de vectores de producción.

La biblioteca también se queda sin núcleo. Con mil millones de puntos (K=32768, d=128), finaliza una iteración en 41,4 segundos, frente a 261,8 segundos para la línea base. Utiliza superposición de flujo fragmentado para ocultar la transferencia PCIe detrás de la computación. Una heurística de compilación con reconocimiento de caché también reduce la sobrecarga de ajuste hasta 175 veces, dentro del 0,3% de la velocidad sintonizada.

Explicador interactivo de MTP

Marktechpost · Explicador interactivo

Flash-KMeans: k-means exactos, reconstruidos alrededor de la memoria de la GPU

Las mismas matemáticas de Lloyd que las k-medias estándar: más rápidas sólo gracias al flujo de datos. Ejecute la agrupación en clústeres en vivo, observe el cuello de botella de la actualización y mida el IO que elimina.

17,9×de un extremo a otro frente a la mejor línea de base

33×frente a NVIDIA cuML

200×+frente a FAISS

1Bpuntos, fuera del núcleo

1 · Agrupación en vivo

2 · Contención de actualización

3 · calculadora de E/S

Puntos de datos (N) 800

Grupos (K) 5

Ejecutar paso Nuevos datos

Iteración0

cambio de centroide

Estadoinactivo

Esto ejecuta k-means reales de Lloyd en su navegador en puntos 2-D. El algoritmo es idéntico a lo que acelera Flash-KMeans; solo difiere el flujo de datos de la GPU. Cada paso = una asignación + una actualización del centroide.

Pulsa reproducir. La actualización de dispersión estándar se serializa cuando los bloques escriben el mismo centroide “activo” (puestos rojos). La actualización de clasificación inversa ordena primero los ID de los clústeres, de modo que cada bloque fusiona segmentos contiguos con una adición atómica, sin conflictos.

Reproducir línea de tiempo Restablecer

Atómicos estándarO(N·d)

Ordenamiento atómico inversoO((K+N/B)·d)

Ancho de banda estándar medido50 GB/s

aceleración del kernel6,3×

Las actualizaciones estándar emiten una adición atómica por token. Muchos hilos golpean el mismo centroide a la vez, lo que genera contención. La clasificación por ID de clúster convierte las dispersiones en reducciones a nivel de segmento en la memoria del chip.

Estándar: materializar la matriz N×K, O(NK)

FlashAssign: entradas de flujo, O(Nd+Kd)

menos tráfico de HBM para el paso de asignación (teórico)

Casos de uso

K-means exactos y más rápidos cambian lo que puede ejecutar en línea, no solo sin conexión.

Indexación de búsqueda vectorial: FAISS construye sus índices de búsqueda con k-medias. K-means más rápido le permite volver a indexar a medida que cambian los datos, en lugar de reconstruirlos de la noche a la mañana. Enrutamiento de atención escasa: enrutamiento de tokens de clúster de transformadores y tácticas para enrutar la atención. Las k-medias de milisegundos hacen que esto sea viable dentro del ciclo de inferencia. Compresión de caché KV: ClusterKV agrupa tokens en un espacio semántico para comprimir el caché. Una agrupación más económica hace que la compresión por capa y por paso sea práctica. Cuantización de KV de bits bajos: los métodos recientes agrupan las entradas de KV en libros de códigos, repetidamente. Una agrupación en clústeres más rápida reduce el costo de preprocesamiento. Transformadores de difusión: Sparse VideoGen2 llama k-medias por lotes durante pases hacia adelante. Permuta tokens por similitud semántica para explotar la escasez.

Usándolo

La API refleja faiss y sklearn. La siguiente llamada agrupa un tensor por lotes (B, N, d).

importar antorcha desde flash_kmeans importar lote_kmeans_Euclid x = torch.randn(32, 75600, 128, dispositivo=”cuda”, dtype=torch.float16) cluster_ids, centros, _ = lote_kmeans_Euclid( x, n_clusters=1000, tol=1e-4, verbose=True )

También está disponible una interfaz estilo scikit-learn.

from flash_kmeans import FlashKMeans km = FlashKMeans(d=128, k=8192, niter=100) etiquetas = km.fit_predict(large_cpu_tensor) # dispositivo=Ninguno usa todas las GPU visibles

El kernel se envía automáticamente por forma y tipo. Una ruta de D pequeña maneja d≤512. Una ruta dividida en D maneja d más grande sin materializar la matriz de distancias. Las ejecuciones de múltiples GPU se activan automáticamente para datos de N grandes almacenados en la memoria de la CPU.

Conclusiones clave

Flash-KMeans es exacto, no aproximado: las mismas matemáticas de Lloyd, aceleradas únicamente por el flujo de datos de la GPU. FlashAssign fusiona distancia + argmin en línea, cortando la asignación IO de O(NK) a O(Nd+Kd) — hasta 21,2×. La actualización Sort-Inverse clasifica los ID de los clústeres en segmentos, reemplazando los átomos dispersos, hasta 6,3×. Informa hasta 17,9 veces de extremo a extremo, 33 veces más que cuML y más de 200 veces más que FAISS en un H200. Escala fuera del núcleo a mil millones de puntos y reduce la sobrecarga de sintonización hasta 175×.

Consulte el documento y el repositorio. Además, no dude en seguirnos en Twitter y no olvide unirse a nuestro SubReddit de más de 150.000 ml y suscribirse a nuestro boletín. ¡Esperar! estas en telegrama? Ahora también puedes unirte a nosotros en Telegram.

¿Necesita asociarse con nosotros para promocionar su repositorio de GitHub O su página principal de Hugging O su lanzamiento de producto O seminario web, etc.? Conéctate con nosotros