From 0c9cbcd94b1938f6eab5dad47de96933186a62f2 Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Sat, 2 May 2026 18:41:03 -0700 Subject: [PATCH 01/15] adds bias intervention algorithm and test --- pyproject.toml | 1 + .../block_based_trial_generator.py | 125 +++++++- .../test_block_based_trial_generator.py | 267 +++++++++++++++++- uv.lock | 86 ++++++ 4 files changed, 474 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d1e74e1..d8e8597 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ version = "0.0.2rc32" readme = {file = "README.md", content-type = "text/markdown"} dependencies = [ + "aind-dynamic-foraging-models>=0.13.1", "aind_behavior_services>=0.13.5", "pydantic-settings", ] diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index 3729188..db4e933 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -10,6 +10,7 @@ TruncationParameters, ) from aind_behavior_services.task.distributions_utils import draw_sample +from aind_dynamic_foraging_models.logistic_regression import fit_logistic_regression from pydantic import BaseModel, Field from ..trial_models import Trial @@ -33,6 +34,22 @@ class AutoWaterParameters(BaseModel): ) +class BiasThreshold(BaseModel): + upper: float = Field(default=0.7, ge=0, le=1, description="Absolute value of the upper bias threshold.") + lower: float = Field(default=0.3, ge=0, le=1, description="Absolute value of the lower bias threshold.") + + +class AntiBiasParameters(BaseModel): + threshold: BiasThreshold = Field( + default=BiasThreshold(), validate_default=True, description="Thresholds for bias correction." + ) + intervention_interval: int = Field(default=10, ge=0, description="Trials between bias intervention.") + maximum_water_corrections: int = Field(default=5, ge=0, description="Number of water correction to attempt.") + volume: int = Field(default=1, ge=0, description="Volume in ul of water given.") + bias_window_length: int = Field(default=200, ge=0, description="Trials to calculate bias over.") + lickspout_offset_delta: float = Field(default=0.5, ge=0, description="Absolute value of delta to move stage.") + + class Block(BaseModel): p_right_reward: float = Field(ge=0, le=1, description="Reward probability for right side during block.") p_left_reward: float = Field(ge=0, le=1, description="Reward probability for left side during block.") @@ -81,6 +98,12 @@ class BlockBasedTrialGeneratorSpec(BaseTrialGeneratorSpecModel): description="Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", ) + antibias_parameters: Optional[AntiBiasParameters] = Field( + default=AntiBiasParameters(), + validate_default=True, + description="Antibias settings. If set, trial generator with give water and move lickspouts to combat bias.", + ) + is_baiting: bool = Field(default=False, description="Whether uncollected rewards carry over to the next trial.") def create_generator(self) -> "BlockBasedTrialGenerator": @@ -100,7 +123,9 @@ class BlockBasedTrialGenerator(ITrialGenerator, ABC): reward_history: Record of whether each trial resulted in a reward. is_left_baited: Whether the left port currently has a baited reward. is_right_baited: Whether the right port currently has a baited reward. - + trials_in_bias_intervention: trials elapsed since last bias intervention + water_corrections: number of water corrections applied to combat bias + bias: bias of session. Negative values correspond to left bias, positive right. """ def __init__(self, spec: BlockBasedTrialGeneratorSpec) -> None: @@ -117,6 +142,12 @@ def __init__(self, spec: BlockBasedTrialGeneratorSpec) -> None: self.is_right_baited: bool = False self.block: Block + # antibias parameters + self.trials_in_bias_intervention = 0 + self.water_corrections = 0 + self.bias: float + self.total_lickspout_offset = 0 + def update(self, outcome: TrialOutcome | str): """Updates generator state from the previous trial outcome. Records choice and reward history and manages baiting state. Args: @@ -170,10 +201,21 @@ def next(self) -> Trial | None: self.is_right_baited = self.block.p_right_reward > random_numbers[1] or self.is_right_baited logger.debug("Right baited: %s" % self.is_right_baited) + is_auto_response_right = None + # determine autowater - is_right_autowater = None if self._are_autowater_conditions_met(): - is_right_autowater = True if self.block.p_right_reward > self.block.p_left_reward else False + is_auto_response_right = True if self.block.p_right_reward > self.block.p_left_reward else False + logger.debug("Delivering autowater: is_auto_response_right = %s" % is_auto_response_right) + + # determine bias correction. Overrides autowater + lickspout_offset_delta = 0 + if self._are_antibias_conditions_met(): + is_auto_response_right, lickspout_offset_delta = self._determine_antibias_intervention() + logger.debug( + "Performing bias intervention: is_auto_response_right = %s, lickspout_offset_delta = %s." + % (is_auto_response_right, lickspout_offset_delta) + ) return Trial( p_reward_left=1 if (self.is_left_baited and self.spec.is_baiting) else self.block.p_left_reward, @@ -182,7 +224,8 @@ def next(self) -> Trial | None: response_deadline_duration=self.spec.response_duration, quiescence_period_duration=quiescent, inter_trial_interval_duration=iti, - is_auto_response_right=is_right_autowater, + is_auto_response_right=is_auto_response_right, + lickspout_offset_delta=lickspout_offset_delta, ) def _are_autowater_conditions_met(self) -> bool: @@ -193,6 +236,7 @@ def _are_autowater_conditions_met(self) -> bool: """ if self.spec.autowater_parameters is None: # autowater disabled + logger.debug("Autowater not configured.") return False min_ignore = self.spec.autowater_parameters.min_ignored_trials @@ -200,14 +244,86 @@ def _are_autowater_conditions_met(self) -> bool: is_ignored = [choice is None for choice in self.is_right_choice_history] if all(is_ignored[-min_ignore:]): + logger.debug("Past %s trials ignored." % min_ignore) return True is_unrewarded = [not reward for reward in self.reward_history] if all(is_unrewarded[-min_unreward:]): + logger.debug("Past %s trials unrewarded." % min_unreward) return True return False + def _are_antibias_conditions_met(self) -> bool: + """Checks whether antibias conditions are met. + + Returns: + True if antibias conditions are met, False otherwise. + """ + + if self.spec.antibias_parameters is None: # antibias disabled + logger.debug("Anitbias not configured.") + return False + + if self.trials_in_bias_intervention > self.spec.antibias_parameters.intervention_interval: + # update bias + choice_history = ( + np.array( + self.is_right_choice_history[-self.spec.antibias_parameters.bias_window_length :], dtype=float + ), + ) + reward_history = self.reward_history[-self.spec.antibias_parameters.bias_window_length :] + lr = fit_logistic_regression( + choice_history=np.array(choice_history, dtype=float), + reward_history=np.array(reward_history, dtype=float), + n_trial_back=5, + cv=10, + fit_exponential=10, + ) + self.bias = lr["df_beta"].loc["bias"]["cross_validation"].values[0] + + if self.bias <= self.spec.antibias_parameters.threshold.lower: + logger.debug("Bias calculated below threshold: %s." % self.bias) + return True + + if self.bias >= self.spec.antibias_parameters.threshold.upper: + logger.debug("Bias calculated above threshold: %s." % self.bias) + return True + + return False + + def _determine_antibias_intervention(self) -> tuple[bool | None, float]: + """Determine anitbias interventions to perform: give water or move lickspouts + + Returns: + Tuple dictating is_auto_response_right and lickspout_offset_delta of trial + """ + + is_right_autowater = None + lickspout_offset_delta = 0 + if abs(self.bias) > self.spec.antibias_parameters.threshold.upper: + if self.water_corrections < self.spec.antibias_parameters.maximum_water_corrections: + logger.debug("Correcting bias with water.") + is_right_autowater = ( + True if self.bias < 0 else False + ) # - bias values corresponds to left, so give right and vice versa + self.water_corrections += 1 + else: + logger.debug("Correcting bias with lickspout offset.") + lickspout_offset_delta = 0.5 if self.bias < 0 else -0.5 # + values move lickspout right + self.water_corrections = 0 + + elif ( + abs(self.bias) < self.spec.antibias_parameters.threshold.lower and self.total_lickspout_offset != 0 + ): # bias below lower threshold, move back towards center + logger.debug("Moving lickspout back toward center.") + delta = min(0.5, abs(self.total_lickspout_offset)) + lickspout_offset_delta = -delta if self.total_lickspout_offset > 0 else delta + + self.total_lickspout_offset += lickspout_offset_delta + + return is_right_autowater, lickspout_offset_delta + @abstractmethod def _are_end_conditions_met(self) -> bool: """Checks whether the session should end. @@ -218,6 +334,7 @@ def _are_end_conditions_met(self) -> bool: """ pass + @abstractmethod def _generate_next_block(*args, **kwargs) -> Block: """Abstract method. Subclasses must implement their own block switching logic. diff --git a/tests/trial_generators/test_block_based_trial_generator.py b/tests/trial_generators/test_block_based_trial_generator.py index 60efd2d..f063813 100644 --- a/tests/trial_generators/test_block_based_trial_generator.py +++ b/tests/trial_generators/test_block_based_trial_generator.py @@ -1,6 +1,7 @@ import logging import unittest -from unittest.mock import patch +from unittest.mock import patch, MagicMock +import pandas as pd import numpy as np @@ -8,6 +9,9 @@ Block, BlockBasedTrialGenerator, BlockBasedTrialGeneratorSpec, + AntiBiasParameters, + AutoWaterParameters, + BiasThreshold ) from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial @@ -60,6 +64,267 @@ def test_baiting_disabled_reward_prob_unchanged(self): self.assertEqual(trial.p_reward_right, 0.8) self.assertEqual(trial.p_reward_left, 0.2) +class TestAntiBiasBlockBasedTrialGenerator(unittest.TestCase): + + def _patch_bias(self, bias_value: float) -> dict: + + return_value ={"df_beta": MagicMock()} + return_value["df_beta"].loc = {"bias": {"cross_validation": MagicMock()}} + return_value["df_beta"].loc["bias"]["cross_validation"].values = [bias_value] + + return patch( + "aind_behavior_dynamic_foraging.task_logic.trial_generators.block_based_trial_generator.fit_logistic_regression", + return_value=return_value + ) + + def _make_generator(self, + bias: float, + trials_in_bias_intervention: int = 15, + water_corrections: int = 0, + maximum_water_corrections: int = 5, + bias_window_length: int = 5, + intervention_interval: int = 10, + total_offset: float = 0.0, + threshold: BiasThreshold = BiasThreshold(upper=0.7, lower=.3) + ) -> ConcreteBlockBasedTrialGenerator: + ab = AntiBiasParameters( + maximum_water_corrections=maximum_water_corrections, + bias_window_length=bias_window_length, + intervention_interval=intervention_interval, + threshold=threshold + ) + spec = ConcreteBlockBasedTrialGeneratorSpec(antibias_parameters=ab) + gen = spec.create_generator() + gen.block = Block(p_left_reward=0.2, p_right_reward=0.8, left_length=10, right_length=10) + gen.total_lickspout_offset = total_offset + gen.bias = bias + gen.trials_in_bias_intervention = trials_in_bias_intervention + gen.water_corrections = water_corrections + gen.is_right_choice_history = [True] * 100 + gen.reward_history = [True] * 100 + return gen + + def test_returns_false_when_antibias_disabled(self): + """Antibias should never trigger when antibias_parameters is None.""" + spec = ConcreteBlockBasedTrialGeneratorSpec(antibias_parameters=None) + gen = spec.create_generator() + self.assertFalse(gen._are_antibias_conditions_met()) + + def test_returns_false_before_intervention_interval(self): + """Condition should not trigger before the intervention interval is exceeded.""" + + gen = self._make_generator(bias=.5, intervention_interval=10) + gen.trials_in_bias_intervention = 5 + self.assertFalse(gen._are_antibias_conditions_met()) + + def test_returns_false_when_bias_within_thresholds(self): + """No intervention when bias sits between lower and upper thresholds.""" + gen = self._make_generator( + bias=.5, + intervention_interval=10, + threshold=BiasThreshold(upper=0.7, lower=0.3), + bias_window_length=5, + ) + gen.trials_in_bias_intervention = 15 + gen.is_right_choice_history = [True, False] * 50 + gen.reward_history = [True] * 100 + + with self._patch_bias(0.5): + result = gen._are_antibias_conditions_met() + + self.assertFalse(result) + + def test_returns_true_when_bias_above_upper_threshold(self): + """Intervention when bias is above threshold""" + gen = self._make_generator( + bias = .9, + intervention_interval=10, + threshold=BiasThreshold(upper=0.7, lower=0.3), + bias_window_length=5, + ) + gen.trials_in_bias_intervention = 15 + gen.is_right_choice_history = [True] * 100 + gen.reward_history = [True] * 100 + + with self._patch_bias(0.9): + result = gen._are_antibias_conditions_met() + + self.assertTrue(result) + + def test_returns_true_when_bias_below_lower_threshold(self): + """Intervention when bias is below threshold""" + gen = self._make_generator( + bias=.2, + intervention_interval=10, + threshold=BiasThreshold(upper=0.7, lower=0.3), + bias_window_length=5, + ) + gen.trials_in_bias_intervention = 15 + gen.is_right_choice_history = [False] * 100 + gen.reward_history = [False] * 100 + + with self._patch_bias(0.9): + result = gen._are_antibias_conditions_met() + + self.assertTrue(result) + + def test_bias_stored_on_generator_after_check(self): + """The computed bias value should be saved on the generator.""" + gen = self._make_generator( + bias = 0, + intervention_interval=10, + bias_window_length=5, + ) + + with self._patch_bias(0.42): + gen._are_antibias_conditions_met() + + self.assertAlmostEqual(gen.bias, 0.42) + + def test_gives_right_water_on_left_bias(self): + """Negative bias (left bias) → give right water.""" + gen = self._make_generator(bias=-0.9, maximum_water_corrections=5) + is_right, delta = gen._determine_antibias_intervention() + self.assertTrue(is_right) + self.assertEqual(delta, 0.0) + + def test_gives_left_water_on_right_bias(self): + """Positive bias (right bias) → give left water.""" + gen = self._make_generator(bias=0.9, maximum_water_corrections=5) + is_right, delta = gen._determine_antibias_intervention() + self.assertFalse(is_right) + self.assertEqual(delta, 0.0) + + def test_water_corrections_counter_increments(self): + gen = self._make_generator(bias=-0.9, water_corrections=2, maximum_water_corrections=5) + gen._determine_antibias_intervention() + self.assertEqual(gen.water_corrections, 3) + + def test_switches_to_lickspout_after_max_corrections_left_bias(self): + """After exhausting water corrections, move lickspout right (combat left bias).""" + gen = self._make_generator(bias=-0.9, water_corrections=5, maximum_water_corrections=5) + is_right, delta = gen._determine_antibias_intervention() + self.assertIsNone(is_right) + self.assertGreater(delta, 0) + + def test_switches_to_lickspout_after_max_corrections_right_bias(self): + """After exhausting water corrections, move lickspout left (combat right bias).""" + gen = self._make_generator(bias=0.9, water_corrections=5, maximum_water_corrections=5) + is_right, delta = gen._determine_antibias_intervention() + self.assertIsNone(is_right) + self.assertLess(delta, 0) + + def test_water_corrections_reset_after_lickspout_move(self): + gen = self._make_generator(bias=-0.9, water_corrections=5, maximum_water_corrections=5) + gen._determine_antibias_intervention() + self.assertEqual(gen.water_corrections, 0) + + # #### Test lickspout centering #### + + def test_no_centering_when_offset_is_zero(self): + """No correction when already centered, even if bias drops below lower threshold.""" + gen = self._make_generator( + bias=0.1, + total_offset=0.0, + threshold=BiasThreshold(upper=0.7, lower=0.3), + ) + _, delta = gen._determine_antibias_intervention() + self.assertEqual(delta, 0.0) + + def test_centering_moves_toward_zero_from_positive_offset(self): + """Positive offset + low bias → negative delta (move back left).""" + gen = self._make_generator( + bias=0.1, + total_offset=1.0, + threshold=BiasThreshold(upper=0.7, lower=0.3), + ) + _, delta = gen._determine_antibias_intervention() + self.assertLess(delta, 0) + + def test_centering_moves_toward_zero_from_negative_offset(self): + """Negative offset + low bias → positive delta (move back right).""" + gen = self._make_generator( + bias=0.1, + total_offset=-1.0, + threshold=BiasThreshold(upper=0.7, lower=0.3), + ) + _, delta = gen._determine_antibias_intervention() + self.assertGreater(delta, 0) + + def test_centering_step_capped_at_offset_magnitude(self): + """Centering delta should not overshoot: capped at min(0.5, |offset|).""" + gen = self._make_generator( + bias=0.1, + total_offset=0.2, + threshold=BiasThreshold(upper=0.7, lower=0.3), + ) + _, delta = gen._determine_antibias_intervention() + self.assertLessEqual(abs(delta), 0.2) + + def test_total_lickspout_offset_updated_after_move(self): + """total_lickspout_offset should accumulate the delta applied.""" + gen = self._make_generator( + bias=-0.9, water_corrections=5, maximum_water_corrections=5, total_offset=0.0 + ) + _, delta = gen._determine_antibias_intervention() + self.assertAlmostEqual(gen.total_lickspout_offset, delta) + + #### Test next #### + + def test_next_gives_right_autowater_on_left_bias(self): + gen = self._make_generator(bias=-0.9) + with self._patch_bias(-0.9): + trial = gen.next() + self.assertIsNotNone(trial) + self.assertTrue(trial.is_auto_response_right) + + def test_next_gives_left_autowater_on_right_bias(self): + gen = self._make_generator(bias=0.9) + with self._patch_bias(0.9): + trial = gen.next() + self.assertIsNotNone(trial) + self.assertFalse(trial.is_auto_response_right) + + def test_next_no_antibias_when_below_interval(self): + """No antibias effect when trials_in_bias_intervention has not exceeded interval.""" + gen = self._make_generator(bias=-0.9, trials_in_bias_intervention=5) + trial = gen.next() + self.assertIsNone(trial.is_auto_response_right) + + def test_next_antibias_overrides_autowater(self): + """When both autowater and antibias conditions are met, antibias takes precedence.""" + ab = AntiBiasParameters( + intervention_interval=10, + threshold=BiasThreshold(upper=0.7, lower=0.3), + maximum_water_corrections=5, + bias_window_length=5, + ) + aw = AutoWaterParameters(min_ignored_trials=1, min_unrewarded_trials=1, reward_fraction=0.8) + spec = ConcreteBlockBasedTrialGeneratorSpec(antibias_parameters=ab, autowater_parameters=aw) + gen = spec.create_generator() + gen.block = Block(p_left_reward=0.2, p_right_reward=0.8, left_length=10, right_length=10) + gen.bias = -0.9 + gen.trials_in_bias_intervention = 15 + gen.is_right_choice_history = [None] # ignored trial → autowater would also fire + gen.reward_history = [False] + + with self._patch_bias(-0.9): + trial = gen.next() + + # Antibias (left bias → give right water) should win + self.assertTrue(trial.is_auto_response_right) + + def test_next_lickspout_delta_nonzero_after_corrections_exhausted(self): + """After max water corrections, next() should produce a nonzero lickspout delta.""" + gen = self._make_generator(bias=-0.9, water_corrections=5) + with self._patch_bias(-0.9): + trial = gen.next() + self.assertNotEqual(trial.lickspout_offset_delta, 0) + + def test_next_no_lickspout_delta_when_antibias_not_triggered(self): + gen = self._make_generator(bias=-0.9, trials_in_bias_intervention=5) + trial = gen.next() + self.assertEqual(trial.lickspout_offset_delta, 0) class TestBlockBaseBaitingTrialGenerator(unittest.TestCase): def setUp(self): diff --git a/uv.lock b/uv.lock index 16af59d..d8a6db8 100644 --- a/uv.lock +++ b/uv.lock @@ -51,6 +51,7 @@ version = "0.0.2rc32" source = { editable = "." } dependencies = [ { name = "aind-behavior-services" }, + { name = "aind-dynamic-foraging-models" }, { name = "pydantic-settings" }, ] @@ -79,6 +80,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "aind-behavior-services", specifier = ">=0.13.5" }, + { name = "aind-dynamic-foraging-models", specifier = ">=0.13.1" }, { name = "contraqctor", marker = "extra == 'data'", specifier = ">=0.5.3" }, { name = "pydantic-settings" }, ] @@ -242,6 +244,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/bc/c180bff4ab1ac67aeec5206afe75cc0a785980d3f38ba7482daae36231c5/aind_data_transfer_service-2.0.3-py3-none-any.whl", hash = "sha256:4429048c1896dba2fb4688f6fc2ca6bbae5e4e1d645417417e9ad45fd3736bc6", size = 23147, upload-time = "2026-03-23T23:51:32.237Z" }, ] +[[package]] +name = "aind-dynamic-foraging-models" +version = "0.13.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "numpy" }, + { name = "pydantic" }, + { name = "scikit-learn" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/47/a7b35997f9767cc4911ce65db39a06492192d338175487916d98b70b1640/aind_dynamic_foraging_models-0.13.1.tar.gz", hash = "sha256:260e8cacb8b55b43f12b2443f876ced4d1cde364f9938199c1b2466425cd2419", size = 6648444, upload-time = "2026-02-21T01:05:25.913Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/12/e299389a360dae91b07b936ebed085fa1a718e210319faa50d4766054b26/aind_dynamic_foraging_models-0.13.1-py3-none-any.whl", hash = "sha256:34602f2ee02fa1d5b48449f05980b2ddc37c4f770608a2df3989875856571a13", size = 50200, upload-time = "2026-02-21T01:05:24.632Z" }, +] + [[package]] name = "aind-watchdog-service" version = "0.1.6" @@ -1086,6 +1104,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "joblib" +version = "1.5.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/41/f2/d34e8b3a08a9cc79a50b2208a93dce981fe615b64d5a4d4abee421d898df/joblib-1.5.3.tar.gz", hash = "sha256:8561a3269e6801106863fd0d6d84bb737be9e7631e33aaed3fb9ce5953688da3", size = 331603, upload-time = "2025-12-15T08:41:46.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl", hash = "sha256:5fc3c5039fc5ca8c0276333a188bbd59d6b7ab37fe6632daa76bc7f9ec18e713", size = 309071, upload-time = "2025-12-15T08:41:44.973Z" }, +] + [[package]] name = "jsonpointer" version = "3.1.1" @@ -2506,6 +2533,56 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/e8/726643a3ea68c727da31570bde48c7a10f1aa60eddd628d94078fec586ff/ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2", size = 11023304, upload-time = "2026-03-19T16:26:51.669Z" }, ] +[[package]] +name = "scikit-learn" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "joblib" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "threadpoolctl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/d4/40988bf3b8e34feec1d0e6a051446b1f66225f8529b9309becaeef62b6c4/scikit_learn-1.8.0.tar.gz", hash = "sha256:9bccbb3b40e3de10351f8f5068e105d0f4083b1a65fa07b6634fbc401a6287fd", size = 7335585, upload-time = "2025-12-10T07:08:53.618Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/92/53ea2181da8ac6bf27170191028aee7251f8f841f8d3edbfdcaf2008fde9/scikit_learn-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:146b4d36f800c013d267b29168813f7a03a43ecd2895d04861f1240b564421da", size = 8595835, upload-time = "2025-12-10T07:07:39.385Z" }, + { url = "https://files.pythonhosted.org/packages/01/18/d154dc1638803adf987910cdd07097d9c526663a55666a97c124d09fb96a/scikit_learn-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f984ca4b14914e6b4094c5d52a32ea16b49832c03bd17a110f004db3c223e8e1", size = 8080381, upload-time = "2025-12-10T07:07:41.93Z" }, + { url = "https://files.pythonhosted.org/packages/8a/44/226142fcb7b7101e64fdee5f49dbe6288d4c7af8abf593237b70fca080a4/scikit_learn-1.8.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5e30adb87f0cc81c7690a84f7932dd66be5bac57cfe16b91cb9151683a4a2d3b", size = 8799632, upload-time = "2025-12-10T07:07:43.899Z" }, + { url = "https://files.pythonhosted.org/packages/36/4d/4a67f30778a45d542bbea5db2dbfa1e9e100bf9ba64aefe34215ba9f11f6/scikit_learn-1.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ada8121bcb4dac28d930febc791a69f7cb1673c8495e5eee274190b73a4559c1", size = 9103788, upload-time = "2025-12-10T07:07:45.982Z" }, + { url = "https://files.pythonhosted.org/packages/89/3c/45c352094cfa60050bcbb967b1faf246b22e93cb459f2f907b600f2ceda5/scikit_learn-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:c57b1b610bd1f40ba43970e11ce62821c2e6569e4d74023db19c6b26f246cb3b", size = 8081706, upload-time = "2025-12-10T07:07:48.111Z" }, + { url = "https://files.pythonhosted.org/packages/3d/46/5416595bb395757f754feb20c3d776553a386b661658fb21b7c814e89efe/scikit_learn-1.8.0-cp311-cp311-win_arm64.whl", hash = "sha256:2838551e011a64e3053ad7618dda9310175f7515f1742fa2d756f7c874c05961", size = 7688451, upload-time = "2025-12-10T07:07:49.873Z" }, + { url = "https://files.pythonhosted.org/packages/90/74/e6a7cc4b820e95cc38cf36cd74d5aa2b42e8ffc2d21fe5a9a9c45c1c7630/scikit_learn-1.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5fb63362b5a7ddab88e52b6dbb47dac3fd7dafeee740dc6c8d8a446ddedade8e", size = 8548242, upload-time = "2025-12-10T07:07:51.568Z" }, + { url = "https://files.pythonhosted.org/packages/49/d8/9be608c6024d021041c7f0b3928d4749a706f4e2c3832bbede4fb4f58c95/scikit_learn-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:5025ce924beccb28298246e589c691fe1b8c1c96507e6d27d12c5fadd85bfd76", size = 8079075, upload-time = "2025-12-10T07:07:53.697Z" }, + { url = "https://files.pythonhosted.org/packages/dd/47/f187b4636ff80cc63f21cd40b7b2d177134acaa10f6bb73746130ee8c2e5/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4496bb2cf7a43ce1a2d7524a79e40bc5da45cf598dbf9545b7e8316ccba47bb4", size = 8660492, upload-time = "2025-12-10T07:07:55.574Z" }, + { url = "https://files.pythonhosted.org/packages/97/74/b7a304feb2b49df9fafa9382d4d09061a96ee9a9449a7cbea7988dda0828/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0bcfe4d0d14aec44921545fd2af2338c7471de9cb701f1da4c9d85906ab847a", size = 8931904, upload-time = "2025-12-10T07:07:57.666Z" }, + { url = "https://files.pythonhosted.org/packages/9f/c4/0ab22726a04ede56f689476b760f98f8f46607caecff993017ac1b64aa5d/scikit_learn-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:35c007dedb2ffe38fe3ee7d201ebac4a2deccd2408e8621d53067733e3c74809", size = 8019359, upload-time = "2025-12-10T07:07:59.838Z" }, + { url = "https://files.pythonhosted.org/packages/24/90/344a67811cfd561d7335c1b96ca21455e7e472d281c3c279c4d3f2300236/scikit_learn-1.8.0-cp312-cp312-win_arm64.whl", hash = "sha256:8c497fff237d7b4e07e9ef1a640887fa4fb765647f86fbe00f969ff6280ce2bb", size = 7641898, upload-time = "2025-12-10T07:08:01.36Z" }, + { url = "https://files.pythonhosted.org/packages/03/aa/e22e0768512ce9255eba34775be2e85c2048da73da1193e841707f8f039c/scikit_learn-1.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0d6ae97234d5d7079dc0040990a6f7aeb97cb7fa7e8945f1999a429b23569e0a", size = 8513770, upload-time = "2025-12-10T07:08:03.251Z" }, + { url = "https://files.pythonhosted.org/packages/58/37/31b83b2594105f61a381fc74ca19e8780ee923be2d496fcd8d2e1147bd99/scikit_learn-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:edec98c5e7c128328124a029bceb09eda2d526997780fef8d65e9a69eead963e", size = 8044458, upload-time = "2025-12-10T07:08:05.336Z" }, + { url = "https://files.pythonhosted.org/packages/2d/5a/3f1caed8765f33eabb723596666da4ebbf43d11e96550fb18bdec42b467b/scikit_learn-1.8.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:74b66d8689d52ed04c271e1329f0c61635bcaf5b926db9b12d58914cdc01fe57", size = 8610341, upload-time = "2025-12-10T07:08:07.732Z" }, + { url = "https://files.pythonhosted.org/packages/38/cf/06896db3f71c75902a8e9943b444a56e727418f6b4b4a90c98c934f51ed4/scikit_learn-1.8.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8fdf95767f989b0cfedb85f7ed8ca215d4be728031f56ff5a519ee1e3276dc2e", size = 8900022, upload-time = "2025-12-10T07:08:09.862Z" }, + { url = "https://files.pythonhosted.org/packages/1c/f9/9b7563caf3ec8873e17a31401858efab6b39a882daf6c1bfa88879c0aa11/scikit_learn-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:2de443b9373b3b615aec1bb57f9baa6bb3a9bd093f1269ba95c17d870422b271", size = 7989409, upload-time = "2025-12-10T07:08:12.028Z" }, + { url = "https://files.pythonhosted.org/packages/49/bd/1f4001503650e72c4f6009ac0c4413cb17d2d601cef6f71c0453da2732fc/scikit_learn-1.8.0-cp313-cp313-win_arm64.whl", hash = "sha256:eddde82a035681427cbedded4e6eff5e57fa59216c2e3e90b10b19ab1d0a65c3", size = 7619760, upload-time = "2025-12-10T07:08:13.688Z" }, + { url = "https://files.pythonhosted.org/packages/d2/7d/a630359fc9dcc95496588c8d8e3245cc8fd81980251079bc09c70d41d951/scikit_learn-1.8.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:7cc267b6108f0a1499a734167282c00c4ebf61328566b55ef262d48e9849c735", size = 8826045, upload-time = "2025-12-10T07:08:15.215Z" }, + { url = "https://files.pythonhosted.org/packages/cc/56/a0c86f6930cfcd1c7054a2bc417e26960bb88d32444fe7f71d5c2cfae891/scikit_learn-1.8.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:fe1c011a640a9f0791146011dfd3c7d9669785f9fed2b2a5f9e207536cf5c2fd", size = 8420324, upload-time = "2025-12-10T07:08:17.561Z" }, + { url = "https://files.pythonhosted.org/packages/46/1e/05962ea1cebc1cf3876667ecb14c283ef755bf409993c5946ade3b77e303/scikit_learn-1.8.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:72358cce49465d140cc4e7792015bb1f0296a9742d5622c67e31399b75468b9e", size = 8680651, upload-time = "2025-12-10T07:08:19.952Z" }, + { url = "https://files.pythonhosted.org/packages/fe/56/a85473cd75f200c9759e3a5f0bcab2d116c92a8a02ee08ccd73b870f8bb4/scikit_learn-1.8.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:80832434a6cc114f5219211eec13dcbc16c2bac0e31ef64c6d346cde3cf054cb", size = 8925045, upload-time = "2025-12-10T07:08:22.11Z" }, + { url = "https://files.pythonhosted.org/packages/cc/b7/64d8cfa896c64435ae57f4917a548d7ac7a44762ff9802f75a79b77cb633/scikit_learn-1.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ee787491dbfe082d9c3013f01f5991658b0f38aa8177e4cd4bf434c58f551702", size = 8507994, upload-time = "2025-12-10T07:08:23.943Z" }, + { url = "https://files.pythonhosted.org/packages/5e/37/e192ea709551799379958b4c4771ec507347027bb7c942662c7fbeba31cb/scikit_learn-1.8.0-cp313-cp313t-win_arm64.whl", hash = "sha256:bf97c10a3f5a7543f9b88cbf488d33d175e9146115a451ae34568597ba33dcde", size = 7869518, upload-time = "2025-12-10T07:08:25.71Z" }, + { url = "https://files.pythonhosted.org/packages/24/05/1af2c186174cc92dcab2233f327336058c077d38f6fe2aceb08e6ab4d509/scikit_learn-1.8.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:c22a2da7a198c28dd1a6e1136f19c830beab7fdca5b3e5c8bba8394f8a5c45b3", size = 8528667, upload-time = "2025-12-10T07:08:27.541Z" }, + { url = "https://files.pythonhosted.org/packages/a8/25/01c0af38fe969473fb292bba9dc2b8f9b451f3112ff242c647fee3d0dfe7/scikit_learn-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:6b595b07a03069a2b1740dc08c2299993850ea81cce4fe19b2421e0c970de6b7", size = 8066524, upload-time = "2025-12-10T07:08:29.822Z" }, + { url = "https://files.pythonhosted.org/packages/be/ce/a0623350aa0b68647333940ee46fe45086c6060ec604874e38e9ab7d8e6c/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:29ffc74089f3d5e87dfca4c2c8450f88bdc61b0fc6ed5d267f3988f19a1309f6", size = 8657133, upload-time = "2025-12-10T07:08:31.865Z" }, + { url = "https://files.pythonhosted.org/packages/b8/cb/861b41341d6f1245e6ca80b1c1a8c4dfce43255b03df034429089ca2a2c5/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fb65db5d7531bccf3a4f6bec3462223bea71384e2cda41da0f10b7c292b9e7c4", size = 8923223, upload-time = "2025-12-10T07:08:34.166Z" }, + { url = "https://files.pythonhosted.org/packages/76/18/a8def8f91b18cd1ba6e05dbe02540168cb24d47e8dcf69e8d00b7da42a08/scikit_learn-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:56079a99c20d230e873ea40753102102734c5953366972a71d5cb39a32bc40c6", size = 8096518, upload-time = "2025-12-10T07:08:36.339Z" }, + { url = "https://files.pythonhosted.org/packages/d1/77/482076a678458307f0deb44e29891d6022617b2a64c840c725495bee343f/scikit_learn-1.8.0-cp314-cp314-win_arm64.whl", hash = "sha256:3bad7565bc9cf37ce19a7c0d107742b320c1285df7aab1a6e2d28780df167242", size = 7754546, upload-time = "2025-12-10T07:08:38.128Z" }, + { url = "https://files.pythonhosted.org/packages/2d/d1/ef294ca754826daa043b2a104e59960abfab4cf653891037d19dd5b6f3cf/scikit_learn-1.8.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:4511be56637e46c25721e83d1a9cea9614e7badc7040c4d573d75fbe257d6fd7", size = 8848305, upload-time = "2025-12-10T07:08:41.013Z" }, + { url = "https://files.pythonhosted.org/packages/5b/e2/b1f8b05138ee813b8e1a4149f2f0d289547e60851fd1bb268886915adbda/scikit_learn-1.8.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:a69525355a641bf8ef136a7fa447672fb54fe8d60cab5538d9eb7c6438543fb9", size = 8432257, upload-time = "2025-12-10T07:08:42.873Z" }, + { url = "https://files.pythonhosted.org/packages/26/11/c32b2138a85dcb0c99f6afd13a70a951bfdff8a6ab42d8160522542fb647/scikit_learn-1.8.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c2656924ec73e5939c76ac4c8b026fc203b83d8900362eb2599d8aee80e4880f", size = 8678673, upload-time = "2025-12-10T07:08:45.362Z" }, + { url = "https://files.pythonhosted.org/packages/c7/57/51f2384575bdec454f4fe4e7a919d696c9ebce914590abf3e52d47607ab8/scikit_learn-1.8.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15fc3b5d19cc2be65404786857f2e13c70c83dd4782676dd6814e3b89dc8f5b9", size = 8922467, upload-time = "2025-12-10T07:08:47.408Z" }, + { url = "https://files.pythonhosted.org/packages/35/4d/748c9e2872637a57981a04adc038dacaa16ba8ca887b23e34953f0b3f742/scikit_learn-1.8.0-cp314-cp314t-win_amd64.whl", hash = "sha256:00d6f1d66fbcf4eba6e356e1420d33cc06c70a45bb1363cd6f6a8e4ebbbdece2", size = 8774395, upload-time = "2025-12-10T07:08:49.337Z" }, + { url = "https://files.pythonhosted.org/packages/60/22/d7b2ebe4704a5e50790ba089d5c2ae308ab6bb852719e6c3bd4f04c3a363/scikit_learn-1.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:f28dd15c6bb0b66ba09728cf09fd8736c304be29409bd8445a080c1280619e8c", size = 8002647, upload-time = "2025-12-10T07:08:51.601Z" }, +] + [[package]] name = "scipy" version = "1.17.1" @@ -2817,6 +2894,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/a7/d2782e4e3f77c8450f727ba74a8f12756d5ba823d81b941f1b04da9d033a/sphinxcontrib_serializinghtml-2.0.0-py3-none-any.whl", hash = "sha256:6e2cb0eef194e10c27ec0023bfeb25badbbb5868244cf5bc5bdc04e4464bf331", size = 92072, upload-time = "2024-07-29T01:10:08.203Z" }, ] +[[package]] +name = "threadpoolctl" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, +] + [[package]] name = "tomli" version = "2.4.0" From f4adb533b1e7a1b210a0d08674952dcd74589df1 Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Sat, 2 May 2026 18:46:49 -0700 Subject: [PATCH 02/15] specify delta --- schema/aind_behavior_dynamic_foraging.json | 181 ++++++++ .../AindBehaviorDynamicForaging.Generated.cs | 397 ++++++++++++++++++ .../block_based_trial_generator.py | 7 +- .../test_block_based_trial_generator.py | 139 +++--- 4 files changed, 651 insertions(+), 73 deletions(-) diff --git a/schema/aind_behavior_dynamic_foraging.json b/schema/aind_behavior_dynamic_foraging.json index 2107795..cd391aa 100644 --- a/schema/aind_behavior_dynamic_foraging.json +++ b/schema/aind_behavior_dynamic_foraging.json @@ -421,6 +421,55 @@ "type": "object", "x-sgen-typename": "AllenNeuralDynamics.AindManipulator.AindManipulatorCalibration" }, + "AntiBiasParameters": { + "properties": { + "threshold": { + "$ref": "#/$defs/BiasThreshold", + "default": { + "upper": 0.7, + "lower": 0.3 + }, + "description": "Thresholds for bias correction." + }, + "intervention_interval": { + "default": 10, + "description": "Trials between bias intervention.", + "minimum": 0, + "title": "Intervention Interval", + "type": "integer" + }, + "maximum_water_corrections": { + "default": 5, + "description": "Number of water correction to attempt.", + "minimum": 0, + "title": "Maximum Water Corrections", + "type": "integer" + }, + "volume": { + "default": 1, + "description": "Volume in ul of water given.", + "minimum": 0, + "title": "Volume", + "type": "integer" + }, + "bias_window_length": { + "default": 200, + "description": "Trials to calculate bias over.", + "minimum": 0, + "title": "Bias Window Length", + "type": "integer" + }, + "lickspout_offset_delta": { + "default": 0.05, + "description": "Absolute value of delta (mm) to move stage.", + "minimum": 0, + "title": "Lickspout Offset Delta", + "type": "number" + } + }, + "title": "AntiBiasParameters", + "type": "object" + }, "AuditorySecondaryReinforcer": { "description": "Represents an auditory secondary reinforcer.", "properties": { @@ -629,6 +678,28 @@ } ] }, + "antibias_parameters": { + "default": { + "threshold": { + "lower": 0.3, + "upper": 0.7 + }, + "intervention_interval": 10, + "maximum_water_corrections": 5, + "volume": 1, + "bias_window_length": 200, + "lickspout_offset_delta": 0.05 + }, + "description": "Antibias settings. If set, trial generator with give water and move lickspouts to combat bias.", + "oneOf": [ + { + "$ref": "#/$defs/AntiBiasParameters" + }, + { + "type": "null" + } + ] + }, "is_baiting": { "default": false, "description": "Whether uncollected rewards carry over to the next trial.", @@ -764,6 +835,28 @@ "type": "object", "x-sgen-typename": "AllenNeuralDynamics.AindBehaviorServices.Distributions.BetaDistributionParameters" }, + "BiasThreshold": { + "properties": { + "upper": { + "default": 0.7, + "description": "Absolute value of the upper bias threshold.", + "maximum": 1, + "minimum": 0, + "title": "Upper", + "type": "number" + }, + "lower": { + "default": 0.3, + "description": "Absolute value of the lower bias threshold.", + "maximum": 1, + "minimum": 0, + "title": "Lower", + "type": "number" + } + }, + "title": "BiasThreshold", + "type": "object" + }, "BinomialDistribution": { "description": "A binomial probability distribution.\n\nModels the number of successes in a fixed number of independent\nBernoulli trials with constant success probability.", "properties": { @@ -929,6 +1022,28 @@ } ] }, + "antibias_parameters": { + "default": { + "threshold": { + "lower": 0.3, + "upper": 0.7 + }, + "intervention_interval": 10, + "maximum_water_corrections": 5, + "volume": 1, + "bias_window_length": 200, + "lickspout_offset_delta": 0.05 + }, + "description": "Antibias settings. If set, trial generator with give water and move lickspouts to combat bias.", + "oneOf": [ + { + "$ref": "#/$defs/AntiBiasParameters" + }, + { + "type": "null" + } + ] + }, "is_baiting": { "default": false, "description": "Whether uncollected rewards carry over to the next trial.", @@ -1195,6 +1310,28 @@ } ] }, + "antibias_parameters": { + "default": { + "threshold": { + "lower": 0.3, + "upper": 0.7 + }, + "intervention_interval": 10, + "maximum_water_corrections": 5, + "volume": 1, + "bias_window_length": 200, + "lickspout_offset_delta": 0.05 + }, + "description": "Antibias settings. If set, trial generator with give water and move lickspouts to combat bias.", + "oneOf": [ + { + "$ref": "#/$defs/AntiBiasParameters" + }, + { + "type": "null" + } + ] + }, "is_baiting": { "default": false, "description": "Whether uncollected rewards carry over to the next trial.", @@ -1384,6 +1521,28 @@ } ] }, + "antibias_parameters": { + "default": { + "threshold": { + "lower": 0.3, + "upper": 0.7 + }, + "intervention_interval": 10, + "maximum_water_corrections": 5, + "volume": 1, + "bias_window_length": 200, + "lickspout_offset_delta": 0.05 + }, + "description": "Antibias settings. If set, trial generator with give water and move lickspouts to combat bias.", + "oneOf": [ + { + "$ref": "#/$defs/AntiBiasParameters" + }, + { + "type": "null" + } + ] + }, "is_baiting": { "const": true, "default": true, @@ -3538,6 +3697,28 @@ } ] }, + "antibias_parameters": { + "default": { + "threshold": { + "lower": 0.3, + "upper": 0.7 + }, + "intervention_interval": 10, + "maximum_water_corrections": 5, + "volume": 1, + "bias_window_length": 200, + "lickspout_offset_delta": 0.05 + }, + "description": "Antibias settings. If set, trial generator with give water and move lickspouts to combat bias.", + "oneOf": [ + { + "$ref": "#/$defs/AntiBiasParameters" + }, + { + "type": "null" + } + ] + }, "is_baiting": { "default": false, "description": "Whether uncollected rewards carry over to the next trial.", diff --git a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs index 0479d06..58b8ddc 100644 --- a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs +++ b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs @@ -703,6 +703,183 @@ public override string ToString() } + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")] + [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] + [Bonsai.CombinatorAttribute(MethodName="Generate")] + public partial class AntiBiasParameters + { + + private BiasThreshold _threshold; + + private int _interventionInterval; + + private int _maximumWaterCorrections; + + private int _volume; + + private int _biasWindowLength; + + private double _lickspoutOffsetDelta; + + public AntiBiasParameters() + { + _threshold = new BiasThreshold(); + _interventionInterval = 10; + _maximumWaterCorrections = 5; + _volume = 1; + _biasWindowLength = 200; + _lickspoutOffsetDelta = 0.05D; + } + + protected AntiBiasParameters(AntiBiasParameters other) + { + _threshold = other._threshold; + _interventionInterval = other._interventionInterval; + _maximumWaterCorrections = other._maximumWaterCorrections; + _volume = other._volume; + _biasWindowLength = other._biasWindowLength; + _lickspoutOffsetDelta = other._lickspoutOffsetDelta; + } + + /// + /// Thresholds for bias correction. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("threshold")] + [System.ComponentModel.DescriptionAttribute("Thresholds for bias correction.")] + public BiasThreshold Threshold + { + get + { + return _threshold; + } + set + { + _threshold = value; + } + } + + /// + /// Trials between bias intervention. + /// + [Newtonsoft.Json.JsonPropertyAttribute("intervention_interval")] + [System.ComponentModel.DescriptionAttribute("Trials between bias intervention.")] + public int InterventionInterval + { + get + { + return _interventionInterval; + } + set + { + _interventionInterval = value; + } + } + + /// + /// Number of water correction to attempt. + /// + [Newtonsoft.Json.JsonPropertyAttribute("maximum_water_corrections")] + [System.ComponentModel.DescriptionAttribute("Number of water correction to attempt.")] + public int MaximumWaterCorrections + { + get + { + return _maximumWaterCorrections; + } + set + { + _maximumWaterCorrections = value; + } + } + + /// + /// Volume in ul of water given. + /// + [Newtonsoft.Json.JsonPropertyAttribute("volume")] + [System.ComponentModel.DescriptionAttribute("Volume in ul of water given.")] + public int Volume + { + get + { + return _volume; + } + set + { + _volume = value; + } + } + + /// + /// Trials to calculate bias over. + /// + [Newtonsoft.Json.JsonPropertyAttribute("bias_window_length")] + [System.ComponentModel.DescriptionAttribute("Trials to calculate bias over.")] + public int BiasWindowLength + { + get + { + return _biasWindowLength; + } + set + { + _biasWindowLength = value; + } + } + + /// + /// Absolute value of delta (mm) to move stage. + /// + [Newtonsoft.Json.JsonPropertyAttribute("lickspout_offset_delta")] + [System.ComponentModel.DescriptionAttribute("Absolute value of delta (mm) to move stage.")] + public double LickspoutOffsetDelta + { + get + { + return _lickspoutOffsetDelta; + } + set + { + _lickspoutOffsetDelta = value; + } + } + + public System.IObservable Generate() + { + return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new AntiBiasParameters(this))); + } + + public System.IObservable Generate(System.IObservable source) + { + return System.Reactive.Linq.Observable.Select(source, _ => new AntiBiasParameters(this)); + } + + protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder) + { + stringBuilder.Append("Threshold = " + _threshold + ", "); + stringBuilder.Append("InterventionInterval = " + _interventionInterval + ", "); + stringBuilder.Append("MaximumWaterCorrections = " + _maximumWaterCorrections + ", "); + stringBuilder.Append("Volume = " + _volume + ", "); + stringBuilder.Append("BiasWindowLength = " + _biasWindowLength + ", "); + stringBuilder.Append("LickspoutOffsetDelta = " + _lickspoutOffsetDelta); + return true; + } + + public override string ToString() + { + System.Text.StringBuilder stringBuilder = new System.Text.StringBuilder(); + stringBuilder.Append(GetType().Name); + stringBuilder.Append(" { "); + if (PrintMembers(stringBuilder)) + { + stringBuilder.Append(" "); + } + stringBuilder.Append("}"); + return stringBuilder.ToString(); + } + } + + /// /// Represents an auditory secondary reinforcer. /// @@ -897,6 +1074,8 @@ public partial class BaseCoupledTrialGeneratorSpec : TrialGeneratorSpec private AutoWaterParameters _autowaterParameters; + private AntiBiasParameters _antibiasParameters; + private bool _isBaiting; private RewardProbabilityParameters _rewardProbabilityParameters; @@ -909,6 +1088,7 @@ public BaseCoupledTrialGeneratorSpec() _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _autowaterParameters = new AutoWaterParameters(); + _antibiasParameters = new AntiBiasParameters(); _isBaiting = false; _rewardProbabilityParameters = new RewardProbabilityParameters(); } @@ -922,6 +1102,7 @@ protected BaseCoupledTrialGeneratorSpec(BaseCoupledTrialGeneratorSpec other) : _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; _autowaterParameters = other._autowaterParameters; + _antibiasParameters = other._antibiasParameters; _isBaiting = other._isBaiting; _rewardProbabilityParameters = other._rewardProbabilityParameters; } @@ -1034,6 +1215,25 @@ public AutoWaterParameters AutowaterParameters } } + /// + /// Antibias settings. If set, trial generator with give water and move lickspouts to combat bias. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator with give water and move lickspouts to" + + " combat bias.")] + public AntiBiasParameters AntibiasParameters + { + get + { + return _antibiasParameters; + } + set + { + _antibiasParameters = value; + } + } + /// /// Whether uncollected rewards carry over to the next trial. /// @@ -1091,6 +1291,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters); return true; @@ -1259,6 +1460,94 @@ public override string ToString() } + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")] + [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] + [Bonsai.CombinatorAttribute(MethodName="Generate")] + public partial class BiasThreshold + { + + private double _upper; + + private double _lower; + + public BiasThreshold() + { + _upper = 0.7D; + _lower = 0.3D; + } + + protected BiasThreshold(BiasThreshold other) + { + _upper = other._upper; + _lower = other._lower; + } + + /// + /// Absolute value of the upper bias threshold. + /// + [Newtonsoft.Json.JsonPropertyAttribute("upper")] + [System.ComponentModel.DescriptionAttribute("Absolute value of the upper bias threshold.")] + public double Upper + { + get + { + return _upper; + } + set + { + _upper = value; + } + } + + /// + /// Absolute value of the lower bias threshold. + /// + [Newtonsoft.Json.JsonPropertyAttribute("lower")] + [System.ComponentModel.DescriptionAttribute("Absolute value of the lower bias threshold.")] + public double Lower + { + get + { + return _lower; + } + set + { + _lower = value; + } + } + + public System.IObservable Generate() + { + return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new BiasThreshold(this))); + } + + public System.IObservable Generate(System.IObservable source) + { + return System.Reactive.Linq.Observable.Select(source, _ => new BiasThreshold(this)); + } + + protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder) + { + stringBuilder.Append("Upper = " + _upper + ", "); + stringBuilder.Append("Lower = " + _lower); + return true; + } + + public override string ToString() + { + System.Text.StringBuilder stringBuilder = new System.Text.StringBuilder(); + stringBuilder.Append(GetType().Name); + stringBuilder.Append(" { "); + if (PrintMembers(stringBuilder)) + { + stringBuilder.Append(" "); + } + stringBuilder.Append("}"); + return stringBuilder.ToString(); + } + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")] [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] [Bonsai.CombinatorAttribute(MethodName="Generate")] @@ -1277,6 +1566,8 @@ public partial class BlockBasedTrialGeneratorSpec : TrialGeneratorSpec private AutoWaterParameters _autowaterParameters; + private AntiBiasParameters _antibiasParameters; + private bool _isBaiting; public BlockBasedTrialGeneratorSpec() @@ -1287,6 +1578,7 @@ public BlockBasedTrialGeneratorSpec() _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _autowaterParameters = new AutoWaterParameters(); + _antibiasParameters = new AntiBiasParameters(); _isBaiting = false; } @@ -1299,6 +1591,7 @@ protected BlockBasedTrialGeneratorSpec(BlockBasedTrialGeneratorSpec other) : _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; _autowaterParameters = other._autowaterParameters; + _antibiasParameters = other._antibiasParameters; _isBaiting = other._isBaiting; } @@ -1410,6 +1703,25 @@ public AutoWaterParameters AutowaterParameters } } + /// + /// Antibias settings. If set, trial generator with give water and move lickspouts to combat bias. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator with give water and move lickspouts to" + + " combat bias.")] + public AntiBiasParameters AntibiasParameters + { + get + { + return _antibiasParameters; + } + set + { + _antibiasParameters = value; + } + } + /// /// Whether uncollected rewards carry over to the next trial. /// @@ -1449,6 +1761,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting); return true; } @@ -1976,6 +2289,8 @@ public partial class CoupledTrialGeneratorSpec : TrialGeneratorSpec private AutoWaterParameters _autowaterParameters; + private AntiBiasParameters _antibiasParameters; + private bool _isBaiting; private RewardProbabilityParameters _rewardProbabilityParameters; @@ -1998,6 +2313,7 @@ public CoupledTrialGeneratorSpec() _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _autowaterParameters = new AutoWaterParameters(); + _antibiasParameters = new AntiBiasParameters(); _isBaiting = false; _rewardProbabilityParameters = new RewardProbabilityParameters(); _trialGenerationEndParameters = new CoupledTrialGenerationEndConditions(); @@ -2016,6 +2332,7 @@ protected CoupledTrialGeneratorSpec(CoupledTrialGeneratorSpec other) : _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; _autowaterParameters = other._autowaterParameters; + _antibiasParameters = other._antibiasParameters; _isBaiting = other._isBaiting; _rewardProbabilityParameters = other._rewardProbabilityParameters; _trialGenerationEndParameters = other._trialGenerationEndParameters; @@ -2133,6 +2450,25 @@ public AutoWaterParameters AutowaterParameters } } + /// + /// Antibias settings. If set, trial generator with give water and move lickspouts to combat bias. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator with give water and move lickspouts to" + + " combat bias.")] + public AntiBiasParameters AntibiasParameters + { + get + { + return _antibiasParameters; + } + set + { + _antibiasParameters = value; + } + } + /// /// Whether uncollected rewards carry over to the next trial. /// @@ -2275,6 +2611,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters + ", "); stringBuilder.Append("TrialGenerationEndParameters = " + _trialGenerationEndParameters + ", "); @@ -2437,6 +2774,8 @@ public partial class CoupledWarmupTrialGeneratorSpec : TrialGeneratorSpec private AutoWaterParameters _autowaterParameters; + private AntiBiasParameters _antibiasParameters; + private bool _isBaiting; private RewardProbabilityParameters _rewardProbabilityParameters; @@ -2451,6 +2790,7 @@ public CoupledWarmupTrialGeneratorSpec() _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _autowaterParameters = new AutoWaterParameters(); + _antibiasParameters = new AntiBiasParameters(); _isBaiting = true; _rewardProbabilityParameters = new RewardProbabilityParameters(); _trialGenerationEndParameters = new CoupledWarmupTrialGenerationEndConditions(); @@ -2465,6 +2805,7 @@ protected CoupledWarmupTrialGeneratorSpec(CoupledWarmupTrialGeneratorSpec other) _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; _autowaterParameters = other._autowaterParameters; + _antibiasParameters = other._antibiasParameters; _isBaiting = other._isBaiting; _rewardProbabilityParameters = other._rewardProbabilityParameters; _trialGenerationEndParameters = other._trialGenerationEndParameters; @@ -2578,6 +2919,25 @@ public AutoWaterParameters AutowaterParameters } } + /// + /// Antibias settings. If set, trial generator with give water and move lickspouts to combat bias. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator with give water and move lickspouts to" + + " combat bias.")] + public AntiBiasParameters AntibiasParameters + { + get + { + return _antibiasParameters; + } + set + { + _antibiasParameters = value; + } + } + /// /// Whether uncollected rewards carry over to the next trial. /// @@ -2653,6 +3013,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters + ", "); stringBuilder.Append("TrialGenerationEndParameters = " + _trialGenerationEndParameters); @@ -6192,6 +6553,8 @@ public partial class UncoupledTrialGeneratorSpec : TrialGeneratorSpec private AutoWaterParameters _autowaterParameters; + private AntiBiasParameters _antibiasParameters; + private bool _isBaiting; private UncoupledTrialGenerationEndConditions _trialGenerationEndParameters; @@ -6208,6 +6571,7 @@ public UncoupledTrialGeneratorSpec() _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.UniformDistribution(); _autowaterParameters = new AutoWaterParameters(); + _antibiasParameters = new AntiBiasParameters(); _isBaiting = false; _trialGenerationEndParameters = new UncoupledTrialGenerationEndConditions(); _rewardProbabilities = new System.Collections.Generic.List(); @@ -6223,6 +6587,7 @@ protected UncoupledTrialGeneratorSpec(UncoupledTrialGeneratorSpec other) : _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; _autowaterParameters = other._autowaterParameters; + _antibiasParameters = other._antibiasParameters; _isBaiting = other._isBaiting; _trialGenerationEndParameters = other._trialGenerationEndParameters; _rewardProbabilities = other._rewardProbabilities; @@ -6337,6 +6702,25 @@ public AutoWaterParameters AutowaterParameters } } + /// + /// Antibias settings. If set, trial generator with give water and move lickspouts to combat bias. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator with give water and move lickspouts to" + + " combat bias.")] + public AntiBiasParameters AntibiasParameters + { + get + { + return _antibiasParameters; + } + set + { + _antibiasParameters = value; + } + } + /// /// Whether uncollected rewards carry over to the next trial. /// @@ -6429,6 +6813,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("TrialGenerationEndParameters = " + _trialGenerationEndParameters + ", "); stringBuilder.Append("RewardProbabilities = " + _rewardProbabilities + ", "); @@ -7569,6 +7954,11 @@ public System.IObservable Process(System.IObservable(source); } + public System.IObservable Process(System.IObservable source) + { + return Process(source); + } + public System.IObservable Process(System.IObservable source) { return Process(source); @@ -7594,6 +7984,11 @@ public System.IObservable Process(System.IObservable(source); } + public System.IObservable Process(System.IObservable source) + { + return Process(source); + } + public System.IObservable Process(System.IObservable source) { return Process(source); @@ -7791,11 +8186,13 @@ public System.IObservable Process(System.IObservable source) [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index db4e933..1a1d488 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -47,7 +47,7 @@ class AntiBiasParameters(BaseModel): maximum_water_corrections: int = Field(default=5, ge=0, description="Number of water correction to attempt.") volume: int = Field(default=1, ge=0, description="Volume in ul of water given.") bias_window_length: int = Field(default=200, ge=0, description="Trials to calculate bias over.") - lickspout_offset_delta: float = Field(default=0.5, ge=0, description="Absolute value of delta to move stage.") + lickspout_offset_delta: float = Field(default=0.05, ge=0, description="Absolute value of delta (mm) to move stage.") class Block(BaseModel): @@ -301,6 +301,7 @@ def _determine_antibias_intervention(self) -> tuple[bool | None, float]: is_right_autowater = None lickspout_offset_delta = 0 + ab_delta = self.spec.antibias_parameters.lickspout_offset_delta if abs(self.bias) > self.spec.antibias_parameters.threshold.upper: if self.water_corrections < self.spec.antibias_parameters.maximum_water_corrections: logger.debug("Correcting bias with water.") @@ -310,14 +311,14 @@ def _determine_antibias_intervention(self) -> tuple[bool | None, float]: self.water_corrections += 1 else: logger.debug("Correcting bias with lickspout offset.") - lickspout_offset_delta = 0.5 if self.bias < 0 else -0.5 # + values move lickspout right + lickspout_offset_delta = ab_delta if self.bias < 0 else -ab_delta # + values move lickspout right self.water_corrections = 0 elif ( abs(self.bias) < self.spec.antibias_parameters.threshold.lower and self.total_lickspout_offset != 0 ): # bias below lower threshold, move back towards center logger.debug("Moving lickspout back toward center.") - delta = min(0.5, abs(self.total_lickspout_offset)) + delta = min(ab_delta, abs(self.total_lickspout_offset)) lickspout_offset_delta = -delta if self.total_lickspout_offset > 0 else delta self.total_lickspout_offset += lickspout_offset_delta diff --git a/tests/trial_generators/test_block_based_trial_generator.py b/tests/trial_generators/test_block_based_trial_generator.py index f063813..9bf05e8 100644 --- a/tests/trial_generators/test_block_based_trial_generator.py +++ b/tests/trial_generators/test_block_based_trial_generator.py @@ -1,17 +1,16 @@ import logging import unittest -from unittest.mock import patch, MagicMock -import pandas as pd +from unittest.mock import MagicMock, patch import numpy as np from aind_behavior_dynamic_foraging.task_logic.trial_generators.block_based_trial_generator import ( + AntiBiasParameters, + AutoWaterParameters, + BiasThreshold, Block, BlockBasedTrialGenerator, BlockBasedTrialGeneratorSpec, - AntiBiasParameters, - AutoWaterParameters, - BiasThreshold ) from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial @@ -64,34 +63,35 @@ def test_baiting_disabled_reward_prob_unchanged(self): self.assertEqual(trial.p_reward_right, 0.8) self.assertEqual(trial.p_reward_left, 0.2) -class TestAntiBiasBlockBasedTrialGenerator(unittest.TestCase): +class TestAntiBiasBlockBasedTrialGenerator(unittest.TestCase): def _patch_bias(self, bias_value: float) -> dict: - - return_value ={"df_beta": MagicMock()} + + return_value = {"df_beta": MagicMock()} return_value["df_beta"].loc = {"bias": {"cross_validation": MagicMock()}} return_value["df_beta"].loc["bias"]["cross_validation"].values = [bias_value] return patch( - "aind_behavior_dynamic_foraging.task_logic.trial_generators.block_based_trial_generator.fit_logistic_regression", - return_value=return_value - ) - - def _make_generator(self, - bias: float, - trials_in_bias_intervention: int = 15, - water_corrections: int = 0, - maximum_water_corrections: int = 5, - bias_window_length: int = 5, - intervention_interval: int = 10, - total_offset: float = 0.0, - threshold: BiasThreshold = BiasThreshold(upper=0.7, lower=.3) - ) -> ConcreteBlockBasedTrialGenerator: + "aind_behavior_dynamic_foraging.task_logic.trial_generators.block_based_trial_generator.fit_logistic_regression", + return_value=return_value, + ) + + def _make_generator( + self, + bias: float, + trials_in_bias_intervention: int = 15, + water_corrections: int = 0, + maximum_water_corrections: int = 5, + bias_window_length: int = 5, + intervention_interval: int = 10, + total_offset: float = 0.0, + threshold: BiasThreshold = BiasThreshold(upper=0.7, lower=0.3), + ) -> ConcreteBlockBasedTrialGenerator: ab = AntiBiasParameters( maximum_water_corrections=maximum_water_corrections, bias_window_length=bias_window_length, intervention_interval=intervention_interval, - threshold=threshold + threshold=threshold, ) spec = ConcreteBlockBasedTrialGeneratorSpec(antibias_parameters=ab) gen = spec.create_generator() @@ -109,18 +109,18 @@ def test_returns_false_when_antibias_disabled(self): spec = ConcreteBlockBasedTrialGeneratorSpec(antibias_parameters=None) gen = spec.create_generator() self.assertFalse(gen._are_antibias_conditions_met()) - + def test_returns_false_before_intervention_interval(self): """Condition should not trigger before the intervention interval is exceeded.""" - - gen = self._make_generator(bias=.5, intervention_interval=10) + + gen = self._make_generator(bias=0.5, intervention_interval=10) gen.trials_in_bias_intervention = 5 self.assertFalse(gen._are_antibias_conditions_met()) - + def test_returns_false_when_bias_within_thresholds(self): """No intervention when bias sits between lower and upper thresholds.""" gen = self._make_generator( - bias=.5, + bias=0.5, intervention_interval=10, threshold=BiasThreshold(upper=0.7, lower=0.3), bias_window_length=5, @@ -128,16 +128,16 @@ def test_returns_false_when_bias_within_thresholds(self): gen.trials_in_bias_intervention = 15 gen.is_right_choice_history = [True, False] * 50 gen.reward_history = [True] * 100 - + with self._patch_bias(0.5): result = gen._are_antibias_conditions_met() - + self.assertFalse(result) - + def test_returns_true_when_bias_above_upper_threshold(self): """Intervention when bias is above threshold""" gen = self._make_generator( - bias = .9, + bias=0.9, intervention_interval=10, threshold=BiasThreshold(upper=0.7, lower=0.3), bias_window_length=5, @@ -145,16 +145,16 @@ def test_returns_true_when_bias_above_upper_threshold(self): gen.trials_in_bias_intervention = 15 gen.is_right_choice_history = [True] * 100 gen.reward_history = [True] * 100 - + with self._patch_bias(0.9): result = gen._are_antibias_conditions_met() - + self.assertTrue(result) - + def test_returns_true_when_bias_below_lower_threshold(self): """Intervention when bias is below threshold""" gen = self._make_generator( - bias=.2, + bias=0.2, intervention_interval=10, threshold=BiasThreshold(upper=0.7, lower=0.3), bias_window_length=5, @@ -162,65 +162,65 @@ def test_returns_true_when_bias_below_lower_threshold(self): gen.trials_in_bias_intervention = 15 gen.is_right_choice_history = [False] * 100 gen.reward_history = [False] * 100 - + with self._patch_bias(0.9): result = gen._are_antibias_conditions_met() - + self.assertTrue(result) - + def test_bias_stored_on_generator_after_check(self): """The computed bias value should be saved on the generator.""" gen = self._make_generator( - bias = 0, + bias=0, intervention_interval=10, bias_window_length=5, ) - + with self._patch_bias(0.42): gen._are_antibias_conditions_met() - + self.assertAlmostEqual(gen.bias, 0.42) def test_gives_right_water_on_left_bias(self): - """Negative bias (left bias) → give right water.""" - gen = self._make_generator(bias=-0.9, maximum_water_corrections=5) - is_right, delta = gen._determine_antibias_intervention() - self.assertTrue(is_right) - self.assertEqual(delta, 0.0) - + """Negative bias (left bias) → give right water.""" + gen = self._make_generator(bias=-0.9, maximum_water_corrections=5) + is_right, delta = gen._determine_antibias_intervention() + self.assertTrue(is_right) + self.assertEqual(delta, 0.0) + def test_gives_left_water_on_right_bias(self): """Positive bias (right bias) → give left water.""" gen = self._make_generator(bias=0.9, maximum_water_corrections=5) is_right, delta = gen._determine_antibias_intervention() self.assertFalse(is_right) self.assertEqual(delta, 0.0) - + def test_water_corrections_counter_increments(self): gen = self._make_generator(bias=-0.9, water_corrections=2, maximum_water_corrections=5) gen._determine_antibias_intervention() self.assertEqual(gen.water_corrections, 3) - + def test_switches_to_lickspout_after_max_corrections_left_bias(self): """After exhausting water corrections, move lickspout right (combat left bias).""" gen = self._make_generator(bias=-0.9, water_corrections=5, maximum_water_corrections=5) is_right, delta = gen._determine_antibias_intervention() self.assertIsNone(is_right) self.assertGreater(delta, 0) - + def test_switches_to_lickspout_after_max_corrections_right_bias(self): """After exhausting water corrections, move lickspout left (combat right bias).""" gen = self._make_generator(bias=0.9, water_corrections=5, maximum_water_corrections=5) is_right, delta = gen._determine_antibias_intervention() self.assertIsNone(is_right) self.assertLess(delta, 0) - + def test_water_corrections_reset_after_lickspout_move(self): gen = self._make_generator(bias=-0.9, water_corrections=5, maximum_water_corrections=5) gen._determine_antibias_intervention() self.assertEqual(gen.water_corrections, 0) - + # #### Test lickspout centering #### - + def test_no_centering_when_offset_is_zero(self): """No correction when already centered, even if bias drops below lower threshold.""" gen = self._make_generator( @@ -230,7 +230,7 @@ def test_no_centering_when_offset_is_zero(self): ) _, delta = gen._determine_antibias_intervention() self.assertEqual(delta, 0.0) - + def test_centering_moves_toward_zero_from_positive_offset(self): """Positive offset + low bias → negative delta (move back left).""" gen = self._make_generator( @@ -240,7 +240,7 @@ def test_centering_moves_toward_zero_from_positive_offset(self): ) _, delta = gen._determine_antibias_intervention() self.assertLess(delta, 0) - + def test_centering_moves_toward_zero_from_negative_offset(self): """Negative offset + low bias → positive delta (move back right).""" gen = self._make_generator( @@ -250,7 +250,7 @@ def test_centering_moves_toward_zero_from_negative_offset(self): ) _, delta = gen._determine_antibias_intervention() self.assertGreater(delta, 0) - + def test_centering_step_capped_at_offset_magnitude(self): """Centering delta should not overshoot: capped at min(0.5, |offset|).""" gen = self._make_generator( @@ -260,37 +260,35 @@ def test_centering_step_capped_at_offset_magnitude(self): ) _, delta = gen._determine_antibias_intervention() self.assertLessEqual(abs(delta), 0.2) - + def test_total_lickspout_offset_updated_after_move(self): """total_lickspout_offset should accumulate the delta applied.""" - gen = self._make_generator( - bias=-0.9, water_corrections=5, maximum_water_corrections=5, total_offset=0.0 - ) + gen = self._make_generator(bias=-0.9, water_corrections=5, maximum_water_corrections=5, total_offset=0.0) _, delta = gen._determine_antibias_intervention() self.assertAlmostEqual(gen.total_lickspout_offset, delta) - + #### Test next #### - + def test_next_gives_right_autowater_on_left_bias(self): gen = self._make_generator(bias=-0.9) with self._patch_bias(-0.9): trial = gen.next() self.assertIsNotNone(trial) self.assertTrue(trial.is_auto_response_right) - + def test_next_gives_left_autowater_on_right_bias(self): gen = self._make_generator(bias=0.9) with self._patch_bias(0.9): trial = gen.next() self.assertIsNotNone(trial) self.assertFalse(trial.is_auto_response_right) - + def test_next_no_antibias_when_below_interval(self): """No antibias effect when trials_in_bias_intervention has not exceeded interval.""" gen = self._make_generator(bias=-0.9, trials_in_bias_intervention=5) trial = gen.next() self.assertIsNone(trial.is_auto_response_right) - + def test_next_antibias_overrides_autowater(self): """When both autowater and antibias conditions are met, antibias takes precedence.""" ab = AntiBiasParameters( @@ -307,25 +305,26 @@ def test_next_antibias_overrides_autowater(self): gen.trials_in_bias_intervention = 15 gen.is_right_choice_history = [None] # ignored trial → autowater would also fire gen.reward_history = [False] - + with self._patch_bias(-0.9): trial = gen.next() - + # Antibias (left bias → give right water) should win self.assertTrue(trial.is_auto_response_right) - + def test_next_lickspout_delta_nonzero_after_corrections_exhausted(self): """After max water corrections, next() should produce a nonzero lickspout delta.""" gen = self._make_generator(bias=-0.9, water_corrections=5) with self._patch_bias(-0.9): trial = gen.next() self.assertNotEqual(trial.lickspout_offset_delta, 0) - + def test_next_no_lickspout_delta_when_antibias_not_triggered(self): gen = self._make_generator(bias=-0.9, trials_in_bias_intervention=5) trial = gen.next() self.assertEqual(trial.lickspout_offset_delta, 0) + class TestBlockBaseBaitingTrialGenerator(unittest.TestCase): def setUp(self): self.spec = ConcreteBlockBasedTrialGeneratorSpec(is_baiting=True) From a8f4ff29e2df5032daca869eea7be123e9d3cc16 Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Sun, 3 May 2026 08:14:08 -0700 Subject: [PATCH 03/15] fixes comments and descriptions --- .../block_based_trial_generator.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index 1a1d488..92d2315 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -41,7 +41,7 @@ class BiasThreshold(BaseModel): class AntiBiasParameters(BaseModel): threshold: BiasThreshold = Field( - default=BiasThreshold(), validate_default=True, description="Thresholds for bias correction." + default=BiasThreshold(), validate_default=True, description="Thresholds for bias correction intervention." ) intervention_interval: int = Field(default=10, ge=0, description="Trials between bias intervention.") maximum_water_corrections: int = Field(default=5, ge=0, description="Number of water correction to attempt.") @@ -101,7 +101,7 @@ class BlockBasedTrialGeneratorSpec(BaseTrialGeneratorSpecModel): antibias_parameters: Optional[AntiBiasParameters] = Field( default=AntiBiasParameters(), validate_default=True, - description="Antibias settings. If set, trial generator with give water and move lickspouts to combat bias.", + description="Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.", ) is_baiting: bool = Field(default=False, description="Whether uncollected rewards carry over to the next trial.") @@ -235,7 +235,7 @@ def _are_autowater_conditions_met(self) -> bool: True if autowater conditions are met, False otherwise. """ - if self.spec.autowater_parameters is None: # autowater disabled + if self.spec.autowater_parameters is None: logger.debug("Autowater not configured.") return False @@ -261,17 +261,13 @@ def _are_antibias_conditions_met(self) -> bool: True if antibias conditions are met, False otherwise. """ - if self.spec.antibias_parameters is None: # antibias disabled + if self.spec.antibias_parameters is None: logger.debug("Anitbias not configured.") return False if self.trials_in_bias_intervention > self.spec.antibias_parameters.intervention_interval: # update bias - choice_history = ( - np.array( - self.is_right_choice_history[-self.spec.antibias_parameters.bias_window_length :], dtype=float - ), - ) + choice_history = self.is_right_choice_history[-self.spec.antibias_parameters.bias_window_length :] reward_history = self.reward_history[-self.spec.antibias_parameters.bias_window_length :] lr = fit_logistic_regression( choice_history=np.array(choice_history, dtype=float), @@ -302,7 +298,7 @@ def _determine_antibias_intervention(self) -> tuple[bool | None, float]: is_right_autowater = None lickspout_offset_delta = 0 ab_delta = self.spec.antibias_parameters.lickspout_offset_delta - if abs(self.bias) > self.spec.antibias_parameters.threshold.upper: + if abs(self.bias) >= self.spec.antibias_parameters.threshold.upper: if self.water_corrections < self.spec.antibias_parameters.maximum_water_corrections: logger.debug("Correcting bias with water.") is_right_autowater = ( @@ -315,7 +311,7 @@ def _determine_antibias_intervention(self) -> tuple[bool | None, float]: self.water_corrections = 0 elif ( - abs(self.bias) < self.spec.antibias_parameters.threshold.lower and self.total_lickspout_offset != 0 + abs(self.bias) <= self.spec.antibias_parameters.threshold.lower and self.total_lickspout_offset != 0 ): # bias below lower threshold, move back towards center logger.debug("Moving lickspout back toward center.") delta = min(ab_delta, abs(self.total_lickspout_offset)) From dc0d022ceeb5b1494e1dd320d352829131d80ff6 Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Sun, 3 May 2026 08:19:09 -0700 Subject: [PATCH 04/15] regenerates --- schema/aind_behavior_dynamic_foraging.json | 12 +++++----- .../AindBehaviorDynamicForaging.Generated.cs | 24 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/schema/aind_behavior_dynamic_foraging.json b/schema/aind_behavior_dynamic_foraging.json index cd391aa..d511e8f 100644 --- a/schema/aind_behavior_dynamic_foraging.json +++ b/schema/aind_behavior_dynamic_foraging.json @@ -429,7 +429,7 @@ "upper": 0.7, "lower": 0.3 }, - "description": "Thresholds for bias correction." + "description": "Thresholds for bias correction intervention." }, "intervention_interval": { "default": 10, @@ -690,7 +690,7 @@ "bias_window_length": 200, "lickspout_offset_delta": 0.05 }, - "description": "Antibias settings. If set, trial generator with give water and move lickspouts to combat bias.", + "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.", "oneOf": [ { "$ref": "#/$defs/AntiBiasParameters" @@ -1034,7 +1034,7 @@ "bias_window_length": 200, "lickspout_offset_delta": 0.05 }, - "description": "Antibias settings. If set, trial generator with give water and move lickspouts to combat bias.", + "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.", "oneOf": [ { "$ref": "#/$defs/AntiBiasParameters" @@ -1322,7 +1322,7 @@ "bias_window_length": 200, "lickspout_offset_delta": 0.05 }, - "description": "Antibias settings. If set, trial generator with give water and move lickspouts to combat bias.", + "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.", "oneOf": [ { "$ref": "#/$defs/AntiBiasParameters" @@ -1533,7 +1533,7 @@ "bias_window_length": 200, "lickspout_offset_delta": 0.05 }, - "description": "Antibias settings. If set, trial generator with give water and move lickspouts to combat bias.", + "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.", "oneOf": [ { "$ref": "#/$defs/AntiBiasParameters" @@ -3709,7 +3709,7 @@ "bias_window_length": 200, "lickspout_offset_delta": 0.05 }, - "description": "Antibias settings. If set, trial generator with give water and move lickspouts to combat bias.", + "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.", "oneOf": [ { "$ref": "#/$defs/AntiBiasParameters" diff --git a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs index 58b8ddc..c5dbf2b 100644 --- a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs +++ b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs @@ -742,11 +742,11 @@ protected AntiBiasParameters(AntiBiasParameters other) } /// - /// Thresholds for bias correction. + /// Thresholds for bias correction intervention. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("threshold")] - [System.ComponentModel.DescriptionAttribute("Thresholds for bias correction.")] + [System.ComponentModel.DescriptionAttribute("Thresholds for bias correction intervention.")] public BiasThreshold Threshold { get @@ -1216,11 +1216,11 @@ public AutoWaterParameters AutowaterParameters } /// - /// Antibias settings. If set, trial generator with give water and move lickspouts to combat bias. + /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] - [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator with give water and move lickspouts to" + + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" + " combat bias.")] public AntiBiasParameters AntibiasParameters { @@ -1704,11 +1704,11 @@ public AutoWaterParameters AutowaterParameters } /// - /// Antibias settings. If set, trial generator with give water and move lickspouts to combat bias. + /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] - [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator with give water and move lickspouts to" + + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" + " combat bias.")] public AntiBiasParameters AntibiasParameters { @@ -2451,11 +2451,11 @@ public AutoWaterParameters AutowaterParameters } /// - /// Antibias settings. If set, trial generator with give water and move lickspouts to combat bias. + /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] - [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator with give water and move lickspouts to" + + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" + " combat bias.")] public AntiBiasParameters AntibiasParameters { @@ -2920,11 +2920,11 @@ public AutoWaterParameters AutowaterParameters } /// - /// Antibias settings. If set, trial generator with give water and move lickspouts to combat bias. + /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] - [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator with give water and move lickspouts to" + + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" + " combat bias.")] public AntiBiasParameters AntibiasParameters { @@ -6703,11 +6703,11 @@ public AutoWaterParameters AutowaterParameters } /// - /// Antibias settings. If set, trial generator with give water and move lickspouts to combat bias. + /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] - [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator with give water and move lickspouts to" + + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" + " combat bias.")] public AntiBiasParameters AntibiasParameters { From 53561a2eaab1a954abdd8143a97881ac69867718 Mon Sep 17 00:00:00 2001 From: Micah Woodard <110491290+micahwoodard@users.noreply.github.com> Date: Sun, 3 May 2026 11:23:34 -0700 Subject: [PATCH 05/15] Update src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py Co-authored-by: Bruno Cruz <7049351+bruno-f-cruz@users.noreply.github.com> --- .../task_logic/trial_generators/block_based_trial_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index 92d2315..16d1b2d 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -47,7 +47,7 @@ class AntiBiasParameters(BaseModel): maximum_water_corrections: int = Field(default=5, ge=0, description="Number of water correction to attempt.") volume: int = Field(default=1, ge=0, description="Volume in ul of water given.") bias_window_length: int = Field(default=200, ge=0, description="Trials to calculate bias over.") - lickspout_offset_delta: float = Field(default=0.05, ge=0, description="Absolute value of delta (mm) to move stage.") + lickspout_offset_delta: float = Field(default=0.05, ge=0, description="Distance (mm) to move the stage spouts by. This is a relative distance to the current value, not absolute.") class Block(BaseModel): From 4118245d75d23618458bda64d866bf05d2741b38 Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Sun, 3 May 2026 12:03:10 -0700 Subject: [PATCH 06/15] regenerates --- schema/aind_behavior_dynamic_foraging.json | 2 +- src/Extensions/AindBehaviorDynamicForaging.Generated.cs | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/schema/aind_behavior_dynamic_foraging.json b/schema/aind_behavior_dynamic_foraging.json index d511e8f..dc29385 100644 --- a/schema/aind_behavior_dynamic_foraging.json +++ b/schema/aind_behavior_dynamic_foraging.json @@ -461,7 +461,7 @@ }, "lickspout_offset_delta": { "default": 0.05, - "description": "Absolute value of delta (mm) to move stage.", + "description": "Distance (mm) to move the stage spouts by. This is a relative distance to the current value, not absolute.", "minimum": 0, "title": "Lickspout Offset Delta", "type": "number" diff --git a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs index c5dbf2b..95543e0 100644 --- a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs +++ b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs @@ -828,10 +828,11 @@ public int BiasWindowLength } /// - /// Absolute value of delta (mm) to move stage. + /// Distance (mm) to move the stage spouts by. This is a relative distance to the current value, not absolute. /// [Newtonsoft.Json.JsonPropertyAttribute("lickspout_offset_delta")] - [System.ComponentModel.DescriptionAttribute("Absolute value of delta (mm) to move stage.")] + [System.ComponentModel.DescriptionAttribute("Distance (mm) to move the stage spouts by. This is a relative distance to the cur" + + "rent value, not absolute.")] public double LickspoutOffsetDelta { get From b4a0660bfc51933767a486eeb31fa6f29fafafe7 Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Sun, 3 May 2026 12:11:05 -0700 Subject: [PATCH 07/15] lints --- .../trial_generators/block_based_trial_generator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index 16d1b2d..9dee9e9 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -47,7 +47,11 @@ class AntiBiasParameters(BaseModel): maximum_water_corrections: int = Field(default=5, ge=0, description="Number of water correction to attempt.") volume: int = Field(default=1, ge=0, description="Volume in ul of water given.") bias_window_length: int = Field(default=200, ge=0, description="Trials to calculate bias over.") - lickspout_offset_delta: float = Field(default=0.05, ge=0, description="Distance (mm) to move the stage spouts by. This is a relative distance to the current value, not absolute.") + lickspout_offset_delta: float = Field( + default=0.05, + ge=0, + description="Distance (mm) to move the stage spouts by. This is a relative distance to the current value, not absolute.", + ) class Block(BaseModel): From 65d65b3898d1aa91206a2da55506c547c40b6c7a Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Sun, 3 May 2026 15:51:51 -0700 Subject: [PATCH 08/15] rewrites bias calculation --- pyproject.toml | 2 +- .../block_based_trial_generator.py | 16 ++-- .../task_logic/utils/__init__.py | 3 + .../task_logic/utils/calculate_bias.py | 79 +++++++++++++++++++ .../test_block_based_trial_generator.py | 10 +-- tests/trial_generators/test_calculate_bias.py | 74 +++++++++++++++++ uv.lock | 20 +---- 7 files changed, 167 insertions(+), 37 deletions(-) create mode 100644 src/aind_behavior_dynamic_foraging/task_logic/utils/__init__.py create mode 100644 src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py create mode 100644 tests/trial_generators/test_calculate_bias.py diff --git a/pyproject.toml b/pyproject.toml index d8e8597..a861bea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,9 +19,9 @@ version = "0.0.2rc32" readme = {file = "README.md", content-type = "text/markdown"} dependencies = [ - "aind-dynamic-foraging-models>=0.13.1", "aind_behavior_services>=0.13.5", "pydantic-settings", + "scikit-learn>=1.8.0", ] [tool.uv.workspace] diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index 9dee9e9..8290312 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -10,9 +10,10 @@ TruncationParameters, ) from aind_behavior_services.task.distributions_utils import draw_sample -from aind_dynamic_foraging_models.logistic_regression import fit_logistic_regression from pydantic import BaseModel, Field +from aind_behavior_dynamic_foraging.task_logic.utils.calculate_bias import calculate_bias + from ..trial_models import Trial from ._base import BaseTrialGeneratorSpecModel, ITrialGenerator, TrialOutcome @@ -140,6 +141,7 @@ def __init__(self, spec: BlockBasedTrialGeneratorSpec) -> None: """ self.spec = spec + self.outcome_history: list[TrialOutcome] = [] self.is_right_choice_history: list[bool | None] = [] self.reward_history: list[bool] = [] self.is_left_baited: bool = False @@ -161,6 +163,7 @@ def update(self, outcome: TrialOutcome | str): if isinstance(outcome, str): outcome = TrialOutcome.model_validate_json(outcome) + self.outcome_history.append(outcome) self.is_right_choice_history.append(outcome.is_right_choice) self.reward_history.append(outcome.is_rewarded) @@ -271,16 +274,7 @@ def _are_antibias_conditions_met(self) -> bool: if self.trials_in_bias_intervention > self.spec.antibias_parameters.intervention_interval: # update bias - choice_history = self.is_right_choice_history[-self.spec.antibias_parameters.bias_window_length :] - reward_history = self.reward_history[-self.spec.antibias_parameters.bias_window_length :] - lr = fit_logistic_regression( - choice_history=np.array(choice_history, dtype=float), - reward_history=np.array(reward_history, dtype=float), - n_trial_back=5, - cv=10, - fit_exponential=10, - ) - self.bias = lr["df_beta"].loc["bias"]["cross_validation"].values[0] + self.bias = calculate_bias(outcome_history=self.outcome_history) if self.bias <= self.spec.antibias_parameters.threshold.lower: logger.debug("Bias calculated below threshold: %s." % self.bias) diff --git a/src/aind_behavior_dynamic_foraging/task_logic/utils/__init__.py b/src/aind_behavior_dynamic_foraging/task_logic/utils/__init__.py new file mode 100644 index 0000000..ed9492e --- /dev/null +++ b/src/aind_behavior_dynamic_foraging/task_logic/utils/__init__.py @@ -0,0 +1,3 @@ +from .calculate_bias import calculate_bias + +__all__ = ["calculate_bias"] diff --git a/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py b/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py new file mode 100644 index 0000000..0fcb769 --- /dev/null +++ b/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py @@ -0,0 +1,79 @@ +import logging +from typing import List + +import numpy as np +from sklearn.linear_model import LogisticRegressionCV + +from aind_behavior_dynamic_foraging.task_logic.trial_models import TrialOutcome + +logger = logging.getLogger(__name__) + + +def calculate_bias(outcomes: List[TrialOutcome]) -> float: + """Estimate the side bias of an animal using logistic regression on recent trial history. + + Fits a Su2022-style logistic regression model using rewarded and unrewarded choice + history as predictors. The intercept of the fitted model is returned as the bias, + representing the animal's baseline tendency to choose right independent of reward history. + + Parameters + ---------- + outcomes : List[TrialOutcome] + List of trial outcomes. Auto-response and ignored trials are excluded. + Only the most recent 200 trials are used. + + Returns + ------- + float + The logistic regression intercept, representing side bias. + Positive values indicate a bias toward right, negative toward left. + """ + + solver = "liblinear" + l1_ratios = (0,) + lag = 5 + cv = 10 + cs = 10 + + # reduce outcomes to last 200 + outcomes = outcomes[-200:] + + # exclude auto response and ignored trials + filtered = [t for t in outcomes if t.is_right_choice is not None and t.trial.is_auto_response_right is None] + + if len(filtered) <= lag: + logger.warning("Not enough choices to calculate bias.") + return np.nan + + is_right_choice_history = np.array([t.is_right_choice for t in filtered], dtype=float) + is_rewarded_history = np.array([t.is_rewarded for t in filtered], dtype=float) + + # transform to +-1 space for zero centered logistic regression + choice_signed = 2 * is_right_choice_history - 1 # left=0 → -1, right=1 → +1 + reward_signed = 2 * is_rewarded_history - 1 # unrewarded=0 → -1, rewarded=1 → +1 + + rewarded_choice = choice_signed * (reward_signed == 1) + unrewarded_choice = choice_signed * (reward_signed == -1) + + trial_length = len(is_rewarded_history) + x = np.zeros((trial_length - lag, 2 * lag)) + for i in range(lag, trial_length): + x[i - lag] = np.hstack([choice[i - lag : i] for choice in [rewarded_choice, unrewarded_choice]]) + + y = choice_signed[lag:] + + if len(np.unique(y)) < 2: # all choices are the same, return max bias in that direction + logger.warning("All choices are the same, cannot calculate bias.") + return np.nan + + logistic_reg = LogisticRegressionCV( + solver=solver, + l1_ratios=l1_ratios, + Cs=cs, + cv=cv, + use_legacy_attributes=True, + ) + logistic_reg.fit(x, y) + + bias = logistic_reg.intercept_[0] + return bias diff --git a/tests/trial_generators/test_block_based_trial_generator.py b/tests/trial_generators/test_block_based_trial_generator.py index 9bf05e8..076dbcc 100644 --- a/tests/trial_generators/test_block_based_trial_generator.py +++ b/tests/trial_generators/test_block_based_trial_generator.py @@ -1,6 +1,6 @@ import logging import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import patch import numpy as np @@ -67,13 +67,9 @@ def test_baiting_disabled_reward_prob_unchanged(self): class TestAntiBiasBlockBasedTrialGenerator(unittest.TestCase): def _patch_bias(self, bias_value: float) -> dict: - return_value = {"df_beta": MagicMock()} - return_value["df_beta"].loc = {"bias": {"cross_validation": MagicMock()}} - return_value["df_beta"].loc["bias"]["cross_validation"].values = [bias_value] - return patch( - "aind_behavior_dynamic_foraging.task_logic.trial_generators.block_based_trial_generator.fit_logistic_regression", - return_value=return_value, + "aind_behavior_dynamic_foraging.task_logic.trial_generators.block_based_trial_generator.calculate_bias", + return_value=bias_value, ) def _make_generator( diff --git a/tests/trial_generators/test_calculate_bias.py b/tests/trial_generators/test_calculate_bias.py new file mode 100644 index 0000000..752346b --- /dev/null +++ b/tests/trial_generators/test_calculate_bias.py @@ -0,0 +1,74 @@ +import math +import unittest + +import numpy as np + +from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome +from aind_behavior_dynamic_foraging.task_logic.utils.calculate_bias import calculate_bias + + +def make_outcomes(n, right_prob=0.5, reward_prob=0.5, **trial_kwargs) -> list[TrialOutcome]: + outcomes = [] + for _ in range(n): + is_right = bool(np.random.rand() < right_prob) + is_rewarded = bool(np.random.rand() < reward_prob) + outcomes.append(TrialOutcome(is_right_choice=is_right, is_rewarded=is_rewarded, trial=Trial(**trial_kwargs))) + return outcomes + + +class TestCalculateBias(unittest.TestCase): + def setUp(self): + np.random.seed(42) + + def test_right_bias_is_positive(self): + """Always choosing right with alternating rewards should produce positive bias.""" + outcomes = make_outcomes(200, 0.9, 0.5) + bias = calculate_bias(outcomes) + self.assertGreater(bias, 0) + + def test_left_bias_is_negative(self): + """Always choosing left with alternating rewards should produce negative bias.""" + outcomes = make_outcomes(100, 0.1, 0.5) + bias = calculate_bias(outcomes) + print(bias) + self.assertLess(bias, 0) + + def test_unbiased_near_zero(self): + """Alternating choices should produce bias near zero.""" + outcomes = make_outcomes(100) + bias = calculate_bias(outcomes) + self.assertAlmostEqual(bias, 0, delta=1.0) + + def test_uniform_choices_returns_nan(self): + """All same choice and reward should return nan since bias is undefined.""" + outcomes = make_outcomes(100, 1, 1) + bias = calculate_bias(outcomes) + self.assertTrue(math.isnan(bias)) + + def test_ignored_trials_excluded(self): + """Adding ignored trials should not change the result.""" + valid = make_outcomes(80) + ignored = [TrialOutcome(is_right_choice=None, is_rewarded=False, trial=Trial()) for _ in range(20)] + self.assertEqual(calculate_bias(valid + ignored), calculate_bias(valid)) + + def test_auto_response_trials_excluded(self): + """Adding auto response trials should not change the result.""" + valid = make_outcomes(80) + auto = make_outcomes(20, 1, 1, is_auto_response_right=True) + self.assertEqual(calculate_bias(valid + auto), calculate_bias(valid)) + + def test_only_uses_last_200_trials(self): + """Trials beyond the last 200 should not affect the result.""" + old = make_outcomes(100, 0) + recent = make_outcomes(200, 0.9) + self.assertEqual(calculate_bias(old + recent), calculate_bias(recent)) + + def test_too_few_trials_returns_nan(self): + """Fewer trials than lag should return nan since regression cannot be fit.""" + outcomes = make_outcomes(5) + bias = calculate_bias(outcomes) + self.assertTrue(math.isnan(bias)) + + +if __name__ == "__main__": + unittest.main() diff --git a/uv.lock b/uv.lock index d8a6db8..e062634 100644 --- a/uv.lock +++ b/uv.lock @@ -51,8 +51,8 @@ version = "0.0.2rc32" source = { editable = "." } dependencies = [ { name = "aind-behavior-services" }, - { name = "aind-dynamic-foraging-models" }, { name = "pydantic-settings" }, + { name = "scikit-learn" }, ] [package.optional-dependencies] @@ -80,9 +80,9 @@ docs = [ [package.metadata] requires-dist = [ { name = "aind-behavior-services", specifier = ">=0.13.5" }, - { name = "aind-dynamic-foraging-models", specifier = ">=0.13.1" }, { name = "contraqctor", marker = "extra == 'data'", specifier = ">=0.5.3" }, { name = "pydantic-settings" }, + { name = "scikit-learn", specifier = ">=1.8.0" }, ] provides-extras = ["data"] @@ -244,22 +244,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/bc/c180bff4ab1ac67aeec5206afe75cc0a785980d3f38ba7482daae36231c5/aind_data_transfer_service-2.0.3-py3-none-any.whl", hash = "sha256:4429048c1896dba2fb4688f6fc2ca6bbae5e4e1d645417417e9ad45fd3736bc6", size = 23147, upload-time = "2026-03-23T23:51:32.237Z" }, ] -[[package]] -name = "aind-dynamic-foraging-models" -version = "0.13.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "matplotlib" }, - { name = "numpy" }, - { name = "pydantic" }, - { name = "scikit-learn" }, - { name = "scipy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/eb/47/a7b35997f9767cc4911ce65db39a06492192d338175487916d98b70b1640/aind_dynamic_foraging_models-0.13.1.tar.gz", hash = "sha256:260e8cacb8b55b43f12b2443f876ced4d1cde364f9938199c1b2466425cd2419", size = 6648444, upload-time = "2026-02-21T01:05:25.913Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/61/12/e299389a360dae91b07b936ebed085fa1a718e210319faa50d4766054b26/aind_dynamic_foraging_models-0.13.1-py3-none-any.whl", hash = "sha256:34602f2ee02fa1d5b48449f05980b2ddc37c4f770608a2df3989875856571a13", size = 50200, upload-time = "2026-02-21T01:05:24.632Z" }, -] - [[package]] name = "aind-watchdog-service" version = "0.1.6" From 46c741f7f54d71fd324af031d56ee252185d2c6c Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Sun, 3 May 2026 15:53:55 -0700 Subject: [PATCH 09/15] removes print --- tests/trial_generators/test_calculate_bias.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/trial_generators/test_calculate_bias.py b/tests/trial_generators/test_calculate_bias.py index 752346b..2851878 100644 --- a/tests/trial_generators/test_calculate_bias.py +++ b/tests/trial_generators/test_calculate_bias.py @@ -22,7 +22,7 @@ def setUp(self): def test_right_bias_is_positive(self): """Always choosing right with alternating rewards should produce positive bias.""" - outcomes = make_outcomes(200, 0.9, 0.5) + outcomes = make_outcomes(100, 0.9, 0.5) bias = calculate_bias(outcomes) self.assertGreater(bias, 0) @@ -30,7 +30,6 @@ def test_left_bias_is_negative(self): """Always choosing left with alternating rewards should produce negative bias.""" outcomes = make_outcomes(100, 0.1, 0.5) bias = calculate_bias(outcomes) - print(bias) self.assertLess(bias, 0) def test_unbiased_near_zero(self): From ec58e76cacf9e40292f14b8639a7742761bff302 Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Wed, 27 May 2026 10:14:20 -0700 Subject: [PATCH 10/15] add metrics function --- .../task_logic/trial_generators/_base.py | 5 ++++- .../task_logic/trial_models.py | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py index 2405daa..85ffdee 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -from ..trial_models import Trial, TrialOutcome +from ..trial_models import Trial, TrialOutcome, TrialMetrics class BaseTrialGeneratorSpecModel(BaseModel, abc.ABC): @@ -24,3 +24,6 @@ def next(self) -> Trial | None: def update(self, outcome: TrialOutcome | str) -> None: """Update the trial generator with the outcome of the previous trial.""" + + def metrics(self) -> TrialMetrics: + """Return metrics of session at current state of the trial generator.""" diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py index adb8d07..a633d94 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py @@ -98,3 +98,8 @@ class TrialOutcome(BaseModel): description="Reports the choice made by the subject. True for right, False for left, None for no choice." ) is_rewarded: bool = Field(description="Indicates whether the subject received a reward on this trial.") + +class TrialMetrics(BaseModel): + """Represents metrics of session""" + + bias: Optional[float] = Field(default=None, description="Bias of session. Negative values correspond to left bias, positive right.") \ No newline at end of file From a07b08611be41eacecfaaf29416cdd623a526bef Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Thu, 28 May 2026 10:23:10 -0700 Subject: [PATCH 11/15] adds metrics function --- .../block_based_trial_generator.py | 23 ++++++++---- .../task_logic/utils/calculate_bias.py | 13 ++++--- .../test_block_based_trial_generator.py | 31 ++++++---------- tests/trial_generators/test_calculate_bias.py | 36 ++++++++++++++++++- .../test_coupled_trial_generator.py | 2 +- .../test_uncoupled_trial_generator.py | 2 +- .../test_warmup_trial_generator.py | 2 +- 7 files changed, 74 insertions(+), 35 deletions(-) diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index 71e9e0b..c3e3046 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -16,7 +16,7 @@ from aind_behavior_dynamic_foraging.task_logic.utils.calculate_bias import calculate_bias -from ..trial_models import Metadata, Trial +from ..trial_models import Metadata, Trial, TrialMetrics from ._base import BaseTrialGeneratorSpecModel, ITrialGenerator, TrialOutcome logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ class AutoWaterParameters(BaseModel): ge=0, le=1, description="Fraction of full reward volume delivered during auto water (0=none, 1=full).", - ) + ) # TODO: Not implemented yet class BiasThreshold(BaseModel): @@ -54,7 +54,12 @@ class AntiBiasParameters(BaseModel): ) intervention_interval: int = Field(default=10, ge=0, description="Trials between bias intervention.") maximum_water_corrections: int = Field(default=5, ge=0, description="Number of water correction to attempt.") - volume: int = Field(default=1, ge=0, description="Volume in ul of water given.") + reward_fraction: float = Field( + default=0.8, + ge=0, + le=1, + description="Fraction of full reward volume delivered during auto water (0=none, 1=full).", + ) # TODO: Not implemented yet bias_window_length: int = Field(default=200, ge=0, description="Trials to calculate bias over.") lickspout_offset_delta: float = Field( default=0.05, @@ -108,7 +113,7 @@ class BlockBasedTrialGeneratorSpec(BaseTrialGeneratorSpecModel): autowater_parameters: Optional[AutoWaterParameters] = Field( default=AutoWaterParameters(), validate_default=True, - description="Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + description="Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", ) antibias_parameters: Optional[AntiBiasParameters] = Field( @@ -185,6 +190,8 @@ def update(self, outcome: TrialOutcome | str): else: # trial ignored so current baiting state retained pass + + self.bias = calculate_bias(outcomes=self.outcome_history) def next(self) -> Trial | None: """Generates the next trial in the session. @@ -247,6 +254,12 @@ def next(self) -> Trial | None: extra=BlockBasedTrialMetadata(is_autowater=is_autowater), ), ) + + def metrics(self) -> TrialMetrics: + """Return metrics of session at current state of the trial generator.""" + + return TrialMetrics(bias=self.bias) + def _are_autowater_conditions_met(self) -> bool: """Checks whether autowater should be given. @@ -286,8 +299,6 @@ def _are_antibias_conditions_met(self) -> bool: return False if self.trials_in_bias_intervention > self.spec.antibias_parameters.intervention_interval: - # update bias - self.bias = calculate_bias(outcome_history=self.outcome_history) if self.bias <= self.spec.antibias_parameters.threshold.lower: logger.debug("Bias calculated below threshold: %s." % self.bias) diff --git a/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py b/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py index 0fcb769..3efe33f 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py @@ -7,7 +7,7 @@ from aind_behavior_dynamic_foraging.task_logic.trial_models import TrialOutcome logger = logging.getLogger(__name__) - + def calculate_bias(outcomes: List[TrialOutcome]) -> float: """Estimate the side bias of an animal using logistic regression on recent trial history. @@ -62,9 +62,14 @@ def calculate_bias(outcomes: List[TrialOutcome]) -> float: y = choice_signed[lag:] - if len(np.unique(y)) < 2: # all choices are the same, return max bias in that direction - logger.warning("All choices are the same, cannot calculate bias.") - return np.nan + n_right_choice = np.sum(y == 1) + n_left_choice = np.sum(y == -1) + if min(n_right_choice, n_left_choice) < cv: + logger.warning( + "Not enough trials per class to fit logistic regression (need %d per class, got %d right, %d left).", + cv, n_right_choice, n_left_choice + ) + return np.nan logistic_reg = LogisticRegressionCV( solver=solver, diff --git a/tests/trial_generators/test_block_based_trial_generator.py b/tests/trial_generators/test_block_based_trial_generator.py index f61e123..29c966e 100644 --- a/tests/trial_generators/test_block_based_trial_generator.py +++ b/tests/trial_generators/test_block_based_trial_generator.py @@ -12,7 +12,7 @@ BlockBasedTrialGenerator, BlockBasedTrialGeneratorSpec, ) -from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial +from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome logging.basicConfig(level=logging.DEBUG) @@ -112,9 +112,7 @@ def test_returns_false_when_bias_within_thresholds(self): gen.trials_in_bias_intervention = 15 gen.is_right_choice_history = [True, False] * 50 gen.reward_history = [True] * 100 - - with self._patch_bias(0.5): - result = gen._are_antibias_conditions_met() + result = gen._are_antibias_conditions_met() self.assertFalse(result) @@ -129,9 +127,7 @@ def test_returns_true_when_bias_above_upper_threshold(self): gen.trials_in_bias_intervention = 15 gen.is_right_choice_history = [True] * 100 gen.reward_history = [True] * 100 - - with self._patch_bias(0.9): - result = gen._are_antibias_conditions_met() + result = gen._are_antibias_conditions_met() self.assertTrue(result) @@ -146,9 +142,7 @@ def test_returns_true_when_bias_below_lower_threshold(self): gen.trials_in_bias_intervention = 15 gen.is_right_choice_history = [False] * 100 gen.reward_history = [False] * 100 - - with self._patch_bias(0.9): - result = gen._are_antibias_conditions_met() + result = gen._are_antibias_conditions_met() self.assertTrue(result) @@ -161,7 +155,7 @@ def test_bias_stored_on_generator_after_check(self): ) with self._patch_bias(0.42): - gen._are_antibias_conditions_met() + gen.update(TrialOutcome(is_rewarded=True, is_right_choice=True, trial=Trial())) self.assertAlmostEqual(gen.bias, 0.42) @@ -255,15 +249,13 @@ def test_total_lickspout_offset_updated_after_move(self): def test_next_gives_right_autowater_on_left_bias(self): gen = self._make_generator(bias=-0.9) - with self._patch_bias(-0.9): - trial = gen.next() + trial = gen.next() self.assertIsNotNone(trial) self.assertTrue(trial.is_auto_response_right) def test_next_gives_left_autowater_on_right_bias(self): gen = self._make_generator(bias=0.9) - with self._patch_bias(0.9): - trial = gen.next() + trial = gen.next() self.assertIsNotNone(trial) self.assertFalse(trial.is_auto_response_right) @@ -289,9 +281,7 @@ def test_next_antibias_overrides_autowater(self): gen.trials_in_bias_intervention = 15 gen.is_right_choice_history = [None] # ignored trial → autowater would also fire gen.reward_history = [False] - - with self._patch_bias(-0.9): - trial = gen.next() + trial = gen.next() # Antibias (left bias → give right water) should win self.assertTrue(trial.is_auto_response_right) @@ -299,9 +289,8 @@ def test_next_antibias_overrides_autowater(self): def test_next_lickspout_delta_nonzero_after_corrections_exhausted(self): """After max water corrections, next() should produce a nonzero lickspout delta.""" gen = self._make_generator(bias=-0.9, water_corrections=5) - with self._patch_bias(-0.9): - trial = gen.next() - self.assertNotEqual(trial.lickspout_offset_delta, 0) + trial = gen.next() + self.assertEqual(trial.lickspout_offset_delta, 0.05) def test_next_no_lickspout_delta_when_antibias_not_triggered(self): gen = self._make_generator(bias=-0.9, trials_in_bias_intervention=5) diff --git a/tests/trial_generators/test_calculate_bias.py b/tests/trial_generators/test_calculate_bias.py index 2851878..2577891 100644 --- a/tests/trial_generators/test_calculate_bias.py +++ b/tests/trial_generators/test_calculate_bias.py @@ -1,6 +1,6 @@ import math import unittest - +import time import numpy as np from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome @@ -68,6 +68,40 @@ def test_too_few_trials_returns_nan(self): bias = calculate_bias(outcomes) self.assertTrue(math.isnan(bias)) +TIME_LIMIT_MS = 170 +N_REPEATS = 20 + + +class TestCalculateBiasTiming(unittest.TestCase): + + def _assert_within_time_limit(self, n_trials: int): + times = [] + for _ in range(N_REPEATS): + outcomes = make_outcomes(n_trials) + t0 = time.perf_counter() + calculate_bias(outcomes) + times.append((time.perf_counter() - t0) * 1000) + mean_ms = float(np.mean(times)) + self.assertLess( + mean_ms, + TIME_LIMIT_MS, + f"calculate_bias with {n_trials} trials too slow: " + f"{mean_ms:.1f}ms (limit {TIME_LIMIT_MS}ms)" + ) + + def test_timing_50_trials(self): + self._assert_within_time_limit(50) + + def test_timing_100_trials(self): + self._assert_within_time_limit(100) + + def test_timing_200_trials(self): + self._assert_within_time_limit(200) + + def test_timing_1000_trials(self): + # should be similar to 200 since only last 200 trials evaluated + self._assert_within_time_limit(1000) + if __name__ == "__main__": unittest.main() diff --git a/tests/trial_generators/test_coupled_trial_generator.py b/tests/trial_generators/test_coupled_trial_generator.py index e4cc59e..f22c77d 100644 --- a/tests/trial_generators/test_coupled_trial_generator.py +++ b/tests/trial_generators/test_coupled_trial_generator.py @@ -7,7 +7,7 @@ from aind_behavior_dynamic_foraging.task_logic.trial_generators import CoupledTrialGeneratorSpec from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome -from .util import simulate_response +from tests.trial_generators.util import simulate_response logging.basicConfig(level=logging.DEBUG) diff --git a/tests/trial_generators/test_uncoupled_trial_generator.py b/tests/trial_generators/test_uncoupled_trial_generator.py index 2b21422..cde2774 100644 --- a/tests/trial_generators/test_uncoupled_trial_generator.py +++ b/tests/trial_generators/test_uncoupled_trial_generator.py @@ -11,7 +11,7 @@ ) from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial -from .util import simulate_response +from tests.trial_generators.util import simulate_response logging.basicConfig(level=logging.DEBUG) diff --git a/tests/trial_generators/test_warmup_trial_generator.py b/tests/trial_generators/test_warmup_trial_generator.py index e478685..3f8f5fb 100644 --- a/tests/trial_generators/test_warmup_trial_generator.py +++ b/tests/trial_generators/test_warmup_trial_generator.py @@ -7,7 +7,7 @@ ) from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome -from .util import simulate_response +from tests.trial_generators.util import simulate_response def make_outcome(is_right_choice: bool | None, is_rewarded: bool) -> TrialOutcome: From e54ac89629fc00ae82d9c7efe90b4feb9ae51bef Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Thu, 28 May 2026 10:24:53 -0700 Subject: [PATCH 12/15] lints --- schema/aind_behavior_dynamic_foraging.json | 31 ++++++------ .../AindBehaviorDynamicForaging.Generated.cs | 50 +++++++++---------- .../task_logic/trial_generators/_base.py | 2 +- .../block_based_trial_generator.py | 14 ++---- .../task_logic/trial_models.py | 5 +- .../task_logic/utils/calculate_bias.py | 8 +-- tests/trial_generators/test_calculate_bias.py | 14 +++--- .../test_coupled_trial_generator.py | 1 - .../test_uncoupled_trial_generator.py | 1 - .../test_warmup_trial_generator.py | 1 - 10 files changed, 63 insertions(+), 64 deletions(-) diff --git a/schema/aind_behavior_dynamic_foraging.json b/schema/aind_behavior_dynamic_foraging.json index 7127ebf..10b53c3 100644 --- a/schema/aind_behavior_dynamic_foraging.json +++ b/schema/aind_behavior_dynamic_foraging.json @@ -337,12 +337,13 @@ "title": "Maximum Water Corrections", "type": "integer" }, - "volume": { - "default": 1, - "description": "Volume in ul of water given.", + "reward_fraction": { + "default": 0.8, + "description": "Fraction of full reward volume delivered during auto water (0=none, 1=full).", + "maximum": 1, "minimum": 0, - "title": "Volume", - "type": "integer" + "title": "Reward Fraction", + "type": "number" }, "bias_window_length": { "default": 200, @@ -560,7 +561,7 @@ "min_unrewarded_trials": 3, "reward_fraction": 0.8 }, - "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { "$ref": "#/$defs/AutoWaterParameters" @@ -578,7 +579,7 @@ }, "intervention_interval": 10, "maximum_water_corrections": 5, - "volume": 1, + "reward_fraction": 0.8, "bias_window_length": 200, "lickspout_offset_delta": 0.05 }, @@ -904,7 +905,7 @@ "min_unrewarded_trials": 3, "reward_fraction": 0.8 }, - "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { "$ref": "#/$defs/AutoWaterParameters" @@ -922,7 +923,7 @@ }, "intervention_interval": 10, "maximum_water_corrections": 5, - "volume": 1, + "reward_fraction": 0.8, "bias_window_length": 200, "lickspout_offset_delta": 0.05 }, @@ -1192,7 +1193,7 @@ "min_unrewarded_trials": 3, "reward_fraction": 0.8 }, - "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { "$ref": "#/$defs/AutoWaterParameters" @@ -1210,7 +1211,7 @@ }, "intervention_interval": 10, "maximum_water_corrections": 5, - "volume": 1, + "reward_fraction": 0.8, "bias_window_length": 200, "lickspout_offset_delta": 0.05 }, @@ -1403,7 +1404,7 @@ "min_unrewarded_trials": 3, "reward_fraction": 0.8 }, - "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { "$ref": "#/$defs/AutoWaterParameters" @@ -1421,7 +1422,7 @@ }, "intervention_interval": 10, "maximum_water_corrections": 5, - "volume": 1, + "reward_fraction": 0.8, "bias_window_length": 200, "lickspout_offset_delta": 0.05 }, @@ -3745,7 +3746,7 @@ "min_unrewarded_trials": 3, "reward_fraction": 0.8 }, - "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { "$ref": "#/$defs/AutoWaterParameters" @@ -3763,7 +3764,7 @@ }, "intervention_interval": 10, "maximum_water_corrections": 5, - "volume": 1, + "reward_fraction": 0.8, "bias_window_length": 200, "lickspout_offset_delta": 0.05 }, diff --git a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs index 9264560..7314e3b 100644 --- a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs +++ b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs @@ -715,7 +715,7 @@ public partial class AntiBiasParameters private int _maximumWaterCorrections; - private int _volume; + private double _rewardFraction; private int _biasWindowLength; @@ -726,7 +726,7 @@ public AntiBiasParameters() _threshold = new BiasThreshold(); _interventionInterval = 10; _maximumWaterCorrections = 5; - _volume = 1; + _rewardFraction = 0.8D; _biasWindowLength = 200; _lickspoutOffsetDelta = 0.05D; } @@ -736,7 +736,7 @@ protected AntiBiasParameters(AntiBiasParameters other) _threshold = other._threshold; _interventionInterval = other._interventionInterval; _maximumWaterCorrections = other._maximumWaterCorrections; - _volume = other._volume; + _rewardFraction = other._rewardFraction; _biasWindowLength = other._biasWindowLength; _lickspoutOffsetDelta = other._lickspoutOffsetDelta; } @@ -794,19 +794,19 @@ public int MaximumWaterCorrections } /// - /// Volume in ul of water given. + /// Fraction of full reward volume delivered during auto water (0=none, 1=full). /// - [Newtonsoft.Json.JsonPropertyAttribute("volume")] - [System.ComponentModel.DescriptionAttribute("Volume in ul of water given.")] - public int Volume + [Newtonsoft.Json.JsonPropertyAttribute("reward_fraction")] + [System.ComponentModel.DescriptionAttribute("Fraction of full reward volume delivered during auto water (0=none, 1=full).")] + public double RewardFraction { get { - return _volume; + return _rewardFraction; } set { - _volume = value; + _rewardFraction = value; } } @@ -860,7 +860,7 @@ protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("Threshold = " + _threshold + ", "); stringBuilder.Append("InterventionInterval = " + _interventionInterval + ", "); stringBuilder.Append("MaximumWaterCorrections = " + _maximumWaterCorrections + ", "); - stringBuilder.Append("Volume = " + _volume + ", "); + stringBuilder.Append("RewardFraction = " + _rewardFraction + ", "); stringBuilder.Append("BiasWindowLength = " + _biasWindowLength + ", "); stringBuilder.Append("LickspoutOffsetDelta = " + _lickspoutOffsetDelta); return true; @@ -1198,12 +1198,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block } /// - /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. + /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] - [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " + - "ignored or unrewarded trial thresholds.")] + [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + + "gnored or unrewarded trial thresholds.")] public AutoWaterParameters AutowaterParameters { get @@ -1686,12 +1686,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block } /// - /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. + /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] - [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " + - "ignored or unrewarded trial thresholds.")] + [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + + "gnored or unrewarded trial thresholds.")] public AutoWaterParameters AutowaterParameters { get @@ -2433,12 +2433,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block } /// - /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. + /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] - [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " + - "ignored or unrewarded trial thresholds.")] + [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + + "gnored or unrewarded trial thresholds.")] public AutoWaterParameters AutowaterParameters { get @@ -2902,12 +2902,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block } /// - /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. + /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] - [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " + - "ignored or unrewarded trial thresholds.")] + [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + + "gnored or unrewarded trial thresholds.")] public AutoWaterParameters AutowaterParameters { get @@ -6971,12 +6971,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.UniformDistributio } /// - /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. + /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] - [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " + - "ignored or unrewarded trial thresholds.")] + [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + + "gnored or unrewarded trial thresholds.")] public AutoWaterParameters AutowaterParameters { get diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py index 85ffdee..285adaf 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -from ..trial_models import Trial, TrialOutcome, TrialMetrics +from ..trial_models import Trial, TrialMetrics, TrialOutcome class BaseTrialGeneratorSpecModel(BaseModel, abc.ABC): diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index c3e3046..4adc9de 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -14,8 +14,6 @@ from aind_behavior_dynamic_foraging.task_logic.utils.calculate_bias import calculate_bias -from aind_behavior_dynamic_foraging.task_logic.utils.calculate_bias import calculate_bias - from ..trial_models import Metadata, Trial, TrialMetrics from ._base import BaseTrialGeneratorSpecModel, ITrialGenerator, TrialOutcome @@ -40,7 +38,7 @@ class AutoWaterParameters(BaseModel): ge=0, le=1, description="Fraction of full reward volume delivered during auto water (0=none, 1=full).", - ) # TODO: Not implemented yet + ) # TODO: Not implemented yet class BiasThreshold(BaseModel): @@ -59,7 +57,7 @@ class AntiBiasParameters(BaseModel): ge=0, le=1, description="Fraction of full reward volume delivered during auto water (0=none, 1=full).", - ) # TODO: Not implemented yet + ) # TODO: Not implemented yet bias_window_length: int = Field(default=200, ge=0, description="Trials to calculate bias over.") lickspout_offset_delta: float = Field( default=0.05, @@ -190,7 +188,7 @@ def update(self, outcome: TrialOutcome | str): else: # trial ignored so current baiting state retained pass - + self.bias = calculate_bias(outcomes=self.outcome_history) def next(self) -> Trial | None: @@ -254,12 +252,11 @@ def next(self) -> Trial | None: extra=BlockBasedTrialMetadata(is_autowater=is_autowater), ), ) - + def metrics(self) -> TrialMetrics: """Return metrics of session at current state of the trial generator.""" - - return TrialMetrics(bias=self.bias) + return TrialMetrics(bias=self.bias) def _are_autowater_conditions_met(self) -> bool: """Checks whether autowater should be given. @@ -299,7 +296,6 @@ def _are_antibias_conditions_met(self) -> bool: return False if self.trials_in_bias_intervention > self.spec.antibias_parameters.intervention_interval: - if self.bias <= self.spec.antibias_parameters.threshold.lower: logger.debug("Bias calculated below threshold: %s." % self.bias) return True diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py index 9bfe4e4..ee0a234 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py @@ -121,7 +121,10 @@ class TrialOutcome(BaseModel): ) is_rewarded: bool = Field(description="Indicates whether the subject received a reward on this trial.") + class TrialMetrics(BaseModel): """Represents metrics of session""" - bias: Optional[float] = Field(default=None, description="Bias of session. Negative values correspond to left bias, positive right.") \ No newline at end of file + bias: Optional[float] = Field( + default=None, description="Bias of session. Negative values correspond to left bias, positive right." + ) diff --git a/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py b/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py index 3efe33f..e2767f8 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py @@ -7,7 +7,7 @@ from aind_behavior_dynamic_foraging.task_logic.trial_models import TrialOutcome logger = logging.getLogger(__name__) - + def calculate_bias(outcomes: List[TrialOutcome]) -> float: """Estimate the side bias of an animal using logistic regression on recent trial history. @@ -67,9 +67,11 @@ def calculate_bias(outcomes: List[TrialOutcome]) -> float: if min(n_right_choice, n_left_choice) < cv: logger.warning( "Not enough trials per class to fit logistic regression (need %d per class, got %d right, %d left).", - cv, n_right_choice, n_left_choice + cv, + n_right_choice, + n_left_choice, ) - return np.nan + return np.nan logistic_reg = LogisticRegressionCV( solver=solver, diff --git a/tests/trial_generators/test_calculate_bias.py b/tests/trial_generators/test_calculate_bias.py index 2577891..c3d23a4 100644 --- a/tests/trial_generators/test_calculate_bias.py +++ b/tests/trial_generators/test_calculate_bias.py @@ -1,6 +1,7 @@ import math -import unittest import time +import unittest + import numpy as np from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome @@ -68,12 +69,12 @@ def test_too_few_trials_returns_nan(self): bias = calculate_bias(outcomes) self.assertTrue(math.isnan(bias)) + TIME_LIMIT_MS = 170 N_REPEATS = 20 class TestCalculateBiasTiming(unittest.TestCase): - def _assert_within_time_limit(self, n_trials: int): times = [] for _ in range(N_REPEATS): @@ -85,16 +86,15 @@ def _assert_within_time_limit(self, n_trials: int): self.assertLess( mean_ms, TIME_LIMIT_MS, - f"calculate_bias with {n_trials} trials too slow: " - f"{mean_ms:.1f}ms (limit {TIME_LIMIT_MS}ms)" + f"calculate_bias with {n_trials} trials too slow: {mean_ms:.1f}ms (limit {TIME_LIMIT_MS}ms)", ) - + def test_timing_50_trials(self): self._assert_within_time_limit(50) - + def test_timing_100_trials(self): self._assert_within_time_limit(100) - + def test_timing_200_trials(self): self._assert_within_time_limit(200) diff --git a/tests/trial_generators/test_coupled_trial_generator.py b/tests/trial_generators/test_coupled_trial_generator.py index f22c77d..cae55c8 100644 --- a/tests/trial_generators/test_coupled_trial_generator.py +++ b/tests/trial_generators/test_coupled_trial_generator.py @@ -6,7 +6,6 @@ from aind_behavior_dynamic_foraging.task_logic.trial_generators import CoupledTrialGeneratorSpec from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome - from tests.trial_generators.util import simulate_response logging.basicConfig(level=logging.DEBUG) diff --git a/tests/trial_generators/test_uncoupled_trial_generator.py b/tests/trial_generators/test_uncoupled_trial_generator.py index cde2774..d952172 100644 --- a/tests/trial_generators/test_uncoupled_trial_generator.py +++ b/tests/trial_generators/test_uncoupled_trial_generator.py @@ -10,7 +10,6 @@ UncoupledTrialGeneratorSpec, ) from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial - from tests.trial_generators.util import simulate_response logging.basicConfig(level=logging.DEBUG) diff --git a/tests/trial_generators/test_warmup_trial_generator.py b/tests/trial_generators/test_warmup_trial_generator.py index 3f8f5fb..92a89eb 100644 --- a/tests/trial_generators/test_warmup_trial_generator.py +++ b/tests/trial_generators/test_warmup_trial_generator.py @@ -6,7 +6,6 @@ CoupledWarmupTrialGeneratorSpec, ) from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome - from tests.trial_generators.util import simulate_response From 5a98537fc861386d323d6fdb9861f5d4a8faad07 Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Fri, 29 May 2026 09:15:02 -0700 Subject: [PATCH 13/15] integrates metrics into bonsai workflow --- schema/aind_behavior_dynamic_foraging.json | 20 +++++ .../AindBehaviorDynamicForaging.Generated.cs | 75 +++++++++++++++++++ src/Extensions/DeserializeFromPyObject.cs | 2 + src/Extensions/TaskEngine.bonsai | 43 ++++++++++- .../regenerate.py | 1 + .../block_based_trial_generator.py | 2 +- src/main.bonsai | 25 ++++--- 7 files changed, 155 insertions(+), 13 deletions(-) diff --git a/schema/aind_behavior_dynamic_foraging.json b/schema/aind_behavior_dynamic_foraging.json index 10b53c3..38cf055 100644 --- a/schema/aind_behavior_dynamic_foraging.json +++ b/schema/aind_behavior_dynamic_foraging.json @@ -3563,6 +3563,26 @@ } ] }, + "TrialMetrics": { + "description": "Represents metrics of session", + "properties": { + "bias": { + "default": null, + "description": "Bias of session. Negative values correspond to left bias, positive right.", + "oneOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "title": "Bias" + } + }, + "title": "TrialMetrics", + "type": "object" + }, "TrialOutcome": { "description": "Represents the outcome of a single trial.", "properties": { diff --git a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs index 7314e3b..021a37d 100644 --- a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs +++ b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs @@ -6509,6 +6509,75 @@ public override string ToString() } + /// + /// Represents metrics of session + /// + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")] + [System.ComponentModel.DescriptionAttribute("Represents metrics of session")] + [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] + [Bonsai.CombinatorAttribute(MethodName="Generate")] + public partial class TrialMetrics + { + + private double? _bias; + + public TrialMetrics() + { + } + + protected TrialMetrics(TrialMetrics other) + { + _bias = other._bias; + } + + /// + /// Bias of session. Negative values correspond to left bias, positive right. + /// + [Newtonsoft.Json.JsonPropertyAttribute("bias")] + [System.ComponentModel.DescriptionAttribute("Bias of session. Negative values correspond to left bias, positive right.")] + public double? Bias + { + get + { + return _bias; + } + set + { + _bias = value; + } + } + + public System.IObservable Generate() + { + return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new TrialMetrics(this))); + } + + public System.IObservable Generate(System.IObservable source) + { + return System.Reactive.Linq.Observable.Select(source, _ => new TrialMetrics(this)); + } + + protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder) + { + stringBuilder.Append("Bias = " + _bias); + return true; + } + + public override string ToString() + { + System.Text.StringBuilder stringBuilder = new System.Text.StringBuilder(); + stringBuilder.Append(GetType().Name); + stringBuilder.Append(" { "); + if (PrintMembers(stringBuilder)) + { + stringBuilder.Append(" "); + } + stringBuilder.Append("}"); + return stringBuilder.ToString(); + } + } + + /// /// Represents the outcome of a single trial. /// @@ -8421,6 +8490,11 @@ public System.IObservable Process(System.IObservable return Process(source); } + public System.IObservable Process(System.IObservable source) + { + return Process(source); + } + public System.IObservable Process(System.IObservable source) { return Process(source); @@ -8519,6 +8593,7 @@ public System.IObservable Process(System.IObservable source) [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] diff --git a/src/Extensions/DeserializeFromPyObject.cs b/src/Extensions/DeserializeFromPyObject.cs index 6b19749..558e0a0 100644 --- a/src/Extensions/DeserializeFromPyObject.cs +++ b/src/Extensions/DeserializeFromPyObject.cs @@ -15,6 +15,8 @@ [System.Xml.Serialization.XmlIncludeAttribute(typeof(TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(TypeMapping))] +[System.Xml.Serialization.XmlIncludeAttribute(typeof(TypeMapping))] + public class DeserializeFromPyObject : SingleArgumentExpressionBuilder { public DeserializeFromPyObject() diff --git a/src/Extensions/TaskEngine.bonsai b/src/Extensions/TaskEngine.bonsai index 46fea6c..87c0914 100644 --- a/src/Extensions/TaskEngine.bonsai +++ b/src/Extensions/TaskEngine.bonsai @@ -1358,6 +1358,45 @@ + + thisTrialOutcome + + + Metrics + + + + + + + + trial_generator + + + + + metrics + true + + + + + + + + + + + + + + + + + + + GlobalTrialMetrics + ThisTrial @@ -1423,11 +1462,13 @@ - + + + diff --git a/src/aind_behavior_dynamic_foraging/regenerate.py b/src/aind_behavior_dynamic_foraging/regenerate.py index 4dd1717..53a8389 100644 --- a/src/aind_behavior_dynamic_foraging/regenerate.py +++ b/src/aind_behavior_dynamic_foraging/regenerate.py @@ -20,6 +20,7 @@ def main(): Session, aind_behavior_dynamic_foraging.task_logic.trial_models.Trial, aind_behavior_dynamic_foraging.task_logic.trial_models.TrialOutcome, + aind_behavior_dynamic_foraging.task_logic.trial_models.TrialMetrics, ] model = pydantic.RootModel[Union[tuple(models)]] diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index 4adc9de..84b2b72 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -162,7 +162,7 @@ def __init__(self, spec: BlockBasedTrialGeneratorSpec) -> None: # antibias parameters self.trials_in_bias_intervention = 0 self.water_corrections = 0 - self.bias: float + self.bias: Optional[float] = None self.total_lickspout_offset = 0 def update(self, outcome: TrialOutcome | str): diff --git a/src/main.bonsai b/src/main.bonsai index f64dd49..a5e1dbe 100644 --- a/src/main.bonsai +++ b/src/main.bonsai @@ -150,6 +150,9 @@ GlobalTrial + + GlobalTrialMetrics + TaskLogicParameters @@ -230,21 +233,21 @@ - - - - + + + - - + + - - - - - + + + + + + From 670b6f99ecdbb6a9f33d579046a8106715fbf38b Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Fri, 29 May 2026 10:05:25 -0700 Subject: [PATCH 14/15] moves metrics update --- src/Extensions/TaskEngine.bonsai | 81 +++++++++++++++----------------- 1 file changed, 39 insertions(+), 42 deletions(-) diff --git a/src/Extensions/TaskEngine.bonsai b/src/Extensions/TaskEngine.bonsai index 87c0914..849641b 100644 --- a/src/Extensions/TaskEngine.bonsai +++ b/src/Extensions/TaskEngine.bonsai @@ -1217,6 +1217,42 @@ true + + Metrics + + + + + + + + trial_generator + + + + + metrics + true + + + + + + + + + + + + + + + + + + + GlobalTrialMetrics + SampleNextTrial @@ -1262,6 +1298,8 @@ + + @@ -1358,45 +1396,6 @@ - - thisTrialOutcome - - - Metrics - - - - - - - - trial_generator - - - - - metrics - true - - - - - - - - - - - - - - - - - - - GlobalTrialMetrics - ThisTrial @@ -1462,13 +1461,11 @@ + - - - From 46c0e06dcae2142b5673b120144291e6c2fa62a1 Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Fri, 29 May 2026 12:29:48 -0700 Subject: [PATCH 15/15] adds TrialMetrics software event --- src/Extensions/TaskEngine.bonsai | 12 ++++++++---- .../data_contract/_dataset.py | 7 +++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/Extensions/TaskEngine.bonsai b/src/Extensions/TaskEngine.bonsai index 849641b..4367cfd 100644 --- a/src/Extensions/TaskEngine.bonsai +++ b/src/Extensions/TaskEngine.bonsai @@ -1240,6 +1240,12 @@ + + GlobalTrialMetrics + + + TrialMetrics + @@ -1247,12 +1253,11 @@ + + - - GlobalTrialMetrics - SampleNextTrial @@ -1299,7 +1304,6 @@ - diff --git a/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py b/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py index 3f2e238..36b5435 100644 --- a/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py +++ b/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py @@ -194,6 +194,13 @@ def make_dataset( name="SoftwareEvents", description="Software events generated by the workflow. The timestamps of these events are low precision and should not be used to align to physiology data.", data_streams=[ + SoftwareEvents( + name="TrialMetrics", + description="An event emitted with the metrics of a trial.", + reader_params=SoftwareEvents.make_params( + root_path / "behavior/SoftwareEvents/TrialMetrics.json" + ), + ), SoftwareEvents( name="TrialGeneratorSpec", description="An event emitted with the specification for the trial generator.",