diff --git a/schema/aind_behavior_dynamic_foraging.json b/schema/aind_behavior_dynamic_foraging.json
index 1d73d69..681846e 100644
--- a/schema/aind_behavior_dynamic_foraging.json
+++ b/schema/aind_behavior_dynamic_foraging.json
@@ -2105,6 +2105,54 @@
"title": "Measurement",
"type": "object"
},
+ "Metadata": {
+ "description": "Metadata for trial. These fields will NOT be used by the task engine.",
+ "properties": {
+ "p_reward_left": {
+ "default": null,
+ "description": "Metadata for block probability of reward on the left side if response is made.",
+ "oneOf": [
+ {
+ "maximum": 1,
+ "minimum": 0,
+ "type": "number"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "P Reward Left"
+ },
+ "p_reward_right": {
+ "default": null,
+ "description": "Metadata for the probability of reward on the right side if response is made.",
+ "oneOf": [
+ {
+ "maximum": 1,
+ "minimum": 0,
+ "type": "number"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "P Reward Right"
+ },
+ "extra": {
+ "default": null,
+ "description": "Additional metadata to include with the trial. This field will NOT be used or validated by the task engine.",
+ "oneOf": [
+ {},
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Extra"
+ }
+ },
+ "title": "Metadata",
+ "type": "object"
+ },
"MicrostepResolution": {
"description": "Microstep resolution available",
"enum": [
@@ -3277,16 +3325,17 @@
"title": "Lickspout Offset Delta",
"type": "number"
},
- "extra_metadata": {
+ "metadata": {
"default": null,
- "description": "Additional metadata to include with the trial. This field will NOT be used or validated by the task engine.",
+ "description": "Metadata fields that will not be used by task engine such as block information.",
"oneOf": [
- {},
+ {
+ "$ref": "#/$defs/Metadata"
+ },
{
"type": "null"
}
- ],
- "title": "Extra Metadata"
+ ]
}
},
"title": "Trial",
diff --git a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs
index 7fdde9f..fd4cb0f 100644
--- a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs
+++ b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs
@@ -3894,6 +3894,119 @@ public override string ToString()
}
+ ///
+ /// Metadata for trial. These fields will NOT be used by the task engine.
+ ///
+ [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")]
+ [System.ComponentModel.DescriptionAttribute("Metadata for trial. These fields will NOT be used by the task engine.")]
+ [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)]
+ [Bonsai.CombinatorAttribute(MethodName="Generate")]
+ public partial class Metadata
+ {
+
+ private double? _pRewardLeft;
+
+ private double? _pRewardRight;
+
+ private object _extra;
+
+ public Metadata()
+ {
+ }
+
+ protected Metadata(Metadata other)
+ {
+ _pRewardLeft = other._pRewardLeft;
+ _pRewardRight = other._pRewardRight;
+ _extra = other._extra;
+ }
+
+ ///
+ /// Metadata for block probability of reward on the left side if response is made.
+ ///
+ [Newtonsoft.Json.JsonPropertyAttribute("p_reward_left")]
+ [System.ComponentModel.DescriptionAttribute("Metadata for block probability of reward on the left side if response is made.")]
+ public double? PRewardLeft
+ {
+ get
+ {
+ return _pRewardLeft;
+ }
+ set
+ {
+ _pRewardLeft = value;
+ }
+ }
+
+ ///
+ /// Metadata for the probability of reward on the right side if response is made.
+ ///
+ [Newtonsoft.Json.JsonPropertyAttribute("p_reward_right")]
+ [System.ComponentModel.DescriptionAttribute("Metadata for the probability of reward on the right side if response is made.")]
+ public double? PRewardRight
+ {
+ get
+ {
+ return _pRewardRight;
+ }
+ set
+ {
+ _pRewardRight = value;
+ }
+ }
+
+ ///
+ /// Additional metadata to include with the trial. This field will NOT be used or validated by the task engine.
+ ///
+ [System.Xml.Serialization.XmlIgnoreAttribute()]
+ [Newtonsoft.Json.JsonPropertyAttribute("extra")]
+ [System.ComponentModel.DescriptionAttribute("Additional metadata to include with the trial. This field will NOT be used or val" +
+ "idated by the task engine.")]
+ public object Extra
+ {
+ get
+ {
+ return _extra;
+ }
+ set
+ {
+ _extra = value;
+ }
+ }
+
+ public System.IObservable Generate()
+ {
+ return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new Metadata(this)));
+ }
+
+ public System.IObservable Generate(System.IObservable source)
+ {
+ return System.Reactive.Linq.Observable.Select(source, _ => new Metadata(this));
+ }
+
+ protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder)
+ {
+ stringBuilder.Append("PRewardLeft = " + _pRewardLeft + ", ");
+ stringBuilder.Append("PRewardRight = " + _pRewardRight + ", ");
+ stringBuilder.Append("Extra = " + _extra);
+ return true;
+ }
+
+ public override string ToString()
+ {
+ System.Text.StringBuilder stringBuilder = new System.Text.StringBuilder();
+ stringBuilder.Append(GetType().Name);
+ stringBuilder.Append(" { ");
+ if (PrintMembers(stringBuilder))
+ {
+ stringBuilder.Append(" ");
+ }
+ stringBuilder.Append("}");
+ return stringBuilder.ToString();
+ }
+ }
+
+
///
/// Settings for the quick retract feature.
///
@@ -5639,7 +5752,7 @@ public partial class Trial
private double _lickspoutOffsetDelta;
- private object _extraMetadata;
+ private Metadata _metadata;
public Trial()
{
@@ -5666,7 +5779,7 @@ protected Trial(Trial other)
_interTrialIntervalDuration = other._interTrialIntervalDuration;
_isAutoResponseRight = other._isAutoResponseRight;
_lickspoutOffsetDelta = other._lickspoutOffsetDelta;
- _extraMetadata = other._extraMetadata;
+ _metadata = other._metadata;
}
///
@@ -5863,21 +5976,20 @@ public double LickspoutOffsetDelta
}
///
- /// Additional metadata to include with the trial. This field will NOT be used or validated by the task engine.
+ /// Metadata fields that will not be used by task engine such as block information.
///
[System.Xml.Serialization.XmlIgnoreAttribute()]
- [Newtonsoft.Json.JsonPropertyAttribute("extra_metadata")]
- [System.ComponentModel.DescriptionAttribute("Additional metadata to include with the trial. This field will NOT be used or val" +
- "idated by the task engine.")]
- public object ExtraMetadata
+ [Newtonsoft.Json.JsonPropertyAttribute("metadata")]
+ [System.ComponentModel.DescriptionAttribute("Metadata fields that will not be used by task engine such as block information.")]
+ public Metadata Metadata
{
get
{
- return _extraMetadata;
+ return _metadata;
}
set
{
- _extraMetadata = value;
+ _metadata = value;
}
}
@@ -5904,7 +6016,7 @@ protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder)
stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", ");
stringBuilder.Append("IsAutoResponseRight = " + _isAutoResponseRight + ", ");
stringBuilder.Append("LickspoutOffsetDelta = " + _lickspoutOffsetDelta + ", ");
- stringBuilder.Append("ExtraMetadata = " + _extraMetadata);
+ stringBuilder.Append("Metadata = " + _metadata);
return true;
}
@@ -7853,6 +7965,11 @@ public System.IObservable Process(System.IObservable source
return Process(source);
}
+ public System.IObservable Process(System.IObservable source)
+ {
+ return Process(source);
+ }
+
public System.IObservable Process(System.IObservable source)
{
return Process(source);
@@ -7992,6 +8109,7 @@ public System.IObservable Process(System.IObservable source)
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
+ [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py
index 3729188..39ddc92 100644
--- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py
+++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py
@@ -12,12 +12,18 @@
from aind_behavior_services.task.distributions_utils import draw_sample
from pydantic import BaseModel, Field
-from ..trial_models import Trial
+from ..trial_models import Metadata, Trial
from ._base import BaseTrialGeneratorSpecModel, ITrialGenerator, TrialOutcome
logger = logging.getLogger(__name__)
+class BlockBasedTrialMetadata(BaseModel):
+ """Metadata for block based trial. These fields will NOT be used by the task engine."""
+
+ is_autowater: bool = Field(default=False, description="Flag indicating if autowater is given for trial.")
+
+
class AutoWaterParameters(BaseModel):
min_ignored_trials: int = Field(
default=3, ge=0, description="Minimum consecutive ignored trials before auto water is triggered."
@@ -171,18 +177,23 @@ def next(self) -> Trial | None:
logger.debug("Right baited: %s" % self.is_right_baited)
# determine autowater
- is_right_autowater = None
- if self._are_autowater_conditions_met():
- is_right_autowater = True if self.block.p_right_reward > self.block.p_left_reward else False
+ is_auto_response_right = None
+ if is_autowater := self._are_autowater_conditions_met():
+ is_auto_response_right = True if self.block.p_right_reward > self.block.p_left_reward else False
return Trial(
- p_reward_left=1 if (self.is_left_baited and self.spec.is_baiting) else self.block.p_left_reward,
- p_reward_right=1 if (self.is_right_baited and self.spec.is_baiting) else self.block.p_right_reward,
+ p_reward_left=1 if (self.is_left_baited or is_auto_response_right is False) else self.block.p_left_reward,
+ p_reward_right=1 if (self.is_right_baited or is_auto_response_right) else self.block.p_right_reward,
reward_consumption_duration=self.spec.reward_consumption_duration,
response_deadline_duration=self.spec.response_duration,
quiescence_period_duration=quiescent,
inter_trial_interval_duration=iti,
- is_auto_response_right=is_right_autowater,
+ is_auto_response_right=is_auto_response_right,
+ metadata=Metadata(
+ p_reward_left=self.block.p_left_reward,
+ p_reward_right=self.block.p_right_reward,
+ extra=BlockBasedTrialMetadata(is_autowater=is_autowater),
+ ),
)
def _are_autowater_conditions_met(self) -> bool:
@@ -199,11 +210,11 @@ def _are_autowater_conditions_met(self) -> bool:
min_unreward = self.spec.autowater_parameters.min_unrewarded_trials
is_ignored = [choice is None for choice in self.is_right_choice_history]
- if all(is_ignored[-min_ignore:]):
+ if len(is_ignored) > min_ignore and all(is_ignored[-min_ignore:]):
return True
is_unrewarded = [not reward for reward in self.reward_history]
- if all(is_unrewarded[-min_unreward:]):
+ if len(is_unrewarded) > min_unreward and all(is_unrewarded[-min_unreward:]):
return True
return False
diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py
index adb8d07..7f28f4d 100644
--- a/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py
+++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py
@@ -43,6 +43,27 @@ class QuickRetractSettings(BaseModel):
)
+class Metadata(BaseModel):
+ """Metadata for trial. These fields will NOT be used by the task engine."""
+
+ p_reward_left: Optional[float] = Field(
+ default=None,
+ ge=0,
+ le=1,
+ description="Metadata for block probability of reward on the left side if response is made.",
+ )
+ p_reward_right: Optional[float] = Field(
+ default=None,
+ ge=0,
+ le=1,
+ description="Metadata for the probability of reward on the right side if response is made.",
+ )
+ extra: Optional[SerializeAsAny[Any]] = Field(
+ default=None,
+ description="Additional metadata to include with the trial. This field will NOT be used or validated by the task engine.",
+ )
+
+
class Trial(BaseModel):
"""Represents a single trial that can be instantiated by the Bonsai state machine."""
@@ -84,9 +105,10 @@ class Trial(BaseModel):
default=0.0,
description="Horizontal delta offset of the lickspouts (in mm) applied in this trial. Positive values move the lickspouts right.",
)
- extra_metadata: Optional[SerializeAsAny[Any]] = Field(
+ metadata: Optional[Metadata] = Field(
default=None,
- description="Additional metadata to include with the trial. This field will NOT be used or validated by the task engine.",
+ validate_default=True,
+ description="Metadata fields that will not be used by task engine such as block information.",
)
diff --git a/tests/trial_generators/test_block_based_trial_generator.py b/tests/trial_generators/test_block_based_trial_generator.py
index 60efd2d..661439d 100644
--- a/tests/trial_generators/test_block_based_trial_generator.py
+++ b/tests/trial_generators/test_block_based_trial_generator.py
@@ -48,18 +48,6 @@ def test_next_returns_correct_reward_probs(self):
self.assertEqual(trial.p_reward_left, self.generator.block.p_left_reward)
self.assertEqual(trial.p_reward_right, self.generator.block.p_right_reward)
- #### Test unbaited ####
-
- def test_baiting_disabled_reward_prob_unchanged(self):
- """Without baiting, reward probs should equal block probs exactly."""
- self.generator.block = Block(p_right_reward=0.8, p_left_reward=0.2, right_length=10, left_length=10)
- self.generator.is_left_baited = True
- self.generator.is_right_baited = True
- trial = self.generator.next()
-
- self.assertEqual(trial.p_reward_right, 0.8)
- self.assertEqual(trial.p_reward_left, 0.2)
-
class TestBlockBaseBaitingTrialGenerator(unittest.TestCase):
def setUp(self):
@@ -80,14 +68,14 @@ def test_baiting_sets_prob_to_1_when_baited(self):
def test_baiting_accumulates_when_random_exceeds_prob(self):
"""Bait should carry over when random number exceeds reward prob."""
self.generator.block = Block(p_right_reward=0.5, p_left_reward=0.5, right_length=10, left_length=10)
- self.generator.is_right_baited = False
- self.generator.is_left_baited = False
+ self.generator.is_right_baited = True
+ self.generator.is_left_baited = True
with patch("numpy.random.random", return_value=np.array([0.9, 0.9])):
trial = self.generator.next()
- self.assertEqual(trial.p_reward_right, 0.5)
- self.assertEqual(trial.p_reward_left, 0.5)
+ self.assertEqual(trial.p_reward_right, 1.0)
+ self.assertEqual(trial.p_reward_left, 1.0)
def test_baiting_triggers_when_random_below_prob(self):
"""Bait should trigger reward prob of 1.0 when random number is below reward prob."""