diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..6af7de5eb 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -238,6 +238,9 @@ Available Datasets datasets/pyhealth.datasets.BMDHSDataset datasets/pyhealth.datasets.COVID19CXRDataset datasets/pyhealth.datasets.ChestXray14Dataset + datasets/pyhealth.datasets.ISIC2018Dataset + datasets/pyhealth.datasets.ISIC2018ArtifactsDataset + datasets/pyhealth.datasets.PH2Dataset datasets/pyhealth.datasets.TUABDataset datasets/pyhealth.datasets.TUEVDataset datasets/pyhealth.datasets.ClinVarDataset diff --git a/docs/api/datasets/pyhealth.datasets.ISIC2018ArtifactsDataset.rst b/docs/api/datasets/pyhealth.datasets.ISIC2018ArtifactsDataset.rst new file mode 100644 index 000000000..477387a6e --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.ISIC2018ArtifactsDataset.rst @@ -0,0 +1,34 @@ +pyhealth.datasets.ISIC2018ArtifactsDataset +====================================== + +A dataset class for dermoscopy images paired with per-image artifact +annotations. The default annotation file is ``isic_bias.csv`` from +Bissoto et al. (2020), but any CSV following the same column format can be +supplied via the ``annotations_csv`` parameter. + +**Default data sources (Bissoto et al. 2020)** + +Using ``ISIC2018ArtifactsDataset`` with the default annotation CSV requires +**two separate downloads**: + +1. **Artifact annotations** (``isic_bias.csv``): + https://github.com/alceubissoto/debiasing-skin + + See ``artefacts-annotation/`` in that repository for the annotation files. + + Reference: + Bissoto et al. "Debiasing Skin Lesion Datasets and Models? Not So Fast", + ISIC Skin Image Analysis Workshop @ CVPR 2020. + +2. **ISIC 2018 Task 1/2 images & masks** (~8 GB): + https://challenge.isic-archive.com/data/#2018 + + * Training images: ``ISIC2018_Task1-2_Training_Input.zip`` + * Segmentation masks: ``ISIC2018_Task1_Training_GroundTruth.zip`` + +Both can be fetched automatically by passing ``download=True``. + +.. autoclass:: pyhealth.datasets.ISIC2018ArtifactsDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/datasets/pyhealth.datasets.ISIC2018Dataset.rst b/docs/api/datasets/pyhealth.datasets.ISIC2018Dataset.rst new file mode 100644 index 000000000..4d22314cb --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.ISIC2018Dataset.rst @@ -0,0 +1,21 @@ +pyhealth.datasets.ISIC2018Dataset +=================================== + +Unified dataset class for the ISIC 2018 challenge, supporting both +**Task 1/2** (lesion segmentation & attribute detection) and **Task 3** +(7-class skin lesion classification) via the ``task`` argument. + +For more information see `the ISIC 2018 challenge page `_. + +.. note:: + **Licenses differ by task:** + + * ``task="task1_2"`` — `CC-0 (Public Domain) `_. + No attribution required. + * ``task="task3"`` — `CC-BY-NC 4.0 `_. + Attribution is required; commercial use is not permitted. + +.. autoclass:: pyhealth.datasets.ISIC2018Dataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/datasets/pyhealth.datasets.PH2Dataset.rst b/docs/api/datasets/pyhealth.datasets.PH2Dataset.rst new file mode 100644 index 000000000..8d77bf3cd --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.PH2Dataset.rst @@ -0,0 +1,47 @@ +pyhealth.datasets.PH2Dataset +============================ + +A dataset class for the **PH2 dermoscopy database** — 200 dermoscopic +images of melanocytic lesions labelled as Common Nevus (80), Atypical +Nevus (80), or Melanoma (40). + +Two source formats are supported automatically: + +* **Mirror format** (recommended for quick start): flat JPEG images and a + ``PH2_simple_dataset.csv`` from the community mirror + `vikaschouhan/PH2-dataset `_. + Pass ``download=True`` to fetch this automatically. + +* **Original format**: nested BMP images and expert annotations from the + `official ADDI project release `_ + (requires registration). Place the extracted ``PH2_Dataset_images/`` + directory inside *root* and the dataset will use it automatically. + +**Data source** + +Teresa Mendonça, Pedro M. Ferreira, Jorge Marques, Andre R.S. Marcal, +Jorge Rozeira. "PH² - A dermoscopic image database for research and +benchmarking", 35th International Conference on Engineering in Medicine +and Biology Society, EMBC'13, pp. 5437-5440, IEEE, 2013. + +License: free for non-commercial research purposes. See the +`ADDI project page `_ for +full terms. + +**Quick start** + +.. code-block:: python + + from pyhealth.datasets import PH2Dataset + from pyhealth.tasks import PH2MelanomaClassification + + # Download mirror automatically + dataset = PH2Dataset(root="~/ph2", download=True) + + # Apply 3-class melanoma classification task + samples = dataset.set_task(PH2MelanomaClassification()) + +.. autoclass:: pyhealth.datasets.PH2Dataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..7d41d9884 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -225,6 +225,9 @@ Available Tasks Benchmark EHRShot ChestX-ray14 Binary Classification ChestX-ray14 Multilabel Classification + ISIC 2018 Skin Lesion Classification + ISIC 2018 Dermoscopic Artifacts Classification + PH2 Melanoma Classification Variant Classification (ClinVar) Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) diff --git a/docs/api/tasks/pyhealth.tasks.ISIC2018ArtifactsBinaryClassification.rst b/docs/api/tasks/pyhealth.tasks.ISIC2018ArtifactsBinaryClassification.rst new file mode 100644 index 000000000..bdca7d59d --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ISIC2018ArtifactsBinaryClassification.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.ISIC2018ArtifactsBinaryClassification +===================================================== + +.. autoclass:: pyhealth.tasks.ISIC2018ArtifactsBinaryClassification + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.ISIC2018Classification.rst b/docs/api/tasks/pyhealth.tasks.ISIC2018Classification.rst new file mode 100644 index 000000000..1579439ea --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ISIC2018Classification.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.ISIC2018Classification +======================================= + +.. autoclass:: pyhealth.tasks.ISIC2018Classification + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.PH2MelanomaClassification.rst b/docs/api/tasks/pyhealth.tasks.PH2MelanomaClassification.rst new file mode 100644 index 000000000..8ff7bbc91 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.PH2MelanomaClassification.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.PH2MelanomaClassification +========================================= + +.. autoclass:: pyhealth.tasks.PH2MelanomaClassification + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/isic2018_artifacts_train_lora_sd15.py b/examples/isic2018_artifacts_train_lora_sd15.py new file mode 100644 index 000000000..480b13b99 --- /dev/null +++ b/examples/isic2018_artifacts_train_lora_sd15.py @@ -0,0 +1,535 @@ +"""DreamBooth + LoRA fine-tuning on ISIC 2018 dermoscopic artifact images. + +Replicates the diffusion model training from: + Jin et al. "A Study of Artifacts on Melanoma Classification under + Diffusion-Based Perturbations", CHIL 2025. + +Method +------ +For each artifact type, fine-tune Stable Diffusion 1.5 with DreamBooth and +LoRA to associate a rare token with that artifact's visual appearance. The +trained LoRA adapters are later loaded into an SD inpainting pipeline to +augment PH2 images (see ph2_diffusion_sd.py). + +Training setup (matches paper §3.3) +------------------------------------ + Base model : runwayml/stable-diffusion-v1-5 + LoRA targets : UNet attention projections + CLIP text-encoder projections + LoRA rank : 64 (paper: 64) + LoRA alpha : 32 (paper: 32) + Epochs : 4 (paper: 3–5) + LR : 1e-4 (paper: 1e-4) + Batch size : 2 (paper: 2) + Resolution : 512 (paper: 512) + Prior weight : 0.3 (paper: 0.3) + Prior images : 200 (paper: 200) + Seed : 0 (paper: 0) + Precision : fp16 + +Rare tokens (paper Table / §3.3) +--------------------------------- + patches → olis + dark_corner → lun + ruler → dits + ink → httr + gel_bubble → sown + +Instance prompt : "a dermoscopic image of {token} {class}" +Class prompt : "a dermoscopic image of {class}" + +Instance image selection +------------------------ +We sample --n_instance images per artifact from the Bissoto et al. (2020) +artifact annotations included in isic-artifact-metadata-pyhealth.csv. +Single-artifact images are preferred (cleanest signal); if fewer than +--n_instance such images exist the remaining slots are filled from +multi-artifact images that contain the target artifact. + +Outputs +------- + ~/lora_checkpoints/{artifact}/unet/ — LoRA adapter (PEFT format) + ~/lora_checkpoints/{artifact}/text_encoder/ — LoRA adapter (PEFT format) + ~/lora_checkpoints/{artifact}/prior_images/ — 200 generated class images + +Usage +----- + # Train one artifact + pixi run -e base python examples/isic2018_artifacts_train_lora_sd15.py --artifact gel_bubble + + # Train all five artifacts sequentially + pixi run -e base python examples/isic2018_artifacts_train_lora_sd15.py --artifact all + + # Quick smoke test (3 instance images, 1 epoch, 10 prior images) + pixi run -e base python examples/isic2018_artifacts_train_lora_sd15.py --artifact ink --test + +Optional flags +-------------- + --artifact {hair,dark_corner,ruler,gel_bubble,ink,patches,all} + --n_instance Number of instance images per artifact (default 50) + --epochs Training epochs (default 4) + --lr Learning rate (default 1e-4) + --rank LoRA rank (default 64) + --alpha LoRA alpha (default 32) + --prior_weight Prior preservation loss weight (default 0.3) + --n_prior Prior images to generate (default 200) + --output_dir Checkpoint root (default ~/lora_checkpoints) + --data_csv Artifact annotation CSV + --image_dir ISIC image directory + --model Base SD model (default runwayml/stable-diffusion-v1-5) + --test Smoke-test mode: 3 images, 1 epoch, 10 prior images +""" + +import argparse +import csv +import math +import os +import random +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +RARE_TOKENS = { + "patches": "olis", + "dark_corner": "lun", + "ruler": "dits", + "ink": "httr", + "gel_bubble": "sown", + "hair": "helo", # not in paper; we define our own rare token +} + +ALL_ARTIFACTS = list(RARE_TOKENS.keys()) + +UNET_LORA_TARGETS = ["to_q", "to_k", "to_v", "to_out.0"] +TEXT_ENCODER_LORA_TARGETS = ["q_proj", "k_proj", "v_proj", "out_proj"] + + +# --------------------------------------------------------------------------- +# Data helpers +# --------------------------------------------------------------------------- + +def sample_instance_images(artifact: str, n: int, data_csv: str, image_dir: str, + seed: int = 0) -> list: + """Return up to n image paths for images labelled with `artifact`. + + Prefers single-artifact images (cleanest visual signal); fills remaining + slots with multi-artifact images that contain the target artifact. + """ + rng = random.Random(seed) + rows = [] + with open(data_csv) as f: + reader = csv.DictReader(f) + for row in reader: + rows.append(row) + + artifact_cols = ["dark_corner", "hair", "gel_bubble", "ruler", "ink", "patches"] + + single, multi = [], [] + for row in rows: + if int(row[artifact]) != 1: + continue + path = row.get("path") or os.path.join(image_dir, row["image"]) + if not os.path.exists(path): + path = os.path.join(image_dir, row["image"]) + if not os.path.exists(path): + continue + n_arts = sum(int(row[a]) for a in artifact_cols if a in row) + label = row.get("label_string", "benign") + if n_arts == 1: + single.append((path, label)) + else: + multi.append((path, label)) + + rng.shuffle(single) + rng.shuffle(multi) + combined = (single + multi)[:n] + print(f" {artifact}: {len(single)} single-artifact, {len(multi)} multi; " + f"selected {len(combined)} / {n} requested") + return combined # list of (path, label) + + +def generate_prior_images(pipe, class_prompts: list, out_dir: Path, seed: int = 0): + """Generate prior class images using the unmodified base model.""" + out_dir.mkdir(parents=True, exist_ok=True) + existing = list(out_dir.glob("*.jpg")) + if len(existing) >= len(class_prompts): + print(f" Prior images already exist ({len(existing)}), skipping generation.") + return [str(p) for p in existing[:len(class_prompts)]] + + print(f" Generating {len(class_prompts)} prior images…") + generator = torch.Generator("cuda").manual_seed(seed) + paths = [] + for i, prompt in enumerate(class_prompts): + out_path = out_dir / f"prior_{i:04d}.jpg" + if out_path.exists(): + paths.append(str(out_path)) + continue + result = pipe( + prompt, + num_inference_steps=30, + guidance_scale=7.5, + generator=generator, + height=512, + width=512, + ).images[0] + result.save(out_path, quality=95) + paths.append(str(out_path)) + if (i + 1) % 20 == 0: + print(f" {i+1}/{len(class_prompts)}") + return paths + + +# --------------------------------------------------------------------------- +# Dataset +# --------------------------------------------------------------------------- + +IMAGE_TRANSFORMS = transforms.Compose([ + transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS), + transforms.CenterCrop(512), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), +]) + + +class DreamBoothDataset(Dataset): + """Yields (instance_pixel, instance_prompt, prior_pixel, prior_prompt).""" + + def __init__(self, instance_items, instance_prompt_fn, + prior_paths, prior_prompt): + self.instance_items = instance_items # list of (path, label) + self.instance_prompt_fn = instance_prompt_fn # fn(label) -> str + self.prior_paths = prior_paths + self.prior_prompt = prior_prompt + + def __len__(self): + return max(len(self.instance_items), len(self.prior_paths)) + + def __getitem__(self, idx): + inst_path, label = self.instance_items[idx % len(self.instance_items)] + prior_path = self.prior_paths[idx % len(self.prior_paths)] + + inst_img = Image.open(inst_path).convert("RGB") + prior_img = Image.open(prior_path).convert("RGB") + + return { + "instance_pixel": IMAGE_TRANSFORMS(inst_img), + "instance_prompt": self.instance_prompt_fn(label), + "prior_pixel": IMAGE_TRANSFORMS(prior_img), + "prior_prompt": self.prior_prompt, + } + + +def collate_fn(batch): + return { + "instance_pixels": torch.stack([b["instance_pixel"] for b in batch]), + "instance_prompts": [b["instance_prompt"] for b in batch], + "prior_pixels": torch.stack([b["prior_pixel"] for b in batch]), + "prior_prompts": [b["prior_prompt"] for b in batch], + } + + +# --------------------------------------------------------------------------- +# LoRA helpers +# --------------------------------------------------------------------------- + +def add_lora(model, target_modules: list, rank: int, alpha: int): + from peft import LoraConfig, get_peft_model + config = LoraConfig( + r=rank, + lora_alpha=alpha, + target_modules=target_modules, + lora_dropout=0.0, + bias="none", + ) + return get_peft_model(model, config) + + +def encode_prompts(tokenizer, text_encoder, prompts: list, device): + tokens = tokenizer( + prompts, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + return text_encoder(tokens.input_ids.to(device))[0] + + +# --------------------------------------------------------------------------- +# Training +# --------------------------------------------------------------------------- + +def train_artifact(args, artifact: str): + from diffusers import ( + DDPMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, + ) + from transformers import CLIPTextModel, CLIPTokenizer + + device = torch.device("cuda") + token = RARE_TOKENS[artifact] + out_dir = Path(args.output_dir) / artifact + out_dir.mkdir(parents=True, exist_ok=True) + unet_dir = out_dir / "unet" + te_dir = out_dir / "text_encoder" + + if unet_dir.exists() and te_dir.exists(): + print(f"\n[{artifact}] LoRA already trained, skipping.") + return + + print(f"\n{'='*60}") + print(f"Training LoRA for artifact: {artifact} (token: '{token}')") + print(f"{'='*60}") + + # ----------------------------------------------------------------------- + # 1. Sample instance images + # ----------------------------------------------------------------------- + n_inst = 3 if args.test else args.n_instance + instance_items = sample_instance_images( + artifact, n_inst, args.data_csv, args.image_dir, seed=args.seed + ) + if not instance_items: + print(f" WARNING: no images found for {artifact}, skipping.") + return + + # ----------------------------------------------------------------------- + # 2. Build prompts + # ----------------------------------------------------------------------- + def instance_prompt(label: str) -> str: + return f"a dermoscopic image of {token} {label}" + + def class_prompt(label: str) -> str: + return f"a dermoscopic image of {label}" + + labels_used = list({lbl for _, lbl in instance_items}) + + # ----------------------------------------------------------------------- + # 3. Generate prior images using unmodified base pipeline + # ----------------------------------------------------------------------- + n_prior = 10 if args.test else args.n_prior + prior_prompts = [] + for i in range(n_prior): + lbl = "benign" if i < n_prior // 2 else "malignant" + prior_prompts.append(class_prompt(lbl)) + + pipe_txt2img = StableDiffusionPipeline.from_pretrained( + args.model, + torch_dtype=torch.float16, + safety_checker=None, + ).to(device) + pipe_txt2img.set_progress_bar_config(disable=True) + + prior_dir = out_dir / "prior_images" + prior_paths = generate_prior_images( + pipe_txt2img, prior_prompts, prior_dir, seed=args.seed + ) + del pipe_txt2img + torch.cuda.empty_cache() + + # ----------------------------------------------------------------------- + # 4. Load model components + # ----------------------------------------------------------------------- + print(" Loading model components…") + tokenizer = CLIPTokenizer.from_pretrained(args.model, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(args.model, subfolder="text_encoder") + unet = UNet2DConditionModel.from_pretrained(args.model, subfolder="unet") + scheduler = DDPMScheduler.from_pretrained(args.model, subfolder="scheduler") + + # Load VAE separately in fp16 (frozen, not trained) + from diffusers import AutoencoderKL + vae = AutoencoderKL.from_pretrained(args.model, subfolder="vae", + torch_dtype=torch.float16).to(device) + vae.requires_grad_(False) + + # ----------------------------------------------------------------------- + # 5. Add LoRA adapters + # ----------------------------------------------------------------------- + text_encoder = add_lora(text_encoder, TEXT_ENCODER_LORA_TARGETS, + args.rank, args.alpha).to(device) + unet = add_lora(unet, UNET_LORA_TARGETS, args.rank, args.alpha).to(device) + + # Cast trainable LoRA params to fp32, rest to fp16 + for name, param in unet.named_parameters(): + if param.requires_grad: + param.data = param.data.float() + else: + param.data = param.data.half() + for name, param in text_encoder.named_parameters(): + if param.requires_grad: + param.data = param.data.float() + else: + param.data = param.data.half() + + # ----------------------------------------------------------------------- + # 6. Dataset & dataloader + # ----------------------------------------------------------------------- + prior_label_list = ["benign" if i < n_prior // 2 else "malignant" + for i in range(len(prior_paths))] + # Use a single class prompt (mixed) for simplicity + dataset = DreamBoothDataset( + instance_items=instance_items, + instance_prompt_fn=instance_prompt, + prior_paths=prior_paths, + prior_prompt="a dermoscopic image of benign", + ) + loader = DataLoader(dataset, batch_size=args.batch_size, + shuffle=True, collate_fn=collate_fn, + num_workers=0, drop_last=True) + + # ----------------------------------------------------------------------- + # 7. Optimizer + # ----------------------------------------------------------------------- + trainable = ( + [p for p in unet.parameters() if p.requires_grad] + + [p for p in text_encoder.parameters() if p.requires_grad] + ) + optimizer = torch.optim.AdamW(trainable, lr=args.lr, + betas=(0.9, 0.999), weight_decay=1e-2) + + epochs = 1 if args.test else args.epochs + total_steps = epochs * math.ceil(len(dataset) / args.batch_size) + print(f" Instance: {len(instance_items)} Prior: {len(prior_paths)} " + f"Steps: {total_steps}") + + # ----------------------------------------------------------------------- + # 8. Training loop + # ----------------------------------------------------------------------- + scaler = torch.cuda.amp.GradScaler() + global_step = 0 + + for epoch in range(epochs): + unet.train() + text_encoder.train() + epoch_loss = 0.0 + + for batch in loader: + # Concatenate instance + prior along batch dim for a single fwd pass + pixels = torch.cat([ + batch["instance_pixels"].to(device, dtype=torch.float16), + batch["prior_pixels"].to(device, dtype=torch.float16), + ]) + prompts = batch["instance_prompts"] + batch["prior_prompts"] + + with torch.cuda.amp.autocast(): + # Encode images to latents + latents = vae.encode(pixels).latent_dist.sample() * 0.18215 + + # Sample noise and timesteps + noise = torch.randn_like(latents) + bsz = latents.shape[0] + timesteps = torch.randint( + 0, scheduler.config.num_train_timesteps, (bsz,), + device=device, dtype=torch.long + ) + noisy_latents = scheduler.add_noise(latents, noise, timesteps) + + # Text conditioning + enc = encode_prompts(tokenizer, text_encoder, prompts, device) + + # UNet prediction + pred = unet(noisy_latents, timesteps, enc).sample + + # Split instance vs prior halves + half = bsz // 2 + pred_inst, pred_prior = pred[:half], pred[half:] + noise_inst, noise_prior = noise[:half], noise[half:] + + if scheduler.config.prediction_type == "epsilon": + target_inst, target_prior = noise_inst, noise_prior + else: + target_inst = scheduler.get_velocity( + latents[:half], noise_inst, timesteps[:half]) + target_prior = scheduler.get_velocity( + latents[half:], noise_prior, timesteps[half:]) + + loss_inst = F.mse_loss(pred_inst.float(), target_inst.float()) + loss_prior = F.mse_loss(pred_prior.float(), target_prior.float()) + loss = loss_inst + args.prior_weight * loss_prior + + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(trainable, 1.0) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + epoch_loss += loss.item() + global_step += 1 + + avg = epoch_loss / max(1, len(loader)) + print(f" Epoch {epoch+1}/{epochs} loss={avg:.4f}") + + # ----------------------------------------------------------------------- + # 9. Save LoRA weights + # ----------------------------------------------------------------------- + unet.save_pretrained(str(unet_dir)) + text_encoder.save_pretrained(str(te_dir)) + print(f" Saved LoRA adapters → {out_dir}") + + del unet, text_encoder, vae + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser( + description="DreamBooth + LoRA training for dermoscopic artifacts" + ) + p.add_argument("--artifact", default="all", + choices=ALL_ARTIFACTS + ["all"], + help="Artifact to train (default: all)") + p.add_argument("--n_instance", type=int, default=50, + help="Instance images per artifact (default 50)") + p.add_argument("--epochs", type=int, default=4) + p.add_argument("--lr", type=float, default=1e-4) + p.add_argument("--batch_size", type=int, default=2) + p.add_argument("--rank", type=int, default=64) + p.add_argument("--alpha", type=int, default=32) + p.add_argument("--prior_weight", type=float, default=0.3) + p.add_argument("--n_prior", type=int, default=200) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--output_dir", + default=os.path.expanduser("~/lora_checkpoints")) + p.add_argument("--data_csv", + default=os.path.expanduser( + "~/isic2018_data/isic-artifact-metadata-pyhealth.csv")) + p.add_argument("--image_dir", + default=os.path.expanduser( + "~/isic2018_data/ISIC2018_Task1-2_Training_Input")) + p.add_argument("--model", default="runwayml/stable-diffusion-v1-5") + p.add_argument("--test", action="store_true", + help="Smoke-test: 3 images, 1 epoch, 10 prior images") + return p.parse_args() + + +def main(): + args = parse_args() + torch.manual_seed(args.seed) + + artifacts = ALL_ARTIFACTS if args.artifact == "all" else [args.artifact] + print(f"Training LoRA for: {artifacts}") + print(f"Instance images: {3 if args.test else args.n_instance} " + f"Epochs: {1 if args.test else args.epochs} " + f"Prior: {10 if args.test else args.n_prior}") + + for artifact in artifacts: + train_artifact(args, artifact) + + print("\nAll done.") + + +if __name__ == "__main__": + main() diff --git a/examples/isic2018_train_resnet50.py b/examples/isic2018_train_resnet50.py new file mode 100644 index 000000000..a370eccb3 --- /dev/null +++ b/examples/isic2018_train_resnet50.py @@ -0,0 +1,468 @@ +""" +ISIC 2018 — Binary Melanoma Classification Under Artifact Modes +================================================================ + +Replicates image-mode experiments from: + "A Study of Artifacts on Melanoma Classification under Diffusion-Based + Perturbations" + +Supports 14 preprocessing modes via ``--mode``: + whole, lesion, background, bbox, bbox70, bbox90, + high_whole, high_lesion, high_background, + low_whole, low_lesion, low_background, + blur_bg, gray_whole + +Dataset +------- +Images / masks (~9 GB): https://challenge.isic-archive.com/data/#2018 +Artifact annotations: https://github.com/alceubissoto/debiasing-skin + +Expected layout under ``--root``:: + + / + isic_bias.csv (--annotations_csv) + ISIC2018_Task1-2_Training_Input/ (--image_dir) + ISIC2018_Task1_Training_GroundTruth/ (--mask_dir) + +Pass ``--download`` to fetch automatically. + +Usage +----- + # Whole-image mode (default) + python isic2018_train_resnet50.py --root /path/to/data + + # Specific mode + python isic2018_train_resnet50.py --root /path/to/data --mode lesion + + # Sigma ablation + python isic2018_train_resnet50.py --root /path/to/data --mode low_whole --sigma 2.0 + +Experimental Setup +------------------ +All runs use ResNet-50 with ImageNet pretrained weights. + +Common parameters: + + Model : ResNet-50, ImageNet pretrained (weights="DEFAULT") + Optimizer : Adam, lr=1e-4, weight_decay=0.0 + Epochs : 10 + Batch size : 32 + CV : 5-fold KFold (shuffle=True, random_state=42) + Sigma : 1.0 [GaussianBlur for high_* / low_* modes] + Filter backend: scipy.ndimage.gaussian_filter (reference-faithful) + +Two validation strategies are supported via ``--val_strategy``: + + none (default) Train on full train_val split, evaluate at last epoch. + Matches reference methodology. Use for replication. + best Hold out 10% val per fold, load best checkpoint by val + AUROC. Use for ablation / model selection. + +All results are 5-fold CV on the ISIC 2018 *training* partition only +(no independent test set); metrics may overestimate generalization. + +Replication Results (10 epochs, val_strategy=none, sigma=1.0) +---------------------------------------------------------- +AUROC per fold (canonical mode order): + + Mode F1 F2 F3 F4 F5 Mean +-Std + ------------------------------------------------------------------ + whole 0.742 0.760 0.717 0.750 0.771 0.748 0.021 + lesion 0.708 0.687 0.732 0.722 0.805 0.731 0.045 + background 0.744 0.717 0.733 0.731 0.741 0.733 0.010 + bbox 0.764 0.655 0.673 0.720 0.661 0.695 0.046 + bbox70 0.707 0.624 0.611 0.620 0.634 0.639 0.039 + bbox90 0.653 0.599 0.563 0.632 0.612 0.612 0.034 + high_whole 0.650 0.670 0.680 0.602 0.639 0.648 0.031 + high_lesion 0.645 0.714 0.652 0.682 0.741 0.687 0.041 + high_background 0.723 0.681 0.655 0.684 0.685 0.686 0.024 + low_whole 0.710 0.779 0.761 0.726 0.782 0.751 0.032 + low_lesion 0.701 0.690 0.728 0.691 0.764 0.715 0.032 + low_background 0.749 0.637 0.716 0.755 0.718 0.715 0.047 + +Key observations: +- low_whole (0.751) matches whole (0.748); diff +0.003, p=0.83 (n.s.). +- low_whole vs high_whole: diff +0.103, p=0.002 (*) — the low/high-pass gap + is the only significant within-region contrast; low-frequency colour/texture + carries the signal, not fine-grained edges. +- whole vs bbox90: diff +0.136, p=0.001 (*) — aggressive context removal + significantly degrades performance. + +Sigma Ablation Results (low_whole, val_strategy=none, filter_backend=scipy) +--------------------------------------------------------------------------- +AUROC per fold across Gaussian blur sigma values: + + Sigma F1 F2 F3 F4 F5 Mean +-Std + ----------------------------------------------------------- + 0.5 0.761 0.768 0.723 0.730 0.738 0.744 0.020 + 1.0 0.710 0.779 0.761 0.726 0.782 0.751 0.032 + 2.0 0.753 0.705 0.720 0.765 0.775 0.744 0.030 + 4.0 0.690 0.736 0.730 0.753 0.746 0.731 0.025 + 8.0 0.742 0.703 0.652 0.696 0.771 0.713 0.046 + 16.0 0.783 0.689 0.716 0.704 0.706 0.720 0.037 + +Key observations: +- Performance peaks at sigma=1.0 (0.751) and degrades monotonically at higher + sigmas; sigma=8.0 shows the largest drop (0.713) and highest variance (±0.041), + suggesting aggressive smoothing removes diagnostically useful features. +- sigma=0.5 (0.744) is competitive but slightly below sigma=1.0. +- sigma=1.0 (low_whole, 0.751) vs whole (0.748): paired t-test diff=+0.003, + t=0.229, p=0.830 — not significant; low-pass at sigma=1.0 retains full + performance, confirming low-frequency features carry the diagnostic signal. + Note: resizing to 224×224 already acts as an implicit low-pass filter, which + may explain why an additional Gaussian at sigma=1.0 has negligible effect. +- sigma=1.0 vs sigma=16.0 (low_whole): paired t-test diff=+0.032, + t=1.095, p=0.335 — not significant (df=4); the trend of degradation at + high sigma is consistent but underpowered with 5 folds. + +Mode Ablation Results +--------------------- +gray_whole, blur_bg, whole_norm, whole_stratified: val_strategy=none. +whole_best, whole_best_stratified: val_strategy=best. +blur_bg_best_stratified: val_strategy=best, stratified (pending). + +AUROC per fold: + + Mode F1 F2 F3 F4 F5 Mean +-Std val_strategy + -------------------------------------------------------------------------------------- + gray_whole 0.777 0.749 0.691 0.763 0.772 0.750 0.035 none + blur_bg 0.767 0.801 0.747 0.752 0.756 0.765 0.022 none + whole_norm 0.734 0.738 0.718 0.730 0.775 0.739 0.022 none + whole_stratified 0.772 0.799 0.745 0.737 0.778 0.766 0.025 none + whole_best 0.787 0.791 0.795 0.744 0.846 0.792 0.036 best + whole_best_stratified 0.824 0.814 0.738 0.710 0.807 0.779 0.051 best + +Key observations: +- whole_best (0.792) vs whole/none (0.748): diff=+0.044, t=2.884, p=0.045 (*) — + early stopping meaningfully improves generalisation. +- whole_best_stratified (0.779) vs whole/none (0.748): diff=+0.031, t=1.501, + p=0.208 (n.s., df=4) — combining best+stratified is not significantly better + than none alone; high variance across folds. +- whole_best_stratified (0.779) vs whole_best (0.792): diff=-0.013, t=-0.750, + p=0.495 (n.s., df=4) — stratified folding does not add benefit over best alone. +- whole_stratified (0.766) vs whole/none (0.748): diff=+0.018, t=1.931, + p=0.126 (n.s., df=4) — class balance is not a confound. +- gray_whole (0.750) vs whole/none (0.748): diff=+0.002, t=0.224, p=0.834 (n.s., df=4) — + colour loss does not degrade training performance; the model learns effectively + from grayscale images alone. +- blur_bg (0.765) vs whole/none (0.748): diff=+0.016, t=1.634, p=0.178 (n.s., df=4) — + background blurring does not significantly improve performance. +- whole_norm (0.739) vs whole/none (0.748): diff=-0.009, t=-1.754, p=0.154 (n.s., df=4) — + per-image normalisation does not significantly affect performance. + +Per-Artifact AUROC Analysis (whole/none baseline, 5-fold CV on training set) +---------------------------------------------------------------------------- +AUROC computed on held-out test folds of the ISIC 2018 training partition, +separately on images with vs. without each artifact type. +Paired t-test (df varies per artifact; folds with single-class subsets excluded). + + Artifact With Artifact Without Artifact Diff p + ───────────────────────────────────────────────────────────── + dark_corner 0.7543 0.7460 +0.008 0.807 + hair 0.7363 0.7567 -0.020 0.373 + gel_border 0.7658 0.7414 +0.024 0.295 + gel_bubble 0.7569 0.7461 +0.011 0.675 + ruler 0.7412 0.7593 -0.018 0.602 + ink 0.7256 0.7518 -0.026 0.518 + patches 0.9762 0.7370 +0.239 0.048 (*) + +Key observations: +- patches is a strong outlier: the model achieves near-perfect AUROC (0.976) on + images containing patches, vs 0.737 on images without (diff=+0.239, p=0.048). + This suggests the model exploits patch presence as a diagnostic shortcut rather + than learning true lesion features. Most folds have single-class patch subsets + (AUROC undefined), so the significant result is based on limited valid pairs. +- All other artifacts show negligible and non-significant effects (|diff| <= 0.026, + p >= 0.295), indicating the model is not strongly biased by common dermoscopic + artifacts such as hair, ruler marks, or gel borders. + +PH2 Classifier Training (whole mode, val_strategy=none) +------------------------------------------------------- +Script: examples/ph2_train_resnet50.py +Same KFold(n_splits=5, shuffle=True, random_state=42) as paper. +Binary: melanoma=1, common_nevus/atypical_nevus=0. 200 images. +Paper (Jin CHIL 2025) reports whole=0.975 ±0.012. + + Mode F1 F2 F3 F4 F5 Mean +-Std + ---------------------------------------------------------- + whole 0.926 0.939 0.961 0.996 0.980 0.961 0.029 + +Slight gap vs paper (0.961 vs 0.975) likely due to flat JPEG input +vs paper's original BMP format; fold assignments match (same KFold seed). + +Diffusion-Augmented Evaluation — Partial Table 4 Replication +------------------------------------------------------------- +Script: examples/ph2_artifacts_test_resnet50.py +Classifier sources: ISIC whole_sigma1.0_none and PH2 whole (above). +Augmented PH2 images: ~/ph2_augmented/{clean,dark_corner,ruler}/ + dark_corner — programmatic vignette (deterministic) + ruler — SD-inpainted ruler marks (base runwayml/stable-diffusion-inpainting) +Note: paper uses DreamBooth+LoRA fine-tuned models per artifact type; our +augmentation is a simplified approximation. Mode: whole only (no PH2 masks). + +ISIC classifiers → PH2: + + Condition F1 F2 F3 F4 F5 Mean +-Std + -------------------------------------------------------------- + clean 0.878 0.941 0.917 0.898 0.868 0.900 0.029 + dark_corner 0.702 0.920 0.583 0.643 0.561 0.682 0.144 t=-3.885 p=0.018 (*) + ruler 0.847 0.953 0.929 0.906 0.902 0.907 0.039 t=+0.664 p=0.543 (n.s.) + +PH2 classifiers → PH2: + + Condition F1 F2 F3 F4 F5 Mean +-Std + -------------------------------------------------------------- + clean 0.999 0.997 0.996 1.000 0.999 0.998 0.002 + dark_corner 0.994 0.990 0.990 0.994 0.998 0.993 0.004 t=-4.767 p=0.009 (**) + ruler 0.978 0.994 0.995 0.992 0.999 0.992 0.008 t=-1.712 p=0.162 (n.s.) + +Key observations (cf. paper Table 4, whole mode): +- PH2-trained classifiers are highly robust to both artifacts (0.998→0.993/0.992), + matching the paper's finding (0.975→0.978/0.966). +- ISIC dark_corner causes a large drop (0.900→0.682, p=0.018). Paper shows a + smaller drop (0.858→0.816) because their DreamBooth-generated dark corners are + more subtle than our programmatic vignette. +- ISIC ruler is unaffected (0.907, p=0.543), consistent with paper results. +- ISIC classifiers are more vulnerable to artifact perturbation than PH2-trained + classifiers, replicating the paper's core finding. + +""" + +import argparse +import logging +import os +import sys + +import numpy as np +from sklearn.model_selection import KFold, StratifiedKFold + +from pyhealth.datasets import ISIC2018ArtifactsDataset, get_dataloader +from pyhealth.models import TorchvisionModel +from pyhealth.processors import DermoscopicImageProcessor +from pyhealth.processors.dermoscopic_image_processor import VALID_MODES +from pyhealth.tasks import ISIC2018ArtifactsBinaryClassification +from pyhealth.trainer import Trainer + +parser = argparse.ArgumentParser( + description="Train ISIC2018 artifact classifier") +parser.add_argument( + "--root", + type=str, + required=True, + help="Root directory containing the annotation CSV, images, and masks.", +) +parser.add_argument( + "--image_dir", + type=str, + default="ISIC2018_Task1-2_Training_Input", + help="Sub-directory (relative to root, or absolute path) for ISIC images.", +) +parser.add_argument( + "--mask_dir", + type=str, + default="ISIC2018_Task1_Training_GroundTruth", + help="Sub-directory (relative to root, or absolute path) for segmentation masks.", +) +parser.add_argument( + "--annotations_csv", + type=str, + default="isic_bias.csv", + help="Annotation CSV filename (relative to root, or absolute path).", +) +parser.add_argument( + "--mode", + type=str, + default="whole", + choices=VALID_MODES, + help="Image preprocessing mode.", +) +parser.add_argument( + "--model", + type=str, + default="resnet50", + help="Torchvision model backbone (e.g. resnet50, vit_b_16).", +) +parser.add_argument("--epochs", type=int, default=10) +parser.add_argument("--batch_size", type=int, default=32) +parser.add_argument("--lr", type=float, default=1e-4) +parser.add_argument("--n_splits", type=int, default=5) +parser.add_argument("--seed", type=int, default=42) +parser.add_argument( + "--stratified", + action="store_true", + help="Use StratifiedKFold instead of KFold to preserve class balance per fold.") +parser.add_argument( + "--sigma", + type=float, + default=1.0, + help="Gaussian sigma for high_* / low_* filter modes (default: 1.0).") +parser.add_argument( + "--high_grayscale", + action=argparse.BooleanOptionalAction, + default=True, + help="If True (default), apply high-pass filter in grayscale then stack to 3 channels " + "(matches reference grayscale=True). " + "Use --no-high-grayscale to apply HPF per RGB channel instead.") +parser.add_argument( + "--val_strategy", + type=str, + default="none", + choices=["none", "best"], + help="Validation strategy. " + "'none' (default): train on full train_val split, no val holdout, " + "evaluate last epoch — matches the reference implementation. " + "'best': hold out 10%% of train_val as validation and load the " + "best-scoring checkpoint at the end (ablation).") +parser.add_argument( + "--resume", + action="store_true", + help="Skip folds whose checkpoint already exists in the output directory.") +parser.add_argument( + "--cache_only", + action="store_true", + help="Build the litdata sample cache and exit without training. " + "Use this to pre-warm caches in parallel before training runs.") +parser.add_argument( + "--download", + action="store_true", + help="Auto-download data.") +parser.add_argument( + "--num_workers", + type=int, + default=8, + help="Number of worker processes for building the litdata cache (default: 8).") +args = parser.parse_args() + +# Route PyHealth trainer logs to stdout so per-epoch metrics are visible. +_handler = logging.StreamHandler() +_handler.setFormatter(logging.Formatter("%(message)s")) +logging.getLogger("pyhealth.trainer").addHandler(_handler) +logging.getLogger("pyhealth.trainer").setLevel(logging.INFO) + + + +if __name__ == "__main__": + # ------------------------------------------------------------------ + # 1. Build dataset — all path resolution delegated to the loader + # ------------------------------------------------------------------ + dataset = ISIC2018ArtifactsDataset( + root=args.root, + annotations_csv=args.annotations_csv, + image_dir=args.image_dir, + mask_dir=args.mask_dir, + download=args.download, + num_workers=args.num_workers, + ) + dataset.stats() + + # ------------------------------------------------------------------ + # 2. Apply task → SampleDataset with binary labels + # ------------------------------------------------------------------ + processor = DermoscopicImageProcessor( + mode=args.mode, + sigma=args.sigma, + mask_dir=dataset.mask_dir, + high_grayscale=args.high_grayscale, + ) + task = ISIC2018ArtifactsBinaryClassification() + samples = dataset.set_task(task, input_processors={"image": processor}) + + if args.cache_only: + print(f"Cache built for mode={args.mode} sigma={args.sigma}. Exiting (--cache_only).") + samples.close() + sys.exit(0) + + # ------------------------------------------------------------------ + # 3. Generate stratified K-fold splits from sample labels + # ------------------------------------------------------------------ + labels = np.array([samples[i]["label"] for i in range(len(samples))]) + indices = np.arange(len(labels)) + + splitter_cls = StratifiedKFold if args.stratified else KFold + skf = splitter_cls( + n_splits=args.n_splits, + shuffle=True, + random_state=args.seed) + + color_tag = "" if args.high_grayscale else "_color" + strat_tag = "_stratified" if args.stratified else "" + output_dir = os.path.join( + args.root, "checkpoints", + f"{args.mode}_sigma{args.sigma}{color_tag}{strat_tag}_{args.val_strategy}") + os.makedirs(output_dir, exist_ok=True) + + split_input = (indices, labels) if args.stratified else (indices,) + for fold, (train_val_idx, test_idx) in enumerate( + skf.split(*split_input), start=1): + ckpt_path = os.path.join(output_dir, f"fold{fold}.pt") + if args.resume and os.path.exists(ckpt_path): + print(f"\nSkipping fold {fold}/{args.n_splits} — checkpoint exists: {ckpt_path}") + continue + + print(f"\n{'=' * 60}") + print(f" Mode: {args.mode} | Split {fold}/{args.n_splits}" + f" | val_strategy={args.val_strategy}") + print(f"{'=' * 60}") + + if args.val_strategy == "best": + # Hold out 10% of train_val for model selection + val_size = max(1, int(0.1 * len(train_val_idx))) + rng = np.random.default_rng(args.seed + fold) + rng.shuffle(train_val_idx) + val_idx = train_val_idx[:val_size] + train_idx = train_val_idx[val_size:] + val_loader = get_dataloader( + samples.subset(val_idx), batch_size=args.batch_size, shuffle=False, + num_workers=args.num_workers, + ) + else: + # Reference-faithful: train on full train_val, no validation + train_idx = train_val_idx + val_loader = None + + train_loader = get_dataloader( + samples.subset(train_idx), batch_size=args.batch_size, shuffle=True, + num_workers=args.num_workers, + ) + test_loader = get_dataloader( + samples.subset(test_idx), batch_size=args.batch_size, shuffle=False, + num_workers=args.num_workers, + ) + + # -------------------------------------------------------------- + # 4. Fresh model per fold + # -------------------------------------------------------------- + model = TorchvisionModel( + dataset=samples, + model_name=args.model, + model_config={"weights": "DEFAULT"}, + ) + + # -------------------------------------------------------------- + # 5. Train + # -------------------------------------------------------------- + trainer = Trainer( + model=model, + metrics=["accuracy", "roc_auc"], + ) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=args.epochs, + optimizer_params={"lr": args.lr}, + load_best_model_at_last=(args.val_strategy == "best"), + ) + + # -------------------------------------------------------------- + # 6. Evaluate + # -------------------------------------------------------------- + scores = trainer.evaluate(test_loader) + print(f"Mode: {args.mode} Split {fold} test results: {scores}") + + # -------------------------------------------------------------- + # 7. Save checkpoint + # -------------------------------------------------------------- + trainer.save_ckpt(ckpt_path) + print(f"Checkpoint saved → {ckpt_path}") + + samples.close() diff --git a/examples/ph2_artifacts_test_resnet50.py b/examples/ph2_artifacts_test_resnet50.py new file mode 100644 index 000000000..8f1f670a5 --- /dev/null +++ b/examples/ph2_artifacts_test_resnet50.py @@ -0,0 +1,242 @@ +""" +PH2 Artifacts Evaluation — ISIC and PH2 ResNet-50 classifiers on clean + perturbed PH2. +======================================================================================== + +Replication of Table 4 from Jin et al. (CHIL 2025): + Evaluates classifiers trained on PH2 or ISIC on DreamBooth-augmented PH2 images + where a specific artifact is added to each image. + +Paper artifacts: dark_corner, gel_bubble, ink, patches, ruler (hair not in paper). + +Checkpoints +----------- + ISIC : ~/isic2018_data/checkpoints/whole_sigma1.0_none/fold{1..5}.pt (1-indexed) + PH2 : ~/ph2_checkpoints/whole/resnet50_fold{0..4}.pt (0-indexed) + +Augmented images +---------------- + ~/ph2_augmented/{artifact}/{image_id}.jpg + Produced by ph2_artifacts_augment_sd.py + +PH2 labels +---------- + melanoma → 1 ; common_nevus / atypical_nevus → 0 + +Results (Table 4 replication — whole mode, mean AUROC ± std over 5 folds) +-------------------------------------------------------------------------- + Artifact PH2-trained Paper† ISIC-trained Paper† + ───────────────────────────────────────────────────────────────── + original 0.998±0.002 0.975 0.900±0.029 0.858 + dark_corner 0.992±0.004 0.978 0.847±0.083 0.816 + gel_bubble 0.996±0.003 0.973 0.892±0.045 0.841 + ink 0.994±0.007 0.959 0.905±0.029 0.788 + patches 0.995±0.005 0.976 0.909±0.037 0.848 + ruler 0.992±0.010 0.966 0.904±0.040 0.752 + hair (ours) 0.972±0.007 — 0.866±0.041 — + +Usage +----- + pixi run -e base python examples/ph2_artifacts_test_resnet50.py + pixi run -e base python examples/ph2_artifacts_test_resnet50.py --source isic + pixi run -e base python examples/ph2_artifacts_test_resnet50.py --source ph2 + pixi run -e base python examples/ph2_artifacts_test_resnet50.py \\ + --artifacts clean dark_corner ruler +""" + +import argparse +import csv +import os +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from sklearn.metrics import roc_auc_score +from torch.utils.data import DataLoader, Dataset +from torchvision import models, transforms + +# --------------------------------------------------------------------------- +# Paths +# --------------------------------------------------------------------------- + +AUG_DIR = Path(os.path.expanduser("~/ph2_augmented")) +META_PATH = AUG_DIR / "augmented_metadata.csv" + +ISIC_CKPT_DIR = Path(os.path.expanduser("~/isic2018_data/checkpoints/whole_sigma1.0_none")) +PH2_CKPT_DIR = Path(os.path.expanduser("~/ph2_checkpoints/whole")) + +TRANSFORM = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +ARTIFACTS = ["clean", "dark_corner", "gel_bubble", "ink", "patches", "ruler", "hair"] + + +# --------------------------------------------------------------------------- +# Dataset +# --------------------------------------------------------------------------- + +class PH2AugDataset(Dataset): + def __init__(self, records, transform=None): + self.records = records + self.transform = transform + + def __len__(self): + return len(self.records) + + def __getitem__(self, idx): + path, label = self.records[idx] + img = Image.open(path).convert("RGB") + if self.transform: + img = self.transform(img) + return img, label + + +# --------------------------------------------------------------------------- +# Model helpers +# --------------------------------------------------------------------------- + +def build_resnet50(): + model = models.resnet50(weights=None) + model.fc = nn.Linear(model.fc.in_features, 1) + return model.to(DEVICE) + + +def load_isic_checkpoint(fold_k: int) -> nn.Module: + """Load ISIC whole_none checkpoint (1-indexed: fold1..fold5). + + PyHealth saves with ``model.`` prefix and ``_dummy_param`` key; strip both. + """ + path = ISIC_CKPT_DIR / f"fold{fold_k + 1}.pt" + raw = torch.load(path, map_location=DEVICE) + state = {k.removeprefix("model."): v for k, v in raw.items() if k != "_dummy_param"} + model = build_resnet50() + model.load_state_dict(state) + model.eval() + return model + + +def load_ph2_checkpoint(fold_k: int) -> nn.Module: + """Load PH2 whole checkpoint (0-indexed: fold0..fold4).""" + path = PH2_CKPT_DIR / f"resnet50_fold{fold_k}.pt" + model = build_resnet50() + model.load_state_dict(torch.load(path, map_location=DEVICE)) + model.eval() + return model + + +# --------------------------------------------------------------------------- +# Evaluation +# --------------------------------------------------------------------------- + +@torch.no_grad() +def compute_auroc(model: nn.Module, records) -> float: + ds = PH2AugDataset(records, TRANSFORM) + loader = DataLoader(ds, batch_size=32, shuffle=False, num_workers=4) + probs, labels = [], [] + for imgs, lbls in loader: + p = torch.sigmoid(model(imgs.to(DEVICE))).squeeze(1).cpu().numpy() + probs.extend(p) + labels.extend(lbls.numpy()) + return roc_auc_score(labels, probs) + + +def evaluate_source(source: str, records_by_artifact: dict): + """Evaluate all 5 folds of `source` (isic|ph2) across artifacts.""" + print(f"\n{'='*60}") + print(f"Source: {source.upper()} classifiers → PH2 (whole mode)") + print(f"{'='*60}") + + loader_fn = load_isic_checkpoint if source == "isic" else load_ph2_checkpoint + + results = {} + for artifact, recs in records_by_artifact.items(): + fold_aurocs = [] + for k in range(5): + model = loader_fn(k) + auroc = compute_auroc(model, recs) + fold_aurocs.append(auroc) + del model + torch.cuda.empty_cache() + mean = np.mean(fold_aurocs) + std = np.std(fold_aurocs, ddof=1) + results[artifact] = (mean, std, fold_aurocs) + print(f" {artifact:15s} AUROC={mean:.3f} ±{std:.3f} " + f"folds={[f'{a:.3f}' for a in fold_aurocs]}") + return results + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser( + description="Evaluate ISIC/PH2 classifiers on clean + artifact-augmented PH2 (Table 4)") + p.add_argument( + "--source", choices=["isic", "ph2", "both"], default="both", + help="Which trained classifiers to evaluate", + ) + p.add_argument( + "--artifacts", nargs="+", default=ARTIFACTS, + help="Artifact conditions to evaluate", + ) + p.add_argument( + "--aug_dir", type=str, default=str(AUG_DIR), + help="Directory containing augmented images and augmented_metadata.csv", + ) + return p.parse_args() + + +def main(): + args = parse_args() + meta_path = Path(args.aug_dir) / "augmented_metadata.csv" + + # Load metadata grouped by artifact + records_by_artifact: dict[str, list] = {} + with open(meta_path) as f: + for row in csv.DictReader(f): + artifact = row["artifact"] + label = 1 if row["diagnosis"] == "melanoma" else 0 + records_by_artifact.setdefault(artifact, []).append((row["path"], label)) + + artifact_subset = {a: records_by_artifact[a] for a in args.artifacts if a in records_by_artifact} + + all_results = {} + if args.source in ("isic", "both"): + if not any(ISIC_CKPT_DIR.glob("fold*.pt")): + print(f"[WARN] No ISIC checkpoints found in {ISIC_CKPT_DIR}") + else: + all_results["isic"] = evaluate_source("isic", artifact_subset) + + if args.source in ("ph2", "both"): + if not any(PH2_CKPT_DIR.glob("resnet50_fold*.pt")): + print(f"[WARN] No PH2 checkpoints found in {PH2_CKPT_DIR}") + print(" Run ph2_train_resnet50.py first.") + else: + all_results["ph2"] = evaluate_source("ph2", artifact_subset) + + if all_results: + print(f"\n{'='*60}") + print("Summary (mean AUROC)") + print(f"{'Artifact':15s} " + " ".join(f"{s:>12}" for s in all_results)) + print("-" * 60) + for artifact in args.artifacts: + row_vals = [] + for src, res in all_results.items(): + if artifact in res: + m, s, _ = res[artifact] + row_vals.append(f"{m:.3f}±{s:.3f}") + else: + row_vals.append(" — ") + print(f"{artifact:15s} " + " ".join(f"{v:>12}" for v in row_vals)) + + +if __name__ == "__main__": + main() + diff --git a/examples/ph2_diffusion_sd.py b/examples/ph2_diffusion_sd.py new file mode 100644 index 000000000..7b8e621f9 --- /dev/null +++ b/examples/ph2_diffusion_sd.py @@ -0,0 +1,518 @@ +"""PH2 Dermoscopic Artifact Augmentation via Stable Diffusion Inpainting. + +Adds synthetic artifacts to the 200-image PH2 dataset using a two-step pipeline: + + 1. Mask generation — each artifact function returns the original image + unchanged plus a binary mask that defines *where* the artifact should + appear (hair paths, edge strips, bubble circles, corner gradients, etc.) + 2. SD inpainting — the model generates the artifact appearance from scratch + inside the mask, guided purely by a text prompt. No programmatic drawing + is blended into the image. + +Supported artifact types +------------------------ + hair — Bezier-path mask, SD generates hair strands (paper: excluded) + dark_corner — Radial peripheral mask, SD generates dark vignette (paper: ✓) + ruler — Edge-strip mask, SD generates ruler tick marks (paper: ✓) + gel_bubble — Circular masks, SD generates gel bubble discs (paper: ✓) + ink — Small ellipse masks, SD generates ink dot marks (paper: ✓) + patches — Edge-strip mask, SD generates colour-checker swatches (paper: ✓) + +Outputs +------- + ~/ph2_augmented/ + clean/ — resized originals (512×512) with no artifact + hair/ — hair artifact variants + dark_corner/ — dark corner vignette variants + ruler/ — ruler mark variants + gel_bubble/ — gel bubble variants + ink/ — ink marking variants + patches/ — colour calibration patch variants + augmented_metadata.csv — columns: image_id, path, diagnosis, artifact + +Usage +----- + # All artifact types (GPU required — all types use SD inpainting) + pixi run -e base python examples/ph2_diffusion_sd.py \\ + --src ~/ph2/PH2-dataset-master \\ + --out ~/ph2_augmented + + # Specific artifacts + pixi run -e base python examples/ph2_diffusion_sd.py \\ + --artifacts gel_bubble ink dark_corner + + [--model runwayml/stable-diffusion-inpainting] + [--n_aug 1] # augmented copies per image per artifact type + [--test 3] # only process first N images (smoke test) +""" + +import argparse +import csv +import os +import random +import sys +from pathlib import Path + +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFilter + +# --------------------------------------------------------------------------- +# Artifact generators +# --------------------------------------------------------------------------- + +def _bezier_point(p0, p1, p2, t): + """Quadratic Bezier interpolation.""" + return ( + (1 - t) ** 2 * p0[0] + 2 * (1 - t) * t * p1[0] + t ** 2 * p2[0], + (1 - t) ** 2 * p0[1] + 2 * (1 - t) * t * p1[1] + t ** 2 * p2[1], + ) + + +def make_hair_overlay(img: Image.Image, n_strands: int = 12, seed: int = 0): + """Draw dark Bezier hair strands as seeds; mask those paths for SD refinement.""" + rng = random.Random(seed) + w, h = img.size + overlay = img.copy() + draw = ImageDraw.Draw(overlay) + mask = Image.new("L", (w, h), 0) + mdraw = ImageDraw.Draw(mask) + + for _ in range(n_strands): + x0, y0 = rng.randint(0, w), rng.randint(0, h) + cx, cy = rng.randint(0, w), rng.randint(0, h) + x1, y1 = rng.randint(0, w), rng.randint(0, h) + thickness = rng.randint(2, 4) + darkness = rng.randint(5, 25) + hair_col = (darkness, darkness, darkness) + pts = [_bezier_point((x0, y0), (cx, cy), (x1, y1), t / 60) for t in range(61)] + pts_int = [(int(p[0]), int(p[1])) for p in pts] + draw.line(pts_int, fill=hair_col, width=thickness) + mdraw.line(pts_int, fill=255, width=thickness + 8) + + mask = mask.filter(ImageFilter.MaxFilter(5)) + return overlay, mask + + +def make_dark_corner_overlay(img: Image.Image): + """Apply a mild dark vignette; mask the affected region for SD refinement.""" + w, h = img.size + arr = np.array(img).astype(np.float32) + cx, cy = w / 2, h / 2 + y_idx, x_idx = np.ogrid[:h, :w] + dist = np.sqrt(((x_idx - cx) / (w / 2)) ** 2 + ((y_idx - cy) / (h / 2)) ** 2) + # Gentler onset at 0.80, softer falloff + vignette = np.clip(1.0 - np.maximum(0, dist - 0.80) * 1.5, 0, 1) + arr *= vignette[:, :, None] + overlay = Image.fromarray(arr.astype(np.uint8)) + mask = Image.fromarray((255 * (1 - vignette)).astype(np.uint8)) + return overlay, mask + + +def make_ruler_overlay(img: Image.Image, seed: int = 0): + """Draw a semi-transparent ruler strip offset from one edge; mask for SD refinement.""" + rng = random.Random(seed) + w, h = img.size + overlay = img.copy() + mask = Image.new("L", (w, h), 0) + mdraw = ImageDraw.Draw(mask) + + edge = rng.choice(["top", "bottom", "left", "right"]) + n_ticks = rng.randint(10, 20) + tick_len_major = rng.randint(18, 28) + tick_len_minor = tick_len_major // 2 + strip_w = tick_len_major + 16 + offset = 10 + tick_color = (30, 30, 30) + + # Semi-transparent background strip (alpha=120 ≈ 47% opaque) + strip_layer = Image.new("RGBA", (w, h), (0, 0, 0, 0)) + sdraw = ImageDraw.Draw(strip_layer) + + if edge in ("top", "bottom"): + y_bg = offset if edge == "top" else h - offset - strip_w + sdraw.rectangle([0, y_bg, w, y_bg + strip_w], fill=(240, 240, 240, 120)) + base = overlay.convert("RGBA") + base.alpha_composite(strip_layer) + overlay = base.convert("RGB") + draw = ImageDraw.Draw(overlay) + for i in range(n_ticks): + x = int(w * (i + 1) / (n_ticks + 1)) + ln = tick_len_major if i % 5 == 0 else tick_len_minor + draw.line([(x, y_bg + 4), (x, y_bg + 4 + ln)], fill=tick_color, width=2) + mdraw.rectangle([0, y_bg, w, y_bg + strip_w], fill=255) + else: + x_bg = offset if edge == "left" else w - offset - strip_w + sdraw.rectangle([x_bg, 0, x_bg + strip_w, h], fill=(240, 240, 240, 120)) + base = overlay.convert("RGBA") + base.alpha_composite(strip_layer) + overlay = base.convert("RGB") + draw = ImageDraw.Draw(overlay) + for i in range(n_ticks): + y = int(h * (i + 1) / (n_ticks + 1)) + ln = tick_len_major if i % 5 == 0 else tick_len_minor + draw.line([(x_bg + 4, y), (x_bg + 4 + ln, y)], fill=tick_color, width=2) + mdraw.rectangle([x_bg, 0, x_bg + strip_w, h], fill=255) + + return overlay, mask + + +def make_gel_bubble_overlay(img: Image.Image, n_bubbles: int = 5, seed: int = 0): + """Draw semi-transparent bubble fills as seeds; mask those regions for SD refinement.""" + rng = random.Random(seed) + w, h = img.size + max_r = max(5, min(w, h) // 12) # cap radius to ~1/12 of shortest side + overlay = img.copy() + mask = Image.new("L", (w, h), 0) + mdraw = ImageDraw.Draw(mask) + + bubble_layer = Image.new("RGBA", (w, h), (0, 0, 0, 0)) + bdraw = ImageDraw.Draw(bubble_layer) + + for _ in range(n_bubbles): + r = rng.randint(max(5, max_r // 2), max_r) + cx = rng.randint(r, max(r + 1, w - r)) + cy = rng.randint(r, max(r + 1, h - r)) + bdraw.ellipse([cx - r, cy - r, cx + r, cy + r], + fill=(210, 220, 235, 100), outline=(100, 120, 160, 255), width=2) + mdraw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=255) + + base = overlay.convert("RGBA") + base.alpha_composite(bubble_layer) + overlay = base.convert("RGB") + + mask = mask.filter(ImageFilter.MaxFilter(3)) + return overlay, mask + + +def make_ink_overlay(img: Image.Image, n_marks: int = 5, seed: int = 0): + """Draw ink marks biased toward lines; mask those spots for SD refinement.""" + import math + rng = random.Random(seed) + w, h = img.size + overlay = img.copy() + draw = ImageDraw.Draw(overlay) + mask = Image.new("L", (w, h), 0) + mdraw = ImageDraw.Draw(mask) + + for _ in range(n_marks): + cx = rng.randint(int(0.1 * w), int(0.9 * w)) + cy = rng.randint(int(0.1 * h), int(0.9 * h)) + # Bias: 60% line, 25% cross, 15% dot + roll = rng.random() + style = "line" if roll < 0.60 else ("cross" if roll < 0.85 else "dot") + ink_col = (rng.randint(0, 20), rng.randint(0, 20), rng.randint(80, 140)) + r = rng.randint(16, 36) + if style == "dot": + draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=ink_col) + mdraw.ellipse([cx - r - 8, cy - r - 8, cx + r + 8, cy + r + 8], fill=255) + elif style == "cross": + draw.line([(cx - r, cy), (cx + r, cy)], fill=ink_col, width=3) + draw.line([(cx, cy - r), (cx, cy + r)], fill=ink_col, width=3) + mdraw.ellipse([cx - r - 8, cy - r - 8, cx + r + 8, cy + r + 8], fill=255) + else: + angle = rng.uniform(0, math.pi) + x1, y1 = int(cx + r * math.cos(angle)), int(cy + r * math.sin(angle)) + x2, y2 = int(cx - r * math.cos(angle)), int(cy - r * math.sin(angle)) + draw.line([(x1, y1), (x2, y2)], fill=ink_col, width=4) + mdraw.ellipse([cx - r - 8, cy - r - 8, cx + r + 8, cy + r + 8], fill=255) + + mask = mask.filter(ImageFilter.MaxFilter(3)) + return overlay, mask + + +def make_patches_overlay(img: Image.Image, seed: int = 0): + """Draw round colour-calibration stickers half-cut-off at a corner edge; mask for SD.""" + rng = random.Random(seed) + w, h = img.size + overlay = img.copy() + draw = ImageDraw.Draw(overlay) + mask = Image.new("L", (w, h), 0) + mdraw = ImageDraw.Draw(mask) + + colours = [ + (210, 180, 170), (170, 200, 180), (180, 180, 210), + (200, 195, 170), (190, 175, 200), (170, 205, 210), + ] + rng.shuffle(colours) + r = max(10, min(w, h) // 8) + gap = max(4, r // 4) + n_patches = 1 + + # Always start from bottom-left corner, march right + # cy = h places centre on edge so upper half is visible + cy = h + cx = r # first circle starts at x=r from left + for i in range(n_patches): + if cx - r >= w: + break + col = colours[i % len(colours)] + draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col, + outline=(20, 20, 20), width=1) + mdraw.ellipse([cx - r - 4, cy - r - 4, cx + r + 4, cy + r + 4], fill=255) + cx += 2 * r + gap + + return overlay, mask + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + +# Rare tokens used during LoRA training (matches isic2018_artifacts_train_lora_sd15.py) +RARE_TOKENS = { + "patches": "olis", + "dark_corner": "lun", + "ruler": "dits", + "ink": "httr", + "gel_bubble": "sown", + "hair": "helo", +} + +# Inpainting strength per artifact (paper §3.3; hair not in paper → 0.75) +ARTIFACT_STRENGTH = { + "patches": 0.60, + "dark_corner": 0.55, + "ruler": 0.50, + "ink": 0.45, + "gel_bubble": 0.52, + "hair": 0.55, +} + +# Fallback descriptive prompts used when no LoRA is available +FALLBACK_PROMPTS = { + "hair": ( + "dermoscopy skin lesion image with thin dark hair strands crossing the lesion, " + "medical imaging, high detail" + ), + "ruler": ( + "dermoscopy skin lesion image with white ruler measurement marks along the edge, " + "medical imaging, calibration scale" + ), + "gel_bubble": ( + "dermoscopy skin lesion image with transparent gel air bubbles on the surface, " + "bright circular reflections, medical imaging" + ), + "ink": ( + "dermoscopy skin lesion image with small dark ink markings and dots on the skin, " + "clinical annotation marks, medical imaging" + ), + "dark_corner": ( + "dermoscopy skin lesion image with dark vignette corners, black edges fading " + "toward the center, medical imaging" + ), + "patches": ( + "dermoscopy skin lesion image with colour calibration patches along the border, " + "round circular colour swatches, medical imaging" + ), +} + + +def lora_prompt(artifact: str, diagnosis: str) -> str: + """Build the DreamBooth instance prompt using the trained rare token.""" + token = RARE_TOKENS[artifact] + label = "malignant" if "melanoma" in diagnosis else "benign" + return f"a dermoscopic image of {token} {label}" + + +def load_pipeline(model_id: str): + from diffusers import StableDiffusionInpaintPipeline + + print(f"Loading inpainting model: {model_id}") + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, + torch_dtype=torch.float16, + safety_checker=None, + ) + pipe = pipe.to("cuda") + pipe.set_progress_bar_config(disable=True) + return pipe + + +def load_lora_pipeline(model_id: str, artifact: str, lora_dir: str): + """Load a fresh SD inpainting pipeline with LoRA adapters applied. + + Loads a new pipeline instance per artifact to avoid PEFT adapter + accumulation warnings when swapping adapters on the same model. + Returns (pipe, use_lora). + """ + from diffusers import StableDiffusionInpaintPipeline + from peft import PeftModel + + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, + torch_dtype=torch.float16, + safety_checker=None, + ).to("cuda") + pipe.set_progress_bar_config(disable=True) + + unet_path = os.path.join(lora_dir, artifact, "unet") + te_path = os.path.join(lora_dir, artifact, "text_encoder") + + if os.path.isdir(unet_path): + print(f" Loading LoRA for {artifact} from {lora_dir}/{artifact}/") + pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_path).to("cuda", dtype=torch.float16) + pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, te_path).to("cuda", dtype=torch.float16) + return pipe, True + else: + print(f" [no LoRA] {artifact}: using base model + fallback prompt") + return pipe, False + + +def inpaint(pipe, image: Image.Image, mask: Image.Image, + prompt: str, strength: float) -> Image.Image: + """Run SD inpainting at 512×512, resize result back to original size.""" + orig_size = image.size + img_512 = image.resize((512, 512), Image.LANCZOS) + mask_512 = mask.resize((512, 512), Image.NEAREST) + + result = pipe( + prompt=prompt, + image=img_512, + mask_image=mask_512, + num_inference_steps=30, + guidance_scale=10.0, + strength=strength, + ).images[0] + + return result.resize(orig_size, Image.LANCZOS) + + +def augment_image(pipe, img: Image.Image, artifact: str, aug_idx: int, + diagnosis: str = "benign", use_lora: bool = False) -> Image.Image: + """Generate one augmented copy of img with the given artifact type. + + When use_lora=True the DreamBooth rare-token prompt and paper-matched + guidance/strength values are used; otherwise falls back to the + descriptive prompt with the same strength. + """ + seed = aug_idx * 1000 + hash(artifact) % 1000 + + if artifact == "hair": + img_in, mask = make_hair_overlay(img, n_strands=12, seed=seed) + elif artifact == "ruler": + img_in, mask = make_ruler_overlay(img, seed=seed) + elif artifact == "gel_bubble": + img_in, mask = make_gel_bubble_overlay(img, n_bubbles=5, seed=seed) + elif artifact == "ink": + img_in, mask = make_ink_overlay(img, n_marks=4, seed=seed) + elif artifact == "dark_corner": + img_in, mask = make_dark_corner_overlay(img) + elif artifact == "patches": + img_in, mask = make_patches_overlay(img, seed=seed) + else: + raise ValueError(f"Unknown artifact: {artifact}") + + prompt = lora_prompt(artifact, diagnosis) if use_lora else FALLBACK_PROMPTS[artifact] + strength = ARTIFACT_STRENGTH[artifact] + return inpaint(pipe, img_in, mask, prompt, strength) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="PH2 artifact augmentation via SD inpainting") + p.add_argument("--src", default=os.path.expanduser("~/ph2/PH2-dataset-master")) + p.add_argument("--out", default=os.path.expanduser("~/ph2_augmented")) + p.add_argument("--model", default="runwayml/stable-diffusion-inpainting") + p.add_argument("--lora_dir", default=os.path.expanduser("~/lora_checkpoints"), + help="Root directory of trained LoRA adapters (default ~/lora_checkpoints)") + p.add_argument("--n_aug", type=int, default=1, help="Augmented copies per artifact type") + p.add_argument("--test", type=int, default=0, help="Only process first N images (0=all)") + p.add_argument("--artifacts", nargs="+", + default=["hair", "dark_corner", "ruler", "gel_bubble", "ink", "patches"]) + return p.parse_args() + + +def main(): + args = parse_args() + src = Path(args.src) + out = Path(args.out) + + # Load metadata CSV + meta_path = src / "PH2_simple_dataset.csv" + if not meta_path.exists(): + meta_path = src / "ph2_metadata_pyhealth.csv" + rows = [] + with open(meta_path) as f: + reader = csv.DictReader(f) + for row in reader: + rows.append(row) + + id_col = "image_name" if "image_name" in rows[0] else "image_id" + diag_col = "diagnosis" + images_dir = src / "images" + + if args.test: + rows = rows[: args.test] + + out.mkdir(parents=True, exist_ok=True) + (out / "clean").mkdir(exist_ok=True) + for art in args.artifacts: + (out / art).mkdir(exist_ok=True) + + # Build list of valid (img_id, diagnosis, img_path) tuples + valid_rows = [] + for row in rows: + img_id = row[id_col] + diagnosis = row[diag_col].lower().replace(" ", "_") + img_path = images_dir / f"{img_id}.jpg" + if not img_path.exists(): + print(f" [skip] {img_id}: image not found") + continue + valid_rows.append((img_id, diagnosis, img_path)) + + results = [] # (image_id, path, diagnosis, artifact) + + # ----------------------------------------------------------------------- + # Save clean copies + # ----------------------------------------------------------------------- + for img_id, diagnosis, img_path in valid_rows: + clean_path = out / "clean" / f"{img_id}.jpg" + if not clean_path.exists(): + Image.open(img_path).convert("RGB").save(clean_path) + results.append((img_id, str(clean_path), diagnosis, "clean")) + + # ----------------------------------------------------------------------- + # Process one artifact at a time: load fresh pipeline+LoRA, run all images + # ----------------------------------------------------------------------- + for artifact in args.artifacts: + pipe, use_lora = load_lora_pipeline(args.model, artifact, args.lora_dir) + + for img_id, diagnosis, img_path in valid_rows: + for i in range(args.n_aug): + suffix = f"_aug{i}" if args.n_aug > 1 else "" + out_name = f"{img_id}{suffix}.jpg" + out_path = out / artifact / out_name + + if out_path.exists(): + print(f" [cached] {img_id} {artifact}{suffix}") + else: + print(f" Processing {img_id} → {artifact}{suffix} ...", + end=" ", flush=True) + img = Image.open(img_path).convert("RGB") + aug = augment_image(pipe, img, artifact, aug_idx=i, + diagnosis=diagnosis, use_lora=use_lora) + aug.save(out_path) + print("✓") + + results.append((img_id, str(out_path), diagnosis, artifact)) + + # Free GPU memory before loading next artifact's pipeline + del pipe + torch.cuda.empty_cache() + + # Write metadata CSV + csv_path = out / "augmented_metadata.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["image_id", "path", "diagnosis", "artifact"]) + writer.writerows(results) + + print(f"\nDone. {len(results)} records → {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/ph2_train_resnet50.py b/examples/ph2_train_resnet50.py new file mode 100644 index 000000000..222827ceb --- /dev/null +++ b/examples/ph2_train_resnet50.py @@ -0,0 +1,198 @@ +"""PH2 Melanoma Classifier — 5-fold CV (whole mode, val_strategy=none). + +Replicates the PH2-trained classifiers from Jin et al. (CHIL 2025) for the +whole-image mode. Results are used in: + - Table 2: PH2-trained classifiers evaluated on PH2/ISIC/HAM10000 + - Table 4: PH2-trained classifiers evaluated on diffusion-augmented PH2 + +Setup +----- + PH2 images : ~/ph2/PH2-dataset-master/images/{ID}.jpg (200 images) + Metadata : ~/ph2/PH2-dataset-master/ph2_metadata_pyhealth.csv + columns: image_id, path, diagnosis + Labels : melanoma → 1 ; common_nevus / atypical_nevus → 0 + +Splits +------ + KFold(n_splits=5, shuffle=True, random_state=42) — identical to paper. + Training uses the full train fold (no validation holdout). + +Hyperparameters (paper-faithful) +--------------------------------- + Model : ResNet-50, ImageNet pretrained (weights="DEFAULT") + FC : Linear(2048 → 1) + Loss : BCEWithLogitsLoss + Optim : Adam lr=1e-4 + Epochs : 10 + Batch : 32 + Input : 224×224, ImageNet normalised + +Outputs +------- + Checkpoints: ~/ph2_checkpoints/whole/resnet50_fold{k}.pt (k = 0..4) + Per-fold AUROC printed to stdout. + +Usage +----- + pixi run -e base python examples/ph2_train_resnet50.py + pixi run -e base python examples/ph2_train_resnet50.py --test 20 +""" + +import argparse +import os +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from PIL import Image +from sklearn.metrics import roc_auc_score +from sklearn.model_selection import KFold +from torch.utils.data import DataLoader, Dataset +from torchvision import models, transforms + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +CKPT_DIR = Path(os.path.expanduser("~/ph2_checkpoints/whole")) +META_PATH = Path(os.path.expanduser("~/ph2/PH2-dataset-master/ph2_metadata_pyhealth.csv")) + +TRANSFORM = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# --------------------------------------------------------------------------- +# Dataset +# --------------------------------------------------------------------------- + +class PH2WholeDataset(Dataset): + def __init__(self, records, transform=None): + """ + Args: + records: list of (image_path, label) tuples. + transform: torchvision transform to apply. + """ + self.records = records + self.transform = transform + + def __len__(self): + return len(self.records) + + def __getitem__(self, idx): + path, label = self.records[idx] + img = Image.open(path).convert("RGB") + if self.transform: + img = self.transform(img) + return img, torch.tensor(label, dtype=torch.float32) + + +# --------------------------------------------------------------------------- +# Training helpers +# --------------------------------------------------------------------------- + +def build_model(): + model = models.resnet50(weights="DEFAULT") + model.fc = nn.Linear(model.fc.in_features, 1) + return model.to(DEVICE) + + +def train_one_epoch(model, loader, criterion, optimizer): + model.train() + total_loss = 0.0 + for imgs, labels in loader: + imgs = imgs.to(DEVICE) + labels = labels.unsqueeze(1).to(DEVICE) + optimizer.zero_grad() + loss = criterion(model(imgs), labels) + loss.backward() + optimizer.step() + total_loss += loss.item() + return total_loss / len(loader) + + +@torch.no_grad() +def evaluate(model, loader): + model.eval() + all_probs, all_labels = [], [] + for imgs, labels in loader: + probs = torch.sigmoid(model(imgs.to(DEVICE))).squeeze(1).cpu().numpy() + all_probs.extend(probs) + all_labels.extend(labels.numpy()) + return roc_auc_score(all_labels, all_probs) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--epochs", type=int, default=10) + p.add_argument("--batch", type=int, default=32) + p.add_argument("--lr", type=float, default=1e-4) + p.add_argument("--test", type=int, default=0, help="Smoke-test: only use first N images") + p.add_argument("--resume", action="store_true", help="Skip folds whose checkpoint exists") + return p.parse_args() + + +def main(): + args = parse_args() + CKPT_DIR.mkdir(parents=True, exist_ok=True) + + # Load metadata + import csv + records = [] + with open(META_PATH) as f: + for row in csv.DictReader(f): + label = 1 if row["diagnosis"] == "melanoma" else 0 + records.append((row["path"], label)) + + if args.test: + records = records[: args.test] + + records = np.array(records, dtype=object) + kf = KFold(n_splits=5, shuffle=True, random_state=42) + aurocs = [] + + for fold, (train_idx, test_idx) in enumerate(kf.split(records)): + ckpt_path = CKPT_DIR / f"resnet50_fold{fold}.pt" + if args.resume and ckpt_path.exists(): + print(f"[Fold {fold}] checkpoint exists — skipping training") + model = build_model() + model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE)) + else: + print(f"\n{'='*50}\nFold {fold}\n{'='*50}") + train_ds = PH2WholeDataset(records[train_idx].tolist(), TRANSFORM) + train_loader = DataLoader(train_ds, batch_size=args.batch, shuffle=True, num_workers=4) + + model = build_model() + criterion = nn.BCEWithLogitsLoss() + optimizer = optim.Adam(model.parameters(), lr=args.lr) + + for epoch in range(args.epochs): + loss = train_one_epoch(model, train_loader, criterion, optimizer) + print(f" Epoch {epoch+1:2d}/{args.epochs} loss={loss:.4f}") + + torch.save(model.state_dict(), ckpt_path) + print(f" Saved → {ckpt_path}") + + test_ds = PH2WholeDataset(records[test_idx].tolist(), TRANSFORM) + test_loader = DataLoader(test_ds, batch_size=args.batch, shuffle=False, num_workers=4) + auroc = evaluate(model, test_loader) + aurocs.append(auroc) + print(f" Fold {fold} test AUROC: {auroc:.4f}") + + mean, std = np.mean(aurocs), np.std(aurocs, ddof=1) + print(f"\nPH2 whole — 5-fold AUROC: {mean:.4f} ±{std:.4f}") + print("Per-fold:", [f"{a:.4f}" for a in aurocs]) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..39dc73942 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -55,6 +55,9 @@ def __init__(self, *args, **kwargs): from .dreamt import DREAMTDataset from .ehrshot import EHRShotDataset from .eicu import eICUDataset +from .isic2018 import ISIC2018Dataset +from .isic2018_artifacts import ISIC2018ArtifactsDataset +from .ph2 import PH2Dataset from .isruc import ISRUCDataset from .medical_transcriptions import MedicalTranscriptionsDataset from .mimic3 import MIMIC3Dataset diff --git a/pyhealth/datasets/configs/isic2018.yaml b/pyhealth/datasets/configs/isic2018.yaml new file mode 100644 index 000000000..2c0f33b7f --- /dev/null +++ b/pyhealth/datasets/configs/isic2018.yaml @@ -0,0 +1,16 @@ +version: "1.0" +tables: + isic2018: + file_path: "isic2018-metadata-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "path" + - "image_id" + - "mel" + - "nv" + - "bcc" + - "akiec" + - "bkl" + - "df" + - "vasc" diff --git a/pyhealth/datasets/configs/ph2.yaml b/pyhealth/datasets/configs/ph2.yaml new file mode 100644 index 000000000..23b5693a2 --- /dev/null +++ b/pyhealth/datasets/configs/ph2.yaml @@ -0,0 +1,13 @@ +version: "1.0" +dataset_name: ph2 + +tables: + ph2: + file_path: ph2_metadata_pyhealth.csv + patient_id: image_id + visit_id: image_id + timestamp: null + timestamp_format: null + attributes: + - path + - diagnosis diff --git a/pyhealth/datasets/isic2018.py b/pyhealth/datasets/isic2018.py new file mode 100644 index 000000000..21e80b9c3 --- /dev/null +++ b/pyhealth/datasets/isic2018.py @@ -0,0 +1,491 @@ +""" +Unified PyHealth dataset for ISIC 2018 Tasks. + +This module provides :class:`ISIC2018Dataset`, a single dataset class that +covers both: + +* ``task="task3"`` — 7-class skin lesion **classification** (HAM10000 / Task 3). + Downloads images and ``ISIC2018_Task3_Training_GroundTruth.csv``. + +* ``task="task1_2"`` — Lesion **segmentation** & attribute detection (Task 1/2). + Downloads images and binary segmentation masks. + +Both modes support ``download=True`` for automatic data acquisition from the +official ISIC 2018 challenge S3 archive. + +The module also exports the URL / directory-name constants and helper functions +(``_download_file``, ``_extract_zip``) that are re-used by +:class:`~pyhealth.datasets.ISIC2018ArtifactsDataset`. + +Dataset link: + https://challenge.isic-archive.com/data/#2018 + +Licenses: + Task 1/2 (segmentation & attribute detection): + CC-0 (Public Domain) — https://creativecommons.org/public-domain/cc0/ + + Task 3 (classification): + CC-BY-NC 4.0 — https://creativecommons.org/licenses/by-nc/4.0/ + Attribution required — see references below. + +References: + [1] Noel Codella et al. "Skin Lesion Analysis Toward Melanoma Detection + 2018: A Challenge Hosted by the International Skin Imaging Collaboration + (ISIC)", 2018; https://arxiv.org/abs/1902.03368 + + [2] Tschandl et al. "The HAM10000 dataset, a large collection of + multi-source dermatoscopic images of common pigmented skin lesions." + Sci. Data 5, 180161 (2018). https://doi.org/10.1038/sdata.2018.161 +""" + +import hashlib +import logging +import os +import shutil +import zipfile +from functools import wraps +from pathlib import Path +from typing import Dict, List, Optional + +import pandas as pd +import requests +import yaml + +from pyhealth.datasets import BaseDataset + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Public constants (also imported by isic2018_artifacts.py) +# --------------------------------------------------------------------------- + +TASK12_IMAGES_URL: str = ( + "https://isic-archive.s3.amazonaws.com/challenges/2018/" + "ISIC2018_Task1-2_Training_Input.zip" +) +TASK12_MASKS_URL: str = ( + "https://isic-archive.s3.amazonaws.com/challenges/2018/" + "ISIC2018_Task1_Training_GroundTruth.zip" +) +TASK12_IMAGES_DIR: str = "ISIC2018_Task1-2_Training_Input" +TASK12_MASKS_DIR: str = "ISIC2018_Task1_Training_GroundTruth" +_T12_IMAGES_ZIP = "ISIC2018_Task1-2_Training_Input.zip" +_T12_MASKS_ZIP = "ISIC2018_Task1_Training_GroundTruth.zip" + +_T3_IMAGES_URL = ( + "https://isic-archive.s3.amazonaws.com/challenges/2018/" + "ISIC2018_Task3_Training_Input.zip" +) +_T3_LABELS_URL = ( + "https://isic-archive.s3.amazonaws.com/challenges/2018/" + "ISIC2018_Task3_Training_GroundTruth.zip" +) +_T3_IMAGES_DIR = "ISIC2018_Task3_Training_Input" +_T3_GROUNDTRUTH_CSV = "ISIC2018_Task3_Training_GroundTruth.csv" +_T3_IMAGES_ZIP = "ISIC2018_Task3_Training_Input.zip" +_T3_LABELS_ZIP = "ISIC2018_Task3_Training_GroundTruth.zip" + +VALID_TASKS = ("task3", "task1_2") + +# MD5 checksums for ISIC 2018 files. Update with values from archive. +# To compute: python -c "import hashlib; h=hashlib.md5();" +# "f=open('file.zip','rb'); h.update(f.read()); print(h.hexdigest())" +# +# Verified checksums (downloaded and computed): +# - ISIC2018_Task3_Training_GroundTruth.zip: verified ✓ +# - ISIC2018_Task1_Training_GroundTruth.zip: verified ✓ +# +# MD5 checksums for all four ZIP archives (sourced from archive.org metadata). +_CHECKSUMS: Dict[str, str] = { + "ISIC2018_Task3_Training_GroundTruth.zip": "8302427e4ce0c107559531b9f444abe9", + "ISIC2018_Task3_Training_Input.zip": "0c281f121070a8d63457caffcdec439a", + "ISIC2018_Task1-2_Training_Input.zip": "8b5be801f37b58ccf533df2928a5906b", + "ISIC2018_Task1_Training_GroundTruth.zip": "ee5e5db7771d48fa2613abc7cb5c24e2", +} + + +# --------------------------------------------------------------------------- +# Public download helpers (also imported by isic2018_artifacts.py) +# --------------------------------------------------------------------------- + +def _download_file( + url: str, + dest: str, + expected_md5: Optional[str] = None) -> None: + """Stream *url* to *dest* with 1 MB chunks, logging % progress. + + Args: + url: Source URL to download from. + dest: Destination file path. + expected_md5: Expected MD5 checksum (optional). If provided, verifies + downloaded file integrity and raises ValueError if mismatch. + + Raises: + requests.HTTPError: If the HTTP request fails. + ValueError: If MD5 checksum verification fails. + """ + with requests.get(url, stream=True, timeout=300) as response: + response.raise_for_status() + total = int(response.headers.get("content-length", 0)) + downloaded = 0 + md5_hash = hashlib.md5() + + with open(dest, "wb") as fh: + for chunk in response.iter_content(chunk_size=1024 * 1024): + fh.write(chunk) + md5_hash.update(chunk) + downloaded += len(chunk) + if total: + logger.info( + " %.1f%% (%d / %d bytes)", + downloaded / total * 100, + downloaded, + total, + ) + + if expected_md5 is not None: + actual_md5 = md5_hash.hexdigest() + if actual_md5 != expected_md5: + os.remove(dest) + raise ValueError( + f"MD5 checksum mismatch for {os.path.basename(dest)}\n" + f" Expected: {expected_md5}\n" + f" Got: {actual_md5}\n" + f"Download was corrupted or incomplete. File removed." + ) + + +def _extract_zip(zip_path: str, dest_dir: str, flatten: bool = False) -> None: + """Safely extract zip, guarding against path-traversal attacks. + + Args: + zip_path: Path to zip file to extract. + dest_dir: Destination directory. + flatten: If True and zip has a single top-level directory, extract + its contents to dest_dir (flattening structure). If False, + extract normally (preserving directory structure). + """ + abs_dest = os.path.abspath(dest_dir) + with zipfile.ZipFile(zip_path, "r") as zf: + # Security check: prevent path traversal + for member in zf.infolist(): + member_path = os.path.abspath( + os.path.join(abs_dest, member.filename)) + if not member_path.startswith(abs_dest + os.sep): + raise ValueError( + f"Unsafe path in zip archive: '{ + member.filename}'") + + if flatten: + # Check if all files are in a single top-level directory + names = zf.namelist() + if names: + # Get top-level entries + top_level = set() + for name in names: + parts = name.split('/') + if parts[0]: # Skip empty parts from trailing slashes + top_level.add(parts[0]) + + # If only one top-level item and it's a directory, flatten it + if len(top_level) == 1: + top_dir = list(top_level)[0] + # Check if it's a directory (has trailing slash or contains + # files) + is_dir = any(name.startswith(top_dir + '/') + for name in names) + if is_dir: + # Extract to temp location and move contents up + temp_dir = os.path.join(abs_dest, '.extract_temp') + os.makedirs(temp_dir, exist_ok=True) + zf.extractall(temp_dir) + + # Move contents of top_dir to dest_dir + source_dir = os.path.join(temp_dir, top_dir) + for item in os.listdir(source_dir): + src = os.path.join(source_dir, item) + dst = os.path.join(abs_dest, item) + if os.path.isdir(src): + os.makedirs(dst, exist_ok=True) + shutil.copytree(src, dst, dirs_exist_ok=True) + else: + os.makedirs( + os.path.dirname(dst), exist_ok=True) + shutil.copy2(src, dst) + + # Clean up temp dir + shutil.rmtree(temp_dir) + return + + # Otherwise, extract normally + zf.extractall(dest_dir) + + +# --------------------------------------------------------------------------- +# Dataset class +# --------------------------------------------------------------------------- + +class ISIC2018Dataset(BaseDataset): + """Unified ISIC 2018 dataset for Task 1/2 (segmentation) or Task 3 (classification). + + Args: + root (str): Root directory. Defaults to ".". + task (str): Which ISIC 2018 task to load. One of: + + - ``"task3"`` (default) — 7-class skin lesion classification. + Downloads images + ``ISIC2018_Task3_Training_GroundTruth.csv``. + - ``"task1_2"`` — Lesion segmentation & attribute detection. + Downloads images + binary segmentation masks. + + download (bool): Download missing data automatically. Defaults to False. + **kwargs: Forwarded to BaseDataset. + + .. note:: + **Licenses differ by task:** + + * ``task="task1_2"`` — **CC-0** (public domain). + No attribution required. + * ``task="task3"`` — **CC-BY-NC 4.0**. + Attribution is required; commercial use is not permitted. + See https://challenge.isic-archive.com/data/#2018 for citation details. + + Raises: + ValueError: If task is not one of VALID_TASKS. + FileNotFoundError: If required paths are missing and download=False. + requests.HTTPError: If download fails. + + task="task3" directory layout:: + + / + ISIC2018_Task3_Training_GroundTruth.csv + ISIC2018_Task3_Training_Input/ + ISIC_0024306.jpg ... + + task="task1_2" directory layout:: + + / + ISIC2018_Task1-2_Training_Input/ + ISIC_0024306.jpg ... + ISIC2018_Task1_Training_GroundTruth/ + ISIC_0024306_segmentation.png ... + + Event attributes for task="task3" (table "isic2018"): + image_id, path, mel, nv, bcc, akiec, bkl, df, vasc + + Event attributes for task="task1_2" (table "isic2018_task12"): + image_id, path, mask_path (empty string if mask absent) + + Example:: + >>> dataset = ISIC2018Dataset(root="/data/isic", task="task3", download=True) + >>> dataset = ISIC2018Dataset(root="/data/isic", task="task1_2", download=True) + """ + + classes: List[str] = ["mel", "nv", "bcc", "akiec", "bkl", "df", "vasc"] + + def __init__(self, root=".", task="task3", download=False, **kwargs): + if task not in VALID_TASKS: + raise ValueError( + f"task must be one of {VALID_TASKS}, got '{task}'") + self.task = task + + if task == "task3": + self._image_dir = os.path.join(root, _T3_IMAGES_DIR) + self._label_path = os.path.join(root, _T3_GROUNDTRUTH_CSV) + else: # task1_2 + self._image_dir = os.path.join(root, TASK12_IMAGES_DIR) + self._mask_dir = os.path.join(root, TASK12_MASKS_DIR) + + if download: + self._download(root) + + self._verify_data(root) + config_path = self._index_data(root) + + table = "isic2018" if task == "task3" else "isic2018_task12" + super().__init__( + root=root, + tables=[table], + dataset_name="ISIC2018", + config_path=config_path, + **kwargs, + ) + + @property + def default_task(self): + if self.task == "task3": + from pyhealth.tasks import ISIC2018Classification + return ISIC2018Classification() + return None # No native segmentation task yet + + @wraps(BaseDataset.set_task) + def set_task(self, *args, **kwargs): + return super().set_task(*args, **kwargs) + + def _download(self, root): + os.makedirs(root, exist_ok=True) + if self.task == "task3": + if not os.path.isfile(self._label_path): + zip_path = os.path.join(root, _T3_LABELS_ZIP) + # Skip download if ZIP already exists (may be + # partial/incomplete) + if not os.path.isfile(zip_path): + logger.info("Downloading ISIC 2018 Task 3 labels...") + _download_file( + _T3_LABELS_URL, + zip_path, + _CHECKSUMS.get(_T3_LABELS_ZIP)) + if os.path.isfile(zip_path): + _extract_zip(zip_path, root, flatten=True) + os.remove(zip_path) + if not os.path.isdir(self._image_dir): + zip_path = os.path.join(root, _T3_IMAGES_ZIP) + # Skip download if ZIP already exists (may be + # partial/incomplete) + if not os.path.isfile(zip_path): + logger.info( + "Downloading ISIC 2018 Task 3 images (~8 GB)...") + _download_file( + _T3_IMAGES_URL, + zip_path, + _CHECKSUMS.get(_T3_IMAGES_ZIP)) + if os.path.isfile(zip_path): + _extract_zip(zip_path, root, flatten=False) + os.remove(zip_path) + else: # task1_2 + if not os.path.isdir(self._image_dir): + zip_path = os.path.join(root, _T12_IMAGES_ZIP) + # Skip download if ZIP already exists (may be + # partial/incomplete) + if not os.path.isfile(zip_path): + logger.info( + "Downloading ISIC 2018 Task 1/2 images (~8 GB)...") + _download_file( + TASK12_IMAGES_URL, + zip_path, + _CHECKSUMS.get(_T12_IMAGES_ZIP)) + if os.path.isfile(zip_path): + _extract_zip(zip_path, root, flatten=False) + os.remove(zip_path) + if not os.path.isdir(self._mask_dir): + zip_path = os.path.join(root, _T12_MASKS_ZIP) + # Skip download if ZIP already exists (may be + # partial/incomplete) + if not os.path.isfile(zip_path): + logger.info("Downloading ISIC 2018 Task 1 masks...") + _download_file( + TASK12_MASKS_URL, + zip_path, + _CHECKSUMS.get(_T12_MASKS_ZIP)) + if os.path.isfile(zip_path): + _extract_zip(zip_path, root, flatten=False) + os.remove(zip_path) + + def _verify_data(self, root): + if not os.path.exists(root): + raise FileNotFoundError(f"Dataset root not found: {root}") + if not os.path.isdir(self._image_dir): + raise FileNotFoundError( + f"Image directory not found: {self._image_dir}\n" + "Use download=True or obtain manually from " + "https://challenge.isic-archive.com/data/#2018" + ) + if self.task == "task3": + if not os.path.isfile(self._label_path): + raise FileNotFoundError( + f"Ground-truth CSV not found: {self._label_path}\n" + "Use download=True or obtain manually from " + "https://challenge.isic-archive.com/data/#2018" + ) + if not list(Path(self._image_dir).glob("*.jpg")): + raise ValueError( + f"No JPEG images found in '{ + self._image_dir}'") + else: # task1_2 + if not os.path.isdir(self._mask_dir): + raise FileNotFoundError( + f"Mask directory not found: {self._mask_dir}\n" + "Use download=True or obtain manually from " + "https://challenge.isic-archive.com/data/#2018" + ) + + def _index_data(self, root): + if self.task == "task3": + return self._index_task3(root) + return self._index_task12(root) + + def _index_task3(self, root): + df = pd.read_csv(self._label_path) + image_names = {f.stem for f in Path(self._image_dir).glob("*.jpg")} + df = df[df["image"].isin(image_names)].copy() + df.rename(columns={c.upper(): c for c in self.classes}, inplace=True) + df.rename(columns={"image": "image_id"}, inplace=True) + df["patient_id"] = df["image_id"] + df["path"] = df["image_id"].apply( + lambda img_id: os.path.join(self._image_dir, f"{img_id}.jpg") + ) + metadata_path = os.path.join(root, "isic2018-metadata-pyhealth.csv") + df.to_csv(metadata_path, index=False) + + config = { + "version": "1.0", + "tables": { + "isic2018": { + "file_path": "isic2018-metadata-pyhealth.csv", + "patient_id": "patient_id", + "timestamp": None, + "attributes": ["path", "image_id"] + list(self.classes), + } + }, + } + config_path = os.path.join(root, "isic2018-config-pyhealth.yaml") + with open(config_path, "w") as fh: + yaml.dump(config, fh, default_flow_style=False, sort_keys=False) + logger.info( + "ISIC2018Dataset (task3): indexed %d images → %s", + len(df), + metadata_path) + return config_path + + def _index_task12(self, root): + image_dir = Path(self._image_dir) + mask_dir = Path(self._mask_dir) + images = sorted(image_dir.glob("*.jpg")) + \ + sorted(image_dir.glob("*.JPG")) + if not images: + raise ValueError(f"No images found in '{self._image_dir}'") + records = [] + for img_path in images: + image_id = img_path.stem + mask_path = mask_dir / f"{image_id}_segmentation.png" + records.append({ + "image_id": image_id, + "patient_id": image_id, + "path": str(img_path), + "mask_path": str(mask_path) if mask_path.exists() else None, + }) + df = pd.DataFrame(records) + metadata_path = os.path.join( + root, "isic2018-task12-metadata-pyhealth.csv") + df.to_csv(metadata_path, index=False) + config = { + "version": "1.0", + "tables": { + "isic2018_task12": { + "file_path": "isic2018-task12-metadata-pyhealth.csv", + "patient_id": "patient_id", + "timestamp": None, + "attributes": ["path", "mask_path", "image_id"], + } + }, + } + config_path = os.path.join( + root, "isic2018-task12-config-pyhealth.yaml") + with open(config_path, "w") as fh: + yaml.dump(config, fh, default_flow_style=False, sort_keys=False) + logger.info( + "ISIC2018Dataset (task1_2): indexed %d images (%d with masks) → %s", + len(df), + (df["mask_path"] != "").sum(), + metadata_path, + ) + return config_path diff --git a/pyhealth/datasets/isic2018_artifacts.py b/pyhealth/datasets/isic2018_artifacts.py new file mode 100644 index 000000000..bcb280261 --- /dev/null +++ b/pyhealth/datasets/isic2018_artifacts.py @@ -0,0 +1,504 @@ +""" +PyHealth dataset for dermoscopy images with per-image artifact annotations. + +Overview +-------- +Dermoscopy images frequently contain non-clinical artifacts — visual elements +introduced during image acquisition that are unrelated to the underlying +pathology. When these artifacts correlate with diagnostic labels in training +data they create *spurious shortcuts* that inflate reported model accuracy +without capturing genuine disease features. + +This dataset pairs any collection of dermoscopy images with a per-image +artifact annotation CSV. The default annotation file is ``isic_bias.csv`` +from Bissoto et al. (2020), which was created for the ISIC 2018 Task 1/2 +image set, but the dataset class accepts any CSV that follows the same column +format. + +Default annotation source +-------------------------- +Using ``ISIC2018ArtifactsDataset`` requires **two separate downloads**: + +1. **Artifact annotations** (``isic_bias.csv``) — Bissoto et al. (2020): + https://github.com/alceubissoto/debiasing-skin + + See ``artefacts-annotation/`` in that repository for the annotation files. + + Reference: + Bissoto et al. "Debiasing Skin Lesion Datasets and Models? Not So Fast" + ISIC Skin Image Analysis Workshop @ CVPR 2020 + +2. **ISIC 2018 Task 1/2 images and segmentation masks** (~8 GB): + https://challenge.isic-archive.com/data/#2018 + + * Training images: ``ISIC2018_Task1-2_Training_Input.zip`` + * Segmentation masks: ``ISIC2018_Task1_Training_GroundTruth.zip`` + +Both can be fetched automatically by passing ``download=True`` to the +constructor (see class docs for details). + +Artifact types +-------------- +The default CSV provides seven binary artifact labels per image +(1 = present, 0 = absent). Any additional binary columns in a custom CSV +are also preserved and accessible on events. + +================= ============================================================= +Label Description +================= ============================================================= +``dark_corner`` Dark vignetting at the image periphery, typically from the + dermoscope lens edge. +``hair`` Hair strands crossing the field of view and obscuring skin + surface details. +``gel_border`` Visible boundary of the contact gel or immersion fluid used + during dermoscopic examination. +``gel_bubble`` Air bubbles trapped in the contact gel, appearing as + circular bright reflections. +``ruler`` Measurement scale or ruler placed in the frame for size + reference. +``ink`` Ink or marker pen markings drawn on the skin before + acquisition (e.g., for surgical planning). +``patches`` Adhesive patches or stickers visible in the image. +================= ============================================================= + +Image preprocessing +-------------------- +Image preprocessing is not handled by this dataset. Pass an +``input_processors={"image": }`` argument to ``set_task`` to +apply any image transformation. The dataset exposes ``dataset.mask_dir`` +so downstream processors can locate segmentation masks without hard-coding +the path. + +CSV format +---------- +The annotation CSV must be **semicolon-delimited** (``sep=";"``), with an +unnamed integer index as the first column, and must contain: + +* ``image`` — image filename (must match files present in ``image_dir``). +* ``label`` — binary classification label (1 = malignant, 0 = benign). + +Any additional columns are preserved as event attributes. Columns beginning +with ``split_`` are treated as fold / trap-set assignment columns. + +Cross-validation and trap-set protocol +--------------------------------------- +When using the default Bissoto et al. CSV, five-fold splits +(``split_1`` … ``split_5``) support standard K-fold evaluation. + +The *trap-set* protocol (Bissoto et al. 2020) studies whether models trained +on artifact-biased data learn spurious correlations. +""" + +import logging +import os +from pathlib import Path +from typing import List + +import pandas as pd +import requests +import yaml + +from pyhealth.datasets import BaseDataset +from pyhealth.datasets.isic2018 import ( + TASK12_IMAGES_DIR as _IMAGES_DIR, + TASK12_IMAGES_URL as _IMAGES_URL, + TASK12_MASKS_DIR as _MASKS_DIR, + TASK12_MASKS_URL as _MASKS_URL, + _T12_IMAGES_ZIP as _IMAGES_ZIP, + _T12_MASKS_ZIP as _MASKS_ZIP, + _CHECKSUMS as _ISIC_CHECKSUMS, + _download_file, + _extract_zip, +) +logger = logging.getLogger(__name__) + +_BIAS_CSV = "isic_bias.csv" # default Bissoto et al. annotation filename +_BIAS_CSV_URL = ( + "https://raw.githubusercontent.com/alceubissoto/debiasing-skin/" + "master/artefacts-annotation/isic_bias.csv" +) + +#: The seven dermoscopic artifact categories annotated in ``isic_bias.csv``. +#: Each label is a binary column (1 = artifact present, 0 = absent). +#: See the module docstring for a detailed description of each type. +ARTIFACT_LABELS: List[str] = [ + "dark_corner", # dark lens-edge vignetting + "hair", # hair strands overlapping the field of view + "gel_border", # visible contact-gel or immersion-fluid boundary + "gel_bubble", # air bubbles in the contact gel + "ruler", # measurement scale / ruler in the frame + "ink", # ink or marker-pen markings on the skin + "patches", # adhesive patches or stickers +] + + +class ISIC2018ArtifactsDataset(BaseDataset): + """PyHealth dataset for dermoscopy images with per-image artifact annotations. + + Pairs a directory of dermoscopy images with an artifact annotation CSV. + Any CSV that contains an ``image`` column (filenames), a ``label`` column + (binary classification target), and one or more binary artifact columns is + supported. The default CSV is ``isic_bias.csv`` from Bissoto et al. (2020), + annotated on the ISIC 2018 Task 1/2 image set. + + Image preprocessing is decoupled from this dataset. Supply an image + processor via ``input_processors={"image": }`` when calling + ``set_task``. The resolved mask directory is available as + ``dataset.mask_dir`` for convenience. + + Attributes: + artifact_labels (List[str]): The seven well-known artifact types from + Bissoto et al. (2020). Any subset present in the CSV is exposed. + mask_dir (str): Resolved absolute path to the segmentation-mask + directory. + + The expected directory structure under ``root`` is:: + + / + ← downloadable via download=True + / ← downloadable via download=True + ISIC_0024306.jpg + ISIC_0024307.jpg + ... + / ← downloadable via download=True; only + ISIC_0024306_segmentation.png required for non-"whole" modes + ... + + When ``download=True`` is used with the default annotation CSV, the + following files are fetched automatically: + + * ``isic_bias.csv`` — from the Bissoto et al. GitHub repository. + * ``ISIC2018_Task1-2_Training_Input/`` — ISIC 2018 Task 1/2 training + images (~8 GB), extracted from the ISIC S3 archive. + * ``ISIC2018_Task1_Training_GroundTruth/`` — ISIC 2018 Task 1 + segmentation masks, extracted from the ISIC S3 archive. + + Pass ``image_dir="ISIC2018_Task1-2_Training_Input"`` and + ``mask_dir="ISIC2018_Task1_Training_GroundTruth"`` to match the + extracted layout. + + Example — Bissoto et al. default CSV with on-demand download:: + + >>> dataset = ISIC2018ArtifactsDataset( + ... root="/data/isic", + ... image_dir="ISIC2018_Task1-2_Training_Input", + ... mask_dir="ISIC2018_Task1_Training_GroundTruth", + ... download=True, # fetches CSV + ~8 GB images on first run + ... ) + + Example — images already on disk:: + + >>> dataset = ISIC2018ArtifactsDataset( + ... root="/data/isic", + ... image_dir="2018_train_task1-2", + ... mask_dir="2018_train_task1-2_segmentations", + ... ) + >>> sample_ds = dataset.set_task(dataset.default_task, input_processors={"image": my_processor}) + + Example — custom annotation CSV:: + + >>> dataset = ISIC2018ArtifactsDataset( + ... root="/data/my_dataset", + ... annotations_csv="my_annotations.csv", + ... image_dir="images", + ... mask_dir="masks", + ... ) + """ + + artifact_labels: List[str] = ARTIFACT_LABELS + + def __init__( + self, + root: str = ".", + annotations_csv: str = _BIAS_CSV, + image_dir: str = "images", + mask_dir: str = "masks", + download: bool = False, + **kwargs, + ) -> None: + """Initialise the artifact dataset. + + Args: + root (str): Root directory containing the annotation CSV, the + image directory, and the segmentation-mask directory. + annotations_csv (str): Filename of the annotation CSV inside + ``root``. Defaults to ``"isic_bias.csv"`` (Bissoto et al. + 2020). The file must contain at minimum an ``image`` column + (filenames) and a ``label`` column (binary target). + image_dir (str): Sub-directory name (or absolute path) for the + dermoscopy images. Defaults to ``"images"``. + mask_dir (str): Sub-directory name (or absolute path) for the + segmentation masks. Defaults to ``"masks"``. The resolved + path is exposed as ``dataset.mask_dir``. + download (bool): If ``True`` and ``annotations_csv`` is the + default ``"isic_bias.csv"``, download all missing data + automatically: + + * ``isic_bias.csv`` — from the Bissoto et al. GitHub repo. + * ISIC 2018 Task 1/2 training images (~8 GB) — from the + ISIC S3 archive; extracted to + ``/ISIC2018_Task1-2_Training_Input/``. + * ISIC 2018 Task 1 segmentation masks — from the ISIC S3 + archive; extracted to + ``/ISIC2018_Task1_Training_GroundTruth/``. + + Pass ``image_dir="ISIC2018_Task1-2_Training_Input"`` and + ``mask_dir="ISIC2018_Task1_Training_GroundTruth"`` to use the + extracted directories. Raises :exc:`ValueError` if used + with a custom ``annotations_csv``. Defaults to ``False``. + **kwargs: Additional keyword arguments forwarded to + :class:`~pyhealth.datasets.BaseDataset`. + + Raises: + ValueError: If ``download=True`` is used with a custom + ``annotations_csv``, or no images match the CSV. + FileNotFoundError: If ``root``, the annotation CSV, the image + directory, or the mask directory is missing. + requests.HTTPError: If ``download=True`` and the CSV download + fails. + """ + if download and annotations_csv != _BIAS_CSV: + raise ValueError( + "download=True is only supported for the default " + f"annotations_csv='{_BIAS_CSV}'. " + "Provide your own CSV or omit the download flag." + ) + + self.annotations_csv = annotations_csv + + self._image_dir = (image_dir if os.path.isabs( + image_dir) else os.path.join(root, image_dir)) + self.mask_dir = (mask_dir if os.path.isabs( + mask_dir) else os.path.join(root, mask_dir)) + self._bias_csv_path = os.path.join(root, annotations_csv) + + if download: + self._download_bias_csv(root) + self._download_images(root) + + self._verify_data(root) + config_path = self._index_data(root) + + super().__init__( + root=root, + tables=["isic_artifacts"], + dataset_name="ISICArtifact", + config_path=config_path, + **kwargs, + ) + + @property + def default_task(self): + """Return the default task for this dataset. + + Returns: + None: No default task is registered until the ISIC task classes + are available. Use ``dataset.set_task(task)`` directly. + """ + return None + + def _download_bias_csv(self, root: str) -> None: + """Download the default Bissoto et al. ``isic_bias.csv`` from GitHub. + + Skips the download if the file already exists. + + Args: + root: Dataset root directory where the CSV will be saved. + + Raises: + requests.HTTPError: If the HTTP request returns an error status. + """ + if os.path.isfile(self._bias_csv_path): + logger.info( + "%s already present, skipping download.", + self.annotations_csv) + return + + os.makedirs(root, exist_ok=True) + logger.info("Downloading %s from GitHub...", self.annotations_csv) + response = requests.get(_BIAS_CSV_URL, timeout=60) + response.raise_for_status() + with open(self._bias_csv_path, "wb") as fh: + fh.write(response.content) + logger.info("Saved %s → %s", self.annotations_csv, self._bias_csv_path) + + def _download_images(self, root: str) -> None: + """Download and extract ISIC 2018 Task 1/2 images and masks. + + Skips download if ZIP files already exist (they may be partial/incomplete). + Always attempts extraction if ZIP is present. The images archive is + ~8 GB — this may take several minutes. + + Args: + root: Dataset root directory. + + Raises: + requests.HTTPError: If any HTTP request returns an error status. + ValueError: If MD5 checksum verification fails. + """ + os.makedirs(root, exist_ok=True) + + images_dest = os.path.join(root, _IMAGES_DIR) + if not os.path.isdir(images_dest): + zip_path = os.path.join(root, _IMAGES_ZIP) + # Skip download if ZIP already exists (may be partial/incomplete) + if not os.path.isfile(zip_path): + logger.info( + "Downloading ISIC 2018 Task 1/2 images (~8 GB): %s", + _IMAGES_URL) + _download_file( + _IMAGES_URL, + zip_path, + _ISIC_CHECKSUMS.get(_IMAGES_ZIP)) + if os.path.isfile(zip_path): + logger.info("Extracting images to %s ...", root) + _extract_zip(zip_path, root) + os.remove(zip_path) + logger.info("Images ready at %s", images_dest) + else: + logger.info( + "Image directory already present, skipping: %s", + images_dest) + + masks_dest = os.path.join(root, _MASKS_DIR) + if not os.path.isdir(masks_dest): + zip_path = os.path.join(root, _MASKS_ZIP) + # Skip download if ZIP already exists (may be partial/incomplete) + if not os.path.isfile(zip_path): + logger.info( + "Downloading ISIC 2018 Task 1 segmentation masks: %s", + _MASKS_URL) + _download_file( + _MASKS_URL, + zip_path, + _ISIC_CHECKSUMS.get(_MASKS_ZIP)) + if os.path.isfile(zip_path): + logger.info("Extracting masks to %s ...", root) + _extract_zip(zip_path, root) + os.remove(zip_path) + logger.info("Masks ready at %s", masks_dest) + else: + logger.info( + "Mask directory already present, skipping: %s", + masks_dest) + + def _verify_data(self, root: str) -> None: + """Check required paths exist and raise informative errors if not.""" + if not os.path.exists(root): + raise FileNotFoundError(f"Dataset root does not exist: {root}") + + if not os.path.isfile(self._bias_csv_path): + msg = f"Annotation CSV not found: {self._bias_csv_path}" + if self.annotations_csv == _BIAS_CSV: + msg += ( + "\nDownload it from: " + "https://github.com/alceubissoto/debiasing-skin" + "/tree/main/artefacts-annotation" + "\nOr pass download=True to fetch it automatically." + ) + raise FileNotFoundError(msg) + + if not os.path.isdir(self._image_dir): + raise FileNotFoundError( + f"Image directory not found: {self._image_dir}\n" + "Download images with download=True (requires ~8 GB), or " + "obtain them manually from: " + "https://challenge.isic-archive.com/data/#2018" + ) + + if not os.path.isdir(self.mask_dir): + raise FileNotFoundError( + f"Mask directory not found: {self.mask_dir}\n" + "Download masks with download=True, or " + "obtain them manually from: " + "https://challenge.isic-archive.com/data/#2018" + ) + + def _index_data(self, root: str) -> str: + """Parse ``isic_bias.csv`` and write a metadata CSV + YAML config. + + All columns present in ``isic_bias.csv`` (artifact labels, split + columns, trap-set columns) are preserved so tasks can filter by any + of them. + + Args: + root: Dataset root directory. + + Returns: + str: Path to the generated YAML configuration file. + + Raises: + ValueError: If no images from the CSV are found on disk. + """ + df = pd.read_csv(self._bias_csv_path, sep=";", index_col=0) + + # Keep only rows whose image file exists on disk. + # Resolve extension mismatches (e.g. CSV lists .png but files are + # .jpg). + available = {f.name: f.name for f in Path( + self._image_dir).iterdir() if f.is_file()} + stem_to_actual = { + Path(name).stem: name for name in available + } + + def _resolve_filename(csv_name: str) -> str: + """Return the actual on-disk filename, resolving extension mismatches.""" + if csv_name in available: + return csv_name + stem = Path(csv_name).stem + return stem_to_actual.get(stem, "") + + df["image"] = df["image"].apply(_resolve_filename) + df = df[df["image"] != ""].copy() + + if df.empty: + raise ValueError( + f"No matching images found in '{self._image_dir}'. " + "Ensure image filenames in the annotation CSV correspond to " + "files present in the image directory." + ) + + # Derive image_id (stem without extension) and patient_id + df["image_id"] = df["image"].str.replace( + r"\.[A-Za-z]+$", "", regex=True) + df["patient_id"] = df["image_id"] + + # Absolute path to the image file + df["path"] = df["image"].apply( + lambda name: os.path.join(self._image_dir, name) + ) + + metadata_path = os.path.join( + root, "isic-artifact-metadata-pyhealth.csv") + df.to_csv(metadata_path, index=False) + + # Build YAML config dynamically so all CSV columns are accessible + fixed_attrs = ["path", "image_id", "label"] + artifact_attrs = [c for c in ARTIFACT_LABELS if c in df.columns] + split_attrs = sorted(c for c in df.columns if c.startswith("split_")) + attributes = fixed_attrs + artifact_attrs + split_attrs + + config: dict = { + "version": "1.0", + "tables": { + "isic_artifacts": { + "file_path": "isic-artifact-metadata-pyhealth.csv", + "patient_id": "patient_id", + "timestamp": None, + "attributes": attributes, + } + }, + } + + config_path = os.path.join(root, "isic-artifact-config-pyhealth.yaml") + tmp_path = config_path + f".tmp.{os.getpid()}" + with open(tmp_path, "w") as fh: + yaml.dump(config, fh, default_flow_style=False, sort_keys=False) + os.replace(tmp_path, config_path) # atomic on Linux — safe for parallel workers + + logger.info( + "ISIC2018ArtifactsDataset: indexed %d images → %s", + len(df), + metadata_path, + ) + return config_path diff --git a/pyhealth/datasets/ph2.py b/pyhealth/datasets/ph2.py new file mode 100644 index 000000000..8943d9b39 --- /dev/null +++ b/pyhealth/datasets/ph2.py @@ -0,0 +1,332 @@ +""" +PyHealth dataset for the PH2 dermoscopic image database. + +The PH2 dataset contains 200 dermoscopic images of melanocytic lesions with +three diagnostic categories: common nevus, atypical nevus, and melanoma. + +Dataset source +-------------- +Original dataset: + https://www.fc.up.pt/addi/ph2%20database.html (requires registration) + +Mirror used by ``download=True``: + https://github.com/vikaschouhan/PH2-dataset + (200 JPEGs in ``images/``, labels in ``PH2_simple_dataset.csv``) + + +Directory structure expected under ``root`` +------------------------------------------- +**Option A – downloaded via** ``download=True`` **(GitHub mirror)**:: + + / + images/ + IMD002.jpg + IMD003.jpg + ... + PH2_simple_dataset.csv # image_name, diagnosis + +**Option B – original PH2 release**:: + + / + PH2_dataset.xlsx # official Excel (12 header rows) + — OR — + PH2_dataset.csv # user-converted CSV + PH2_Dataset_images/ + IMD001/ + IMD001_Dermoscopic_Image/ + IMD001.bmp + ... + +After the first call, a ``ph2_metadata_pyhealth.csv`` file is written next to +the source files and reused on subsequent loads instead of re-parsing. +""" + +import logging +import os +import zipfile +from pathlib import Path +from typing import Optional + +import pandas as pd +import requests + +from pyhealth.datasets import BaseDataset +from pyhealth.datasets.isic2018 import _download_file, _extract_zip + +logger = logging.getLogger(__name__) + +_MIRROR_ZIP_URL = ( + "https://github.com/vikaschouhan/PH2-dataset/archive/refs/heads/master.zip" +) +_IMAGES_DIR = "images" +_SIMPLE_CSV = "PH2_simple_dataset.csv" + +# Canonical label strings stored in ph2_metadata_pyhealth.csv +_LABEL_MAP = { + "Common Nevus": "common_nevus", + "Atypical Nevus": "atypical_nevus", + "Melanoma": "melanoma", +} + + +class PH2Dataset(BaseDataset): + """Base image dataset for the PH2 dermoscopic image database. + + The PH2 dataset contains 200 dermoscopic images of melanocytic lesions + in three diagnostic categories: common nevus, atypical nevus, and melanoma. + + Args: + root: Path to the directory containing the PH2 source files. + download: If ``True``, automatically download data from the GitHub + mirror when the image directory is absent. Requires ~30 MB. + Defaults to ``False``. + dataset_name: Optional override for the internal dataset identifier. + Defaults to ``"ph2"``. + config_path: Optional path to a custom YAML config. Defaults to the + bundled ``configs/ph2.yaml``. + cache_dir: Optional directory for litdata cache. + dev: If ``True``, load only the first 1 000 records (for quick testing). + + Raises: + FileNotFoundError: If required source files are missing and + ``download=False``. + requests.HTTPError: If ``download=True`` and the download fails. + + Examples: + >>> dataset = PH2Dataset(root="/path/to/ph2", download=True) + >>> from pyhealth.tasks import PH2MelanomaClassification + >>> samples = dataset.set_task(PH2MelanomaClassification()) + """ + + def __init__( + self, + root: str, + download: bool = False, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + cache_dir=None, + num_workers: int = 1, + dev: bool = False, + ) -> None: + if config_path is None: + config_path = str(Path(__file__).parent / "configs" / "ph2.yaml") + + if download: + self._download(root) + + self._verify_data(root) + + metadata_path = Path(root) / "ph2_metadata_pyhealth.csv" + if not metadata_path.exists(): + self._prepare_metadata(root) + + super().__init__( + root=root, + tables=["ph2"], + dataset_name=dataset_name or "ph2", + config_path=config_path, + cache_dir=cache_dir, + num_workers=num_workers, + dev=dev, + ) + + # ------------------------------------------------------------------ + # Download + # ------------------------------------------------------------------ + + def _download(self, root: str) -> None: + """Download PH2 data from the GitHub mirror. + + Downloads and extracts ~30 MB. Skips if ``images/`` already exists. + + Args: + root: Target directory. + + Raises: + requests.HTTPError: If the download request fails. + """ + os.makedirs(root, exist_ok=True) + images_dest = os.path.join(root, _IMAGES_DIR) + + if os.path.isdir(images_dest): + logger.info("PH2 images already present at %s, skipping download.", images_dest) + return + + zip_path = os.path.join(root, "ph2_mirror.zip") + if not os.path.isfile(zip_path): + logger.info("Downloading PH2 mirror from GitHub: %s", _MIRROR_ZIP_URL) + _download_file(_MIRROR_ZIP_URL, zip_path) + + logger.info("Extracting PH2 archive to %s ...", root) + self._extract_mirror(zip_path, root) + os.remove(zip_path) + logger.info("PH2 data ready at %s", root) + + @staticmethod + def _extract_mirror(zip_path: str, dest: str) -> None: + """Extract GitHub mirror zip, flattening the top-level directory. + + The GitHub archive has a single top-level folder + (``PH2-dataset-master/``). This method moves ``images/`` and + ``PH2_simple_dataset.csv`` directly into *dest*. + """ + with zipfile.ZipFile(zip_path, "r") as zf: + members = zf.namelist() + # Determine top-level prefix (e.g. "PH2-dataset-master/") + prefix = members[0].split("/")[0] + "/" if members else "" + + for member in members: + # Strip the top-level prefix + rel = member[len(prefix):] + if not rel: + continue + target = os.path.join(dest, rel) + if member.endswith("/"): + os.makedirs(target, exist_ok=True) + else: + os.makedirs(os.path.dirname(target), exist_ok=True) + with zf.open(member) as src, open(target, "wb") as out: + out.write(src.read()) + + # ------------------------------------------------------------------ + # Validation + # ------------------------------------------------------------------ + + def _verify_data(self, root: str) -> None: + """Raise informative errors if required source files are missing.""" + if not os.path.exists(root): + raise FileNotFoundError(f"Dataset root does not exist: {root}") + + has_simple_csv = os.path.isfile(os.path.join(root, _SIMPLE_CSV)) + has_orig_xlsx = os.path.isfile(os.path.join(root, "PH2_dataset.xlsx")) + has_orig_csv = os.path.isfile(os.path.join(root, "PH2_dataset.csv")) + has_images = os.path.isdir(os.path.join(root, _IMAGES_DIR)) + has_bmp_images = os.path.isdir(os.path.join(root, "PH2_Dataset_images")) + + if not (has_simple_csv or has_orig_xlsx or has_orig_csv): + raise FileNotFoundError( + f"No PH2 metadata file found in {root}.\n" + "Expected one of: PH2_simple_dataset.csv, PH2_dataset.xlsx, PH2_dataset.csv\n" + "Pass download=True to fetch data automatically." + ) + + if not (has_images or has_bmp_images): + raise FileNotFoundError( + f"No PH2 image directory found in {root}.\n" + "Expected 'images/' (GitHub mirror) or 'PH2_Dataset_images/' (original).\n" + "Pass download=True to fetch data automatically." + ) + + # ------------------------------------------------------------------ + # Metadata preparation + # ------------------------------------------------------------------ + + def _prepare_metadata(self, root: str) -> None: + """Parse source files and write ``ph2_metadata_pyhealth.csv``. + + Supports two source layouts: + + * **GitHub mirror**: ``PH2_simple_dataset.csv`` + ``images/IMDXXX.jpg`` + * **Original PH2**: ``PH2_dataset.xlsx`` / ``PH2_dataset.csv`` + + ``PH2_Dataset_images/IMDXXX/IMDXXX_Dermoscopic_Image/IMDXXX.bmp`` + + Args: + root: Directory containing the PH2 source files. + """ + logger.info("Processing PH2 metadata…") + root_path = Path(root) + + has_simple_csv = (root_path / _SIMPLE_CSV).exists() + has_orig_images = (root_path / "PH2_Dataset_images").exists() + + # Prefer original BMP format when PH2_Dataset_images/ is present, + # even if PH2_simple_dataset.csv is also in the directory. + if has_orig_images or not has_simple_csv: + df = self._load_original(root_path) + else: + df = self._load_simple_csv(root_path) + + output_path = root_path / "ph2_metadata_pyhealth.csv" + df[["image_id", "path", "diagnosis"]].to_csv(str(output_path), index=False) + logger.info("Saved PH2 metadata to %s (%d images)", output_path, len(df)) + + def _load_simple_csv(self, root: Path) -> pd.DataFrame: + """Load GitHub mirror format (flat JPEGs + PH2_simple_dataset.csv).""" + df = pd.read_csv(str(root / _SIMPLE_CSV)) + df = df.rename(columns={"image_name": "image_id"}) + + # Normalise diagnosis strings + df["diagnosis"] = df["diagnosis"].map( + lambda v: _LABEL_MAP.get(str(v).strip(), "Unknown") + ) + + image_dir = root / _IMAGES_DIR + + def _path(img_id: str) -> Optional[str]: + for ext in (".jpg", ".jpeg", ".png", ".bmp"): + p = image_dir / f"{img_id}{ext}" + if p.exists(): + return str(p) + return None + + df["path"] = df["image_id"].apply(_path) + df = df.dropna(subset=["path"]) + df = df[df["diagnosis"] != "Unknown"] + return df + + def _load_original(self, root: Path) -> pd.DataFrame: + """Load original PH2 format (nested BMPs + Excel/CSV).""" + xlsx = root / "PH2_dataset.xlsx" + csv = root / "PH2_dataset.csv" + + if xlsx.exists(): + raw = pd.read_excel(str(xlsx), header=12) + elif csv.exists(): + raw = pd.read_csv(str(csv)) + else: + raise FileNotFoundError( + f"Could not find PH2_dataset.xlsx or PH2_dataset.csv in {root}" + ) + + raw = raw.rename( + columns={ + "Image Name": "image_id", + "Common Nevus": "common_nevus", + "Atypical Nevus": "atypical_nevus", + "Melanoma": "melanoma", + } + ) + + def _diagnosis(row: pd.Series) -> str: + if row.get("melanoma") == "X": + return "melanoma" + if row.get("atypical_nevus") == "X": + return "atypical_nevus" + if row.get("common_nevus") == "X": + return "common_nevus" + return "Unknown" + + raw["diagnosis"] = raw.apply(_diagnosis, axis=1) + + image_root = root / "PH2_Dataset_images" + + def _path(img_id: str) -> Optional[str]: + p = image_root / img_id / f"{img_id}_Dermoscopic_Image" / f"{img_id}.bmp" + return str(p) if p.exists() else None + + raw["path"] = raw["image_id"].apply(_path) + raw = raw.dropna(subset=["path"]) + raw = raw[raw["diagnosis"] != "Unknown"] + return raw + + # ------------------------------------------------------------------ + # Default task + # ------------------------------------------------------------------ + + @property + def default_task(self): + """Return the default task (:class:`~pyhealth.tasks.PH2MelanomaClassification`).""" + from pyhealth.tasks import PH2MelanomaClassification + + return PH2MelanomaClassification() + diff --git a/pyhealth/datasets/utils.py b/pyhealth/datasets/utils.py index 24c87a1d5..84780ec33 100644 --- a/pyhealth/datasets/utils.py +++ b/pyhealth/datasets/utils.py @@ -329,7 +329,10 @@ def collate_fn_dict_with_padding(batch: List[dict]) -> dict: def get_dataloader( - dataset: litdata.StreamingDataset, batch_size: int, shuffle: bool = False + dataset: litdata.StreamingDataset, + batch_size: int, + shuffle: bool = False, + num_workers: int = 0, ) -> DataLoader: """Creates a DataLoader for a given dataset. @@ -337,6 +340,8 @@ def get_dataloader( dataset: The dataset to load data from. batch_size: The number of samples per batch. shuffle: Whether to shuffle the data at every epoch. + num_workers: Number of subprocesses for data loading (default: 0, + meaning data is loaded in the main process). Returns: A DataLoader instance for the dataset. @@ -345,6 +350,7 @@ def get_dataloader( dataloader = DataLoader( dataset, batch_size=batch_size, + num_workers=num_workers, collate_fn=collate_fn_dict_with_padding, ) diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py index b48072270..4fd5d329e 100644 --- a/pyhealth/processors/__init__.py +++ b/pyhealth/processors/__init__.py @@ -50,6 +50,7 @@ def get_processor(name: str): from .ignore_processor import IgnoreProcessor from .temporal_timeseries_processor import TemporalTimeseriesProcessor from .tuple_time_text_processor import TupleTimeTextProcessor +from .dermoscopic_image_processor import DermoscopicImageProcessor # Expose public API from .base_processor import ( @@ -79,4 +80,5 @@ def get_processor(name: str): "GraphProcessor", "AudioProcessor", "TupleTimeTextProcessor", + "DermoscopicImageProcessor", ] diff --git a/pyhealth/processors/dermoscopic_image_processor.py b/pyhealth/processors/dermoscopic_image_processor.py new file mode 100644 index 000000000..c48742288 --- /dev/null +++ b/pyhealth/processors/dermoscopic_image_processor.py @@ -0,0 +1,322 @@ +""" +Dermoscopic image processor for ISIC 2018 artifact experiments. + +Implements the 12 preprocessing modes from 'A Study of Artifacts on Melanoma Classification under +Diffusion-Based Perturbations', +adapted as a PyHealth :class:`~pyhealth.processors.base_processor.FeatureProcessor`. + +Modes +----- +``whole`` + Full image, no masking. +``lesion`` + Lesion region only (image multiplied by binary segmentation mask). +``background`` + Background region only (image multiplied by inverted mask). +``bbox`` + Full image with the lesion bounding-box blacked out. +``bbox70`` + Full image with an expanded bounding-box (≈70 % of image area) blacked out. +``bbox90`` + Same as ``bbox70`` but ≈90 % of image area. +``high_whole`` / ``high_lesion`` / ``high_background`` + High-pass–filtered version of the respective region. +``low_whole`` / ``low_lesion`` / ``low_background`` + Low-pass–filtered version of the respective region. +``blur_bg`` + Lesion region kept sharp; background blurred with a Gaussian low-pass + filter. Composite of the original lesion pixels and the blurred + background pixels. +``gray_whole`` + Full image converted to grayscale and broadcast back to 3 channels. + Removes all colour information while preserving spatial structure. +``whole_norm`` + Full image with per-channel min-max normalisation applied (each channel + stretched to [0, 255]). + +Filter backend +-------------- +Uses ``scipy.ndimage.gaussian_filter`` (truncate=4.0, σ=1 → effective ~9×9 +kernel). High-pass output is raw float residuals cast to uint8 with no +normalisation, faithfully replicating the reference implementation +(``dermoscopic_artifacts/datasets.py``). +""" + +import os +from pathlib import Path +from typing import Any, Union + +import numpy as np +import scipy.ndimage +import torchvision.transforms as transforms +from PIL import Image + +from .base_processor import FeatureProcessor + +#: All valid mode identifiers. +VALID_MODES = ( + "whole", + "lesion", + "background", + "bbox", + "bbox70", + "bbox90", + "high_whole", + "high_lesion", + "high_background", + "low_whole", + "low_lesion", + "low_background", + "blur_bg", + "gray_whole", + "whole_norm", +) + +#: Modes that operate on the full image and do not require a segmentation mask. +MASK_FREE_MODES = frozenset( + ("whole", "high_whole", "low_whole", "gray_whole", "whole_norm") +) + +_IMAGENET_MEAN = [0.485, 0.456, 0.406] +_IMAGENET_STD = [0.229, 0.224, 0.225] + + +def _high_pass_filter( + image: np.ndarray, + sigma: float = 1, + grayscale: bool = True, +) -> np.ndarray: + """Return a high-pass–filtered image (3-channel uint8 output). + + Args: + sigma: Gaussian sigma for the low-pass kernel. + grayscale: If ``True`` (default), convert to BT.601 grayscale first, + apply HPF on the single channel, then stack to 3 channels — + matches ``high_pass_filter(image, grayscale=True)`` in the + reference. If ``False``, apply HPF independently on each RGB + channel. + """ + if grayscale: + image_gray = np.dot(image[..., :3], [0.2989, 0.587, 0.114]) + low_frequencies = scipy.ndimage.gaussian_filter(image_gray, sigma=sigma) + high_frequencies = image_gray - low_frequencies + out = np.stack([high_frequencies] * 3, axis=-1) + else: + out = np.empty(image.shape[:2] + (3,), dtype=np.float32) + for c in range(3): + ch = image[:, :, c].astype(np.float64) + low_frequencies = scipy.ndimage.gaussian_filter(ch, sigma=sigma) + out[:, :, c] = ch - low_frequencies + return out.astype(np.uint8) + + +def _low_pass_filter(image: np.ndarray, sigma: float = 1) -> np.ndarray: + """Return a Gaussian-blurred (low-pass) image (uint8 output).""" + return scipy.ndimage.gaussian_filter(image, sigma=sigma).astype(np.uint8) + + +class DermoscopicImageProcessor(FeatureProcessor): + """Load and preprocess a dermoscopy image according to a named mode. + + Mirrors the ``ISICDataset.__getitem__`` preprocessing logic from the + ``dermoscopic_artifacts`` experiment codebase so that PyHealth training + scripts reproduce the same pixel-level transformations. + + Args: + mask_dir: Directory containing ``*_segmentation.png`` masks. + Required for all modes except ``"whole"``. + mode: One of the valid mode strings (see module docstring). + Defaults to ``"whole"``. + image_size: Square resize target. Defaults to 224. + sigma: Standard deviation for the Gaussian filter used in ``high_*`` + and ``low_*`` modes. Defaults to ``1.0``. + + Raises: + ValueError: If *mode* is not in :data:`VALID_MODES`. + """ + + def __init__( + self, + mask_dir: str = "", + mode: str = "whole", + image_size: int = 224, + sigma: float = 1.0, + high_grayscale: bool = True, + ) -> None: + if mode not in VALID_MODES: + raise ValueError( + f"Invalid mode '{mode}'. Choose from: {VALID_MODES}" + ) + self.mask_dir = mask_dir + self.mode = mode + self.image_size = image_size + self.sigma = sigma + self.high_grayscale = high_grayscale + + self.transform = transforms.Compose([ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD), + ]) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _load_image_and_mask(self, image_path: str): + """Return ``(image_rgb, mask_binary)`` as uint8 numpy arrays. + + For mask-free modes (``whole``, ``high_whole``, ``low_whole``, + ``gray_whole``, ``whole_norm``) the mask is a dummy all-ones array + and no mask file is read from disk. + """ + try: + image = np.array(Image.open(image_path).convert("RGB")) + except Exception as exc: + raise FileNotFoundError(f"Image not found: {image_path}") from exc + + if self.mode in MASK_FREE_MODES: + mask = np.ones(image.shape[:2], dtype=np.uint8) + return image, mask + + img_name = os.path.basename(image_path) + stem = Path(img_name).stem + mask_path = os.path.join(self.mask_dir, f"{stem}_segmentation.png") + try: + mask = np.array(Image.open(mask_path).convert("L")) + except Exception as exc: + raise FileNotFoundError(f"Mask not found: {mask_path}") from exc + + if image.shape[:2] != mask.shape: + mask = np.array( + Image.fromarray(mask).resize( + (image.shape[1], image.shape[0]), Image.NEAREST + ) + ) + mask = (mask > 0).astype(np.uint8) + return image, mask + + def _apply_mode(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray: + """Apply the configured mode and return a uint8 RGB numpy array.""" + if self.mode == "whole": + return image + + if self.mode == "lesion": + return image * mask[:, :, np.newaxis] + + if self.mode == "background": + return image * (1 - mask[:, :, np.newaxis]) + + if self.mode in ("bbox", "bbox70", "bbox90"): + y_idxs, x_idxs = np.where(mask > 0) + if len(y_idxs) == 0: + return np.zeros_like(image) + + y_min, y_max = int(y_idxs.min()), int(y_idxs.max()) + x_min, x_max = int(x_idxs.min()), int(x_idxs.max()) + + if self.mode == "bbox": + out = image.copy() + out[y_min:y_max + 1, x_min:x_max + 1] = 0 + return out + + expand_ratio = 0.7 if self.mode == "bbox70" else 0.9 + img_h, img_w = image.shape[:2] + bbox_h = max(y_max - y_min, 1) + bbox_w = max(x_max - x_min, 1) + target_area = expand_ratio * img_h * img_w + cy, cx = (y_min + y_max) // 2, (x_min + x_max) // 2 + new_h = int(np.sqrt(target_area * bbox_h / bbox_w)) + new_w = int(np.sqrt(target_area * bbox_w / bbox_h)) + y_min = max(0, cy - new_h // 2) + y_max = min(img_h, cy + new_h // 2) + x_min = max(0, cx - new_w // 2) + x_max = min(img_w, cx + new_w // 2) + out = image.copy() + out[y_min:y_max + 1, x_min:x_max + 1] = 0 + return out + + # Blur background, keep lesion sharp — alpha blend across the boundary + if self.mode == "blur_bg": + blurred = _low_pass_filter(image.astype(np.uint8), sigma=self.sigma) + alpha = mask[:, :, np.newaxis].astype(np.float32) # 0.0 or 1.0 + sharp = image.astype(np.float32) + soft = blurred.astype(np.float32) + return (alpha * sharp + (1.0 - alpha) * soft).astype(np.uint8) + + # Grayscale whole image — broadcast single channel back to 3 + if self.mode == "gray_whole": + gray = np.dot(image[..., :3], [0.2989, 0.5870, 0.1140]).astype(np.uint8) + return np.stack([gray] * 3, axis=-1) + + # Per-channel min-max normalisation of whole image + if self.mode == "whole_norm": + out = np.empty_like(image) + for c in range(3): + ch = image[:, :, c].astype(np.float32) + mn, mx = ch.min(), ch.max() + out[:, :, c] = ((ch - mn) / (mx - mn) * 255).astype(np.uint8) if mx > mn else ch.astype(np.uint8) + return out + + # Frequency-filter modes + if "whole" in self.mode: + base = image + elif "lesion" in self.mode: + base = image * mask[:, :, np.newaxis] + else: # background + base = image * (1 - mask[:, :, np.newaxis]) + + if self.mode.startswith("high_"): + return _high_pass_filter( + base, + sigma=self.sigma, + grayscale=self.high_grayscale, + ) + # low_* modes + return _low_pass_filter(base, sigma=self.sigma) + + # ------------------------------------------------------------------ + # FeatureProcessor interface + # ------------------------------------------------------------------ + + def process(self, value: Union[str, Path]) -> Any: + """Load image at *value*, apply mode preprocessing, return tensor. + + Args: + value: Absolute path to the dermoscopy image. + + Returns: + Float32 tensor of shape ``(3, image_size, image_size)``, + normalised with ImageNet statistics. + """ + image_path = str(value) + + if self.mode == "whole": + try: + image = np.array(Image.open(image_path).convert("RGB")) + except Exception as exc: + raise FileNotFoundError(f"Image not found: {image_path}") from exc + else: + image, mask = self._load_image_and_mask(image_path) + image = self._apply_mode(image, mask) + + pil_image = Image.fromarray(image.astype(np.uint8)) + return self.transform(pil_image) + + def is_token(self) -> bool: + return False + + def schema(self) -> tuple[str, ...]: + return ("value",) + + def dim(self) -> tuple[int, ...]: + return (3,) + + def spatial(self) -> tuple[bool, ...]: + return (False, True, True) + + def __repr__(self) -> str: + return ( + f"DermoscopicImageProcessor(mode={self.mode!r}, " + f"image_size={self.image_size}, mask_dir={self.mask_dir!r})" + ) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..def34f4cf 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -11,6 +11,8 @@ ) from .chestxray14_binary_classification import ChestXray14BinaryClassification from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification +from .isic2018_classification import ISIC2018Classification +from .isic2018_artifacts_classification import ISIC2018ArtifactsBinaryClassification from .covid19_cxr_classification import COVID19CXRClassification from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4 from .drug_recommendation import ( @@ -66,3 +68,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .ph2_melanoma_classification import PH2MelanomaClassification diff --git a/pyhealth/tasks/isic2018_artifacts_classification.py b/pyhealth/tasks/isic2018_artifacts_classification.py new file mode 100644 index 000000000..864864cd1 --- /dev/null +++ b/pyhealth/tasks/isic2018_artifacts_classification.py @@ -0,0 +1,85 @@ +""" +PyHealth task for binary melanoma classification using the ISIC 2018 artifact +annotation dataset (Bissoto et al. 2020). + +Dataset link: + https://challenge.isic-archive.com/data/#2018 + +Annotation source: + Bissoto et al. "Debiasing Skin Lesion Datasets and Models? Not So Fast" + ISIC Skin Image Analysis Workshop @ CVPR 2020 + https://github.com/alceubissoto/debiasing-skin + +License: + CC-0 (Public Domain) — https://creativecommons.org/public-domain/cc0/ +""" + +import logging +from typing import Dict, List + +from pyhealth.data import Event, Patient +from pyhealth.tasks import BaseTask + +logger = logging.getLogger(__name__) + + +class ISIC2018ArtifactsBinaryClassification(BaseTask): + """Binary melanoma classification task for the ISIC 2018 artifact dataset. + + Each dermoscopy image is mapped to a binary label (1 = malignant, + 0 = benign) as provided by the Bissoto et al. ``isic_bias.csv`` + annotation file. + + This task is designed for use with + :class:`~pyhealth.datasets.ISIC2018ArtifactsDataset`. The dataset's + ``set_task`` method automatically injects a + :class:`~pyhealth.processors.DermoscopicImageProcessor`, so the + ``mode`` (e.g. ``"whole"``, ``"lesion"``) is controlled at the dataset + level, not here. + + Attributes: + task_name (str): Unique task identifier. + input_schema (Dict[str, str]): Maps ``"image"`` to the ``"image"`` + processor type. + output_schema (Dict[str, str]): Maps ``"label"`` to ``"binary"``. + + Examples: + >>> from pyhealth.datasets import ISIC2018ArtifactsDataset + >>> from pyhealth.tasks import ISIC2018ArtifactsBinaryClassification + >>> dataset = ISIC2018ArtifactsDataset( + ... root="/path/to/data", + ... image_dir="ISIC2018_Task1-2_Training_Input", + ... mask_dir="ISIC2018_Task1_Training_GroundTruth", + ... mode="whole", + ... ) + >>> task = ISIC2018ArtifactsBinaryClassification() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "ISIC2018ArtifactsBinaryClassification" + input_schema: Dict[str, str] = {"image": "image"} + output_schema: Dict[str, str] = {"label": "binary"} + + def __call__(self, patient: Patient) -> List[Dict]: + """Generate binary classification samples for a single patient. + + Args: + patient: A :class:`~pyhealth.data.Patient` object containing at + least one ``"isic_artifacts"`` event. + + Returns: + A list with one dict per image:: + + [{"image": "/abs/path/to/ISIC_XXXXXXX.png", "label": 0}, ...] + """ + events: List[Event] = patient.get_events(event_type="isic_artifacts") + + samples = [] + for event in events: + samples.append( + { + "image": event["path"], + "label": int(event["label"]), + } + ) + return samples diff --git a/pyhealth/tasks/isic2018_classification.py b/pyhealth/tasks/isic2018_classification.py new file mode 100644 index 000000000..788959e21 --- /dev/null +++ b/pyhealth/tasks/isic2018_classification.py @@ -0,0 +1,89 @@ +""" +PyHealth task for multiclass classification using the ISIC 2018 dataset. + +Dataset link: + https://challenge.isic-archive.com/data/#2018 + +License: + CC-BY-NC 4.0 (https://creativecommons.org/licenses/by-nc/4.0/) + +Dataset paper: (please cite if you use this dataset) + [1] Noel Codella, Veronica Rotemberg, Philipp Tschandl, et al. "Skin Lesion + Analysis Toward Melanoma Detection 2018: A Challenge Hosted by the + International Skin Imaging Collaboration (ISIC)", 2018; + https://arxiv.org/abs/1902.03368 + + [2] Tschandl, P., Rosendahl, C. & Kittler, H. "The HAM10000 dataset, a large + collection of multi-source dermatoscopic images of common pigmented skin + lesions." Sci. Data 5, 180161 (2018). + +Dataset paper link: + https://doi.org/10.1038/sdata.2018.161 +""" + +import logging +from typing import Dict, List + +from pyhealth.data import Event, Patient +from pyhealth.tasks import BaseTask + +logger = logging.getLogger(__name__) + + +class ISIC2018Classification(BaseTask): + """ + A PyHealth task class for multiclass skin lesion classification using the + ISIC 2018 Task 3 dataset. + + The task maps each dermoscopy image to one of seven skin lesion categories. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for the task input. + output_schema (Dict[str, str]): The schema for the task output. + + Examples: + >>> from pyhealth.datasets import ISIC2018Dataset + >>> from pyhealth.tasks import ISIC2018Classification + >>> dataset = ISIC2018Dataset(root="/path/to/isic2018") + >>> task = ISIC2018Classification() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "ISIC2018Classification" + input_schema: Dict[str, str] = {"image": "image"} + output_schema: Dict[str, str] = {"label": "multiclass"} + + def __call__(self, patient: Patient) -> List[Dict]: + """ + Generates multiclass classification data samples for a single patient. + + Args: + patient (Patient): A patient object containing at least one + 'isic2018' event. + + Returns: + List[Dict]: A list containing a dictionary for each image with: + - 'image': path to the dermoscopy image. + - 'label': the skin lesion class label (str) from + ISIC2018Dataset.classes. + """ + events: List[Event] = patient.get_events(event_type="isic2018") + + samples = [] + from pyhealth.datasets import ISIC2018Dataset # Avoid circular import + + for event in events: + label = next( + (cls for cls in ISIC2018Dataset.classes if float(event[cls])), + None, + ) + if label is not None: + samples.append( + { + "image": event["path"], + "label": label, + } + ) + + return samples diff --git a/pyhealth/tasks/ph2_melanoma_classification.py b/pyhealth/tasks/ph2_melanoma_classification.py new file mode 100644 index 000000000..cd1f42cfc --- /dev/null +++ b/pyhealth/tasks/ph2_melanoma_classification.py @@ -0,0 +1,64 @@ +""" +PyHealth task for multiclass melanoma classification using the PH2 dataset. + +Dataset source: + https://www.kaggle.com/datasets/spacesurfer/ph2-dataset +""" + +import logging +from typing import Dict, List + +from pyhealth.data import Event, Patient +from pyhealth.tasks import BaseTask + +logger = logging.getLogger(__name__) + + +class PH2MelanomaClassification(BaseTask): + """Multiclass lesion classification task for the PH2 dataset. + + Each dermoscopic image is classified into one of three categories: + ``"common_nevus"``, ``"atypical_nevus"``, or ``"melanoma"``. + + This task is designed for use with + :class:`~pyhealth.datasets.PH2Dataset`. + + Attributes: + task_name (str): Unique task identifier. + input_schema (Dict[str, str]): Maps ``"image"`` to the ``"image"`` + processor type. + output_schema (Dict[str, str]): Maps ``"label"`` to ``"multiclass"``. + + Examples: + >>> from pyhealth.datasets import PH2Dataset + >>> from pyhealth.tasks import PH2MelanomaClassification + >>> dataset = PH2Dataset(root="/path/to/ph2") + >>> samples = dataset.set_task(PH2MelanomaClassification()) + """ + + task_name: str = "PH2MelanomaClassification" + input_schema: Dict[str, str] = {"image": "image"} + output_schema: Dict[str, str] = {"label": "multiclass"} + + def __call__(self, patient: Patient) -> List[Dict]: + """Generate multiclass classification samples for a single patient. + + Args: + patient: A :class:`~pyhealth.data.Patient` object containing at + least one ``"ph2"`` event. + + Returns: + A list with one dict per image:: + + [{"image": "/abs/path/img.bmp", "label": "melanoma"}, ...] + """ + events: List[Event] = patient.get_events(event_type="ph2") + + samples = [] + for event in events: + path = event["path"] + diagnosis = event["diagnosis"] + if not path or not diagnosis: + continue + samples.append({"image": path, "label": diagnosis}) + return samples diff --git a/test-resources/core/isic2018/ISIC2018_Task1-2_Training_Input/.gitkeep b/test-resources/core/isic2018/ISIC2018_Task1-2_Training_Input/.gitkeep new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/test-resources/core/isic2018/ISIC2018_Task1-2_Training_Input/.gitkeep @@ -0,0 +1 @@ + diff --git a/test-resources/core/isic2018/ISIC2018_Task1_Training_GroundTruth/.gitkeep b/test-resources/core/isic2018/ISIC2018_Task1_Training_GroundTruth/.gitkeep new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/test-resources/core/isic2018/ISIC2018_Task1_Training_GroundTruth/.gitkeep @@ -0,0 +1 @@ + diff --git a/test-resources/core/isic2018/ISIC2018_Task3_Training_GroundTruth.csv b/test-resources/core/isic2018/ISIC2018_Task3_Training_GroundTruth.csv new file mode 100644 index 000000000..b3ae5c58d --- /dev/null +++ b/test-resources/core/isic2018/ISIC2018_Task3_Training_GroundTruth.csv @@ -0,0 +1,11 @@ +image,MEL,NV,BCC,AKIEC,BKL,DF,VASC +ISIC_0024306,0.0,1.0,0.0,0.0,0.0,0.0,0.0 +ISIC_0024307,0.0,1.0,0.0,0.0,0.0,0.0,0.0 +ISIC_0024308,0.0,1.0,0.0,0.0,0.0,0.0,0.0 +ISIC_0024309,0.0,1.0,0.0,0.0,0.0,0.0,0.0 +ISIC_0024310,1.0,0.0,0.0,0.0,0.0,0.0,0.0 +ISIC_0024311,0.0,1.0,0.0,0.0,0.0,0.0,0.0 +ISIC_0024312,0.0,0.0,0.0,0.0,1.0,0.0,0.0 +ISIC_0024313,1.0,0.0,0.0,0.0,0.0,0.0,0.0 +ISIC_0024314,0.0,1.0,0.0,0.0,0.0,0.0,0.0 +ISIC_0024315,1.0,0.0,0.0,0.0,0.0,0.0,0.0 diff --git a/test-resources/core/isic2018/ISIC2018_Task3_Training_Input/.gitkeep b/test-resources/core/isic2018/ISIC2018_Task3_Training_Input/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/test-resources/core/isic2018/isic2018-config-pyhealth.yaml b/test-resources/core/isic2018/isic2018-config-pyhealth.yaml new file mode 100644 index 000000000..2e59bd345 --- /dev/null +++ b/test-resources/core/isic2018/isic2018-config-pyhealth.yaml @@ -0,0 +1,16 @@ +version: '1.0' +tables: + isic2018: + file_path: isic2018-metadata-pyhealth.csv + patient_id: patient_id + timestamp: null + attributes: + - path + - image_id + - mel + - nv + - bcc + - akiec + - bkl + - df + - vasc diff --git a/test-resources/core/isic2018_artifacts/images/.gitkeep b/test-resources/core/isic2018_artifacts/images/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/test-resources/core/isic2018_artifacts/isic_bias.csv b/test-resources/core/isic2018_artifacts/isic_bias.csv new file mode 100644 index 000000000..d49e21672 --- /dev/null +++ b/test-resources/core/isic2018_artifacts/isic_bias.csv @@ -0,0 +1,9 @@ +;image;dark_corner;hair;gel_border;gel_bubble;ruler;ink;patches;label;label_string +0;ISIC_0024306.png;0;0;0;0;1;0;0;0;benign +1;ISIC_0024307.png;0;1;0;0;0;0;0;1;malignant +2;ISIC_0024308.png;1;0;0;0;0;0;0;0;benign +3;ISIC_0024309.png;0;0;0;0;1;1;0;1;malignant +4;ISIC_0024310.png;0;0;0;0;0;1;0;0;benign +5;ISIC_0024311.png;0;0;0;0;0;0;0;1;malignant +6;ISIC_0024312.png;0;0;1;0;0;0;0;0;benign +7;ISIC_0024313.png;0;0;0;0;0;0;1;1;malignant \ No newline at end of file diff --git a/test-resources/core/isic2018_artifacts/masks/.gitkeep b/test-resources/core/isic2018_artifacts/masks/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/test-resources/core/ph2/PH2_dataset.csv b/test-resources/core/ph2/PH2_dataset.csv new file mode 100644 index 000000000..28326f8b2 --- /dev/null +++ b/test-resources/core/ph2/PH2_dataset.csv @@ -0,0 +1,6 @@ +Image Name,Common Nevus,Atypical Nevus,Melanoma +IMD003,X,, +IMD009,X,, +IMD002,,X, +IMD004,,X, +IMD058,,,X diff --git a/test-resources/core/ph2/PH2_simple_dataset.csv b/test-resources/core/ph2/PH2_simple_dataset.csv new file mode 100644 index 000000000..c32d5e265 --- /dev/null +++ b/test-resources/core/ph2/PH2_simple_dataset.csv @@ -0,0 +1,6 @@ +image_name,diagnosis +IMD003,Common Nevus +IMD009,Common Nevus +IMD002,Atypical Nevus +IMD004,Atypical Nevus +IMD058,Melanoma diff --git a/tests/core/test_isic2018.py b/tests/core/test_isic2018.py new file mode 100644 index 000000000..31438c2e8 --- /dev/null +++ b/tests/core/test_isic2018.py @@ -0,0 +1,490 @@ +""" +Unit tests for the ISIC2018Dataset and ISIC2018Classification classes. +""" +import os +import tempfile +import unittest +import zipfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +from PIL import Image + +import requests + +from pyhealth.datasets import ISIC2018Dataset +from pyhealth.tasks import ISIC2018Classification + + +class TestISIC2018Dataset(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.root = ( + Path(__file__).parent.parent.parent + / "test-resources" + / "core" + / "isic2018" + ) + cls.generate_fake_images() + cls.cache_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + cls.dataset = ISIC2018Dataset(cls.root, cache_dir=cls.cache_dir.name) + + @classmethod + def tearDownClass(cls): + (cls.root / "isic2018-metadata-pyhealth.csv").unlink(missing_ok=True) + cls.delete_fake_images() + try: + cls.cache_dir.cleanup() + except Exception: + pass + cls.cache_dir = None + + @classmethod + def generate_fake_images(cls): + images_dir = cls.root / "ISIC2018_Task3_Training_Input" + with open(cls.root / "ISIC2018_Task3_Training_GroundTruth.csv", "r") as f: + lines = f.readlines() + + for line in lines[1:]: # Skip header row + image_id = line.split(",")[0].strip() + img = Image.fromarray( + np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8) + ) + img.save(images_dir / f"{image_id}.jpg") + + @classmethod + def delete_fake_images(cls): + for jpg in (cls.root / "ISIC2018_Task3_Training_Input").glob("*.jpg"): + jpg.unlink() + + def test_stats(self): + self.dataset.stats() + + def test_num_patients(self): + # Each ISIC image is its own patient + self.assertEqual(len(self.dataset.unique_patient_ids), 10) + + def test_default_task(self): + self.assertIsInstance( + self.dataset.default_task, + ISIC2018Classification) + + def test_metadata_csv_created(self): + self.assertTrue( + (self.root / "isic2018-metadata-pyhealth.csv").exists() + ) + + def test_event_fields(self): + # Patient ID equals image ID for ISIC images + patient = self.dataset.get_patient("ISIC_0024307") + events = patient.get_events() + + self.assertEqual(len(events), 1) + self.assertEqual(events[0]["image_id"], "ISIC_0024307") + self.assertEqual(events[0]["mel"], "0.0") + self.assertEqual(events[0]["nv"], "1.0") + self.assertEqual(events[0]["bcc"], "0.0") + self.assertEqual(events[0]["akiec"], "0.0") + self.assertEqual(events[0]["bkl"], "0.0") + self.assertEqual(events[0]["df"], "0.0") + self.assertEqual(events[0]["vasc"], "0.0") + + def test_event_fields_mel(self): + patient = self.dataset.get_patient("ISIC_0024310") + events = patient.get_events() + + self.assertEqual(len(events), 1) + self.assertEqual(events[0]["mel"], "1.0") + self.assertEqual(events[0]["nv"], "0.0") + + def test_all_label_columns_present(self): + # All 7 class columns must be accessible on every event + for cls in ISIC2018Dataset.classes: + for pid in self.dataset.unique_patient_ids: + event = self.dataset.get_patient(pid).get_events()[0] + self.assertIn(cls, event) + + def test_image_paths_exist(self): + for pid in self.dataset.unique_patient_ids: + event = self.dataset.get_patient(pid).get_events()[0] + self.assertTrue(os.path.isfile(event["path"])) + + def test_verify_data_missing_root(self): + with self.assertRaises(FileNotFoundError): + ISIC2018Dataset(root="/nonexistent/path") + + def test_verify_data_missing_csv(self): + import tempfile + with tempfile.TemporaryDirectory() as tmpdir: + with self.assertRaises(FileNotFoundError): + ISIC2018Dataset(root=tmpdir) + + +class TestISIC2018Task12Dataset(unittest.TestCase): + """Tests for ISIC2018Dataset with task='task1_2'.""" + + @classmethod + def setUpClass(cls): + cls.root = ( + Path(__file__).parent.parent.parent + / "test-resources" + / "core" + / "isic2018" + ) + cls.images_dir = cls.root / "ISIC2018_Task1-2_Training_Input" + cls.masks_dir = cls.root / "ISIC2018_Task1_Training_GroundTruth" + cls._generated_images = [] + cls._generated_masks = [] + cls._generate_fake_data() + cls.cache_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + cls.dataset = ISIC2018Dataset( + str(cls.root), task="task1_2", cache_dir=cls.cache_dir.name + ) + + @classmethod + def tearDownClass(cls): + for p in cls._generated_images + cls._generated_masks: + p.unlink(missing_ok=True) + for f in ["isic2018-task12-metadata-pyhealth.csv", + "isic2018-task12-config-pyhealth.yaml"]: + (cls.root / f).unlink(missing_ok=True) + try: + cls.cache_dir.cleanup() + except Exception: + pass + + @classmethod + def _generate_fake_data(cls): + image_ids = [f"ISIC_002430{i}" for i in range(5)] + for img_id in image_ids: + img_path = cls.images_dir / f"{img_id}.jpg" + Image.fromarray( + np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8) + ).save(img_path) + cls._generated_images.append(img_path) + + # Create a matching segmentation mask for every other image + mask_path = cls.masks_dir / f"{img_id}_segmentation.png" + Image.fromarray( + np.random.randint(0, 2, (64, 64), dtype=np.uint8) * 255 + ).save(mask_path) + cls._generated_masks.append(mask_path) + + # ── Construction ────────────────────────────────────────── + + def test_invalid_task_raises(self): + with self.assertRaises(ValueError): + ISIC2018Dataset(str(self.root), task="task99") + + def test_num_patients(self): + self.assertEqual(len(self.dataset.unique_patient_ids), 5) + + def test_default_task_is_none(self): + self.assertIsNone(self.dataset.default_task) + + # ── Config / metadata files ─────────────────────────────── + + def test_metadata_csv_created(self): + self.assertTrue( + (self.root / "isic2018-task12-metadata-pyhealth.csv").exists()) + + def test_config_yaml_created(self): + self.assertTrue( + (self.root / "isic2018-task12-config-pyhealth.yaml").exists()) + + # ── Event attributes ────────────────────────────────────── + + def test_event_has_image_id(self): + pid = sorted(self.dataset.unique_patient_ids)[0] + event = self.dataset.get_patient(pid).get_events()[0] + self.assertEqual(event["image_id"], pid) + + def test_event_has_path(self): + for pid in self.dataset.unique_patient_ids: + event = self.dataset.get_patient(pid).get_events()[0] + self.assertTrue(os.path.isfile(event["path"])) + + def test_event_has_mask_path(self): + # All 5 images have masks in our fixture + for pid in self.dataset.unique_patient_ids: + event = self.dataset.get_patient(pid).get_events()[0] + self.assertIn("mask_path", event) + self.assertIsNotNone(event["mask_path"]) + + def test_event_missing_mask_path_is_none(self): + # Temporarily remove one mask to simulate absence, re-index, check + pid = sorted(self.dataset.unique_patient_ids)[0] + mask_path = self.masks_dir / f"{pid}_segmentation.png" + tmp_path = self.masks_dir / f"{pid}_segmentation.png.bak" + mask_path.rename(tmp_path) + try: + with tempfile.TemporaryDirectory() as cache: + ds = ISIC2018Dataset( + str(self.root), task="task1_2", cache_dir=cache) + event = ds.get_patient(pid).get_events()[0] + self.assertIsNone(event["mask_path"]) + finally: + tmp_path.rename(mask_path) + + # ── Validation ──────────────────────────────────────────── + + def test_missing_mask_dir_raises(self): + with tempfile.TemporaryDirectory() as tmpdir: + # Copy only images, no masks dir + img_dir = Path(tmpdir) / "ISIC2018_Task1-2_Training_Input" + img_dir.mkdir() + Image.fromarray( + np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8) + ).save(img_dir / "ISIC_0000001.jpg") + with self.assertRaises(FileNotFoundError): + ISIC2018Dataset(tmpdir, task="task1_2") + + +class TestISIC2018VerifyDataEdgeCases(unittest.TestCase): + """Covers _verify_data raise paths not hit by the main test classes.""" + + def test_task3_missing_csv_raises(self): + """Image dir present but CSV absent → FileNotFoundError (line 267).""" + with tempfile.TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "ISIC2018_Task3_Training_Input").mkdir() + with self.assertRaises(FileNotFoundError): + ISIC2018Dataset(tmpdir, task="task3") + + def test_task3_empty_image_dir_raises(self): + """Image dir present but no .jpg files → ValueError (line 273).""" + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + (root / "ISIC2018_Task3_Training_Input").mkdir() + (root / "ISIC2018_Task3_Training_GroundTruth.csv").write_text( + "image,MEL,NV,BCC,AKIEC,BKL,DF,VASC\n" + ) + with self.assertRaises(ValueError): + ISIC2018Dataset(tmpdir, task="task3") + + def test_task12_empty_image_dir_raises(self): + """Image dir present but no .jpg files → ValueError (line 322).""" + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + (root / "ISIC2018_Task1-2_Training_Input").mkdir() + (root / "ISIC2018_Task1_Training_GroundTruth").mkdir() + with self.assertRaises(ValueError): + ISIC2018Dataset(tmpdir, task="task1_2") + + +class TestISIC2018Download(unittest.TestCase): + """Covers _download_file, _extract_zip, and the _download dispatch method.""" + + # ── _download_file ──────────────────────────────────────── + + def test_download_file_writes_content(self): + from pyhealth.datasets.isic2018 import _download_file + + mock_resp = MagicMock() + mock_resp.__enter__ = lambda s: mock_resp + mock_resp.__exit__ = MagicMock(return_value=False) + mock_resp.headers = {"content-length": "5"} + mock_resp.iter_content.return_value = [b"hello"] + + with tempfile.TemporaryDirectory() as tmpdir: + dest = str(Path(tmpdir) / "out.bin") + with patch( + "pyhealth.datasets.isic2018.requests.get", return_value=mock_resp + ): + _download_file("http://example.com/file", dest) + self.assertEqual(Path(dest).read_bytes(), b"hello") + + def test_download_file_verifies_md5_checksum(self): + from pyhealth.datasets.isic2018 import _download_file + import hashlib + + mock_resp = MagicMock() + mock_resp.__enter__ = lambda s: mock_resp + mock_resp.__exit__ = MagicMock(return_value=False) + mock_resp.headers = {"content-length": "5"} + mock_resp.iter_content.return_value = [b"hello"] + + # Correct MD5 for "hello" + correct_md5 = hashlib.md5(b"hello").hexdigest() + + with tempfile.TemporaryDirectory() as tmpdir: + dest = str(Path(tmpdir) / "out.bin") + with patch( + "pyhealth.datasets.isic2018.requests.get", return_value=mock_resp + ): + _download_file( + "http://example.com/file", dest, expected_md5=correct_md5 + ) + self.assertEqual(Path(dest).read_bytes(), b"hello") + + def test_download_file_raises_on_md5_mismatch(self): + from pyhealth.datasets.isic2018 import _download_file + + mock_resp = MagicMock() + mock_resp.__enter__ = lambda s: mock_resp + mock_resp.__exit__ = MagicMock(return_value=False) + mock_resp.headers = {"content-length": "5"} + mock_resp.iter_content.return_value = [b"hello"] + + with tempfile.TemporaryDirectory() as tmpdir: + dest = str(Path(tmpdir) / "out.bin") + with patch( + "pyhealth.datasets.isic2018.requests.get", return_value=mock_resp + ): + with self.assertRaises(ValueError) as ctx: + _download_file( + "http://example.com/file", + dest, + expected_md5="wronghash123", + ) + self.assertIn("MD5 checksum mismatch", str(ctx.exception)) + self.assertFalse(Path(dest).exists()) # File should be removed + + def test_download_file_propagates_http_error(self): + from pyhealth.datasets.isic2018 import _download_file + + mock_resp = MagicMock() + mock_resp.__enter__ = lambda s: mock_resp + mock_resp.__exit__ = MagicMock(return_value=False) + mock_resp.raise_for_status.side_effect = requests.HTTPError("404") + + with tempfile.TemporaryDirectory() as tmpdir: + with patch( + "pyhealth.datasets.isic2018.requests.get", return_value=mock_resp + ): + with self.assertRaises(requests.HTTPError): + _download_file("http://example.com/bad", + str(Path(tmpdir) / "f")) + + # ── _extract_zip ────────────────────────────────────────── + + def test_extract_zip_normal(self): + from pyhealth.datasets.isic2018 import _extract_zip + + with tempfile.TemporaryDirectory() as tmpdir: + zip_path = str(Path(tmpdir) / "good.zip") + with zipfile.ZipFile(zip_path, "w") as z: + z.writestr("subdir/file.txt", "hello") + _extract_zip(zip_path, tmpdir) + self.assertTrue((Path(tmpdir) / "subdir" / "file.txt").exists()) + + def test_extract_zip_path_traversal_raises(self): + from pyhealth.datasets.isic2018 import _extract_zip + + with tempfile.TemporaryDirectory() as tmpdir: + zip_path = str(Path(tmpdir) / "evil.zip") + with zipfile.ZipFile(zip_path, "w") as z: + z.writestr("../evil.txt", "bad") + with self.assertRaises(ValueError): + _extract_zip(zip_path, tmpdir) + + # ── _download (skip when already present) ───────────────── + + def test_download_task3_skipped_when_present(self): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + (root / "ISIC2018_Task3_Training_Input").mkdir() + (root / "ISIC2018_Task3_Training_GroundTruth.csv").write_text("x") + ds = ISIC2018Dataset.__new__(ISIC2018Dataset) + ds.task = "task3" + ds._image_dir = str(root / "ISIC2018_Task3_Training_Input") + ds._label_path = str( + root / "ISIC2018_Task3_Training_GroundTruth.csv") + with patch("pyhealth.datasets.isic2018._download_file") as mock_dl: + ds._download(str(root)) + mock_dl.assert_not_called() + + def test_download_task3_skipped_when_zip_present(self): + """If ZIP exists but not extracted, skip download and proceed to extract.""" + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + # Create fake ZIPs but don't extract + (root / "ISIC2018_Task3_Training_GroundTruth.zip").write_text("fake") + (root / "ISIC2018_Task3_Training_Input.zip").write_text("fake") + + ds = ISIC2018Dataset.__new__(ISIC2018Dataset) + ds.task = "task3" + ds._image_dir = str(root / "ISIC2018_Task3_Training_Input") + ds._label_path = str( + root / "ISIC2018_Task3_Training_GroundTruth.csv") + + with patch("pyhealth.datasets.isic2018._download_file") as mock_dl, \ + patch("pyhealth.datasets.isic2018._extract_zip"): + ds._download(str(root)) + # Should not call download if ZIP already exists + mock_dl.assert_not_called() + + def test_download_task12_skipped_when_present(self): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + (root / "ISIC2018_Task1-2_Training_Input").mkdir() + (root / "ISIC2018_Task1_Training_GroundTruth").mkdir() + ds = ISIC2018Dataset.__new__(ISIC2018Dataset) + ds.task = "task1_2" + ds._image_dir = str(root / "ISIC2018_Task1-2_Training_Input") + ds._mask_dir = str(root / "ISIC2018_Task1_Training_GroundTruth") + with patch("pyhealth.datasets.isic2018._download_file") as mock_dl: + ds._download(str(root)) + mock_dl.assert_not_called() + + # ── _download (fetch when missing) ──────────────────────── + + def test_download_task3_fetches_both_when_missing(self): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + ds = ISIC2018Dataset.__new__(ISIC2018Dataset) + ds.task = "task3" + ds._image_dir = str(root / "ISIC2018_Task3_Training_Input") + ds._label_path = str( + root / "ISIC2018_Task3_Training_GroundTruth.csv") + with patch("pyhealth.datasets.isic2018._download_file") as mock_dl, \ + patch("pyhealth.datasets.isic2018._extract_zip"), \ + patch("pyhealth.datasets.isic2018.os.remove"): + ds._download(str(root)) + self.assertEqual(mock_dl.call_count, 2) # labels zip + images zip + + def test_download_task3_skips_labels_when_present(self): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + (root / "ISIC2018_Task3_Training_GroundTruth.csv").write_text("x") + ds = ISIC2018Dataset.__new__(ISIC2018Dataset) + ds.task = "task3" + ds._image_dir = str(root / "ISIC2018_Task3_Training_Input") + ds._label_path = str( + root / "ISIC2018_Task3_Training_GroundTruth.csv") + with patch("pyhealth.datasets.isic2018._download_file") as mock_dl, \ + patch("pyhealth.datasets.isic2018._extract_zip"), \ + patch("pyhealth.datasets.isic2018.os.remove"): + ds._download(str(root)) + self.assertEqual(mock_dl.call_count, 1) # only images zip + + def test_download_task12_fetches_both_when_missing(self): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + ds = ISIC2018Dataset.__new__(ISIC2018Dataset) + ds.task = "task1_2" + ds._image_dir = str(root / "ISIC2018_Task1-2_Training_Input") + ds._mask_dir = str(root / "ISIC2018_Task1_Training_GroundTruth") + with patch("pyhealth.datasets.isic2018._download_file") as mock_dl, \ + patch("pyhealth.datasets.isic2018._extract_zip"), \ + patch("pyhealth.datasets.isic2018.os.remove"): + ds._download(str(root)) + self.assertEqual(mock_dl.call_count, 2) # images zip + masks zip + + def test_download_task12_skips_images_when_present(self): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + (root / "ISIC2018_Task1-2_Training_Input").mkdir() + ds = ISIC2018Dataset.__new__(ISIC2018Dataset) + ds.task = "task1_2" + ds._image_dir = str(root / "ISIC2018_Task1-2_Training_Input") + ds._mask_dir = str(root / "ISIC2018_Task1_Training_GroundTruth") + with patch("pyhealth.datasets.isic2018._download_file") as mock_dl, \ + patch("pyhealth.datasets.isic2018._extract_zip"), \ + patch("pyhealth.datasets.isic2018.os.remove"): + ds._download(str(root)) + self.assertEqual(mock_dl.call_count, 1) # only masks zip + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_isic2018_artifacts.py b/tests/core/test_isic2018_artifacts.py new file mode 100644 index 000000000..bed99253e --- /dev/null +++ b/tests/core/test_isic2018_artifacts.py @@ -0,0 +1,351 @@ +""" +Unit tests for ISIC2018ArtifactsDataset and associated task classes. + +Fixture layout (test-resources/core/isic2018_artifacts/): + isic_bias.csv — 8 images, 4 melanoma / 4 non-melanoma, semicolon-delimited + images/ — fake PNGs generated in setUpClass, deleted in tearDownClass + masks/ — fake segmentation PNGs generated in setUpClass + +Fixture summary +--------------- +image label ruler hair dark_corner gel_border gel_bubble ink patches +ISIC_0024306.png 0 1 0 0 0 0 0 0 +ISIC_0024307.png 1 0 1 0 0 0 0 0 +ISIC_0024308.png 0 0 0 1 0 0 0 0 +ISIC_0024309.png 1 1 0 0 0 1 1 0 +ISIC_0024310.png 0 0 0 0 0 0 1 0 +ISIC_0024311.png 1 0 0 0 0 0 0 0 +ISIC_0024312.png 0 0 0 0 1 0 0 0 +ISIC_0024313.png 1 0 0 0 0 0 0 1 +""" + +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +from PIL import Image + +from pyhealth.datasets import ISIC2018ArtifactsDataset +from pyhealth.datasets.isic2018_artifacts import ARTIFACT_LABELS + +_RESOURCES = ( + Path(__file__).parent.parent.parent + / "test-resources" + / "core" + / "isic2018_artifacts" +) + +_IMAGE_NAMES = [f"ISIC_{24306 + i:07d}.png" for i in range(8)] + + +def _make_fake_images(image_dir: Path) -> None: + image_dir.mkdir(exist_ok=True) + for name in _IMAGE_NAMES: + arr = np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8) + Image.fromarray(arr).save(image_dir / name) + + +def _make_fake_masks(mask_dir: Path) -> None: + mask_dir.mkdir(exist_ok=True) + for name in _IMAGE_NAMES: + stem = Path(name).stem + arr = np.zeros((64, 64), dtype=np.uint8) + arr[16:48, 16:48] = 255 + Image.fromarray(arr).save(mask_dir / f"{stem}_segmentation.png") + + +def _delete_fake_images(image_dir: Path) -> None: + for png in image_dir.glob("*.png"): + png.unlink() + + +def _delete_fake_masks(mask_dir: Path) -> None: + for png in mask_dir.glob("*.png"): + png.unlink() + + +class TestISIC2018ArtifactsDataset(unittest.TestCase): + """Tests for ISIC2018ArtifactsDataset loading, indexing, and validation.""" + + @classmethod + def setUpClass(cls): + cls.root = _RESOURCES + _make_fake_images(cls.root / "images") + _make_fake_masks(cls.root / "masks") + cls.cache_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + cls.dataset = ISIC2018ArtifactsDataset( + root=str(cls.root), + image_dir="images", + mask_dir="masks", + cache_dir=cls.cache_dir.name, + ) + + @classmethod + def tearDownClass(cls): + _delete_fake_images(cls.root / "images") + _delete_fake_masks(cls.root / "masks") + (cls.root / "isic-artifact-metadata-pyhealth.csv").unlink(missing_ok=True) + (cls.root / "isic-artifact-config-pyhealth.yaml").unlink(missing_ok=True) + try: + cls.cache_dir.cleanup() + except Exception: + pass + + # ── Dataset-level ──────────────────────────────────────────────────────── + + def test_stats(self): + self.dataset.stats() + + def test_num_patients(self): + # One patient per image + self.assertEqual(len(self.dataset.unique_patient_ids), 8) + + def test_default_task_is_none(self): + self.assertIsNone(self.dataset.default_task) + + def test_metadata_csv_created(self): + self.assertTrue( + (self.root / "isic-artifact-metadata-pyhealth.csv").exists()) + + def test_config_yaml_created(self): + self.assertTrue( + (self.root / "isic-artifact-config-pyhealth.yaml").exists()) + + def test_artifact_labels_class_attribute(self): + self.assertEqual( + ISIC2018ArtifactsDataset.artifact_labels, + ARTIFACT_LABELS) + + # ── Event field access ─────────────────────────────────────────────────── + + def test_event_fields_ruler_image(self): + events = self.dataset.get_patient("ISIC_0024306").get_events( + event_type="isic_artifacts" + ) + self.assertEqual(len(events), 1) + event = events[0] + self.assertEqual(event["image_id"], "ISIC_0024306") + self.assertEqual(event["ruler"], "1") + self.assertEqual(event["label"], "0") + + def test_event_fields_test_split(self): + event = self.dataset.get_patient("ISIC_0024311").get_events()[0] + self.assertEqual(event["label"], "1") + + def test_all_artifact_columns_present_on_every_patient(self): + for pid in self.dataset.unique_patient_ids: + event = self.dataset.get_patient(pid).get_events()[0] + for col in ARTIFACT_LABELS: + self.assertIn(col, event, f"Missing '{col}' for patient {pid}") + + def test_image_paths_exist(self): + import os + for pid in self.dataset.unique_patient_ids: + event = self.dataset.get_patient(pid).get_events()[0] + self.assertTrue(os.path.isfile(event["path"])) + + # ── Validation errors ──────────────────────────────────────────────────── + + def test_invalid_mode_raises(self): + with self.assertRaises(FileNotFoundError): + ISIC2018ArtifactsDataset(root="/nonexistent/path/xyz") + + def test_missing_csv_raises(self): + with tempfile.TemporaryDirectory() as tmpdir: + Path(tmpdir, "images").mkdir() + Path(tmpdir, "masks").mkdir() + with self.assertRaises(FileNotFoundError): + ISIC2018ArtifactsDataset(root=tmpdir) + + def test_missing_image_dir_raises(self): + with tempfile.TemporaryDirectory() as tmpdir: + import shutil + shutil.copy(self.root / "isic_bias.csv", tmpdir) + Path(tmpdir, "masks").mkdir() + with self.assertRaises(FileNotFoundError): + ISIC2018ArtifactsDataset(root=tmpdir, image_dir="nonexistent") + + def test_missing_mask_dir_raises(self): + with tempfile.TemporaryDirectory() as tmpdir: + import shutil + shutil.copy(self.root / "isic_bias.csv", tmpdir) + Path(tmpdir, "images").mkdir() + with self.assertRaises(FileNotFoundError): + ISIC2018ArtifactsDataset(root=tmpdir, mask_dir="nonexistent") + + def test_download_raises_for_custom_csv(self): + with self.assertRaises(ValueError): + ISIC2018ArtifactsDataset( + root=str(self.root), + annotations_csv="custom.csv", + download=True, + ) + + # ── Download behaviour ─────────────────────────────────────────────────── + + def test_download_skipped_when_csv_already_present(self): + """_download_bias_csv should not call requests.get if CSV exists.""" + with patch("pyhealth.datasets.isic2018_artifacts.requests.get") as mock_get: + self.dataset._download_bias_csv(str(self.root)) + mock_get.assert_not_called() + + def test_download_fetches_csv_when_absent(self): + """_download_bias_csv should write the CSV when it is missing.""" + from pyhealth.datasets.isic2018_artifacts import _BIAS_CSV + + csv_bytes = (self.root / "isic_bias.csv").read_bytes() + mock_resp = MagicMock() + mock_resp.content = csv_bytes + mock_resp.raise_for_status = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + # Instantiate without calling __init__ to test the method in + # isolation + obj = ISIC2018ArtifactsDataset.__new__(ISIC2018ArtifactsDataset) + obj.annotations_csv = _BIAS_CSV + obj._bias_csv_path = str(Path(tmpdir) / "isic_bias.csv") + + with patch( + "pyhealth.datasets.isic2018_artifacts.requests.get", + return_value=mock_resp, + ): + obj._download_bias_csv(tmpdir) + + mock_resp.raise_for_status.assert_called_once() + self.assertTrue(Path(tmpdir, "isic_bias.csv").exists()) + + def test_download_images_skipped_when_dirs_exist(self): + """_download_images should not call requests.get if extracted dirs exist.""" + from pyhealth.datasets.isic2018_artifacts import _IMAGES_DIR, _MASKS_DIR + + with tempfile.TemporaryDirectory() as tmpdir: + os.makedirs(os.path.join(tmpdir, _IMAGES_DIR)) + os.makedirs(os.path.join(tmpdir, _MASKS_DIR)) + obj = ISIC2018ArtifactsDataset.__new__(ISIC2018ArtifactsDataset) + mock_path = "pyhealth.datasets.isic2018_artifacts._download_file" + with patch(mock_path) as mock_dl: + obj._download_images(tmpdir) + mock_dl.assert_not_called() + + + + +class TestISIC2018ArtifactsDownload(unittest.TestCase): + """Covers _download_bias_csv fetch, _download_images branches, and + the no-matching-images ValueError — the remaining coverage gaps.""" + + # ── _download_bias_csv fetch ──────────────────────────────────────────── + + def test_download_bias_csv_fetches_when_absent(self): + from pyhealth.datasets.isic2018_artifacts import _BIAS_CSV + + csv_bytes = (_RESOURCES / "isic_bias.csv").read_bytes() + mock_resp = MagicMock() + mock_resp.content = csv_bytes + mock_resp.raise_for_status = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + obj = ISIC2018ArtifactsDataset.__new__(ISIC2018ArtifactsDataset) + obj.annotations_csv = _BIAS_CSV + obj._bias_csv_path = str(Path(tmpdir) / _BIAS_CSV) + + with patch("pyhealth.datasets.isic2018_artifacts.requests.get", + return_value=mock_resp): + obj._download_bias_csv(tmpdir) + + mock_resp.raise_for_status.assert_called_once() + self.assertTrue(Path(tmpdir, _BIAS_CSV).exists()) + + # ── _download_images branches ─────────────────────────────────────────── + + def test_download_images_fetches_when_dirs_absent(self): + """Both image and mask dirs missing → _download_file called twice.""" + + with tempfile.TemporaryDirectory() as tmpdir: + obj = ISIC2018ArtifactsDataset.__new__(ISIC2018ArtifactsDataset) + dl_path = "pyhealth.datasets.isic2018_artifacts._download_file" + zip_path = "pyhealth.datasets.isic2018_artifacts._extract_zip" + rm_path = "pyhealth.datasets.isic2018_artifacts.os.remove" + with patch(dl_path) as mock_dl, patch(zip_path), patch(rm_path): + obj._download_images(tmpdir) + self.assertEqual(mock_dl.call_count, 2) + + def test_download_images_skips_images_when_present(self): + """Image dir already present → only mask zip fetched.""" + from pyhealth.datasets.isic2018_artifacts import _IMAGES_DIR + + with tempfile.TemporaryDirectory() as tmpdir: + os.makedirs(os.path.join(tmpdir, _IMAGES_DIR)) + obj = ISIC2018ArtifactsDataset.__new__(ISIC2018ArtifactsDataset) + dl_path = "pyhealth.datasets.isic2018_artifacts._download_file" + zip_path = "pyhealth.datasets.isic2018_artifacts._extract_zip" + rm_path = "pyhealth.datasets.isic2018_artifacts.os.remove" + with patch(dl_path) as mock_dl, patch(zip_path), patch(rm_path): + obj._download_images(tmpdir) + self.assertEqual(mock_dl.call_count, 1) + + def test_download_images_skips_masks_when_present(self): + """Mask dir already present → only image zip fetched.""" + from pyhealth.datasets.isic2018_artifacts import _MASKS_DIR + + with tempfile.TemporaryDirectory() as tmpdir: + os.makedirs(os.path.join(tmpdir, _MASKS_DIR)) + obj = ISIC2018ArtifactsDataset.__new__(ISIC2018ArtifactsDataset) + dl_path = "pyhealth.datasets.isic2018_artifacts._download_file" + zip_path = "pyhealth.datasets.isic2018_artifacts._extract_zip" + rm_path = "pyhealth.datasets.isic2018_artifacts.os.remove" + with patch(dl_path) as mock_dl, patch(zip_path), patch(rm_path): + obj._download_images(tmpdir) + self.assertEqual(mock_dl.call_count, 1) + + # ── no-matching-images ValueError ─────────────────────────────────────── + + def test_no_matching_images_raises(self): + """CSV present but no image files match → ValueError.""" + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + (root / "images").mkdir() + (root / "masks").mkdir() + # CSV references images that don't exist in the images dir + (_RESOURCES / "isic_bias.csv").read_bytes() + import shutil + shutil.copy(_RESOURCES / "isic_bias.csv", root / "isic_bias.csv") + with self.assertRaises(ValueError): + ISIC2018ArtifactsDataset( + root=str(root), + image_dir="images", + mask_dir="masks", + ) + + # ── download=True constructor path ────────────────────────────────────── + + def test_constructor_download_true_calls_both_downloads(self): + """download=True triggers both _download_bias_csv and _download_images.""" + with tempfile.TemporaryDirectory() as tmpdir: + with patch.object( + ISIC2018ArtifactsDataset, "_download_bias_csv" + ) as mock_csv, patch.object( + ISIC2018ArtifactsDataset, "_download_images" + ) as mock_img, patch.object( + ISIC2018ArtifactsDataset, "_verify_data" + ), patch.object( + ISIC2018ArtifactsDataset, "_index_data", return_value=None + ), patch("pyhealth.datasets.base_dataset.BaseDataset.__init__"): + obj = ISIC2018ArtifactsDataset.__new__( + ISIC2018ArtifactsDataset) + obj.annotations_csv = "isic_bias.csv" + obj._image_dir = tmpdir + obj.mask_dir = tmpdir + obj._bias_csv_path = str(Path(tmpdir) / "isic_bias.csv") + # Call __init__ manually with download=True + ISIC2018ArtifactsDataset.__init__( + obj, root=tmpdir, download=True) + mock_csv.assert_called_once_with(tmpdir) + mock_img.assert_called_once_with(tmpdir) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_ph2.py b/tests/core/test_ph2.py new file mode 100644 index 000000000..86f37d318 --- /dev/null +++ b/tests/core/test_ph2.py @@ -0,0 +1,313 @@ +""" +Unit tests for PH2Dataset and PH2MelanomaClassification. + +Fixture layout (test-resources/core/ph2/): + PH2_dataset.csv — 5 images, original format: 2 common_nevus, 2 atypical_nevus, 1 melanoma + PH2_Dataset_images/ — fake BMPs generated in setUpClass, deleted in tearDownClass + PH2_simple_dataset.csv — same 5 images, GitHub mirror format + +Image IDs: IMD001–IMD005 +BMP structure: + PH2_Dataset_images/IMDXXX/IMDXXX_Dermoscopic_Image/IMDXXX.bmp +JPEG structure (mirror): + images/IMDXXX.jpg +""" + +import io +import os +import shutil +import tempfile +import unittest +import zipfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +from PIL import Image + +from pyhealth.datasets import PH2Dataset +from pyhealth.tasks import PH2MelanomaClassification + +_RESOURCES = Path(__file__).parent.parent.parent / "test-resources" / "core" / "ph2" + +_IMAGE_IDS = ["IMD003", "IMD009", "IMD002", "IMD004", "IMD058"] +_DIAGNOSES = { + "IMD003": "common_nevus", + "IMD009": "common_nevus", + "IMD002": "atypical_nevus", + "IMD004": "atypical_nevus", + "IMD058": "melanoma", +} + + +def _make_fake_bmp_images(root: Path) -> None: + """Create nested BMP structure (original PH2 format).""" + for img_id in _IMAGE_IDS: + img_dir = root / "PH2_Dataset_images" / img_id / f"{img_id}_Dermoscopic_Image" + img_dir.mkdir(parents=True, exist_ok=True) + arr = np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8) + Image.fromarray(arr).save(img_dir / f"{img_id}.bmp") + + +def _make_fake_jpg_images(root: Path) -> None: + """Create flat JPEG structure (GitHub mirror format).""" + (root / "images").mkdir(exist_ok=True) + for img_id in _IMAGE_IDS: + arr = np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8) + Image.fromarray(arr).save(root / "images" / f"{img_id}.jpg") + + +def _delete_bmp_images(root: Path) -> None: + img_root = root / "PH2_Dataset_images" + if img_root.exists(): + shutil.rmtree(img_root) + + +def _delete_jpg_images(root: Path) -> None: + img_root = root / "images" + if img_root.exists(): + shutil.rmtree(img_root) + + +def _make_mirror_zip(root: Path) -> str: + """Build a fake GitHub mirror zip archive for download tests.""" + buf = io.BytesIO() + prefix = "PH2-dataset-master/" + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr(prefix, "") # top-level dir entry + zf.writestr(prefix + "images/", "") + + # Write simple CSV + csv_lines = "image_name,diagnosis\n" + for img_id, diag in _DIAGNOSES.items(): + label = { + "common_nevus": "Common Nevus", + "atypical_nevus": "Atypical Nevus", + "melanoma": "Melanoma", + }[diag] + csv_lines += f"{img_id},{label}\n" + zf.writestr(prefix + "PH2_simple_dataset.csv", csv_lines) + + # Write fake JPEGs + for img_id in _IMAGE_IDS: + arr = np.zeros((8, 8, 3), dtype=np.uint8) + img_buf = io.BytesIO() + Image.fromarray(arr).save(img_buf, format="JPEG") + zf.writestr(prefix + f"images/{img_id}.jpg", img_buf.getvalue()) + + zip_path = str(root / "ph2_mirror.zip") + with open(zip_path, "wb") as f: + f.write(buf.getvalue()) + return zip_path + + +class TestPH2Dataset(unittest.TestCase): + """Tests for PH2Dataset loading and indexing (original BMP format).""" + + @classmethod + def setUpClass(cls): + cls.root = _RESOURCES + _make_fake_bmp_images(cls.root) + (cls.root / "ph2_metadata_pyhealth.csv").unlink(missing_ok=True) + cls.cache_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + cls.dataset = PH2Dataset(root=str(cls.root), cache_dir=cls.cache_dir.name) + + @classmethod + def tearDownClass(cls): + _delete_bmp_images(cls.root) + (cls.root / "ph2_metadata_pyhealth.csv").unlink(missing_ok=True) + (cls.root / "ph2-config-pyhealth.yaml").unlink(missing_ok=True) + try: + cls.cache_dir.cleanup() + except Exception: + pass + + def test_num_patients(self): + self.assertEqual(len(self.dataset.unique_patient_ids), 5) + + def test_metadata_csv_created(self): + self.assertTrue((self.root / "ph2_metadata_pyhealth.csv").exists()) + + def test_default_task_is_classification(self): + self.assertIsInstance(self.dataset.default_task, PH2MelanomaClassification) + + def test_event_has_path(self): + pid = next(iter(self.dataset.unique_patient_ids)) + event = self.dataset.get_patient(pid).get_events(event_type="ph2")[0] + self.assertTrue(os.path.isfile(event["path"])) + + def test_event_has_diagnosis(self): + for img_id, expected in _DIAGNOSES.items(): + event = self.dataset.get_patient(img_id).get_events(event_type="ph2")[0] + self.assertEqual(event["diagnosis"], expected) + + def test_image_paths_exist(self): + for pid in self.dataset.unique_patient_ids: + event = self.dataset.get_patient(pid).get_events(event_type="ph2")[0] + self.assertTrue(os.path.isfile(event["path"])) + + def test_missing_root_raises(self): + with self.assertRaises(FileNotFoundError): + PH2Dataset(root="/nonexistent/path/xyz") + + def test_missing_source_file_raises(self): + with tempfile.TemporaryDirectory() as tmpdir: + Path(tmpdir, "images").mkdir() + with self.assertRaises(FileNotFoundError): + PH2Dataset(root=tmpdir) + + def test_missing_images_dir_raises(self): + with tempfile.TemporaryDirectory() as tmpdir: + shutil.copy(_RESOURCES / "PH2_dataset.csv", tmpdir) + with self.assertRaises(FileNotFoundError): + PH2Dataset(root=tmpdir) + + +class TestPH2PrepareMetadata(unittest.TestCase): + """Tests for PH2Dataset._prepare_metadata from raw CSV sources.""" + + def test_prepare_metadata_from_original_csv(self): + """_prepare_metadata reads PH2_dataset.csv (original format).""" + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + shutil.copy(_RESOURCES / "PH2_dataset.csv", root / "PH2_dataset.csv") + _make_fake_bmp_images(root) + + obj = PH2Dataset.__new__(PH2Dataset) + obj._prepare_metadata(str(root)) + + out_csv = root / "ph2_metadata_pyhealth.csv" + self.assertTrue(out_csv.exists()) + + df = __import__("pandas").read_csv(out_csv) + self.assertEqual(len(df), 5) + self.assertIn("image_id", df.columns) + self.assertIn("path", df.columns) + self.assertIn("diagnosis", df.columns) + self.assertCountEqual( + df["diagnosis"].tolist(), + ["common_nevus", "common_nevus", "atypical_nevus", "atypical_nevus", "melanoma"], + ) + _delete_bmp_images(root) + + def test_prepare_metadata_from_simple_csv(self): + """_prepare_metadata reads PH2_simple_dataset.csv (GitHub mirror format).""" + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + shutil.copy(_RESOURCES / "PH2_simple_dataset.csv", root / "PH2_simple_dataset.csv") + _make_fake_jpg_images(root) + + obj = PH2Dataset.__new__(PH2Dataset) + obj._prepare_metadata(str(root)) + + out_csv = root / "ph2_metadata_pyhealth.csv" + df = __import__("pandas").read_csv(out_csv) + self.assertEqual(len(df), 5) + self.assertCountEqual( + df["diagnosis"].tolist(), + ["common_nevus", "common_nevus", "atypical_nevus", "atypical_nevus", "melanoma"], + ) + _delete_jpg_images(root) + + def test_missing_source_raises(self): + with tempfile.TemporaryDirectory() as tmpdir: + obj = PH2Dataset.__new__(PH2Dataset) + with self.assertRaises(FileNotFoundError): + obj._prepare_metadata(tmpdir) + + +class TestPH2Download(unittest.TestCase): + """Tests for PH2Dataset download functionality.""" + + def test_download_skipped_when_images_present(self): + """_download should not call _download_file if images/ already exists.""" + with tempfile.TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "images").mkdir() + obj = PH2Dataset.__new__(PH2Dataset) + with patch("pyhealth.datasets.ph2._download_file") as mock_dl: + obj._download(tmpdir) + mock_dl.assert_not_called() + + def test_download_fetches_and_extracts(self): + """_download downloads zip and extracts images/ + PH2_simple_dataset.csv.""" + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + zip_path = _make_mirror_zip(root) + + obj = PH2Dataset.__new__(PH2Dataset) + with patch("pyhealth.datasets.ph2._download_file", + side_effect=lambda url, dest, **kw: shutil.copy(zip_path, dest)): + obj._download(tmpdir) + + self.assertTrue((root / "images").is_dir()) + self.assertTrue((root / "PH2_simple_dataset.csv").exists()) + # All 5 fake images extracted + jpgs = list((root / "images").glob("*.jpg")) + self.assertEqual(len(jpgs), 5) + + def test_download_true_loads_dataset(self): + """download=True followed by dataset loading works end-to-end.""" + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + zip_path = _make_mirror_zip(root) + + with patch("pyhealth.datasets.ph2._download_file", + side_effect=lambda url, dest, **kw: shutil.copy(zip_path, dest)): + ds = PH2Dataset(root=tmpdir, download=True) + + self.assertEqual(len(ds.unique_patient_ids), 5) + + +class TestPH2MelanomaClassification(unittest.TestCase): + """Tests for PH2MelanomaClassification task logic.""" + + @classmethod + def setUpClass(cls): + _make_fake_bmp_images(_RESOURCES) + (_RESOURCES / "ph2_metadata_pyhealth.csv").unlink(missing_ok=True) + cls.cache_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + cls.dataset = PH2Dataset(root=str(_RESOURCES), cache_dir=cls.cache_dir.name) + cls.task = PH2MelanomaClassification() + + @classmethod + def tearDownClass(cls): + _delete_bmp_images(_RESOURCES) + (_RESOURCES / "ph2_metadata_pyhealth.csv").unlink(missing_ok=True) + (_RESOURCES / "ph2-config-pyhealth.yaml").unlink(missing_ok=True) + try: + cls.cache_dir.cleanup() + except Exception: + pass + + def test_task_name(self): + self.assertEqual(self.task.task_name, "PH2MelanomaClassification") + + def test_output_schema_is_multiclass(self): + self.assertEqual(self.task.output_schema["label"], "multiclass") + + def test_call_returns_sample_per_image(self): + pid = next(iter(self.dataset.unique_patient_ids)) + patient = self.dataset.get_patient(pid) + samples = self.task(patient) + self.assertEqual(len(samples), 1) + self.assertIn("image", samples[0]) + self.assertIn("label", samples[0]) + + def test_correct_labels(self): + for img_id, expected_diag in _DIAGNOSES.items(): + patient = self.dataset.get_patient(img_id) + samples = self.task(patient) + self.assertEqual(samples[0]["label"], expected_diag) + + def test_melanoma_count(self): + mel_count = sum( + 1 + for pid in self.dataset.unique_patient_ids + for s in self.task(self.dataset.get_patient(pid)) + if s["label"] == "melanoma" + ) + self.assertEqual(mel_count, 1) + + +if __name__ == "__main__": + unittest.main()