セマンティックおよびインスタンスのセグメンテーション: U-Net、マスク R-CNN、および SAM
画像のセグメンテーションは、最も詳細なレベルの視覚的理解を表します。 「この画像には腫瘍がある」(分類)または「腫瘍が見つかった」と知るのではなく、 このエリアで」(検出)、知りたい 正確にどのピクセルが腫瘍に属しているか。 このピクセル完璧な精度は、医療、ロボット手術、自動運転の基礎となります。 そして工業品質管理。
この記事では、セグメンテーションにとって最も重要なアーキテクチャについて説明します。 ユーネット (医療セグメンテーションに革命をもたらしたモデル)、 マスク R-CNN (インスタンスセグメンテーションのゴールドスタンダード) e サム (可能な限界を再定義したメタ AI によるセグメント何でもモデル)。
何を学ぶか
- U-Net アーキテクチャ: 医療セグメンテーションのためのスキップ接続を備えたエンコーダ/デコーダ
- 医療データセットのトレーニングを使用した PyTorch での U-Net の最初からの実装
- マスク R-CNN: 境界ボックス + バイナリ マスクによるインスタンスのセグメンテーション
- Segment Anything Model (SAM): 視覚的なプロンプトによるゼロショット セグメンテーション
- 評価指標: ダイススコア、IoU、セグメンテーションの精度/再現率
- 後処理技術: CRF、数学的形態学
- ケーススタディ: X 線写真からの肺のセグメンテーション (オープンソース データセット)
- 本番環境へのセグメンテーション モデルの導入
1. セグメンテーションの基礎
1.1 セグメンテーションの種類
セグメンテーション分類法
| タイプ | インスタンスの区別 | ランキングの背景 | 出力 | アーキテクチャ |
|---|---|---|---|---|
| セマンティクス | No | Si | ピクセルごとのラベルを含む HxW マップ | U-Net、DeepLabv3、SegFormer |
| インスタンスの | Si | いいえ(単なる「もの」) | オブジェクトのバイナリマスク | マスク R-CNN、SOLOv2、YOLACT |
| パノプティック | はい(「物」について) | はい(「もの」について) | 統合されたインスタンス + セマンティクス マップ | パノプティック FPN、Mask2Former |
| 相互の作用 | はい (プロンプト付き) | それはプロンプトによって異なります | クリック/Bボックス駆動マスク | SAM、SAM2、ClickSEG |
1.2 評価指標
セグメンテーションには、オーバーラップをピクセルごとに測定する特定のメトリックが使用されます。 予測されたマスクとグランドトゥルースの間:
import torch
import numpy as np
from typing import Union
def compute_iou(pred: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> float:
"""
Intersection over Union per segmentazione binaria.
pred, target: tensori [H, W] o [B, H, W] con valori in [0,1]
"""
pred_binary = (pred >= threshold).bool()
target_binary = target.bool()
intersection = (pred_binary & target_binary).float().sum()
union = (pred_binary | target_binary).float().sum()
if union == 0:
return 1.0 # caso degenere: entrambe vuote
return float(intersection / union)
def dice_score(pred: torch.Tensor, target: torch.Tensor, threshold: float = 0.5,
smooth: float = 1.0) -> float:
"""
Dice Score (F1 per segmentazione): 2*|X intersect Y| / (|X| + |Y|)
Preferito in ambito medico perchè meno sensibile agli sbilanciamenti.
Valore: 0 (peggio) -> 1 (perfetto)
"""
pred_binary = (pred >= threshold).float()
target_binary = target.float()
intersection = (pred_binary * target_binary).sum()
dice = (2.0 * intersection + smooth) / (pred_binary.sum() + target_binary.sum() + smooth)
return float(dice)
def compute_multiclass_miou(pred_logits: torch.Tensor, targets: torch.Tensor,
num_classes: int, ignore_index: int = 255) -> float:
"""
mIoU per segmentazione semantica multi-classe.
pred_logits: [B, C, H, W] - logit grezzi
targets: [B, H, W] - indici di classe 0..num_classes-1
"""
preds = pred_logits.argmax(dim=1) # [B, H, W]
ious = []
for cls in range(num_classes):
pred_cls = preds == cls
true_cls = targets == cls
valid = targets != ignore_index
pred_cls = pred_cls & valid
true_cls = true_cls & valid
intersection = (pred_cls & true_cls).sum().float()
union = (pred_cls | true_cls).sum().float()
if union > 0:
ious.append(float(intersection / union))
return float(np.mean(ious)) if ious else 0.0
def hausdorff_distance(pred: np.ndarray, target: np.ndarray) -> float:
"""
Hausdorff Distance: misura la distanza massima tra i bordi delle maschere.
Utile in medicina per valutare la precisione dei contorni.
"""
from scipy.spatial.distance import directed_hausdorff
pred_points = np.argwhere(pred)
target_points = np.argwhere(target)
if len(pred_points) == 0 or len(target_points) == 0:
return float('inf')
d1 = directed_hausdorff(pred_points, target_points)[0]
d2 = directed_hausdorff(target_points, pred_points)[0]
return max(d1, d2)
print("Esempio metriche:")
pred = torch.sigmoid(torch.randn(256, 256))
target = (torch.randn(256, 256) > 0).float()
iou = compute_iou(pred, target)
dice = dice_score(pred, target)
print(f"IoU: {iou:.3f} | Dice: {dice:.3f}")
2. U-Net: 医療セグメンテーションのためのネットワーク
ユーネット (Ronneberger et al., 2015) はもともとセグメンテーションのために提案されました 生物医学画像の。その「U」字型のアーキテクチャは、 接続をスキップする エンコーダとデコーダの間で使用され、あらゆるセグメンテーション タスクの主要なテンプレートとなっています。 医療用ピクセルから衛星地図、産業用画像から屋外シーンまで、高密度です。
2.1 U-Net アーキテクチャ
アーキテクチャは 3 つの部分に分かれています。
- エンコーダー (収縮パス): 一連の畳み込みブロック + 最大プーリングにより、解像度を下げてチャネルを増やし、意味的に豊富だが空間的に不正確な特徴を抽出します
- ボトルネック: 最も深いブロックは最低の解像度で動作します
- デコーダ(拡張パス): 一連のアップサンプリング + 変換により元の解像度が復元され、スキップ接続を介してエンコーダーの特徴マップが連結され、空間の詳細が復元されます。
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""Blocco base U-Net: Conv-BN-ReLU-Conv-BN-ReLU."""
def __init__(self, in_channels: int, out_channels: int, mid_channels: int | None = None):
super().__init__()
if mid_channels is None:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.double_conv(x)
class DownBlock(nn.Module):
"""Encoder block: MaxPool2d + DoubleConv."""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.maxpool_conv(x)
class UpBlock(nn.Module):
"""Decoder block: Upsample + concatenazione skip + DoubleConv."""
def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2,
kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
# Padding se le dimensioni non coincidono esattamente
diff_h = x2.size(2) - x1.size(2)
diff_w = x2.size(3) - x1.size(3)
x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2,
diff_h // 2, diff_h - diff_h // 2])
# Skip connection: concatena feature encoder + decoder
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class UNet(nn.Module):
"""
U-Net originale per segmentazione binaria o multi-classe.
Architettura:
Input -> [64] -> [128] -> [256] -> [512] -> [1024] (bottleneck)
-> [512] -> [256] -> [128] -> [64] -> Output
Le frecce verso il basso sono encoder (+ maxpool)
Le frecce verso l'alto sono decoder (+ skip connections)
"""
def __init__(self, in_channels: int = 1, num_classes: int = 1,
features: list[int] = [64, 128, 256, 512], bilinear: bool = True):
super().__init__()
self.in_conv = DoubleConv(in_channels, features[0])
# Encoder
self.downs = nn.ModuleList([
DownBlock(features[i], features[i+1])
for i in range(len(features) - 1)
])
# Bottleneck
factor = 2 if bilinear else 1
self.bottleneck = DownBlock(features[-1], features[-1] * 2 // factor)
# Decoder
self.ups = nn.ModuleList([
UpBlock(features[-1] * 2 // factor + features[-(i+1)],
features[-(i+2)] if i < len(features)-1 else features[0],
bilinear)
for i in range(len(features))
])
# Semplifichiamo con lista esplicita
self.ups = nn.ModuleList([
UpBlock(1024, 512 // factor, bilinear),
UpBlock(512, 256 // factor, bilinear),
UpBlock(256, 128 // factor, bilinear),
UpBlock(128, 64, bilinear),
])
self.out_conv = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Encoder path (salva skip connections)
x1 = self.in_conv(x)
x2 = self.downs[0](x1)
x3 = self.downs[1](x2)
x4 = self.downs[2](x3)
# Bottleneck
x5 = self.bottleneck(x4)
# Decoder path (usa skip connections)
x = self.ups[0](x5, x4)
x = self.ups[1](x, x3)
x = self.ups[2](x, x2)
x = self.ups[3](x, x1)
return self.out_conv(x)
# Test architettura
model = UNet(in_channels=3, num_classes=1)
x = torch.randn(2, 3, 256, 256)
y = model(x)
print(f"Input: {x.shape} -> Output: {y.shape}")
# Input: torch.Size([2, 3, 256, 256]) -> Output: torch.Size([2, 1, 256, 256])
total_params = sum(p.numel() for p in model.parameters())
print(f"Parametri: {total_params:,}")
2.2 ダイス損失のある U-Net トレーニング
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
"""
Dice Loss per segmentazione binaria.
Gestisce naturalmente lo sbilanciamento di classe tipico delle immagini mediche
(es. 95% sfondo, 5% lesione).
"""
def __init__(self, smooth: float = 1.0):
super().__init__()
self.smooth = smooth
def forward(self, pred_logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# Applica sigmoid per ottenere probabilità
preds = torch.sigmoid(pred_logits)
# Flatten per calcolo efficiente
preds_flat = preds.view(-1)
targets_flat = targets.view(-1)
intersection = (preds_flat * targets_flat).sum()
dice = (2.0 * intersection + self.smooth) / (
preds_flat.sum() + targets_flat.sum() + self.smooth
)
return 1.0 - dice # loss = 1 - Dice (minimizzare)
class CombinedLoss(nn.Module):
"""
Combinazione BCE + Dice: il compromesso migliore per segmentazione medica.
BCE: ottimizza ogni pixel individualmente
Dice: ottimizza l'overlap globale tra predizione e ground truth
"""
def __init__(self, bce_weight: float = 0.5, dice_weight: float = 0.5):
super().__init__()
self.bce = nn.BCEWithLogitsLoss()
self.dice = DiceLoss()
self.bce_weight = bce_weight
self.dice_weight = dice_weight
def forward(self, pred_logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
bce_loss = self.bce(pred_logits, targets.float())
dice_loss = self.dice(pred_logits, targets.float())
return self.bce_weight * bce_loss + self.dice_weight * dice_loss
def train_unet(
model: UNet,
train_loader,
val_loader,
num_epochs: int = 50,
learning_rate: float = 1e-4
) -> dict:
"""
Training completo di U-Net con:
- Combined BCE+Dice loss
- AdamW + CosineAnnealingLR
- Early stopping su Dice score di validazione
- Checkpoint del modello migliore
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=1e-5
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=num_epochs, eta_min=1e-6
)
history = {'train_loss': [], 'val_loss': [], 'val_dice': []}
best_dice = 0.0
patience = 15
no_improve = 0
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
pred_logits = model(images)
loss = criterion(pred_logits, masks)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
scheduler.step()
# Validation
model.eval()
val_loss = 0.0
val_dice_scores = []
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
pred_logits = model(images)
val_loss += criterion(pred_logits, masks).item()
preds = torch.sigmoid(pred_logits)
for p, m in zip(preds, masks):
val_dice_scores.append(dice_score(p, m))
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
avg_val_dice = sum(val_dice_scores) / len(val_dice_scores)
history['train_loss'].append(avg_train_loss)
history['val_loss'].append(avg_val_loss)
history['val_dice'].append(avg_val_dice)
if avg_val_dice > best_dice:
best_dice = avg_val_dice
torch.save(model.state_dict(), 'best_unet.pth')
no_improve = 0
else:
no_improve += 1
print(f"Epoch {epoch+1:2d}/{num_epochs} | "
f"Loss: {avg_train_loss:.4f}/{avg_val_loss:.4f} | "
f"Dice: {avg_val_dice:.4f} | Best: {best_dice:.4f}")
if no_improve >= patience:
print(f"Early stopping at epoch {epoch+1}")
break
print(f"Training completato. Best Dice Score: {best_dice:.4f}")
return history
3. セグメント何でもモデル (SAM)
メタAIがリリースされました サム (Kirillov et al., 2023) という野心的な目標を掲げています。 汎用的なセグメンテーション モデルを構築する: 10 億のマスクでトレーニングされたモデル セグメント化できるもの 何でも in 任意の画像 柔軟なプロンプトを使用する (点、境界ボックス、テキストをクリックします)。 SAM2 (2024) では、モデルをビデオにも拡張しました。
3.1 SAM アーキテクチャ
SAM は 3 つの主要コンポーネントで構成されます。
- 画像エンコーダ: 高密度の画像埋め込みを生成する Vision Transformer (ViT-H、632M パラメータ)。イメージごとに 1 回だけ実行されます。
- プロンプトエンコーダー: さまざまなタイプのプロンプト (ポイント、ボックス、マスク、テキスト) をデコーダ互換の埋め込みにエンコードします。
- マスクデコーダー: 画像の埋め込みとマスクを生成するプロンプトを組み合わせた軽量のトランスフォーマー。信頼スコアを持つ 3 つの候補マスクを生成します。
# pip install segment-anything
# Download checkpoint: https://github.com/facebookresearch/segment-anything
import numpy as np
import cv2
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
def load_sam_model(
model_type: str = 'vit_h',
checkpoint_path: str = 'sam_vit_h_4b8939.pth',
device: str = 'cuda'
):
"""
Carica il modello SAM.
Tipi disponibili: 'vit_h' (default, max accuratezza), 'vit_l', 'vit_b' (più veloce)
"""
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam.to(device=device)
return sam
def segment_with_point_prompt(
sam_model,
image: np.ndarray,
point_coords: list[tuple[int, int]],
point_labels: list[int] # 1=foreground, 0=background
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Segmenta con prompt a punti.
Restituisce: (maschere, score, logits) - 3 proposte ordinate per score.
"""
predictor = SamPredictor(sam_model)
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=np.array(point_coords),
point_labels=np.array(point_labels),
multimask_output=True # genera 3 maschere candidate
)
# Ordina per score decrescente
sorted_idx = np.argsort(scores)[::-1]
return masks[sorted_idx], scores[sorted_idx], logits[sorted_idx]
def segment_with_box_prompt(
sam_model,
image: np.ndarray,
box: tuple[int, int, int, int] # [x1, y1, x2, y2]
) -> tuple[np.ndarray, float]:
"""
Segmenta con prompt bounding box.
Il box definisce la regione di interesse da segmentare.
"""
predictor = SamPredictor(sam_model)
predictor.set_image(image)
masks, scores, _ = predictor.predict(
box=np.array([box]),
multimask_output=False # 1 sola maschera con box prompt
)
return masks[0], float(scores[0])
def automatic_segmentation(sam_model, image: np.ndarray) -> list[dict]:
"""
Segmentazione automatica: SAM segmenta TUTTO nell'immagine
senza nessun prompt. Usa una griglia di punti come seed.
"""
mask_generator = SamAutomaticMaskGenerator(
model=sam_model,
points_per_side=32, # griglia 32x32 = 1024 punti seed
pred_iou_thresh=0.88, # filtra maschere con IoU basso
stability_score_thresh=0.95, # filtra maschere instabili
crop_n_layers=1, # multi-crop per oggetti piccoli
crop_n_points_downscale_factor=2,
min_mask_region_area=100 # rimuovi regioni molto piccole
)
masks = mask_generator.generate(image)
# Ordina per area decrescente
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
print(f"SAM ha trovato {len(masks)} segmenti")
for i, mask in enumerate(masks[:5]):
print(f" Segmento {i+1}: area={mask['area']} "
f"score={mask['predicted_iou']:.3f}")
return masks
def visualize_sam_results(image: np.ndarray, masks: list[dict],
alpha: float = 0.4) -> np.ndarray:
"""Visualizza tutte le maschere SAM con colori random."""
result = image.copy()
np.random.seed(42)
for mask_info in masks:
mask = mask_info['segmentation'] # bool array [H, W]
color = np.random.randint(50, 255, 3)
overlay = result.copy()
overlay[mask] = color
result = cv2.addWeighted(result, 1 - alpha, overlay, alpha, 0)
# Contorno
contours, _ = cv2.findContours(
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
cv2.drawContours(result, contours, -1, color.tolist(), 2)
return result
# Esempio d'uso
sam = load_sam_model('vit_b', 'sam_vit_b_01ec64.pth') # versione più leggera
image = cv2.imread('image.jpg')
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Segmenta con un click (punto foreground)
masks, scores, _ = segment_with_point_prompt(
sam, image_rgb,
point_coords=[(320, 240)], # click al centro dell'oggetto
point_labels=[1] # 1 = foreground
)
best_mask = masks[0]
print(f"Maschera trovata con score: {scores[0]:.3f}")
3.2 ビデオ用の SAM2
# pip install sam2
# SAM2 rilasciato da Meta AI nell'agosto 2024
import torch
from sam2.build_sam import build_sam2_video_predictor
def segment_video_with_sam2(
video_path: str,
initial_frame: int,
initial_points: list[tuple[int, int]],
checkpoint: str = 'sam2_hiera_large.pt',
config: str = 'sam2_hiera_l.yaml'
) -> dict[int, np.ndarray]:
"""
Segmenta e traccia un oggetto attraverso i frame di un video.
Inizializza con punti sul primo frame, poi traccia automaticamente.
Returns:
Dict frame_idx -> maschera binaria [H, W]
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
predictor = build_sam2_video_predictor(config, checkpoint, device=device)
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.bfloat16):
# Inizializza sul video
state = predictor.init_state(video_path=video_path)
predictor.reset_state(state)
# Aggiungi prompt sul frame iniziale
frame_idx, obj_ids, masks = predictor.add_new_points_or_box(
inference_state=state,
frame_idx=initial_frame,
obj_id=1,
points=np.array(initial_points),
labels=np.ones(len(initial_points), dtype=np.int32)
)
# Propaga su tutto il video
video_masks = {}
for frame_idx, obj_ids, masks in predictor.propagate_in_video(state):
mask = (masks[0][0] > 0.0).cpu().numpy()
video_masks[frame_idx] = mask
print(f"Segmentazione completata: {len(video_masks)} frame processati")
return video_masks
4. ケーススタディ: X 線写真からの肺のセグメンテーション
私たちは、胸部 X 線写真からの肺のセグメンテーションに U-Net を適用します。 モンゴメリー郡の X 線データセット (セグメンテーションマスクを使用した 138 枚の X 線写真 放射線科医が手動で注釈を付けた肺)。
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
class LungXrayDataset(Dataset):
"""Dataset per segmentazione polmoni da radiografie (Montgomery CXR)."""
def __init__(self, image_dir: str, mask_dir: str, img_size: int = 512,
augment: bool = True):
self.image_paths = sorted(Path(image_dir).glob('*.png'))
self.mask_dir = Path(mask_dir)
self.img_size = img_size
if augment:
self.transform = A.Compose([
A.RandomResizedCrop(img_size, img_size, scale=(0.8, 1.0)),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1,
rotate_limit=15, p=0.7),
A.OneOf([
A.GaussNoise(var_limit=(10, 50)),
A.GaussianBlur(blur_limit=3),
A.MedianBlur(blur_limit=3)
], p=0.3),
A.RandomBrightnessContrast(brightness_limit=0.2,
contrast_limit=0.2, p=0.5),
A.CLAHE(clip_limit=2, p=0.3), # Contrast Limited AHE per RX
A.Normalize(mean=[0.485], std=[0.229]), # Grayscale normalization
ToTensorV2()
])
else:
self.transform = A.Compose([
A.Resize(img_size, img_size),
A.Normalize(mean=[0.485], std=[0.229]),
ToTensorV2()
])
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
img_path = self.image_paths[idx]
mask_path = self.mask_dir / img_path.name
# Carica immagine (grayscale)
image = np.array(Image.open(img_path).convert('L'))
mask = np.array(Image.open(mask_path).convert('L'))
# Binarizza maschera
mask = (mask > 127).astype(np.float32)
transformed = self.transform(image=image, mask=mask)
return transformed['image'], transformed['mask'].unsqueeze(0)
def run_lung_segmentation_pipeline():
"""Pipeline completa: dataset -> training -> valutazione -> salvataggio."""
# Data loading
train_dataset = LungXrayDataset(
'data/train/images', 'data/train/masks', augment=True
)
val_dataset = LungXrayDataset(
'data/val/images', 'data/val/masks', augment=False
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,
num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False,
num_workers=4, pin_memory=True)
# Modello: U-Net per immagini grayscale
model = UNet(in_channels=1, num_classes=1, features=[32, 64, 128, 256])
# Training
history = train_unet(model, train_loader, val_loader, num_epochs=100)
# Valutazione finale
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('best_unet.pth', map_location=device))
model.eval()
all_dice = []
all_iou = []
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
preds = torch.sigmoid(model(images))
for p, m in zip(preds, masks):
all_dice.append(dice_score(p, m))
all_iou.append(compute_iou(p, m))
print(f"\n=== Risultati Finali ===")
print(f"Dice Score: {np.mean(all_dice):.4f} ± {np.std(all_dice):.4f}")
print(f"IoU: {np.mean(all_iou):.4f} ± {np.std(all_iou):.4f}")
# Risultati attesi per U-Net su Montgomery: Dice ~0.97, IoU ~0.94
5. 何でもセグメント化モデル 2: ゼロショット ビデオ セグメント化
SAM2 (Meta AI、2024 年 7 月) SAM をビデオ シーケンスに拡張: その先へ インタラクティブなプロンプト (ポイント、ボックス、マスク) を使用して静的画像内のオブジェクトをセグメント化します。 SAM2 は、モジュールのおかげでビデオ フレームに沿ってマスクを自動的に伝播します。 記憶。これは、ビデオ上でゼロショット セグメンテーションを確実に実行する最初のモデルです。
SAM と SAM2: 主な違い
| 特徴 | サム (2023) | SAM2 (2024) |
|---|---|---|
| ビデオサポート | いいえ(画像のみ) | はい (時間伝播) |
| メモリモジュール | 不在 | クロスフレームアテンションを備えたメモリバンク |
| プロンプトタイプ | ポイント、ボックス、マスク、テキスト(CLIP経由) | ポイント、ボックス、マスク (+ ビデオ トラッキング) |
| スピード | ~50ms/画像 (ViT-H) | ~44ms/フレーム (Hiera-L)、~8ms/フレーム (Hiera-T) |
| トレーニングデータ | SA-1B(1Bマスク) | SA-V (50.9K ビデオ、642K マスク) |
| マルチオブジェクト | 限定 | はい、同時複数オブジェクト追跡 |
import torch
import numpy as np
import cv2
from PIL import Image
# pip install git+https://github.com/facebookresearch/segment-anything-2.git
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
# ============================================================
# PARTE 1: SAM2 su singola immagine
# ============================================================
def sam2_image_segment(image_path: str,
point_coords: list[list[int]],
point_labels: list[int], # 1=foreground, 0=background
model_cfg: str = 'sam2_hiera_large.yaml',
checkpoint: str = 'sam2_hiera_large.pt') -> np.ndarray:
"""
Segmentazione con SAM2 su singola immagine.
point_coords: [[x1, y1], [x2, y2], ...] - punti prompt
point_labels: [1, 1, 0, ...] - 1=foreground, 0=background
Returns: maschera binaria [H, W] bool
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = build_sam2(model_cfg, checkpoint, device=device)
predictor = SAM2ImagePredictor(model)
# Carica immagine
image = np.array(Image.open(image_path).convert('RGB'))
predictor.set_image(image)
# Predici maschera con prompt
masks, scores, logits = predictor.predict(
point_coords=np.array(point_coords),
point_labels=np.array(point_labels),
multimask_output=True, # 3 maschere con confidenze diverse
)
# Prendi la maschera con score più alto
best_idx = np.argmax(scores)
best_mask = masks[best_idx]
print(f"Maschera selezionata: score={scores[best_idx]:.3f}, "
f"area={best_mask.sum()} pixel")
return best_mask # [H, W] bool
def sam2_box_prompt(image_np: np.ndarray,
box: list[int],
predictor: SAM2ImagePredictor) -> np.ndarray:
"""
Segmentazione con prompt box (x1, y1, x2, y2).
Più preciso dei punti per oggetti con bordi definiti.
"""
predictor.set_image(image_np)
masks, scores, _ = predictor.predict(
box=np.array(box),
multimask_output=False, # Box prompt -> singola maschera ottimale
)
return masks[0] # [H, W] bool
# ============================================================
# PARTE 2: SAM2 su video - propagazione temporale
# ============================================================
def sam2_video_segment(video_dir: str,
frame_idx: int,
points: list[list[int]],
labels: list[int],
model_cfg: str = 'sam2_hiera_large.yaml',
checkpoint: str = 'sam2_hiera_large.pt') -> dict:
"""
SAM2 video predictor: segmenta un oggetto nel frame 'frame_idx'
e propaga la maschera automaticamente lungo tutto il video.
video_dir: cartella con frame del video (frame_*.jpg)
Returns: dict {frame_idx: {obj_id: mask}}
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=device)
with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
# Inizializza predictor con la directory video
inference_state = predictor.init_state(video_path=video_dir)
# Aggiungi prompt nel frame di annotazione
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=1, # ID oggetto da trackare
points=np.array(points, dtype=np.float32),
labels=np.array(labels, dtype=np.int32),
)
# Propaga la segmentazione su tutto il video
all_masks = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state
):
for obj_id, mask_logit in zip(out_obj_ids, out_mask_logits):
mask = (mask_logit > 0).squeeze().cpu().numpy()
if out_frame_idx not in all_masks:
all_masks[out_frame_idx] = {}
all_masks[out_frame_idx][int(obj_id)] = mask
return all_masks
# ============================================================
# PARTE 3: SAM2 come labeling tool automatizzato
# ============================================================
class SAM2AutoLabeler:
"""
Usa SAM2 per generare automaticamente maschere di training.
Riduce i costi di annotazione del 60-80% rispetto all'annotazione manuale.
Human-in-the-loop: un umano valida e corregge le predizioni SAM2.
"""
def __init__(self, checkpoint: str = 'sam2_hiera_base_plus.pt',
model_cfg: str = 'sam2_hiera_base_plus.yaml'):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = build_sam2(model_cfg, checkpoint, device=device)
self.predictor = SAM2ImagePredictor(model)
def auto_label_from_yolo_boxes(self,
image_np: np.ndarray,
yolo_boxes: list[tuple],
min_score: float = 0.7) -> list[dict]:
"""
Genera maschere SAM2 usando bounding box di YOLO come prompt.
Workflow: YOLO rileva oggetti -> SAM2 affina con maschera pixel-perfect.
yolo_boxes: lista di (x1, y1, x2, y2, class_id, confidence)
Returns: lista di {box, class_id, mask, sam_score}
"""
self.predictor.set_image(image_np)
results = []
for x1, y1, x2, y2, class_id, conf in yolo_boxes:
if conf < 0.5:
continue
masks, scores, _ = self.predictor.predict(
box=np.array([x1, y1, x2, y2]),
multimask_output=True,
)
best_idx = np.argmax(scores)
if scores[best_idx] < min_score:
continue
results.append({
'box': (x1, y1, x2, y2),
'class_id': class_id,
'mask': masks[best_idx],
'sam_score': float(scores[best_idx]),
'yolo_conf': float(conf)
})
return results
def save_masks_coco_format(self, results: list[dict],
image_id: int,
output_path: str) -> None:
"""Salva maschere in formato COCO per training Mask R-CNN."""
import json
from pycocotools import mask as coco_mask
annotations = []
for ann_id, r in enumerate(results):
binary_mask = r['mask'].astype(np.uint8)
rle = coco_mask.encode(np.asfortranarray(binary_mask))
rle['counts'] = rle['counts'].decode('utf-8')
area = float(np.sum(binary_mask))
x1, y1, x2, y2 = r['box']
annotations.append({
'id': ann_id,
'image_id': image_id,
'category_id': r['class_id'],
'segmentation': rle,
'area': area,
'bbox': [x1, y1, x2-x1, y2-y1],
'iscrowd': 0
})
with open(output_path, 'w') as f:
json.dump(annotations, f, indent=2)
6. セグメンテーションのベストプラクティス
主な推奨事項
- 損失の選択: 不均衡なデータセット (大きな背景上の小さな病変など) の場合は、純粋な BCE の代わりに Dice Loss または Focal Loss を使用します。多くの場合、ECB と Dice の組み合わせが最良の妥協策となります。
- ドメイン固有の正規化: 医療画像 (グレースケール) の場合は、ImageNet ではなく、特定のデータセットで計算された統計を使用します。 X 線写真の場合、CLAHE 前処理により結果が大幅に向上します。
- 保守的なデータ拡張: 医学では、解剖学的に意味が無い場合は、垂直方向の反転を適用しないでください。あまり歪めないでください。解剖学的構造には正確な方向があります。
- 入力解像度: U-Net は解像度に依存します。 X 線: 最小 512x512。細かい詳細(組織学、細胞学)の場合: 1024x1024 またはクロップアプローチ。
- 後処理: 条件付きランダム フィールド (CRF) または形態学的操作 (閉じる、開く) を適用して、マスクのエッジをシャープにします。
- ラベル付け用の SAM: SAM を使用してトレーニング マスク (人間参加型ラベル付け) の生成を加速し、アノテーション コストを 60 ~ 80% 削減します。
よくある間違い
- 異なる分布のデータを検証しないでください。 医療セグメンテーション モデルは、ドメインの変化 (異なるスキャナー、プロトコル、母集団) に対して脆弱であることで知られています。常にさまざまなセンターからのデータを検証します。
- 低品質のマスクを無視します。 トレーニングでは、人間によるアノテーションには観察者間でのばらつきがあります。可能であれば、複数のアノテーターのコンセンサスを使用するか、アノテーションの信頼性に基づいたウェイトロスを使用します。
- ダイスは損失としてのみ使用します。 ダイス損失はバッチが小さい場合は不安定であり、勾配に不連続性があります。常に BCE と組み合わせるか、一般化されたダイス損失バリアントを使用してください。
- まれなクラスを無視する: マルチクラス セグメンテーションでは、まれなクラス (数ピクセル) がモデルによって無視される傾向があります。まれなクラスを含む画像のクラス重み付け損失またはオーバーサンプリングを使用します。
結論
主なセグメンテーション アーキテクチャとその実際のアプリケーションを検討しました。
- U-Net: スキップ接続を備えたエンコーダ/デコーダ アーキテクチャ。肺 X 線写真で Dice ~0.97 の医療セグメンテーションの事実上の標準
- マスク R-CNN: 各インスタンスのバウンディング ボックス + マスクを使用したインスタンスのセグメンテーション。密集した自然のシーンに最適です。
- SAM および SAM2: インタラクティブ プロンプト (SAM) と時間ビデオ伝播 (SAM2) を備えたユニバーサル ゼロショット セグメンテーション。高速ラベル付けに革新的です。
- 自動ラベル付けツールとしての SAM2: アノテーション コストを 60 ~ 80% 削減する YOLO+SAM2 パイプライン
- Dice 損失と BCE+Dice の組み合わせ: 領域が小さい不均衡なデータセットの最適な損失
- 後処理: マスクのエッジを調整するための数学的形態学と CRF







