diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..3b5e57901 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -204,5 +204,6 @@ API Reference models/pyhealth.models.VisionEmbeddingModel models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT + models/pyhealth.models.CBraMod_Wrapper models/pyhealth.models.unified_multimodal_embedding_docs models/pyhealth.models.califorest diff --git a/docs/api/models/pyhealth.models.cbramod.rst b/docs/api/models/pyhealth.models.cbramod.rst new file mode 100644 index 000000000..b4bd5b547 --- /dev/null +++ b/docs/api/models/pyhealth.models.cbramod.rst @@ -0,0 +1,67 @@ +pyhealth.models.CBraMod_Wrapper +=================================== + +CBraMod model for EEG signal classification. + +Overview +-------- + +CBraMod is a criss-cross attention transformer tailored for EEG decoding. The +wrapper integrates the model into the PyHealth ``BaseModel`` pipeline so it can +be trained with the standard ``Trainer`` APIs. + +Input/Output +------------ + +- **Input:** ``signal`` tensor shaped ``(batch, channels, timesteps)`` where + ``timesteps`` is a multiple of 200 (the patch size used by CBraMod). +- **Output (classifier_head=True):** dict with ``loss``, ``y_prob``, ``y_true``, + ``logit``, and ``embeddings``. +- **Output (classifier_head=False):** dict with ``logit`` and ``embeddings``. + +Example Usage +------------- + +.. code-block:: python + + import torch + from pyhealth.datasets import create_sample_dataset, get_dataloader + from pyhealth.models import CBraMod_Wrapper + + n_channels = 16 + patch_size = 200 + n_patches = 10 + n_samples = patch_size * n_patches + + samples = [ + { + "patient_id": f"patient-{i}", + "visit_id": "visit-0", + "signal": torch.randn(n_channels, n_samples).numpy().tolist(), + "label": i % 6, + } + for i in range(8) + ] + + dataset = create_sample_dataset( + samples=samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "multiclass"}, + dataset_name="test_cbramod", + ) + + model = CBraMod_Wrapper( + dataset=dataset, + seq_len=n_patches, + n_classes=6, + classifier_head=True, + ) + + batch = next(iter(get_dataloader(dataset, batch_size=2, shuffle=True))) + output = model(**batch) + print(output["logit"].shape) + +.. autoclass:: pyhealth.models.CBraMod_Wrapper + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 9193c86c0..153a3513d 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -187,7 +187,9 @@ EEG and Sleep Analysis * - ``EEG_events_SparcNet.py`` - SparcNet for EEG event detection * - ``EEG_isAbnormal_SparcNet.py`` - - SparcNet for EEG abnormality detection + - SparcNet for EEG abnormality detection + * - ``CBraMod_tuab_eeg_abnormal_classification.py`` + - CBraMod for EEG abnormality detection on TUAB * - ``cardiology_detection_isAR_SparcNet.py`` - SparcNet for cardiology arrhythmia detection diff --git a/examples/eeg/eeg_models/CBraMod_tuab_eeg_abnormal_classification.py b/examples/eeg/eeg_models/CBraMod_tuab_eeg_abnormal_classification.py new file mode 100644 index 000000000..b029bcde7 --- /dev/null +++ b/examples/eeg/eeg_models/CBraMod_tuab_eeg_abnormal_classification.py @@ -0,0 +1,68 @@ +from pyhealth.datasets import TUABDataset, split_by_visit, get_dataloader +from pyhealth.tasks import EEGAbnormalTUAB +from pyhealth.models import CBraMod_Wrapper +from pyhealth.trainer import Trainer + +# step 1: load signal data +dataset = TUABDataset( + root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/", + dev=True, + refresh_cache=True, +) +print(dataset.stats()) + +# step 2: set task (disable STFT for CBraMod) +TUAB_ds = dataset.set_task( + EEGAbnormalTUAB( + resample_rate=200, + bandpass_filter=(0.1, 75.0), + notch_filter=50.0, + compute_stft=False, + ) +) + +print(f"Total task samples: {len(TUAB_ds)}") +print(f"Input schema: {TUAB_ds.input_schema}") +print(f"Output schema: {TUAB_ds.output_schema}") + +# Inspect a sample to infer sequence length +sample = TUAB_ds[0] +print(f"\nSample keys: {sample.keys()}") +print(f"Signal shape: {sample['signal'].shape}") +print(f"Label: {sample['label']}") + +seq_len = sample["signal"].shape[-1] // 200 + +# split dataset +train_dataset, val_dataset, test_dataset = split_by_visit( + TUAB_ds, [0.6, 0.2, 0.2] +) +train_dataloader = get_dataloader(train_dataset, batch_size=16, shuffle=True) +val_dataloader = get_dataloader(val_dataset, batch_size=16, shuffle=False) +test_dataloader = get_dataloader(test_dataset, batch_size=16, shuffle=False) +print( + "loader size: train/val/test", + len(train_dataset), + len(val_dataset), + len(test_dataset), +) + +# step 3: define model +model = CBraMod_Wrapper( + dataset=TUAB_ds, + seq_len=seq_len, + n_classes=2, + classifier_head=True, +) + +# step 4: define trainer +trainer = Trainer(model=model, device="cuda:0") +trainer.train( + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + epochs=10, + optimizer_params={"lr": 1e-4}, +) + +# step 5: evaluate +print(trainer.evaluate(test_dataloader)) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..4825e8c18 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -3,6 +3,7 @@ from .base_model import BaseModel from .transformer_deid import TransformerDeID from .biot import BIOT +from .cbramod import CBraMod_Wrapper from .cnn import CNN, CNNLayer from .concare import ConCare, ConCareLayer from .contrawr import ContraWR, ResBlock2D @@ -45,4 +46,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .califorest import CaliForest \ No newline at end of file +from .califorest import CaliForest diff --git a/pyhealth/models/cbramod.py b/pyhealth/models/cbramod.py new file mode 100644 index 000000000..5ad6d453e --- /dev/null +++ b/pyhealth/models/cbramod.py @@ -0,0 +1,521 @@ +import copy +from typing import Optional, Any, Union, Callable, Dict +import torch +import torch.nn as nn +from einops import rearrange +from torch import Tensor +from torch.nn import functional as F + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True): + super().__init__() + torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: Optional[bool] = None) -> Tensor: + + output = src + for mod in self.layers: + output = mod(output, src_mask=mask) + if self.norm is not None: + output = self.norm(output) + return output + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ['norm_first'] + + def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, + bias: bool = True, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.self_attn_s = nn.MultiheadAttention(d_model//2, nhead // 2, dropout=dropout, + bias=bias, batch_first=batch_first, + **factory_kwargs) + self.self_attn_t = nn.MultiheadAttention(d_model//2, nhead // 2, dropout=dropout, + bias=bias, batch_first=batch_first, + **factory_kwargs) + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) + + self.norm_first = norm_first + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + + # We can't test self.activation in forward() in TorchScript, + # so stash some information about it instead. + if activation is F.relu or isinstance(activation, torch.nn.ReLU): + self.activation_relu_or_gelu = 1 + elif activation is F.gelu or isinstance(activation, torch.nn.GELU): + self.activation_relu_or_gelu = 2 + else: + self.activation_relu_or_gelu = 0 + self.activation = activation + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, 'activation'): + self.activation = F.relu + + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: bool = False) -> Tensor: + + x = src + x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal) + x = x + self._ff_block(self.norm2(x)) + return x + + # self-attention block + def _sa_block(self, x: Tensor, + attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: + bz, ch_num, patch_num, patch_size = x.shape + xs = x[:, :, :, :patch_size // 2] + xt = x[:, :, :, patch_size // 2:] + xs = xs.transpose(1, 2).contiguous().view(bz*patch_num, ch_num, patch_size // 2) + xt = xt.contiguous().view(bz*ch_num, patch_num, patch_size // 2) + xs = self.self_attn_s(xs, xs, xs, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False)[0] + xs = xs.contiguous().view(bz, patch_num, ch_num, patch_size//2).transpose(1, 2) + xt = self.self_attn_t(xt, xt, xt, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False)[0] + xt = xt.contiguous().view(bz, ch_num, patch_num, patch_size//2) + x = torch.concat((xs, xt), dim=3) + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError(f"activation should be relu/gelu, not {activation}") + +def _get_clones(module, N): + # FIXME: copy.deepcopy() is not defined on nn.module + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_seq_len( + src: Tensor, + batch_first: bool +) -> Optional[int]: + + if src.is_nested: + return None + else: + src_size = src.size() + if len(src_size) == 2: + # unbatched: S, E + return src_size[0] + else: + # batched: B, S, E if batch_first else S, B, E + seq_len_pos = 1 if batch_first else 0 + return src_size[seq_len_pos] + + +def _detect_is_causal_mask( + mask: Optional[Tensor], + is_causal: Optional[bool] = None, + size: Optional[int] = None, +) -> bool: + """Return whether the given attention mask is causal. + + Warning: + If ``is_causal`` is not ``None``, its value will be returned as is. If a + user supplies an incorrect ``is_causal`` hint, + + ``is_causal=False`` when the mask is in fact a causal attention.mask + may lead to reduced performance relative to what would be achievable + with ``is_causal=True``; + ``is_causal=True`` when the mask is in fact not a causal attention.mask + may lead to incorrect and unpredictable execution - in some scenarios, + a causal mask may be applied based on the hint, in other execution + scenarios the specified mask may be used. The choice may not appear + to be deterministic, in that a number of factors like alignment, + hardware SKU, etc influence the decision whether to use a mask or + rely on the hint. + ``size`` if not None, check whether the mask is a causal mask of the provided size + Otherwise, checks for any causal mask. + """ + # Prevent type refinement + make_causal = (is_causal is True) + + if is_causal is None and mask is not None: + sz = size if size is not None else mask.size(-2) + causal_comparison = _generate_square_subsequent_mask( + sz, device=mask.device, dtype=mask.dtype) + + # Do not use `torch.equal` so we handle batched masks by + # broadcasting the comparison. + if mask.size() == causal_comparison.size(): + make_causal = bool((mask == causal_comparison).all()) + else: + make_causal = False + + return make_causal + + +def _generate_square_subsequent_mask( + sz: int, + device: torch.device = torch.device(torch._C._get_default_device()), # torch.device('cpu'), + dtype: torch.dtype = torch.get_default_dtype(), +) -> Tensor: + r"""Generate a square causal mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + return torch.triu( + torch.full((sz, sz), float('-inf'), dtype=dtype, device=device), + diagonal=1, + ) + + + + +class CBraMod(nn.Module): + def __init__(self, in_dim=200, out_dim=200, d_model=200, dim_feedforward=800, seq_len=30, n_layer=12, + nhead=8): + super().__init__() + self.patch_embedding = PatchEmbedding(in_dim, out_dim, d_model, seq_len) + encoder_layer = TransformerEncoderLayer( + d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, norm_first=True, + activation=F.gelu + ) + self.encoder = TransformerEncoder(encoder_layer, num_layers=n_layer, enable_nested_tensor=False) + self.proj_out = nn.Sequential( + # nn.Linear(d_model, d_model*2), + # nn.GELU(), + # nn.Linear(d_model*2, d_model), + # nn.GELU(), + nn.Linear(d_model, out_dim), + ) + self.apply(_weights_init) + + def forward(self, x, mask=None): + patch_emb = self.patch_embedding(x, mask) + feats = self.encoder(patch_emb) + + out = self.proj_out(feats) + + return out + +class PatchEmbedding(nn.Module): + def __init__(self, in_dim, out_dim, d_model, seq_len): + super().__init__() + self.d_model = d_model + self.positional_encoding = nn.Sequential( + nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(19, 7), stride=(1, 1), padding=(9, 3), + groups=d_model), + ) + self.mask_encoding = nn.Parameter(torch.zeros(in_dim), requires_grad=False) + # self.mask_encoding = nn.Parameter(torch.randn(in_dim), requires_grad=True) + + self.proj_in = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=25, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)), + nn.GroupNorm(5, 25), + nn.GELU(), + + nn.Conv2d(in_channels=25, out_channels=25, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)), + nn.GroupNorm(5, 25), + nn.GELU(), + + nn.Conv2d(in_channels=25, out_channels=25, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)), + nn.GroupNorm(5, 25), + nn.GELU(), + ) + self.spectral_proj = nn.Sequential( + nn.Linear(101, d_model), + nn.Dropout(0.1), + # nn.LayerNorm(d_model, eps=1e-5), + ) + # self.norm1 = nn.LayerNorm(d_model, eps=1e-5) + # self.norm2 = nn.LayerNorm(d_model, eps=1e-5) + # self.proj_in = nn.Sequential( + # nn.Linear(in_dim, d_model, bias=False), + # ) + + + def forward(self, x, mask=None): + bz, ch_num, patch_num, patch_size = x.shape + if mask == None: + mask_x = x + else: + mask_x = x.clone() + mask_x[mask == 1] = self.mask_encoding + + mask_x = mask_x.contiguous().view(bz, 1, ch_num * patch_num, patch_size) + patch_emb = self.proj_in(mask_x) + patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model) + + mask_x = mask_x.contiguous().view(bz*ch_num*patch_num, patch_size) + spectral = torch.fft.rfft(mask_x, dim=-1, norm='forward') + spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101) + spectral_emb = self.spectral_proj(spectral) + # print(patch_emb[5, 5, 5, :]) + # print(spectral_emb[5, 5, 5, :]) + patch_emb = patch_emb + spectral_emb + + positional_embedding = self.positional_encoding(patch_emb.permute(0, 3, 1, 2)) + positional_embedding = positional_embedding.permute(0, 2, 3, 1) + + patch_emb = patch_emb + positional_embedding + + return patch_emb + + +def _weights_init(m): + if isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + + +class CBraMod_Wrapper(BaseModel): + """Cbramod: A criss-cross brain foundation model for eeg decoding + + Citation: Wang, Jiquan, et al. "Cbramod: A criss-cross brain foundation model for eeg decoding." arXiv preprint arXiv:2412.07236 (2024). + + The CBraMod model expects as input: + - Biosignal data in the format: (batch_size, n_channels, n_time) + + Args: + dataset: A SampleDataset instance that this model is trained or evaluated on. + in_dim: Input patch size or feature dimension from pre-processing (default: 200). + emb_size: Transformer embedding dimension (default: 200). + dim_feedforward: Feedforward network dimension inside the transformer encoder (default: 800). + seq_len: Number of patches per channel sequence (default: 30). + n_layer: Number of transformer encoder layers (default: 12). + nhead: Number of transformer attention heads (default: 8). + classifier_head: Whether to attach a classification head (default: True). + n_classes: Number of output classes for the classifier (default: 6). + + Examples: + >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.models import CBraMod_Wrapper + >>> dataset = SampleDataset(...) # Load your dataset compatible with SampleDataset + >>> model = CBraMod_Wrapper( + >>> dataset=dataset, + >>> in_dim=200, + >>> emb_size=200, + >>> dim_feedforward=800, + >>> seq_len=30, + >>> n_layer=12, + >>> nhead=8, + >>> classifier_head=True, + >>> n_classes=6, + >>> ) + >>> sample = torch.randn(8, 18, 6000) # batch of 8, 18 channels, 6000 timesteps + >>> output = model(signal=sample) + >>> # 'output' is a dictionary with 'loss', 'y_prob', 'y_true', 'logits', and 'embeddings' + """ + def __init__(self, + dataset: SampleDataset, + in_dim: int = 200, + emb_size: int = 200, + dim_feedforward: int = 800, + seq_len: int = 30, + n_layer: int = 12, + nhead: int = 8, + classifier_head: bool = True, + n_classes: int = 6, + **kwargs): + super().__init__(dataset=dataset) + + self.in_dim = in_dim + self.emb_size = emb_size + self.dim_feedforward = dim_feedforward + self.seq_len = seq_len + self.n_layer = n_layer + self.nhead = nhead + self.classifier_head = classifier_head + self.n_classes = n_classes + + + self.cbramod = CBraMod(in_dim=in_dim, d_model=emb_size, dim_feedforward=dim_feedforward, seq_len=seq_len, n_layer=n_layer, + nhead=nhead) + + + if classifier_head: + self.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1))) + self.classifier = nn.Linear(emb_size, n_classes) + + def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]: + """Forward propagation. + + Args: + **kwargs: keyword arguments containing 'signal'. + + Returns: + a dictionary containing loss, y_prob, y_true, logit, embeddings. + """ + signal = kwargs.get("signal") + if signal is None: + raise ValueError("'signal' must be provided in inputs") + signal = signal.to(self.device) + B,C,T = signal.shape + signal = rearrange(signal, 'B C (S T) -> B C S T',T = 200) + eeg_embed = self.cbramod(signal) + if eeg_embed.shape[0] != B: + eeg_embed = eeg_embed.unsqueeze(0) + + label_key = self.label_keys[0] + y_true = kwargs[label_key].to(self.device) + + + + if self.classifier_head: + eeg_embed = rearrange(eeg_embed,'B C S E -> B E C S') + eeg_embed = self.pool(eeg_embed) + eeg_embed = eeg_embed.view(eeg_embed.shape[0], -1) + logits = self.classifier(eeg_embed) + y_prob = self.prepare_y_prob(logits) + loss_fn = self.get_loss_function() + loss = loss_fn(logits, y_true) + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + "embeddings": eeg_embed, + } + else: + return { + "logit": eeg_embed, + "embeddings": eeg_embed, + } + + def load_pretrained_weights(self, path: str): + """Load pre-trained weights from checkpoint. + + Args: + path: path to the checkpoint file. + """ + self.cbramod.load_state_dict(torch.load(path, map_location=self.device),strict=False) + print(f"Loaded pretrained weights from {path}") + + + +if __name__ == "__main__": + from pyhealth.datasets import create_sample_dataset, get_dataloader + + print("Testing CBraMod model...") + + n_channels = 16 + n_time = 10 + patch_size = 200 + n_samples = patch_size * n_time # 2000 + + # Create sample dataset + samples = [ + { + "patient_id": f"patient-{i}", + "visit_id": "visit-0", + "signal": torch.randn(n_channels, n_samples).numpy().tolist(), + "label": i % 6, + } + for i in range(4) + ] + + dataset = create_sample_dataset( + samples=samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "multiclass"}, + dataset_name="test_cbramod", + ) + + # Use small model for quick testing + model = CBraMod_Wrapper( + dataset=dataset, + in_dim=200, + emb_size=200, + dim_feedforward=800, + seq_len=n_time, + n_layer=2, # small for testing + nhead=8, + classifier_head=True, + n_classes=6, + ) + print(f"✓ Created CBraMod_Wrapper") + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f" Total params: {total_params:,}") + print(f" Trainable params: {trainable_params:,}") + + # Forward pass + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + print(f"✓ Forward pass:") + print(f" Loss: {ret['loss'].item():.4f}") + print(f" Logit shape: {ret['logit'].shape}") + print(f" y_prob shape: {ret['y_prob'].shape}") + print(f" y_true shape: {ret['y_true'].shape}") + print(f" Embeddings shape: {ret['embeddings'].shape}") + + # Backward pass + ret2 = model(**data_batch) + ret2["loss"].backward() + has_grad = any(p.grad is not None for p in model.parameters() if p.requires_grad) + print(f"✓ Backward pass: gradients={'yes' if has_grad else 'no'}") + + # Test without classifier + model_no_cls = CBraMod_Wrapper( + dataset=dataset, + in_dim=200, + emb_size=200, + dim_feedforward=800, + seq_len=n_time, + n_layer=2, + nhead=8, + classifier_head=False, + n_classes=6, + ) + with torch.no_grad(): + ret3 = model_no_cls(**data_batch) + print(f"✓ Encoder-only mode: embeddings shape = {ret3['embeddings'].shape}") + + print("\n✓ All tests passed!") \ No newline at end of file diff --git a/tests/core/test_cbramod.py b/tests/core/test_cbramod.py new file mode 100644 index 000000000..0833ffd97 --- /dev/null +++ b/tests/core/test_cbramod.py @@ -0,0 +1,234 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import CBraMod_Wrapper + + +class TestCBraMod(unittest.TestCase): + """Test cases for the CBraMod_Wrapper model.""" + + def setUp(self): + """Set up test data and model.""" + n_channels = 16 + patch_size = 200 + n_patches = 10 + n_samples = patch_size * n_patches # 2000 + + self.samples = [ + { + "patient_id": f"patient-{i}", + "visit_id": "visit-0", + "signal": torch.randn(n_channels, n_samples).numpy().tolist(), + "label": i % 6, + } + for i in range(4) + ] + + self.input_schema = {"signal": "tensor"} + self.output_schema = {"label": "multiclass"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test_cbramod", + ) + + # Use small model for fast testing + self.model = CBraMod_Wrapper( + dataset=self.dataset, + in_dim=200, + emb_size=200, + dim_feedforward=800, + seq_len=n_patches, + n_layer=2, + nhead=8, + classifier_head=True, + n_classes=6, + ) + + def test_model_initialization(self): + """Test that the CBraMod_Wrapper model initializes correctly.""" + self.assertIsInstance(self.model, CBraMod_Wrapper) + self.assertTrue(self.model.classifier_head) + self.assertEqual(self.model.emb_size, 200) + self.assertEqual(self.model.n_classes, 6) + self.assertEqual(len(self.model.feature_keys), 1) + self.assertIn("signal", self.model.feature_keys) + + def test_model_forward(self): + """Test that the CBraMod_Wrapper forward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + self.assertIn("embeddings", ret) + + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["y_true"].shape[0], 2) + self.assertEqual(ret["logit"].shape[0], 2) + self.assertEqual(ret["logit"].shape[1], 6) # n_classes + self.assertEqual(ret["embeddings"].shape[0], 2) + self.assertEqual(ret["embeddings"].shape[1], 200) # emb_size + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + """Test that the CBraMod_Wrapper backward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + ret = self.model(**data_batch) + ret["loss"].backward() + + has_gradient = any( + param.requires_grad and param.grad is not None + for param in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters have gradients after backward pass") + + def test_model_without_classifier(self): + """Test CBraMod_Wrapper without classifier head (encoder only).""" + model_no_cls = CBraMod_Wrapper( + dataset=self.dataset, + in_dim=200, + emb_size=200, + dim_feedforward=800, + seq_len=10, + n_layer=2, + nhead=8, + classifier_head=False, + n_classes=6, + ) + + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model_no_cls(**data_batch) + + self.assertIn("logit", ret) + self.assertIn("embeddings", ret) + self.assertNotIn("loss", ret) + self.assertNotIn("y_prob", ret) + self.assertNotIn("y_true", ret) + + def test_model_different_batch_sizes(self): + """Test CBraMod_Wrapper with different batch sizes.""" + for batch_size in [1, 2, 4]: + train_loader = get_dataloader(self.dataset, batch_size=batch_size, shuffle=False) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + actual_batch = min(batch_size, len(self.samples)) + self.assertEqual(ret["y_prob"].shape[0], actual_batch) + self.assertEqual(ret["y_true"].shape[0], actual_batch) + self.assertEqual(ret["logit"].shape[0], actual_batch) + self.assertEqual(ret["embeddings"].shape[0], actual_batch) + + def test_model_output_probabilities(self): + """Test that output probabilities are valid.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + y_prob = ret["y_prob"] + self.assertTrue(torch.all(y_prob >= 0), "Probabilities contain negative values") + self.assertTrue(torch.all(y_prob <= 1), "Probabilities exceed 1") + + def test_missing_signal_raises_error(self): + """Test that missing 'signal' input raises ValueError.""" + with self.assertRaises((ValueError, KeyError)): + self.model(label=torch.tensor([0, 1])) + + def test_model_different_n_classes(self): + """Test CBraMod_Wrapper with different number of classes.""" + n_channels = 16 + patch_size = 200 + n_patches = 10 + n_samples = patch_size * n_patches + + binary_samples = [ + { + "patient_id": f"patient-{i}", + "visit_id": "visit-0", + "signal": torch.randn(n_channels, n_samples).numpy().tolist(), + "label": i % 2, + } + for i in range(4) + ] + + binary_dataset = create_sample_dataset( + samples=binary_samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="test_cbramod_binary", + ) + + model_binary = CBraMod_Wrapper( + dataset=binary_dataset, + in_dim=200, + emb_size=200, + dim_feedforward=800, + seq_len=n_patches, + n_layer=2, + nhead=8, + classifier_head=True, + n_classes=1, + ) + + train_loader = get_dataloader(binary_dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model_binary(**data_batch) + + self.assertEqual(ret["logit"].shape[1], 1) + + def test_model_smaller_config(self): + """Test CBraMod_Wrapper with a smaller configuration.""" + model_small = CBraMod_Wrapper( + dataset=self.dataset, + in_dim=200, + emb_size=200, + dim_feedforward=400, + seq_len=10, + n_layer=1, + nhead=4, + classifier_head=True, + n_classes=6, + ) + + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model_small(**data_batch) + + self.assertIn("loss", ret) + self.assertEqual(ret["logit"].shape[1], 6) + self.assertEqual(ret["embeddings"].shape[1], 200) # smaller emb_size + + def test_embedding_shape(self): + """Test that embeddings have the correct shape.""" + train_loader = get_dataloader(self.dataset, batch_size=4, shuffle=False) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertEqual(ret["embeddings"].shape, (4, 200)) # (batch, emb_size) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/core/test_jamba_ehr.py b/tests/core/test_jamba_ehr.py index 01d0534e0..342fb6049 100644 --- a/tests/core/test_jamba_ehr.py +++ b/tests/core/test_jamba_ehr.py @@ -343,4 +343,4 @@ def test_model_initialization(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()