-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathapp.py
More file actions
151 lines (121 loc) · 4.8 KB
/
Copy pathapp.py
File metadata and controls
151 lines (121 loc) · 4.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Gradio Space: upload a NIfTI brain activation map, get NiCLIP task predictions."""
import logging
import os.path as op
import time
import numpy as np
from nilearn.image import new_img_like
import gradio as gr
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("niclip-space")
MODEL_NAME = "BrainGPT-7B-v0.2"
SECTION = "body"
SOURCE = "cogatlasred"
_RESOURCES = {}
def load_resources():
"""Download assets and build the model/embedder once, warming the DiFuMo atlas cache."""
t0 = time.perf_counter()
from braindec.fetcher import download_bundle, get_data_dir
work_dir = get_data_dir()
download_bundle("example_prediction", destination_root=work_dir)
logger.info("bundle ready in %.1fs", time.perf_counter() - t0)
data_dir = op.join(work_dir, "data")
results_dir = op.join(work_dir, "results")
voc_dir = op.join(data_dir, "vocabulary")
voc_label = f"vocabulary-{SOURCE}_task-combined_embedding-{MODEL_NAME}"
model_fn = op.join(
results_dir, "pubmed", f"model-clip_section-{SECTION}_embedding-{MODEL_NAME}_best.pth"
)
vocabulary_fn = op.join(voc_dir, f"vocabulary-{SOURCE}_task.txt")
vocabulary_emb_fn = op.join(voc_dir, f"{voc_label}.npy")
vocabulary_prior_fn = op.join(voc_dir, f"{voc_label}_section-{SECTION}_prior.npy")
from braindec.model import build_model
from braindec.utils import _get_device
device = _get_device()
t1 = time.perf_counter()
try:
model = build_model(model_fn, device=device)
except Exception as exc:
raise RuntimeError(f"Failed to load CLIP model from {model_fn}") from exc
logger.info("model built in %.1fs (device=%s)", time.perf_counter() - t1, device)
from braindec.embedding import ImageEmbedding
image_embedder = ImageEmbedding(
standardize=False, nilearn_dir=op.join(data_dir, "nilearn"), space="MNI152"
)
with open(vocabulary_fn) as fh:
vocabulary = [line.strip() for line in fh]
vocabulary_emb = np.load(vocabulary_emb_fn)
vocabulary_prior = np.load(vocabulary_prior_fn)
t2 = time.perf_counter()
dummy = new_img_like(
image_embedder.atlas_maps,
np.zeros(image_embedder.atlas_maps.shape[:3], dtype=np.float32),
)
image_embedder(dummy)
logger.info("image_embedder warmup in %.1fs", time.perf_counter() - t2)
_RESOURCES.update(
model=model,
device=device,
image_embedder=image_embedder,
vocabulary=vocabulary,
vocabulary_emb=vocabulary_emb,
vocabulary_prior=vocabulary_prior,
data_dir=data_dir,
)
logger.info("TOTAL startup: %.1fs", time.perf_counter() - t0)
def predict(nifti_path, topk):
if nifti_path is None:
raise gr.Error("Please upload a .nii or .nii.gz file.")
lower = nifti_path.lower()
if not (lower.endswith(".nii") or lower.endswith(".nii.gz")):
raise gr.Error("File must be a NIfTI image (.nii or .nii.gz).")
from braindec.predict import image_to_labels
r = _RESOURCES
try:
task_df = image_to_labels(
nifti_path,
model_path=None,
vocabulary=r["vocabulary"],
vocabulary_emb=r["vocabulary_emb"],
prior_probability=r["vocabulary_prior"],
topk=int(topk),
logit_scale=20.0,
model=r["model"],
image_emb_gene=r["image_embedder"],
data_dir=r["data_dir"],
device=r["device"],
)
except gr.Error:
raise
except Exception as exc:
logger.exception("prediction failed")
raise gr.Error(f"Decoding failed: {exc}") from exc
return task_df[["pred", "prob", "bayes_factor"]].round(4)
def build_ui():
with gr.Blocks(title="NiCLIP: Brain Activation Decoder") as demo:
gr.Markdown(
"# NiCLIP — Functional Brain Decoding\n"
"Upload a NIfTI brain activation map (group- or subject-level "
"z-stat/t-stat) to predict the cognitive tasks most associated with it. "
"[Paper](https://doi.org/10.1101/2025.06.14.659706)"
)
with gr.Row():
with gr.Column(scale=1):
file_in = gr.File(label="Activation map (.nii / .nii.gz)", type="filepath")
topk = gr.Slider(3, 20, value=10, step=1, label="Top-k tasks")
run_btn = gr.Button("Decode", variant="primary")
with gr.Column(scale=2):
task_out = gr.Dataframe(label="Predicted tasks P(T|A)")
run_btn.click(
fn=predict,
inputs=[file_in, topk],
outputs=[task_out],
concurrency_limit=1,
)
return demo
def main():
load_resources()
demo = build_ui()
demo.queue(default_concurrency_limit=1)
demo.launch()
if __name__ == "__main__":
main()