From 9f8fd6ae9a4896dcab53f7d0b7e61673f035e327 Mon Sep 17 00:00:00 2001 From: Daryl Okeke Date: Sun, 24 May 2026 20:13:42 -0500 Subject: [PATCH 1/4] Exclude pediatric patients (age < 18) from clinical prediction tasks --- pyhealth/tasks/drug_recommendation.py | 46 +++++++++++++++++++-- pyhealth/tasks/length_of_stay_prediction.py | 40 ++++++++++++++++-- pyhealth/tasks/mortality_prediction.py | 16 ++++++- 3 files changed, 94 insertions(+), 8 deletions(-) diff --git a/pyhealth/tasks/drug_recommendation.py b/pyhealth/tasks/drug_recommendation.py index ec113e1dd..8d08375a9 100644 --- a/pyhealth/tasks/drug_recommendation.py +++ b/pyhealth/tasks/drug_recommendation.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Any, Dict, Iterable, List, Optional import polars as pl @@ -379,6 +380,20 @@ def drug_recommendation_mimic3_fn(patient: Patient): } """ samples = [] + + # Skip pediatric patients + demographics = patient.get_events(event_type="patients") + admissions = patient.get_events(event_type="admissions") + if demographics and admissions: + dob_str = getattr(demographics[0], "dob", None) + if dob_str is not None: + try: + dob = datetime.strptime(str(dob_str)[:10], "%Y-%m-%d") + if (admissions[0].timestamp - dob).days / 365.25 < 18: + return [] + except (ValueError, TypeError): + pass + for i in range(len(patient)): visit: Visit = patient[i] conditions = visit.get_code_list(table="DIAGNOSES_ICD") @@ -388,7 +403,6 @@ def drug_recommendation_mimic3_fn(patient: Patient): # exclude: visits without condition, procedure, or drug code if len(conditions) * len(procedures) * len(drugs) == 0: continue - # TODO: should also exclude visit with age < 18 samples.append( { "visit_id": visit.visit_id, @@ -459,6 +473,17 @@ def drug_recommendation_mimic4_fn(patient: Patient): [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'label': [['2', '3', '4']]}] """ samples = [] + + # Skip pediatric patients + demographics = patient.get_events(event_type="patients") + if demographics: + anchor_age = getattr(demographics[0], "anchor_age", None) + try: + if anchor_age is not None and int(float(anchor_age)) < 18: + return [] + except (ValueError, TypeError): + pass + for i in range(len(patient)): visit: Visit = patient[i] conditions = visit.get_code_list(table="diagnoses_icd") @@ -468,7 +493,6 @@ def drug_recommendation_mimic4_fn(patient: Patient): # exclude: visits without condition, procedure, or drug code if len(conditions) * len(procedures) * len(drugs) == 0: continue - # TODO: should also exclude visit with age < 18 samples.append( { "visit_id": visit.visit_id, @@ -671,6 +695,23 @@ def drug_recommendation_omop_fn(patient: Patient): """ samples = [] + + # Skip pediatric patients + demographics = patient.get_events(event_type="person") + if demographics: + person = demographics[0] + birth_year = getattr(person, "year_of_birth", None) + if birth_year is not None: + try: + birth_month = int(getattr(person, "month_of_birth", None) or 1) + birth_day = int(getattr(person, "day_of_birth", None) or 1) + dob = datetime(int(birth_year), birth_month, birth_day) + visits = patient.get_events(event_type="visit_occurrence") + if visits and (visits[0].timestamp - dob).days / 365.25 < 18: + return [] + except (ValueError, TypeError): + pass + for i in range(len(patient)): visit: Visit = patient[i] conditions = visit.get_code_list(table="condition_occurrence") @@ -679,7 +720,6 @@ def drug_recommendation_omop_fn(patient: Patient): # exclude: visits without condition, procedure, or drug code if len(conditions) * len(procedures) * len(drugs) == 0: continue - # TODO: should also exclude visit with age < 18 samples.append( { "visit_id": visit.visit_id, diff --git a/pyhealth/tasks/length_of_stay_prediction.py b/pyhealth/tasks/length_of_stay_prediction.py index 25e0c3121..446c6f6ad 100644 --- a/pyhealth/tasks/length_of_stay_prediction.py +++ b/pyhealth/tasks/length_of_stay_prediction.py @@ -116,7 +116,18 @@ def __call__(self, patient: Patient) -> List[Dict]: los_days = (discharge_time - admit_time).days los_category = categorize_los(los_days) - # TODO: should also exclude visit with age < 18 + # Skip pediatric admissions + demographics = patient.get_events(event_type="patients") + if demographics: + dob_str = getattr(demographics[0], "dob", None) + if dob_str is not None: + try: + dob = datetime.strptime(str(dob_str)[:10], "%Y-%m-%d") + if (admit_time - dob).days / 365.25 < 18: + continue + except (ValueError, TypeError): + pass + samples.append( { "visit_id": admission.hadm_id, @@ -219,7 +230,16 @@ def __call__(self, patient: Patient) -> List[Dict]: los_days = (discharge_time - admit_time).days los_category = categorize_los(los_days) - # TODO: should also exclude visit with age < 18 + # Skip pediatric admissions + demographics = patient.get_events(event_type="patients") + if demographics: + anchor_age = getattr(demographics[0], "anchor_age", None) + try: + if anchor_age is not None and int(float(anchor_age)) < 18: + continue + except (ValueError, TypeError): + pass + samples.append( { "visit_id": admission.hadm_id, @@ -447,7 +467,21 @@ def __call__(self, patient: Patient) -> List[Dict]: los_days = (discharge_time - admit_time).days los_category = categorize_los(los_days) - # TODO: should also exclude visit with age < 18 + # Skip pediatric visits + demographics = patient.get_events(event_type="person") + if demographics: + person = demographics[0] + birth_year = getattr(person, "year_of_birth", None) + if birth_year is not None: + try: + birth_month = int(getattr(person, "month_of_birth", None) or 1) + birth_day = int(getattr(person, "day_of_birth", None) or 1) + dob = datetime(int(birth_year), birth_month, birth_day) + if (admit_time - dob).days / 365.25 < 18: + continue + except (ValueError, TypeError): + pass + samples.append( { "visit_id": visit.visit_occurrence_id, diff --git a/pyhealth/tasks/mortality_prediction.py b/pyhealth/tasks/mortality_prediction.py index 249f717f3..7c1a7c59d 100644 --- a/pyhealth/tasks/mortality_prediction.py +++ b/pyhealth/tasks/mortality_prediction.py @@ -875,7 +875,13 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: if len(conditions) * len(procedures_list) * len(drugs) == 0: continue - # TODO: Exclude visits with age < 18 + # Exclude stays with age < 18 + age = getattr(stay, "age", None) + try: + if age is not None and str(age) != "> 89" and int(float(age)) < 18: + continue + except (ValueError, TypeError): + pass samples.append( { @@ -991,7 +997,13 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: if len(conditions) * len(treatment_codes) == 0: continue - # TODO: Exclude visits with age < 18 + # Exclude stays with age < 18 + age = getattr(stay, "age", None) + try: + if age is not None and str(age) != "> 89" and int(float(age)) < 18: + continue + except (ValueError, TypeError): + pass samples.append( { From 8b2ba75d64cee2c51dde78887cc7be1329024643 Mon Sep 17 00:00:00 2001 From: Daryl Okeke Date: Mon, 25 May 2026 02:05:24 -0500 Subject: [PATCH 2/4] Add exclude_minors flag to clinical task age filters --- pyhealth/tasks/drug_recommendation.py | 75 ++++++++-------- pyhealth/tasks/length_of_stay_prediction.py | 98 ++++++++++++++------- pyhealth/tasks/mortality_prediction.py | 50 ++++++++--- 3 files changed, 142 insertions(+), 81 deletions(-) diff --git a/pyhealth/tasks/drug_recommendation.py b/pyhealth/tasks/drug_recommendation.py index 8d08375a9..4a576c9bd 100644 --- a/pyhealth/tasks/drug_recommendation.py +++ b/pyhealth/tasks/drug_recommendation.py @@ -339,7 +339,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: return samples -def drug_recommendation_mimic3_fn(patient: Patient): +def drug_recommendation_mimic3_fn(patient: Patient, exclude_minors: bool = True): """Processes a single patient for the drug recommendation task. Drug recommendation aims at recommending a set of drugs given the patient health @@ -382,17 +382,18 @@ def drug_recommendation_mimic3_fn(patient: Patient): samples = [] # Skip pediatric patients - demographics = patient.get_events(event_type="patients") - admissions = patient.get_events(event_type="admissions") - if demographics and admissions: - dob_str = getattr(demographics[0], "dob", None) - if dob_str is not None: - try: - dob = datetime.strptime(str(dob_str)[:10], "%Y-%m-%d") - if (admissions[0].timestamp - dob).days / 365.25 < 18: - return [] - except (ValueError, TypeError): - pass + if exclude_minors: + demographics = patient.get_events(event_type="patients") + admissions = patient.get_events(event_type="admissions") + if demographics and admissions: + dob_str = getattr(demographics[0], "dob", None) + if dob_str is not None: + try: + dob = datetime.strptime(str(dob_str)[:10], "%Y-%m-%d") + if (admissions[0].timestamp - dob).days / 365.25 < 18: + return [] + except (ValueError, TypeError): + pass for i in range(len(patient)): visit: Visit = patient[i] @@ -439,7 +440,7 @@ def drug_recommendation_mimic3_fn(patient: Patient): return samples -def drug_recommendation_mimic4_fn(patient: Patient): +def drug_recommendation_mimic4_fn(patient: Patient, exclude_minors: bool = True): """Processes a single patient for the drug recommendation task. Drug recommendation aims at recommending a set of drugs given the patient health @@ -475,14 +476,15 @@ def drug_recommendation_mimic4_fn(patient: Patient): samples = [] # Skip pediatric patients - demographics = patient.get_events(event_type="patients") - if demographics: - anchor_age = getattr(demographics[0], "anchor_age", None) - try: - if anchor_age is not None and int(float(anchor_age)) < 18: - return [] - except (ValueError, TypeError): - pass + if exclude_minors: + demographics = patient.get_events(event_type="patients") + if demographics: + anchor_age = getattr(demographics[0], "anchor_age", None) + try: + if anchor_age is not None and int(float(anchor_age)) < 18: + return [] + except (ValueError, TypeError): + pass for i in range(len(patient)): visit: Visit = patient[i] @@ -668,7 +670,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: return samples -def drug_recommendation_omop_fn(patient: Patient): +def drug_recommendation_omop_fn(patient: Patient, exclude_minors: bool = True): """Processes a single patient for the drug recommendation task. Drug recommendation aims at recommending a set of drugs given the patient health @@ -697,20 +699,21 @@ def drug_recommendation_omop_fn(patient: Patient): samples = [] # Skip pediatric patients - demographics = patient.get_events(event_type="person") - if demographics: - person = demographics[0] - birth_year = getattr(person, "year_of_birth", None) - if birth_year is not None: - try: - birth_month = int(getattr(person, "month_of_birth", None) or 1) - birth_day = int(getattr(person, "day_of_birth", None) or 1) - dob = datetime(int(birth_year), birth_month, birth_day) - visits = patient.get_events(event_type="visit_occurrence") - if visits and (visits[0].timestamp - dob).days / 365.25 < 18: - return [] - except (ValueError, TypeError): - pass + if exclude_minors: + demographics = patient.get_events(event_type="person") + if demographics: + person = demographics[0] + birth_year = getattr(person, "year_of_birth", None) + if birth_year is not None: + try: + birth_month = int(getattr(person, "month_of_birth", None) or 1) + birth_day = int(getattr(person, "day_of_birth", None) or 1) + dob = datetime(int(birth_year), birth_month, birth_day) + visits = patient.get_events(event_type="visit_occurrence") + if visits and (visits[0].timestamp - dob).days / 365.25 < 18: + return [] + except (ValueError, TypeError): + pass for i in range(len(patient)): visit: Visit = patient[i] diff --git a/pyhealth/tasks/length_of_stay_prediction.py b/pyhealth/tasks/length_of_stay_prediction.py index 446c6f6ad..06bfbc74d 100644 --- a/pyhealth/tasks/length_of_stay_prediction.py +++ b/pyhealth/tasks/length_of_stay_prediction.py @@ -73,6 +73,17 @@ class LengthOfStayPredictionMIMIC3(BaseTask): } output_schema: Dict[str, str] = {"los": "multiclass"} + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude admissions where the patient + was under 18 years old. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + def __call__(self, patient: Patient) -> List[Dict]: samples = [] @@ -117,16 +128,17 @@ def __call__(self, patient: Patient) -> List[Dict]: los_category = categorize_los(los_days) # Skip pediatric admissions - demographics = patient.get_events(event_type="patients") - if demographics: - dob_str = getattr(demographics[0], "dob", None) - if dob_str is not None: - try: - dob = datetime.strptime(str(dob_str)[:10], "%Y-%m-%d") - if (admit_time - dob).days / 365.25 < 18: - continue - except (ValueError, TypeError): - pass + if self.exclude_minors: + demographics = patient.get_events(event_type="patients") + if demographics: + dob_str = getattr(demographics[0], "dob", None) + if dob_str is not None: + try: + dob = datetime.strptime(str(dob_str)[:10], "%Y-%m-%d") + if (admit_time - dob).days / 365.25 < 18: + continue + except (ValueError, TypeError): + pass samples.append( { @@ -182,6 +194,17 @@ class LengthOfStayPredictionMIMIC4(BaseTask): } output_schema: Dict[str, str] = {"los": "multiclass"} + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude admissions where the patient + was under 18 years old. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + def __call__(self, patient: Patient) -> List[Dict]: samples = [] @@ -231,14 +254,15 @@ def __call__(self, patient: Patient) -> List[Dict]: los_category = categorize_los(los_days) # Skip pediatric admissions - demographics = patient.get_events(event_type="patients") - if demographics: - anchor_age = getattr(demographics[0], "anchor_age", None) - try: - if anchor_age is not None and int(float(anchor_age)) < 18: - continue - except (ValueError, TypeError): - pass + if self.exclude_minors: + demographics = patient.get_events(event_type="patients") + if demographics: + anchor_age = getattr(demographics[0], "anchor_age", None) + try: + if anchor_age is not None and int(float(anchor_age)) < 18: + continue + except (ValueError, TypeError): + pass samples.append( { @@ -422,6 +446,17 @@ class LengthOfStayPredictionOMOP(BaseTask): } output_schema: Dict[str, str] = {"los": "multiclass"} + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude visits where the patient + was under 18 years old. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + def __call__(self, patient: Patient) -> List[Dict]: samples = [] @@ -468,19 +503,20 @@ def __call__(self, patient: Patient) -> List[Dict]: los_category = categorize_los(los_days) # Skip pediatric visits - demographics = patient.get_events(event_type="person") - if demographics: - person = demographics[0] - birth_year = getattr(person, "year_of_birth", None) - if birth_year is not None: - try: - birth_month = int(getattr(person, "month_of_birth", None) or 1) - birth_day = int(getattr(person, "day_of_birth", None) or 1) - dob = datetime(int(birth_year), birth_month, birth_day) - if (admit_time - dob).days / 365.25 < 18: - continue - except (ValueError, TypeError): - pass + if self.exclude_minors: + demographics = patient.get_events(event_type="person") + if demographics: + person = demographics[0] + birth_year = getattr(person, "year_of_birth", None) + if birth_year is not None: + try: + birth_month = int(getattr(person, "month_of_birth", None) or 1) + birth_day = int(getattr(person, "day_of_birth", None) or 1) + dob = datetime(int(birth_year), birth_month, birth_day) + if (admit_time - dob).days / 365.25 < 18: + continue + except (ValueError, TypeError): + pass samples.append( { diff --git a/pyhealth/tasks/mortality_prediction.py b/pyhealth/tasks/mortality_prediction.py index 7c1a7c59d..f4f1aac55 100644 --- a/pyhealth/tasks/mortality_prediction.py +++ b/pyhealth/tasks/mortality_prediction.py @@ -811,6 +811,17 @@ class MortalityPredictionEICU(BaseTask): } output_schema: Dict[str, str] = {"mortality": "binary"} + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude stays where the patient + was under 18 years old. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Processes a single patient for the mortality prediction task. @@ -875,13 +886,13 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: if len(conditions) * len(procedures_list) * len(drugs) == 0: continue - # Exclude stays with age < 18 - age = getattr(stay, "age", None) - try: - if age is not None and str(age) != "> 89" and int(float(age)) < 18: - continue - except (ValueError, TypeError): - pass + if self.exclude_minors: + age = getattr(stay, "age", None) + try: + if age is not None and str(age) != "> 89" and int(float(age)) < 18: + continue + except (ValueError, TypeError): + pass samples.append( { @@ -922,6 +933,17 @@ class MortalityPredictionEICU2(BaseTask): input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence"} output_schema: Dict[str, str] = {"mortality": "binary"} + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude stays where the patient + was under 18 years old. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Processes a single patient for the mortality prediction task. @@ -997,13 +1019,13 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: if len(conditions) * len(treatment_codes) == 0: continue - # Exclude stays with age < 18 - age = getattr(stay, "age", None) - try: - if age is not None and str(age) != "> 89" and int(float(age)) < 18: - continue - except (ValueError, TypeError): - pass + if self.exclude_minors: + age = getattr(stay, "age", None) + try: + if age is not None and str(age) != "> 89" and int(float(age)) < 18: + continue + except (ValueError, TypeError): + pass samples.append( { From face2002a7cb2fbfa043e21065515e460748aaec Mon Sep 17 00:00:00 2001 From: Daryl Okeke Date: Mon, 25 May 2026 13:29:11 -0500 Subject: [PATCH 3/4] Add exclude_minors to LengthOfStayPredictioneICU --- pyhealth/tasks/length_of_stay_prediction.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/pyhealth/tasks/length_of_stay_prediction.py b/pyhealth/tasks/length_of_stay_prediction.py index 06bfbc74d..1f99895e7 100644 --- a/pyhealth/tasks/length_of_stay_prediction.py +++ b/pyhealth/tasks/length_of_stay_prediction.py @@ -319,6 +319,17 @@ class LengthOfStayPredictioneICU(BaseTask): } output_schema: Dict[str, str] = {"los": "multiclass"} + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude stays where the patient + was under 18 years old. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + def __call__(self, patient: Patient) -> List[Dict]: samples = [] @@ -377,6 +388,15 @@ def __call__(self, patient: Patient) -> List[Dict]: if len(conditions) * len(procedures) * len(drugs) == 0: continue + # Skip pediatric stays + if self.exclude_minors: + age = getattr(stay, "age", None) + try: + if age is not None and str(age) != "> 89" and int(float(age)) < 18: + continue + except (ValueError, TypeError): + pass + # --- Length of stay --- # unitdischargeoffset is the number of minutes from ICU admission # to ICU discharge. This directly gives us the ICU LOS. From 1f52e3c9aa6dbcf79d1d2783825c62bc5a999f66 Mon Sep 17 00:00:00 2001 From: Daryl Okeke Date: Mon, 25 May 2026 22:48:18 -0400 Subject: [PATCH 4/4] Resolve trainer mode from output_schema instead of model.mode --- pyhealth/trainer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pyhealth/trainer.py b/pyhealth/trainer.py index bc6a28677..dcc183e03 100644 --- a/pyhealth/trainer.py +++ b/pyhealth/trainer.py @@ -313,6 +313,15 @@ def inference(self, dataloader, additional_outputs=None, outputs.append(patient_ids) return outputs + def _get_mode(self) -> Optional[str]: + try: + label_key = self.model.label_keys[0] + return self.model._resolve_mode( + self.model.dataset.output_schema[label_key] + ) + except (AttributeError, IndexError, KeyError, ValueError): + return getattr(self.model, "mode", None) + def evaluate(self, dataloader) -> Dict[str, float]: """Evaluates the model. @@ -322,9 +331,9 @@ def evaluate(self, dataloader) -> Dict[str, float]: Returns: scores: a dictionary of scores. """ - if self.model.mode is not None: + mode = self._get_mode() + if mode is not None: y_true_all, y_prob_all, loss_mean = self.inference(dataloader) - mode = self.model.mode metrics_fn = get_metrics_fn(mode) scores = metrics_fn(y_true_all, y_prob_all, metrics=self.metrics) scores["loss"] = loss_mean