의미론적 및 인스턴스 분할: U-Net, Mask R-CNN 및 SAM
이미지 분할은 시각적 이해의 가장 세부적인 수준을 나타냅니다. "이 이미지에 종양이 있습니다"(분류) 또는 "종양이 발견되었습니다"라는 것을 아는 대신 이 영역에서"(탐지)에 대해 알고 싶습니다. 정확히 어떤 픽셀이 종양에 속해 있는지. 이 픽셀 단위의 완벽한 정밀도는 의학, 로봇 수술, 자율 주행의 기본입니다. 및 산업 품질 관리.
이 기사에서는 분할을 위한 가장 중요한 아키텍처를 살펴보겠습니다. 유넷 (의료 세분화에 혁명을 일으킨 모델), 마스크 R-CNN (인스턴스 분할의 최적 표준) e SAM (가능한 한계를 재정의한 Meta AI의 Segment Anything Model)
무엇을 배울 것인가
- U-Net 아키텍처: 의료 분할을 위한 스킵 연결이 있는 인코더-디코더
- 의료 데이터 세트에 대한 교육을 통해 PyTorch에서 처음부터 U-Net 구현
- 마스크 R-CNN: 경계 상자 + 바이너리 마스크를 사용한 인스턴스 분할
- SAM(Segment Anything Model): 시각적 프롬프트를 통한 제로샷 분할
- 평가 지표: 주사위 점수, IoU, 분할을 위한 정밀도/재현율
- 후처리 기술: CRF, 수학적 형태학
- 사례 연구: 방사선 사진을 통한 폐 분할(오픈 소스 데이터세트)
- 프로덕션에 세분화 모델 배포
1. 분할의 기본
1.1 분할 유형
세분화 분류
| 유형 | 인스턴스 구별 | 순위 배경 | 출력 | 아키텍처 |
|---|---|---|---|---|
| 의미론 | No | Si | 픽셀당 레이블이 있는 HxW 맵 | U-Net, DeepLabv3, SegFormer |
| 예를 들어 | Si | 아니요(그냥 "사물") | 객체의 바이너리 마스크 | 마스크 R-CNN, SOLOv2, YOLACT |
| 파놉틱 | 예('사물'의 경우) | 예('물건'의 경우) | 통합 인스턴스+의미 체계 맵 | Panoptic FPN, Mask2Former |
| 인터랙티브 | 예(프롬프트 포함) | 프롬프트에 따라 다릅니다. | 클릭/bbox 기반 마스크 | SAM, SAM2, ClickSEG |
1.2 평가 지표
세분화를 위해 픽셀별로 겹치는 부분을 측정하는 특정 측정항목이 사용됩니다. 예측된 마스크와 실제 진실 사이:
import torch
import numpy as np
from typing import Union
def compute_iou(pred: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> float:
"""
Intersection over Union per segmentazione binaria.
pred, target: tensori [H, W] o [B, H, W] con valori in [0,1]
"""
pred_binary = (pred >= threshold).bool()
target_binary = target.bool()
intersection = (pred_binary & target_binary).float().sum()
union = (pred_binary | target_binary).float().sum()
if union == 0:
return 1.0 # caso degenere: entrambe vuote
return float(intersection / union)
def dice_score(pred: torch.Tensor, target: torch.Tensor, threshold: float = 0.5,
smooth: float = 1.0) -> float:
"""
Dice Score (F1 per segmentazione): 2*|X intersect Y| / (|X| + |Y|)
Preferito in ambito medico perchè meno sensibile agli sbilanciamenti.
Valore: 0 (peggio) -> 1 (perfetto)
"""
pred_binary = (pred >= threshold).float()
target_binary = target.float()
intersection = (pred_binary * target_binary).sum()
dice = (2.0 * intersection + smooth) / (pred_binary.sum() + target_binary.sum() + smooth)
return float(dice)
def compute_multiclass_miou(pred_logits: torch.Tensor, targets: torch.Tensor,
num_classes: int, ignore_index: int = 255) -> float:
"""
mIoU per segmentazione semantica multi-classe.
pred_logits: [B, C, H, W] - logit grezzi
targets: [B, H, W] - indici di classe 0..num_classes-1
"""
preds = pred_logits.argmax(dim=1) # [B, H, W]
ious = []
for cls in range(num_classes):
pred_cls = preds == cls
true_cls = targets == cls
valid = targets != ignore_index
pred_cls = pred_cls & valid
true_cls = true_cls & valid
intersection = (pred_cls & true_cls).sum().float()
union = (pred_cls | true_cls).sum().float()
if union > 0:
ious.append(float(intersection / union))
return float(np.mean(ious)) if ious else 0.0
def hausdorff_distance(pred: np.ndarray, target: np.ndarray) -> float:
"""
Hausdorff Distance: misura la distanza massima tra i bordi delle maschere.
Utile in medicina per valutare la precisione dei contorni.
"""
from scipy.spatial.distance import directed_hausdorff
pred_points = np.argwhere(pred)
target_points = np.argwhere(target)
if len(pred_points) == 0 or len(target_points) == 0:
return float('inf')
d1 = directed_hausdorff(pred_points, target_points)[0]
d2 = directed_hausdorff(target_points, pred_points)[0]
return max(d1, d2)
print("Esempio metriche:")
pred = torch.sigmoid(torch.randn(256, 256))
target = (torch.randn(256, 256) > 0).float()
iou = compute_iou(pred, target)
dice = dice_score(pred, target)
print(f"IoU: {iou:.3f} | Dice: {dice:.3f}")
2. U-Net: 의료 세분화를 위한 네트워크
유넷 (Ronnberger et al., 2015)은 원래 분할을 위해 제안되었습니다. 생체 의학 이미지. "U"자 모양의 아키텍처 연결 건너뛰기 인코더와 디코더 사이에서 모든 분할 작업에 대한 지배적인 템플릿이 되었습니다. 의료용 픽셀부터 위성 지도, 산업 이미지부터 야외 장면까지 밀도가 높습니다.
2.1 U-Net 아키텍처
아키텍처는 세 부분으로 나누어집니다.
- 인코더(수축 경로): 일련의 컨벌루션 블록 + 해상도를 줄이고 채널을 늘리는 최대 풀링, 의미상 풍부하지만 공간적으로 부정확한 특징 추출
- 병목: 가장 깊은 블록은 가장 낮은 해상도에서 작동합니다.
- 디코더(확장 경로): 원래 해상도를 복원하는 일련의 업샘플링 + 변환, 건너뛰기 연결을 통해 인코더 기능 맵을 연결하여 공간 세부 정보 복구
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""Blocco base U-Net: Conv-BN-ReLU-Conv-BN-ReLU."""
def __init__(self, in_channels: int, out_channels: int, mid_channels: int | None = None):
super().__init__()
if mid_channels is None:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.double_conv(x)
class DownBlock(nn.Module):
"""Encoder block: MaxPool2d + DoubleConv."""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.maxpool_conv(x)
class UpBlock(nn.Module):
"""Decoder block: Upsample + concatenazione skip + DoubleConv."""
def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2,
kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
# Padding se le dimensioni non coincidono esattamente
diff_h = x2.size(2) - x1.size(2)
diff_w = x2.size(3) - x1.size(3)
x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2,
diff_h // 2, diff_h - diff_h // 2])
# Skip connection: concatena feature encoder + decoder
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class UNet(nn.Module):
"""
U-Net originale per segmentazione binaria o multi-classe.
Architettura:
Input -> [64] -> [128] -> [256] -> [512] -> [1024] (bottleneck)
-> [512] -> [256] -> [128] -> [64] -> Output
Le frecce verso il basso sono encoder (+ maxpool)
Le frecce verso l'alto sono decoder (+ skip connections)
"""
def __init__(self, in_channels: int = 1, num_classes: int = 1,
features: list[int] = [64, 128, 256, 512], bilinear: bool = True):
super().__init__()
self.in_conv = DoubleConv(in_channels, features[0])
# Encoder
self.downs = nn.ModuleList([
DownBlock(features[i], features[i+1])
for i in range(len(features) - 1)
])
# Bottleneck
factor = 2 if bilinear else 1
self.bottleneck = DownBlock(features[-1], features[-1] * 2 // factor)
# Decoder
self.ups = nn.ModuleList([
UpBlock(features[-1] * 2 // factor + features[-(i+1)],
features[-(i+2)] if i < len(features)-1 else features[0],
bilinear)
for i in range(len(features))
])
# Semplifichiamo con lista esplicita
self.ups = nn.ModuleList([
UpBlock(1024, 512 // factor, bilinear),
UpBlock(512, 256 // factor, bilinear),
UpBlock(256, 128 // factor, bilinear),
UpBlock(128, 64, bilinear),
])
self.out_conv = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Encoder path (salva skip connections)
x1 = self.in_conv(x)
x2 = self.downs[0](x1)
x3 = self.downs[1](x2)
x4 = self.downs[2](x3)
# Bottleneck
x5 = self.bottleneck(x4)
# Decoder path (usa skip connections)
x = self.ups[0](x5, x4)
x = self.ups[1](x, x3)
x = self.ups[2](x, x2)
x = self.ups[3](x, x1)
return self.out_conv(x)
# Test architettura
model = UNet(in_channels=3, num_classes=1)
x = torch.randn(2, 3, 256, 256)
y = model(x)
print(f"Input: {x.shape} -> Output: {y.shape}")
# Input: torch.Size([2, 3, 256, 256]) -> Output: torch.Size([2, 1, 256, 256])
total_params = sum(p.numel() for p in model.parameters())
print(f"Parametri: {total_params:,}")
2.2 주사위 손실을 이용한 U-Net 훈련
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
"""
Dice Loss per segmentazione binaria.
Gestisce naturalmente lo sbilanciamento di classe tipico delle immagini mediche
(es. 95% sfondo, 5% lesione).
"""
def __init__(self, smooth: float = 1.0):
super().__init__()
self.smooth = smooth
def forward(self, pred_logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# Applica sigmoid per ottenere probabilità
preds = torch.sigmoid(pred_logits)
# Flatten per calcolo efficiente
preds_flat = preds.view(-1)
targets_flat = targets.view(-1)
intersection = (preds_flat * targets_flat).sum()
dice = (2.0 * intersection + self.smooth) / (
preds_flat.sum() + targets_flat.sum() + self.smooth
)
return 1.0 - dice # loss = 1 - Dice (minimizzare)
class CombinedLoss(nn.Module):
"""
Combinazione BCE + Dice: il compromesso migliore per segmentazione medica.
BCE: ottimizza ogni pixel individualmente
Dice: ottimizza l'overlap globale tra predizione e ground truth
"""
def __init__(self, bce_weight: float = 0.5, dice_weight: float = 0.5):
super().__init__()
self.bce = nn.BCEWithLogitsLoss()
self.dice = DiceLoss()
self.bce_weight = bce_weight
self.dice_weight = dice_weight
def forward(self, pred_logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
bce_loss = self.bce(pred_logits, targets.float())
dice_loss = self.dice(pred_logits, targets.float())
return self.bce_weight * bce_loss + self.dice_weight * dice_loss
def train_unet(
model: UNet,
train_loader,
val_loader,
num_epochs: int = 50,
learning_rate: float = 1e-4
) -> dict:
"""
Training completo di U-Net con:
- Combined BCE+Dice loss
- AdamW + CosineAnnealingLR
- Early stopping su Dice score di validazione
- Checkpoint del modello migliore
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=1e-5
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=num_epochs, eta_min=1e-6
)
history = {'train_loss': [], 'val_loss': [], 'val_dice': []}
best_dice = 0.0
patience = 15
no_improve = 0
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
pred_logits = model(images)
loss = criterion(pred_logits, masks)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
scheduler.step()
# Validation
model.eval()
val_loss = 0.0
val_dice_scores = []
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
pred_logits = model(images)
val_loss += criterion(pred_logits, masks).item()
preds = torch.sigmoid(pred_logits)
for p, m in zip(preds, masks):
val_dice_scores.append(dice_score(p, m))
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
avg_val_dice = sum(val_dice_scores) / len(val_dice_scores)
history['train_loss'].append(avg_train_loss)
history['val_loss'].append(avg_val_loss)
history['val_dice'].append(avg_val_dice)
if avg_val_dice > best_dice:
best_dice = avg_val_dice
torch.save(model.state_dict(), 'best_unet.pth')
no_improve = 0
else:
no_improve += 1
print(f"Epoch {epoch+1:2d}/{num_epochs} | "
f"Loss: {avg_train_loss:.4f}/{avg_val_loss:.4f} | "
f"Dice: {avg_val_dice:.4f} | Best: {best_dice:.4f}")
if no_improve >= patience:
print(f"Early stopping at epoch {epoch+1}")
break
print(f"Training completato. Best Dice Score: {best_dice:.4f}")
return history
3. SAM(Segment Anything Model)
메타 AI가 출시되었습니다 SAM (Kirillov et al., 2023) 야심찬 목표를 가지고 일반 분할 모델 구축: 10억 개의 마스크에 대해 훈련된 모델 분할할 수 있는 것 아무것 in 모든 이미지 유연한 프롬프트 사용 (점, 경계 상자, 텍스트를 클릭합니다). SAM2(2024)는 모델을 비디오에도 확장했습니다.
3.1 SAM 아키텍처
SAM은 세 가지 주요 구성 요소로 구성됩니다.
- 이미지 인코더: 조밀한 이미지 임베딩을 생성하는 Vision Transformer(632M 매개변수의 ViT-H). 이미지당 한 번만 실행됩니다.
- 프롬프트 인코더: 다양한 유형(포인트, 상자, 마스크, 텍스트)의 프롬프트를 디코더 호환 임베딩으로 인코딩합니다.
- 마스크 디코더: 이미지 임베딩과 마스크 생성 프롬프트를 결합한 경량 변환기입니다. 신뢰도 점수를 사용하여 3개의 후보 마스크를 생성합니다.
# pip install segment-anything
# Download checkpoint: https://github.com/facebookresearch/segment-anything
import numpy as np
import cv2
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
def load_sam_model(
model_type: str = 'vit_h',
checkpoint_path: str = 'sam_vit_h_4b8939.pth',
device: str = 'cuda'
):
"""
Carica il modello SAM.
Tipi disponibili: 'vit_h' (default, max accuratezza), 'vit_l', 'vit_b' (più veloce)
"""
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam.to(device=device)
return sam
def segment_with_point_prompt(
sam_model,
image: np.ndarray,
point_coords: list[tuple[int, int]],
point_labels: list[int] # 1=foreground, 0=background
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Segmenta con prompt a punti.
Restituisce: (maschere, score, logits) - 3 proposte ordinate per score.
"""
predictor = SamPredictor(sam_model)
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=np.array(point_coords),
point_labels=np.array(point_labels),
multimask_output=True # genera 3 maschere candidate
)
# Ordina per score decrescente
sorted_idx = np.argsort(scores)[::-1]
return masks[sorted_idx], scores[sorted_idx], logits[sorted_idx]
def segment_with_box_prompt(
sam_model,
image: np.ndarray,
box: tuple[int, int, int, int] # [x1, y1, x2, y2]
) -> tuple[np.ndarray, float]:
"""
Segmenta con prompt bounding box.
Il box definisce la regione di interesse da segmentare.
"""
predictor = SamPredictor(sam_model)
predictor.set_image(image)
masks, scores, _ = predictor.predict(
box=np.array([box]),
multimask_output=False # 1 sola maschera con box prompt
)
return masks[0], float(scores[0])
def automatic_segmentation(sam_model, image: np.ndarray) -> list[dict]:
"""
Segmentazione automatica: SAM segmenta TUTTO nell'immagine
senza nessun prompt. Usa una griglia di punti come seed.
"""
mask_generator = SamAutomaticMaskGenerator(
model=sam_model,
points_per_side=32, # griglia 32x32 = 1024 punti seed
pred_iou_thresh=0.88, # filtra maschere con IoU basso
stability_score_thresh=0.95, # filtra maschere instabili
crop_n_layers=1, # multi-crop per oggetti piccoli
crop_n_points_downscale_factor=2,
min_mask_region_area=100 # rimuovi regioni molto piccole
)
masks = mask_generator.generate(image)
# Ordina per area decrescente
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
print(f"SAM ha trovato {len(masks)} segmenti")
for i, mask in enumerate(masks[:5]):
print(f" Segmento {i+1}: area={mask['area']} "
f"score={mask['predicted_iou']:.3f}")
return masks
def visualize_sam_results(image: np.ndarray, masks: list[dict],
alpha: float = 0.4) -> np.ndarray:
"""Visualizza tutte le maschere SAM con colori random."""
result = image.copy()
np.random.seed(42)
for mask_info in masks:
mask = mask_info['segmentation'] # bool array [H, W]
color = np.random.randint(50, 255, 3)
overlay = result.copy()
overlay[mask] = color
result = cv2.addWeighted(result, 1 - alpha, overlay, alpha, 0)
# Contorno
contours, _ = cv2.findContours(
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
cv2.drawContours(result, contours, -1, color.tolist(), 2)
return result
# Esempio d'uso
sam = load_sam_model('vit_b', 'sam_vit_b_01ec64.pth') # versione più leggera
image = cv2.imread('image.jpg')
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Segmenta con un click (punto foreground)
masks, scores, _ = segment_with_point_prompt(
sam, image_rgb,
point_coords=[(320, 240)], # click al centro dell'oggetto
point_labels=[1] # 1 = foreground
)
best_mask = masks[0]
print(f"Maschera trovata con score: {scores[0]:.3f}")
3.2 비디오용 SAM2
# pip install sam2
# SAM2 rilasciato da Meta AI nell'agosto 2024
import torch
from sam2.build_sam import build_sam2_video_predictor
def segment_video_with_sam2(
video_path: str,
initial_frame: int,
initial_points: list[tuple[int, int]],
checkpoint: str = 'sam2_hiera_large.pt',
config: str = 'sam2_hiera_l.yaml'
) -> dict[int, np.ndarray]:
"""
Segmenta e traccia un oggetto attraverso i frame di un video.
Inizializza con punti sul primo frame, poi traccia automaticamente.
Returns:
Dict frame_idx -> maschera binaria [H, W]
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
predictor = build_sam2_video_predictor(config, checkpoint, device=device)
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.bfloat16):
# Inizializza sul video
state = predictor.init_state(video_path=video_path)
predictor.reset_state(state)
# Aggiungi prompt sul frame iniziale
frame_idx, obj_ids, masks = predictor.add_new_points_or_box(
inference_state=state,
frame_idx=initial_frame,
obj_id=1,
points=np.array(initial_points),
labels=np.ones(len(initial_points), dtype=np.int32)
)
# Propaga su tutto il video
video_masks = {}
for frame_idx, obj_ids, masks in predictor.propagate_in_video(state):
mask = (masks[0][0] > 0.0).cpu().numpy()
video_masks[frame_idx] = mask
print(f"Segmentazione completata: {len(video_masks)} frame processati")
return video_masks
4. 사례 연구: 방사선 사진을 통한 폐 분할
우리는 흉부 방사선 사진의 폐 분할에 U-Net을 적용합니다. 몽고메리 카운티 X-Ray 데이터세트 (분할 마스크가 있는 138개의 방사선 사진 방사선 전문의가 수동으로 주석을 추가한 폐).
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
class LungXrayDataset(Dataset):
"""Dataset per segmentazione polmoni da radiografie (Montgomery CXR)."""
def __init__(self, image_dir: str, mask_dir: str, img_size: int = 512,
augment: bool = True):
self.image_paths = sorted(Path(image_dir).glob('*.png'))
self.mask_dir = Path(mask_dir)
self.img_size = img_size
if augment:
self.transform = A.Compose([
A.RandomResizedCrop(img_size, img_size, scale=(0.8, 1.0)),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1,
rotate_limit=15, p=0.7),
A.OneOf([
A.GaussNoise(var_limit=(10, 50)),
A.GaussianBlur(blur_limit=3),
A.MedianBlur(blur_limit=3)
], p=0.3),
A.RandomBrightnessContrast(brightness_limit=0.2,
contrast_limit=0.2, p=0.5),
A.CLAHE(clip_limit=2, p=0.3), # Contrast Limited AHE per RX
A.Normalize(mean=[0.485], std=[0.229]), # Grayscale normalization
ToTensorV2()
])
else:
self.transform = A.Compose([
A.Resize(img_size, img_size),
A.Normalize(mean=[0.485], std=[0.229]),
ToTensorV2()
])
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
img_path = self.image_paths[idx]
mask_path = self.mask_dir / img_path.name
# Carica immagine (grayscale)
image = np.array(Image.open(img_path).convert('L'))
mask = np.array(Image.open(mask_path).convert('L'))
# Binarizza maschera
mask = (mask > 127).astype(np.float32)
transformed = self.transform(image=image, mask=mask)
return transformed['image'], transformed['mask'].unsqueeze(0)
def run_lung_segmentation_pipeline():
"""Pipeline completa: dataset -> training -> valutazione -> salvataggio."""
# Data loading
train_dataset = LungXrayDataset(
'data/train/images', 'data/train/masks', augment=True
)
val_dataset = LungXrayDataset(
'data/val/images', 'data/val/masks', augment=False
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,
num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False,
num_workers=4, pin_memory=True)
# Modello: U-Net per immagini grayscale
model = UNet(in_channels=1, num_classes=1, features=[32, 64, 128, 256])
# Training
history = train_unet(model, train_loader, val_loader, num_epochs=100)
# Valutazione finale
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('best_unet.pth', map_location=device))
model.eval()
all_dice = []
all_iou = []
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
preds = torch.sigmoid(model(images))
for p, m in zip(preds, masks):
all_dice.append(dice_score(p, m))
all_iou.append(compute_iou(p, m))
print(f"\n=== Risultati Finali ===")
print(f"Dice Score: {np.mean(all_dice):.4f} ± {np.std(all_dice):.4f}")
print(f"IoU: {np.mean(all_iou):.4f} ± {np.std(all_iou):.4f}")
# Risultati attesi per U-Net su Montgomery: Dice ~0.97, IoU ~0.94
5. 무엇이든 분할 모델 2: 제로샷 비디오 분할
SAM2 (Meta AI, 2024년 7월) SAM을 비디오 시퀀스로 확장: 그 이상 대화형 프롬프트(점, 상자, 마스크)를 사용하여 정적 이미지의 개체를 분할합니다. SAM2는 모듈 덕분에 비디오 프레임을 따라 마스크를 자동으로 전파합니다. 기억. 비디오에서 제로샷 분할을 안정적으로 수행하는 최초의 모델입니다.
SAM과 SAM2: 주요 차이점
| 특징 | 샘(2023) | SAM2 (2024) |
|---|---|---|
| 비디오 지원 | 아니요(이미지만 해당) | 예(시간 전파) |
| 메모리 모듈 | 결석한 | 크로스 프레임 어텐션을 갖춘 메모리 뱅크 |
| 프롬프트 유형 | 점, 상자, 마스크, 텍스트(CLIP을 통해) | 포인트, 박스, 마스크(+비디오 추적) |
| 속도 | ~50ms/이미지(ViT-H) | ~44ms/프레임(Hiera-L), ~8ms(Hiera-T) |
| 훈련 데이터 | SA-1B(1B 마스크) | SA-V(50.9K 비디오, 642K 마스크) |
| 다중 객체 | 제한된 | 예, 동시 다중 객체 추적 |
import torch
import numpy as np
import cv2
from PIL import Image
# pip install git+https://github.com/facebookresearch/segment-anything-2.git
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
# ============================================================
# PARTE 1: SAM2 su singola immagine
# ============================================================
def sam2_image_segment(image_path: str,
point_coords: list[list[int]],
point_labels: list[int], # 1=foreground, 0=background
model_cfg: str = 'sam2_hiera_large.yaml',
checkpoint: str = 'sam2_hiera_large.pt') -> np.ndarray:
"""
Segmentazione con SAM2 su singola immagine.
point_coords: [[x1, y1], [x2, y2], ...] - punti prompt
point_labels: [1, 1, 0, ...] - 1=foreground, 0=background
Returns: maschera binaria [H, W] bool
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = build_sam2(model_cfg, checkpoint, device=device)
predictor = SAM2ImagePredictor(model)
# Carica immagine
image = np.array(Image.open(image_path).convert('RGB'))
predictor.set_image(image)
# Predici maschera con prompt
masks, scores, logits = predictor.predict(
point_coords=np.array(point_coords),
point_labels=np.array(point_labels),
multimask_output=True, # 3 maschere con confidenze diverse
)
# Prendi la maschera con score più alto
best_idx = np.argmax(scores)
best_mask = masks[best_idx]
print(f"Maschera selezionata: score={scores[best_idx]:.3f}, "
f"area={best_mask.sum()} pixel")
return best_mask # [H, W] bool
def sam2_box_prompt(image_np: np.ndarray,
box: list[int],
predictor: SAM2ImagePredictor) -> np.ndarray:
"""
Segmentazione con prompt box (x1, y1, x2, y2).
Più preciso dei punti per oggetti con bordi definiti.
"""
predictor.set_image(image_np)
masks, scores, _ = predictor.predict(
box=np.array(box),
multimask_output=False, # Box prompt -> singola maschera ottimale
)
return masks[0] # [H, W] bool
# ============================================================
# PARTE 2: SAM2 su video - propagazione temporale
# ============================================================
def sam2_video_segment(video_dir: str,
frame_idx: int,
points: list[list[int]],
labels: list[int],
model_cfg: str = 'sam2_hiera_large.yaml',
checkpoint: str = 'sam2_hiera_large.pt') -> dict:
"""
SAM2 video predictor: segmenta un oggetto nel frame 'frame_idx'
e propaga la maschera automaticamente lungo tutto il video.
video_dir: cartella con frame del video (frame_*.jpg)
Returns: dict {frame_idx: {obj_id: mask}}
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=device)
with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
# Inizializza predictor con la directory video
inference_state = predictor.init_state(video_path=video_dir)
# Aggiungi prompt nel frame di annotazione
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=1, # ID oggetto da trackare
points=np.array(points, dtype=np.float32),
labels=np.array(labels, dtype=np.int32),
)
# Propaga la segmentazione su tutto il video
all_masks = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state
):
for obj_id, mask_logit in zip(out_obj_ids, out_mask_logits):
mask = (mask_logit > 0).squeeze().cpu().numpy()
if out_frame_idx not in all_masks:
all_masks[out_frame_idx] = {}
all_masks[out_frame_idx][int(obj_id)] = mask
return all_masks
# ============================================================
# PARTE 3: SAM2 come labeling tool automatizzato
# ============================================================
class SAM2AutoLabeler:
"""
Usa SAM2 per generare automaticamente maschere di training.
Riduce i costi di annotazione del 60-80% rispetto all'annotazione manuale.
Human-in-the-loop: un umano valida e corregge le predizioni SAM2.
"""
def __init__(self, checkpoint: str = 'sam2_hiera_base_plus.pt',
model_cfg: str = 'sam2_hiera_base_plus.yaml'):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = build_sam2(model_cfg, checkpoint, device=device)
self.predictor = SAM2ImagePredictor(model)
def auto_label_from_yolo_boxes(self,
image_np: np.ndarray,
yolo_boxes: list[tuple],
min_score: float = 0.7) -> list[dict]:
"""
Genera maschere SAM2 usando bounding box di YOLO come prompt.
Workflow: YOLO rileva oggetti -> SAM2 affina con maschera pixel-perfect.
yolo_boxes: lista di (x1, y1, x2, y2, class_id, confidence)
Returns: lista di {box, class_id, mask, sam_score}
"""
self.predictor.set_image(image_np)
results = []
for x1, y1, x2, y2, class_id, conf in yolo_boxes:
if conf < 0.5:
continue
masks, scores, _ = self.predictor.predict(
box=np.array([x1, y1, x2, y2]),
multimask_output=True,
)
best_idx = np.argmax(scores)
if scores[best_idx] < min_score:
continue
results.append({
'box': (x1, y1, x2, y2),
'class_id': class_id,
'mask': masks[best_idx],
'sam_score': float(scores[best_idx]),
'yolo_conf': float(conf)
})
return results
def save_masks_coco_format(self, results: list[dict],
image_id: int,
output_path: str) -> None:
"""Salva maschere in formato COCO per training Mask R-CNN."""
import json
from pycocotools import mask as coco_mask
annotations = []
for ann_id, r in enumerate(results):
binary_mask = r['mask'].astype(np.uint8)
rle = coco_mask.encode(np.asfortranarray(binary_mask))
rle['counts'] = rle['counts'].decode('utf-8')
area = float(np.sum(binary_mask))
x1, y1, x2, y2 = r['box']
annotations.append({
'id': ann_id,
'image_id': image_id,
'category_id': r['class_id'],
'segmentation': rle,
'area': area,
'bbox': [x1, y1, x2-x1, y2-y1],
'iscrowd': 0
})
with open(output_path, 'w') as f:
json.dump(annotations, f, indent=2)
6. 분할 모범 사례
주요 권장 사항
- 손실 선택: 불균형 데이터세트(예: 넓은 배경에 있는 작은 병변)의 경우 순수 BCE 대신 Dice Loss 또는 Focal Loss를 사용합니다. ECB+주사위 결합이 가장 좋은 절충안인 경우가 많습니다.
- 도메인별 정규화: 의료 이미지(회색조)의 경우 ImageNet이 아닌 특정 데이터세트에서 계산된 통계를 사용합니다. 방사선 사진의 경우 CLAHE 전처리를 통해 결과가 크게 향상됩니다.
- 보수적인 데이터 증대: 의학에서는 해부학적으로 이해가 되지 않는다면 수직 뒤집기를 적용하지 마세요. 너무 많이 왜곡하지 마십시오. 해부학적 구조는 정확한 방향을 가지고 있습니다.
- 입력 해상도: U-Net은 해상도에 민감합니다. 엑스레이: 최소 512x512. 세부적인 세부 사항(조직학, 세포학): 1024x1024 또는 자르기 접근 방식.
- 후처리: CRF(조건부 무작위 필드) 또는 형태학적 연산(닫기, 열기)을 적용하여 마스크 가장자리를 선명하게 합니다.
- 라벨링을 위한 SAM: SAM을 사용하면 훈련 마스크(Human-In-The-Loop 라벨링) 생성을 가속화하여 주석 비용을 60-80% 절감할 수 있습니다.
일반적인 실수
- 분포가 다른 데이터를 검증하지 마세요. 의료 세분화 모델은 도메인 이동(다른 스캐너, 프로토콜, 모집단)에 매우 취약합니다. 항상 다른 센터의 데이터를 검증하십시오.
- 품질이 낮은 마스크를 무시합니다. 훈련에서 사람의 주석은 관찰자 간 가변성을 갖습니다. 가능하다면 주석 신뢰도를 기반으로 다중 주석자 합의 또는 체중 감량을 사용하세요.
- 주사위를 손실로만 사용하세요: Dice Loss는 소규모 배치에서는 불안정하며 변화도에 불연속성이 있습니다. 항상 BCE와 결합하거나 일반화된 주사위 손실 변형을 사용하십시오.
- 희귀 클래스 무시: 다중 클래스 분할에서는 희귀 클래스(몇 픽셀)가 모델에서 무시되는 경향이 있습니다. 희귀 클래스가 포함된 이미지의 클래스 가중치 손실 또는 오버샘플링을 사용합니다.
결론
우리는 주요 분할 아키텍처와 실제 적용을 살펴보았습니다.
- U-Net: 스킵 연결이 있는 인코더-디코더 아키텍처, 폐 방사선 사진에서 Dice ~0.97을 사용한 의료 분할을 위한 사실상의 표준
- 마스크 R-CNN: 경계 상자를 사용한 인스턴스 분할 + 각 인스턴스에 대한 마스크로 조밀하고 자연스러운 장면에 적합
- SAM 및 SAM2: 대화형 프롬프트(SAM) 및 임시 비디오 전파(SAM2)를 사용한 범용 제로샷 분할, 빠른 라벨링에 혁명적
- 자동 라벨링 도구로서의 SAM2: 주석 비용을 60-80% 절감하는 YOLO+SAM2 파이프라인
- 주사위 손실 및 결합된 BCE+Dice: 소규모 지역의 불균형 데이터 세트에 대한 최적의 손실
- 후처리: 마스크 가장자리를 다듬기 위한 수학적 형태학 및 CRF







