Cree un efecto de modo retrato con Segment Anything Model 2 (SAM2)

¿Alguna vez has admirado cómo las cámaras de los teléfonos inteligentes aíslan al sujeto principal del fondo, agregando un sutil desenfoque al fondo según la profundidad? Este efecto de “modo retrato” brinda a las fotografías un aspecto profesional al simular una profundidad de campo reducida similar a las cámaras DSLR. En este tutorial, recrearemos este efecto mediante programación utilizando modelos de visión por computadora de código abierto, como SAM2 de Meta y MiDaS de Intel ISL.

Para construir nuestra tubería, usaremos:

  1. Modelo de segmentación de cualquier cosa (SAM2): Para segmentar objetos de interés y separar el primer plano del fondo.
  2. Modelo de estimación de profundidad: Para calcular un mapa de profundidad, habilitando el desenfoque basado en la profundidad.
  3. Desenfoque gaussiano: Para desenfocar el fondo con una intensidad que varía según la profundidad.

Paso 1: configurar el entorno

Para comenzar, instale las siguientes dependencias:

pip install matplotlib samv2 pytest opencv-python timm pillow

Paso 2: cargar una imagen de destino

Elija una imagen para aplicar este efecto y cárguela en Python usando el Almohada biblioteca.

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

image_path = "<path to your image>.jpg"
img = Image.open(image_path)
img_array = np.array(img)

# Display the image
plt.imshow(img)
plt.axis("off")
plt.show()

Paso 3: Inicialice el SAM2

Para inicializar el modelo, descargue el punto de control previamente entrenado. SAM2 ofrece cuatro variantes según el rendimiento y la velocidad de inferencia: minúsculo, pequeño, base_plus y grande. En este tutorial, usaremos tiny para una inferencia más rápida.

Descargue el modelo de punto de control desde: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_.pt

Reemplace con el tipo de modelo que desee.

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.utils.misc import variant_to_config_mapping
from sam2.utils.visualization import show_masks

model = build_sam2(
    variant_to_config_mapping["tiny"],
    "sam2_hiera_tiny.pt",
)
image_predictor = SAM2ImagePredictor(model)

Paso 4: introduzca la imagen en SAM y seleccione el tema

Configure la imagen en SAM y proporcione puntos que se encuentren en el sujeto que desea aislar. SAM predice una máscara binaria del sujeto y el fondo.

image_predictor.set_image(img_array)
input_point = np.array([[2500, 1200], [2500, 1500], [2500, 2000]])
input_label = np.array([1, 1, 1])

masks, scores, logits = image_predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=None,
    multimask_output=True,
)
output_mask = show_masks(img_array, masks, scores)
sorted_ind = np.argsort(scores)[::-1]

Paso 5: Inicializar el modelo de estimación de profundidad

Para la estimación de profundidad, utilizamos Midas por Intel ISL. Al igual que SAM, puedes elegir diferentes variantes según la precisión y la velocidad.Nota: El mapa de profundidad previsto se invierte, lo que significa que los valores más grandes corresponden a objetos más cercanos. Lo invertiremos en el siguiente paso para una mejor intuición.

import torch
import torchvision.transforms as transforms

model_type = "DPT_Large"  # MiDaS v3 - Large (highest accuracy)

# Load MiDaS model
model = torch.hub.load("intel-isl/MiDaS", model_type)
model.eval()

# Load and preprocess image
transform = torch.hub.load("intel-isl/MiDaS", "transforms").dpt_transform
input_batch = transform(img_array)

# Perform depth estimation
with torch.no_grad():
    prediction = model(input_batch)
    prediction = torch.nn.functional.interpolate(
        prediction.unsqueeze(1),
        size=img_array.shape[:2],
        mode="bicubic",
        align_corners=False,
    ).squeeze()

prediction = prediction.cpu().numpy()

# Visualize the depth map
plt.imshow(prediction, cmap="plasma")
plt.colorbar(label="Relative Depth")
plt.title("Depth Map Visualization")
plt.show()

Paso 6: aplicar desenfoque gaussiano basado en profundidad

Aquí optimizamos el desenfoque basado en la profundidad utilizando un enfoque de desenfoque gaussiano iterativo. En lugar de aplicar un único núcleo grande, aplicamos un núcleo más pequeño varias veces para los píxeles con valores de profundidad más altos.

import cv2

def apply_depth_based_blur_iterative(image, depth_map, base_kernel_size=7, max_repeats=10):
    if base_kernel_size % 2 == 0:
        base_kernel_size += 1

    # Invert depth map
    depth_map = np.max(depth_map) - depth_map

    # Normalize depth to range [0, max_repeats]
    depth_normalized = cv2.normalize(depth_map, None, 0, max_repeats, cv2.NORM_MINMAX).astype(np.uint8)

    blurred_image = image.copy()

    for repeat in range(1, max_repeats + 1):
        mask = (depth_normalized == repeat)
        if np.any(mask):
            blurred_temp = cv2.GaussianBlur(blurred_image, (base_kernel_size, base_kernel_size), 0)
            for c in range(image.shape[2]):
                blurred_image[..., c][mask] = blurred_temp[..., c][mask]

    return blurred_image

blurred_image = apply_depth_based_blur_iterative(img_array, prediction, base_kernel_size=35, max_repeats=20)

# Visualize the result
plt.figure(figsize=(20, 10))
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(blurred_image)
plt.title("Depth-based Blurred Image")
plt.axis("off")
plt.show()

Paso 7: combine el primer plano y el fondo

Finalmente, usa la máscara SAM para extraer el primer plano nítido y combinarlo con el fondo borroso.

def combine_foreground_background(foreground, background, mask):
    if mask.ndim == 2:
        mask = np.expand_dims(mask, axis=-1)
    return np.where(mask, foreground, background)

mask = masks[sorted_ind[0]].astype(np.uint8)
mask = cv2.resize(mask, (img_array.shape[1], img_array.shape[0]))
foreground = img_array
background = blurred_image

combined_image = combine_foreground_background(foreground, background, mask)

plt.figure(figsize=(20, 10))
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(combined_image)
plt.title("Final Portrait Mode Effect")
plt.axis("off")
plt.show()

Conclusión

Con sólo unas pocas herramientas, hemos recreado el efecto del modo retrato mediante programación. Esta técnica se puede ampliar para aplicaciones de edición de fotografías, simulación de efectos de cámara o proyectos creativos.

Mejoras futuras:

  1. Utilice algoritmos de detección de bordes para refinar mejor los bordes del sujeto.
  2. Experimente con el tamaño de los granos para mejorar el efecto de desenfoque.
  3. Cree una interfaz de usuario para cargar imágenes y seleccionar temas dinámicamente.

Recursos:

  1. Segmentar cualquier modelo por META (https://github.com/facebookresearch/sam2)
  2. Implementación compatible con CPU de SAM 2 (https://github.com/SauravMaheshkar/samv2/tree/main)
  3. Modelo de estimación de profundidad MIDas (https://pytorch.org/hub/intelisl_midas_v2/)


Vineet Kumar es pasante de consultoría en MarktechPost. Actualmente está cursando su licenciatura en el Instituto Indio de Tecnología (IIT), Kanpur. Es un entusiasta del aprendizaje automático. Le apasiona la investigación y los últimos avances en Deep Learning, Computer Vision y campos relacionados.