La guía completa para crear conjuntos de datos y cargadores de datos personalizados para diferentes modelos en PyTorch
Antes de poder crear un modelo de aprendizaje automático, debe cargar sus datos en un conjunto de datos. Afortunadamente, PyTorch tiene muchos comandos para ayudar con todo este proceso (si no está familiarizado con PyTorch, le recomiendo actualizar los conceptos básicos). aquí).
PyTorch tiene buena documentación para ayudar con este proceso, pero no he encontrado ninguna documentación completa ni tutoriales sobre conjuntos de datos personalizados. ¡Primero comenzaré con la creación de conjuntos de datos básicos prediseñados y luego avanzaré hasta crear conjuntos de datos desde cero para diferentes modelos!
Antes de profundizar en el código para diferentes casos de uso, comprendamos la diferencia entre los dos términos. Generalmente, primero crea su conjunto de datos y luego crea un cargador de datos. A conjunto de datos contiene las características y etiquetas de cada punto de datos que se introducirán en el modelo. A cargador de datos es un iterable personalizado de PyTorch que facilita la carga de datos con funciones adicionales.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
Los argumentos más comunes en el cargador de datos son tamaño del lote, barajar (normalmente sólo para los datos de entrenamiento), numero_trabajadores (para cargar los datos en múltiples procesos), y pin_memoria (para colocar los tensores de datos recuperados en la memoria fija y permitir una transferencia de datos más rápida a GPU habilitadas para CUDA).
Se recomienda establecer pin_memory = True en lugar de especificar num_workers debido a complicaciones de multiprocesamiento con CUDA.
En el caso de que su conjunto de datos se descargue en línea o localmente, será extremadamente sencillo crear el conjunto de datos. Creo que PyTorch tiene buenos documentación sobre esto, así que seré breve.
Si sabe que el conjunto de datos es de PyTorch o compatible con PyTorch, simplemente llame a las importaciones necesarias y al conjunto de datos de su elección:
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms imports ToTensordata = torchvision.datasets.CIFAR10('path', train=True, transform=ToTensor())
Cada conjunto de datos tendrá argumentos únicos para pasarle (encontrados aquí). En general, será la ruta en la que se almacena el conjunto de datos, un valor booleano que indica si es necesario descargarlo o no (convenientemente llamado descarga), si se trata de entrenamiento o prueba, y si es necesario aplicar transformaciones.
Mencioné que las transformaciones se pueden aplicar a un conjunto de datos al final de la última sección, pero ¿qué es realmente una transformación?
A transformar es un método de manipulación de datos para preprocesar una imagen. Hay muchas facetas diferentes de las transformaciones. La transformación más común, ATensor(), convertirá el conjunto de datos en tensores (necesarios para ingresar en cualquier modelo). Otras transformaciones integradas en PyTorch (torchvision.transforms) incluyen voltear, rotar, recortar, normalizar y cambiar imágenes. Por lo general, se utilizan para que el modelo pueda generalizarse mejor y no se ajuste demasiado a los datos de entrenamiento. Los aumentos de datos también se pueden utilizar para aumentar artificialmente el tamaño del conjunto de datos si es necesario.
Tenga en cuenta que la mayoría de las transformaciones de torchvision solo aceptan formatos de imagen Pillow o tensor (no numpy). Para convertir, simplemente use
Para convertir desde numpy, cree un tensor de antorcha o use lo siguiente:
From PIL import Image
# assume arr is a numpy array
# you may need to normalize and cast arr to np.uint8 depending on format
img = Image.fromarray(arr)
Las transformaciones se pueden aplicar simultáneamente usando torchvision.transforms.compose. Puede combinar tantas transformaciones como necesites para el conjunto de datos. A continuación se muestra un ejemplo:
import torchvision.transforms.Composedataset_transform = transforms.Compose([
transforms.RandomResizedCrop(256),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
Asegúrese de pasar la transformación guardada como argumento al conjunto de datos para que se aplique en el cargador de datos.
En la mayoría de los casos de desarrollo de su propio modelo, necesitará un conjunto de datos personalizado. Un caso de uso común sería el aprendizaje por transferencia para aplicar su propio conjunto de datos en un modelo previamente entrenado.
Hay 3 partes requeridas para una clase de conjunto de datos de PyTorch: inicialización, longitudy recuperando un elemento.
__en eso__: Para inicializar el conjunto de datos, pase los datos sin procesar y etiquetados. La mejor práctica es pasar los datos de la imagen sin procesar y los datos etiquetados por separado.
__len__: Devuelve la longitud del conjunto de datos. Antes de crear el conjunto de datos, se debe verificar que los datos sin procesar y etiquetados tengan el mismo tamaño.
__obtiene el objeto__: Aquí es donde ocurre todo el manejo de datos para devolver un índice determinado (idx) de los datos sin procesar y etiquetados. Si es necesario aplicar alguna transformación, los datos deben convertirse a un tensor y transformarse. Si la inicialización contenía una ruta al conjunto de datos, se debe abrir la ruta y acceder a los datos/preprocesarlos antes de poder devolverlos.
Conjunto de datos de ejemplo para un modelo de segmentación semántica:
from torch.utils.data import Dataset
from torchvision import transformsclass ExampleDataset(Dataset):
"""Example dataset"""
def __init__(self, raw_img, data_mask, transform=None):
self.raw_img = raw_img
self.data_mask = data_mask
self.transform = transform
def __len__(self):
return len(self.raw_img)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.raw_img[idx]
mask = self.data_mask[idx]
sample = {'image': image, 'mask': mask}
if self.transform:
sample = self.transform(sample)
return sample