Cómo construir un modelo de incrustación de oraciones optimizado para Matryoshka para una recuperación ultrarrápida con truncamiento de 64 dimensiones

En este tutorial, ajustamos un modelo de incrustación de Transformadores de Oraciones utilizando el Aprendizaje de Representación Matryoshka para que las primeras dimensiones del vector transporten la señal semántica más útil. Entrenamos con MatryoshkaLoss en datos tripletes y luego validamos la promesa clave de MRL comparando la calidad de la recuperación después de truncar las incrustaciones a 64, 128 y 256 dimensiones. Al final, guardamos el modelo ajustado y demostramos cómo cargarlo con una pequeña configuración truncate_dim para una búsqueda vectorial rápida y eficiente en memoria. Consulta los CÓDIGOS COMPLETOS aquí.

!pip -q install -U conjuntos de datos de transformadores de oraciones acelerar importar matemáticas importar aleatorio importar numpy como np importar antorcha de conjuntos de datos importar load_dataset de torch.utils.data importar DataLoader de sentencia_transformers importar SentenceTransformer, InputEjemplo de sentencia_transformers importar pérdidas de sentencia_transformers.util importar cos_sim def set_seed(seed=42): random.seed(seed) np.random.seed(semilla) torch.manual_seed(semilla) torch.cuda.manual_seed_all(semilla) set_seed(42)

Instalamos las librerías requeridas e importamos todos los módulos necesarios para la formación y evaluación. Establecemos una semilla determinista, por lo que nuestro comportamiento de muestreo y entrenamiento se mantiene consistente en todas las ejecuciones. También nos aseguramos de que los RNG de PyTorch y CUDA estén alineados cuando hay una GPU disponible. Consulta los CÓDIGOS COMPLETOS aquí.

@torch.no_grad() def retrieval_metrics_mrr_recall_at_k( modelo, consultas, corpus, qrels, dims_list=(64, 128, 256, Ninguno), k=10, tamaño_lote=64,): dispositivo = “cuda” if torch.cuda.is_available() else “cpu” model.to(dispositivo) qids = lista(consultas.keys()) docids = lista(corpus.keys()) q_texts = [queries[qid] para qid en qids]d_texts = [corpus[did] for did in docids]q_emb = model.encode(q_texts, lote_size=batch_size, convert_to_tensor=True, normalize_embeddings=True) d_emb = model.encode(d_texts, lote_size=batch_size, convert_to_tensor=True, normalize_embeddings=True) resultados = {} para dim en dims_list: si dim es Ninguno: qe = q_emb de = d_emb dim_name = “full” else: qe = q_emb[:, :dim]
de = d_emb[:, :dim]
dim_name = str(dim) qe = torch.nn.functional.normalize(qe, p=2, dim=1) de = torch.nn.functional.normalize(de, p=2, dim=1) sims = cos_sim(qe, de) mrr_total = 0.0 recordar_total = 0.0 para i, qid en enumerate(qids): rel = qrels.get(qid, set()) si no es rel: continuar topk = torch.topk(sims[i]k=min(k, sims.forma[1]), mayor=True).indices.tolist() topk_docids = [docids[j] para j en topk]recordar_total += 1.0 si hay alguno (d en rel para d en topk_docids) más 0.0 rr = 0.0 para rango, d en enumerar(topk_docids, inicio=1): si d en rel: rr = 1.0 / salto de rango mrr_total += rr denom = max(1, len(qids)) resultados[dim_name] = {f”MRR@{k}”: mrr_total / denom, f”Recall@{k}”: recordar_total / denom} devuelve resultados def Pretty_print(resultados, título): print(“\n” + “=” * 80) print(title) print(“=” * 80) para atenuación, métricas en results.items(): print(f”dim={dim:>4} | ” + ” | “.unirse([f”{k}={v:.4f}” for k, v in metrics.items()]))

Implementamos un evaluador de recuperación liviano que codifica consultas y documentos, calcula la similitud de cosenos e informa MRR@10 y Recall@10. Volvemos a normalizar las incrustaciones después del truncamiento para que los prefijos más pequeños sigan siendo comparables en el espacio coseno. También agregamos una impresora compacta para que las comparaciones antes y después sean fáciles de leer. Consulta los CÓDIGOS COMPLETOS aquí.

DATASET_ID = “sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1” SUBSET = “triplet-hard” SPLIT = “entrenar” TRAIN_SAMPLES = 4000 EVAL_QUERIES = 300 flujo = load_dataset(DATASET_ID, SUBSET, split=SPLIT, streaming=True) train_examples = []
eval_queries = {} eval_corpus = {} eval_qrels = {} doc_id_counter = 0 qid_counter = 0 para la fila en la secuencia: q = (row.get(“query”) o “”).strip() pos = (row.get(“positive”) o “”).strip() neg = (row.get(“negative”) o “”).strip() si no q o no pos o no neg: continuar train_examples.append(InputExample(textos=[q, pos, neg])) if len(eval_queries) < EVAL_QUERIES: qid = f"q{qid_counter}" qid_counter += 1 pos_id = f"d{doc_id_counter}"; doc_id_counter += 1 neg_id = f"d{doc_id_counter}"; doc_id_counter += 1 consultas_eval[qid] = q eval_corpus[pos_id] = pos eval_corpus[neg_id] = neg eval_qrels[qid] = {pos_id} if len(train_examples) >= TRAIN_SAMPLES and len(eval_queries) >= EVAL_QUERIES: break print(len(train_examples), len(eval_queries), len(eval_corpus))

Transmitimos un conjunto de datos triplete de MS MARCO extraído y creamos un conjunto de entrenamiento (consultas, positivos, negativos) y un pequeño conjunto de referencia de IR. Asignamos cada consulta a un documento positivo relevante e incluimos un documento negativo para que la recuperación sea significativa. Nos detenemos temprano para que la ejecución sea compatible con Colab y al mismo tiempo sea lo suficientemente grande como para mostrar efectos de truncamiento.

MODEL_ID = “BAAI/bge-base-en-v1.5” dispositivo = “cuda” if torch.cuda.is_available() else “cpu” modelo = SentenceTransformer(MODEL_ID, dispositivo=dispositivo) full_dim = model.get_sentence_embedding_dimension() baseline = retrieval_metrics_mrr_recall_at_k( model, queries=eval_queries, corpus=eval_corpus, qrels=eval_qrels, dims_list=(64, 128, 256, Ninguno), k=10, ) Pretty_print(línea base, “ANTES”)

Cargamos un modelo de incrustación de base sólida y registramos su dimensión de incrustación completa. Realizamos la evaluación de referencia en 64/128/256/dimensiones completas para ver cómo se comporta el truncamiento antes de cualquier entrenamiento. Imprimimos los resultados para luego poder comparar si MRL mejora la calidad de las primeras dimensiones.

tamaño_por lotes = 16 épocas = 1 pasos_de_calentamiento = 100 cargador_de_trenes = DataLoader(ejemplos_de_trenes, tamaño_por_lotes=tamaño_por lotes, shuffle=True, drop_last=True) base_loss = pérdidas.MultipleNegativesRankingLoss(modelo=modelo) mrl_dims = [full_dim, 512, 256, 128, 64] si full_dim >= 768 más [full_dim, 256, 128, 64]
mrl_loss = pérdidas.MatryoshkaLoss( modelo=modelo, pérdida=base_loss, matryoshka_dims=mrl_dims ) model.fit( train_objectives=[(train_loader, mrl_loss)]épocas=épocas, pasos_calentamiento=pasos_calentamiento, show_progress_bar=True, ) después = retrieval_metrics_mrr_recall_at_k( modelo, consultas=eval_queries, corpus=eval_corpus, qrels=eval_qrels, dims_list=(64, 128, 256, Ninguno), k=10, ) Pretty_print(después, “DESPUÉS”) out_dir = “mrl-msmarco-demo” model.save(out_dir) m64 = SentenceTransformer(out_dir, truncate_dim=64) emb = m64.encode(
[“what is the liberal arts?”, “liberal arts covers humanities and sciences”]normalize_embeddings=True ) print(emb.shape)

Creamos MultipleNegativesRankingLoss y lo envolvemos con MatryoshkaLoss usando una lista descendente de dimensiones de prefijo de destino. Ajustamos el modelo en los tripletes y luego volvemos a ejecutar el mismo punto de referencia de truncamiento para medir la mejora en la retención. Además, guardamos el modelo y lo recargamos con truncate_dim=64 para confirmar el uso práctico para una recuperación compacta.

En conclusión, entrenamos con éxito un modelo de incrustación optimizado para Matryoshka que mantiene un sólido rendimiento de recuperación incluso cuando truncamos vectores a dimensiones de prefijo pequeñas, como 64. Verificamos el efecto comparando las métricas de recuperación de referencia con las posteriores al entrenamiento en múltiples tamaños de truncamiento y la incrustación completa. Con el modelo guardado y el patrón de carga truncate_dim, ahora tenemos un flujo de trabajo limpio para crear índices vectoriales más pequeños y más rápidos, manteniendo al mismo tiempo la opción de reclasificar con incrustaciones de dimensiones completas.

Consulta los CÓDIGOS COMPLETOS aquí. Además, no dude en seguirnos en Twitter y no olvide unirse a nuestro SubReddit de más de 100.000 ML y suscribirse a nuestro boletín. ¡Esperar! estas en telegrama? Ahora también puedes unirte a nosotros en Telegram.