diff --git a/schema/aind_behavior_dynamic_foraging.json b/schema/aind_behavior_dynamic_foraging.json index 1d73d69..681846e 100644 --- a/schema/aind_behavior_dynamic_foraging.json +++ b/schema/aind_behavior_dynamic_foraging.json @@ -2105,6 +2105,54 @@ "title": "Measurement", "type": "object" }, + "Metadata": { + "description": "Metadata for trial. These fields will NOT be used by the task engine.", + "properties": { + "p_reward_left": { + "default": null, + "description": "Metadata for block probability of reward on the left side if response is made.", + "oneOf": [ + { + "maximum": 1, + "minimum": 0, + "type": "number" + }, + { + "type": "null" + } + ], + "title": "P Reward Left" + }, + "p_reward_right": { + "default": null, + "description": "Metadata for the probability of reward on the right side if response is made.", + "oneOf": [ + { + "maximum": 1, + "minimum": 0, + "type": "number" + }, + { + "type": "null" + } + ], + "title": "P Reward Right" + }, + "extra": { + "default": null, + "description": "Additional metadata to include with the trial. This field will NOT be used or validated by the task engine.", + "oneOf": [ + {}, + { + "type": "null" + } + ], + "title": "Extra" + } + }, + "title": "Metadata", + "type": "object" + }, "MicrostepResolution": { "description": "Microstep resolution available", "enum": [ @@ -3277,16 +3325,17 @@ "title": "Lickspout Offset Delta", "type": "number" }, - "extra_metadata": { + "metadata": { "default": null, - "description": "Additional metadata to include with the trial. This field will NOT be used or validated by the task engine.", + "description": "Metadata fields that will not be used by task engine such as block information.", "oneOf": [ - {}, + { + "$ref": "#/$defs/Metadata" + }, { "type": "null" } - ], - "title": "Extra Metadata" + ] } }, "title": "Trial", diff --git a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs index 7fdde9f..fd4cb0f 100644 --- a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs +++ b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs @@ -3894,6 +3894,119 @@ public override string ToString() } + /// + /// Metadata for trial. These fields will NOT be used by the task engine. + /// + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")] + [System.ComponentModel.DescriptionAttribute("Metadata for trial. These fields will NOT be used by the task engine.")] + [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] + [Bonsai.CombinatorAttribute(MethodName="Generate")] + public partial class Metadata + { + + private double? _pRewardLeft; + + private double? _pRewardRight; + + private object _extra; + + public Metadata() + { + } + + protected Metadata(Metadata other) + { + _pRewardLeft = other._pRewardLeft; + _pRewardRight = other._pRewardRight; + _extra = other._extra; + } + + /// + /// Metadata for block probability of reward on the left side if response is made. + /// + [Newtonsoft.Json.JsonPropertyAttribute("p_reward_left")] + [System.ComponentModel.DescriptionAttribute("Metadata for block probability of reward on the left side if response is made.")] + public double? PRewardLeft + { + get + { + return _pRewardLeft; + } + set + { + _pRewardLeft = value; + } + } + + /// + /// Metadata for the probability of reward on the right side if response is made. + /// + [Newtonsoft.Json.JsonPropertyAttribute("p_reward_right")] + [System.ComponentModel.DescriptionAttribute("Metadata for the probability of reward on the right side if response is made.")] + public double? PRewardRight + { + get + { + return _pRewardRight; + } + set + { + _pRewardRight = value; + } + } + + /// + /// Additional metadata to include with the trial. This field will NOT be used or validated by the task engine. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("extra")] + [System.ComponentModel.DescriptionAttribute("Additional metadata to include with the trial. This field will NOT be used or val" + + "idated by the task engine.")] + public object Extra + { + get + { + return _extra; + } + set + { + _extra = value; + } + } + + public System.IObservable Generate() + { + return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new Metadata(this))); + } + + public System.IObservable Generate(System.IObservable source) + { + return System.Reactive.Linq.Observable.Select(source, _ => new Metadata(this)); + } + + protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder) + { + stringBuilder.Append("PRewardLeft = " + _pRewardLeft + ", "); + stringBuilder.Append("PRewardRight = " + _pRewardRight + ", "); + stringBuilder.Append("Extra = " + _extra); + return true; + } + + public override string ToString() + { + System.Text.StringBuilder stringBuilder = new System.Text.StringBuilder(); + stringBuilder.Append(GetType().Name); + stringBuilder.Append(" { "); + if (PrintMembers(stringBuilder)) + { + stringBuilder.Append(" "); + } + stringBuilder.Append("}"); + return stringBuilder.ToString(); + } + } + + /// /// Settings for the quick retract feature. /// @@ -5639,7 +5752,7 @@ public partial class Trial private double _lickspoutOffsetDelta; - private object _extraMetadata; + private Metadata _metadata; public Trial() { @@ -5666,7 +5779,7 @@ protected Trial(Trial other) _interTrialIntervalDuration = other._interTrialIntervalDuration; _isAutoResponseRight = other._isAutoResponseRight; _lickspoutOffsetDelta = other._lickspoutOffsetDelta; - _extraMetadata = other._extraMetadata; + _metadata = other._metadata; } /// @@ -5863,21 +5976,20 @@ public double LickspoutOffsetDelta } /// - /// Additional metadata to include with the trial. This field will NOT be used or validated by the task engine. + /// Metadata fields that will not be used by task engine such as block information. /// [System.Xml.Serialization.XmlIgnoreAttribute()] - [Newtonsoft.Json.JsonPropertyAttribute("extra_metadata")] - [System.ComponentModel.DescriptionAttribute("Additional metadata to include with the trial. This field will NOT be used or val" + - "idated by the task engine.")] - public object ExtraMetadata + [Newtonsoft.Json.JsonPropertyAttribute("metadata")] + [System.ComponentModel.DescriptionAttribute("Metadata fields that will not be used by task engine such as block information.")] + public Metadata Metadata { get { - return _extraMetadata; + return _metadata; } set { - _extraMetadata = value; + _metadata = value; } } @@ -5904,7 +6016,7 @@ protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("IsAutoResponseRight = " + _isAutoResponseRight + ", "); stringBuilder.Append("LickspoutOffsetDelta = " + _lickspoutOffsetDelta + ", "); - stringBuilder.Append("ExtraMetadata = " + _extraMetadata); + stringBuilder.Append("Metadata = " + _metadata); return true; } @@ -7853,6 +7965,11 @@ public System.IObservable Process(System.IObservable source return Process(source); } + public System.IObservable Process(System.IObservable source) + { + return Process(source); + } + public System.IObservable Process(System.IObservable source) { return Process(source); @@ -7992,6 +8109,7 @@ public System.IObservable Process(System.IObservable source) [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index 3729188..39ddc92 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -12,12 +12,18 @@ from aind_behavior_services.task.distributions_utils import draw_sample from pydantic import BaseModel, Field -from ..trial_models import Trial +from ..trial_models import Metadata, Trial from ._base import BaseTrialGeneratorSpecModel, ITrialGenerator, TrialOutcome logger = logging.getLogger(__name__) +class BlockBasedTrialMetadata(BaseModel): + """Metadata for block based trial. These fields will NOT be used by the task engine.""" + + is_autowater: bool = Field(default=False, description="Flag indicating if autowater is given for trial.") + + class AutoWaterParameters(BaseModel): min_ignored_trials: int = Field( default=3, ge=0, description="Minimum consecutive ignored trials before auto water is triggered." @@ -171,18 +177,23 @@ def next(self) -> Trial | None: logger.debug("Right baited: %s" % self.is_right_baited) # determine autowater - is_right_autowater = None - if self._are_autowater_conditions_met(): - is_right_autowater = True if self.block.p_right_reward > self.block.p_left_reward else False + is_auto_response_right = None + if is_autowater := self._are_autowater_conditions_met(): + is_auto_response_right = True if self.block.p_right_reward > self.block.p_left_reward else False return Trial( - p_reward_left=1 if (self.is_left_baited and self.spec.is_baiting) else self.block.p_left_reward, - p_reward_right=1 if (self.is_right_baited and self.spec.is_baiting) else self.block.p_right_reward, + p_reward_left=1 if (self.is_left_baited or is_auto_response_right is False) else self.block.p_left_reward, + p_reward_right=1 if (self.is_right_baited or is_auto_response_right) else self.block.p_right_reward, reward_consumption_duration=self.spec.reward_consumption_duration, response_deadline_duration=self.spec.response_duration, quiescence_period_duration=quiescent, inter_trial_interval_duration=iti, - is_auto_response_right=is_right_autowater, + is_auto_response_right=is_auto_response_right, + metadata=Metadata( + p_reward_left=self.block.p_left_reward, + p_reward_right=self.block.p_right_reward, + extra=BlockBasedTrialMetadata(is_autowater=is_autowater), + ), ) def _are_autowater_conditions_met(self) -> bool: @@ -199,11 +210,11 @@ def _are_autowater_conditions_met(self) -> bool: min_unreward = self.spec.autowater_parameters.min_unrewarded_trials is_ignored = [choice is None for choice in self.is_right_choice_history] - if all(is_ignored[-min_ignore:]): + if len(is_ignored) > min_ignore and all(is_ignored[-min_ignore:]): return True is_unrewarded = [not reward for reward in self.reward_history] - if all(is_unrewarded[-min_unreward:]): + if len(is_unrewarded) > min_unreward and all(is_unrewarded[-min_unreward:]): return True return False diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py index adb8d07..7f28f4d 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py @@ -43,6 +43,27 @@ class QuickRetractSettings(BaseModel): ) +class Metadata(BaseModel): + """Metadata for trial. These fields will NOT be used by the task engine.""" + + p_reward_left: Optional[float] = Field( + default=None, + ge=0, + le=1, + description="Metadata for block probability of reward on the left side if response is made.", + ) + p_reward_right: Optional[float] = Field( + default=None, + ge=0, + le=1, + description="Metadata for the probability of reward on the right side if response is made.", + ) + extra: Optional[SerializeAsAny[Any]] = Field( + default=None, + description="Additional metadata to include with the trial. This field will NOT be used or validated by the task engine.", + ) + + class Trial(BaseModel): """Represents a single trial that can be instantiated by the Bonsai state machine.""" @@ -84,9 +105,10 @@ class Trial(BaseModel): default=0.0, description="Horizontal delta offset of the lickspouts (in mm) applied in this trial. Positive values move the lickspouts right.", ) - extra_metadata: Optional[SerializeAsAny[Any]] = Field( + metadata: Optional[Metadata] = Field( default=None, - description="Additional metadata to include with the trial. This field will NOT be used or validated by the task engine.", + validate_default=True, + description="Metadata fields that will not be used by task engine such as block information.", ) diff --git a/tests/trial_generators/test_block_based_trial_generator.py b/tests/trial_generators/test_block_based_trial_generator.py index 60efd2d..661439d 100644 --- a/tests/trial_generators/test_block_based_trial_generator.py +++ b/tests/trial_generators/test_block_based_trial_generator.py @@ -48,18 +48,6 @@ def test_next_returns_correct_reward_probs(self): self.assertEqual(trial.p_reward_left, self.generator.block.p_left_reward) self.assertEqual(trial.p_reward_right, self.generator.block.p_right_reward) - #### Test unbaited #### - - def test_baiting_disabled_reward_prob_unchanged(self): - """Without baiting, reward probs should equal block probs exactly.""" - self.generator.block = Block(p_right_reward=0.8, p_left_reward=0.2, right_length=10, left_length=10) - self.generator.is_left_baited = True - self.generator.is_right_baited = True - trial = self.generator.next() - - self.assertEqual(trial.p_reward_right, 0.8) - self.assertEqual(trial.p_reward_left, 0.2) - class TestBlockBaseBaitingTrialGenerator(unittest.TestCase): def setUp(self): @@ -80,14 +68,14 @@ def test_baiting_sets_prob_to_1_when_baited(self): def test_baiting_accumulates_when_random_exceeds_prob(self): """Bait should carry over when random number exceeds reward prob.""" self.generator.block = Block(p_right_reward=0.5, p_left_reward=0.5, right_length=10, left_length=10) - self.generator.is_right_baited = False - self.generator.is_left_baited = False + self.generator.is_right_baited = True + self.generator.is_left_baited = True with patch("numpy.random.random", return_value=np.array([0.9, 0.9])): trial = self.generator.next() - self.assertEqual(trial.p_reward_right, 0.5) - self.assertEqual(trial.p_reward_left, 0.5) + self.assertEqual(trial.p_reward_right, 1.0) + self.assertEqual(trial.p_reward_left, 1.0) def test_baiting_triggers_when_random_below_prob(self): """Bait should trigger reward prob of 1.0 when random number is below reward prob."""