diff --git a/pyhealth/tasks/drug_recommendation.py b/pyhealth/tasks/drug_recommendation.py index ec113e1dd..4a576c9bd 100644 --- a/pyhealth/tasks/drug_recommendation.py +++ b/pyhealth/tasks/drug_recommendation.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Any, Dict, Iterable, List, Optional import polars as pl @@ -338,7 +339,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: return samples -def drug_recommendation_mimic3_fn(patient: Patient): +def drug_recommendation_mimic3_fn(patient: Patient, exclude_minors: bool = True): """Processes a single patient for the drug recommendation task. Drug recommendation aims at recommending a set of drugs given the patient health @@ -379,6 +380,21 @@ def drug_recommendation_mimic3_fn(patient: Patient): } """ samples = [] + + # Skip pediatric patients + if exclude_minors: + demographics = patient.get_events(event_type="patients") + admissions = patient.get_events(event_type="admissions") + if demographics and admissions: + dob_str = getattr(demographics[0], "dob", None) + if dob_str is not None: + try: + dob = datetime.strptime(str(dob_str)[:10], "%Y-%m-%d") + if (admissions[0].timestamp - dob).days / 365.25 < 18: + return [] + except (ValueError, TypeError): + pass + for i in range(len(patient)): visit: Visit = patient[i] conditions = visit.get_code_list(table="DIAGNOSES_ICD") @@ -388,7 +404,6 @@ def drug_recommendation_mimic3_fn(patient: Patient): # exclude: visits without condition, procedure, or drug code if len(conditions) * len(procedures) * len(drugs) == 0: continue - # TODO: should also exclude visit with age < 18 samples.append( { "visit_id": visit.visit_id, @@ -425,7 +440,7 @@ def drug_recommendation_mimic3_fn(patient: Patient): return samples -def drug_recommendation_mimic4_fn(patient: Patient): +def drug_recommendation_mimic4_fn(patient: Patient, exclude_minors: bool = True): """Processes a single patient for the drug recommendation task. Drug recommendation aims at recommending a set of drugs given the patient health @@ -459,6 +474,18 @@ def drug_recommendation_mimic4_fn(patient: Patient): [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'label': [['2', '3', '4']]}] """ samples = [] + + # Skip pediatric patients + if exclude_minors: + demographics = patient.get_events(event_type="patients") + if demographics: + anchor_age = getattr(demographics[0], "anchor_age", None) + try: + if anchor_age is not None and int(float(anchor_age)) < 18: + return [] + except (ValueError, TypeError): + pass + for i in range(len(patient)): visit: Visit = patient[i] conditions = visit.get_code_list(table="diagnoses_icd") @@ -468,7 +495,6 @@ def drug_recommendation_mimic4_fn(patient: Patient): # exclude: visits without condition, procedure, or drug code if len(conditions) * len(procedures) * len(drugs) == 0: continue - # TODO: should also exclude visit with age < 18 samples.append( { "visit_id": visit.visit_id, @@ -644,7 +670,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: return samples -def drug_recommendation_omop_fn(patient: Patient): +def drug_recommendation_omop_fn(patient: Patient, exclude_minors: bool = True): """Processes a single patient for the drug recommendation task. Drug recommendation aims at recommending a set of drugs given the patient health @@ -671,6 +697,24 @@ def drug_recommendation_omop_fn(patient: Patient): """ samples = [] + + # Skip pediatric patients + if exclude_minors: + demographics = patient.get_events(event_type="person") + if demographics: + person = demographics[0] + birth_year = getattr(person, "year_of_birth", None) + if birth_year is not None: + try: + birth_month = int(getattr(person, "month_of_birth", None) or 1) + birth_day = int(getattr(person, "day_of_birth", None) or 1) + dob = datetime(int(birth_year), birth_month, birth_day) + visits = patient.get_events(event_type="visit_occurrence") + if visits and (visits[0].timestamp - dob).days / 365.25 < 18: + return [] + except (ValueError, TypeError): + pass + for i in range(len(patient)): visit: Visit = patient[i] conditions = visit.get_code_list(table="condition_occurrence") @@ -679,7 +723,6 @@ def drug_recommendation_omop_fn(patient: Patient): # exclude: visits without condition, procedure, or drug code if len(conditions) * len(procedures) * len(drugs) == 0: continue - # TODO: should also exclude visit with age < 18 samples.append( { "visit_id": visit.visit_id, diff --git a/pyhealth/tasks/length_of_stay_prediction.py b/pyhealth/tasks/length_of_stay_prediction.py index 25e0c3121..1f99895e7 100644 --- a/pyhealth/tasks/length_of_stay_prediction.py +++ b/pyhealth/tasks/length_of_stay_prediction.py @@ -73,6 +73,17 @@ class LengthOfStayPredictionMIMIC3(BaseTask): } output_schema: Dict[str, str] = {"los": "multiclass"} + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude admissions where the patient + was under 18 years old. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + def __call__(self, patient: Patient) -> List[Dict]: samples = [] @@ -116,7 +127,19 @@ def __call__(self, patient: Patient) -> List[Dict]: los_days = (discharge_time - admit_time).days los_category = categorize_los(los_days) - # TODO: should also exclude visit with age < 18 + # Skip pediatric admissions + if self.exclude_minors: + demographics = patient.get_events(event_type="patients") + if demographics: + dob_str = getattr(demographics[0], "dob", None) + if dob_str is not None: + try: + dob = datetime.strptime(str(dob_str)[:10], "%Y-%m-%d") + if (admit_time - dob).days / 365.25 < 18: + continue + except (ValueError, TypeError): + pass + samples.append( { "visit_id": admission.hadm_id, @@ -171,6 +194,17 @@ class LengthOfStayPredictionMIMIC4(BaseTask): } output_schema: Dict[str, str] = {"los": "multiclass"} + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude admissions where the patient + was under 18 years old. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + def __call__(self, patient: Patient) -> List[Dict]: samples = [] @@ -219,7 +253,17 @@ def __call__(self, patient: Patient) -> List[Dict]: los_days = (discharge_time - admit_time).days los_category = categorize_los(los_days) - # TODO: should also exclude visit with age < 18 + # Skip pediatric admissions + if self.exclude_minors: + demographics = patient.get_events(event_type="patients") + if demographics: + anchor_age = getattr(demographics[0], "anchor_age", None) + try: + if anchor_age is not None and int(float(anchor_age)) < 18: + continue + except (ValueError, TypeError): + pass + samples.append( { "visit_id": admission.hadm_id, @@ -275,6 +319,17 @@ class LengthOfStayPredictioneICU(BaseTask): } output_schema: Dict[str, str] = {"los": "multiclass"} + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude stays where the patient + was under 18 years old. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + def __call__(self, patient: Patient) -> List[Dict]: samples = [] @@ -333,6 +388,15 @@ def __call__(self, patient: Patient) -> List[Dict]: if len(conditions) * len(procedures) * len(drugs) == 0: continue + # Skip pediatric stays + if self.exclude_minors: + age = getattr(stay, "age", None) + try: + if age is not None and str(age) != "> 89" and int(float(age)) < 18: + continue + except (ValueError, TypeError): + pass + # --- Length of stay --- # unitdischargeoffset is the number of minutes from ICU admission # to ICU discharge. This directly gives us the ICU LOS. @@ -402,6 +466,17 @@ class LengthOfStayPredictionOMOP(BaseTask): } output_schema: Dict[str, str] = {"los": "multiclass"} + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude visits where the patient + was under 18 years old. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + def __call__(self, patient: Patient) -> List[Dict]: samples = [] @@ -447,7 +522,22 @@ def __call__(self, patient: Patient) -> List[Dict]: los_days = (discharge_time - admit_time).days los_category = categorize_los(los_days) - # TODO: should also exclude visit with age < 18 + # Skip pediatric visits + if self.exclude_minors: + demographics = patient.get_events(event_type="person") + if demographics: + person = demographics[0] + birth_year = getattr(person, "year_of_birth", None) + if birth_year is not None: + try: + birth_month = int(getattr(person, "month_of_birth", None) or 1) + birth_day = int(getattr(person, "day_of_birth", None) or 1) + dob = datetime(int(birth_year), birth_month, birth_day) + if (admit_time - dob).days / 365.25 < 18: + continue + except (ValueError, TypeError): + pass + samples.append( { "visit_id": visit.visit_occurrence_id, diff --git a/pyhealth/tasks/mortality_prediction.py b/pyhealth/tasks/mortality_prediction.py index 249f717f3..f4f1aac55 100644 --- a/pyhealth/tasks/mortality_prediction.py +++ b/pyhealth/tasks/mortality_prediction.py @@ -811,6 +811,17 @@ class MortalityPredictionEICU(BaseTask): } output_schema: Dict[str, str] = {"mortality": "binary"} + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude stays where the patient + was under 18 years old. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Processes a single patient for the mortality prediction task. @@ -875,7 +886,13 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: if len(conditions) * len(procedures_list) * len(drugs) == 0: continue - # TODO: Exclude visits with age < 18 + if self.exclude_minors: + age = getattr(stay, "age", None) + try: + if age is not None and str(age) != "> 89" and int(float(age)) < 18: + continue + except (ValueError, TypeError): + pass samples.append( { @@ -916,6 +933,17 @@ class MortalityPredictionEICU2(BaseTask): input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence"} output_schema: Dict[str, str] = {"mortality": "binary"} + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude stays where the patient + was under 18 years old. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Processes a single patient for the mortality prediction task. @@ -991,7 +1019,13 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: if len(conditions) * len(treatment_codes) == 0: continue - # TODO: Exclude visits with age < 18 + if self.exclude_minors: + age = getattr(stay, "age", None) + try: + if age is not None and str(age) != "> 89" and int(float(age)) < 18: + continue + except (ValueError, TypeError): + pass samples.append( { diff --git a/pyhealth/trainer.py b/pyhealth/trainer.py index bc6a28677..dcc183e03 100644 --- a/pyhealth/trainer.py +++ b/pyhealth/trainer.py @@ -313,6 +313,15 @@ def inference(self, dataloader, additional_outputs=None, outputs.append(patient_ids) return outputs + def _get_mode(self) -> Optional[str]: + try: + label_key = self.model.label_keys[0] + return self.model._resolve_mode( + self.model.dataset.output_schema[label_key] + ) + except (AttributeError, IndexError, KeyError, ValueError): + return getattr(self.model, "mode", None) + def evaluate(self, dataloader) -> Dict[str, float]: """Evaluates the model. @@ -322,9 +331,9 @@ def evaluate(self, dataloader) -> Dict[str, float]: Returns: scores: a dictionary of scores. """ - if self.model.mode is not None: + mode = self._get_mode() + if mode is not None: y_true_all, y_prob_all, loss_mean = self.inference(dataloader) - mode = self.model.mode metrics_fn = get_metrics_fn(mode) scores = metrics_fn(y_true_all, y_prob_all, metrics=self.metrics) scores["loss"] = loss_mean