Clasificación se erige como una de las aplicaciones más básicas pero más importantes del procesamiento del lenguaje natural. Tiene un papel vital en muchas aplicaciones del mundo real que van a filtrar correos electrónicos no deseados como spam, detectar categorías de productos o clasificar la intención del usuario en una aplicación de botón de chat. La forma predeterminada de construir clasificadores de texto es recopilar grandes cantidades de datos etiquetados, lo que significa textos de entrada y sus etiquetas correspondientes, y luego capacitar a un modelo de aprendizaje automático personalizado. Las cosas cambiaron un poco a medida que los LLM se volvieron más potentes, donde a menudo puede obtener un rendimiento decente utilizando modelos de lenguaje de gran propósito general como clasificadores de disparo cero o de pocos disparos, reduciendo significativamente el tiempo de despliegue de los servicios de clasificación de texto. Sin embargo, la precisión puede retrasarse detrás de los modelos personalizados y depende en gran medida de elaborar indicaciones personalizadas para definir mejor la tarea de clasificación para el LLM. En este blog, nuestro objetivo es minimizar la brecha entre los modelos ML personalizados para la clasificación y las LLM de propósito general, al tiempo que minimizamos el esfuerzo necesario para adaptar la solicitud de LLM a su tarea.
LLMS vs modelos ML personalizados para la clasificación de texto
Pros:
Primero exploremos los profesionales y los contras de cada uno de los dos enfoques para hacer la clasificación de texto.
Modelos de idiomas grandes como clasificadores de propósito general:
- Alta capacidad de generalización dada el vasto corpus de pre-entrenamiento y las habilidades de razonamiento de la LLM.
- Un solo propósito general LLM puede manejar múltiples tareas de clasificaciones sin la necesidad de implementar un modelo para cada uno.
- Como LLMS Continúe mejorando, puede mejorar potencialmente la precisión con un esfuerzo mínimo simplemente mediante la adopción de modelos más nuevos y poderosos a medida que estén disponibles.
- La disponibilidad de la mayoría de los LLM como servicios administrados reduce significativamente el conocimiento y el esfuerzo de despliegue requeridos para comenzar.
- Las LLM a menudo superan a los modelos ML personalizados en escenarios de baja datos donde los datos etiquetados son limitados o costosos de obtener.
- Los LLM se generalizan a varios idiomas.
- Los LLM pueden ser más baratos cuando tienen volúmenes de predicciones bajos o impredecibles si paga por token.
- Las definiciones de clase se pueden cambiar dinámicamente sin reentrenarse simplemente modificando las indicaciones.
Contras:
- Los LLM son propensos a las alucinaciones.
- Los LLM pueden ser lentos, o al menos más lentos que los pequeños modelos ML personalizados.
- Requieren un esfuerzo rápido de ingeniería.
- Las aplicaciones de alto rendimiento utilizando LLMS-as-A-Service pueden encontrar rápidamente limitaciones de cuotas.
- Este enfoque se vuelve menos efectivo con una gran cantidad de clases potenciales debido a las limitaciones de tamaño de contexto. Definir todas las clases consumiría una parte significativa del contexto de entrada disponible y efectivo.
- Los LLM generalmente tienen peor precisión que los modelos personalizados en el régimen de datos altos.
Costumbre Aprendizaje automático Modelos:
Pros:
- Eficiente y rápido.
- Más flexible en elección de arquitectura, capacitación y servicio.
- Capacidad para agregar la interpretabilidad y los aspectos de estimación de incertidumbre al modelo.
- Mayor precisión en el régimen de datos altos.
- Usted mantiene el control de su modelo y la infraestructura de servicio.
Contras:
- Requiere re-entradas frecuentes para adaptarse a nuevos datos o cambios de distribución.
- Puede necesitar cantidades significativas de datos etiquetados.
- Generalización limitada.
- Sensible al vocabulario o formulaciones fuera de dominio.
- Requiere conocimiento de MLOPS para la implementación.
Pinchar la brecha entre el clasificador de texto personalizado y las LLM:
Trabajemos en forma de mantener a los profesionales del uso de LLM para la clasificación mientras aliviamos algunos de los contras. Nos inspiraremos en el trapo y utilizaremos una técnica de solicitación llamada pocos disparos.
Definamos ambos:
TRAPO
La generación de recuperación aumentada es un método popular que aumenta el contexto de LLM con conocimiento externo antes de hacer una pregunta. Esto reduce la probabilidad de alucinación y mejora la calidad de las respuestas.
Pocas de disparo
En cada tarea de clasificación, mostramos los ejemplos LLM de entradas y salidas esperadas como parte de la solicitud para ayudarlo a comprender la tarea.
Ahora, la idea principal de este proyecto es mezclar ambos. Obtuvimos ejemplos dinámicamente que son los más similares a la consulta de texto que se clasificará y los inyectan como pocos indicios de ejemplo. También limitamos el alcance de las posibles clases de forma dinámica utilizando las de los vecinos K-Nears. Esto libera una cantidad significativa de tokens en el contexto de entrada cuando se trabaja con un problema de clasificación con una gran cantidad de clases posibles.
Así es como funcionaría eso:
Pasemos por los pasos prácticos de hacer que este enfoque se ejecute:
- Construyendo una base de conocimiento de los pares de texto / categoría de entrada etiquetados. Esta será nuestra fuente de conocimiento externo para el LLM. Usaremos ChromAdB.
from typing import List
from uuid import uuid4
from langchain_core.documents import Document
from chromadb import PersistentClient
from langchain_chroma import Chroma
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import torch
from tqdm import tqdm
from chromadb.config import Settings
from retrieval_augmented_classification.logger import logger
class DatasetVectorStore:
"""ChromaDB vector store for PublicationModel objects with SentenceTransformers embeddings."""
def __init__(
self,
db_name: str = "retrieval_augmented_classification", # Using db_name as collection name in Chroma
collection_name: str = "classification_dataset",
persist_directory: str = "chroma_db", # Directory to persist ChromaDB
):
self.db_name = db_name
self.collection_name = collection_name
self.persist_directory = persist_directory
# Determine if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
self.embeddings = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-small-en-v1.5",
model_kwargs={"device": device},
encode_kwargs={
"device": device,
"batch_size": 100,
}, # Adjust batch_size as needed
)
# Initialize Chroma vector store
self.client = PersistentClient(
path=self.persist_directory, settings=Settings(anonymized_telemetry=False)
)
self.vector_store = Chroma(
client=self.client,
collection_name=self.collection_name,
embedding_function=self.embeddings,
persist_directory=self.persist_directory,
)
def add_documents(self, documents: List) -> None:
"""
Add multiple documents to the vector store.
Args:
documents: List of dictionaries containing document data. Each dict needs a "text" key.
"""
local_documents = []
ids = []
for doc_data in documents:
if not doc_data.get("id"):
doc_data["id"] = str(uuid4())
local_documents.append(
Document(
page_content=doc_data["text"],
metadata={k: v for k, v in doc_data.items() if k != "text"},
)
)
ids.append(doc_data["id"])
batch_size = 100 # Adjust batch size as needed
for i in tqdm(range(0, len(documents), batch_size)):
batch_docs = local_documents[i : i + batch_size]
batch_ids = ids[i : i + batch_size]
# Chroma's add_documents doesn't directly support pre-defined IDs. Upsert instead.
self._upsert_batch(batch_docs, batch_ids)
def _upsert_batch(self, batch_docs: List[Document], batch_ids: List[str]):
"""Upsert a batch of documents into Chroma. If the ID exists, it updates; otherwise, it creates."""
texts = [doc.page_content for doc in batch_docs]
metadatas = [doc.metadata for doc in batch_docs]
self.vector_store.add_texts(texts=texts, metadatas=metadatas, ids=batch_ids)
Esta clase maneja la creación de una colección e incrustando cada documento antes de insertarlo en el índice Vector. Usamos Baai/BGE-Small-en-V1.5, pero cualquier modelo de incrustación funcionaría, incluso aquellos disponibles como servicio de Gemini, OpenAi o Nebius.
- Encontrar los k vecinos más cercanos para un texto de entrada
def search(self, query: str, k: int = 5) -> List[Document]:
"""Search documents by semantic similarity."""
results = self.vector_store.similarity_search(query, k=k)
return results
Este método devuelve los documentos en la base de datos Vector que son más similares a nuestra entrada.
- Construyendo el clasificador aumentado de recuperación
from typing import Optional
from pydantic import BaseModel, Field
from collections import Counter
from retrieval_augmented_classification.vector_store import DatasetVectorStore
from tenacity import retry, stop_after_attempt, wait_exponential
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
class PredictedCategories(BaseModel):
"""
Pydantic model for the predicted categories from the LLM.
"""
reasoning: str = Field(description="Explain your reasoning")
predicted_category: str = Field(description="Category")
class RAC:
"""
A hybrid classifier combining K-Nearest Neighbors retrieval with an LLM for multi-class prediction.
Finds top K neighbors, uses top few-shot for context, and uses all neighbor categories
as potential prediction candidates for the LLM.
"""
def __init__(
self,
vector_store: DatasetVectorStore,
llm_client,
knn_k_search: int = 30,
knn_k_few_shot: int = 5,
):
"""
Initializes the classifier.
Args:
vector_store: An instance of DatasetVectorStore with a search method.
llm_client: An instance of the LLM client capable of structured output.
knn_k_search: The number of nearest neighbors to retrieve from the vector store.
knn_k_few_shot: The number of top neighbors to use as few-shot examples for the LLM.
Must be less than or equal to knn_k_search.
"""
self.vector_store = vector_store
self.llm_client = llm_client
self.knn_k_search = knn_k_search
self.knn_k_few_shot = knn_k_few_shot
@retry(
stop=stop_after_attempt(3), # Retry LLM call a few times
wait=wait_exponential(multiplier=1, min=2, max=5), # Shorter waits for demo
)
def predict(self, document_text: str) -> Optional[str]:
"""
Predicts the relevant categories for a given document text using KNN retrieval and an LLM.
Args:
document_text: The text content of the document to classify.
Returns:
The predicted category
"""
neighbors = self.vector_store.search(document_text, k=self.knn_k_search)
all_neighbor_categories = set()
valid_neighbors = [] # Store neighbors that have metadata and categories
for neighbor in neighbors:
if (
hasattr(neighbor, "metadata")
and isinstance(neighbor.metadata, dict)
and "category" in neighbor.metadata
):
all_neighbor_categories.add(neighbor.metadata["category"])
valid_neighbors.append(neighbor)
else:
pass # Suppress warnings for cleaner demo output
if not valid_neighbors:
return None
category_counts = Counter(all_neighbor_categories)
ranked_categories = [
category for category, count in category_counts.most_common()
]
if not ranked_categories:
return None
few_shot_neighbors = valid_neighbors[: self.knn_k_few_shot]
messages = []
system_prompt = f"""You are an expert multi-class classifier. Your task is to analyze the provided document text and assign the most relevant category from the list of allowed categories.
You MUST only return categories that are present in the following list: {ranked_categories}.
If none of the allowed categories are relevant, return an empty list.
Return the categories by likelihood (more confident to least confident).
Output your prediction as a JSON object matching the Pydantic schema: {PredictedCategories.model_json_schema()}.
"""
messages.append(SystemMessage(content=system_prompt))
for i, neighbor in enumerate(few_shot_neighbors):
messages.append(
HumanMessage(content=f"Document: {neighbor.page_content}")
)
expected_output_json = PredictedCategories(
reasoning="Your reasoning here",
predicted_category=neighbor.metadata["category"]
).model_dump_json()
# Simulate the structure often used with tool calling/structured output
ai_message_with_tool = AIMessage(
content=expected_output_json,
)
messages.append(ai_message_with_tool)
# Final user message: The document text to classify
messages.append(HumanMessage(content=f"Document: {document_text}"))
# Configure the client for structured output with the Pydantic schema
structured_client = self.llm_client.with_structured_output(PredictedCategories)
llm_response: PredictedCategories = structured_client.invoke(messages)
predicted_category = llm_response.predicted_category
return predicted_category if predicted_category in ranked_categories else None
La primera parte del código define la estructura de la salida que esperamos del LLM. La clase Pydantic tiene dos campos, el razonamiento, utilizado para la cadena de impulso (https://www.prompptingguide.ai/techniques/cot) y la categoría predicha.
El método Predicto primero encuentra los vecinos K más cercanos y los usa como pocas indicaciones al crear un historial de mensajes sintéticos como si el LLM diera las categorías correctas para cada KNN, luego inyectamos el texto de la consulta como el último mensaje humano.
Filtramos el valor para verificar si es válido y, si es así, devuélvalo.
_rac = RAC(
vector_store=store,
llm_client=llm_client,
knn_k_search=50,
knn_k_few_shot=10,
)
print(
f"Initialized rac with knn_k_search={_rac.knn_k_search}, knn_k_few_shot={_rac.knn_k_few_shot}."
)
text = """Ivanoe Bonomi [iˈvaːnoe boˈnɔːmi] (18 October 1873 – 20 April 1951) was an Italian politician and statesman before and after World War II. Bonomi was born in Mantua. He was elected to the Italian Chamber of Deputies in ...
"""
category = _rac.predict(text)
print(text)
print(category)
text = """Michel Rocard, né le 23 août 1930 à Courbevoie et mort le 2 juillet 2016 à Paris, est un haut fonctionnaire et ...
"""
category = _rac.predict(text)
print(text)
print(category)
Ambas entradas devuelven la predicción “Primeminister” a pesar de que el segundo ejemplo está en francés, mientras que el conjunto de datos de capacitación está completamente en inglés. Esto ilustra las habilidades de generalización de este enfoque incluso en idiomas similares.
Usamos el Clases de dbpedia Categorías L3 del conjunto de datos (https://www.kaggle.com/datasets/danofer/dbpedia-classes ,Licencia CC BY-SA 3.0.) para nuestra evaluación. Este conjunto de datos tiene más de 200 categorías y 240000 muestras de entrenamiento.
Comparamos el enfoque de clasificación aumentada de recuperación contra un clasificador KNN simple con voto mayoritario y obtenemos los siguientes resultados del conjunto de datos DBPEDIA categorías L3:
| Exactitud | Latencia promedio | Rendimiento (múltiples subprocesos) | |
| Clasificador KNN | 87% | 24 ms | 108 predicciones / s |
| Clasificador solo llm | 88% | ~ 600 ms | 47 predicciones / s |
| RAC | 96% | ~ 1s | 27 predicciones / s |
Por referencia, la mejor precisión que encontré en los cuadernos de Kaggle para el nivel L3 de este conjunto de datos fue 94% Uso de modelos ML personalizados.
Observamos que combinar una búsqueda de KNN con las habilidades de razonamiento de un LLM nos permite obtener puntos de precisión de +9%, pero tiene un costo de menor rendimiento y mayor latencia.
Conclusión
En este proyecto creamos un clasificador de texto que aprovecha la “recuperación” para aumentar la capacidad de un LLM para encontrar la categoría correcta del contenido de entrada. Este enfoque ofrece varias ventajas sobre los clasificadores de texto ML tradicionales. Estos incluyen la capacidad de cambiar dinámicamente el conjunto de datos de capacitación sin capacitación, una mayor capacidad de generalización debido al razonamiento y el conocimiento general de las LLM, la implementación fácil al usar servicios LLM administrados en comparación con los modelos ML personalizados y la capacidad de manejar tareas de clasificación múltiple con un modelo LLM único. Esto tiene un costo de mayor latencia y menor rendimiento y un riesgo de bloqueo del proveedor de LLM.
Este método no debería ser su primera opción cuando se trabaja en una tarea de clasificación, pero aún así sería útil como parte de su caja de herramientas cuando su aplicación puede beneficiarse de la flexibilidad de no tener que volver a entrenar un clasificador cada vez que los datos cambian o cuando trabajan con una pequeña cantidad de datos etiquetados. También puede permitirle obtener el objetivo de tener un servicio de clasificación en funcionamiento muy rápidamente cuando se avecina una fecha límite 😃.
Fuentes:
- [1] G. Yu, L. Liu, H. Jiang, S. Shi y X. Ao, Clasificación de texto de pocos disparos acuático de recuperación (2023), Hallazgos de la Asociación de Lingüística Computacional: EMNLP 2023
- [2] A. Long, W. Yin, T. Ajanthan, V. Nguyen, P. Purkait, R. Garg, C. Shen y A. van den Hengel, Recuperación de clasificación aumentada para el reconocimiento visual de cola larga (2022)
Código: https://github.com/cvxtz/retrieval_augmented_classification